improved update queries for BaseDict & BaseList

Migrate changes to include updating single elements of ListFields as
well as MapFields by adding the same changes to BaseList. This is
done by ensuring all BaseDicts and BaseLists have the correct name
from the base of the nearest (Embedded)Document, then marking changes
with their key or index when they are changed.

Tests also all fixed up.
This commit is contained in:
Damien Churchill 2014-03-12 15:07:40 +00:00
parent 85e271098f
commit db7f93cff3
5 changed files with 64 additions and 57 deletions

View File

@ -20,19 +20,23 @@ class BaseDict(dict):
self._name = name self._name = name
return super(BaseDict, self).__init__(dict_items) return super(BaseDict, self).__init__(dict_items)
def __getitem__(self, key): def __getitem__(self, key, *args, **kwargs):
value = super(BaseDict, self).__getitem__(key) value = super(BaseDict, self).__getitem__(key)
EmbeddedDocument = _import_class('EmbeddedDocument') EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(value, EmbeddedDocument) and value._instance is None: if isinstance(value, EmbeddedDocument) and value._instance is None:
value._instance = self._instance value._instance = self._instance
elif isinstance(value, dict): elif not isinstance(value, BaseDict) and isinstance(value, dict):
value = BaseDict(value, None, '%s.%s' % (self._name, key)) value = BaseDict(value, None, '%s.%s' % (self._name, key))
super(BaseDict, self).__setitem__(key, value) super(BaseDict, self).__setitem__(key, value)
value._instance = self._instance value._instance = self._instance
elif not isinstance(value, BaseList) and isinstance(value, list):
value = BaseList(value, None, '%s.%s' % (self._name, key))
super(BaseDict, self).__setitem__(key, value)
value._instance = self._instance
return value return value
def __setitem__(self, key, value): def __setitem__(self, key, value, *args, **kwargs):
self._mark_as_changed(key) self._mark_as_changed(key)
return super(BaseDict, self).__setitem__(key, value) return super(BaseDict, self).__setitem__(key, value)
@ -40,13 +44,13 @@ class BaseDict(dict):
self._mark_as_changed() self._mark_as_changed()
return super(BaseDict, self).__delete__(*args, **kwargs) return super(BaseDict, self).__delete__(*args, **kwargs)
def __delitem__(self, key): def __delitem__(self, key, *args, **kwargs):
self._mark_as_changed(key) self._mark_as_changed(key)
return super(BaseDict, self).__delitem__(key) return super(BaseDict, self).__delitem__(key)
def __delattr__(self, key): def __delattr__(self, key, *args, **kwargs):
self._mark_as_changed(key) self._mark_as_changed(key)
return super(BaseDict, self).__delattr__(*args, **kwargs) return super(BaseDict, self).__delattr__(key)
def __getstate__(self): def __getstate__(self):
self.instance = None self.instance = None
@ -98,21 +102,29 @@ class BaseList(list):
self._name = name self._name = name
return super(BaseList, self).__init__(list_items) return super(BaseList, self).__init__(list_items)
def __getitem__(self, *args, **kwargs): def __getitem__(self, key, *args, **kwargs):
value = super(BaseList, self).__getitem__(*args, **kwargs) value = super(BaseList, self).__getitem__(key)
EmbeddedDocument = _import_class('EmbeddedDocument') EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(value, EmbeddedDocument) and value._instance is None: if isinstance(value, EmbeddedDocument) and value._instance is None:
value._instance = self._instance value._instance = self._instance
elif not isinstance(value, BaseDict) and isinstance(value, dict):
value = BaseDict(value, None, '%s.%s' % (self._name, key))
super(BaseList, self).__setitem__(key, value)
value._instance = self._instance
elif not isinstance(value, BaseList) and isinstance(value, list):
value = BaseList(value, None, '%s.%s' % (self._name, key))
super(BaseList, self).__setitem__(key, value)
value._instance = self._instance
return value return value
def __setitem__(self, *args, **kwargs): def __setitem__(self, key, value, *args, **kwargs):
self._mark_as_changed() self._mark_as_changed(key)
return super(BaseList, self).__setitem__(*args, **kwargs) return super(BaseList, self).__setitem__(key, value)
def __delitem__(self, *args, **kwargs): def __delitem__(self, key, *args, **kwargs):
self._mark_as_changed() self._mark_as_changed(key)
return super(BaseList, self).__delitem__(*args, **kwargs) return super(BaseList, self).__delitem__(key)
def __setslice__(self, *args, **kwargs): def __setslice__(self, *args, **kwargs):
self._mark_as_changed() self._mark_as_changed()
@ -159,6 +171,9 @@ class BaseList(list):
self._mark_as_changed() self._mark_as_changed()
return super(BaseList, self).sort(*args, **kwargs) return super(BaseList, self).sort(*args, **kwargs)
def _mark_as_changed(self): def _mark_as_changed(self, key=None):
if hasattr(self._instance, '_mark_as_changed'): if hasattr(self._instance, '_mark_as_changed'):
self._instance._mark_as_changed(self._name) if key:
self._instance._mark_as_changed('%s.%s' % (self._name, key))
else:
self._instance._mark_as_changed(self._name)

View File

@ -371,27 +371,17 @@ class BaseDocument(object):
if not key: if not key:
return return
key = self._db_field_map.get(key, key)
if not hasattr(self, '_changed_fields'): if not hasattr(self, '_changed_fields'):
return return
# FIXME: would _delta be a better place for this? if '.' in key:
# key, rest = key.split('.', 1)
# We want to go through and check that a parent key doesn't exist already key = self._db_field_map.get(key, key)
# when adding a nested key. key = '%s.%s' % (key, rest)
key_parts = key.split('.') else:
partial = key_parts[0] key = self._db_field_map.get(key, key)
if partial in self._changed_fields: if key not in self._changed_fields:
return
for part in key_parts[1:]:
partial += '.' + part
if partial in self._changed_fields:
return
if (hasattr(self, '_changed_fields') and
key not in self._changed_fields):
self._changed_fields.append(key) self._changed_fields.append(key)
def _clear_changed_fields(self): def _clear_changed_fields(self):
@ -443,6 +433,7 @@ class BaseDocument(object):
ReferenceField = _import_class("ReferenceField") ReferenceField = _import_class("ReferenceField")
changed_fields = [] changed_fields = []
changed_fields += getattr(self, '_changed_fields', []) changed_fields += getattr(self, '_changed_fields', [])
inspected = inspected or set() inspected = inspected or set()
if hasattr(self, 'id') and isinstance(self.id, Hashable): if hasattr(self, 'id') and isinstance(self.id, Hashable):
if self.id in inspected: if self.id in inspected:
@ -495,7 +486,10 @@ class BaseDocument(object):
if isinstance(d, (ObjectId, DBRef)): if isinstance(d, (ObjectId, DBRef)):
break break
elif isinstance(d, list) and p.isdigit(): elif isinstance(d, list) and p.isdigit():
d = d[int(p)] try:
d = d[int(p)]
except IndexError:
d = None
elif hasattr(d, 'get'): elif hasattr(d, 'get'):
d = d.get(p) d = d.get(p)
new_path.append(p) new_path.append(p)

View File

@ -204,7 +204,8 @@ class DeReference(object):
elif isinstance(v, (list, tuple)) and depth <= self.max_depth: elif isinstance(v, (list, tuple)) and depth <= self.max_depth:
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name) data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
data[k] = self._attach_objects(v, depth - 1, instance=instance, name=name) item_name = '%s.%s' % (name, k) if name else name
data[k] = self._attach_objects(v, depth - 1, instance=instance, name=item_name)
elif hasattr(v, 'id'): elif hasattr(v, 'id'):
data[k] = self.object_map.get(v.id, v) data[k] = self.object_map.get(v.id, v)

View File

@ -207,22 +207,21 @@ class DeltaTest(unittest.TestCase):
doc.embedded_field.list_field[2].string_field = 'hello world' doc.embedded_field.list_field[2].string_field = 'hello world'
doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2]
self.assertEqual(doc._get_changed_fields(), self.assertEqual(doc._get_changed_fields(),
['embedded_field.list_field']) ['embedded_field.list_field.2'])
self.assertEqual(doc.embedded_field._delta(), ({ self.assertEqual(doc.embedded_field._delta(), ({'list_field.2': {
'list_field': ['1', 2, {
'_cls': 'Embedded',
'string_field': 'hello world',
'int_field': 1,
'list_field': ['1', 2, {'hello': 'world'}],
'dict_field': {'hello': 'world'}}]}, {}))
self.assertEqual(doc._delta(), ({
'embedded_field.list_field': ['1', 2, {
'_cls': 'Embedded', '_cls': 'Embedded',
'string_field': 'hello world', 'string_field': 'hello world',
'int_field': 1, 'int_field': 1,
'list_field': ['1', 2, {'hello': 'world'}], 'list_field': ['1', 2, {'hello': 'world'}],
'dict_field': {'hello': 'world'}} 'dict_field': {'hello': 'world'}}
]}, {})) }, {}))
self.assertEqual(doc._delta(), ({'embedded_field.list_field.2': {
'_cls': 'Embedded',
'string_field': 'hello world',
'int_field': 1,
'list_field': ['1', 2, {'hello': 'world'}],
'dict_field': {'hello': 'world'}}
}, {}))
doc.save() doc.save()
doc = doc.reload(10) doc = doc.reload(10)
self.assertEqual(doc.embedded_field.list_field[2].string_field, self.assertEqual(doc.embedded_field.list_field[2].string_field,
@ -253,7 +252,7 @@ class DeltaTest(unittest.TestCase):
del(doc.embedded_field.list_field[2].list_field[2]['hello']) del(doc.embedded_field.list_field[2].list_field[2]['hello'])
self.assertEqual(doc._delta(), self.assertEqual(doc._delta(),
({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {})) ({}, {'embedded_field.list_field.2.list_field.2.hello': 1}))
doc.save() doc.save()
doc = doc.reload(10) doc = doc.reload(10)
@ -548,22 +547,21 @@ class DeltaTest(unittest.TestCase):
doc.embedded_field.list_field[2].string_field = 'hello world' doc.embedded_field.list_field[2].string_field = 'hello world'
doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2]
self.assertEqual(doc._get_changed_fields(), self.assertEqual(doc._get_changed_fields(),
['db_embedded_field.db_list_field']) ['db_embedded_field.db_list_field.2'])
self.assertEqual(doc.embedded_field._delta(), ({ self.assertEqual(doc.embedded_field._delta(), ({'db_list_field.2': {
'db_list_field': ['1', 2, {
'_cls': 'Embedded', '_cls': 'Embedded',
'db_string_field': 'hello world', 'db_string_field': 'hello world',
'db_int_field': 1, 'db_int_field': 1,
'db_list_field': ['1', 2, {'hello': 'world'}], 'db_list_field': ['1', 2, {'hello': 'world'}],
'db_dict_field': {'hello': 'world'}}]}, {})) 'db_dict_field': {'hello': 'world'}}}, {}))
self.assertEqual(doc._delta(), ({ self.assertEqual(doc._delta(), ({
'db_embedded_field.db_list_field': ['1', 2, { 'db_embedded_field.db_list_field.2': {
'_cls': 'Embedded', '_cls': 'Embedded',
'db_string_field': 'hello world', 'db_string_field': 'hello world',
'db_int_field': 1, 'db_int_field': 1,
'db_list_field': ['1', 2, {'hello': 'world'}], 'db_list_field': ['1', 2, {'hello': 'world'}],
'db_dict_field': {'hello': 'world'}} 'db_dict_field': {'hello': 'world'}}
]}, {})) }, {}))
doc.save() doc.save()
doc = doc.reload(10) doc = doc.reload(10)
self.assertEqual(doc.embedded_field.list_field[2].string_field, self.assertEqual(doc.embedded_field.list_field[2].string_field,
@ -594,8 +592,7 @@ class DeltaTest(unittest.TestCase):
del(doc.embedded_field.list_field[2].list_field[2]['hello']) del(doc.embedded_field.list_field[2].list_field[2]['hello'])
self.assertEqual(doc._delta(), self.assertEqual(doc._delta(),
({'db_embedded_field.db_list_field.2.db_list_field': ({}, {'db_embedded_field.db_list_field.2.db_list_field.2.hello': 1}))
[1, 2, {}]}, {}))
doc.save() doc.save()
doc = doc.reload(10) doc = doc.reload(10)

View File

@ -398,8 +398,8 @@ class InstanceTest(unittest.TestCase):
doc.embedded_field.dict_field['woot'] = "woot" doc.embedded_field.dict_field['woot'] = "woot"
self.assertEqual(doc._get_changed_fields(), [ self.assertEqual(doc._get_changed_fields(), [
'list_field', 'dict_field', 'embedded_field.list_field', 'list_field', 'dict_field.woot', 'embedded_field.list_field',
'embedded_field.dict_field']) 'embedded_field.dict_field.woot'])
doc.save() doc.save()
doc = doc.reload(10) doc = doc.reload(10)