Cleaned up dereferencing

Dereferencing now respects max_depth, so should be more performant.
Reload is chainable and can be passed a max_depth for dereferencing
Added an Observer for ComplexBaseFields.

Refs #324 #323 #289
Closes #320
This commit is contained in:
Ross Lawley
2011-11-25 08:28:20 -08:00
parent 5e553ffaf7
commit 83fff80b0f
6 changed files with 122 additions and 97 deletions

View File

@@ -1,6 +1,7 @@
import pymongo
from base import BaseDict, BaseList, get_document, TopLevelDocumentMetaclass
from base import (BaseDict, BaseList, DataObserver,
TopLevelDocumentMetaclass, get_document)
from fields import ReferenceField
from connection import get_db
from queryset import QuerySet
@@ -9,7 +10,7 @@ from document import Document
class DeReference(object):
def __call__(self, items, max_depth=1, instance=None, name=None, get=False):
def __call__(self, items, max_depth=1, instance=None, name=None):
"""
Cheaply dereferences the items to a set depth.
Also handles the convertion of complex data types.
@@ -43,7 +44,7 @@ class DeReference(object):
self.reference_map = self._find_references(items)
self.object_map = self._fetch_objects(doc_type=doc_type)
return self._attach_objects(items, 0, instance, name, get)
return self._attach_objects(items, 0, instance, name)
def _find_references(self, items, depth=0):
"""
@@ -53,7 +54,7 @@ class DeReference(object):
:param depth: The current depth of recursion
"""
reference_map = {}
if not items:
if not items or depth >= self.max_depth:
return reference_map
# Determine the iterator to use
@@ -63,6 +64,7 @@ class DeReference(object):
iterator = items.iteritems()
# Recursively find dbreferences
depth += 1
for k, item in iterator:
if hasattr(item, '_fields'):
for field_name, field in item._fields.iteritems():
@@ -82,11 +84,11 @@ class DeReference(object):
reference_map.setdefault(item.collection, []).append(item.id)
elif isinstance(item, (dict, pymongo.son.SON)) and '_ref' in item:
reference_map.setdefault(get_document(item['_cls']), []).append(item['_ref'].id)
elif isinstance(item, (dict, list, tuple)) and depth <= self.max_depth:
references = self._find_references(item, depth)
elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth:
references = self._find_references(item, depth - 1)
for key, refs in references.iteritems():
reference_map.setdefault(key, []).extend(refs)
depth += 1
return reference_map
def _fetch_objects(self, doc_type=None):
@@ -110,7 +112,7 @@ class DeReference(object):
object_map[doc.id] = doc
return object_map
def _attach_objects(self, items, depth=0, instance=None, name=None, get=False):
def _attach_objects(self, items, depth=0, instance=None, name=None):
"""
Recursively finds all db references to be dereferenced
@@ -120,25 +122,24 @@ class DeReference(object):
:class:`~mongoengine.base.ComplexBaseField`
:param name: The name of the field, used for tracking changes by
:class:`~mongoengine.base.ComplexBaseField`
:param get: A boolean determining if being called by __get__
"""
if not items:
if isinstance(items, (BaseDict, BaseList)):
return items
if instance:
observer = DataObserver(instance, name)
if isinstance(items, dict):
return BaseDict(items, instance=instance, name=name)
return BaseDict(items, observer)
else:
return BaseList(items, instance=instance, name=name)
return BaseList(items, observer)
if isinstance(items, (dict, pymongo.son.SON)):
if '_ref' in items:
return self.object_map.get(items['_ref'].id, items)
elif '_types' in items and '_cls' in items:
doc = get_document(items['_cls'])._from_son(items)
if not get:
doc._data = self._attach_objects(doc._data, depth, doc, name, get)
doc._data = self._attach_objects(doc._data, depth, doc, name)
return doc
if not hasattr(items, 'items'):
@@ -150,6 +151,7 @@ class DeReference(object):
iterator = items.iteritems()
data = {}
depth += 1
for k, v in iterator:
if is_list:
data.append(v)
@@ -165,19 +167,20 @@ class DeReference(object):
data[k]._data[field_name] = self.object_map.get(v.id, v)
elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v:
data[k]._data[field_name] = self.object_map.get(v['_ref'].id, v)
elif isinstance(v, dict) and depth < self.max_depth:
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name, get=get)
elif isinstance(v, (list, tuple)):
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name, get=get)
elif isinstance(v, (dict, list, tuple)) and depth < self.max_depth:
data[k] = self._attach_objects(v, depth, instance=instance, name=name, get=get)
elif isinstance(v, dict) and depth <= self.max_depth:
data[k]._data[field_name] = self._attach_objects(v, depth - 1, instance=instance, name=name)
elif isinstance(v, (list, tuple)) and depth <= self.max_depth:
data[k]._data[field_name] = self._attach_objects(v, depth - 1, instance=instance, name=name)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
data[k] = self._attach_objects(v, depth - 1, instance=instance, name=name)
elif hasattr(v, 'id'):
data[k] = self.object_map.get(v.id, v)
if instance and name:
observer = DataObserver(instance, name)
if is_list:
return BaseList(data, instance=instance, name=name)
return BaseDict(data, instance=instance, name=name)
return BaseList(data, observer)
return BaseDict(data, observer)
depth += 1
return data