diff --git a/docs/apireference.rst b/docs/apireference.rst index 0f8901a1..69b1db03 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -7,6 +7,7 @@ Connecting .. autofunction:: mongoengine.connect .. autofunction:: mongoengine.register_connection +.. autoclass:: mongoengine.SwitchDB Documents ========= diff --git a/docs/changelog.rst b/docs/changelog.rst index ba2c04c8..354d4718 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -32,6 +32,7 @@ Changes in 0.8.X - Fixed inheritance and unique index creation (#140) - Fixed reverse delete rule with inheritance (#197) - Fixed validation for GenericReferences which havent been dereferenced +- Added SwitchDB context manager (#106) Changes in 0.7.9 ================ diff --git a/docs/guide/connecting.rst b/docs/guide/connecting.rst index 657c46c2..b39ccda4 100644 --- a/docs/guide/connecting.rst +++ b/docs/guide/connecting.rst @@ -69,3 +69,21 @@ to point across databases and collections. Below is an example schema, using book = ReferenceField(Book) meta = {"db_alias": "users-books-db"} + + +Switch Database Context Manager +=============================== + +Sometimes you might want to switch the database to query against for a class. +The SwitchDB context manager allows you to change the database alias for a +class eg :: + + from mongoengine import SwitchDB + + class User(Document): + name = StringField() + + meta = {"db_alias": "user-db"} + + with SwitchDB(User, 'archive-user-db') as User: + User(name="Ross").save() # Saves the 'archive-user-db' diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 87308ba3..b6c78e84 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -3,7 +3,7 @@ from pymongo import Connection, ReplicaSetConnection, uri_parser __all__ = ['ConnectionError', 'connect', 'register_connection', - 'DEFAULT_CONNECTION_NAME'] + 'DEFAULT_CONNECTION_NAME', 'SwitchDB'] DEFAULT_CONNECTION_NAME = 'default' @@ -163,6 +163,47 @@ def connect(db, alias=DEFAULT_CONNECTION_NAME, **kwargs): return get_connection(alias) + +class SwitchDB(object): + """ SwitchDB alias contextmanager. + + Example :: + # Register connections + register_connection('default', 'mongoenginetest') + register_connection('testdb-1', 'mongoenginetest2') + + class Group(Document): + name = StringField() + + Group(name="test").save() # Saves in the default db + + with SwitchDB(Group, 'testdb-1') as Group: + Group(name="hello testdb!").save() # Saves in testdb-1 + + """ + + def __init__(self, cls, db_alias): + """ Construct the query_counter. + + :param cls: the class to change the registered db + :param db_alias: the name of the specific database to use + """ + self.cls = cls + self.collection = cls._get_collection() + self.db_alias = db_alias + self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME) + + def __enter__(self): + """ change the db_alias and clear the cached collection """ + self.cls._meta["db_alias"] = self.db_alias + self.cls._collection = None + return self.cls + + def __exit__(self, t, value, traceback): + """ Reset the db_alias and collection """ + self.cls._meta["db_alias"] = self.ori_db_alias + self.cls._collection = self.collection + # Support old naming convention _get_connection = get_connection _get_db = get_db diff --git a/tests/test_connection.py b/tests/test_connection.py index cd03df0b..4931dc9f 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -93,6 +93,29 @@ class ConnectionTest(unittest.TestCase): date_doc = DateDoc.objects.first() self.assertEqual(d, date_doc.the_date) + def test_switch_db_context_manager(self): + register_connection('testdb-1', 'mongoenginetest2') + + class Group(Document): + name = StringField() + + Group.drop_collection() + + Group(name="hello - default").save() + self.assertEqual(1, Group.objects.count()) + + with SwitchDB(Group, 'testdb-1') as Group: + + self.assertEqual(0, Group.objects.count()) + + Group(name="hello").save() + + self.assertEqual(1, Group.objects.count()) + + Group.drop_collection() + self.assertEqual(0, Group.objects.count()) + + self.assertEqual(1, Group.objects.count()) if __name__ == '__main__': unittest.main()