fix for nested MapFields
When using nested MapFields from a document loaded from the database, the nested dictionaries aren't converted to BaseDict, so changes aren't marked. This also includes a change when marking a field as changed to ensure that nested fields aren't included in a $set query if a parent is already marked as changed. Not sure if this could occur but it prevents breakage if it does.
This commit is contained in:
parent
516591fe88
commit
2f6890c78a
@ -5,8 +5,7 @@ __all__ = ("BaseDict", "BaseList")
|
|||||||
|
|
||||||
|
|
||||||
class BaseDict(dict):
|
class BaseDict(dict):
|
||||||
"""A special dict so we can watch any changes
|
"""A special dict so we can watch any changes"""
|
||||||
"""
|
|
||||||
|
|
||||||
_dereferenced = False
|
_dereferenced = False
|
||||||
_instance = None
|
_instance = None
|
||||||
@ -21,28 +20,32 @@ class BaseDict(dict):
|
|||||||
self._name = name
|
self._name = name
|
||||||
return super(BaseDict, self).__init__(dict_items)
|
return super(BaseDict, self).__init__(dict_items)
|
||||||
|
|
||||||
def __getitem__(self, *args, **kwargs):
|
def __getitem__(self, key):
|
||||||
value = super(BaseDict, self).__getitem__(*args, **kwargs)
|
value = super(BaseDict, self).__getitem__(key)
|
||||||
|
|
||||||
EmbeddedDocument = _import_class('EmbeddedDocument')
|
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||||
if isinstance(value, EmbeddedDocument) and value._instance is None:
|
if isinstance(value, EmbeddedDocument) and value._instance is None:
|
||||||
value._instance = self._instance
|
value._instance = self._instance
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
value = BaseDict(value, None, '%s.%s' % (self._name, key))
|
||||||
|
super(BaseDict, self).__setitem__(key, value)
|
||||||
|
value._instance = self._instance
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def __setitem__(self, *args, **kwargs):
|
def __setitem__(self, key, value):
|
||||||
self._mark_as_changed()
|
self._mark_as_changed(key)
|
||||||
return super(BaseDict, self).__setitem__(*args, **kwargs)
|
return super(BaseDict, self).__setitem__(key, value)
|
||||||
|
|
||||||
def __delete__(self, *args, **kwargs):
|
def __delete__(self, *args, **kwargs):
|
||||||
self._mark_as_changed()
|
self._mark_as_changed()
|
||||||
return super(BaseDict, self).__delete__(*args, **kwargs)
|
return super(BaseDict, self).__delete__(*args, **kwargs)
|
||||||
|
|
||||||
def __delitem__(self, *args, **kwargs):
|
def __delitem__(self, key):
|
||||||
self._mark_as_changed()
|
self._mark_as_changed(key)
|
||||||
return super(BaseDict, self).__delitem__(*args, **kwargs)
|
return super(BaseDict, self).__delitem__(key)
|
||||||
|
|
||||||
def __delattr__(self, *args, **kwargs):
|
def __delattr__(self, key):
|
||||||
self._mark_as_changed()
|
self._mark_as_changed(key)
|
||||||
return super(BaseDict, self).__delattr__(*args, **kwargs)
|
return super(BaseDict, self).__delattr__(*args, **kwargs)
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
@ -70,9 +73,12 @@ class BaseDict(dict):
|
|||||||
self._mark_as_changed()
|
self._mark_as_changed()
|
||||||
return super(BaseDict, self).update(*args, **kwargs)
|
return super(BaseDict, self).update(*args, **kwargs)
|
||||||
|
|
||||||
def _mark_as_changed(self):
|
def _mark_as_changed(self, key=None):
|
||||||
if hasattr(self._instance, '_mark_as_changed'):
|
if hasattr(self._instance, '_mark_as_changed'):
|
||||||
self._instance._mark_as_changed(self._name)
|
if key:
|
||||||
|
self._instance._mark_as_changed('%s.%s' % (self._name, key))
|
||||||
|
else:
|
||||||
|
self._instance._mark_as_changed(self._name)
|
||||||
|
|
||||||
|
|
||||||
class BaseList(list):
|
class BaseList(list):
|
||||||
|
@ -370,7 +370,26 @@ class BaseDocument(object):
|
|||||||
"""
|
"""
|
||||||
if not key:
|
if not key:
|
||||||
return
|
return
|
||||||
|
|
||||||
key = self._db_field_map.get(key, key)
|
key = self._db_field_map.get(key, key)
|
||||||
|
if not hasattr(self, '_changed_fields'):
|
||||||
|
return
|
||||||
|
|
||||||
|
# FIXME: would _delta be a better place for this?
|
||||||
|
#
|
||||||
|
# We want to go through and check that a parent key doesn't exist already
|
||||||
|
# when adding a nested key.
|
||||||
|
key_parts = key.split('.')
|
||||||
|
partial = key_parts[0]
|
||||||
|
|
||||||
|
if partial in self._changed_fields:
|
||||||
|
return
|
||||||
|
|
||||||
|
for part in key_parts[1:]:
|
||||||
|
partial += '.' + part
|
||||||
|
if partial in self._changed_fields:
|
||||||
|
return
|
||||||
|
|
||||||
if (hasattr(self, '_changed_fields') and
|
if (hasattr(self, '_changed_fields') and
|
||||||
key not in self._changed_fields):
|
key not in self._changed_fields):
|
||||||
self._changed_fields.append(key)
|
self._changed_fields.append(key)
|
||||||
@ -405,6 +424,10 @@ class BaseDocument(object):
|
|||||||
|
|
||||||
for index, value in iterator:
|
for index, value in iterator:
|
||||||
list_key = "%s%s." % (key, index)
|
list_key = "%s%s." % (key, index)
|
||||||
|
# don't check anything lower if this key is already marked
|
||||||
|
# as changed.
|
||||||
|
if list_key[:-1] in changed_fields:
|
||||||
|
continue
|
||||||
if hasattr(value, '_get_changed_fields'):
|
if hasattr(value, '_get_changed_fields'):
|
||||||
changed = value._get_changed_fields(inspected)
|
changed = value._get_changed_fields(inspected)
|
||||||
changed_fields += ["%s%s" % (list_key, k)
|
changed_fields += ["%s%s" % (list_key, k)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user