Fix multi level nested fields getting marked as changed (#523)

This commit is contained in:
Ross Lawley
2013-11-29 16:24:32 +00:00
parent c074f4d925
commit d868cfdeb0
4 changed files with 69 additions and 38 deletions

View File

@@ -375,20 +375,41 @@ class BaseDocument(object):
self._changed_fields.append(key)
def _clear_changed_fields(self):
"""Using get_changed_fields iterate and remove any fields that are
marked as changed"""
for changed in self._get_changed_fields():
parts = changed.split(".")
data = self
for part in parts:
if isinstance(data, list):
try:
data = data[int(part)]
except IndexError:
data = None
elif isinstance(data, dict):
data = data.get(part, None)
else:
data = getattr(data, part, None)
if hasattr(data, "_changed_fields"):
data._changed_fields = []
self._changed_fields = []
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
for field_name, field in self._fields.iteritems():
if (isinstance(field, ComplexBaseField) and
isinstance(field.field, EmbeddedDocumentField)):
field_value = getattr(self, field_name, None)
if field_value:
for idx in (field_value if isinstance(field_value, dict)
else xrange(len(field_value))):
field_value[idx]._clear_changed_fields()
elif isinstance(field, EmbeddedDocumentField):
field_value = getattr(self, field_name, None)
if field_value:
field_value._clear_changed_fields()
def _nestable_types_changed_fields(self, changed_fields, key, data, inspected):
# Loop list / dict fields as they contain documents
# Determine the iterator to use
if not hasattr(data, 'items'):
iterator = enumerate(data)
else:
iterator = data.iteritems()
for index, value in iterator:
list_key = "%s%s." % (key, index)
if hasattr(value, '_get_changed_fields'):
changed = value._get_changed_fields(inspected)
changed_fields += ["%s%s" % (list_key, k)
for k in changed if k]
elif isinstance(value, (list, tuple, dict)):
self._nestable_types_changed_fields(changed_fields, list_key, value, inspected)
def _get_changed_fields(self, inspected=None):
"""Returns a list of all fields that have explicitly been changed.
@@ -396,13 +417,12 @@ class BaseDocument(object):
EmbeddedDocument = _import_class("EmbeddedDocument")
DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument")
ReferenceField = _import_class("ReferenceField")
_changed_fields = []
_changed_fields += getattr(self, '_changed_fields', [])
changed_fields = []
changed_fields += getattr(self, '_changed_fields', [])
inspected = inspected or set()
if hasattr(self, 'id'):
if hasattr(self, 'id') and not isinstance(self.id, dict):
if self.id in inspected:
return _changed_fields
return changed_fields
inspected.add(self.id)
for field_name in self._fields_ordered:
@@ -418,29 +438,17 @@ class BaseDocument(object):
if isinstance(field, ReferenceField):
continue
elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument))
and db_field_name not in _changed_fields):
and db_field_name not in changed_fields):
# Find all embedded fields that have been changed
changed = data._get_changed_fields(inspected)
_changed_fields += ["%s%s" % (key, k) for k in changed if k]
changed_fields += ["%s%s" % (key, k) for k in changed if k]
elif (isinstance(data, (list, tuple, dict)) and
db_field_name not in _changed_fields):
# Loop list / dict fields as they contain documents
# Determine the iterator to use
if not hasattr(data, 'items'):
iterator = enumerate(data)
else:
iterator = data.iteritems()
for index, value in iterator:
if not hasattr(value, '_get_changed_fields'):
continue
if (hasattr(field, 'field') and
isinstance(field.field, ReferenceField)):
continue
list_key = "%s%s." % (key, index)
changed = value._get_changed_fields(inspected)
_changed_fields += ["%s%s" % (list_key, k)
for k in changed if k]
return _changed_fields
db_field_name not in changed_fields):
if (hasattr(field, 'field') and
isinstance(field.field, ReferenceField)):
continue
self._nestable_types_changed_fields(changed_fields, key, data, inspected)
return changed_fields
def _delta(self):
"""Returns the delta (set, unset) of the changes for a document.