fix for nested MapFields
When using nested MapFields from a document loaded from the database, the nested dictionaries aren't converted to BaseDict, so changes aren't marked. This also includes a change when marking a field as changed to ensure that nested fields aren't included in a $set query if a parent is already marked as changed. Not sure if this could occur but it prevents breakage if it does.
This commit is contained in:
		| @@ -5,8 +5,7 @@ __all__ = ("BaseDict", "BaseList") | ||||
|  | ||||
|  | ||||
| class BaseDict(dict): | ||||
|     """A special dict so we can watch any changes | ||||
|     """ | ||||
|     """A special dict so we can watch any changes""" | ||||
|  | ||||
|     _dereferenced = False | ||||
|     _instance = None | ||||
| @@ -21,28 +20,32 @@ class BaseDict(dict): | ||||
|         self._name = name | ||||
|         return super(BaseDict, self).__init__(dict_items) | ||||
|  | ||||
|     def __getitem__(self, *args, **kwargs): | ||||
|         value = super(BaseDict, self).__getitem__(*args, **kwargs) | ||||
|     def __getitem__(self, key): | ||||
|         value = super(BaseDict, self).__getitem__(key) | ||||
|  | ||||
|         EmbeddedDocument = _import_class('EmbeddedDocument') | ||||
|         if isinstance(value, EmbeddedDocument) and value._instance is None: | ||||
|             value._instance = self._instance | ||||
|         elif isinstance(value, dict): | ||||
|             value = BaseDict(value, None, '%s.%s' % (self._name, key)) | ||||
|             super(BaseDict, self).__setitem__(key, value) | ||||
|             value._instance = self._instance | ||||
|         return value | ||||
|  | ||||
|     def __setitem__(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         return super(BaseDict, self).__setitem__(*args, **kwargs) | ||||
|     def __setitem__(self, key, value): | ||||
|         self._mark_as_changed(key) | ||||
|         return super(BaseDict, self).__setitem__(key, value) | ||||
|  | ||||
|     def __delete__(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         return super(BaseDict, self).__delete__(*args, **kwargs) | ||||
|  | ||||
|     def __delitem__(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         return super(BaseDict, self).__delitem__(*args, **kwargs) | ||||
|     def __delitem__(self, key): | ||||
|         self._mark_as_changed(key) | ||||
|         return super(BaseDict, self).__delitem__(key) | ||||
|  | ||||
|     def __delattr__(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|     def __delattr__(self, key): | ||||
|         self._mark_as_changed(key) | ||||
|         return super(BaseDict, self).__delattr__(*args, **kwargs) | ||||
|  | ||||
|     def __getstate__(self): | ||||
| @@ -70,9 +73,12 @@ class BaseDict(dict): | ||||
|         self._mark_as_changed() | ||||
|         return super(BaseDict, self).update(*args, **kwargs) | ||||
|  | ||||
|     def _mark_as_changed(self): | ||||
|     def _mark_as_changed(self, key=None): | ||||
|         if hasattr(self._instance, '_mark_as_changed'): | ||||
|             self._instance._mark_as_changed(self._name) | ||||
|             if key: | ||||
|                 self._instance._mark_as_changed('%s.%s' % (self._name, key)) | ||||
|             else: | ||||
|                 self._instance._mark_as_changed(self._name) | ||||
|  | ||||
|  | ||||
| class BaseList(list): | ||||
|   | ||||
| @@ -370,7 +370,26 @@ class BaseDocument(object): | ||||
|         """ | ||||
|         if not key: | ||||
|             return | ||||
|  | ||||
|         key = self._db_field_map.get(key, key) | ||||
|         if not hasattr(self, '_changed_fields'): | ||||
|             return | ||||
|  | ||||
|         # FIXME: would _delta be a better place for this? | ||||
|         #  | ||||
|         # We want to go through and check that a parent key doesn't exist already | ||||
|         # when adding a nested key. | ||||
|         key_parts = key.split('.') | ||||
|         partial = key_parts[0] | ||||
|  | ||||
|         if partial in self._changed_fields: | ||||
|             return | ||||
|  | ||||
|         for part in key_parts[1:]: | ||||
|             partial += '.' + part | ||||
|             if partial in self._changed_fields: | ||||
|                 return | ||||
|  | ||||
|         if (hasattr(self, '_changed_fields') and | ||||
|            key not in self._changed_fields): | ||||
|             self._changed_fields.append(key) | ||||
| @@ -405,6 +424,10 @@ class BaseDocument(object): | ||||
|  | ||||
|         for index, value in iterator: | ||||
|             list_key = "%s%s." % (key, index) | ||||
|             # don't check anything lower if this key is already marked | ||||
|             # as changed. | ||||
|             if list_key[:-1] in changed_fields: | ||||
|                 continue | ||||
|             if hasattr(value, '_get_changed_fields'): | ||||
|                 changed = value._get_changed_fields(inspected) | ||||
|                 changed_fields += ["%s%s" % (list_key, k) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user