Don't follow references in _get_changed_fields (#422, #417)

A better fix so we dont follow down a references rabbit hole.
This commit is contained in:
Ross Lawley 2013-07-29 17:22:24 +00:00
parent 93a2adb3e6
commit 1e4d48d371
3 changed files with 35 additions and 27 deletions

View File

@ -4,7 +4,7 @@ Changelog
Changes in 0.8.4
================
- Fixed _delta including referenced fields when dbref=False (#417)
- Don't follow references in _get_changed_fields (#422, #417)
- Allow args and kwargs to be passed through to_json (#420)
Changes in 0.8.3

View File

@ -395,6 +395,7 @@ class BaseDocument(object):
"""
EmbeddedDocument = _import_class("EmbeddedDocument")
DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument")
ReferenceField = _import_class("ReferenceField")
_changed_fields = []
_changed_fields += getattr(self, '_changed_fields', [])
@ -405,31 +406,36 @@ class BaseDocument(object):
inspected.add(self.id)
for field_name in self._fields_ordered:
db_field_name = self._db_field_map.get(field_name, field_name)
key = '%s.' % db_field_name
field = self._data.get(field_name, None)
if hasattr(field, 'id'):
if field.id in inspected:
continue
inspected.add(field.id)
data = self._data.get(field_name, None)
field = self._fields.get(field_name)
if (isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument))
if hasattr(data, 'id'):
if data.id in inspected:
continue
inspected.add(data.id)
if isinstance(field, ReferenceField):
continue
elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument))
and db_field_name not in _changed_fields):
# Find all embedded fields that have been changed
changed = field._get_changed_fields(inspected)
changed = data._get_changed_fields(inspected)
_changed_fields += ["%s%s" % (key, k) for k in changed if k]
elif (isinstance(field, (list, tuple, dict)) and
elif (isinstance(data, (list, tuple, dict)) and
db_field_name not in _changed_fields):
# Loop list / dict fields as they contain documents
# Determine the iterator to use
if not hasattr(field, 'items'):
iterator = enumerate(field)
if not hasattr(data, 'items'):
iterator = enumerate(data)
else:
iterator = field.iteritems()
iterator = data.iteritems()
for index, value in iterator:
if not hasattr(value, '_get_changed_fields'):
continue
if (hasattr(field, 'field') and
isinstance(field.field, ReferenceField)):
continue
list_key = "%s%s." % (key, index)
changed = value._get_changed_fields(inspected)
_changed_fields += ["%s%s" % (list_key, k)

View File

@ -328,14 +328,9 @@ class DeltaTest(unittest.TestCase):
Person.drop_collection()
Organization.drop_collection()
person = Person(name="owner")
person.save()
employee = Person(name="employee")
employee.save()
organization = Organization(name="company")
organization.save()
person = Person(name="owner").save()
employee = Person(name="employee").save()
organization = Organization(name="company").save()
person.owns.append(organization)
organization.owner = person
@ -692,25 +687,32 @@ class DeltaTest(unittest.TestCase):
person, organization, employee = self.circular_reference_deltas_2(Document, Document, True)
employee.name = 'test'
self.assertEqual(organization._get_changed_fields(), ['employees.0.name'])
self.assertEqual(organization._get_changed_fields(), [])
updates, removals = organization._delta()
self.assertEqual({}, removals)
self.assertTrue('employees.0' in updates)
self.assertEqual({}, updates)
organization.save()
organization.employees.append(person)
updates, removals = organization._delta()
self.assertEqual({}, removals)
self.assertTrue('employees' in updates)
def test_delta_with_dbref_false(self):
person, organization, employee = self.circular_reference_deltas_2(Document, Document, False)
employee.name = 'test'
self.assertEqual(organization._get_changed_fields(), ['employees.0.name'])
self.assertEqual(organization._get_changed_fields(), [])
updates, removals = organization._delta()
self.assertEqual({}, removals)
self.assertTrue('employees.0' in updates)
self.assertEqual({}, updates)
organization.employees.append(person)
updates, removals = organization._delta()
self.assertEqual({}, removals)
self.assertTrue('employees' in updates)
organization.save()
if __name__ == '__main__':
unittest.main()