import unittest from mongoengine import * from mongoengine.connection import get_db from mongoengine.context_managers import (switch_db, switch_collection, no_sub_classes, no_dereference, query_counter) from mongoengine.pymongo_support import count_documents class ContextManagersTest(unittest.TestCase): def test_switch_db_context_manager(self): connect('mongoenginetest') register_connection('testdb-1', 'mongoenginetest2') class Group(Document): name = StringField() Group.drop_collection() Group(name="hello - default").save() self.assertEqual(1, Group.objects.count()) with switch_db(Group, 'testdb-1') as Group: self.assertEqual(0, Group.objects.count()) Group(name="hello").save() self.assertEqual(1, Group.objects.count()) Group.drop_collection() self.assertEqual(0, Group.objects.count()) self.assertEqual(1, Group.objects.count()) def test_switch_collection_context_manager(self): connect('mongoenginetest') register_connection(alias='testdb-1', db='mongoenginetest2') class Group(Document): name = StringField() Group.drop_collection() # drops in default with switch_collection(Group, 'group1') as Group: Group.drop_collection() # drops in group1 Group(name="hello - group").save() self.assertEqual(1, Group.objects.count()) with switch_collection(Group, 'group1') as Group: self.assertEqual(0, Group.objects.count()) Group(name="hello - group1").save() self.assertEqual(1, Group.objects.count()) Group.drop_collection() self.assertEqual(0, Group.objects.count()) self.assertEqual(1, Group.objects.count()) def test_no_dereference_context_manager_object_id(self): """Ensure that DBRef items in ListFields aren't dereferenced. """ connect('mongoenginetest') class User(Document): name = StringField() class Group(Document): ref = ReferenceField(User, dbref=False) generic = GenericReferenceField() members = ListField(ReferenceField(User, dbref=False)) User.drop_collection() Group.drop_collection() for i in range(1, 51): User(name='user %s' % i).save() user = User.objects.first() Group(ref=user, members=User.objects, generic=user).save() with no_dereference(Group) as NoDeRefGroup: self.assertTrue(Group._fields['members']._auto_dereference) self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference) with no_dereference(Group) as Group: group = Group.objects.first() for m in group.members: self.assertNotIsInstance(m, User) self.assertNotIsInstance(group.ref, User) self.assertNotIsInstance(group.generic, User) for m in group.members: self.assertIsInstance(m, User) self.assertIsInstance(group.ref, User) self.assertIsInstance(group.generic, User) def test_no_dereference_context_manager_dbref(self): """Ensure that DBRef items in ListFields aren't dereferenced. """ connect('mongoenginetest') class User(Document): name = StringField() class Group(Document): ref = ReferenceField(User, dbref=True) generic = GenericReferenceField() members = ListField(ReferenceField(User, dbref=True)) User.drop_collection() Group.drop_collection() for i in range(1, 51): User(name='user %s' % i).save() user = User.objects.first() Group(ref=user, members=User.objects, generic=user).save() with no_dereference(Group) as NoDeRefGroup: self.assertTrue(Group._fields['members']._auto_dereference) self.assertFalse(NoDeRefGroup._fields['members']._auto_dereference) with no_dereference(Group) as Group: group = Group.objects.first() self.assertTrue(all([not isinstance(m, User) for m in group.members])) self.assertNotIsInstance(group.ref, User) self.assertNotIsInstance(group.generic, User) self.assertTrue(all([isinstance(m, User) for m in group.members])) self.assertIsInstance(group.ref, User) self.assertIsInstance(group.generic, User) def test_no_sub_classes(self): class A(Document): x = IntField() meta = {'allow_inheritance': True} class B(A): z = IntField() class C(B): zz = IntField() A.drop_collection() A(x=10).save() A(x=15).save() B(x=20).save() B(x=30).save() C(x=40).save() self.assertEqual(A.objects.count(), 5) self.assertEqual(B.objects.count(), 3) self.assertEqual(C.objects.count(), 1) with no_sub_classes(A): self.assertEqual(A.objects.count(), 2) for obj in A.objects: self.assertEqual(obj.__class__, A) with no_sub_classes(B): self.assertEqual(B.objects.count(), 2) for obj in B.objects: self.assertEqual(obj.__class__, B) with no_sub_classes(C): self.assertEqual(C.objects.count(), 1) for obj in C.objects: self.assertEqual(obj.__class__, C) # Confirm context manager exit correctly self.assertEqual(A.objects.count(), 5) self.assertEqual(B.objects.count(), 3) self.assertEqual(C.objects.count(), 1) def test_no_sub_classes_modification_to_document_class_are_temporary(self): class A(Document): x = IntField() meta = {'allow_inheritance': True} class B(A): z = IntField() self.assertEqual(A._subclasses, ('A', 'A.B')) with no_sub_classes(A): self.assertEqual(A._subclasses, ('A',)) self.assertEqual(A._subclasses, ('A', 'A.B')) self.assertEqual(B._subclasses, ('A.B',)) with no_sub_classes(B): self.assertEqual(B._subclasses, ('A.B',)) self.assertEqual(B._subclasses, ('A.B',)) def test_no_subclass_context_manager_does_not_swallow_exception(self): class User(Document): name = StringField() with self.assertRaises(TypeError): with no_sub_classes(User): raise TypeError() def test_query_counter_does_not_swallow_exception(self): with self.assertRaises(TypeError): with query_counter() as q: raise TypeError() def test_query_counter_temporarily_modifies_profiling_level(self): connect('mongoenginetest') db = get_db() initial_profiling_level = db.profiling_level() try: NEW_LEVEL = 1 db.set_profiling_level(NEW_LEVEL) self.assertEqual(db.profiling_level(), NEW_LEVEL) with query_counter() as q: self.assertEqual(db.profiling_level(), 2) self.assertEqual(db.profiling_level(), NEW_LEVEL) except Exception: db.set_profiling_level(initial_profiling_level) # Ensures it gets reseted no matter the outcome of the test raise def test_query_counter(self): connect('mongoenginetest') db = get_db() collection = db.query_counter collection.drop() def issue_1_count_query(): count_documents(collection, {}) def issue_1_insert_query(): collection.insert_one({'test': 'garbage'}) def issue_1_find_query(): collection.find_one() counter = 0 with query_counter() as q: self.assertEqual(q, counter) self.assertEqual(q, counter) # Ensures previous count query did not get counted for _ in range(10): issue_1_insert_query() counter += 1 self.assertEqual(q, counter) for _ in range(4): issue_1_find_query() counter += 1 self.assertEqual(q, counter) for _ in range(3): issue_1_count_query() counter += 1 self.assertEqual(q, counter) self.assertEqual(int(q), counter) # test __int__ self.assertEqual(repr(q), str(int(q))) # test __repr__ self.assertGreater(q, -1) # test __gt__ self.assertGreaterEqual(q, int(q)) # test __gte__ self.assertNotEqual(q, -1) self.assertLess(q, 1000) self.assertLessEqual(q, int(q)) def test_query_counter_counts_getmore_queries(self): connect('mongoenginetest') db = get_db() collection = db.query_counter collection.drop() many_docs = [{'test': 'garbage %s' % i} for i in range(150)] collection.insert_many(many_docs) # first batch of documents contains 101 documents with query_counter() as q: self.assertEqual(q, 0) list(collection.find()) self.assertEqual(q, 2) # 1st select + 1 getmore def test_query_counter_ignores_particular_queries(self): connect('mongoenginetest') db = get_db() collection = db.query_counter collection.insert_many([{'test': 'garbage %s' % i} for i in range(10)]) with query_counter() as q: self.assertEqual(q, 0) cursor = collection.find() self.assertEqual(q, 0) # cursor wasn't opened yet _ = next(cursor) # opens the cursor and fires the find query self.assertEqual(q, 1) cursor.close() # issues a `killcursors` query that is ignored by the context self.assertEqual(q, 1) _ = db.system.indexes.find_one() # queries on db.system.indexes are ignored as well self.assertEqual(q, 1) if __name__ == '__main__': unittest.main()