diff --git a/mongoengine/base.py b/mongoengine/base.py index 7e6d0aa7..1d4e2d39 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -1121,6 +1121,22 @@ class BaseDocument(object): key not in self._changed_fields): self._changed_fields.append(key) + def _clear_changed_fields(self): + self._changed_fields = [] + EmbeddedDocumentField = _import_class("EmbeddedDocumentField") + for field_name, field in self._fields.iteritems(): + if (isinstance(field, ComplexBaseField) and + isinstance(field.field, EmbeddedDocumentField)): + field_value = getattr(self, field_name, None) + if field_value: + for idx in (field_value if isinstance(field_value, dict) + else xrange(len(field_value))): + field_value[idx]._clear_changed_fields() + elif isinstance(field, EmbeddedDocumentField): + field_value = getattr(self, field_name, None) + if field_value: + field_value._clear_changed_fields() + def _get_changed_fields(self, key='', inspected=None): """Returns a list of all fields that have explicitly been changed. """ diff --git a/mongoengine/document.py b/mongoengine/document.py index a251f589..d9cc2344 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -269,7 +269,7 @@ class Document(BaseDocument): if id_field not in self._meta.get('shard_key', []): self[id_field] = self._fields[id_field].to_python(object_id) - self._changed_fields = [] + self._clear_changed_fields() self._created = False signals.post_save.send(self.__class__, document=self, created=created) return self diff --git a/tests/test_document.py b/tests/test_document.py index 051dc2a3..6bbd2590 100644 --- a/tests/test_document.py +++ b/tests/test_document.py @@ -3370,7 +3370,7 @@ class DocumentTest(unittest.TestCase): } ) ]), "1,2") - def test_data_contains_idfield(self): + def test_data_contains_id_field(self): """Ensure that asking for _data returns 'id' """ class Person(Document): @@ -3383,6 +3383,47 @@ class DocumentTest(unittest.TestCase): self.assertTrue('_id' in person._data.keys()) self.assertEqual(person._data.get('_id'), person.id) + def test_complex_nesting_document_and_embedded_document(self): + + class Macro(EmbeddedDocument): + value = DynamicField(default="UNDEFINED") + + class Parameter(EmbeddedDocument): + macros = MapField(EmbeddedDocumentField(Macro)) + + def expand(self): + self.macros["test"] = Macro() + + class Node(Document): + parameters = MapField(EmbeddedDocumentField(Parameter)) + + def expand(self): + self.flattened_parameter = {} + for parameter_name, parameter in self.parameters.iteritems(): + parameter.expand() + + class System(Document): + name = StringField(required=True) + nodes = MapField(ReferenceField(Node, dbref=False)) + + def save(self, *args, **kwargs): + for node_name, node in self.nodes.iteritems(): + node.expand() + node.save(*args, **kwargs) + super(System, self).save(*args, **kwargs) + + System.drop_collection() + Node.drop_collection() + + system = System(name="system") + system.nodes["node"] = Node() + system.save() + system.nodes["node"].parameters["param"] = Parameter() + system.save() + + system = System.objects.first() + self.assertEqual("UNDEFINED", system.nodes["node"].parameters["param"].macros["test"].value) + class ValidatorErrorTest(unittest.TestCase):