Added switch_db method to document instances (#106)
This commit is contained in:
		| @@ -33,6 +33,7 @@ Changes in 0.8.X | |||||||
| - Fixed reverse delete rule with inheritance (#197) | - Fixed reverse delete rule with inheritance (#197) | ||||||
| - Fixed validation for GenericReferences which havent been dereferenced | - Fixed validation for GenericReferences which havent been dereferenced | ||||||
| - Added SwitchDB context manager (#106) | - Added SwitchDB context manager (#106) | ||||||
|  | - Added switch_db method to document instances (#106) | ||||||
|  |  | ||||||
| Changes in 0.7.9 | Changes in 0.7.9 | ||||||
| ================ | ================ | ||||||
|   | |||||||
| @@ -168,6 +168,7 @@ class SwitchDB(object): | |||||||
|     """ SwitchDB alias contextmanager. |     """ SwitchDB alias contextmanager. | ||||||
|  |  | ||||||
|     Example :: |     Example :: | ||||||
|  |  | ||||||
|         # Register connections |         # Register connections | ||||||
|         register_connection('default', 'mongoenginetest') |         register_connection('default', 'mongoenginetest') | ||||||
|         register_connection('testdb-1', 'mongoenginetest2') |         register_connection('testdb-1', 'mongoenginetest2') | ||||||
|   | |||||||
| @@ -9,7 +9,7 @@ from mongoengine import signals, queryset | |||||||
| from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, | from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, | ||||||
|                   BaseDict, BaseList, ALLOW_INHERITANCE, get_document) |                   BaseDict, BaseList, ALLOW_INHERITANCE, get_document) | ||||||
| from queryset import OperationError, NotUniqueError | from queryset import OperationError, NotUniqueError | ||||||
| from connection import get_db, DEFAULT_CONNECTION_NAME | from connection import get_db, DEFAULT_CONNECTION_NAME, SwitchDB | ||||||
|  |  | ||||||
| __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', | __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', | ||||||
|            'DynamicEmbeddedDocument', 'OperationError', |            'DynamicEmbeddedDocument', 'OperationError', | ||||||
| @@ -222,7 +222,7 @@ class Document(BaseDocument): | |||||||
|         created = ('_id' not in doc or self._created or force_insert) |         created = ('_id' not in doc or self._created or force_insert) | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             collection = self.__class__.objects._collection |             collection = self._get_collection() | ||||||
|             if created: |             if created: | ||||||
|                 if force_insert: |                 if force_insert: | ||||||
|                     object_id = collection.insert(doc, safe=safe, |                     object_id = collection.insert(doc, safe=safe, | ||||||
| @@ -321,6 +321,16 @@ class Document(BaseDocument): | |||||||
|                 ref.save(**kwargs) |                 ref.save(**kwargs) | ||||||
|                 ref._changed_fields = [] |                 ref._changed_fields = [] | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def _qs(self): | ||||||
|  |         """ | ||||||
|  |         Returns the queryset to use for updating / reloading / deletions | ||||||
|  |         """ | ||||||
|  |         qs = self.__class__.objects | ||||||
|  |         if hasattr(self, '_objects'): | ||||||
|  |             qs = self._objects | ||||||
|  |         return qs | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def _object_key(self): |     def _object_key(self): | ||||||
|         """Dict to identify object in collection |         """Dict to identify object in collection | ||||||
| @@ -342,7 +352,7 @@ class Document(BaseDocument): | |||||||
|             raise OperationError('attempt to update a document not yet saved') |             raise OperationError('attempt to update a document not yet saved') | ||||||
|  |  | ||||||
|         # Need to add shard key to query, or you get an error |         # Need to add shard key to query, or you get an error | ||||||
|         return self.__class__.objects(**self._object_key).update_one(**kwargs) |         return self._qs.filter(**self._object_key).update_one(**kwargs) | ||||||
|  |  | ||||||
|     def delete(self, safe=False): |     def delete(self, safe=False): | ||||||
|         """Delete the :class:`~mongoengine.Document` from the database. This |         """Delete the :class:`~mongoengine.Document` from the database. This | ||||||
| @@ -353,13 +363,39 @@ class Document(BaseDocument): | |||||||
|         signals.pre_delete.send(self.__class__, document=self) |         signals.pre_delete.send(self.__class__, document=self) | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             self.__class__.objects(**self._object_key).delete(safe=safe) |             self._qs.filter(**self._object_key).delete(safe=safe) | ||||||
|         except pymongo.errors.OperationFailure, err: |         except pymongo.errors.OperationFailure, err: | ||||||
|             message = u'Could not delete document (%s)' % err.message |             message = u'Could not delete document (%s)' % err.message | ||||||
|             raise OperationError(message) |             raise OperationError(message) | ||||||
|  |  | ||||||
|         signals.post_delete.send(self.__class__, document=self) |         signals.post_delete.send(self.__class__, document=self) | ||||||
|  |  | ||||||
|  |     def switch_db(self, db_alias): | ||||||
|  |         """ | ||||||
|  |         Temporarily switch the database for a document instance. | ||||||
|  |  | ||||||
|  |         Only really useful for archiving off data and calling `save()`:: | ||||||
|  |  | ||||||
|  |             user = User.objects.get(id=user_id) | ||||||
|  |             user.switch_db('archive-db') | ||||||
|  |             user.save() | ||||||
|  |  | ||||||
|  |         If you need to read from another database see | ||||||
|  |         :class:`~mongoengine.SwitchDB` | ||||||
|  |  | ||||||
|  |         :param db_alias: The database alias to use for saving the document | ||||||
|  |         """ | ||||||
|  |         with SwitchDB(self.__class__, db_alias) as cls: | ||||||
|  |             collection = cls._get_collection() | ||||||
|  |             db = cls._get_db | ||||||
|  |         self._get_collection = lambda: collection | ||||||
|  |         self._get_db = lambda: db | ||||||
|  |         self._collection = collection | ||||||
|  |         self._created = True | ||||||
|  |         self._objects = self.__class__.objects | ||||||
|  |         self._objects._collection_obj = collection | ||||||
|  |         return self | ||||||
|  |  | ||||||
|     def select_related(self, max_depth=1): |     def select_related(self, max_depth=1): | ||||||
|         """Handles dereferencing of :class:`~bson.dbref.DBRef` objects to |         """Handles dereferencing of :class:`~bson.dbref.DBRef` objects to | ||||||
|         a maximum depth in order to cut down the number queries to mongodb. |         a maximum depth in order to cut down the number queries to mongodb. | ||||||
| @@ -377,7 +413,7 @@ class Document(BaseDocument): | |||||||
|         .. versionchanged:: 0.6  Now chainable |         .. versionchanged:: 0.6  Now chainable | ||||||
|         """ |         """ | ||||||
|         id_field = self._meta['id_field'] |         id_field = self._meta['id_field'] | ||||||
|         obj = self.__class__.objects( |         obj = self._qs.filter( | ||||||
|                 **{id_field: self[id_field]} |                 **{id_field: self[id_field]} | ||||||
|               ).limit(1).select_related(max_depth=max_depth) |               ).limit(1).select_related(max_depth=max_depth) | ||||||
|         if obj: |         if obj: | ||||||
|   | |||||||
| @@ -2114,6 +2114,56 @@ class ValidatorErrorTest(unittest.TestCase): | |||||||
|         self.assertEqual(classic_doc, dict_doc) |         self.assertEqual(classic_doc, dict_doc) | ||||||
|         self.assertEqual(classic_doc._data, dict_doc._data) |         self.assertEqual(classic_doc._data, dict_doc._data) | ||||||
|  |  | ||||||
|  |     def test_switch_db_instance(self): | ||||||
|  |         register_connection('testdb-1', 'mongoenginetest2') | ||||||
|  |  | ||||||
|  |         class Group(Document): | ||||||
|  |             name = StringField() | ||||||
|  |  | ||||||
|  |         Group.drop_collection() | ||||||
|  |         with SwitchDB(Group, 'testdb-1') as Group: | ||||||
|  |             Group.drop_collection() | ||||||
|  |  | ||||||
|  |         Group(name="hello - default").save() | ||||||
|  |         self.assertEqual(1, Group.objects.count()) | ||||||
|  |  | ||||||
|  |         group = Group.objects.first() | ||||||
|  |         group.switch_db('testdb-1') | ||||||
|  |         group.name = "hello - testdb!" | ||||||
|  |         group.save() | ||||||
|  |  | ||||||
|  |         with SwitchDB(Group, 'testdb-1') as Group: | ||||||
|  |             group = Group.objects.first() | ||||||
|  |             self.assertEqual("hello - testdb!", group.name) | ||||||
|  |  | ||||||
|  |         group = Group.objects.first() | ||||||
|  |         self.assertEqual("hello - default", group.name) | ||||||
|  |  | ||||||
|  |         # Slightly contrived now - perform an update | ||||||
|  |         # Only works as they have the same object_id | ||||||
|  |         group.switch_db('testdb-1') | ||||||
|  |         group.update(set__name="hello - update") | ||||||
|  |  | ||||||
|  |         with SwitchDB(Group, 'testdb-1') as Group: | ||||||
|  |             group = Group.objects.first() | ||||||
|  |             self.assertEqual("hello - update", group.name) | ||||||
|  |             Group.drop_collection() | ||||||
|  |             self.assertEqual(0, Group.objects.count()) | ||||||
|  |  | ||||||
|  |         group = Group.objects.first() | ||||||
|  |         self.assertEqual("hello - default", group.name) | ||||||
|  |  | ||||||
|  |         # Totally contrived now - perform a delete | ||||||
|  |         # Only works as they have the same object_id | ||||||
|  |         group.switch_db('testdb-1') | ||||||
|  |         group.delete() | ||||||
|  |  | ||||||
|  |         with SwitchDB(Group, 'testdb-1') as Group: | ||||||
|  |             self.assertEqual(0, Group.objects.count()) | ||||||
|  |  | ||||||
|  |         group = Group.objects.first() | ||||||
|  |         self.assertEqual("hello - default", group.name) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
| @@ -94,6 +94,7 @@ class ConnectionTest(unittest.TestCase): | |||||||
|         self.assertEqual(d, date_doc.the_date) |         self.assertEqual(d, date_doc.the_date) | ||||||
|  |  | ||||||
|     def test_switch_db_context_manager(self): |     def test_switch_db_context_manager(self): | ||||||
|  |         connect('mongoenginetest') | ||||||
|         register_connection('testdb-1', 'mongoenginetest2') |         register_connection('testdb-1', 'mongoenginetest2') | ||||||
|  |  | ||||||
|         class Group(Document): |         class Group(Document): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user