Fixed validation for GenericReferences

Where the references haven't been dereferenced
This commit is contained in:
Ross Lawley 2013-01-22 17:56:15 +00:00
parent c44b98a7e1
commit 6d68ad735c
3 changed files with 112 additions and 2 deletions

View File

@ -31,6 +31,7 @@ Changes in 0.8.X
- Uses getlasterror to test created on updated saves (#163) - Uses getlasterror to test created on updated saves (#163)
- Fixed inheritance and unique index creation (#140) - Fixed inheritance and unique index creation (#140)
- Fixed reverse delete rule with inheritance (#197) - Fixed reverse delete rule with inheritance (#197)
- Fixed validation for GenericReferences which havent been dereferenced
Changes in 0.7.9 Changes in 0.7.9
================ ================

View File

@ -865,11 +865,15 @@ class GenericReferenceField(BaseField):
return super(GenericReferenceField, self).__get__(instance, owner) return super(GenericReferenceField, self).__get__(instance, owner)
def validate(self, value): def validate(self, value):
if not isinstance(value, (Document, DBRef)): if not isinstance(value, (Document, DBRef, dict, SON)):
self.error('GenericReferences can only contain documents') self.error('GenericReferences can only contain documents')
if isinstance(value, (dict, SON)):
if '_ref' not in value or '_cls' not in value:
self.error('GenericReferences can only contain documents')
# We need the id from the saved object to create the DBRef # We need the id from the saved object to create the DBRef
if isinstance(value, Document) and value.id is None: elif isinstance(value, Document) and value.id is None:
self.error('You can only reference documents once they have been' self.error('You can only reference documents once they have been'
' saved to the database') ' saved to the database')

View File

@ -1,4 +1,7 @@
from __future__ import with_statement from __future__ import with_statement
import sys
sys.path[0:0] = [""]
import unittest import unittest
from bson import DBRef, ObjectId from bson import DBRef, ObjectId
@ -1018,3 +1021,105 @@ class FieldTest(unittest.TestCase):
msg = Message.objects.get(id=1) msg = Message.objects.get(id=1)
self.assertEqual(0, msg.comments[0].id) self.assertEqual(0, msg.comments[0].id)
self.assertEqual(1, msg.comments[1].id) self.assertEqual(1, msg.comments[1].id)
def test_list_item_dereference_dref_false_save_doesnt_cause_extra_queries(self):
"""Ensure that DBRef items in ListFields are dereferenced.
"""
class User(Document):
name = StringField()
class Group(Document):
name = StringField()
members = ListField(ReferenceField(User, dbref=False))
User.drop_collection()
Group.drop_collection()
for i in xrange(1, 51):
User(name='user %s' % i).save()
Group(name="Test", members=User.objects).save()
with query_counter() as q:
self.assertEqual(q, 0)
group_obj = Group.objects.first()
self.assertEqual(q, 1)
group_obj.name = "new test"
group_obj.save()
self.assertEqual(q, 2)
def test_list_item_dereference_dref_true_save_doesnt_cause_extra_queries(self):
"""Ensure that DBRef items in ListFields are dereferenced.
"""
class User(Document):
name = StringField()
class Group(Document):
name = StringField()
members = ListField(ReferenceField(User, dbref=True))
User.drop_collection()
Group.drop_collection()
for i in xrange(1, 51):
User(name='user %s' % i).save()
Group(name="Test", members=User.objects).save()
with query_counter() as q:
self.assertEqual(q, 0)
group_obj = Group.objects.first()
self.assertEqual(q, 1)
group_obj.name = "new test"
group_obj.save()
self.assertEqual(q, 2)
def test_generic_reference_save_doesnt_cause_extra_queries(self):
class UserA(Document):
name = StringField()
class UserB(Document):
name = StringField()
class UserC(Document):
name = StringField()
class Group(Document):
name = StringField()
members = ListField(GenericReferenceField())
UserA.drop_collection()
UserB.drop_collection()
UserC.drop_collection()
Group.drop_collection()
members = []
for i in xrange(1, 51):
a = UserA(name='User A %s' % i).save()
b = UserB(name='User B %s' % i).save()
c = UserC(name='User C %s' % i).save()
members += [a, b, c]
Group(name="test", members=members).save()
with query_counter() as q:
self.assertEqual(q, 0)
group_obj = Group.objects.first()
self.assertEqual(q, 1)
group_obj.name = "new test"
group_obj.save()
self.assertEqual(q, 2)
if __name__ == '__main__':
unittest.main()