diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 6c4a06c9..0aa51b2d 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -688,6 +688,11 @@ class GenericEmbeddedDocumentField(BaseField): return value def validate(self, value, clean=True): + if self.choices and isinstance(value, SON): + for choice in self.choices: + if value['_cls'] == choice._class_name: + return True + if not isinstance(value, EmbeddedDocument): self.error('Invalid embedded document instance provided to an ' 'GenericEmbeddedDocumentField') @@ -705,7 +710,6 @@ class GenericEmbeddedDocumentField(BaseField): def to_mongo(self, document, use_db_field=True, fields=None): if document is None: return None - data = document.to_mongo(use_db_field, fields) if '_cls' not in data: data['_cls'] = document._class_name diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 28e84af4..43800fff 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -2086,6 +2086,23 @@ class QuerySetTest(unittest.TestCase): Site.objects(id=s.id).update_one( pull_all__collaborators__helpful__user=['Ross']) + def test_pull_in_genericembedded_field(self): + + class Foo(EmbeddedDocument): + name = StringField() + + class Bar(Document): + foos = ListField(GenericEmbeddedDocumentField( + choices=[Foo, ])) + + Bar.drop_collection() + + foo = Foo(name="bar") + bar = Bar(foos=[foo]).save() + Bar.objects(id=bar.id).update(pull__foos=foo) + bar.reload() + self.assertEqual(len(bar.foos), 0) + def test_update_one_pop_generic_reference(self): class BlogTag(Document): @@ -2179,6 +2196,24 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(message.authors[1].name, "Ross") self.assertEqual(message.authors[2].name, "Adam") + def test_set_generic_embedded_documents(self): + + class Bar(EmbeddedDocument): + name = StringField() + + class User(Document): + username = StringField() + bar = GenericEmbeddedDocumentField(choices=[Bar,]) + + User.drop_collection() + + User(username='abc').save() + User.objects(username='abc').update( + set__bar=Bar(name='test'), upsert=True) + + user = User.objects(username='abc').first() + self.assertEqual(user.bar.name, "test") + def test_reload_embedded_docs_instance(self): class SubDoc(EmbeddedDocument):