diff --git a/docs/changelog.rst b/docs/changelog.rst index 354d4718..65e11034 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -33,6 +33,7 @@ Changes in 0.8.X - Fixed reverse delete rule with inheritance (#197) - Fixed validation for GenericReferences which havent been dereferenced - Added SwitchDB context manager (#106) +- Added switch_db method to document instances (#106) Changes in 0.7.9 ================ diff --git a/mongoengine/connection.py b/mongoengine/connection.py index b6c78e84..80791e5f 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -168,6 +168,7 @@ class SwitchDB(object): """ SwitchDB alias contextmanager. Example :: + # Register connections register_connection('default', 'mongoenginetest') register_connection('testdb-1', 'mongoenginetest2') diff --git a/mongoengine/document.py b/mongoengine/document.py index f40f1c9f..3bc4caed 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -9,7 +9,7 @@ from mongoengine import signals, queryset from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, BaseDict, BaseList, ALLOW_INHERITANCE, get_document) from queryset import OperationError, NotUniqueError -from connection import get_db, DEFAULT_CONNECTION_NAME +from connection import get_db, DEFAULT_CONNECTION_NAME, SwitchDB __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', 'DynamicEmbeddedDocument', 'OperationError', @@ -222,7 +222,7 @@ class Document(BaseDocument): created = ('_id' not in doc or self._created or force_insert) try: - collection = self.__class__.objects._collection + collection = self._get_collection() if created: if force_insert: object_id = collection.insert(doc, safe=safe, @@ -321,6 +321,16 @@ class Document(BaseDocument): ref.save(**kwargs) ref._changed_fields = [] + @property + def _qs(self): + """ + Returns the queryset to use for updating / reloading / deletions + """ + qs = self.__class__.objects + if hasattr(self, '_objects'): + qs = self._objects + return qs + @property def _object_key(self): """Dict to identify object in collection @@ -342,7 +352,7 @@ class Document(BaseDocument): raise OperationError('attempt to update a document not yet saved') # Need to add shard key to query, or you get an error - return self.__class__.objects(**self._object_key).update_one(**kwargs) + return self._qs.filter(**self._object_key).update_one(**kwargs) def delete(self, safe=False): """Delete the :class:`~mongoengine.Document` from the database. This @@ -353,13 +363,39 @@ class Document(BaseDocument): signals.pre_delete.send(self.__class__, document=self) try: - self.__class__.objects(**self._object_key).delete(safe=safe) + self._qs.filter(**self._object_key).delete(safe=safe) except pymongo.errors.OperationFailure, err: message = u'Could not delete document (%s)' % err.message raise OperationError(message) signals.post_delete.send(self.__class__, document=self) + def switch_db(self, db_alias): + """ + Temporarily switch the database for a document instance. + + Only really useful for archiving off data and calling `save()`:: + + user = User.objects.get(id=user_id) + user.switch_db('archive-db') + user.save() + + If you need to read from another database see + :class:`~mongoengine.SwitchDB` + + :param db_alias: The database alias to use for saving the document + """ + with SwitchDB(self.__class__, db_alias) as cls: + collection = cls._get_collection() + db = cls._get_db + self._get_collection = lambda: collection + self._get_db = lambda: db + self._collection = collection + self._created = True + self._objects = self.__class__.objects + self._objects._collection_obj = collection + return self + def select_related(self, max_depth=1): """Handles dereferencing of :class:`~bson.dbref.DBRef` objects to a maximum depth in order to cut down the number queries to mongodb. @@ -377,7 +413,7 @@ class Document(BaseDocument): .. versionchanged:: 0.6 Now chainable """ id_field = self._meta['id_field'] - obj = self.__class__.objects( + obj = self._qs.filter( **{id_field: self[id_field]} ).limit(1).select_related(max_depth=max_depth) if obj: diff --git a/tests/document/instance.py b/tests/document/instance.py index 07c4f0e0..3b5a4bd9 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -2114,6 +2114,56 @@ class ValidatorErrorTest(unittest.TestCase): self.assertEqual(classic_doc, dict_doc) self.assertEqual(classic_doc._data, dict_doc._data) + def test_switch_db_instance(self): + register_connection('testdb-1', 'mongoenginetest2') + + class Group(Document): + name = StringField() + + Group.drop_collection() + with SwitchDB(Group, 'testdb-1') as Group: + Group.drop_collection() + + Group(name="hello - default").save() + self.assertEqual(1, Group.objects.count()) + + group = Group.objects.first() + group.switch_db('testdb-1') + group.name = "hello - testdb!" + group.save() + + with SwitchDB(Group, 'testdb-1') as Group: + group = Group.objects.first() + self.assertEqual("hello - testdb!", group.name) + + group = Group.objects.first() + self.assertEqual("hello - default", group.name) + + # Slightly contrived now - perform an update + # Only works as they have the same object_id + group.switch_db('testdb-1') + group.update(set__name="hello - update") + + with SwitchDB(Group, 'testdb-1') as Group: + group = Group.objects.first() + self.assertEqual("hello - update", group.name) + Group.drop_collection() + self.assertEqual(0, Group.objects.count()) + + group = Group.objects.first() + self.assertEqual("hello - default", group.name) + + # Totally contrived now - perform a delete + # Only works as they have the same object_id + group.switch_db('testdb-1') + group.delete() + + with SwitchDB(Group, 'testdb-1') as Group: + self.assertEqual(0, Group.objects.count()) + + group = Group.objects.first() + self.assertEqual("hello - default", group.name) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_connection.py b/tests/test_connection.py index 4931dc9f..7ff18a38 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -94,6 +94,7 @@ class ConnectionTest(unittest.TestCase): self.assertEqual(d, date_doc.the_date) def test_switch_db_context_manager(self): + connect('mongoenginetest') register_connection('testdb-1', 'mongoenginetest2') class Group(Document):