From 2f6890c78a23bb9c15709aeb17d1368be08ef87c Mon Sep 17 00:00:00 2001 From: Damien Churchill Date: Mon, 16 Dec 2013 13:44:07 +0000 Subject: [PATCH] fix for nested MapFields When using nested MapFields from a document loaded from the database, the nested dictionaries aren't converted to BaseDict, so changes aren't marked. This also includes a change when marking a field as changed to ensure that nested fields aren't included in a $set query if a parent is already marked as changed. Not sure if this could occur but it prevents breakage if it does. --- mongoengine/base/datastructures.py | 34 ++++++++++++++++++------------ mongoengine/base/document.py | 23 ++++++++++++++++++++ 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index 4652fb56..aa69e0d1 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -5,8 +5,7 @@ __all__ = ("BaseDict", "BaseList") class BaseDict(dict): - """A special dict so we can watch any changes - """ + """A special dict so we can watch any changes""" _dereferenced = False _instance = None @@ -21,28 +20,32 @@ class BaseDict(dict): self._name = name return super(BaseDict, self).__init__(dict_items) - def __getitem__(self, *args, **kwargs): - value = super(BaseDict, self).__getitem__(*args, **kwargs) + def __getitem__(self, key): + value = super(BaseDict, self).__getitem__(key) EmbeddedDocument = _import_class('EmbeddedDocument') if isinstance(value, EmbeddedDocument) and value._instance is None: value._instance = self._instance + elif isinstance(value, dict): + value = BaseDict(value, None, '%s.%s' % (self._name, key)) + super(BaseDict, self).__setitem__(key, value) + value._instance = self._instance return value - def __setitem__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).__setitem__(*args, **kwargs) + def __setitem__(self, key, value): + self._mark_as_changed(key) + return super(BaseDict, self).__setitem__(key, value) def __delete__(self, *args, **kwargs): self._mark_as_changed() return super(BaseDict, self).__delete__(*args, **kwargs) - def __delitem__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).__delitem__(*args, **kwargs) + def __delitem__(self, key): + self._mark_as_changed(key) + return super(BaseDict, self).__delitem__(key) - def __delattr__(self, *args, **kwargs): - self._mark_as_changed() + def __delattr__(self, key): + self._mark_as_changed(key) return super(BaseDict, self).__delattr__(*args, **kwargs) def __getstate__(self): @@ -70,9 +73,12 @@ class BaseDict(dict): self._mark_as_changed() return super(BaseDict, self).update(*args, **kwargs) - def _mark_as_changed(self): + def _mark_as_changed(self, key=None): if hasattr(self._instance, '_mark_as_changed'): - self._instance._mark_as_changed(self._name) + if key: + self._instance._mark_as_changed('%s.%s' % (self._name, key)) + else: + self._instance._mark_as_changed(self._name) class BaseList(list): diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index f5eae8ff..79128c13 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -370,7 +370,26 @@ class BaseDocument(object): """ if not key: return + key = self._db_field_map.get(key, key) + if not hasattr(self, '_changed_fields'): + return + + # FIXME: would _delta be a better place for this? + # + # We want to go through and check that a parent key doesn't exist already + # when adding a nested key. + key_parts = key.split('.') + partial = key_parts[0] + + if partial in self._changed_fields: + return + + for part in key_parts[1:]: + partial += '.' + part + if partial in self._changed_fields: + return + if (hasattr(self, '_changed_fields') and key not in self._changed_fields): self._changed_fields.append(key) @@ -405,6 +424,10 @@ class BaseDocument(object): for index, value in iterator: list_key = "%s%s." % (key, index) + # don't check anything lower if this key is already marked + # as changed. + if list_key[:-1] in changed_fields: + continue if hasattr(value, '_get_changed_fields'): changed = value._get_changed_fields(inspected) changed_fields += ["%s%s" % (list_key, k)