Added switch_collection context manager and method (#220)
This commit is contained in:
		
							
								
								
									
										153
									
								
								tests/test_context_managers.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										153
									
								
								tests/test_context_managers.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,153 @@ | ||||
| from __future__ import with_statement | ||||
| import unittest | ||||
|  | ||||
| from mongoengine import * | ||||
| from mongoengine.connection import get_db | ||||
| from mongoengine.context_managers import (switch_db, switch_collection, | ||||
|                                           no_dereference, query_counter) | ||||
|  | ||||
|  | ||||
| class ContextManagersTest(unittest.TestCase): | ||||
|  | ||||
|     def test_switch_db_context_manager(self): | ||||
|         connect('mongoenginetest') | ||||
|         register_connection('testdb-1', 'mongoenginetest2') | ||||
|  | ||||
|         class Group(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         Group(name="hello - default").save() | ||||
|         self.assertEqual(1, Group.objects.count()) | ||||
|  | ||||
|         with switch_db(Group, 'testdb-1') as Group: | ||||
|  | ||||
|             self.assertEqual(0, Group.objects.count()) | ||||
|  | ||||
|             Group(name="hello").save() | ||||
|  | ||||
|             self.assertEqual(1, Group.objects.count()) | ||||
|  | ||||
|             Group.drop_collection() | ||||
|             self.assertEqual(0, Group.objects.count()) | ||||
|  | ||||
|         self.assertEqual(1, Group.objects.count()) | ||||
|  | ||||
|     def test_switch_collection_context_manager(self): | ||||
|         connect('mongoenginetest') | ||||
|         register_connection('testdb-1', 'mongoenginetest2') | ||||
|  | ||||
|         class Group(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|         Group.drop_collection() | ||||
|         with switch_collection(Group, 'group1') as Group: | ||||
|             Group.drop_collection() | ||||
|  | ||||
|         Group(name="hello - group").save() | ||||
|         self.assertEqual(1, Group.objects.count()) | ||||
|  | ||||
|         with switch_collection(Group, 'group1') as Group: | ||||
|  | ||||
|             self.assertEqual(0, Group.objects.count()) | ||||
|  | ||||
|             Group(name="hello - group1").save() | ||||
|  | ||||
|             self.assertEqual(1, Group.objects.count()) | ||||
|  | ||||
|             Group.drop_collection() | ||||
|             self.assertEqual(0, Group.objects.count()) | ||||
|  | ||||
|         self.assertEqual(1, Group.objects.count()) | ||||
|  | ||||
|     def test_no_dereference_context_manager_object_id(self): | ||||
|         """Ensure that DBRef items in ListFields aren't dereferenced. | ||||
|         """ | ||||
|         connect('mongoenginetest') | ||||
|  | ||||
|         class User(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|         class Group(Document): | ||||
|             ref = ReferenceField(User, dbref=False) | ||||
|             generic = GenericReferenceField() | ||||
|             members = ListField(ReferenceField(User, dbref=False)) | ||||
|  | ||||
|         User.drop_collection() | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         for i in xrange(1, 51): | ||||
|             User(name='user %s' % i).save() | ||||
|  | ||||
|         user = User.objects.first() | ||||
|         Group(ref=user, members=User.objects, generic=user).save() | ||||
|  | ||||
|         with no_dereference(Group) as NoDeRefGroup: | ||||
|             self.assertTrue(Group._fields['members']._auto_dereference) | ||||
|             self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference) | ||||
|  | ||||
|         with no_dereference(Group) as Group: | ||||
|             group = Group.objects.first() | ||||
|             self.assertTrue(all([not isinstance(m, User) | ||||
|                                 for m in group.members])) | ||||
|             self.assertFalse(isinstance(group.ref, User)) | ||||
|             self.assertFalse(isinstance(group.generic, User)) | ||||
|  | ||||
|         self.assertTrue(all([isinstance(m, User) | ||||
|                              for m in group.members])) | ||||
|         self.assertTrue(isinstance(group.ref, User)) | ||||
|         self.assertTrue(isinstance(group.generic, User)) | ||||
|  | ||||
|     def test_no_dereference_context_manager_dbref(self): | ||||
|         """Ensure that DBRef items in ListFields aren't dereferenced. | ||||
|         """ | ||||
|         connect('mongoenginetest') | ||||
|  | ||||
|         class User(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|         class Group(Document): | ||||
|             ref = ReferenceField(User, dbref=True) | ||||
|             generic = GenericReferenceField() | ||||
|             members = ListField(ReferenceField(User, dbref=True)) | ||||
|  | ||||
|         User.drop_collection() | ||||
|         Group.drop_collection() | ||||
|  | ||||
|         for i in xrange(1, 51): | ||||
|             User(name='user %s' % i).save() | ||||
|  | ||||
|         user = User.objects.first() | ||||
|         Group(ref=user, members=User.objects, generic=user).save() | ||||
|  | ||||
|         with no_dereference(Group) as NoDeRefGroup: | ||||
|             self.assertTrue(Group._fields['members']._auto_dereference) | ||||
|             self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference) | ||||
|  | ||||
|         with no_dereference(Group) as Group: | ||||
|             group = Group.objects.first() | ||||
|             self.assertTrue(all([not isinstance(m, User) | ||||
|                                 for m in group.members])) | ||||
|             self.assertFalse(isinstance(group.ref, User)) | ||||
|             self.assertFalse(isinstance(group.generic, User)) | ||||
|  | ||||
|         self.assertTrue(all([isinstance(m, User) | ||||
|                              for m in group.members])) | ||||
|         self.assertTrue(isinstance(group.ref, User)) | ||||
|         self.assertTrue(isinstance(group.generic, User)) | ||||
|  | ||||
|     def test_query_counter(self): | ||||
|         connect('mongoenginetest') | ||||
|         db = get_db() | ||||
|  | ||||
|         with query_counter() as q: | ||||
|             self.assertEqual(0, q) | ||||
|  | ||||
|             for i in xrange(1, 51): | ||||
|                 db.test.find({}).count() | ||||
|  | ||||
|             self.assertEqual(50, q) | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
		Reference in New Issue
	
	Block a user