diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 21a24358..5b62cb7b 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -165,26 +165,29 @@ class BaseField(object): """ pass - def _validate(self, value, **kwargs): + def _validate_choices(self, value): Document = _import_class('Document') EmbeddedDocument = _import_class('EmbeddedDocument') + choice_list = self.choices + if isinstance(choice_list[0], (list, tuple)): + choice_list = [k for k, v in choice_list] + + # Choices which are other types of Documents + if isinstance(value, (Document, EmbeddedDocument)): + if not any(isinstance(value, c) for c in choice_list): + self.error( + 'Value must be instance of %s' % unicode(choice_list) + ) + # Choices which are types other than Documents + elif value not in choice_list: + self.error('Value must be one of %s' % unicode(choice_list)) + + + def _validate(self, value, **kwargs): # Check the Choices Constraint if self.choices: - - choice_list = self.choices - if isinstance(self.choices[0], (list, tuple)): - choice_list = [k for k, v in self.choices] - - # Choices which are other types of Documents - if isinstance(value, (Document, EmbeddedDocument)): - if not any(isinstance(value, c) for c in choice_list): - self.error( - 'Value must be instance of %s' % unicode(choice_list) - ) - # Choices which are types other than Documents - elif value not in choice_list: - self.error('Value must be one of %s' % unicode(choice_list)) + self._validate_choices(value) # check validation argument if self.validation is not None: diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 63c708b2..098988e0 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1140,6 +1140,35 @@ class GenericReferenceField(BaseField): .. versionadded:: 0.3 """ + def __init__(self, *args, **kwargs): + choices = kwargs.pop('choices', None) + super(GenericReferenceField, self).__init__(*args, **kwargs) + self._original_choices = choices or [] + self._cooked_choices = None + + def _validate_choices(self, value): + if isinstance(value, dict): + # If the field has not been dereferenced, it is still a dict + # of class and DBRef + if value.get('_cls') in [c.__name__ for c in self.choices]: + return + super(GenericReferenceField, self)._validate_choices(value) + + @property + def choices(self): + if self._cooked_choices is None: + self._cooked_choices = [] + for choice in self._original_choices: + if isinstance(choice, basestring): + choice = get_document(choice) + self._cooked_choices.append(choice) + return self._cooked_choices + + @choices.setter + def choices(self, value): + self._original_choices = value + self._cooked_choices = None + def __get__(self, instance, owner): if instance is None: return self diff --git a/tests/fields/fields.py b/tests/fields/fields.py index a772de6d..9f9db25d 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -2396,6 +2396,62 @@ class FieldTest(unittest.TestCase): bm = Bookmark.objects.first() self.assertEqual(bm.bookmark_object, post_1) + def test_generic_reference_string_choices(self): + """Ensure that a GenericReferenceField can handle choices as strings + """ + class Link(Document): + title = StringField() + + class Post(Document): + title = StringField() + + class Bookmark(Document): + bookmark_object = GenericReferenceField(choices=('Post', Link)) + + Link.drop_collection() + Post.drop_collection() + Bookmark.drop_collection() + + link_1 = Link(title="Pitchfork") + link_1.save() + + post_1 = Post(title="Behind the Scenes of the Pavement Reunion") + post_1.save() + + bm = Bookmark(bookmark_object=link_1) + bm.save() + + bm = Bookmark(bookmark_object=post_1) + bm.save() + + bm = Bookmark(bookmark_object=bm) + self.assertRaises(ValidationError, bm.validate) + + def test_generic_reference_choices_no_dereference(self): + """Ensure that a GenericReferenceField can handle choices on + non-derefenreced (i.e. DBRef) elements + """ + class Post(Document): + title = StringField() + + class Bookmark(Document): + bookmark_object = GenericReferenceField(choices=(Post, )) + other_field = StringField() + + Post.drop_collection() + Bookmark.drop_collection() + + post_1 = Post(title="Behind the Scenes of the Pavement Reunion") + post_1.save() + + bm = Bookmark(bookmark_object=post_1) + bm.save() + + bm = Bookmark.objects.get(id=bm.id) + # bookmark_object is now a DBRef + bm.other_field = 'dummy_change' + bm.save() + def test_generic_reference_list_choices(self): """Ensure that a ListField properly dereferences generic references and respects choices.