Added select_related() and refactored dereferencing
Added a dereference class to handle both select_related / recursive dereferencing and fetching dereference. Refs #206
This commit is contained in:
171
mongoengine/dereference.py
Normal file
171
mongoengine/dereference.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import operator
|
||||
|
||||
import pymongo
|
||||
|
||||
from base import BaseDict, BaseList, get_document
|
||||
from connection import _get_db
|
||||
from queryset import QuerySet
|
||||
|
||||
|
||||
class DeReference(object):
|
||||
|
||||
def __call__(self, items, max_depth=1, instance=None, name=None, get=False):
|
||||
"""
|
||||
Cheaply dereferences the items to a set depth.
|
||||
Also handles the convertion of complex data types.
|
||||
|
||||
:param items: The iterable (dict, list, queryset) to be dereferenced.
|
||||
:param max_depth: The maximum depth to recurse to
|
||||
:param instance: The owning instance used for tracking changes by
|
||||
:class:`~mongoengine.base.ComplexBaseField`
|
||||
:param name: The name of the field, used for tracking changes by
|
||||
:class:`~mongoengine.base.ComplexBaseField`
|
||||
:param get: A boolean determining if being called by __get__
|
||||
"""
|
||||
if items is None or isinstance(items, basestring):
|
||||
return items
|
||||
|
||||
# cheapest way to convert a queryset to a list
|
||||
# list(queryset) uses a count() query to determine length
|
||||
if isinstance(items, QuerySet):
|
||||
items = [i for i in items]
|
||||
|
||||
self.max_depth = max_depth
|
||||
self.reference_map = self._find_references(items)
|
||||
self.object_map = self._fetch_objects()
|
||||
return self._attach_objects(items, 0, instance, name, get)
|
||||
|
||||
def _find_references(self, items, depth=0):
|
||||
"""
|
||||
Recursively finds all db references to be dereferenced
|
||||
|
||||
:param items: The iterable (dict, list, queryset)
|
||||
:param depth: The current depth of recursion
|
||||
"""
|
||||
reference_map = {}
|
||||
if not items:
|
||||
return reference_map
|
||||
|
||||
# Determine the iterator to use
|
||||
if not hasattr(items, 'items'):
|
||||
iterator = enumerate(items)
|
||||
else:
|
||||
iterator = items.iteritems()
|
||||
|
||||
# Recursively find dbreferences
|
||||
for k, item in iterator:
|
||||
if hasattr(item, '_fields'):
|
||||
for field_name, field in item._fields.iteritems():
|
||||
v = item._data.get(field_name, None)
|
||||
if isinstance(v, (pymongo.dbref.DBRef)):
|
||||
reference_map.setdefault(field.document_type, []).append(v.id)
|
||||
elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v:
|
||||
reference_map.setdefault(get_document(v['_cls']), []).append(v['_ref'].id)
|
||||
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
|
||||
field_cls = getattr(getattr(field, 'field', None), 'document_type', None)
|
||||
references = self._find_references(v, depth)
|
||||
for key, refs in references.iteritems():
|
||||
if field_cls:
|
||||
key = field_cls
|
||||
reference_map.setdefault(key, []).extend(refs)
|
||||
elif isinstance(item, (pymongo.dbref.DBRef)):
|
||||
reference_map.setdefault(item.collection, []).append(item.id)
|
||||
elif isinstance(item, (dict, pymongo.son.SON)) and '_ref' in item:
|
||||
reference_map.setdefault(get_document(item['_cls']), []).append(item['_ref'].id)
|
||||
elif isinstance(item, (dict, list, tuple)) and depth <= self.max_depth:
|
||||
references = self._find_references(item, depth)
|
||||
for key, refs in references.iteritems():
|
||||
reference_map.setdefault(key, []).extend(refs)
|
||||
depth += 1
|
||||
return reference_map
|
||||
|
||||
def _fetch_objects(self):
|
||||
"""Fetch all references and convert to their document objects
|
||||
"""
|
||||
object_map = {}
|
||||
for col, dbrefs in self.reference_map.iteritems():
|
||||
keys = object_map.keys()
|
||||
refs = list(set([dbref for dbref in dbrefs if str(dbref) not in keys]))
|
||||
if hasattr(col, 'objects'): # We have a document class for the refs
|
||||
references = col.objects.in_bulk(refs)
|
||||
for key, doc in references.iteritems():
|
||||
object_map[key] = doc
|
||||
else: # Generic reference: use the refs data to convert to document
|
||||
references = _get_db()[col].find({'_id': {'$in': refs}})
|
||||
for ref in references:
|
||||
doc = get_document(ref['_cls'])._from_son(ref)
|
||||
object_map[doc.id] = doc
|
||||
return object_map
|
||||
|
||||
def _attach_objects(self, items, depth=0, instance=None, name=None, get=False):
|
||||
"""
|
||||
Recursively finds all db references to be dereferenced
|
||||
|
||||
:param items: The iterable (dict, list, queryset)
|
||||
:param depth: The current depth of recursion
|
||||
:param instance: The owning instance used for tracking changes by
|
||||
:class:`~mongoengine.base.ComplexBaseField`
|
||||
:param name: The name of the field, used for tracking changes by
|
||||
:class:`~mongoengine.base.ComplexBaseField`
|
||||
:param get: A boolean determining if being called by __get__
|
||||
"""
|
||||
if not items:
|
||||
if isinstance(items, (BaseDict, BaseList)):
|
||||
return items
|
||||
|
||||
if instance:
|
||||
if isinstance(items, dict):
|
||||
return BaseDict(items, instance=instance, name=name)
|
||||
else:
|
||||
return BaseList(items, instance=instance, name=name)
|
||||
|
||||
if isinstance(items, (dict, pymongo.son.SON)):
|
||||
if '_ref' in items:
|
||||
return self.object_map.get(items['_ref'].id, items)
|
||||
elif '_types' in items and '_cls' in items:
|
||||
doc = get_document(items['_cls'])._from_son(items)
|
||||
if not get:
|
||||
doc._data = self._attach_objects(doc._data, depth, doc, name, get)
|
||||
return doc
|
||||
|
||||
if not hasattr(items, 'items'):
|
||||
is_list = True
|
||||
iterator = enumerate(items)
|
||||
data = []
|
||||
else:
|
||||
is_list = False
|
||||
iterator = items.iteritems()
|
||||
data = {}
|
||||
|
||||
for k, v in iterator:
|
||||
if is_list:
|
||||
data.append(v)
|
||||
else:
|
||||
data[k] = v
|
||||
|
||||
if k in self.object_map:
|
||||
data[k] = self.object_map[k]
|
||||
elif hasattr(v, '_fields'):
|
||||
for field_name, field in v._fields.iteritems():
|
||||
v = data[k]._data.get(field_name, None)
|
||||
if isinstance(v, (pymongo.dbref.DBRef)):
|
||||
data[k]._data[field_name] = self.object_map.get(v.id, v)
|
||||
elif isinstance(v, (dict, pymongo.son.SON)) and '_ref' in v:
|
||||
data[k]._data[field_name] = self.object_map.get(v['_ref'].id, v)
|
||||
elif isinstance(v, dict) and depth < self.max_depth:
|
||||
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name, get=get)
|
||||
elif isinstance(v, (list, tuple)):
|
||||
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=name, get=get)
|
||||
elif isinstance(v, (dict, list, tuple)) and depth < self.max_depth:
|
||||
data[k] = self._attach_objects(v, depth, instance=instance, name=name, get=get)
|
||||
elif hasattr(v, 'id'):
|
||||
data[k] = self.object_map.get(v.id, v)
|
||||
|
||||
if instance and name:
|
||||
if is_list:
|
||||
return BaseList(data, instance=instance, name=name)
|
||||
return BaseDict(data, instance=instance, name=name)
|
||||
depth += 1
|
||||
return data
|
||||
|
||||
dereference = DeReference()
|
||||
Reference in New Issue
Block a user