diff --git a/AUTHORS b/AUTHORS index e8a43dac..21b0ec64 100644 --- a/AUTHORS +++ b/AUTHORS @@ -250,3 +250,4 @@ that much better: * Trevor Hall (https://github.com/tjhall13) * Gleb Voropaev (https://github.com/buggyspace) * Paulo Amaral (https://github.com/pauloAmaral) + * Gaurav Dadhania (https://github.com/GVRV) diff --git a/docs/changelog.rst b/docs/changelog.rst index 9d87d889..53373302 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -17,6 +17,7 @@ Changes in 0.17.0 - Fix InvalidStringData error when using modify on a BinaryField #1127 - DEPRECATION: `EmbeddedDocument.save` & `.reload` are marked as deprecated and will be removed in a next version of mongoengine #1552 - Fix test suite and CI to support MongoDB 3.4 #1445 +- Fix reference fields querying the database on each access if value contains orphan DBRefs ================= Changes in 0.16.3 diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 5586c5b7..598eb606 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -276,11 +276,16 @@ class ComplexBaseField(BaseField): _dereference = _import_class('DeReference')() - if instance._initialised and dereference and instance._data.get(self.name): + if (instance._initialised and + dereference and + instance._data.get(self.name) and + not getattr(instance._data[self.name], '_dereferenced', False)): instance._data[self.name] = _dereference( instance._data.get(self.name), max_depth=1, instance=instance, name=self.name ) + if hasattr(instance._data[self.name], '_dereferenced'): + instance._data[self.name]._dereferenced = True value = super(ComplexBaseField, self).__get__(instance, owner) diff --git a/tests/test_dereference.py b/tests/test_dereference.py index cf1194f4..9c565810 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -105,6 +105,14 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) + self.assertTrue(group_obj._data['members']._dereferenced) + + # verifies that no additional queries gets executed + # if we re-iterate over the ListField once it is + # dereferenced + [m for m in group_obj.members] + self.assertEqual(q, 2) + self.assertTrue(group_obj._data['members']._dereferenced) # Document select_related with query_counter() as q: @@ -125,6 +133,46 @@ class FieldTest(unittest.TestCase): [m for m in group_obj.members] self.assertEqual(q, 2) + def test_list_item_dereference_orphan_dbref(self): + """Ensure that orphan DBRef items in ListFields are dereferenced. + """ + class User(Document): + name = StringField() + + class Group(Document): + members = ListField(ReferenceField(User, dbref=False)) + + User.drop_collection() + Group.drop_collection() + + for i in range(1, 51): + user = User(name='user %s' % i) + user.save() + + group = Group(members=User.objects) + group.save() + group.reload() # Confirm reload works + + # Delete one User so one of the references in the + # Group.members list is an orphan DBRef + User.objects[0].delete() + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 2) + self.assertTrue(group_obj._data['members']._dereferenced) + + # verifies that no additional queries gets executed + # if we re-iterate over the ListField once it is + # dereferenced + [m for m in group_obj.members] + self.assertEqual(q, 2) + self.assertTrue(group_obj._data['members']._dereferenced) + User.drop_collection() Group.drop_collection() @@ -505,6 +553,61 @@ class FieldTest(unittest.TestCase): for m in group_obj.members: self.assertIn('User', m.__class__.__name__) + + def test_generic_reference_orphan_dbref(self): + """Ensure that generic orphan DBRef items in ListFields are dereferenced. + """ + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + members = ListField(GenericReferenceField()) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in range(1, 51): + a = UserA(name='User A %s' % i) + a.save() + + b = UserB(name='User B %s' % i) + b.save() + + c = UserC(name='User C %s' % i) + c.save() + + members += [a, b, c] + + group = Group(members=members) + group.save() + + # Delete one UserA instance so that there is + # an orphan DBRef in the GenericReference ListField + UserA.objects[0].delete() + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + self.assertTrue(group_obj._data['members']._dereferenced) + + [m for m in group_obj.members] + self.assertEqual(q, 4) + self.assertTrue(group_obj._data['members']._dereferenced) + UserA.drop_collection() UserB.drop_collection() UserC.drop_collection()