From 6d68ad735cfcb6f37b964612867f31ff188b622f Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Tue, 22 Jan 2013 17:56:15 +0000 Subject: [PATCH] Fixed validation for GenericReferences Where the references haven't been dereferenced --- docs/changelog.rst | 1 + mongoengine/fields.py | 8 ++- tests/test_dereference.py | 105 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 112 insertions(+), 2 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index c0757fb1..ba2c04c8 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -31,6 +31,7 @@ Changes in 0.8.X - Uses getlasterror to test created on updated saves (#163) - Fixed inheritance and unique index creation (#140) - Fixed reverse delete rule with inheritance (#197) +- Fixed validation for GenericReferences which havent been dereferenced Changes in 0.7.9 ================ diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 5f11ae3b..f7817742 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -865,11 +865,15 @@ class GenericReferenceField(BaseField): return super(GenericReferenceField, self).__get__(instance, owner) def validate(self, value): - if not isinstance(value, (Document, DBRef)): + if not isinstance(value, (Document, DBRef, dict, SON)): self.error('GenericReferences can only contain documents') + if isinstance(value, (dict, SON)): + if '_ref' not in value or '_cls' not in value: + self.error('GenericReferences can only contain documents') + # We need the id from the saved object to create the DBRef - if isinstance(value, Document) and value.id is None: + elif isinstance(value, Document) and value.id is None: self.error('You can only reference documents once they have been' ' saved to the database') diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 8557ec5c..f42482d1 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -1,4 +1,7 @@ from __future__ import with_statement +import sys +sys.path[0:0] = [""] + import unittest from bson import DBRef, ObjectId @@ -1018,3 +1021,105 @@ class FieldTest(unittest.TestCase): msg = Message.objects.get(id=1) self.assertEqual(0, msg.comments[0].id) self.assertEqual(1, msg.comments[1].id) + + def test_list_item_dereference_dref_false_save_doesnt_cause_extra_queries(self): + """Ensure that DBRef items in ListFields are dereferenced. + """ + class User(Document): + name = StringField() + + class Group(Document): + name = StringField() + members = ListField(ReferenceField(User, dbref=False)) + + User.drop_collection() + Group.drop_collection() + + for i in xrange(1, 51): + User(name='user %s' % i).save() + + Group(name="Test", members=User.objects).save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + group_obj.name = "new test" + group_obj.save() + + self.assertEqual(q, 2) + + def test_list_item_dereference_dref_true_save_doesnt_cause_extra_queries(self): + """Ensure that DBRef items in ListFields are dereferenced. + """ + class User(Document): + name = StringField() + + class Group(Document): + name = StringField() + members = ListField(ReferenceField(User, dbref=True)) + + User.drop_collection() + Group.drop_collection() + + for i in xrange(1, 51): + User(name='user %s' % i).save() + + Group(name="Test", members=User.objects).save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + group_obj.name = "new test" + group_obj.save() + + self.assertEqual(q, 2) + + def test_generic_reference_save_doesnt_cause_extra_queries(self): + + class UserA(Document): + name = StringField() + + class UserB(Document): + name = StringField() + + class UserC(Document): + name = StringField() + + class Group(Document): + name = StringField() + members = ListField(GenericReferenceField()) + + UserA.drop_collection() + UserB.drop_collection() + UserC.drop_collection() + Group.drop_collection() + + members = [] + for i in xrange(1, 51): + a = UserA(name='User A %s' % i).save() + b = UserB(name='User B %s' % i).save() + c = UserC(name='User C %s' % i).save() + + members += [a, b, c] + + Group(name="test", members=members).save() + + with query_counter() as q: + self.assertEqual(q, 0) + + group_obj = Group.objects.first() + self.assertEqual(q, 1) + + group_obj.name = "new test" + group_obj.save() + + self.assertEqual(q, 2) + +if __name__ == '__main__': + unittest.main()