diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index aa69e0d1..50a4daa9 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -20,19 +20,23 @@ class BaseDict(dict): self._name = name return super(BaseDict, self).__init__(dict_items) - def __getitem__(self, key): + def __getitem__(self, key, *args, **kwargs): value = super(BaseDict, self).__getitem__(key) EmbeddedDocument = _import_class('EmbeddedDocument') if isinstance(value, EmbeddedDocument) and value._instance is None: value._instance = self._instance - elif isinstance(value, dict): + elif not isinstance(value, BaseDict) and isinstance(value, dict): value = BaseDict(value, None, '%s.%s' % (self._name, key)) super(BaseDict, self).__setitem__(key, value) value._instance = self._instance + elif not isinstance(value, BaseList) and isinstance(value, list): + value = BaseList(value, None, '%s.%s' % (self._name, key)) + super(BaseDict, self).__setitem__(key, value) + value._instance = self._instance return value - def __setitem__(self, key, value): + def __setitem__(self, key, value, *args, **kwargs): self._mark_as_changed(key) return super(BaseDict, self).__setitem__(key, value) @@ -40,13 +44,13 @@ class BaseDict(dict): self._mark_as_changed() return super(BaseDict, self).__delete__(*args, **kwargs) - def __delitem__(self, key): + def __delitem__(self, key, *args, **kwargs): self._mark_as_changed(key) return super(BaseDict, self).__delitem__(key) - def __delattr__(self, key): + def __delattr__(self, key, *args, **kwargs): self._mark_as_changed(key) - return super(BaseDict, self).__delattr__(*args, **kwargs) + return super(BaseDict, self).__delattr__(key) def __getstate__(self): self.instance = None @@ -98,21 +102,29 @@ class BaseList(list): self._name = name return super(BaseList, self).__init__(list_items) - def __getitem__(self, *args, **kwargs): - value = super(BaseList, self).__getitem__(*args, **kwargs) + def __getitem__(self, key, *args, **kwargs): + value = super(BaseList, self).__getitem__(key) EmbeddedDocument = _import_class('EmbeddedDocument') if isinstance(value, EmbeddedDocument) and value._instance is None: value._instance = self._instance + elif not isinstance(value, BaseDict) and isinstance(value, dict): + value = BaseDict(value, None, '%s.%s' % (self._name, key)) + super(BaseList, self).__setitem__(key, value) + value._instance = self._instance + elif not isinstance(value, BaseList) and isinstance(value, list): + value = BaseList(value, None, '%s.%s' % (self._name, key)) + super(BaseList, self).__setitem__(key, value) + value._instance = self._instance return value - def __setitem__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).__setitem__(*args, **kwargs) + def __setitem__(self, key, value, *args, **kwargs): + self._mark_as_changed(key) + return super(BaseList, self).__setitem__(key, value) - def __delitem__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).__delitem__(*args, **kwargs) + def __delitem__(self, key, *args, **kwargs): + self._mark_as_changed(key) + return super(BaseList, self).__delitem__(key) def __setslice__(self, *args, **kwargs): self._mark_as_changed() @@ -159,6 +171,9 @@ class BaseList(list): self._mark_as_changed() return super(BaseList, self).sort(*args, **kwargs) - def _mark_as_changed(self): + def _mark_as_changed(self, key=None): if hasattr(self._instance, '_mark_as_changed'): - self._instance._mark_as_changed(self._name) + if key: + self._instance._mark_as_changed('%s.%s' % (self._name, key)) + else: + self._instance._mark_as_changed(self._name) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 79128c13..05d4d791 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -371,27 +371,17 @@ class BaseDocument(object): if not key: return - key = self._db_field_map.get(key, key) if not hasattr(self, '_changed_fields'): return - # FIXME: would _delta be a better place for this? - # - # We want to go through and check that a parent key doesn't exist already - # when adding a nested key. - key_parts = key.split('.') - partial = key_parts[0] + if '.' in key: + key, rest = key.split('.', 1) + key = self._db_field_map.get(key, key) + key = '%s.%s' % (key, rest) + else: + key = self._db_field_map.get(key, key) - if partial in self._changed_fields: - return - - for part in key_parts[1:]: - partial += '.' + part - if partial in self._changed_fields: - return - - if (hasattr(self, '_changed_fields') and - key not in self._changed_fields): + if key not in self._changed_fields: self._changed_fields.append(key) def _clear_changed_fields(self): @@ -443,6 +433,7 @@ class BaseDocument(object): ReferenceField = _import_class("ReferenceField") changed_fields = [] changed_fields += getattr(self, '_changed_fields', []) + inspected = inspected or set() if hasattr(self, 'id') and isinstance(self.id, Hashable): if self.id in inspected: @@ -495,7 +486,10 @@ class BaseDocument(object): if isinstance(d, (ObjectId, DBRef)): break elif isinstance(d, list) and p.isdigit(): - d = d[int(p)] + try: + d = d[int(p)] + except IndexError: + d = None elif hasattr(d, 'get'): d = d.get(p) new_path.append(p) diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index ceda403e..44bb6ad2 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -204,7 +204,8 @@ class DeReference(object): elif isinstance(v, (list, tuple)) and depth <= self.max_depth: data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: - data[k] = self._attach_objects(v, depth - 1, instance=instance, name=name) + item_name = '%s.%s' % (name, k) if name else name + data[k] = self._attach_objects(v, depth - 1, instance=instance, name=item_name) elif hasattr(v, 'id'): data[k] = self.object_map.get(v.id, v) diff --git a/tests/document/delta.py b/tests/document/delta.py index b0f5f01a..292d8255 100644 --- a/tests/document/delta.py +++ b/tests/document/delta.py @@ -207,22 +207,21 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field[2].string_field = 'hello world' doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'list_field': ['1', 2, { - '_cls': 'Embedded', - 'string_field': 'hello world', - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - 'dict_field': {'hello': 'world'}}]}, {})) - self.assertEqual(doc._delta(), ({ - 'embedded_field.list_field': ['1', 2, { + ['embedded_field.list_field.2']) + self.assertEqual(doc.embedded_field._delta(), ({'list_field.2': { '_cls': 'Embedded', 'string_field': 'hello world', 'int_field': 1, 'list_field': ['1', 2, {'hello': 'world'}], 'dict_field': {'hello': 'world'}} - ]}, {})) + }, {})) + self.assertEqual(doc._delta(), ({'embedded_field.list_field.2': { + '_cls': 'Embedded', + 'string_field': 'hello world', + 'int_field': 1, + 'list_field': ['1', 2, {'hello': 'world'}], + 'dict_field': {'hello': 'world'}} + }, {})) doc.save() doc = doc.reload(10) self.assertEqual(doc.embedded_field.list_field[2].string_field, @@ -253,7 +252,7 @@ class DeltaTest(unittest.TestCase): del(doc.embedded_field.list_field[2].list_field[2]['hello']) self.assertEqual(doc._delta(), - ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {})) + ({}, {'embedded_field.list_field.2.list_field.2.hello': 1})) doc.save() doc = doc.reload(10) @@ -548,22 +547,21 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field[2].string_field = 'hello world' doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] self.assertEqual(doc._get_changed_fields(), - ['db_embedded_field.db_list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'db_list_field': ['1', 2, { + ['db_embedded_field.db_list_field.2']) + self.assertEqual(doc.embedded_field._delta(), ({'db_list_field.2': { '_cls': 'Embedded', 'db_string_field': 'hello world', 'db_int_field': 1, 'db_list_field': ['1', 2, {'hello': 'world'}], - 'db_dict_field': {'hello': 'world'}}]}, {})) + 'db_dict_field': {'hello': 'world'}}}, {})) self.assertEqual(doc._delta(), ({ - 'db_embedded_field.db_list_field': ['1', 2, { + 'db_embedded_field.db_list_field.2': { '_cls': 'Embedded', 'db_string_field': 'hello world', 'db_int_field': 1, 'db_list_field': ['1', 2, {'hello': 'world'}], 'db_dict_field': {'hello': 'world'}} - ]}, {})) + }, {})) doc.save() doc = doc.reload(10) self.assertEqual(doc.embedded_field.list_field[2].string_field, @@ -594,8 +592,7 @@ class DeltaTest(unittest.TestCase): del(doc.embedded_field.list_field[2].list_field[2]['hello']) self.assertEqual(doc._delta(), - ({'db_embedded_field.db_list_field.2.db_list_field': - [1, 2, {}]}, {})) + ({}, {'db_embedded_field.db_list_field.2.db_list_field.2.hello': 1})) doc.save() doc = doc.reload(10) diff --git a/tests/document/instance.py b/tests/document/instance.py index 07db85a0..5d3fdc21 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -398,8 +398,8 @@ class InstanceTest(unittest.TestCase): doc.embedded_field.dict_field['woot'] = "woot" self.assertEqual(doc._get_changed_fields(), [ - 'list_field', 'dict_field', 'embedded_field.list_field', - 'embedded_field.dict_field']) + 'list_field', 'dict_field.woot', 'embedded_field.list_field', + 'embedded_field.dict_field.woot']) doc.save() doc = doc.reload(10)