from __future__ import with_statement import unittest from mongoengine import * from mongoengine.connection import get_db from mongoengine.context_managers import (switch_db, switch_collection, no_dereference, query_counter) class ContextManagersTest(unittest.TestCase): def test_switch_db_context_manager(self): connect('mongoenginetest') register_connection('testdb-1', 'mongoenginetest2') class Group(Document): name = StringField() Group.drop_collection() Group(name="hello - default").save() self.assertEqual(1, Group.objects.count()) with switch_db(Group, 'testdb-1') as Group: self.assertEqual(0, Group.objects.count()) Group(name="hello").save() self.assertEqual(1, Group.objects.count()) Group.drop_collection() self.assertEqual(0, Group.objects.count()) self.assertEqual(1, Group.objects.count()) def test_switch_collection_context_manager(self): connect('mongoenginetest') register_connection('testdb-1', 'mongoenginetest2') class Group(Document): name = StringField() Group.drop_collection() with switch_collection(Group, 'group1') as Group: Group.drop_collection() Group(name="hello - group").save() self.assertEqual(1, Group.objects.count()) with switch_collection(Group, 'group1') as Group: self.assertEqual(0, Group.objects.count()) Group(name="hello - group1").save() self.assertEqual(1, Group.objects.count()) Group.drop_collection() self.assertEqual(0, Group.objects.count()) self.assertEqual(1, Group.objects.count()) def test_no_dereference_context_manager_object_id(self): """Ensure that DBRef items in ListFields aren't dereferenced. """ connect('mongoenginetest') class User(Document): name = StringField() class Group(Document): ref = ReferenceField(User, dbref=False) generic = GenericReferenceField() members = ListField(ReferenceField(User, dbref=False)) User.drop_collection() Group.drop_collection() for i in xrange(1, 51): User(name='user %s' % i).save() user = User.objects.first() Group(ref=user, members=User.objects, generic=user).save() with no_dereference(Group) as NoDeRefGroup: self.assertTrue(Group._fields['members']._auto_dereference) self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference) with no_dereference(Group) as Group: group = Group.objects.first() self.assertTrue(all([not isinstance(m, User) for m in group.members])) self.assertFalse(isinstance(group.ref, User)) self.assertFalse(isinstance(group.generic, User)) self.assertTrue(all([isinstance(m, User) for m in group.members])) self.assertTrue(isinstance(group.ref, User)) self.assertTrue(isinstance(group.generic, User)) def test_no_dereference_context_manager_dbref(self): """Ensure that DBRef items in ListFields aren't dereferenced. """ connect('mongoenginetest') class User(Document): name = StringField() class Group(Document): ref = ReferenceField(User, dbref=True) generic = GenericReferenceField() members = ListField(ReferenceField(User, dbref=True)) User.drop_collection() Group.drop_collection() for i in xrange(1, 51): User(name='user %s' % i).save() user = User.objects.first() Group(ref=user, members=User.objects, generic=user).save() with no_dereference(Group) as NoDeRefGroup: self.assertTrue(Group._fields['members']._auto_dereference) self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference) with no_dereference(Group) as Group: group = Group.objects.first() self.assertTrue(all([not isinstance(m, User) for m in group.members])) self.assertFalse(isinstance(group.ref, User)) self.assertFalse(isinstance(group.generic, User)) self.assertTrue(all([isinstance(m, User) for m in group.members])) self.assertTrue(isinstance(group.ref, User)) self.assertTrue(isinstance(group.generic, User)) def test_query_counter(self): connect('mongoenginetest') db = get_db() db.test.find({}) with query_counter() as q: self.assertEqual(0, q) for i in xrange(1, 51): db.test.find({}).count() self.assertEqual(50, q) if __name__ == '__main__': unittest.main()