Improved validation of (Generic)Reference fields

This commit is contained in:
Ross Lawley 2011-06-16 15:25:09 +01:00
parent 5e8604967c
commit 5cc9188c5b
2 changed files with 35 additions and 0 deletions

View File

@ -604,6 +604,11 @@ class ReferenceField(BaseField):
def validate(self, value): def validate(self, value):
assert isinstance(value, (self.document_type, pymongo.dbref.DBRef)) assert isinstance(value, (self.document_type, pymongo.dbref.DBRef))
if isinstance(value, Document) and value.id is None:
raise ValidationError('You can only reference documents once '
'they have been saved to the database')
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.document_type._fields.get(member_name) return self.document_type._fields.get(member_name)
@ -628,6 +633,15 @@ class GenericReferenceField(BaseField):
return super(GenericReferenceField, self).__get__(instance, owner) return super(GenericReferenceField, self).__get__(instance, owner)
def validate(self, value):
if not isinstance(value, (Document, pymongo.dbref.DBRef)):
raise ValidationError('GenericReferences can only contain documents')
# We need the id from the saved object to create the DBRef
if isinstance(value, Document) and value.id is None:
raise ValidationError('You can only reference documents once '
'they have been saved to the database')
def dereference(self, value): def dereference(self, value):
doc_cls = get_document(value['_cls']) doc_cls = get_document(value['_cls'])
reference = value['_ref'] reference = value['_ref']

View File

@ -377,6 +377,7 @@ class FieldTest(unittest.TestCase):
comments = ListField(EmbeddedDocumentField(Comment)) comments = ListField(EmbeddedDocumentField(Comment))
tags = ListField(StringField()) tags = ListField(StringField())
authors = ListField(ReferenceField(User)) authors = ListField(ReferenceField(User))
generic = ListField(GenericReferenceField())
post = BlogPost(content='Went for a walk today...') post = BlogPost(content='Went for a walk today...')
post.validate() post.validate()
@ -404,8 +405,28 @@ class FieldTest(unittest.TestCase):
self.assertRaises(ValidationError, post.validate) self.assertRaises(ValidationError, post.validate)
post.authors = [User()] post.authors = [User()]
self.assertRaises(ValidationError, post.validate)
user = User()
user.save()
post.authors = [user]
post.validate() post.validate()
post.generic = [1, 2]
self.assertRaises(ValidationError, post.validate)
post.generic = [User(), Comment()]
self.assertRaises(ValidationError, post.validate)
post.generic = [Comment()]
self.assertRaises(ValidationError, post.validate)
post.generic = [user]
post.validate()
User.drop_collection()
BlogPost.drop_collection()
def test_sorted_list_sorting(self): def test_sorted_list_sorting(self):
"""Ensure that a sorted list field properly sorts values. """Ensure that a sorted list field properly sorts values.
""" """