From 2aa8b04c21c180546ee6de7643a9f035257fb25e Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 9 May 2012 13:21:53 +0100 Subject: [PATCH] Implemented Choices for GenericReferenceFields Refs mongoengine/mongoengine#13 --- docs/changelog.rst | 3 +- mongoengine/base.py | 5 ++- mongoengine/fields.py | 5 +++ tests/fields.py | 73 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 4 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 628c1ee3..c1019375 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,7 +4,8 @@ Changelog Changes in 0.6.X ================ -- Added choices for GenericEmbeddedDocuments +- Added support for choices with GenericReferenceFields +- Added support for choices with GenericEmbeddedDocumentFields - Fixed Django 1.4 sessions first save data loss - FileField now automatically delete files on .delete() - Fix for GenericReference to_mongo method diff --git a/mongoengine/base.py b/mongoengine/base.py index 139c326c..347332b2 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -223,11 +223,10 @@ class BaseField(object): pass def _validate(self, value): - from mongoengine import EmbeddedDocument - + from mongoengine import Document, EmbeddedDocument # check choices if self.choices: - is_cls = isinstance(value, EmbeddedDocument) + is_cls = isinstance(value, (Document, EmbeddedDocument)) value_to_check = value.__class__ if is_cls else value err_msg = 'an instance' if is_cls else 'one' if isinstance(self.choices[0], (list, tuple)): diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 2e614d28..3e8f09a4 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -441,6 +441,9 @@ class GenericEmbeddedDocumentField(BaseField): :class:`~mongoengine.EmbeddedDocument` to be stored. Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`. + + ..note :: You can use the choices param to limit the acceptable + EmbeddedDocument types """ def prepare_query_value(self, op, value): @@ -701,6 +704,8 @@ class GenericReferenceField(BaseField): ..note :: Any documents used as a generic reference must be registered in the document registry. Importing the model will automatically register it. + ..note :: You can use the choices param to limit the acceptable Document types + .. versionadded:: 0.3 """ diff --git a/tests/fields.py b/tests/fields.py index 9b9ea980..ea5262db 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -1301,6 +1301,74 @@ class FieldTest(unittest.TestCase): self.assertEquals(repr(Person.objects(city=None)), "[]") + + def test_generic_reference_choices(self): + """Ensure that a GenericReferenceField can handle choices + """ + class Link(Document): + title = StringField() + + class Post(Document): + title = StringField() + + class Bookmark(Document): + bookmark_object = GenericReferenceField(choices=(Post,)) + + Link.drop_collection() + Post.drop_collection() + Bookmark.drop_collection() + + link_1 = Link(title="Pitchfork") + link_1.save() + + post_1 = Post(title="Behind the Scenes of the Pavement Reunion") + post_1.save() + + bm = Bookmark(bookmark_object=link_1) + self.assertRaises(ValidationError, bm.validate) + + bm = Bookmark(bookmark_object=post_1) + bm.save() + + bm = Bookmark.objects.first() + self.assertEqual(bm.bookmark_object, post_1) + + def test_generic_reference_list_choices(self): + """Ensure that a ListField properly dereferences generic references and + respects choices. + """ + class Link(Document): + title = StringField() + + class Post(Document): + title = StringField() + + class User(Document): + bookmarks = ListField(GenericReferenceField(choices=(Post,))) + + Link.drop_collection() + Post.drop_collection() + User.drop_collection() + + link_1 = Link(title="Pitchfork") + link_1.save() + + post_1 = Post(title="Behind the Scenes of the Pavement Reunion") + post_1.save() + + user = User(bookmarks=[link_1]) + self.assertRaises(ValidationError, user.validate) + + user = User(bookmarks=[post_1]) + user.save() + + user = User.objects.first() + self.assertEqual(user.bookmarks, [post_1]) + + Link.drop_collection() + Post.drop_collection() + User.drop_collection() + def test_binary_fields(self): """Ensure that binary fields can be stored and retrieved. """ @@ -1893,6 +1961,8 @@ class FieldTest(unittest.TestCase): self.assertTrue(isinstance(person.like, Dish)) def test_generic_embedded_document_choices(self): + """Ensure you can limit GenericEmbeddedDocument choices + """ class Car(EmbeddedDocument): name = StringField() @@ -1917,6 +1987,9 @@ class FieldTest(unittest.TestCase): self.assertTrue(isinstance(person.like, Dish)) def test_generic_list_embedded_document_choices(self): + """Ensure you can limit GenericEmbeddedDocument choices inside a list + field + """ class Car(EmbeddedDocument): name = StringField()