From 9797d7a7fb9c21684f360d03b06800b99b8093c4 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 23 Jan 2013 21:19:21 +0000 Subject: [PATCH] Added switch_collection context manager and method (#220) --- mongoengine/context_managers.py | 45 +++++++++- mongoengine/document.py | 27 +++++- tests/test_connection.py | 24 ----- tests/test_context_managers.py | 153 ++++++++++++++++++++++++++++++++ tests/test_dereference.py | 71 --------------- 5 files changed, 223 insertions(+), 97 deletions(-) create mode 100644 tests/test_context_managers.py diff --git a/mongoengine/context_managers.py b/mongoengine/context_managers.py index 7255d51c..e73d4a2a 100644 --- a/mongoengine/context_managers.py +++ b/mongoengine/context_managers.py @@ -2,7 +2,7 @@ from mongoengine.common import _import_class from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db from mongoengine.queryset import OperationError, QuerySet -__all__ = ("switch_db", "no_dereference", "query_counter") +__all__ = ("switch_db", "switch_collection", "no_dereference", "query_counter") class switch_db(object): @@ -47,6 +47,49 @@ class switch_db(object): self.cls._collection = self.collection +class switch_collection(object): + """ switch_collection alias context manager. + + Example :: + + class Group(Document): + name = StringField() + + Group(name="test").save() # Saves in the default db + + with switch_collection(Group, 'group1') as Group: + Group(name="hello testdb!").save() # Saves in group1 collection + + """ + + def __init__(self, cls, collection_name): + """ Construct the switch_collection context manager + + :param cls: the class to change the registered db + :param collection_name: the name of the collection to use + """ + self.cls = cls + self.ori_collection = cls._get_collection() + self.ori_get_collection_name = cls._get_collection_name + self.collection_name = collection_name + + def __enter__(self): + """ change the _get_collection_name and clear the cached collection """ + + @classmethod + def _get_collection_name(cls): + return self.collection_name + + self.cls._get_collection_name = _get_collection_name + self.cls._collection = None + return self.cls + + def __exit__(self, t, value, traceback): + """ Reset the collection """ + self.cls._collection = self.ori_collection + self.cls._get_collection_name = self.ori_get_collection_name + + class no_dereference(object): """ no_dereference context manager. diff --git a/mongoengine/document.py b/mongoengine/document.py index 9d4a1e6e..75873b4b 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -11,7 +11,7 @@ from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass, ALLOW_INHERITANCE, get_document) from mongoengine.queryset import OperationError, NotUniqueError from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME -from mongoengine.context_managers import switch_db +from mongoengine.context_managers import switch_db, switch_collection __all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument', 'DynamicEmbeddedDocument', 'OperationError', @@ -398,6 +398,31 @@ class Document(BaseDocument): self._objects._collection_obj = collection return self + def switch_collection(self, collection_name): + """ + Temporarily switch the collection for a document instance. + + Only really useful for archiving off data and calling `save()`:: + + user = User.objects.get(id=user_id) + user.switch_collection('old-users') + user.save() + + If you need to read from another database see + :class:`~mongoengine.context_managers.switch_collection` + + :param collection_name: The database alias to use for saving the + document + """ + with switch_collection(self.__class__, collection_name) as cls: + collection = cls._get_collection() + self._get_collection = lambda: collection + self._collection = collection + self._created = True + self._objects = self.__class__.objects + self._objects._collection_obj = collection + return self + def select_related(self, max_depth=1): """Handles dereferencing of :class:`~bson.dbref.DBRef` objects to a maximum depth in order to cut down the number queries to mongodb. diff --git a/tests/test_connection.py b/tests/test_connection.py index 2a216fef..c32d231f 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -95,30 +95,6 @@ class ConnectionTest(unittest.TestCase): date_doc = DateDoc.objects.first() self.assertEqual(d, date_doc.the_date) - def test_switch_db_context_manager(self): - connect('mongoenginetest') - register_connection('testdb-1', 'mongoenginetest2') - - class Group(Document): - name = StringField() - - Group.drop_collection() - - Group(name="hello - default").save() - self.assertEqual(1, Group.objects.count()) - - with switch_db(Group, 'testdb-1') as Group: - - self.assertEqual(0, Group.objects.count()) - - Group(name="hello").save() - - self.assertEqual(1, Group.objects.count()) - - Group.drop_collection() - self.assertEqual(0, Group.objects.count()) - - self.assertEqual(1, Group.objects.count()) if __name__ == '__main__': unittest.main() diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py new file mode 100644 index 00000000..10fe7b8e --- /dev/null +++ b/tests/test_context_managers.py @@ -0,0 +1,153 @@ +from __future__ import with_statement +import unittest + +from mongoengine import * +from mongoengine.connection import get_db +from mongoengine.context_managers import (switch_db, switch_collection, + no_dereference, query_counter) + + +class ContextManagersTest(unittest.TestCase): + + def test_switch_db_context_manager(self): + connect('mongoenginetest') + register_connection('testdb-1', 'mongoenginetest2') + + class Group(Document): + name = StringField() + + Group.drop_collection() + + Group(name="hello - default").save() + self.assertEqual(1, Group.objects.count()) + + with switch_db(Group, 'testdb-1') as Group: + + self.assertEqual(0, Group.objects.count()) + + Group(name="hello").save() + + self.assertEqual(1, Group.objects.count()) + + Group.drop_collection() + self.assertEqual(0, Group.objects.count()) + + self.assertEqual(1, Group.objects.count()) + + def test_switch_collection_context_manager(self): + connect('mongoenginetest') + register_connection('testdb-1', 'mongoenginetest2') + + class Group(Document): + name = StringField() + + Group.drop_collection() + with switch_collection(Group, 'group1') as Group: + Group.drop_collection() + + Group(name="hello - group").save() + self.assertEqual(1, Group.objects.count()) + + with switch_collection(Group, 'group1') as Group: + + self.assertEqual(0, Group.objects.count()) + + Group(name="hello - group1").save() + + self.assertEqual(1, Group.objects.count()) + + Group.drop_collection() + self.assertEqual(0, Group.objects.count()) + + self.assertEqual(1, Group.objects.count()) + + def test_no_dereference_context_manager_object_id(self): + """Ensure that DBRef items in ListFields aren't dereferenced. + """ + connect('mongoenginetest') + + class User(Document): + name = StringField() + + class Group(Document): + ref = ReferenceField(User, dbref=False) + generic = GenericReferenceField() + members = ListField(ReferenceField(User, dbref=False)) + + User.drop_collection() + Group.drop_collection() + + for i in xrange(1, 51): + User(name='user %s' % i).save() + + user = User.objects.first() + Group(ref=user, members=User.objects, generic=user).save() + + with no_dereference(Group) as NoDeRefGroup: + self.assertTrue(Group._fields['members']._auto_dereference) + self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference) + + with no_dereference(Group) as Group: + group = Group.objects.first() + self.assertTrue(all([not isinstance(m, User) + for m in group.members])) + self.assertFalse(isinstance(group.ref, User)) + self.assertFalse(isinstance(group.generic, User)) + + self.assertTrue(all([isinstance(m, User) + for m in group.members])) + self.assertTrue(isinstance(group.ref, User)) + self.assertTrue(isinstance(group.generic, User)) + + def test_no_dereference_context_manager_dbref(self): + """Ensure that DBRef items in ListFields aren't dereferenced. + """ + connect('mongoenginetest') + + class User(Document): + name = StringField() + + class Group(Document): + ref = ReferenceField(User, dbref=True) + generic = GenericReferenceField() + members = ListField(ReferenceField(User, dbref=True)) + + User.drop_collection() + Group.drop_collection() + + for i in xrange(1, 51): + User(name='user %s' % i).save() + + user = User.objects.first() + Group(ref=user, members=User.objects, generic=user).save() + + with no_dereference(Group) as NoDeRefGroup: + self.assertTrue(Group._fields['members']._auto_dereference) + self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference) + + with no_dereference(Group) as Group: + group = Group.objects.first() + self.assertTrue(all([not isinstance(m, User) + for m in group.members])) + self.assertFalse(isinstance(group.ref, User)) + self.assertFalse(isinstance(group.generic, User)) + + self.assertTrue(all([isinstance(m, User) + for m in group.members])) + self.assertTrue(isinstance(group.ref, User)) + self.assertTrue(isinstance(group.generic, User)) + + def test_query_counter(self): + connect('mongoenginetest') + db = get_db() + + with query_counter() as q: + self.assertEqual(0, q) + + for i in xrange(1, 51): + db.test.find({}).count() + + self.assertEqual(50, q) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 8e4ffdd8..adbc5192 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -1121,77 +1121,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(q, 2) - def test_no_dereference_context_manager_object_id(self): - """Ensure that DBRef items in ListFields aren't dereferenced. - """ - class User(Document): - name = StringField() - - class Group(Document): - ref = ReferenceField(User, dbref=False) - generic = GenericReferenceField() - members = ListField(ReferenceField(User, dbref=False)) - - User.drop_collection() - Group.drop_collection() - - for i in xrange(1, 51): - User(name='user %s' % i).save() - - user = User.objects.first() - Group(ref=user, members=User.objects, generic=user).save() - - with no_dereference(Group) as NoDeRefGroup: - self.assertTrue(Group._fields['members']._auto_dereference) - self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference) - - with no_dereference(Group) as Group: - group = Group.objects.first() - self.assertTrue(all([not isinstance(m, User) - for m in group.members])) - self.assertFalse(isinstance(group.ref, User)) - self.assertFalse(isinstance(group.generic, User)) - - self.assertTrue(all([isinstance(m, User) - for m in group.members])) - self.assertTrue(isinstance(group.ref, User)) - self.assertTrue(isinstance(group.generic, User)) - - def test_no_dereference_context_manager_dbref(self): - """Ensure that DBRef items in ListFields aren't dereferenced. - """ - class User(Document): - name = StringField() - - class Group(Document): - ref = ReferenceField(User, dbref=True) - generic = GenericReferenceField() - members = ListField(ReferenceField(User, dbref=True)) - - User.drop_collection() - Group.drop_collection() - - for i in xrange(1, 51): - User(name='user %s' % i).save() - - user = User.objects.first() - Group(ref=user, members=User.objects, generic=user).save() - - with no_dereference(Group) as NoDeRefGroup: - self.assertTrue(Group._fields['members']._auto_dereference) - self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference) - - with no_dereference(Group) as Group: - group = Group.objects.first() - self.assertTrue(all([not isinstance(m, User) - for m in group.members])) - self.assertFalse(isinstance(group.ref, User)) - self.assertFalse(isinstance(group.generic, User)) - - self.assertTrue(all([isinstance(m, User) - for m in group.members])) - self.assertTrue(isinstance(group.ref, User)) - self.assertTrue(isinstance(group.generic, User)) if __name__ == '__main__': unittest.main()