diff --git a/docs/changelog.rst b/docs/changelog.rst index d93bf7bc..628c1ee3 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,11 +4,12 @@ Changelog Changes in 0.6.X ================ +- Added choices for GenericEmbeddedDocuments - Fixed Django 1.4 sessions first save data loss - FileField now automatically delete files on .delete() - Fix for GenericReference to_mongo method - Fixed connection regression -- Django User document allows inheritance +- Updated Django User document, now allows inheritance Changes in 0.6.7 ================ diff --git a/mongoengine/base.py b/mongoengine/base.py index 995f1326..139c326c 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -223,16 +223,19 @@ class BaseField(object): pass def _validate(self, value): + from mongoengine import EmbeddedDocument # check choices if self.choices: + is_cls = isinstance(value, EmbeddedDocument) + value_to_check = value.__class__ if is_cls else value + err_msg = 'an instance' if is_cls else 'one' if isinstance(self.choices[0], (list, tuple)): option_keys = [option_key for option_key, option_value in self.choices] - if value not in option_keys: - self.error('Value must be one of %s' % unicode(option_keys)) - else: - if value not in self.choices: - self.error('Value must be one of %s' % unicode(self.choices)) + if value_to_check not in option_keys: + self.error('Value must be %s of %s' % (err_msg, unicode(option_keys))) + elif value_to_check not in self.choices: + self.error('Value must be %s of %s' % (err_msg, unicode(self.choices))) # check validation argument if self.validation is not None: @@ -400,7 +403,7 @@ class ComplexBaseField(BaseField): sequence = enumerate(value) for k, v in sequence: try: - self.field.validate(v) + self.field._validate(v) except (ValidationError, AssertionError), error: if hasattr(error, 'errors'): errors[k] = error.errors diff --git a/mongoengine/fields.py b/mongoengine/fields.py index f7344488..2e614d28 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -659,7 +659,7 @@ class ReferenceField(BaseField): def to_mongo(self, document): if isinstance(document, DBRef): return document - + id_field_name = self.document_type._meta['id_field'] id_field = self.document_type._fields[id_field_name] @@ -734,9 +734,9 @@ class GenericReferenceField(BaseField): def to_mongo(self, document): if document is None: return None - + if isinstance(document, (dict, SON)): - return document + return document id_field_name = document.__class__._meta['id_field'] id_field = document.__class__._fields[id_field_name] diff --git a/tests/fields.py b/tests/fields.py index 31d3b588..9b9ea980 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -1877,6 +1877,8 @@ class FieldTest(unittest.TestCase): name = StringField() like = GenericEmbeddedDocumentField() + Person.drop_collection() + person = Person(name='Test User') person.like = Car(name='Fiat') person.save() @@ -1890,6 +1892,54 @@ class FieldTest(unittest.TestCase): person = Person.objects.first() self.assertTrue(isinstance(person.like, Dish)) + def test_generic_embedded_document_choices(self): + class Car(EmbeddedDocument): + name = StringField() + + class Dish(EmbeddedDocument): + food = StringField(required=True) + number = IntField() + + class Person(Document): + name = StringField() + like = GenericEmbeddedDocumentField(choices=(Dish,)) + + Person.drop_collection() + + person = Person(name='Test User') + person.like = Car(name='Fiat') + self.assertRaises(ValidationError, person.validate) + + person.like = Dish(food="arroz", number=15) + person.save() + + person = Person.objects.first() + self.assertTrue(isinstance(person.like, Dish)) + + def test_generic_list_embedded_document_choices(self): + class Car(EmbeddedDocument): + name = StringField() + + class Dish(EmbeddedDocument): + food = StringField(required=True) + number = IntField() + + class Person(Document): + name = StringField() + likes = ListField(GenericEmbeddedDocumentField(choices=(Dish,))) + + Person.drop_collection() + + person = Person(name='Test User') + person.likes = [Car(name='Fiat')] + self.assertRaises(ValidationError, person.validate) + + person.likes = [Dish(food="arroz", number=15)] + person.save() + + person = Person.objects.first() + self.assertTrue(isinstance(person.likes[0], Dish)) + def test_recursive_validation(self): """Ensure that a validation result to_dict is available. """