From db7f93cff3d2a96f4d0095174cf5e8ad5908c23a Mon Sep 17 00:00:00 2001 From: Damien Churchill Date: Wed, 12 Mar 2014 15:07:40 +0000 Subject: [PATCH] improved update queries for BaseDict & BaseList Migrate changes to include updating single elements of ListFields as well as MapFields by adding the same changes to BaseList. This is done by ensuring all BaseDicts and BaseLists have the correct name from the base of the nearest (Embedded)Document, then marking changes with their key or index when they are changed. Tests also all fixed up. --- mongoengine/base/datastructures.py | 47 ++++++++++++++++++++---------- mongoengine/base/document.py | 30 ++++++++----------- mongoengine/dereference.py | 3 +- tests/document/delta.py | 37 +++++++++++------------ tests/document/instance.py | 4 +-- 5 files changed, 64 insertions(+), 57 deletions(-) diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index aa69e0d1..50a4daa9 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -20,19 +20,23 @@ class BaseDict(dict): self._name = name return super(BaseDict, self).__init__(dict_items) - def __getitem__(self, key): + def __getitem__(self, key, *args, **kwargs): value = super(BaseDict, self).__getitem__(key) EmbeddedDocument = _import_class('EmbeddedDocument') if isinstance(value, EmbeddedDocument) and value._instance is None: value._instance = self._instance - elif isinstance(value, dict): + elif not isinstance(value, BaseDict) and isinstance(value, dict): value = BaseDict(value, None, '%s.%s' % (self._name, key)) super(BaseDict, self).__setitem__(key, value) value._instance = self._instance + elif not isinstance(value, BaseList) and isinstance(value, list): + value = BaseList(value, None, '%s.%s' % (self._name, key)) + super(BaseDict, self).__setitem__(key, value) + value._instance = self._instance return value - def __setitem__(self, key, value): + def __setitem__(self, key, value, *args, **kwargs): self._mark_as_changed(key) return super(BaseDict, self).__setitem__(key, value) @@ -40,13 +44,13 @@ class BaseDict(dict): self._mark_as_changed() return super(BaseDict, self).__delete__(*args, **kwargs) - def __delitem__(self, key): + def __delitem__(self, key, *args, **kwargs): self._mark_as_changed(key) return super(BaseDict, self).__delitem__(key) - def __delattr__(self, key): + def __delattr__(self, key, *args, **kwargs): self._mark_as_changed(key) - return super(BaseDict, self).__delattr__(*args, **kwargs) + return super(BaseDict, self).__delattr__(key) def __getstate__(self): self.instance = None @@ -98,21 +102,29 @@ class BaseList(list): self._name = name return super(BaseList, self).__init__(list_items) - def __getitem__(self, *args, **kwargs): - value = super(BaseList, self).__getitem__(*args, **kwargs) + def __getitem__(self, key, *args, **kwargs): + value = super(BaseList, self).__getitem__(key) EmbeddedDocument = _import_class('EmbeddedDocument') if isinstance(value, EmbeddedDocument) and value._instance is None: value._instance = self._instance + elif not isinstance(value, BaseDict) and isinstance(value, dict): + value = BaseDict(value, None, '%s.%s' % (self._name, key)) + super(BaseList, self).__setitem__(key, value) + value._instance = self._instance + elif not isinstance(value, BaseList) and isinstance(value, list): + value = BaseList(value, None, '%s.%s' % (self._name, key)) + super(BaseList, self).__setitem__(key, value) + value._instance = self._instance return value - def __setitem__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).__setitem__(*args, **kwargs) + def __setitem__(self, key, value, *args, **kwargs): + self._mark_as_changed(key) + return super(BaseList, self).__setitem__(key, value) - def __delitem__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).__delitem__(*args, **kwargs) + def __delitem__(self, key, *args, **kwargs): + self._mark_as_changed(key) + return super(BaseList, self).__delitem__(key) def __setslice__(self, *args, **kwargs): self._mark_as_changed() @@ -159,6 +171,9 @@ class BaseList(list): self._mark_as_changed() return super(BaseList, self).sort(*args, **kwargs) - def _mark_as_changed(self): + def _mark_as_changed(self, key=None): if hasattr(self._instance, '_mark_as_changed'): - self._instance._mark_as_changed(self._name) + if key: + self._instance._mark_as_changed('%s.%s' % (self._name, key)) + else: + self._instance._mark_as_changed(self._name) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 79128c13..05d4d791 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -371,27 +371,17 @@ class BaseDocument(object): if not key: return - key = self._db_field_map.get(key, key) if not hasattr(self, '_changed_fields'): return - # FIXME: would _delta be a better place for this? - # - # We want to go through and check that a parent key doesn't exist already - # when adding a nested key. - key_parts = key.split('.') - partial = key_parts[0] + if '.' in key: + key, rest = key.split('.', 1) + key = self._db_field_map.get(key, key) + key = '%s.%s' % (key, rest) + else: + key = self._db_field_map.get(key, key) - if partial in self._changed_fields: - return - - for part in key_parts[1:]: - partial += '.' + part - if partial in self._changed_fields: - return - - if (hasattr(self, '_changed_fields') and - key not in self._changed_fields): + if key not in self._changed_fields: self._changed_fields.append(key) def _clear_changed_fields(self): @@ -443,6 +433,7 @@ class BaseDocument(object): ReferenceField = _import_class("ReferenceField") changed_fields = [] changed_fields += getattr(self, '_changed_fields', []) + inspected = inspected or set() if hasattr(self, 'id') and isinstance(self.id, Hashable): if self.id in inspected: @@ -495,7 +486,10 @@ class BaseDocument(object): if isinstance(d, (ObjectId, DBRef)): break elif isinstance(d, list) and p.isdigit(): - d = d[int(p)] + try: + d = d[int(p)] + except IndexError: + d = None elif hasattr(d, 'get'): d = d.get(p) new_path.append(p) diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index ceda403e..44bb6ad2 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -204,7 +204,8 @@ class DeReference(object): elif isinstance(v, (list, tuple)) and depth <= self.max_depth: data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: - data[k] = self._attach_objects(v, depth - 1, instance=instance, name=name) + item_name = '%s.%s' % (name, k) if name else name + data[k] = self._attach_objects(v, depth - 1, instance=instance, name=item_name) elif hasattr(v, 'id'): data[k] = self.object_map.get(v.id, v) diff --git a/tests/document/delta.py b/tests/document/delta.py index b0f5f01a..292d8255 100644 --- a/tests/document/delta.py +++ b/tests/document/delta.py @@ -207,22 +207,21 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field[2].string_field = 'hello world' doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] self.assertEqual(doc._get_changed_fields(), - ['embedded_field.list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'list_field': ['1', 2, { - '_cls': 'Embedded', - 'string_field': 'hello world', - 'int_field': 1, - 'list_field': ['1', 2, {'hello': 'world'}], - 'dict_field': {'hello': 'world'}}]}, {})) - self.assertEqual(doc._delta(), ({ - 'embedded_field.list_field': ['1', 2, { + ['embedded_field.list_field.2']) + self.assertEqual(doc.embedded_field._delta(), ({'list_field.2': { '_cls': 'Embedded', 'string_field': 'hello world', 'int_field': 1, 'list_field': ['1', 2, {'hello': 'world'}], 'dict_field': {'hello': 'world'}} - ]}, {})) + }, {})) + self.assertEqual(doc._delta(), ({'embedded_field.list_field.2': { + '_cls': 'Embedded', + 'string_field': 'hello world', + 'int_field': 1, + 'list_field': ['1', 2, {'hello': 'world'}], + 'dict_field': {'hello': 'world'}} + }, {})) doc.save() doc = doc.reload(10) self.assertEqual(doc.embedded_field.list_field[2].string_field, @@ -253,7 +252,7 @@ class DeltaTest(unittest.TestCase): del(doc.embedded_field.list_field[2].list_field[2]['hello']) self.assertEqual(doc._delta(), - ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {})) + ({}, {'embedded_field.list_field.2.list_field.2.hello': 1})) doc.save() doc = doc.reload(10) @@ -548,22 +547,21 @@ class DeltaTest(unittest.TestCase): doc.embedded_field.list_field[2].string_field = 'hello world' doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] self.assertEqual(doc._get_changed_fields(), - ['db_embedded_field.db_list_field']) - self.assertEqual(doc.embedded_field._delta(), ({ - 'db_list_field': ['1', 2, { + ['db_embedded_field.db_list_field.2']) + self.assertEqual(doc.embedded_field._delta(), ({'db_list_field.2': { '_cls': 'Embedded', 'db_string_field': 'hello world', 'db_int_field': 1, 'db_list_field': ['1', 2, {'hello': 'world'}], - 'db_dict_field': {'hello': 'world'}}]}, {})) + 'db_dict_field': {'hello': 'world'}}}, {})) self.assertEqual(doc._delta(), ({ - 'db_embedded_field.db_list_field': ['1', 2, { + 'db_embedded_field.db_list_field.2': { '_cls': 'Embedded', 'db_string_field': 'hello world', 'db_int_field': 1, 'db_list_field': ['1', 2, {'hello': 'world'}], 'db_dict_field': {'hello': 'world'}} - ]}, {})) + }, {})) doc.save() doc = doc.reload(10) self.assertEqual(doc.embedded_field.list_field[2].string_field, @@ -594,8 +592,7 @@ class DeltaTest(unittest.TestCase): del(doc.embedded_field.list_field[2].list_field[2]['hello']) self.assertEqual(doc._delta(), - ({'db_embedded_field.db_list_field.2.db_list_field': - [1, 2, {}]}, {})) + ({}, {'db_embedded_field.db_list_field.2.db_list_field.2.hello': 1})) doc.save() doc = doc.reload(10) diff --git a/tests/document/instance.py b/tests/document/instance.py index 07db85a0..5d3fdc21 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -398,8 +398,8 @@ class InstanceTest(unittest.TestCase): doc.embedded_field.dict_field['woot'] = "woot" self.assertEqual(doc._get_changed_fields(), [ - 'list_field', 'dict_field', 'embedded_field.list_field', - 'embedded_field.dict_field']) + 'list_field', 'dict_field.woot', 'embedded_field.list_field', + 'embedded_field.dict_field.woot']) doc.save() doc = doc.reload(10)