Fixes 2 bugs in no_subclasses context mgr (__exit__ swallows exception + repair feature)

This commit is contained in:
Bastien Gérard 2018-09-01 23:30:50 +02:00
parent 1b0c761fc0
commit a7852a89cc
2 changed files with 38 additions and 15 deletions

View File

@ -145,18 +145,17 @@ class no_sub_classes(object):
:param cls: the class to turn querying sub classes on :param cls: the class to turn querying sub classes on
""" """
self.cls = cls self.cls = cls
self.cls_initial_subclasses = None
def __enter__(self): def __enter__(self):
"""Change the objects default and _auto_dereference values.""" """Change the objects default and _auto_dereference values."""
self.cls._all_subclasses = self.cls._subclasses self.cls_initial_subclasses = self.cls._subclasses
self.cls._subclasses = (self.cls,) self.cls._subclasses = (self.cls._class_name,)
return self.cls return self.cls
def __exit__(self, t, value, traceback): def __exit__(self, t, value, traceback):
"""Reset the default and _auto_dereference values.""" """Reset the default and _auto_dereference values."""
self.cls._subclasses = self.cls._all_subclasses self.cls._subclasses = self.cls_initial_subclasses
delattr(self.cls, '_all_subclasses')
return self.cls
class query_counter(object): class query_counter(object):

View File

@ -140,8 +140,6 @@ class ContextManagersTest(unittest.TestCase):
def test_no_sub_classes(self): def test_no_sub_classes(self):
class A(Document): class A(Document):
x = IntField() x = IntField()
y = IntField()
meta = {'allow_inheritance': True} meta = {'allow_inheritance': True}
class B(A): class B(A):
@ -152,29 +150,29 @@ class ContextManagersTest(unittest.TestCase):
A.drop_collection() A.drop_collection()
A(x=10, y=20).save() A(x=10).save()
A(x=15, y=30).save() A(x=15).save()
B(x=20, y=40).save() B(x=20).save()
B(x=30, y=50).save() B(x=30).save()
C(x=40, y=60).save() C(x=40).save()
self.assertEqual(A.objects.count(), 5) self.assertEqual(A.objects.count(), 5)
self.assertEqual(B.objects.count(), 3) self.assertEqual(B.objects.count(), 3)
self.assertEqual(C.objects.count(), 1) self.assertEqual(C.objects.count(), 1)
with no_sub_classes(A) as A: with no_sub_classes(A):
self.assertEqual(A.objects.count(), 2) self.assertEqual(A.objects.count(), 2)
for obj in A.objects: for obj in A.objects:
self.assertEqual(obj.__class__, A) self.assertEqual(obj.__class__, A)
with no_sub_classes(B) as B: with no_sub_classes(B):
self.assertEqual(B.objects.count(), 2) self.assertEqual(B.objects.count(), 2)
for obj in B.objects: for obj in B.objects:
self.assertEqual(obj.__class__, B) self.assertEqual(obj.__class__, B)
with no_sub_classes(C) as C: with no_sub_classes(C):
self.assertEqual(C.objects.count(), 1) self.assertEqual(C.objects.count(), 1)
for obj in C.objects: for obj in C.objects:
@ -185,6 +183,32 @@ class ContextManagersTest(unittest.TestCase):
self.assertEqual(B.objects.count(), 3) self.assertEqual(B.objects.count(), 3)
self.assertEqual(C.objects.count(), 1) self.assertEqual(C.objects.count(), 1)
def test_no_sub_classes_modification_to_document_class_are_temporary(self):
class A(Document):
x = IntField()
meta = {'allow_inheritance': True}
class B(A):
z = IntField()
self.assertEqual(A._subclasses, ('A', 'A.B'))
with no_sub_classes(A):
self.assertEqual(A._subclasses, ('A',))
self.assertEqual(A._subclasses, ('A', 'A.B'))
self.assertEqual(B._subclasses, ('A.B',))
with no_sub_classes(B):
self.assertEqual(B._subclasses, ('A.B',))
self.assertEqual(B._subclasses, ('A.B',))
def test_no_subclass_context_manager_does_not_swallow_exception(self):
class User(Document):
name = StringField()
with self.assertRaises(TypeError):
with no_sub_classes(User):
raise TypeError()
def test_query_counter(self): def test_query_counter(self):
connect('mongoenginetest') connect('mongoenginetest')
db = get_db() db = get_db()