From 363e50abbe1db318472de82ad583c98cef3e61c3 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Thu, 8 Nov 2012 14:46:56 +0000 Subject: [PATCH] Updated documents with embedded documents can be created in a single operation (MongoEngine/mongoengine#6) --- docs/changelog.rst | 1 + mongoengine/base/document.py | 18 ++++++++++++++++-- mongoengine/common.py | 5 +++-- tests/document/instance.py | 36 ++++++++++++++++++++++++++++++++++++ 4 files changed, 56 insertions(+), 4 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 26108b5d..778a047f 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Changes in 0.8 ============== +- Added support for creating documents with embedded documents in a single operation (MongoEngine/mongoengine#6) - Added to_json and from_json to Document (MongoEngine/mongoengine#1) - Added to_json and from_json to QuerySet (MongoEngine/mongoengine#131) - Updated index creation now tied to Document class (MongoEngine/mongoengine#102) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 939c9fbc..2dd4b03b 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -28,7 +28,14 @@ class BaseDocument(object): _dynamic_lock = True _initialised = False - def __init__(self, **values): + def __init__(self, __auto_convert=True, **values): + """ + Initialise a document or embedded document + + :param __auto_convert: Try and will cast python objects to Object types + :param values: A dictionary of values for the document + """ + signals.pre_init.send(self.__class__, document=self, values=values) self._data = {} @@ -50,9 +57,16 @@ class BaseDocument(object): elif self._dynamic: dynamic_data[key] = value else: + FileField = _import_class('FileField') for key, value in values.iteritems(): key = self._reverse_db_field_map.get(key, key) + if (value is not None and __auto_convert and + key in self._fields): + field = self._fields.get(key) + if not isinstance(field, FileField): + value = field.to_python(value) setattr(self, key, value) + # Set any get_fieldname_display methods self.__set_field_display() @@ -487,7 +501,7 @@ class BaseDocument(object): % (cls._class_name, errors)) raise InvalidDocumentError(msg) - obj = cls(**data) + obj = cls(__auto_convert=False, **data) obj._changed_fields = changed_fields obj._created = False return obj diff --git a/mongoengine/common.py b/mongoengine/common.py index c76801ce..a8422c09 100644 --- a/mongoengine/common.py +++ b/mongoengine/common.py @@ -9,8 +9,9 @@ def _import_class(cls_name): doc_classes = ('Document', 'DynamicEmbeddedDocument', 'EmbeddedDocument', 'MapReduceDocument') field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField', - 'GenericReferenceField', 'GenericEmbeddedDocumentField', - 'GeoPointField', 'ReferenceField', 'StringField') + 'FileField', 'GenericReferenceField', + 'GenericEmbeddedDocumentField', 'GeoPointField', + 'ReferenceField', 'StringField') queryset_classes = ('OperationError',) deref_classes = ('DeReference',) diff --git a/tests/document/instance.py b/tests/document/instance.py index 2118575e..8fb4fd72 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -2005,5 +2005,41 @@ class ValidatorErrorTest(unittest.TestCase): self.assertRaises(OperationError, change_shard_key) + def test_kwargs_simple(self): + + class Embedded(EmbeddedDocument): + name = StringField() + + class Doc(Document): + doc_name = StringField() + doc = EmbeddedDocumentField(Embedded) + + classic_doc = Doc(doc_name="my doc", doc=Embedded(name="embedded doc")) + dict_doc = Doc(**{"doc_name": "my doc", + "doc": {"name": "embedded doc"}}) + + self.assertEqual(classic_doc, dict_doc) + self.assertEqual(classic_doc._data, dict_doc._data) + + def test_kwargs_complex(self): + + class Embedded(EmbeddedDocument): + name = StringField() + + class Doc(Document): + doc_name = StringField() + docs = ListField(EmbeddedDocumentField(Embedded)) + + classic_doc = Doc(doc_name="my doc", docs=[ + Embedded(name="embedded doc1"), + Embedded(name="embedded doc2")]) + dict_doc = Doc(**{"doc_name": "my doc", + "docs": [{"name": "embedded doc1"}, + {"name": "embedded doc2"}]}) + + self.assertEqual(classic_doc, dict_doc) + self.assertEqual(classic_doc._data, dict_doc._data) + + if __name__ == '__main__': unittest.main()