diff --git a/AUTHORS b/AUTHORS index fb9a1ce5..18258953 100644 --- a/AUTHORS +++ b/AUTHORS @@ -215,4 +215,4 @@ that much better: * André Ericson https://github.com/aericson) * Mikhail Moshnogorsky (https://github.com/mikhailmoshnogorsky) * Diego Berrocal (https://github.com/cestdiego) - * Matthew Ellison (https://github.com/mmelliso) + * Matthew Ellison (https://github.com/seglberg) diff --git a/docs/changelog.rst b/docs/changelog.rst index fc8b281c..ae4ba85a 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in 0.9.X - DEV ====================== +- Field Choices Now Accept Subclasses of Documents - Ensure Indexes before Each Save #812 - Generate Unique Indices for Lists of EmbeddedDocuments #358 - Sparse fields #515 diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 6fad5d63..359ea6d2 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -158,21 +158,23 @@ class BaseField(object): def _validate(self, value, **kwargs): Document = _import_class('Document') EmbeddedDocument = _import_class('EmbeddedDocument') - # check choices + + # Check the Choices Constraint if self.choices: - is_cls = isinstance(value, (Document, EmbeddedDocument)) - value_to_check = value.__class__ if is_cls else value - err_msg = 'an instance' if is_cls else 'one' + + choice_list = self.choices if isinstance(self.choices[0], (list, tuple)): - option_keys = [k for k, v in self.choices] - if value_to_check not in option_keys: - msg = ('Value must be %s of %s' % - (err_msg, unicode(option_keys))) - self.error(msg) - elif value_to_check not in self.choices: - msg = ('Value must be %s of %s' % - (err_msg, unicode(self.choices))) - self.error(msg) + choice_list = [k for k, v in self.choices] + + # Choices which are other types of Documents + if isinstance(value, (Document, EmbeddedDocument)): + if not any(isinstance(value, c) for c in choice_list): + self.error( + 'Value must be instance of %s' % unicode(choice_list) + ) + # Choices which are types other than Documents + elif value not in choice_list: + self.error('Value must be one of %s' % unicode(choice_list)) # check validation argument if self.validation is not None: diff --git a/tests/fields/fields.py b/tests/fields/fields.py index a95001b4..ab220b25 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -2446,6 +2446,79 @@ class FieldTest(unittest.TestCase): Shirt.drop_collection() + def test_choices_validation_documents(self): + """ + Ensure fields with document choices validate given a valid choice. + """ + class UserComments(EmbeddedDocument): + author = StringField() + message = StringField() + + class BlogPost(Document): + comments = ListField( + GenericEmbeddedDocumentField(choices=(UserComments,)) + ) + + # Ensure Validation Passes + BlogPost(comments=[ + UserComments(author='user2', message='message2'), + ]).save() + + def test_choices_validation_documents_invalid(self): + """ + Ensure fields with document choices validate given an invalid choice. + This should throw a ValidationError exception. + """ + class UserComments(EmbeddedDocument): + author = StringField() + message = StringField() + + class ModeratorComments(EmbeddedDocument): + author = StringField() + message = StringField() + + class BlogPost(Document): + comments = ListField( + GenericEmbeddedDocumentField(choices=(UserComments,)) + ) + + # Single Entry Failure + post = BlogPost(comments=[ + ModeratorComments(author='mod1', message='message1'), + ]) + self.assertRaises(ValidationError, post.save) + + # Mixed Entry Failure + post = BlogPost(comments=[ + ModeratorComments(author='mod1', message='message1'), + UserComments(author='user2', message='message2'), + ]) + self.assertRaises(ValidationError, post.save) + + def test_choices_validation_documents_inheritance(self): + """ + Ensure fields with document choices validate given subclass of choice. + """ + class Comments(EmbeddedDocument): + meta = { + 'abstract': True + } + author = StringField() + message = StringField() + + class UserComments(Comments): + pass + + class BlogPost(Document): + comments = ListField( + GenericEmbeddedDocumentField(choices=(Comments,)) + ) + + # Save Valid EmbeddedDocument Type + BlogPost(comments=[ + UserComments(author='user2', message='message2'), + ]).save() + def test_choices_get_field_display(self): """Test dynamic helper for returning the display value of a choices field.