diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index e697fe40..55b40228 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -537,6 +537,9 @@ class BaseDocument: """Using _get_changed_fields iterate and remove any fields that are marked as changed. """ + ReferenceField = _import_class("ReferenceField") + GenericReferenceField = _import_class("GenericReferenceField") + for changed in self._get_changed_fields(): parts = changed.split(".") data = self @@ -549,7 +552,8 @@ class BaseDocument: elif isinstance(data, dict): data = data.get(part, None) else: - data = getattr(data, part, None) + field_name = data._reverse_db_field_map.get(part, part) + data = getattr(data, field_name, None) if not isinstance(data, LazyReference) and hasattr( data, "_changed_fields" @@ -558,10 +562,40 @@ class BaseDocument: continue data._changed_fields = [] + elif isinstance(data, (list, tuple, dict)): + if hasattr(data, "field") and isinstance( + data.field, (ReferenceField, GenericReferenceField) + ): + continue + BaseDocument._nestable_types_clear_changed_fields(data) self._changed_fields = [] - def _nestable_types_changed_fields(self, changed_fields, base_key, data): + @staticmethod + def _nestable_types_clear_changed_fields(data): + """Inspect nested data for changed fields + + :param data: data to inspect for changes + """ + Document = _import_class("Document") + + # Loop list / dict fields as they contain documents + # Determine the iterator to use + if not hasattr(data, "items"): + iterator = enumerate(data) + else: + iterator = data.items() + + for index_or_key, value in iterator: + if hasattr(value, "_get_changed_fields") and not isinstance( + value, Document + ): # don't follow references + value._clear_changed_fields() + elif isinstance(value, (list, tuple, dict)): + BaseDocument._nestable_types_clear_changed_fields(value) + + @staticmethod + def _nestable_types_changed_fields(changed_fields, base_key, data): """Inspect nested data for changed fields :param changed_fields: Previously collected changed fields @@ -586,7 +620,9 @@ class BaseDocument: changed = value._get_changed_fields() changed_fields += ["{}{}".format(item_key, k) for k in changed if k] elif isinstance(value, (list, tuple, dict)): - self._nestable_types_changed_fields(changed_fields, item_key, value) + BaseDocument._nestable_types_changed_fields( + changed_fields, item_key, value + ) def _get_changed_fields(self): """Return a list of all fields that have explicitly been changed. diff --git a/tests/document/test_delta.py b/tests/document/test_delta.py index 2324211b..27439bc2 100644 --- a/tests/document/test_delta.py +++ b/tests/document/test_delta.py @@ -537,6 +537,7 @@ class TestDelta(MongoDBTestCase): {}, ) doc.save() + assert doc._get_changed_fields() == [] doc = doc.reload(10) assert doc.embedded_field.list_field[0] == "1" @@ -767,9 +768,7 @@ class TestDelta(MongoDBTestCase): MyDoc.drop_collection() - mydoc = MyDoc( - name="testcase1", subs={"a": {"b": EmbeddedDoc(name="foo")}} - ).save() + MyDoc(name="testcase1", subs={"a": {"b": EmbeddedDoc(name="foo")}}).save() mydoc = MyDoc.objects.first() subdoc = mydoc.subs["a"]["b"] @@ -781,6 +780,35 @@ class TestDelta(MongoDBTestCase): mydoc._clear_changed_fields() assert [] == mydoc._get_changed_fields() + def test_nested_nested_fields_db_field_set__gets_mark_as_changed_and_cleaned(self): + class EmbeddedDoc(EmbeddedDocument): + name = StringField(db_field="db_name") + + class MyDoc(Document): + embed = EmbeddedDocumentField(EmbeddedDoc, db_field="db_embed") + name = StringField(db_field="db_name") + + MyDoc.drop_collection() + + MyDoc(name="testcase1", embed=EmbeddedDoc(name="foo")).save() + + mydoc = MyDoc.objects.first() + mydoc.embed.name = "foo1" + + assert mydoc.embed._get_changed_fields() == ["db_name"] + assert mydoc._get_changed_fields() == ["db_embed.db_name"] + + mydoc = MyDoc.objects.first() + embed = EmbeddedDoc(name="foo2") + embed.name = "bar" + mydoc.embed = embed + + assert embed._get_changed_fields() == ["db_name"] + assert mydoc._get_changed_fields() == ["db_embed"] + + mydoc._clear_changed_fields() + assert mydoc._get_changed_fields() == [] + def test_lower_level_mark_as_changed(self): class EmbeddedDoc(EmbeddedDocument): name = StringField() diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 0f9f412c..8ba429f4 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -370,8 +370,7 @@ class FieldTest(unittest.TestCase): assert Post.objects.all()[0].user_lists == [[u1, u2], [u3]] def test_circular_reference(self): - """Ensure you can handle circular references - """ + """Ensure you can handle circular references""" class Relation(EmbeddedDocument): name = StringField() @@ -426,6 +425,7 @@ class FieldTest(unittest.TestCase): daughter.relations.append(mother) daughter.relations.append(daughter) + assert daughter._get_changed_fields() == ["relations"] daughter.save() assert "[, ]" == "%s" % Person.objects()