Added switch_collection context manager and method (#220)

This commit is contained in:
Ross Lawley
2013-01-23 21:19:21 +00:00
parent c8b65317ef
commit 9797d7a7fb
5 changed files with 223 additions and 97 deletions

View File

@@ -2,7 +2,7 @@ from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.queryset import OperationError, QuerySet
__all__ = ("switch_db", "no_dereference", "query_counter")
__all__ = ("switch_db", "switch_collection", "no_dereference", "query_counter")
class switch_db(object):
@@ -47,6 +47,49 @@ class switch_db(object):
self.cls._collection = self.collection
class switch_collection(object):
""" switch_collection alias context manager.
Example ::
class Group(Document):
name = StringField()
Group(name="test").save() # Saves in the default db
with switch_collection(Group, 'group1') as Group:
Group(name="hello testdb!").save() # Saves in group1 collection
"""
def __init__(self, cls, collection_name):
""" Construct the switch_collection context manager
:param cls: the class to change the registered db
:param collection_name: the name of the collection to use
"""
self.cls = cls
self.ori_collection = cls._get_collection()
self.ori_get_collection_name = cls._get_collection_name
self.collection_name = collection_name
def __enter__(self):
""" change the _get_collection_name and clear the cached collection """
@classmethod
def _get_collection_name(cls):
return self.collection_name
self.cls._get_collection_name = _get_collection_name
self.cls._collection = None
return self.cls
def __exit__(self, t, value, traceback):
""" Reset the collection """
self.cls._collection = self.ori_collection
self.cls._get_collection_name = self.ori_get_collection_name
class no_dereference(object):
""" no_dereference context manager.