fix for nested MapFields
When using nested MapFields from a document loaded from the database, the nested dictionaries aren't converted to BaseDict, so changes aren't marked. This also includes a change when marking a field as changed to ensure that nested fields aren't included in a $set query if a parent is already marked as changed. Not sure if this could occur but it prevents breakage if it does.
This commit is contained in:
parent
516591fe88
commit
2f6890c78a
@ -5,8 +5,7 @@ __all__ = ("BaseDict", "BaseList")
|
||||
|
||||
|
||||
class BaseDict(dict):
|
||||
"""A special dict so we can watch any changes
|
||||
"""
|
||||
"""A special dict so we can watch any changes"""
|
||||
|
||||
_dereferenced = False
|
||||
_instance = None
|
||||
@ -21,28 +20,32 @@ class BaseDict(dict):
|
||||
self._name = name
|
||||
return super(BaseDict, self).__init__(dict_items)
|
||||
|
||||
def __getitem__(self, *args, **kwargs):
|
||||
value = super(BaseDict, self).__getitem__(*args, **kwargs)
|
||||
def __getitem__(self, key):
|
||||
value = super(BaseDict, self).__getitem__(key)
|
||||
|
||||
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||
if isinstance(value, EmbeddedDocument) and value._instance is None:
|
||||
value._instance = self._instance
|
||||
elif isinstance(value, dict):
|
||||
value = BaseDict(value, None, '%s.%s' % (self._name, key))
|
||||
super(BaseDict, self).__setitem__(key, value)
|
||||
value._instance = self._instance
|
||||
return value
|
||||
|
||||
def __setitem__(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseDict, self).__setitem__(*args, **kwargs)
|
||||
def __setitem__(self, key, value):
|
||||
self._mark_as_changed(key)
|
||||
return super(BaseDict, self).__setitem__(key, value)
|
||||
|
||||
def __delete__(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseDict, self).__delete__(*args, **kwargs)
|
||||
|
||||
def __delitem__(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
return super(BaseDict, self).__delitem__(*args, **kwargs)
|
||||
def __delitem__(self, key):
|
||||
self._mark_as_changed(key)
|
||||
return super(BaseDict, self).__delitem__(key)
|
||||
|
||||
def __delattr__(self, *args, **kwargs):
|
||||
self._mark_as_changed()
|
||||
def __delattr__(self, key):
|
||||
self._mark_as_changed(key)
|
||||
return super(BaseDict, self).__delattr__(*args, **kwargs)
|
||||
|
||||
def __getstate__(self):
|
||||
@ -70,9 +73,12 @@ class BaseDict(dict):
|
||||
self._mark_as_changed()
|
||||
return super(BaseDict, self).update(*args, **kwargs)
|
||||
|
||||
def _mark_as_changed(self):
|
||||
def _mark_as_changed(self, key=None):
|
||||
if hasattr(self._instance, '_mark_as_changed'):
|
||||
self._instance._mark_as_changed(self._name)
|
||||
if key:
|
||||
self._instance._mark_as_changed('%s.%s' % (self._name, key))
|
||||
else:
|
||||
self._instance._mark_as_changed(self._name)
|
||||
|
||||
|
||||
class BaseList(list):
|
||||
|
@ -370,7 +370,26 @@ class BaseDocument(object):
|
||||
"""
|
||||
if not key:
|
||||
return
|
||||
|
||||
key = self._db_field_map.get(key, key)
|
||||
if not hasattr(self, '_changed_fields'):
|
||||
return
|
||||
|
||||
# FIXME: would _delta be a better place for this?
|
||||
#
|
||||
# We want to go through and check that a parent key doesn't exist already
|
||||
# when adding a nested key.
|
||||
key_parts = key.split('.')
|
||||
partial = key_parts[0]
|
||||
|
||||
if partial in self._changed_fields:
|
||||
return
|
||||
|
||||
for part in key_parts[1:]:
|
||||
partial += '.' + part
|
||||
if partial in self._changed_fields:
|
||||
return
|
||||
|
||||
if (hasattr(self, '_changed_fields') and
|
||||
key not in self._changed_fields):
|
||||
self._changed_fields.append(key)
|
||||
@ -405,6 +424,10 @@ class BaseDocument(object):
|
||||
|
||||
for index, value in iterator:
|
||||
list_key = "%s%s." % (key, index)
|
||||
# don't check anything lower if this key is already marked
|
||||
# as changed.
|
||||
if list_key[:-1] in changed_fields:
|
||||
continue
|
||||
if hasattr(value, '_get_changed_fields'):
|
||||
changed = value._get_changed_fields(inspected)
|
||||
changed_fields += ["%s%s" % (list_key, k)
|
||||
|
Loading…
x
Reference in New Issue
Block a user