Added SwitchDB context manager (#106)

This commit is contained in:
Ross Lawley
2013-01-23 12:54:14 +00:00
parent 6d68ad735c
commit e5e88d792e
5 changed files with 85 additions and 1 deletions

View File

@@ -3,7 +3,7 @@ from pymongo import Connection, ReplicaSetConnection, uri_parser
__all__ = ['ConnectionError', 'connect', 'register_connection',
'DEFAULT_CONNECTION_NAME']
'DEFAULT_CONNECTION_NAME', 'SwitchDB']
DEFAULT_CONNECTION_NAME = 'default'
@@ -163,6 +163,47 @@ def connect(db, alias=DEFAULT_CONNECTION_NAME, **kwargs):
return get_connection(alias)
class SwitchDB(object):
""" SwitchDB alias contextmanager.
Example ::
# Register connections
register_connection('default', 'mongoenginetest')
register_connection('testdb-1', 'mongoenginetest2')
class Group(Document):
name = StringField()
Group(name="test").save() # Saves in the default db
with SwitchDB(Group, 'testdb-1') as Group:
Group(name="hello testdb!").save() # Saves in testdb-1
"""
def __init__(self, cls, db_alias):
""" Construct the query_counter.
:param cls: the class to change the registered db
:param db_alias: the name of the specific database to use
"""
self.cls = cls
self.collection = cls._get_collection()
self.db_alias = db_alias
self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)
def __enter__(self):
""" change the db_alias and clear the cached collection """
self.cls._meta["db_alias"] = self.db_alias
self.cls._collection = None
return self.cls
def __exit__(self, t, value, traceback):
""" Reset the db_alias and collection """
self.cls._meta["db_alias"] = self.ori_db_alias
self.cls._collection = self.collection
# Support old naming convention
_get_connection = get_connection
_get_db = get_db