fix inconsistencies in ._changed_fields computation
This commit is contained in:
		| @@ -537,6 +537,9 @@ class BaseDocument: | |||||||
|         """Using _get_changed_fields iterate and remove any fields that |         """Using _get_changed_fields iterate and remove any fields that | ||||||
|         are marked as changed. |         are marked as changed. | ||||||
|         """ |         """ | ||||||
|  |         ReferenceField = _import_class("ReferenceField") | ||||||
|  |         GenericReferenceField = _import_class("GenericReferenceField") | ||||||
|  |  | ||||||
|         for changed in self._get_changed_fields(): |         for changed in self._get_changed_fields(): | ||||||
|             parts = changed.split(".") |             parts = changed.split(".") | ||||||
|             data = self |             data = self | ||||||
| @@ -549,7 +552,8 @@ class BaseDocument: | |||||||
|                 elif isinstance(data, dict): |                 elif isinstance(data, dict): | ||||||
|                     data = data.get(part, None) |                     data = data.get(part, None) | ||||||
|                 else: |                 else: | ||||||
|                     data = getattr(data, part, None) |                     field_name = data._reverse_db_field_map.get(part, part) | ||||||
|  |                     data = getattr(data, field_name, None) | ||||||
|  |  | ||||||
|                 if not isinstance(data, LazyReference) and hasattr( |                 if not isinstance(data, LazyReference) and hasattr( | ||||||
|                     data, "_changed_fields" |                     data, "_changed_fields" | ||||||
| @@ -558,10 +562,40 @@ class BaseDocument: | |||||||
|                         continue |                         continue | ||||||
|  |  | ||||||
|                     data._changed_fields = [] |                     data._changed_fields = [] | ||||||
|  |                 elif isinstance(data, (list, tuple, dict)): | ||||||
|  |                     if hasattr(data, "field") and isinstance( | ||||||
|  |                         data.field, (ReferenceField, GenericReferenceField) | ||||||
|  |                     ): | ||||||
|  |                         continue | ||||||
|  |                     BaseDocument._nestable_types_clear_changed_fields(data) | ||||||
|  |  | ||||||
|         self._changed_fields = [] |         self._changed_fields = [] | ||||||
|  |  | ||||||
|     def _nestable_types_changed_fields(self, changed_fields, base_key, data): |     @staticmethod | ||||||
|  |     def _nestable_types_clear_changed_fields(data): | ||||||
|  |         """Inspect nested data for changed fields | ||||||
|  |  | ||||||
|  |         :param data: data to inspect for changes | ||||||
|  |         """ | ||||||
|  |         Document = _import_class("Document") | ||||||
|  |  | ||||||
|  |         # Loop list / dict fields as they contain documents | ||||||
|  |         # Determine the iterator to use | ||||||
|  |         if not hasattr(data, "items"): | ||||||
|  |             iterator = enumerate(data) | ||||||
|  |         else: | ||||||
|  |             iterator = data.items() | ||||||
|  |  | ||||||
|  |         for index_or_key, value in iterator: | ||||||
|  |             if hasattr(value, "_get_changed_fields") and not isinstance( | ||||||
|  |                 value, Document | ||||||
|  |             ):  # don't follow references | ||||||
|  |                 value._clear_changed_fields() | ||||||
|  |             elif isinstance(value, (list, tuple, dict)): | ||||||
|  |                 BaseDocument._nestable_types_clear_changed_fields(value) | ||||||
|  |  | ||||||
|  |     @staticmethod | ||||||
|  |     def _nestable_types_changed_fields(changed_fields, base_key, data): | ||||||
|         """Inspect nested data for changed fields |         """Inspect nested data for changed fields | ||||||
|  |  | ||||||
|         :param changed_fields: Previously collected changed fields |         :param changed_fields: Previously collected changed fields | ||||||
| @@ -586,7 +620,9 @@ class BaseDocument: | |||||||
|                 changed = value._get_changed_fields() |                 changed = value._get_changed_fields() | ||||||
|                 changed_fields += ["{}{}".format(item_key, k) for k in changed if k] |                 changed_fields += ["{}{}".format(item_key, k) for k in changed if k] | ||||||
|             elif isinstance(value, (list, tuple, dict)): |             elif isinstance(value, (list, tuple, dict)): | ||||||
|                 self._nestable_types_changed_fields(changed_fields, item_key, value) |                 BaseDocument._nestable_types_changed_fields( | ||||||
|  |                     changed_fields, item_key, value | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|     def _get_changed_fields(self): |     def _get_changed_fields(self): | ||||||
|         """Return a list of all fields that have explicitly been changed. |         """Return a list of all fields that have explicitly been changed. | ||||||
|   | |||||||
| @@ -537,6 +537,7 @@ class TestDelta(MongoDBTestCase): | |||||||
|             {}, |             {}, | ||||||
|         ) |         ) | ||||||
|         doc.save() |         doc.save() | ||||||
|  |         assert doc._get_changed_fields() == [] | ||||||
|         doc = doc.reload(10) |         doc = doc.reload(10) | ||||||
|  |  | ||||||
|         assert doc.embedded_field.list_field[0] == "1" |         assert doc.embedded_field.list_field[0] == "1" | ||||||
| @@ -767,9 +768,7 @@ class TestDelta(MongoDBTestCase): | |||||||
|  |  | ||||||
|         MyDoc.drop_collection() |         MyDoc.drop_collection() | ||||||
|  |  | ||||||
|         mydoc = MyDoc( |         MyDoc(name="testcase1", subs={"a": {"b": EmbeddedDoc(name="foo")}}).save() | ||||||
|             name="testcase1", subs={"a": {"b": EmbeddedDoc(name="foo")}} |  | ||||||
|         ).save() |  | ||||||
|  |  | ||||||
|         mydoc = MyDoc.objects.first() |         mydoc = MyDoc.objects.first() | ||||||
|         subdoc = mydoc.subs["a"]["b"] |         subdoc = mydoc.subs["a"]["b"] | ||||||
| @@ -781,6 +780,35 @@ class TestDelta(MongoDBTestCase): | |||||||
|         mydoc._clear_changed_fields() |         mydoc._clear_changed_fields() | ||||||
|         assert [] == mydoc._get_changed_fields() |         assert [] == mydoc._get_changed_fields() | ||||||
|  |  | ||||||
|  |     def test_nested_nested_fields_db_field_set__gets_mark_as_changed_and_cleaned(self): | ||||||
|  |         class EmbeddedDoc(EmbeddedDocument): | ||||||
|  |             name = StringField(db_field="db_name") | ||||||
|  |  | ||||||
|  |         class MyDoc(Document): | ||||||
|  |             embed = EmbeddedDocumentField(EmbeddedDoc, db_field="db_embed") | ||||||
|  |             name = StringField(db_field="db_name") | ||||||
|  |  | ||||||
|  |         MyDoc.drop_collection() | ||||||
|  |  | ||||||
|  |         MyDoc(name="testcase1", embed=EmbeddedDoc(name="foo")).save() | ||||||
|  |  | ||||||
|  |         mydoc = MyDoc.objects.first() | ||||||
|  |         mydoc.embed.name = "foo1" | ||||||
|  |  | ||||||
|  |         assert mydoc.embed._get_changed_fields() == ["db_name"] | ||||||
|  |         assert mydoc._get_changed_fields() == ["db_embed.db_name"] | ||||||
|  |  | ||||||
|  |         mydoc = MyDoc.objects.first() | ||||||
|  |         embed = EmbeddedDoc(name="foo2") | ||||||
|  |         embed.name = "bar" | ||||||
|  |         mydoc.embed = embed | ||||||
|  |  | ||||||
|  |         assert embed._get_changed_fields() == ["db_name"] | ||||||
|  |         assert mydoc._get_changed_fields() == ["db_embed"] | ||||||
|  |  | ||||||
|  |         mydoc._clear_changed_fields() | ||||||
|  |         assert mydoc._get_changed_fields() == [] | ||||||
|  |  | ||||||
|     def test_lower_level_mark_as_changed(self): |     def test_lower_level_mark_as_changed(self): | ||||||
|         class EmbeddedDoc(EmbeddedDocument): |         class EmbeddedDoc(EmbeddedDocument): | ||||||
|             name = StringField() |             name = StringField() | ||||||
|   | |||||||
| @@ -370,8 +370,7 @@ class FieldTest(unittest.TestCase): | |||||||
|         assert Post.objects.all()[0].user_lists == [[u1, u2], [u3]] |         assert Post.objects.all()[0].user_lists == [[u1, u2], [u3]] | ||||||
|  |  | ||||||
|     def test_circular_reference(self): |     def test_circular_reference(self): | ||||||
|         """Ensure you can handle circular references |         """Ensure you can handle circular references""" | ||||||
|         """ |  | ||||||
|  |  | ||||||
|         class Relation(EmbeddedDocument): |         class Relation(EmbeddedDocument): | ||||||
|             name = StringField() |             name = StringField() | ||||||
| @@ -426,6 +425,7 @@ class FieldTest(unittest.TestCase): | |||||||
|  |  | ||||||
|         daughter.relations.append(mother) |         daughter.relations.append(mother) | ||||||
|         daughter.relations.append(daughter) |         daughter.relations.append(daughter) | ||||||
|  |         assert daughter._get_changed_fields() == ["relations"] | ||||||
|         daughter.save() |         daughter.save() | ||||||
|  |  | ||||||
|         assert "[<Person: Mother>, <Person: Daughter>]" == "%s" % Person.objects() |         assert "[<Person: Mother>, <Person: Daughter>]" == "%s" % Person.objects() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user