From 1e4d48d371e2920dd3397bb20b2f6f1456ed1566 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 29 Jul 2013 17:22:24 +0000 Subject: [PATCH] Don't follow references in _get_changed_fields (#422, #417) A better fix so we dont follow down a references rabbit hole. --- docs/changelog.rst | 2 +- mongoengine/base/document.py | 30 ++++++++++++++++++------------ tests/document/delta.py | 30 ++++++++++++++++-------------- 3 files changed, 35 insertions(+), 27 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index b9c74f89..9112b2b8 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,7 +4,7 @@ Changelog Changes in 0.8.4 ================ -- Fixed _delta including referenced fields when dbref=False (#417) +- Don't follow references in _get_changed_fields (#422, #417) - Allow args and kwargs to be passed through to_json (#420) Changes in 0.8.3 diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index f1c1d55f..80111f70 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -395,6 +395,7 @@ class BaseDocument(object): """ EmbeddedDocument = _import_class("EmbeddedDocument") DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument") + ReferenceField = _import_class("ReferenceField") _changed_fields = [] _changed_fields += getattr(self, '_changed_fields', []) @@ -405,31 +406,36 @@ class BaseDocument(object): inspected.add(self.id) for field_name in self._fields_ordered: - db_field_name = self._db_field_map.get(field_name, field_name) key = '%s.' % db_field_name - field = self._data.get(field_name, None) - if hasattr(field, 'id'): - if field.id in inspected: - continue - inspected.add(field.id) + data = self._data.get(field_name, None) + field = self._fields.get(field_name) - if (isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument)) + if hasattr(data, 'id'): + if data.id in inspected: + continue + inspected.add(data.id) + if isinstance(field, ReferenceField): + continue + elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) and db_field_name not in _changed_fields): # Find all embedded fields that have been changed - changed = field._get_changed_fields(inspected) + changed = data._get_changed_fields(inspected) _changed_fields += ["%s%s" % (key, k) for k in changed if k] - elif (isinstance(field, (list, tuple, dict)) and + elif (isinstance(data, (list, tuple, dict)) and db_field_name not in _changed_fields): # Loop list / dict fields as they contain documents # Determine the iterator to use - if not hasattr(field, 'items'): - iterator = enumerate(field) + if not hasattr(data, 'items'): + iterator = enumerate(data) else: - iterator = field.iteritems() + iterator = data.iteritems() for index, value in iterator: if not hasattr(value, '_get_changed_fields'): continue + if (hasattr(field, 'field') and + isinstance(field.field, ReferenceField)): + continue list_key = "%s%s." % (key, index) changed = value._get_changed_fields(inspected) _changed_fields += ["%s%s" % (list_key, k) diff --git a/tests/document/delta.py b/tests/document/delta.py index c6efc028..b4749f38 100644 --- a/tests/document/delta.py +++ b/tests/document/delta.py @@ -328,14 +328,9 @@ class DeltaTest(unittest.TestCase): Person.drop_collection() Organization.drop_collection() - person = Person(name="owner") - person.save() - - employee = Person(name="employee") - employee.save() - - organization = Organization(name="company") - organization.save() + person = Person(name="owner").save() + employee = Person(name="employee").save() + organization = Organization(name="company").save() person.owns.append(organization) organization.owner = person @@ -692,25 +687,32 @@ class DeltaTest(unittest.TestCase): person, organization, employee = self.circular_reference_deltas_2(Document, Document, True) employee.name = 'test' - self.assertEqual(organization._get_changed_fields(), ['employees.0.name']) + self.assertEqual(organization._get_changed_fields(), []) updates, removals = organization._delta() self.assertEqual({}, removals) - self.assertTrue('employees.0' in updates) + self.assertEqual({}, updates) - organization.save() + organization.employees.append(person) + updates, removals = organization._delta() + self.assertEqual({}, removals) + self.assertTrue('employees' in updates) def test_delta_with_dbref_false(self): person, organization, employee = self.circular_reference_deltas_2(Document, Document, False) employee.name = 'test' - self.assertEqual(organization._get_changed_fields(), ['employees.0.name']) + self.assertEqual(organization._get_changed_fields(), []) updates, removals = organization._delta() self.assertEqual({}, removals) - self.assertTrue('employees.0' in updates) + self.assertEqual({}, updates) + + organization.employees.append(person) + updates, removals = organization._delta() + self.assertEqual({}, removals) + self.assertTrue('employees' in updates) - organization.save() if __name__ == '__main__': unittest.main()