Cleaned up dereferencing
Dereferencing now respects max_depth, so should be more performant. Reload is chainable and can be passed a max_depth for dereferencing Added an Observer for ComplexBaseFields. Refs #324 #323 #289 Closes #320
This commit is contained in:
		| @@ -155,9 +155,11 @@ class BaseField(object): | ||||
|  | ||||
|         # Convert lists / values so we can watch for any changes on them | ||||
|         if isinstance(value, (list, tuple)) and not isinstance(value, BaseList): | ||||
|             value = BaseList(value, instance=instance, name=self.name) | ||||
|             observer = DataObserver(instance, self.name) | ||||
|             value = BaseList(value, observer) | ||||
|         elif isinstance(value, dict) and not isinstance(value, BaseDict): | ||||
|             value = BaseDict(value, instance=instance, name=self.name) | ||||
|             observer = DataObserver(instance, self.name) | ||||
|             value = BaseDict(value, observer) | ||||
|         return value | ||||
|  | ||||
|     def __set__(self, instance, value): | ||||
| @@ -237,7 +239,7 @@ class ComplexBaseField(BaseField): | ||||
|  | ||||
|         from dereference import dereference | ||||
|         instance._data[self.name] = dereference( | ||||
|             instance._data.get(self.name), max_depth=1, instance=instance, name=self.name, get=True | ||||
|             instance._data.get(self.name), max_depth=1, instance=instance, name=self.name | ||||
|         ) | ||||
|         return super(ComplexBaseField, self).__get__(instance, owner) | ||||
|  | ||||
| @@ -780,9 +782,11 @@ class BaseDocument(object): | ||||
|  | ||||
|         # Convert lists / values so we can watch for any changes on them | ||||
|         if isinstance(value, (list, tuple)) and not isinstance(value, BaseList): | ||||
|             value = BaseList(value, instance=self, name=name) | ||||
|             observer = DataObserver(self, name) | ||||
|             value = BaseList(value, observer) | ||||
|         elif isinstance(value, dict) and not isinstance(value, BaseDict): | ||||
|             value = BaseDict(value, instance=self, name=name) | ||||
|             observer = DataObserver(self, name) | ||||
|             value = BaseDict(value, observer) | ||||
|  | ||||
|         return value | ||||
|  | ||||
| @@ -1122,102 +1126,113 @@ class BaseDocument(object): | ||||
|             return hash(self.pk) | ||||
|  | ||||
|  | ||||
| class DataObserver(object): | ||||
|  | ||||
|     __slots__ = ["instance", "name"] | ||||
|  | ||||
|     def __init__(self, instance, name): | ||||
|         self.instance = instance | ||||
|         self.name = name | ||||
|  | ||||
|     def updated(self): | ||||
|         self.instance._mark_as_changed(self.name) | ||||
|  | ||||
|  | ||||
| class BaseList(list): | ||||
|     """A special list so we can watch any changes | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, list_items, instance, name): | ||||
|         self.instance = instance | ||||
|         self.name = name | ||||
|     def __init__(self, list_items, observer): | ||||
|         self.observer = observer | ||||
|         super(BaseList, self).__init__(list_items) | ||||
|  | ||||
|     def __setitem__(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         self.observer.updated() | ||||
|         super(BaseList, self).__setitem__(*args, **kwargs) | ||||
|  | ||||
|     def __delitem__(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         self.observer.updated() | ||||
|         super(BaseList, self).__delitem__(*args, **kwargs) | ||||
|  | ||||
|     def __getstate__(self): | ||||
|         self.observer = None | ||||
|         return self | ||||
|  | ||||
|     def __setstate__(self, state): | ||||
|         self = state | ||||
|  | ||||
|     def append(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         self.observer.updated() | ||||
|         return super(BaseList, self).append(*args, **kwargs) | ||||
|  | ||||
|     def extend(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         self.observer.updated() | ||||
|         return super(BaseList, self).extend(*args, **kwargs) | ||||
|  | ||||
|     def insert(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         self.observer.updated() | ||||
|         return super(BaseList, self).insert(*args, **kwargs) | ||||
|  | ||||
|     def pop(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         self.observer.updated() | ||||
|         return super(BaseList, self).pop(*args, **kwargs) | ||||
|  | ||||
|     def remove(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         self.observer.updated() | ||||
|         return super(BaseList, self).remove(*args, **kwargs) | ||||
|  | ||||
|     def reverse(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         self.observer.updated() | ||||
|         return super(BaseList, self).reverse(*args, **kwargs) | ||||
|  | ||||
|     def sort(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         self.observer.updated() | ||||
|         return super(BaseList, self).sort(*args, **kwargs) | ||||
|  | ||||
|     def _mark_as_changed(self): | ||||
|         """Marks a list as changed if has an instance and a name""" | ||||
|         if hasattr(self, 'instance') and hasattr(self, 'name'): | ||||
|             self.instance._mark_as_changed(self.name) | ||||
|  | ||||
|  | ||||
| class BaseDict(dict): | ||||
|     """A special dict so we can watch any changes | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, dict_items, instance, name): | ||||
|         self.instance = instance | ||||
|         self.name = name | ||||
|     def __init__(self, dict_items, observer): | ||||
|         self.observer = observer | ||||
|         super(BaseDict, self).__init__(dict_items) | ||||
|  | ||||
|     def __setitem__(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         self.observer.updated() | ||||
|         super(BaseDict, self).__setitem__(*args, **kwargs) | ||||
|  | ||||
|     def __setattr__(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         super(BaseDict, self).__setattr__(*args, **kwargs) | ||||
|  | ||||
|     def __delete__(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         self.observer.updated() | ||||
|         super(BaseDict, self).__delete__(*args, **kwargs) | ||||
|  | ||||
|     def __delitem__(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         self.observer.updated() | ||||
|         super(BaseDict, self).__delitem__(*args, **kwargs) | ||||
|  | ||||
|     def __delattr__(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         self.observer.updated() | ||||
|         super(BaseDict, self).__delattr__(*args, **kwargs) | ||||
|  | ||||
|     def __getstate__(self): | ||||
|         self.observer = None | ||||
|         return self | ||||
|  | ||||
|     def __setstate__(self, state): | ||||
|         self = state | ||||
|  | ||||
|     def clear(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         self.observer.updated() | ||||
|         super(BaseDict, self).clear(*args, **kwargs) | ||||
|  | ||||
|     def pop(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         self.observer.updated() | ||||
|         super(BaseDict, self).clear(*args, **kwargs) | ||||
|  | ||||
|     def popitem(self, *args, **kwargs): | ||||
|         self._mark_as_changed() | ||||
|         self.observer.updated() | ||||
|         super(BaseDict, self).clear(*args, **kwargs) | ||||
|  | ||||
|     def _mark_as_changed(self): | ||||
|         """Marks a dict as changed if has an instance and a name""" | ||||
|         if hasattr(self, 'instance') and hasattr(self, 'name'): | ||||
|             self.instance._mark_as_changed(self.name) | ||||
|  | ||||
| if sys.version_info < (2, 5): | ||||
|     # Prior to Python 2.5, Exception was an old-style class | ||||
|   | ||||
| @@ -1,6 +1,7 @@ | ||||
| import pymongo | ||||
|  | ||||
| from base import BaseDict, BaseList, get_document, TopLevelDocumentMetaclass | ||||
| from base import (BaseDict, BaseList, DataObserver, | ||||
|                   TopLevelDocumentMetaclass, get_document) | ||||
| from fields import ReferenceField | ||||
| from connection import get_db | ||||
| from queryset import QuerySet | ||||
| @@ -9,7 +10,7 @@ from document import Document | ||||
|  | ||||
| class DeReference(object): | ||||
|  | ||||
|     def __call__(self, items, max_depth=1, instance=None, name=None, get=False): | ||||
|     def __call__(self, items, max_depth=1, instance=None, name=None): | ||||
|         """ | ||||
|         Cheaply dereferences the items to a set depth. | ||||
|         Also handles the convertion of complex data types. | ||||
| @@ -43,7 +44,7 @@ class DeReference(object): | ||||
|  | ||||
|         self.reference_map = self._find_references(items) | ||||
|         self.object_map = self._fetch_objects(doc_type=doc_type) | ||||
|         return self._attach_objects(items, 0, instance, name, get) | ||||
|         return self._attach_objects(items, 0, instance, name) | ||||
|  | ||||
|     def _find_references(self, items, depth=0): | ||||
|         """ | ||||
| @@ -53,7 +54,7 @@ class DeReference(object): | ||||
|         :param depth: The current depth of recursion | ||||
|         """ | ||||
|         reference_map = {} | ||||
|         if not items: | ||||
|         if not items or depth >= self.max_depth: | ||||
|             return reference_map | ||||
|  | ||||
|         # Determine the iterator to use | ||||
| @@ -63,6 +64,7 @@ class DeReference(object): | ||||
|             iterator = items.iteritems() | ||||
|  | ||||
|         # Recursively find dbreferences | ||||
|         depth += 1 | ||||
|         for k, item in iterator: | ||||
|             if hasattr(item, '_fields'): | ||||
|                 for field_name, field in item._fields.iteritems(): | ||||
| @@ -82,11 +84,11 @@ class DeReference(object): | ||||
|                 reference_map.setdefault(item.collection, []).append(item.id) | ||||
|             elif isinstance(item, (dict, pymongo.son.SON)) and '_ref' in item: | ||||
|                 reference_map.setdefault(get_document(item['_cls']), []).append(item['_ref'].id) | ||||
|             elif isinstance(item, (dict, list, tuple)) and depth <= self.max_depth: | ||||
|                 references = self._find_references(item, depth) | ||||
|             elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth: | ||||
|                 references = self._find_references(item, depth - 1) | ||||
|                 for key, refs in references.iteritems(): | ||||
|                     reference_map.setdefault(key, []).extend(refs) | ||||
|         depth += 1 | ||||
|  | ||||
|         return reference_map | ||||
|  | ||||
|     def _fetch_objects(self, doc_type=None): | ||||
| @@ -110,7 +112,7 @@ class DeReference(object): | ||||
|                     object_map[doc.id] = doc | ||||
|         return object_map | ||||
|  | ||||
|     def _attach_objects(self, items, depth=0, instance=None, name=None, get=False): | ||||
|     def _attach_objects(self, items, depth=0, instance=None, name=None): | ||||
|         """ | ||||
|         Recursively finds all db references to be dereferenced | ||||
|  | ||||
| @@ -120,25 +122,24 @@ class DeReference(object): | ||||
|             :class:`~mongoengine.base.ComplexBaseField` | ||||
|         :param name: The name of the field, used for tracking changes by | ||||
|             :class:`~mongoengine.base.ComplexBaseField` | ||||
|         :param get: A boolean determining if being called by __get__ | ||||
|         """ | ||||
|         if not items: | ||||
|             if isinstance(items, (BaseDict, BaseList)): | ||||
|                 return items | ||||
|  | ||||
|             if instance: | ||||
|                 observer = DataObserver(instance, name) | ||||
|                 if isinstance(items, dict): | ||||
|                     return BaseDict(items, instance=instance, name=name) | ||||
|                     return BaseDict(items, observer) | ||||
|                 else: | ||||
|                     return BaseList(items, instance=instance, name=name) | ||||
|                     return BaseList(items, observer) | ||||
|  | ||||
|         if isinstance(items, (dict, pymongo.son.SON)): | ||||
|             if '_ref' in items: | ||||
|                 return self.object_map.get(items['_ref'].id, items) | ||||
|             elif '_types' in items and '_cls' in items: | ||||
|                 doc = get_document(items['_cls'])._from_son(items) | ||||
|                 if not get: | ||||
|                     doc._data = self._attach_objects(doc._data, depth, doc, name, get) | ||||
|                 doc._data = self._attach_objects(doc._data, depth, doc, name) | ||||
|                 return doc | ||||
|  | ||||
|         if not hasattr(items, 'items'): | ||||
| @@ -150,6 +151,7 @@ class DeReference(object): | ||||
|             iterator = items.iteritems() | ||||
|             data = {} | ||||
|  | ||||
|         depth += 1 | ||||
|         for k, v in iterator: | ||||
|             if is_list: | ||||
|                 data.append(v) | ||||
| @@ -165,19 +167,20 @@ class DeReference(object): | ||||
|                         data[k]._data[field_name] = self.object_map.get(v.id, v) | ||||
|                     elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v: | ||||
|                         data[k]._data[field_name] = self.object_map.get(v['_ref'].id, v) | ||||
|                     elif isinstance(v, dict) and depth < self.max_depth: | ||||
|                         data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name, get=get) | ||||
|                     elif isinstance(v, (list, tuple)): | ||||
|                         data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name, get=get) | ||||
|             elif isinstance(v, (dict, list, tuple)) and depth < self.max_depth: | ||||
|                 data[k] = self._attach_objects(v, depth, instance=instance, name=name, get=get) | ||||
|                     elif isinstance(v, dict) and depth <= self.max_depth: | ||||
|                         data[k]._data[field_name] = self._attach_objects(v, depth - 1, instance=instance, name=name) | ||||
|                     elif isinstance(v, (list, tuple)) and depth <= self.max_depth: | ||||
|                         data[k]._data[field_name] = self._attach_objects(v, depth - 1, instance=instance, name=name) | ||||
|             elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: | ||||
|                 data[k] = self._attach_objects(v, depth - 1, instance=instance, name=name) | ||||
|             elif hasattr(v, 'id'): | ||||
|                 data[k] = self.object_map.get(v.id, v) | ||||
|  | ||||
|         if instance and name: | ||||
|             observer = DataObserver(instance, name) | ||||
|             if is_list: | ||||
|                 return BaseList(data, instance=instance, name=name) | ||||
|             return BaseDict(data, instance=instance, name=name) | ||||
|                 return BaseList(data, observer) | ||||
|             return BaseDict(data, observer) | ||||
|         depth += 1 | ||||
|         return data | ||||
|  | ||||
|   | ||||
| @@ -1,14 +1,13 @@ | ||||
| from mongoengine import signals | ||||
| from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, | ||||
|                   ValidationError, BaseDict, BaseList, BaseDynamicField) | ||||
|                   BaseDict, BaseList, DataObserver) | ||||
| from queryset import OperationError | ||||
| from connection import get_db | ||||
|  | ||||
| import pymongo | ||||
|  | ||||
| __all__ = ['Document', 'EmbeddedDocument', 'DynamicDocument', | ||||
|            'DynamicEmbeddedDocument', 'ValidationError', 'OperationError', | ||||
|            'InvalidCollectionError'] | ||||
|            'DynamicEmbeddedDocument', 'OperationError', 'InvalidCollectionError'] | ||||
|  | ||||
|  | ||||
| class InvalidCollectionError(Exception): | ||||
| @@ -250,31 +249,36 @@ class Document(BaseDocument): | ||||
|         self._data = dereference(self._data, max_depth) | ||||
|         return self | ||||
|  | ||||
|     def reload(self): | ||||
|     def reload(self, max_depth=1): | ||||
|         """Reloads all attributes from the database. | ||||
|  | ||||
|         .. versionadded:: 0.1.2 | ||||
|         .. versionchanged:: 0.6  Now chainable | ||||
|         """ | ||||
|         id_field = self._meta['id_field'] | ||||
|         obj = self.__class__.objects(**{id_field: self[id_field]}).first() | ||||
|  | ||||
|         obj = self.__class__.objects( | ||||
|                 **{id_field: self[id_field]} | ||||
|               ).first().select_related(max_depth=max_depth) | ||||
|         for field in self._fields: | ||||
|             setattr(self, field, self._reload(field, obj[field])) | ||||
|         if self._dynamic: | ||||
|             for name in self._dynamic_fields.keys(): | ||||
|                 setattr(self, name, self._reload(name, obj._data[name])) | ||||
|         self._changed_fields = [] | ||||
|         self._changed_fields = obj._changed_fields | ||||
|         return obj | ||||
|  | ||||
|     def _reload(self, key, value): | ||||
|         """Used by :meth:`~mongoengine.Document.reload` to ensure the | ||||
|         correct instance is linked to self. | ||||
|         """ | ||||
|         if isinstance(value, BaseDict): | ||||
|             value = [(k, self._reload(k,v)) for k,v in value.items()] | ||||
|             value = BaseDict(value, instance=self, name=key) | ||||
|             value = [(k, self._reload(k, v)) for k, v in value.items()] | ||||
|             observer = DataObserver(self, key) | ||||
|             value = BaseDict(value, observer) | ||||
|         elif isinstance(value, BaseList): | ||||
|             value = [self._reload(key, v) for v in value] | ||||
|             value = BaseList(value, instance=self, name=key) | ||||
|             observer = DataObserver(self, key) | ||||
|             value = BaseList(value, observer) | ||||
|         elif isinstance(value, (EmbeddedDocument, DynamicEmbeddedDocument)): | ||||
|             value._changed_fields = [] | ||||
|         return value | ||||
|   | ||||
| @@ -1675,6 +1675,8 @@ class QuerySet(object): | ||||
|         .. versionadded:: 0.5 | ||||
|         """ | ||||
|         from dereference import dereference | ||||
|         # Make select related work the same for querysets | ||||
|         max_depth += 1 | ||||
|         return dereference(self, max_depth=max_depth) | ||||
|  | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user