From a5fb009b62e7e7aca00ca1e53bf1648c30ac9c38 Mon Sep 17 00:00:00 2001 From: Emmanuel Leblond Date: Mon, 6 Jul 2015 02:33:43 +0200 Subject: [PATCH 1/3] Fix GenericReferenceField choices with DBRef and let it possible to set Document choice as string --- mongoengine/base/fields.py | 33 ++++++++++++---------- mongoengine/fields.py | 29 ++++++++++++++++++++ tests/fields/fields.py | 56 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 103 insertions(+), 15 deletions(-) diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 21a24358..5b62cb7b 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -165,26 +165,29 @@ class BaseField(object): """ pass - def _validate(self, value, **kwargs): + def _validate_choices(self, value): Document = _import_class('Document') EmbeddedDocument = _import_class('EmbeddedDocument') + choice_list = self.choices + if isinstance(choice_list[0], (list, tuple)): + choice_list = [k for k, v in choice_list] + + # Choices which are other types of Documents + if isinstance(value, (Document, EmbeddedDocument)): + if not any(isinstance(value, c) for c in choice_list): + self.error( + 'Value must be instance of %s' % unicode(choice_list) + ) + # Choices which are types other than Documents + elif value not in choice_list: + self.error('Value must be one of %s' % unicode(choice_list)) + + + def _validate(self, value, **kwargs): # Check the Choices Constraint if self.choices: - - choice_list = self.choices - if isinstance(self.choices[0], (list, tuple)): - choice_list = [k for k, v in self.choices] - - # Choices which are other types of Documents - if isinstance(value, (Document, EmbeddedDocument)): - if not any(isinstance(value, c) for c in choice_list): - self.error( - 'Value must be instance of %s' % unicode(choice_list) - ) - # Choices which are types other than Documents - elif value not in choice_list: - self.error('Value must be one of %s' % unicode(choice_list)) + self._validate_choices(value) # check validation argument if self.validation is not None: diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 63c708b2..098988e0 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1140,6 +1140,35 @@ class GenericReferenceField(BaseField): .. versionadded:: 0.3 """ + def __init__(self, *args, **kwargs): + choices = kwargs.pop('choices', None) + super(GenericReferenceField, self).__init__(*args, **kwargs) + self._original_choices = choices or [] + self._cooked_choices = None + + def _validate_choices(self, value): + if isinstance(value, dict): + # If the field has not been dereferenced, it is still a dict + # of class and DBRef + if value.get('_cls') in [c.__name__ for c in self.choices]: + return + super(GenericReferenceField, self)._validate_choices(value) + + @property + def choices(self): + if self._cooked_choices is None: + self._cooked_choices = [] + for choice in self._original_choices: + if isinstance(choice, basestring): + choice = get_document(choice) + self._cooked_choices.append(choice) + return self._cooked_choices + + @choices.setter + def choices(self, value): + self._original_choices = value + self._cooked_choices = None + def __get__(self, instance, owner): if instance is None: return self diff --git a/tests/fields/fields.py b/tests/fields/fields.py index a772de6d..9f9db25d 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -2396,6 +2396,62 @@ class FieldTest(unittest.TestCase): bm = Bookmark.objects.first() self.assertEqual(bm.bookmark_object, post_1) + def test_generic_reference_string_choices(self): + """Ensure that a GenericReferenceField can handle choices as strings + """ + class Link(Document): + title = StringField() + + class Post(Document): + title = StringField() + + class Bookmark(Document): + bookmark_object = GenericReferenceField(choices=('Post', Link)) + + Link.drop_collection() + Post.drop_collection() + Bookmark.drop_collection() + + link_1 = Link(title="Pitchfork") + link_1.save() + + post_1 = Post(title="Behind the Scenes of the Pavement Reunion") + post_1.save() + + bm = Bookmark(bookmark_object=link_1) + bm.save() + + bm = Bookmark(bookmark_object=post_1) + bm.save() + + bm = Bookmark(bookmark_object=bm) + self.assertRaises(ValidationError, bm.validate) + + def test_generic_reference_choices_no_dereference(self): + """Ensure that a GenericReferenceField can handle choices on + non-derefenreced (i.e. DBRef) elements + """ + class Post(Document): + title = StringField() + + class Bookmark(Document): + bookmark_object = GenericReferenceField(choices=(Post, )) + other_field = StringField() + + Post.drop_collection() + Bookmark.drop_collection() + + post_1 = Post(title="Behind the Scenes of the Pavement Reunion") + post_1.save() + + bm = Bookmark(bookmark_object=post_1) + bm.save() + + bm = Bookmark.objects.get(id=bm.id) + # bookmark_object is now a DBRef + bm.other_field = 'dummy_change' + bm.save() + def test_generic_reference_list_choices(self): """Ensure that a ListField properly dereferences generic references and respects choices. From 34c67907624773e0bf2f1f14589d8c3c88b438b9 Mon Sep 17 00:00:00 2001 From: Emmanuel Leblond Date: Mon, 6 Jul 2015 10:10:05 +0200 Subject: [PATCH 2/3] Simplify implementation of choices in GenericReferenceField --- mongoengine/fields.py | 33 ++++++++++++++------------------- 1 file changed, 14 insertions(+), 19 deletions(-) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 098988e0..695f5caa 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1143,32 +1143,27 @@ class GenericReferenceField(BaseField): def __init__(self, *args, **kwargs): choices = kwargs.pop('choices', None) super(GenericReferenceField, self).__init__(*args, **kwargs) - self._original_choices = choices or [] - self._cooked_choices = None + self.choices = [] + # Keep the choices as a list of allowed Document class names + if choices: + for choice in choices: + if isinstance(choice, basestring): + self.choices.append(choice) + elif isinstance(choice, type) and issubclass(choice, Document): + self.choices.append(choice._class_name) + else: + self.error('Invalid choices provided: must be a list of' + 'Document subclasses and/or basestrings') def _validate_choices(self, value): if isinstance(value, dict): # If the field has not been dereferenced, it is still a dict # of class and DBRef - if value.get('_cls') in [c.__name__ for c in self.choices]: - return + value = value.get('_cls') + elif isinstance(value, Document): + value = value._class_name super(GenericReferenceField, self)._validate_choices(value) - @property - def choices(self): - if self._cooked_choices is None: - self._cooked_choices = [] - for choice in self._original_choices: - if isinstance(choice, basestring): - choice = get_document(choice) - self._cooked_choices.append(choice) - return self._cooked_choices - - @choices.setter - def choices(self, value): - self._original_choices = value - self._cooked_choices = None - def __get__(self, instance, owner): if instance is None: return self From bebce2c0538f7fa39fe34058e4adcecc16afecf9 Mon Sep 17 00:00:00 2001 From: Emmanuel Leblond Date: Thu, 9 Jul 2015 10:51:04 +0200 Subject: [PATCH 3/3] Clean ununsed variables in iterations --- mongoengine/base/fields.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 5b62cb7b..304c084d 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -171,7 +171,7 @@ class BaseField(object): choice_list = self.choices if isinstance(choice_list[0], (list, tuple)): - choice_list = [k for k, v in choice_list] + choice_list = [k for k, _ in choice_list] # Choices which are other types of Documents if isinstance(value, (Document, EmbeddedDocument)): @@ -311,7 +311,7 @@ class ComplexBaseField(BaseField): value_dict[k] = self.to_python(v) if is_list: # Convert back to a list - return [v for k, v in sorted(value_dict.items(), + return [v for _, v in sorted(value_dict.items(), key=operator.itemgetter(0))] return value_dict @@ -378,7 +378,7 @@ class ComplexBaseField(BaseField): value_dict[k] = self.to_mongo(v) if is_list: # Convert back to a list - return [v for k, v in sorted(value_dict.items(), + return [v for _, v in sorted(value_dict.items(), key=operator.itemgetter(0))] return value_dict