From d868cfdeb08ec2de046f0efac34ae735a4745b6f Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Fri, 29 Nov 2013 16:24:32 +0000 Subject: [PATCH] Fix multi level nested fields getting marked as changed (#523) --- AUTHORS | 1 + docs/changelog.rst | 1 + mongoengine/base/document.py | 84 ++++++++++++++++++++---------------- tests/document/delta.py | 21 +++++++++ 4 files changed, 69 insertions(+), 38 deletions(-) diff --git a/AUTHORS b/AUTHORS index 21b0ae1b..d6994d50 100644 --- a/AUTHORS +++ b/AUTHORS @@ -188,3 +188,4 @@ that much better: * Dmytro Popovych (https://github.com/drudim) * Tom (https://github.com/tomprimozic) * j0hnsmith (https://github.com/j0hnsmith) + * Damien Churchill (https://github.com/damoxc) diff --git a/docs/changelog.rst b/docs/changelog.rst index 6b950318..774b1dbc 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Changes in 0.8.5 ================ +- Fix multi level nested fields getting marked as changed (#523) - Django 1.6 login fix (#522) - Django 1.6 session fix (#509) - EmbeddedDocument._instance is now set when settng the attribute (#506) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 9bf77302..e425cb84 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -375,20 +375,41 @@ class BaseDocument(object): self._changed_fields.append(key) def _clear_changed_fields(self): + """Using get_changed_fields iterate and remove any fields that are + marked as changed""" + for changed in self._get_changed_fields(): + parts = changed.split(".") + data = self + for part in parts: + if isinstance(data, list): + try: + data = data[int(part)] + except IndexError: + data = None + elif isinstance(data, dict): + data = data.get(part, None) + else: + data = getattr(data, part, None) + if hasattr(data, "_changed_fields"): + data._changed_fields = [] self._changed_fields = [] - EmbeddedDocumentField = _import_class("EmbeddedDocumentField") - for field_name, field in self._fields.iteritems(): - if (isinstance(field, ComplexBaseField) and - isinstance(field.field, EmbeddedDocumentField)): - field_value = getattr(self, field_name, None) - if field_value: - for idx in (field_value if isinstance(field_value, dict) - else xrange(len(field_value))): - field_value[idx]._clear_changed_fields() - elif isinstance(field, EmbeddedDocumentField): - field_value = getattr(self, field_name, None) - if field_value: - field_value._clear_changed_fields() + + def _nestable_types_changed_fields(self, changed_fields, key, data, inspected): + # Loop list / dict fields as they contain documents + # Determine the iterator to use + if not hasattr(data, 'items'): + iterator = enumerate(data) + else: + iterator = data.iteritems() + + for index, value in iterator: + list_key = "%s%s." % (key, index) + if hasattr(value, '_get_changed_fields'): + changed = value._get_changed_fields(inspected) + changed_fields += ["%s%s" % (list_key, k) + for k in changed if k] + elif isinstance(value, (list, tuple, dict)): + self._nestable_types_changed_fields(changed_fields, list_key, value, inspected) def _get_changed_fields(self, inspected=None): """Returns a list of all fields that have explicitly been changed. @@ -396,13 +417,12 @@ class BaseDocument(object): EmbeddedDocument = _import_class("EmbeddedDocument") DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument") ReferenceField = _import_class("ReferenceField") - _changed_fields = [] - _changed_fields += getattr(self, '_changed_fields', []) - + changed_fields = [] + changed_fields += getattr(self, '_changed_fields', []) inspected = inspected or set() - if hasattr(self, 'id'): + if hasattr(self, 'id') and not isinstance(self.id, dict): if self.id in inspected: - return _changed_fields + return changed_fields inspected.add(self.id) for field_name in self._fields_ordered: @@ -418,29 +438,17 @@ class BaseDocument(object): if isinstance(field, ReferenceField): continue elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) - and db_field_name not in _changed_fields): + and db_field_name not in changed_fields): # Find all embedded fields that have been changed changed = data._get_changed_fields(inspected) - _changed_fields += ["%s%s" % (key, k) for k in changed if k] + changed_fields += ["%s%s" % (key, k) for k in changed if k] elif (isinstance(data, (list, tuple, dict)) and - db_field_name not in _changed_fields): - # Loop list / dict fields as they contain documents - # Determine the iterator to use - if not hasattr(data, 'items'): - iterator = enumerate(data) - else: - iterator = data.iteritems() - for index, value in iterator: - if not hasattr(value, '_get_changed_fields'): - continue - if (hasattr(field, 'field') and - isinstance(field.field, ReferenceField)): - continue - list_key = "%s%s." % (key, index) - changed = value._get_changed_fields(inspected) - _changed_fields += ["%s%s" % (list_key, k) - for k in changed if k] - return _changed_fields + db_field_name not in changed_fields): + if (hasattr(field, 'field') and + isinstance(field.field, ReferenceField)): + continue + self._nestable_types_changed_fields(changed_fields, key, data, inspected) + return changed_fields def _delta(self): """Returns the delta (set, unset) of the changes for a document. diff --git a/tests/document/delta.py b/tests/document/delta.py index b4749f38..b0f5f01a 100644 --- a/tests/document/delta.py +++ b/tests/document/delta.py @@ -713,6 +713,27 @@ class DeltaTest(unittest.TestCase): self.assertEqual({}, removals) self.assertTrue('employees' in updates) + def test_nested_nested_fields_mark_as_changed(self): + class EmbeddedDoc(EmbeddedDocument): + name = StringField() + + class MyDoc(Document): + subs = MapField(MapField(EmbeddedDocumentField(EmbeddedDoc))) + name = StringField() + + MyDoc.drop_collection() + + mydoc = MyDoc(name='testcase1', subs={'a': {'b': EmbeddedDoc(name='foo')}}).save() + + mydoc = MyDoc.objects.first() + subdoc = mydoc.subs['a']['b'] + subdoc.name = 'bar' + + self.assertEqual(["name"], subdoc._get_changed_fields()) + self.assertEqual(["subs.a.b.name"], mydoc._get_changed_fields()) + + mydoc._clear_changed_fields() + self.assertEqual([], mydoc._get_changed_fields()) if __name__ == '__main__': unittest.main()