Added SwitchDB context manager (#106)
This commit is contained in:
@@ -3,7 +3,7 @@ from pymongo import Connection, ReplicaSetConnection, uri_parser
|
||||
|
||||
|
||||
__all__ = ['ConnectionError', 'connect', 'register_connection',
|
||||
'DEFAULT_CONNECTION_NAME']
|
||||
'DEFAULT_CONNECTION_NAME', 'SwitchDB']
|
||||
|
||||
|
||||
DEFAULT_CONNECTION_NAME = 'default'
|
||||
@@ -163,6 +163,47 @@ def connect(db, alias=DEFAULT_CONNECTION_NAME, **kwargs):
|
||||
|
||||
return get_connection(alias)
|
||||
|
||||
|
||||
class SwitchDB(object):
|
||||
""" SwitchDB alias contextmanager.
|
||||
|
||||
Example ::
|
||||
# Register connections
|
||||
register_connection('default', 'mongoenginetest')
|
||||
register_connection('testdb-1', 'mongoenginetest2')
|
||||
|
||||
class Group(Document):
|
||||
name = StringField()
|
||||
|
||||
Group(name="test").save() # Saves in the default db
|
||||
|
||||
with SwitchDB(Group, 'testdb-1') as Group:
|
||||
Group(name="hello testdb!").save() # Saves in testdb-1
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, cls, db_alias):
|
||||
""" Construct the query_counter.
|
||||
|
||||
:param cls: the class to change the registered db
|
||||
:param db_alias: the name of the specific database to use
|
||||
"""
|
||||
self.cls = cls
|
||||
self.collection = cls._get_collection()
|
||||
self.db_alias = db_alias
|
||||
self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)
|
||||
|
||||
def __enter__(self):
|
||||
""" change the db_alias and clear the cached collection """
|
||||
self.cls._meta["db_alias"] = self.db_alias
|
||||
self.cls._collection = None
|
||||
return self.cls
|
||||
|
||||
def __exit__(self, t, value, traceback):
|
||||
""" Reset the db_alias and collection """
|
||||
self.cls._meta["db_alias"] = self.ori_db_alias
|
||||
self.cls._collection = self.collection
|
||||
|
||||
# Support old naming convention
|
||||
_get_connection = get_connection
|
||||
_get_db = get_db
|
||||
|
||||
Reference in New Issue
Block a user