Fix multi level nested fields getting marked as changed (#523)
This commit is contained in:
parent
c074f4d925
commit
d868cfdeb0
1
AUTHORS
1
AUTHORS
@ -188,3 +188,4 @@ that much better:
|
||||
* Dmytro Popovych (https://github.com/drudim)
|
||||
* Tom (https://github.com/tomprimozic)
|
||||
* j0hnsmith (https://github.com/j0hnsmith)
|
||||
* Damien Churchill (https://github.com/damoxc)
|
||||
|
@ -4,6 +4,7 @@ Changelog
|
||||
|
||||
Changes in 0.8.5
|
||||
================
|
||||
- Fix multi level nested fields getting marked as changed (#523)
|
||||
- Django 1.6 login fix (#522)
|
||||
- Django 1.6 session fix (#509)
|
||||
- EmbeddedDocument._instance is now set when settng the attribute (#506)
|
||||
|
@ -375,20 +375,41 @@ class BaseDocument(object):
|
||||
self._changed_fields.append(key)
|
||||
|
||||
def _clear_changed_fields(self):
|
||||
"""Using get_changed_fields iterate and remove any fields that are
|
||||
marked as changed"""
|
||||
for changed in self._get_changed_fields():
|
||||
parts = changed.split(".")
|
||||
data = self
|
||||
for part in parts:
|
||||
if isinstance(data, list):
|
||||
try:
|
||||
data = data[int(part)]
|
||||
except IndexError:
|
||||
data = None
|
||||
elif isinstance(data, dict):
|
||||
data = data.get(part, None)
|
||||
else:
|
||||
data = getattr(data, part, None)
|
||||
if hasattr(data, "_changed_fields"):
|
||||
data._changed_fields = []
|
||||
self._changed_fields = []
|
||||
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
|
||||
for field_name, field in self._fields.iteritems():
|
||||
if (isinstance(field, ComplexBaseField) and
|
||||
isinstance(field.field, EmbeddedDocumentField)):
|
||||
field_value = getattr(self, field_name, None)
|
||||
if field_value:
|
||||
for idx in (field_value if isinstance(field_value, dict)
|
||||
else xrange(len(field_value))):
|
||||
field_value[idx]._clear_changed_fields()
|
||||
elif isinstance(field, EmbeddedDocumentField):
|
||||
field_value = getattr(self, field_name, None)
|
||||
if field_value:
|
||||
field_value._clear_changed_fields()
|
||||
|
||||
def _nestable_types_changed_fields(self, changed_fields, key, data, inspected):
|
||||
# Loop list / dict fields as they contain documents
|
||||
# Determine the iterator to use
|
||||
if not hasattr(data, 'items'):
|
||||
iterator = enumerate(data)
|
||||
else:
|
||||
iterator = data.iteritems()
|
||||
|
||||
for index, value in iterator:
|
||||
list_key = "%s%s." % (key, index)
|
||||
if hasattr(value, '_get_changed_fields'):
|
||||
changed = value._get_changed_fields(inspected)
|
||||
changed_fields += ["%s%s" % (list_key, k)
|
||||
for k in changed if k]
|
||||
elif isinstance(value, (list, tuple, dict)):
|
||||
self._nestable_types_changed_fields(changed_fields, list_key, value, inspected)
|
||||
|
||||
def _get_changed_fields(self, inspected=None):
|
||||
"""Returns a list of all fields that have explicitly been changed.
|
||||
@ -396,13 +417,12 @@ class BaseDocument(object):
|
||||
EmbeddedDocument = _import_class("EmbeddedDocument")
|
||||
DynamicEmbeddedDocument = _import_class("DynamicEmbeddedDocument")
|
||||
ReferenceField = _import_class("ReferenceField")
|
||||
_changed_fields = []
|
||||
_changed_fields += getattr(self, '_changed_fields', [])
|
||||
|
||||
changed_fields = []
|
||||
changed_fields += getattr(self, '_changed_fields', [])
|
||||
inspected = inspected or set()
|
||||
if hasattr(self, 'id'):
|
||||
if hasattr(self, 'id') and not isinstance(self.id, dict):
|
||||
if self.id in inspected:
|
||||
return _changed_fields
|
||||
return changed_fields
|
||||
inspected.add(self.id)
|
||||
|
||||
for field_name in self._fields_ordered:
|
||||
@ -418,29 +438,17 @@ class BaseDocument(object):
|
||||
if isinstance(field, ReferenceField):
|
||||
continue
|
||||
elif (isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument))
|
||||
and db_field_name not in _changed_fields):
|
||||
and db_field_name not in changed_fields):
|
||||
# Find all embedded fields that have been changed
|
||||
changed = data._get_changed_fields(inspected)
|
||||
_changed_fields += ["%s%s" % (key, k) for k in changed if k]
|
||||
changed_fields += ["%s%s" % (key, k) for k in changed if k]
|
||||
elif (isinstance(data, (list, tuple, dict)) and
|
||||
db_field_name not in _changed_fields):
|
||||
# Loop list / dict fields as they contain documents
|
||||
# Determine the iterator to use
|
||||
if not hasattr(data, 'items'):
|
||||
iterator = enumerate(data)
|
||||
else:
|
||||
iterator = data.iteritems()
|
||||
for index, value in iterator:
|
||||
if not hasattr(value, '_get_changed_fields'):
|
||||
continue
|
||||
if (hasattr(field, 'field') and
|
||||
isinstance(field.field, ReferenceField)):
|
||||
continue
|
||||
list_key = "%s%s." % (key, index)
|
||||
changed = value._get_changed_fields(inspected)
|
||||
_changed_fields += ["%s%s" % (list_key, k)
|
||||
for k in changed if k]
|
||||
return _changed_fields
|
||||
db_field_name not in changed_fields):
|
||||
if (hasattr(field, 'field') and
|
||||
isinstance(field.field, ReferenceField)):
|
||||
continue
|
||||
self._nestable_types_changed_fields(changed_fields, key, data, inspected)
|
||||
return changed_fields
|
||||
|
||||
def _delta(self):
|
||||
"""Returns the delta (set, unset) of the changes for a document.
|
||||
|
@ -713,6 +713,27 @@ class DeltaTest(unittest.TestCase):
|
||||
self.assertEqual({}, removals)
|
||||
self.assertTrue('employees' in updates)
|
||||
|
||||
def test_nested_nested_fields_mark_as_changed(self):
|
||||
class EmbeddedDoc(EmbeddedDocument):
|
||||
name = StringField()
|
||||
|
||||
class MyDoc(Document):
|
||||
subs = MapField(MapField(EmbeddedDocumentField(EmbeddedDoc)))
|
||||
name = StringField()
|
||||
|
||||
MyDoc.drop_collection()
|
||||
|
||||
mydoc = MyDoc(name='testcase1', subs={'a': {'b': EmbeddedDoc(name='foo')}}).save()
|
||||
|
||||
mydoc = MyDoc.objects.first()
|
||||
subdoc = mydoc.subs['a']['b']
|
||||
subdoc.name = 'bar'
|
||||
|
||||
self.assertEqual(["name"], subdoc._get_changed_fields())
|
||||
self.assertEqual(["subs.a.b.name"], mydoc._get_changed_fields())
|
||||
|
||||
mydoc._clear_changed_fields()
|
||||
self.assertEqual([], mydoc._get_changed_fields())
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user