diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 26999204..e1b43664 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -604,6 +604,11 @@ class ReferenceField(BaseField): def validate(self, value): assert isinstance(value, (self.document_type, pymongo.dbref.DBRef)) + if isinstance(value, Document) and value.id is None: + raise ValidationError('You can only reference documents once ' + 'they have been saved to the database') + + def lookup_member(self, member_name): return self.document_type._fields.get(member_name) @@ -628,6 +633,15 @@ class GenericReferenceField(BaseField): return super(GenericReferenceField, self).__get__(instance, owner) + def validate(self, value): + if not isinstance(value, (Document, pymongo.dbref.DBRef)): + raise ValidationError('GenericReferences can only contain documents') + + # We need the id from the saved object to create the DBRef + if isinstance(value, Document) and value.id is None: + raise ValidationError('You can only reference documents once ' + 'they have been saved to the database') + def dereference(self, value): doc_cls = get_document(value['_cls']) reference = value['_ref'] diff --git a/tests/fields.py b/tests/fields.py index c13f9e34..22049309 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -377,6 +377,7 @@ class FieldTest(unittest.TestCase): comments = ListField(EmbeddedDocumentField(Comment)) tags = ListField(StringField()) authors = ListField(ReferenceField(User)) + generic = ListField(GenericReferenceField()) post = BlogPost(content='Went for a walk today...') post.validate() @@ -404,8 +405,28 @@ class FieldTest(unittest.TestCase): self.assertRaises(ValidationError, post.validate) post.authors = [User()] + self.assertRaises(ValidationError, post.validate) + + user = User() + user.save() + post.authors = [user] post.validate() + post.generic = [1, 2] + self.assertRaises(ValidationError, post.validate) + + post.generic = [User(), Comment()] + self.assertRaises(ValidationError, post.validate) + + post.generic = [Comment()] + self.assertRaises(ValidationError, post.validate) + + post.generic = [user] + post.validate() + + User.drop_collection() + BlogPost.drop_collection() + def test_sorted_list_sorting(self): """Ensure that a sorted list field properly sorts values. """