Don't follow references in _get_changed_fields (#422, #417)

A better fix so we dont follow down a references rabbit hole.
This commit is contained in:
Ross Lawley 2013-07-29 17:22:24 +00:00
parent 93a2adb3e6
commit 1e4d48d371
3 changed files with 35 additions and 27 deletions

View File

@ -4,7 +4,7 @@ Changelog
Changes in 0.8.4 Changes in 0.8.4
================ ================
- Fixed _delta including referenced fields when dbref=False (#417) - Don't follow references in _get_changed_fields (#422, #417)
- Allow args and kwargs to be passed through to_json (#420) - Allow args and kwargs to be passed through to_json (#420)
Changes in 0.8.3 Changes in 0.8.3

View File

@ -395,6 +395,7 @@ class BaseDocument(object):
""" """
EmbeddedDocument = _import_class("EmbeddedDocument") EmbeddedDocument = _import_class("EmbeddedDocument")
DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument") DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument")
ReferenceField = _import_class("ReferenceField")
_changed_fields = [] _changed_fields = []
_changed_fields += getattr(self, '_changed_fields', []) _changed_fields += getattr(self, '_changed_fields', [])
@ -405,31 +406,36 @@ class BaseDocument(object):
inspected.add(self.id) inspected.add(self.id)
for field_name in self._fields_ordered: for field_name in self._fields_ordered:
db_field_name = self._db_field_map.get(field_name, field_name) db_field_name = self._db_field_map.get(field_name, field_name)
key = '%s.' % db_field_name key = '%s.' % db_field_name
field = self._data.get(field_name, None) data = self._data.get(field_name, None)
if hasattr(field, 'id'): field = self._fields.get(field_name)
if field.id in inspected:
continue
inspected.add(field.id)
if (isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument)) if hasattr(data, 'id'):
if data.id in inspected:
continue
inspected.add(data.id)
if isinstance(field, ReferenceField):
continue
elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument))
and db_field_name not in _changed_fields): and db_field_name not in _changed_fields):
# Find all embedded fields that have been changed # Find all embedded fields that have been changed
changed = field._get_changed_fields(inspected) changed = data._get_changed_fields(inspected)
_changed_fields += ["%s%s" % (key, k) for k in changed if k] _changed_fields += ["%s%s" % (key, k) for k in changed if k]
elif (isinstance(field, (list, tuple, dict)) and elif (isinstance(data, (list, tuple, dict)) and
db_field_name not in _changed_fields): db_field_name not in _changed_fields):
# Loop list / dict fields as they contain documents # Loop list / dict fields as they contain documents
# Determine the iterator to use # Determine the iterator to use
if not hasattr(field, 'items'): if not hasattr(data, 'items'):
iterator = enumerate(field) iterator = enumerate(data)
else: else:
iterator = field.iteritems() iterator = data.iteritems()
for index, value in iterator: for index, value in iterator:
if not hasattr(value, '_get_changed_fields'): if not hasattr(value, '_get_changed_fields'):
continue continue
if (hasattr(field, 'field') and
isinstance(field.field, ReferenceField)):
continue
list_key = "%s%s." % (key, index) list_key = "%s%s." % (key, index)
changed = value._get_changed_fields(inspected) changed = value._get_changed_fields(inspected)
_changed_fields += ["%s%s" % (list_key, k) _changed_fields += ["%s%s" % (list_key, k)

View File

@ -328,14 +328,9 @@ class DeltaTest(unittest.TestCase):
Person.drop_collection() Person.drop_collection()
Organization.drop_collection() Organization.drop_collection()
person = Person(name="owner") person = Person(name="owner").save()
person.save() employee = Person(name="employee").save()
organization = Organization(name="company").save()
employee = Person(name="employee")
employee.save()
organization = Organization(name="company")
organization.save()
person.owns.append(organization) person.owns.append(organization)
organization.owner = person organization.owner = person
@ -692,25 +687,32 @@ class DeltaTest(unittest.TestCase):
person, organization, employee = self.circular_reference_deltas_2(Document, Document, True) person, organization, employee = self.circular_reference_deltas_2(Document, Document, True)
employee.name = 'test' employee.name = 'test'
self.assertEqual(organization._get_changed_fields(), ['employees.0.name']) self.assertEqual(organization._get_changed_fields(), [])
updates, removals = organization._delta() updates, removals = organization._delta()
self.assertEqual({}, removals) self.assertEqual({}, removals)
self.assertTrue('employees.0' in updates) self.assertEqual({}, updates)
organization.save() organization.employees.append(person)
updates, removals = organization._delta()
self.assertEqual({}, removals)
self.assertTrue('employees' in updates)
def test_delta_with_dbref_false(self): def test_delta_with_dbref_false(self):
person, organization, employee = self.circular_reference_deltas_2(Document, Document, False) person, organization, employee = self.circular_reference_deltas_2(Document, Document, False)
employee.name = 'test' employee.name = 'test'
self.assertEqual(organization._get_changed_fields(), ['employees.0.name']) self.assertEqual(organization._get_changed_fields(), [])
updates, removals = organization._delta() updates, removals = organization._delta()
self.assertEqual({}, removals) self.assertEqual({}, removals)
self.assertTrue('employees.0' in updates) self.assertEqual({}, updates)
organization.employees.append(person)
updates, removals = organization._delta()
self.assertEqual({}, removals)
self.assertTrue('employees' in updates)
organization.save()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()