Added switch_collection context manager and method (#220)
This commit is contained in:
parent
c8b65317ef
commit
9797d7a7fb
@ -2,7 +2,7 @@ from mongoengine.common import _import_class
|
||||
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
|
||||
from mongoengine.queryset import OperationError, QuerySet
|
||||
|
||||
__all__ = ("switch_db", "no_dereference", "query_counter")
|
||||
__all__ = ("switch_db", "switch_collection", "no_dereference", "query_counter")
|
||||
|
||||
|
||||
class switch_db(object):
|
||||
@ -47,6 +47,49 @@ class switch_db(object):
|
||||
self.cls._collection = self.collection
|
||||
|
||||
|
||||
class switch_collection(object):
|
||||
""" switch_collection alias context manager.
|
||||
|
||||
Example ::
|
||||
|
||||
class Group(Document):
|
||||
name = StringField()
|
||||
|
||||
Group(name="test").save() # Saves in the default db
|
||||
|
||||
with switch_collection(Group, 'group1') as Group:
|
||||
Group(name="hello testdb!").save() # Saves in group1 collection
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, cls, collection_name):
|
||||
""" Construct the switch_collection context manager
|
||||
|
||||
:param cls: the class to change the registered db
|
||||
:param collection_name: the name of the collection to use
|
||||
"""
|
||||
self.cls = cls
|
||||
self.ori_collection = cls._get_collection()
|
||||
self.ori_get_collection_name = cls._get_collection_name
|
||||
self.collection_name = collection_name
|
||||
|
||||
def __enter__(self):
|
||||
""" change the _get_collection_name and clear the cached collection """
|
||||
|
||||
@classmethod
|
||||
def _get_collection_name(cls):
|
||||
return self.collection_name
|
||||
|
||||
self.cls._get_collection_name = _get_collection_name
|
||||
self.cls._collection = None
|
||||
return self.cls
|
||||
|
||||
def __exit__(self, t, value, traceback):
|
||||
""" Reset the collection """
|
||||
self.cls._collection = self.ori_collection
|
||||
self.cls._get_collection_name = self.ori_get_collection_name
|
||||
|
||||
|
||||
class no_dereference(object):
|
||||
""" no_dereference context manager.
|
||||
|
||||
|
@ -11,7 +11,7 @@ from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass,
|
||||
ALLOW_INHERITANCE, get_document)
|
||||
from mongoengine.queryset import OperationError, NotUniqueError
|
||||
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
|
||||
from mongoengine.context_managers import switch_db
|
||||
from mongoengine.context_managers import switch_db, switch_collection
|
||||
|
||||
__all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument',
|
||||
'DynamicEmbeddedDocument', 'OperationError',
|
||||
@ -398,6 +398,31 @@ class Document(BaseDocument):
|
||||
self._objects._collection_obj = collection
|
||||
return self
|
||||
|
||||
def switch_collection(self, collection_name):
|
||||
"""
|
||||
Temporarily switch the collection for a document instance.
|
||||
|
||||
Only really useful for archiving off data and calling `save()`::
|
||||
|
||||
user = User.objects.get(id=user_id)
|
||||
user.switch_collection('old-users')
|
||||
user.save()
|
||||
|
||||
If you need to read from another database see
|
||||
:class:`~mongoengine.context_managers.switch_collection`
|
||||
|
||||
:param collection_name: The database alias to use for saving the
|
||||
document
|
||||
"""
|
||||
with switch_collection(self.__class__, collection_name) as cls:
|
||||
collection = cls._get_collection()
|
||||
self._get_collection = lambda: collection
|
||||
self._collection = collection
|
||||
self._created = True
|
||||
self._objects = self.__class__.objects
|
||||
self._objects._collection_obj = collection
|
||||
return self
|
||||
|
||||
def select_related(self, max_depth=1):
|
||||
"""Handles dereferencing of :class:`~bson.dbref.DBRef` objects to
|
||||
a maximum depth in order to cut down the number queries to mongodb.
|
||||
|
@ -95,30 +95,6 @@ class ConnectionTest(unittest.TestCase):
|
||||
date_doc = DateDoc.objects.first()
|
||||
self.assertEqual(d, date_doc.the_date)
|
||||
|
||||
def test_switch_db_context_manager(self):
|
||||
connect('mongoenginetest')
|
||||
register_connection('testdb-1', 'mongoenginetest2')
|
||||
|
||||
class Group(Document):
|
||||
name = StringField()
|
||||
|
||||
Group.drop_collection()
|
||||
|
||||
Group(name="hello - default").save()
|
||||
self.assertEqual(1, Group.objects.count())
|
||||
|
||||
with switch_db(Group, 'testdb-1') as Group:
|
||||
|
||||
self.assertEqual(0, Group.objects.count())
|
||||
|
||||
Group(name="hello").save()
|
||||
|
||||
self.assertEqual(1, Group.objects.count())
|
||||
|
||||
Group.drop_collection()
|
||||
self.assertEqual(0, Group.objects.count())
|
||||
|
||||
self.assertEqual(1, Group.objects.count())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
153
tests/test_context_managers.py
Normal file
153
tests/test_context_managers.py
Normal file
@ -0,0 +1,153 @@
|
||||
from __future__ import with_statement
|
||||
import unittest
|
||||
|
||||
from mongoengine import *
|
||||
from mongoengine.connection import get_db
|
||||
from mongoengine.context_managers import (switch_db, switch_collection,
|
||||
no_dereference, query_counter)
|
||||
|
||||
|
||||
class ContextManagersTest(unittest.TestCase):
|
||||
|
||||
def test_switch_db_context_manager(self):
|
||||
connect('mongoenginetest')
|
||||
register_connection('testdb-1', 'mongoenginetest2')
|
||||
|
||||
class Group(Document):
|
||||
name = StringField()
|
||||
|
||||
Group.drop_collection()
|
||||
|
||||
Group(name="hello - default").save()
|
||||
self.assertEqual(1, Group.objects.count())
|
||||
|
||||
with switch_db(Group, 'testdb-1') as Group:
|
||||
|
||||
self.assertEqual(0, Group.objects.count())
|
||||
|
||||
Group(name="hello").save()
|
||||
|
||||
self.assertEqual(1, Group.objects.count())
|
||||
|
||||
Group.drop_collection()
|
||||
self.assertEqual(0, Group.objects.count())
|
||||
|
||||
self.assertEqual(1, Group.objects.count())
|
||||
|
||||
def test_switch_collection_context_manager(self):
|
||||
connect('mongoenginetest')
|
||||
register_connection('testdb-1', 'mongoenginetest2')
|
||||
|
||||
class Group(Document):
|
||||
name = StringField()
|
||||
|
||||
Group.drop_collection()
|
||||
with switch_collection(Group, 'group1') as Group:
|
||||
Group.drop_collection()
|
||||
|
||||
Group(name="hello - group").save()
|
||||
self.assertEqual(1, Group.objects.count())
|
||||
|
||||
with switch_collection(Group, 'group1') as Group:
|
||||
|
||||
self.assertEqual(0, Group.objects.count())
|
||||
|
||||
Group(name="hello - group1").save()
|
||||
|
||||
self.assertEqual(1, Group.objects.count())
|
||||
|
||||
Group.drop_collection()
|
||||
self.assertEqual(0, Group.objects.count())
|
||||
|
||||
self.assertEqual(1, Group.objects.count())
|
||||
|
||||
def test_no_dereference_context_manager_object_id(self):
|
||||
"""Ensure that DBRef items in ListFields aren't dereferenced.
|
||||
"""
|
||||
connect('mongoenginetest')
|
||||
|
||||
class User(Document):
|
||||
name = StringField()
|
||||
|
||||
class Group(Document):
|
||||
ref = ReferenceField(User, dbref=False)
|
||||
generic = GenericReferenceField()
|
||||
members = ListField(ReferenceField(User, dbref=False))
|
||||
|
||||
User.drop_collection()
|
||||
Group.drop_collection()
|
||||
|
||||
for i in xrange(1, 51):
|
||||
User(name='user %s' % i).save()
|
||||
|
||||
user = User.objects.first()
|
||||
Group(ref=user, members=User.objects, generic=user).save()
|
||||
|
||||
with no_dereference(Group) as NoDeRefGroup:
|
||||
self.assertTrue(Group._fields['members']._auto_dereference)
|
||||
self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference)
|
||||
|
||||
with no_dereference(Group) as Group:
|
||||
group = Group.objects.first()
|
||||
self.assertTrue(all([not isinstance(m, User)
|
||||
for m in group.members]))
|
||||
self.assertFalse(isinstance(group.ref, User))
|
||||
self.assertFalse(isinstance(group.generic, User))
|
||||
|
||||
self.assertTrue(all([isinstance(m, User)
|
||||
for m in group.members]))
|
||||
self.assertTrue(isinstance(group.ref, User))
|
||||
self.assertTrue(isinstance(group.generic, User))
|
||||
|
||||
def test_no_dereference_context_manager_dbref(self):
|
||||
"""Ensure that DBRef items in ListFields aren't dereferenced.
|
||||
"""
|
||||
connect('mongoenginetest')
|
||||
|
||||
class User(Document):
|
||||
name = StringField()
|
||||
|
||||
class Group(Document):
|
||||
ref = ReferenceField(User, dbref=True)
|
||||
generic = GenericReferenceField()
|
||||
members = ListField(ReferenceField(User, dbref=True))
|
||||
|
||||
User.drop_collection()
|
||||
Group.drop_collection()
|
||||
|
||||
for i in xrange(1, 51):
|
||||
User(name='user %s' % i).save()
|
||||
|
||||
user = User.objects.first()
|
||||
Group(ref=user, members=User.objects, generic=user).save()
|
||||
|
||||
with no_dereference(Group) as NoDeRefGroup:
|
||||
self.assertTrue(Group._fields['members']._auto_dereference)
|
||||
self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference)
|
||||
|
||||
with no_dereference(Group) as Group:
|
||||
group = Group.objects.first()
|
||||
self.assertTrue(all([not isinstance(m, User)
|
||||
for m in group.members]))
|
||||
self.assertFalse(isinstance(group.ref, User))
|
||||
self.assertFalse(isinstance(group.generic, User))
|
||||
|
||||
self.assertTrue(all([isinstance(m, User)
|
||||
for m in group.members]))
|
||||
self.assertTrue(isinstance(group.ref, User))
|
||||
self.assertTrue(isinstance(group.generic, User))
|
||||
|
||||
def test_query_counter(self):
|
||||
connect('mongoenginetest')
|
||||
db = get_db()
|
||||
|
||||
with query_counter() as q:
|
||||
self.assertEqual(0, q)
|
||||
|
||||
for i in xrange(1, 51):
|
||||
db.test.find({}).count()
|
||||
|
||||
self.assertEqual(50, q)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -1121,77 +1121,6 @@ class FieldTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(q, 2)
|
||||
|
||||
def test_no_dereference_context_manager_object_id(self):
|
||||
"""Ensure that DBRef items in ListFields aren't dereferenced.
|
||||
"""
|
||||
class User(Document):
|
||||
name = StringField()
|
||||
|
||||
class Group(Document):
|
||||
ref = ReferenceField(User, dbref=False)
|
||||
generic = GenericReferenceField()
|
||||
members = ListField(ReferenceField(User, dbref=False))
|
||||
|
||||
User.drop_collection()
|
||||
Group.drop_collection()
|
||||
|
||||
for i in xrange(1, 51):
|
||||
User(name='user %s' % i).save()
|
||||
|
||||
user = User.objects.first()
|
||||
Group(ref=user, members=User.objects, generic=user).save()
|
||||
|
||||
with no_dereference(Group) as NoDeRefGroup:
|
||||
self.assertTrue(Group._fields['members']._auto_dereference)
|
||||
self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference)
|
||||
|
||||
with no_dereference(Group) as Group:
|
||||
group = Group.objects.first()
|
||||
self.assertTrue(all([not isinstance(m, User)
|
||||
for m in group.members]))
|
||||
self.assertFalse(isinstance(group.ref, User))
|
||||
self.assertFalse(isinstance(group.generic, User))
|
||||
|
||||
self.assertTrue(all([isinstance(m, User)
|
||||
for m in group.members]))
|
||||
self.assertTrue(isinstance(group.ref, User))
|
||||
self.assertTrue(isinstance(group.generic, User))
|
||||
|
||||
def test_no_dereference_context_manager_dbref(self):
|
||||
"""Ensure that DBRef items in ListFields aren't dereferenced.
|
||||
"""
|
||||
class User(Document):
|
||||
name = StringField()
|
||||
|
||||
class Group(Document):
|
||||
ref = ReferenceField(User, dbref=True)
|
||||
generic = GenericReferenceField()
|
||||
members = ListField(ReferenceField(User, dbref=True))
|
||||
|
||||
User.drop_collection()
|
||||
Group.drop_collection()
|
||||
|
||||
for i in xrange(1, 51):
|
||||
User(name='user %s' % i).save()
|
||||
|
||||
user = User.objects.first()
|
||||
Group(ref=user, members=User.objects, generic=user).save()
|
||||
|
||||
with no_dereference(Group) as NoDeRefGroup:
|
||||
self.assertTrue(Group._fields['members']._auto_dereference)
|
||||
self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference)
|
||||
|
||||
with no_dereference(Group) as Group:
|
||||
group = Group.objects.first()
|
||||
self.assertTrue(all([not isinstance(m, User)
|
||||
for m in group.members]))
|
||||
self.assertFalse(isinstance(group.ref, User))
|
||||
self.assertFalse(isinstance(group.generic, User))
|
||||
|
||||
self.assertTrue(all([isinstance(m, User)
|
||||
for m in group.members]))
|
||||
self.assertTrue(isinstance(group.ref, User))
|
||||
self.assertTrue(isinstance(group.generic, User))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user