Added switch_collection context manager and method (#220)
This commit is contained in:
parent
c8b65317ef
commit
9797d7a7fb
@ -2,7 +2,7 @@ from mongoengine.common import _import_class
|
|||||||
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
|
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
|
||||||
from mongoengine.queryset import OperationError, QuerySet
|
from mongoengine.queryset import OperationError, QuerySet
|
||||||
|
|
||||||
__all__ = ("switch_db", "no_dereference", "query_counter")
|
__all__ = ("switch_db", "switch_collection", "no_dereference", "query_counter")
|
||||||
|
|
||||||
|
|
||||||
class switch_db(object):
|
class switch_db(object):
|
||||||
@ -47,6 +47,49 @@ class switch_db(object):
|
|||||||
self.cls._collection = self.collection
|
self.cls._collection = self.collection
|
||||||
|
|
||||||
|
|
||||||
|
class switch_collection(object):
|
||||||
|
""" switch_collection alias context manager.
|
||||||
|
|
||||||
|
Example ::
|
||||||
|
|
||||||
|
class Group(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
Group(name="test").save() # Saves in the default db
|
||||||
|
|
||||||
|
with switch_collection(Group, 'group1') as Group:
|
||||||
|
Group(name="hello testdb!").save() # Saves in group1 collection
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cls, collection_name):
|
||||||
|
""" Construct the switch_collection context manager
|
||||||
|
|
||||||
|
:param cls: the class to change the registered db
|
||||||
|
:param collection_name: the name of the collection to use
|
||||||
|
"""
|
||||||
|
self.cls = cls
|
||||||
|
self.ori_collection = cls._get_collection()
|
||||||
|
self.ori_get_collection_name = cls._get_collection_name
|
||||||
|
self.collection_name = collection_name
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
""" change the _get_collection_name and clear the cached collection """
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_collection_name(cls):
|
||||||
|
return self.collection_name
|
||||||
|
|
||||||
|
self.cls._get_collection_name = _get_collection_name
|
||||||
|
self.cls._collection = None
|
||||||
|
return self.cls
|
||||||
|
|
||||||
|
def __exit__(self, t, value, traceback):
|
||||||
|
""" Reset the collection """
|
||||||
|
self.cls._collection = self.ori_collection
|
||||||
|
self.cls._get_collection_name = self.ori_get_collection_name
|
||||||
|
|
||||||
|
|
||||||
class no_dereference(object):
|
class no_dereference(object):
|
||||||
""" no_dereference context manager.
|
""" no_dereference context manager.
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ from mongoengine.base import (DocumentMetaclass, TopLevelDocumentMetaclass,
|
|||||||
ALLOW_INHERITANCE, get_document)
|
ALLOW_INHERITANCE, get_document)
|
||||||
from mongoengine.queryset import OperationError, NotUniqueError
|
from mongoengine.queryset import OperationError, NotUniqueError
|
||||||
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
|
from mongoengine.connection import get_db, DEFAULT_CONNECTION_NAME
|
||||||
from mongoengine.context_managers import switch_db
|
from mongoengine.context_managers import switch_db, switch_collection
|
||||||
|
|
||||||
__all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument',
|
__all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument',
|
||||||
'DynamicEmbeddedDocument', 'OperationError',
|
'DynamicEmbeddedDocument', 'OperationError',
|
||||||
@ -398,6 +398,31 @@ class Document(BaseDocument):
|
|||||||
self._objects._collection_obj = collection
|
self._objects._collection_obj = collection
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def switch_collection(self, collection_name):
|
||||||
|
"""
|
||||||
|
Temporarily switch the collection for a document instance.
|
||||||
|
|
||||||
|
Only really useful for archiving off data and calling `save()`::
|
||||||
|
|
||||||
|
user = User.objects.get(id=user_id)
|
||||||
|
user.switch_collection('old-users')
|
||||||
|
user.save()
|
||||||
|
|
||||||
|
If you need to read from another database see
|
||||||
|
:class:`~mongoengine.context_managers.switch_collection`
|
||||||
|
|
||||||
|
:param collection_name: The database alias to use for saving the
|
||||||
|
document
|
||||||
|
"""
|
||||||
|
with switch_collection(self.__class__, collection_name) as cls:
|
||||||
|
collection = cls._get_collection()
|
||||||
|
self._get_collection = lambda: collection
|
||||||
|
self._collection = collection
|
||||||
|
self._created = True
|
||||||
|
self._objects = self.__class__.objects
|
||||||
|
self._objects._collection_obj = collection
|
||||||
|
return self
|
||||||
|
|
||||||
def select_related(self, max_depth=1):
|
def select_related(self, max_depth=1):
|
||||||
"""Handles dereferencing of :class:`~bson.dbref.DBRef` objects to
|
"""Handles dereferencing of :class:`~bson.dbref.DBRef` objects to
|
||||||
a maximum depth in order to cut down the number queries to mongodb.
|
a maximum depth in order to cut down the number queries to mongodb.
|
||||||
|
@ -95,30 +95,6 @@ class ConnectionTest(unittest.TestCase):
|
|||||||
date_doc = DateDoc.objects.first()
|
date_doc = DateDoc.objects.first()
|
||||||
self.assertEqual(d, date_doc.the_date)
|
self.assertEqual(d, date_doc.the_date)
|
||||||
|
|
||||||
def test_switch_db_context_manager(self):
|
|
||||||
connect('mongoenginetest')
|
|
||||||
register_connection('testdb-1', 'mongoenginetest2')
|
|
||||||
|
|
||||||
class Group(Document):
|
|
||||||
name = StringField()
|
|
||||||
|
|
||||||
Group.drop_collection()
|
|
||||||
|
|
||||||
Group(name="hello - default").save()
|
|
||||||
self.assertEqual(1, Group.objects.count())
|
|
||||||
|
|
||||||
with switch_db(Group, 'testdb-1') as Group:
|
|
||||||
|
|
||||||
self.assertEqual(0, Group.objects.count())
|
|
||||||
|
|
||||||
Group(name="hello").save()
|
|
||||||
|
|
||||||
self.assertEqual(1, Group.objects.count())
|
|
||||||
|
|
||||||
Group.drop_collection()
|
|
||||||
self.assertEqual(0, Group.objects.count())
|
|
||||||
|
|
||||||
self.assertEqual(1, Group.objects.count())
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
153
tests/test_context_managers.py
Normal file
153
tests/test_context_managers.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
from __future__ import with_statement
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from mongoengine import *
|
||||||
|
from mongoengine.connection import get_db
|
||||||
|
from mongoengine.context_managers import (switch_db, switch_collection,
|
||||||
|
no_dereference, query_counter)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextManagersTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_switch_db_context_manager(self):
|
||||||
|
connect('mongoenginetest')
|
||||||
|
register_connection('testdb-1', 'mongoenginetest2')
|
||||||
|
|
||||||
|
class Group(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
Group.drop_collection()
|
||||||
|
|
||||||
|
Group(name="hello - default").save()
|
||||||
|
self.assertEqual(1, Group.objects.count())
|
||||||
|
|
||||||
|
with switch_db(Group, 'testdb-1') as Group:
|
||||||
|
|
||||||
|
self.assertEqual(0, Group.objects.count())
|
||||||
|
|
||||||
|
Group(name="hello").save()
|
||||||
|
|
||||||
|
self.assertEqual(1, Group.objects.count())
|
||||||
|
|
||||||
|
Group.drop_collection()
|
||||||
|
self.assertEqual(0, Group.objects.count())
|
||||||
|
|
||||||
|
self.assertEqual(1, Group.objects.count())
|
||||||
|
|
||||||
|
def test_switch_collection_context_manager(self):
|
||||||
|
connect('mongoenginetest')
|
||||||
|
register_connection('testdb-1', 'mongoenginetest2')
|
||||||
|
|
||||||
|
class Group(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
Group.drop_collection()
|
||||||
|
with switch_collection(Group, 'group1') as Group:
|
||||||
|
Group.drop_collection()
|
||||||
|
|
||||||
|
Group(name="hello - group").save()
|
||||||
|
self.assertEqual(1, Group.objects.count())
|
||||||
|
|
||||||
|
with switch_collection(Group, 'group1') as Group:
|
||||||
|
|
||||||
|
self.assertEqual(0, Group.objects.count())
|
||||||
|
|
||||||
|
Group(name="hello - group1").save()
|
||||||
|
|
||||||
|
self.assertEqual(1, Group.objects.count())
|
||||||
|
|
||||||
|
Group.drop_collection()
|
||||||
|
self.assertEqual(0, Group.objects.count())
|
||||||
|
|
||||||
|
self.assertEqual(1, Group.objects.count())
|
||||||
|
|
||||||
|
def test_no_dereference_context_manager_object_id(self):
|
||||||
|
"""Ensure that DBRef items in ListFields aren't dereferenced.
|
||||||
|
"""
|
||||||
|
connect('mongoenginetest')
|
||||||
|
|
||||||
|
class User(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class Group(Document):
|
||||||
|
ref = ReferenceField(User, dbref=False)
|
||||||
|
generic = GenericReferenceField()
|
||||||
|
members = ListField(ReferenceField(User, dbref=False))
|
||||||
|
|
||||||
|
User.drop_collection()
|
||||||
|
Group.drop_collection()
|
||||||
|
|
||||||
|
for i in xrange(1, 51):
|
||||||
|
User(name='user %s' % i).save()
|
||||||
|
|
||||||
|
user = User.objects.first()
|
||||||
|
Group(ref=user, members=User.objects, generic=user).save()
|
||||||
|
|
||||||
|
with no_dereference(Group) as NoDeRefGroup:
|
||||||
|
self.assertTrue(Group._fields['members']._auto_dereference)
|
||||||
|
self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference)
|
||||||
|
|
||||||
|
with no_dereference(Group) as Group:
|
||||||
|
group = Group.objects.first()
|
||||||
|
self.assertTrue(all([not isinstance(m, User)
|
||||||
|
for m in group.members]))
|
||||||
|
self.assertFalse(isinstance(group.ref, User))
|
||||||
|
self.assertFalse(isinstance(group.generic, User))
|
||||||
|
|
||||||
|
self.assertTrue(all([isinstance(m, User)
|
||||||
|
for m in group.members]))
|
||||||
|
self.assertTrue(isinstance(group.ref, User))
|
||||||
|
self.assertTrue(isinstance(group.generic, User))
|
||||||
|
|
||||||
|
def test_no_dereference_context_manager_dbref(self):
|
||||||
|
"""Ensure that DBRef items in ListFields aren't dereferenced.
|
||||||
|
"""
|
||||||
|
connect('mongoenginetest')
|
||||||
|
|
||||||
|
class User(Document):
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class Group(Document):
|
||||||
|
ref = ReferenceField(User, dbref=True)
|
||||||
|
generic = GenericReferenceField()
|
||||||
|
members = ListField(ReferenceField(User, dbref=True))
|
||||||
|
|
||||||
|
User.drop_collection()
|
||||||
|
Group.drop_collection()
|
||||||
|
|
||||||
|
for i in xrange(1, 51):
|
||||||
|
User(name='user %s' % i).save()
|
||||||
|
|
||||||
|
user = User.objects.first()
|
||||||
|
Group(ref=user, members=User.objects, generic=user).save()
|
||||||
|
|
||||||
|
with no_dereference(Group) as NoDeRefGroup:
|
||||||
|
self.assertTrue(Group._fields['members']._auto_dereference)
|
||||||
|
self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference)
|
||||||
|
|
||||||
|
with no_dereference(Group) as Group:
|
||||||
|
group = Group.objects.first()
|
||||||
|
self.assertTrue(all([not isinstance(m, User)
|
||||||
|
for m in group.members]))
|
||||||
|
self.assertFalse(isinstance(group.ref, User))
|
||||||
|
self.assertFalse(isinstance(group.generic, User))
|
||||||
|
|
||||||
|
self.assertTrue(all([isinstance(m, User)
|
||||||
|
for m in group.members]))
|
||||||
|
self.assertTrue(isinstance(group.ref, User))
|
||||||
|
self.assertTrue(isinstance(group.generic, User))
|
||||||
|
|
||||||
|
def test_query_counter(self):
|
||||||
|
connect('mongoenginetest')
|
||||||
|
db = get_db()
|
||||||
|
|
||||||
|
with query_counter() as q:
|
||||||
|
self.assertEqual(0, q)
|
||||||
|
|
||||||
|
for i in xrange(1, 51):
|
||||||
|
db.test.find({}).count()
|
||||||
|
|
||||||
|
self.assertEqual(50, q)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -1121,77 +1121,6 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEqual(q, 2)
|
self.assertEqual(q, 2)
|
||||||
|
|
||||||
def test_no_dereference_context_manager_object_id(self):
|
|
||||||
"""Ensure that DBRef items in ListFields aren't dereferenced.
|
|
||||||
"""
|
|
||||||
class User(Document):
|
|
||||||
name = StringField()
|
|
||||||
|
|
||||||
class Group(Document):
|
|
||||||
ref = ReferenceField(User, dbref=False)
|
|
||||||
generic = GenericReferenceField()
|
|
||||||
members = ListField(ReferenceField(User, dbref=False))
|
|
||||||
|
|
||||||
User.drop_collection()
|
|
||||||
Group.drop_collection()
|
|
||||||
|
|
||||||
for i in xrange(1, 51):
|
|
||||||
User(name='user %s' % i).save()
|
|
||||||
|
|
||||||
user = User.objects.first()
|
|
||||||
Group(ref=user, members=User.objects, generic=user).save()
|
|
||||||
|
|
||||||
with no_dereference(Group) as NoDeRefGroup:
|
|
||||||
self.assertTrue(Group._fields['members']._auto_dereference)
|
|
||||||
self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference)
|
|
||||||
|
|
||||||
with no_dereference(Group) as Group:
|
|
||||||
group = Group.objects.first()
|
|
||||||
self.assertTrue(all([not isinstance(m, User)
|
|
||||||
for m in group.members]))
|
|
||||||
self.assertFalse(isinstance(group.ref, User))
|
|
||||||
self.assertFalse(isinstance(group.generic, User))
|
|
||||||
|
|
||||||
self.assertTrue(all([isinstance(m, User)
|
|
||||||
for m in group.members]))
|
|
||||||
self.assertTrue(isinstance(group.ref, User))
|
|
||||||
self.assertTrue(isinstance(group.generic, User))
|
|
||||||
|
|
||||||
def test_no_dereference_context_manager_dbref(self):
|
|
||||||
"""Ensure that DBRef items in ListFields aren't dereferenced.
|
|
||||||
"""
|
|
||||||
class User(Document):
|
|
||||||
name = StringField()
|
|
||||||
|
|
||||||
class Group(Document):
|
|
||||||
ref = ReferenceField(User, dbref=True)
|
|
||||||
generic = GenericReferenceField()
|
|
||||||
members = ListField(ReferenceField(User, dbref=True))
|
|
||||||
|
|
||||||
User.drop_collection()
|
|
||||||
Group.drop_collection()
|
|
||||||
|
|
||||||
for i in xrange(1, 51):
|
|
||||||
User(name='user %s' % i).save()
|
|
||||||
|
|
||||||
user = User.objects.first()
|
|
||||||
Group(ref=user, members=User.objects, generic=user).save()
|
|
||||||
|
|
||||||
with no_dereference(Group) as NoDeRefGroup:
|
|
||||||
self.assertTrue(Group._fields['members']._auto_dereference)
|
|
||||||
self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference)
|
|
||||||
|
|
||||||
with no_dereference(Group) as Group:
|
|
||||||
group = Group.objects.first()
|
|
||||||
self.assertTrue(all([not isinstance(m, User)
|
|
||||||
for m in group.members]))
|
|
||||||
self.assertFalse(isinstance(group.ref, User))
|
|
||||||
self.assertFalse(isinstance(group.generic, User))
|
|
||||||
|
|
||||||
self.assertTrue(all([isinstance(m, User)
|
|
||||||
for m in group.members]))
|
|
||||||
self.assertTrue(isinstance(group.ref, User))
|
|
||||||
self.assertTrue(isinstance(group.generic, User))
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user