From 83fff80b0fa45a81a7fa3b219bbb91984884902d Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Fri, 25 Nov 2011 08:28:20 -0800 Subject: [PATCH] Cleaned up dereferencing Dereferencing now respects max_depth, so should be more performant. Reload is chainable and can be passed a max_depth for dereferencing Added an Observer for ComplexBaseFields. Refs #324 #323 #289 Closes #320 --- docs/changelog.rst | 1 + mongoengine/base.py | 95 ++++++++++++++++++++++---------------- mongoengine/dereference.py | 45 +++++++++--------- mongoengine/document.py | 24 ++++++---- mongoengine/queryset.py | 2 + tests/document.py | 52 ++++++++++----------- 6 files changed, 122 insertions(+), 97 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 6ce81e5b..df897f37 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in dev ============== +- Fixed dereferencing - max_depth now taken into account - Fixed document mutation saving issue - Fixed positional operator when replacing embedded documents - Added Non-Django Style choices back (you can have either) diff --git a/mongoengine/base.py b/mongoengine/base.py index 4198e8b2..c801c2ee 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -155,9 +155,11 @@ class BaseField(object): # Convert lists / values so we can watch for any changes on them if isinstance(value, (list, tuple)) and not isinstance(value, BaseList): - value = BaseList(value, instance=instance, name=self.name) + observer = DataObserver(instance, self.name) + value = BaseList(value, observer) elif isinstance(value, dict) and not isinstance(value, BaseDict): - value = BaseDict(value, instance=instance, name=self.name) + observer = DataObserver(instance, self.name) + value = BaseDict(value, observer) return value def __set__(self, instance, value): @@ -237,7 +239,7 @@ class ComplexBaseField(BaseField): from dereference import dereference instance._data[self.name] = dereference( - instance._data.get(self.name), max_depth=1, instance=instance, name=self.name, get=True + instance._data.get(self.name), max_depth=1, instance=instance, name=self.name ) return super(ComplexBaseField, self).__get__(instance, owner) @@ -780,9 +782,11 @@ class BaseDocument(object): # Convert lists / values so we can watch for any changes on them if isinstance(value, (list, tuple)) and not isinstance(value, BaseList): - value = BaseList(value, instance=self, name=name) + observer = DataObserver(self, name) + value = BaseList(value, observer) elif isinstance(value, dict) and not isinstance(value, BaseDict): - value = BaseDict(value, instance=self, name=name) + observer = DataObserver(self, name) + value = BaseDict(value, observer) return value @@ -1122,102 +1126,113 @@ class BaseDocument(object): return hash(self.pk) +class DataObserver(object): + + __slots__ = ["instance", "name"] + + def __init__(self, instance, name): + self.instance = instance + self.name = name + + def updated(self): + self.instance._mark_as_changed(self.name) + + class BaseList(list): """A special list so we can watch any changes """ - def __init__(self, list_items, instance, name): - self.instance = instance - self.name = name + def __init__(self, list_items, observer): + self.observer = observer super(BaseList, self).__init__(list_items) def __setitem__(self, *args, **kwargs): - self._mark_as_changed() + self.observer.updated() super(BaseList, self).__setitem__(*args, **kwargs) def __delitem__(self, *args, **kwargs): - self._mark_as_changed() + self.observer.updated() super(BaseList, self).__delitem__(*args, **kwargs) + def __getstate__(self): + self.observer = None + return self + + def __setstate__(self, state): + self = state + def append(self, *args, **kwargs): - self._mark_as_changed() + self.observer.updated() return super(BaseList, self).append(*args, **kwargs) def extend(self, *args, **kwargs): - self._mark_as_changed() + self.observer.updated() return super(BaseList, self).extend(*args, **kwargs) def insert(self, *args, **kwargs): - self._mark_as_changed() + self.observer.updated() return super(BaseList, self).insert(*args, **kwargs) def pop(self, *args, **kwargs): - self._mark_as_changed() + self.observer.updated() return super(BaseList, self).pop(*args, **kwargs) def remove(self, *args, **kwargs): - self._mark_as_changed() + self.observer.updated() return super(BaseList, self).remove(*args, **kwargs) def reverse(self, *args, **kwargs): - self._mark_as_changed() + self.observer.updated() return super(BaseList, self).reverse(*args, **kwargs) def sort(self, *args, **kwargs): - self._mark_as_changed() + self.observer.updated() return super(BaseList, self).sort(*args, **kwargs) - def _mark_as_changed(self): - """Marks a list as changed if has an instance and a name""" - if hasattr(self, 'instance') and hasattr(self, 'name'): - self.instance._mark_as_changed(self.name) - class BaseDict(dict): """A special dict so we can watch any changes """ - def __init__(self, dict_items, instance, name): - self.instance = instance - self.name = name + def __init__(self, dict_items, observer): + self.observer = observer super(BaseDict, self).__init__(dict_items) def __setitem__(self, *args, **kwargs): - self._mark_as_changed() + self.observer.updated() super(BaseDict, self).__setitem__(*args, **kwargs) - def __setattr__(self, *args, **kwargs): - self._mark_as_changed() - super(BaseDict, self).__setattr__(*args, **kwargs) - def __delete__(self, *args, **kwargs): - self._mark_as_changed() + self.observer.updated() super(BaseDict, self).__delete__(*args, **kwargs) def __delitem__(self, *args, **kwargs): - self._mark_as_changed() + self.observer.updated() super(BaseDict, self).__delitem__(*args, **kwargs) def __delattr__(self, *args, **kwargs): - self._mark_as_changed() + self.observer.updated() super(BaseDict, self).__delattr__(*args, **kwargs) + def __getstate__(self): + self.observer = None + return self + + def __setstate__(self, state): + self = state + def clear(self, *args, **kwargs): - self._mark_as_changed() + self.observer.updated() super(BaseDict, self).clear(*args, **kwargs) def pop(self, *args, **kwargs): - self._mark_as_changed() + self.observer.updated() super(BaseDict, self).clear(*args, **kwargs) def popitem(self, *args, **kwargs): - self._mark_as_changed() + self.observer.updated() super(BaseDict, self).clear(*args, **kwargs) - def _mark_as_changed(self): - """Marks a dict as changed if has an instance and a name""" - if hasattr(self, 'instance') and hasattr(self, 'name'): - self.instance._mark_as_changed(self.name) if sys.version_info < (2, 5): # Prior to Python 2.5, Exception was an old-style class diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index d817a037..4e595b19 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -1,6 +1,7 @@ import pymongo -from base import BaseDict, BaseList, get_document, TopLevelDocumentMetaclass +from base import (BaseDict, BaseList, DataObserver, + TopLevelDocumentMetaclass, get_document) from fields import ReferenceField from connection import get_db from queryset import QuerySet @@ -9,7 +10,7 @@ from document import Document class DeReference(object): - def __call__(self, items, max_depth=1, instance=None, name=None, get=False): + def __call__(self, items, max_depth=1, instance=None, name=None): """ Cheaply dereferences the items to a set depth. Also handles the convertion of complex data types. @@ -43,7 +44,7 @@ class DeReference(object): self.reference_map = self._find_references(items) self.object_map = self._fetch_objects(doc_type=doc_type) - return self._attach_objects(items, 0, instance, name, get) + return self._attach_objects(items, 0, instance, name) def _find_references(self, items, depth=0): """ @@ -53,7 +54,7 @@ class DeReference(object): :param depth: The current depth of recursion """ reference_map = {} - if not items: + if not items or depth >= self.max_depth: return reference_map # Determine the iterator to use @@ -63,6 +64,7 @@ class DeReference(object): iterator = items.iteritems() # Recursively find dbreferences + depth += 1 for k, item in iterator: if hasattr(item, '_fields'): for field_name, field in item._fields.iteritems(): @@ -82,11 +84,11 @@ class DeReference(object): reference_map.setdefault(item.collection, []).append(item.id) elif isinstance(item, (dict, pymongo.son.SON)) and '_ref' in item: reference_map.setdefault(get_document(item['_cls']), []).append(item['_ref'].id) - elif isinstance(item, (dict, list, tuple)) and depth <= self.max_depth: - references = self._find_references(item, depth) + elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth: + references = self._find_references(item, depth - 1) for key, refs in references.iteritems(): reference_map.setdefault(key, []).extend(refs) - depth += 1 + return reference_map def _fetch_objects(self, doc_type=None): @@ -110,7 +112,7 @@ class DeReference(object): object_map[doc.id] = doc return object_map - def _attach_objects(self, items, depth=0, instance=None, name=None, get=False): + def _attach_objects(self, items, depth=0, instance=None, name=None): """ Recursively finds all db references to be dereferenced @@ -120,25 +122,24 @@ class DeReference(object): :class:`~mongoengine.base.ComplexBaseField` :param name: The name of the field, used for tracking changes by :class:`~mongoengine.base.ComplexBaseField` - :param get: A boolean determining if being called by __get__ """ if not items: if isinstance(items, (BaseDict, BaseList)): return items if instance: + observer = DataObserver(instance, name) if isinstance(items, dict): - return BaseDict(items, instance=instance, name=name) + return BaseDict(items, observer) else: - return BaseList(items, instance=instance, name=name) + return BaseList(items, observer) if isinstance(items, (dict, pymongo.son.SON)): if '_ref' in items: return self.object_map.get(items['_ref'].id, items) elif '_types' in items and '_cls' in items: doc = get_document(items['_cls'])._from_son(items) - if not get: - doc._data = self._attach_objects(doc._data, depth, doc, name, get) + doc._data = self._attach_objects(doc._data, depth, doc, name) return doc if not hasattr(items, 'items'): @@ -150,6 +151,7 @@ class DeReference(object): iterator = items.iteritems() data = {} + depth += 1 for k, v in iterator: if is_list: data.append(v) @@ -165,19 +167,20 @@ class DeReference(object): data[k]._data[field_name] = self.object_map.get(v.id, v) elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: data[k]._data[field_name] = self.object_map.get(v['_ref'].id, v) - elif isinstance(v, dict) and depth < self.max_depth: - data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name, get=get) - elif isinstance(v, (list, tuple)): - data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name, get=get) - elif isinstance(v, (dict, list, tuple)) and depth < self.max_depth: - data[k] = self._attach_objects(v, depth, instance=instance, name=name, get=get) + elif isinstance(v, dict) and depth <= self.max_depth: + data[k]._data[field_name] = self._attach_objects(v, depth - 1, instance=instance, name=name) + elif isinstance(v, (list, tuple)) and depth <= self.max_depth: + data[k]._data[field_name] = self._attach_objects(v, depth - 1, instance=instance, name=name) + elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: + data[k] = self._attach_objects(v, depth - 1, instance=instance, name=name) elif hasattr(v, 'id'): data[k] = self.object_map.get(v.id, v) if instance and name: + observer = DataObserver(instance, name) if is_list: - return BaseList(data, instance=instance, name=name) - return BaseDict(data, instance=instance, name=name) + return BaseList(data, observer) + return BaseDict(data, observer) depth += 1 return data diff --git a/mongoengine/document.py b/mongoengine/document.py index f3893ddc..76685ae2 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,14 +1,13 @@ from mongoengine import signals from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, - ValidationError, BaseDict, BaseList, BaseDynamicField) + BaseDict, BaseList, DataObserver) from queryset import OperationError from connection import get_db import pymongo __all__ = ['Document', 'EmbeddedDocument', 'DynamicDocument', - 'DynamicEmbeddedDocument', 'ValidationError', 'OperationError', - 'InvalidCollectionError'] + 'DynamicEmbeddedDocument', 'OperationError', 'InvalidCollectionError'] class InvalidCollectionError(Exception): @@ -250,31 +249,36 @@ class Document(BaseDocument): self._data = dereference(self._data, max_depth) return self - def reload(self): + def reload(self, max_depth=1): """Reloads all attributes from the database. .. versionadded:: 0.1.2 + .. versionchanged:: 0.6 Now chainable """ id_field = self._meta['id_field'] - obj = self.__class__.objects(**{id_field: self[id_field]}).first() - + obj = self.__class__.objects( + **{id_field: self[id_field]} + ).first().select_related(max_depth=max_depth) for field in self._fields: setattr(self, field, self._reload(field, obj[field])) if self._dynamic: for name in self._dynamic_fields.keys(): setattr(self, name, self._reload(name, obj._data[name])) - self._changed_fields = [] + self._changed_fields = obj._changed_fields + return obj def _reload(self, key, value): """Used by :meth:`~mongoengine.Document.reload` to ensure the correct instance is linked to self. """ if isinstance(value, BaseDict): - value = [(k, self._reload(k,v)) for k,v in value.items()] - value = BaseDict(value, instance=self, name=key) + value = [(k, self._reload(k, v)) for k, v in value.items()] + observer = DataObserver(self, key) + value = BaseDict(value, observer) elif isinstance(value, BaseList): value = [self._reload(key, v) for v in value] - value = BaseList(value, instance=self, name=key) + observer = DataObserver(self, key) + value = BaseList(value, observer) elif isinstance(value, (EmbeddedDocument, DynamicEmbeddedDocument)): value._changed_fields = [] return value diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 4185e39d..a9b3ea99 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -1675,6 +1675,8 @@ class QuerySet(object): .. versionadded:: 0.5 """ from dereference import dereference + # Make select related work the same for querysets + max_depth += 1 return dereference(self, max_depth=max_depth) diff --git a/tests/document.py b/tests/document.py index dca5fed9..a1cfcf42 100644 --- a/tests/document.py +++ b/tests/document.py @@ -1069,7 +1069,7 @@ class DocumentTest(unittest.TestCase): doc.embedded_field = embedded_1 doc.save() - doc.reload() + doc = doc.reload(10) doc.list_field.append(1) doc.dict_field['woot'] = "woot" doc.embedded_field.list_field.append(1) @@ -1080,7 +1080,7 @@ class DocumentTest(unittest.TestCase): 'embedded_field.dict_field']) doc.save() - doc.reload() + doc = doc.reload(10) self.assertEquals(doc._get_changed_fields(), []) self.assertEquals(len(doc.list_field), 4) self.assertEquals(len(doc.dict_field), 2) @@ -1502,14 +1502,14 @@ class DocumentTest(unittest.TestCase): self.assertEquals(doc._delta(), ({'embedded_field': embedded_delta}, {})) doc.save() - doc.reload() + doc = doc.reload(10) doc.embedded_field.dict_field = {} self.assertEquals(doc._get_changed_fields(), ['embedded_field.dict_field']) self.assertEquals(doc.embedded_field._delta(), ({}, {'dict_field': 1})) self.assertEquals(doc._delta(), ({}, {'embedded_field.dict_field': 1})) doc.save() - doc.reload() + doc = doc.reload(10) self.assertEquals(doc.embedded_field.dict_field, {}) doc.embedded_field.list_field = [] @@ -1517,7 +1517,7 @@ class DocumentTest(unittest.TestCase): self.assertEquals(doc.embedded_field._delta(), ({}, {'list_field': 1})) self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field': 1})) doc.save() - doc.reload() + doc = doc.reload(10) self.assertEquals(doc.embedded_field.list_field, []) embedded_2 = Embedded() @@ -1550,7 +1550,7 @@ class DocumentTest(unittest.TestCase): }] }, {})) doc.save() - doc.reload() + doc = doc.reload(10) self.assertEquals(doc.embedded_field.list_field[0], '1') self.assertEquals(doc.embedded_field.list_field[1], 2) @@ -1562,7 +1562,7 @@ class DocumentTest(unittest.TestCase): self.assertEquals(doc.embedded_field._delta(), ({'list_field.2.string_field': 'world'}, {})) self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.string_field': 'world'}, {})) doc.save() - doc.reload() + doc = doc.reload(10) self.assertEquals(doc.embedded_field.list_field[2].string_field, 'world') # Test multiple assignments @@ -1587,40 +1587,40 @@ class DocumentTest(unittest.TestCase): 'dict_field': {'hello': 'world'}} ]}, {})) doc.save() - doc.reload() + doc = doc.reload(10) self.assertEquals(doc.embedded_field.list_field[2].string_field, 'hello world') # Test list native methods doc.embedded_field.list_field[2].list_field.pop(0) self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}]}, {})) doc.save() - doc.reload() + doc = doc.reload(10) doc.embedded_field.list_field[2].list_field.append(1) self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [2, {'hello': 'world'}, 1]}, {})) doc.save() - doc.reload() + doc = doc.reload(10) self.assertEquals(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1]) doc.embedded_field.list_field[2].list_field.sort() doc.save() - doc.reload() + doc = doc.reload(10) self.assertEquals(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}]) del(doc.embedded_field.list_field[2].list_field[2]['hello']) self.assertEquals(doc._delta(), ({'embedded_field.list_field.2.list_field': [1, 2, {}]}, {})) doc.save() - doc.reload() + doc = doc.reload(10) del(doc.embedded_field.list_field[2].list_field) self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field.2.list_field': 1})) doc.save() - doc.reload() + doc = doc.reload(10) doc.dict_field['Embedded'] = embedded_1 doc.save() - doc.reload() + doc = doc.reload(10) doc.dict_field['Embedded'].string_field = 'Hello World' self.assertEquals(doc._get_changed_fields(), ['dict_field.Embedded.string_field']) @@ -1684,7 +1684,7 @@ class DocumentTest(unittest.TestCase): doc.dict_field = {'hello': 'world'} doc.list_field = ['1', 2, {'hello': 'world'}] doc.save() - doc.reload() + doc = doc.reload(10) self.assertEquals(doc.string_field, 'hello') self.assertEquals(doc.int_field, 1) @@ -1735,14 +1735,14 @@ class DocumentTest(unittest.TestCase): self.assertEquals(doc._delta(), ({'db_embedded_field': embedded_delta}, {})) doc.save() - doc.reload() + doc = doc.reload(10) doc.embedded_field.dict_field = {} self.assertEquals(doc._get_changed_fields(), ['db_embedded_field.db_dict_field']) self.assertEquals(doc.embedded_field._delta(), ({}, {'db_dict_field': 1})) self.assertEquals(doc._delta(), ({}, {'db_embedded_field.db_dict_field': 1})) doc.save() - doc.reload() + doc = doc.reload(10) self.assertEquals(doc.embedded_field.dict_field, {}) doc.embedded_field.list_field = [] @@ -1750,7 +1750,7 @@ class DocumentTest(unittest.TestCase): self.assertEquals(doc.embedded_field._delta(), ({}, {'db_list_field': 1})) self.assertEquals(doc._delta(), ({}, {'db_embedded_field.db_list_field': 1})) doc.save() - doc.reload() + doc = doc.reload(10) self.assertEquals(doc.embedded_field.list_field, []) embedded_2 = Embedded() @@ -1783,7 +1783,7 @@ class DocumentTest(unittest.TestCase): }] }, {})) doc.save() - doc.reload() + doc = doc.reload(10) self.assertEquals(doc.embedded_field.list_field[0], '1') self.assertEquals(doc.embedded_field.list_field[1], 2) @@ -1795,7 +1795,7 @@ class DocumentTest(unittest.TestCase): self.assertEquals(doc.embedded_field._delta(), ({'db_list_field.2.db_string_field': 'world'}, {})) self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_string_field': 'world'}, {})) doc.save() - doc.reload() + doc = doc.reload(10) self.assertEquals(doc.embedded_field.list_field[2].string_field, 'world') # Test multiple assignments @@ -1820,30 +1820,30 @@ class DocumentTest(unittest.TestCase): 'db_dict_field': {'hello': 'world'}} ]}, {})) doc.save() - doc.reload() + doc = doc.reload(10) self.assertEquals(doc.embedded_field.list_field[2].string_field, 'hello world') # Test list native methods doc.embedded_field.list_field[2].list_field.pop(0) self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [2, {'hello': 'world'}]}, {})) doc.save() - doc.reload() + doc = doc.reload(10) doc.embedded_field.list_field[2].list_field.append(1) self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [2, {'hello': 'world'}, 1]}, {})) doc.save() - doc.reload() + doc = doc.reload(10) self.assertEquals(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1]) doc.embedded_field.list_field[2].list_field.sort() doc.save() - doc.reload() + doc = doc.reload(10) self.assertEquals(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}]) del(doc.embedded_field.list_field[2].list_field[2]['hello']) self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [1, 2, {}]}, {})) doc.save() - doc.reload() + doc = doc.reload(10) del(doc.embedded_field.list_field[2].list_field) self.assertEquals(doc._delta(), ({}, {'db_embedded_field.db_list_field.2.db_list_field': 1})) @@ -2344,7 +2344,7 @@ class DocumentTest(unittest.TestCase): resurrected.string = "Two" resurrected.save() - pickle_doc.reload() + pickle_doc = pickle_doc.reload() self.assertEquals(resurrected, pickle_doc) def test_throw_invalid_document_error(self):