from bson import DBRef, SON import six from six import iteritems from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList, TopLevelDocumentMetaclass, get_document) from mongoengine.base.datastructures import LazyReference from mongoengine.connection import get_db from mongoengine.document import Document, EmbeddedDocument from mongoengine.fields import DictField, ListField, MapField, ReferenceField from mongoengine.queryset import QuerySet class DeReference(object): def __call__(self, items, max_depth=1, instance=None, name=None): """ Cheaply dereferences the items to a set depth. Also handles the conversion of complex data types. :param items: The iterable (dict, list, queryset) to be dereferenced. :param max_depth: The maximum depth to recurse to :param instance: The owning instance used for tracking changes by :class:`~mongoengine.base.ComplexBaseField` :param name: The name of the field, used for tracking changes by :class:`~mongoengine.base.ComplexBaseField` :param get: A boolean determining if being called by __get__ """ if items is None or isinstance(items, six.string_types): return items # cheapest way to convert a queryset to a list # list(queryset) uses a count() query to determine length if isinstance(items, QuerySet): items = [i for i in items] self.max_depth = max_depth doc_type = None if instance and isinstance(instance, (Document, EmbeddedDocument, TopLevelDocumentMetaclass)): doc_type = instance._fields.get(name) while hasattr(doc_type, 'field'): doc_type = doc_type.field if isinstance(doc_type, ReferenceField): field = doc_type doc_type = doc_type.document_type is_list = not hasattr(items, 'items') if is_list and all([i.__class__ == doc_type for i in items]): return items elif not is_list and all( [i.__class__ == doc_type for i in items.values()]): return items elif not field.dbref: # We must turn the ObjectIds into DBRefs # Recursively dig into the sub items of a list/dict # to turn the ObjectIds into DBRefs def _get_items_from_list(items): new_items = [] for v in items: value = v if isinstance(v, dict): value = _get_items_from_dict(v) elif isinstance(v, list): value = _get_items_from_list(v) elif not isinstance(v, (DBRef, Document)): value = field.to_python(v) new_items.append(value) return new_items def _get_items_from_dict(items): new_items = {} for k, v in iteritems(items): value = v if isinstance(v, list): value = _get_items_from_list(v) elif isinstance(v, dict): value = _get_items_from_dict(v) elif not isinstance(v, (DBRef, Document)): value = field.to_python(v) new_items[k] = value return new_items if not hasattr(items, 'items'): items = _get_items_from_list(items) else: items = _get_items_from_dict(items) self.reference_map = self._find_references(items) self.object_map = self._fetch_objects(doc_type=doc_type) return self._attach_objects(items, 0, instance, name) def _find_references(self, items, depth=0): """ Recursively finds all db references to be dereferenced :param items: The iterable (dict, list, queryset) :param depth: The current depth of recursion """ reference_map = {} if not items or depth >= self.max_depth: return reference_map # Determine the iterator to use if isinstance(items, dict): iterator = items.values() else: iterator = items # Recursively find dbreferences depth += 1 for item in iterator: if isinstance(item, (Document, EmbeddedDocument)): for field_name, field in iteritems(item._fields): v = item._data.get(field_name, None) if isinstance(v, LazyReference): # LazyReference inherits DBRef but should not be dereferenced here ! continue elif isinstance(v, DBRef): reference_map.setdefault(field.document_type, set()).add(v.id) elif isinstance(v, (dict, SON)) and '_ref' in v: reference_map.setdefault(get_document(v['_cls']), set()).add(v['_ref'].id) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: field_cls = getattr(getattr(field, 'field', None), 'document_type', None) references = self._find_references(v, depth) for key, refs in iteritems(references): if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)): key = field_cls reference_map.setdefault(key, set()).update(refs) elif isinstance(item, LazyReference): # LazyReference inherits DBRef but should not be dereferenced here ! continue elif isinstance(item, DBRef): reference_map.setdefault(item.collection, set()).add(item.id) elif isinstance(item, (dict, SON)) and '_ref' in item: reference_map.setdefault(get_document(item['_cls']), set()).add(item['_ref'].id) elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth: references = self._find_references(item, depth - 1) for key, refs in iteritems(references): reference_map.setdefault(key, set()).update(refs) return reference_map def _fetch_objects(self, doc_type=None): """Fetch all references and convert to their document objects """ object_map = {} for collection, dbrefs in iteritems(self.reference_map): # we use getattr instead of hasattr because hasattr swallows any exception under python2 # so it could hide nasty things without raising exceptions (cfr bug #1688)) ref_document_cls_exists = (getattr(collection, 'objects', None) is not None) if ref_document_cls_exists: col_name = collection._get_collection_name() refs = [dbref for dbref in dbrefs if (col_name, dbref) not in object_map] references = collection.objects.in_bulk(refs) for key, doc in iteritems(references): object_map[(col_name, key)] = doc else: # Generic reference: use the refs data to convert to document if isinstance(doc_type, (ListField, DictField, MapField)): continue refs = [dbref for dbref in dbrefs if (collection, dbref) not in object_map] if doc_type: references = doc_type._get_db()[collection].find({'_id': {'$in': refs}}) for ref in references: doc = doc_type._from_son(ref) object_map[(collection, doc.id)] = doc else: references = get_db()[collection].find({'_id': {'$in': refs}}) for ref in references: if '_cls' in ref: doc = get_document(ref['_cls'])._from_son(ref) elif doc_type is None: doc = get_document( ''.join(x.capitalize() for x in collection.split('_')))._from_son(ref) else: doc = doc_type._from_son(ref) object_map[(collection, doc.id)] = doc return object_map def _attach_objects(self, items, depth=0, instance=None, name=None): """ Recursively finds all db references to be dereferenced :param items: The iterable (dict, list, queryset) :param depth: The current depth of recursion :param instance: The owning instance used for tracking changes by :class:`~mongoengine.base.ComplexBaseField` :param name: The name of the field, used for tracking changes by :class:`~mongoengine.base.ComplexBaseField` """ if not items: if isinstance(items, (BaseDict, BaseList)): return items if instance: if isinstance(items, dict): return BaseDict(items, instance, name) else: return BaseList(items, instance, name) if isinstance(items, (dict, SON)): if '_ref' in items: return self.object_map.get( (items['_ref'].collection, items['_ref'].id), items) elif '_cls' in items: doc = get_document(items['_cls'])._from_son(items) _cls = doc._data.pop('_cls', None) del items['_cls'] doc._data = self._attach_objects(doc._data, depth, doc, None) if _cls is not None: doc._data['_cls'] = _cls return doc if not hasattr(items, 'items'): is_list = True list_type = BaseList if isinstance(items, EmbeddedDocumentList): list_type = EmbeddedDocumentList as_tuple = isinstance(items, tuple) iterator = enumerate(items) data = [] else: is_list = False iterator = iteritems(items) data = {} depth += 1 for k, v in iterator: if is_list: data.append(v) else: data[k] = v if k in self.object_map and not is_list: data[k] = self.object_map[k] elif isinstance(v, (Document, EmbeddedDocument)): for field_name in v._fields: v = data[k]._data.get(field_name, None) if isinstance(v, DBRef): data[k]._data[field_name] = self.object_map.get( (v.collection, v.id), v) elif isinstance(v, (dict, SON)) and '_ref' in v: data[k]._data[field_name] = self.object_map.get( (v['_ref'].collection, v['_ref'].id), v) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: item_name = six.text_type('{0}.{1}.{2}').format(name, k, field_name) data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: item_name = '%s.%s' % (name, k) if name else name data[k] = self._attach_objects(v, depth - 1, instance=instance, name=item_name) elif isinstance(v, DBRef) and hasattr(v, 'id'): data[k] = self.object_map.get((v.collection, v.id), v) if instance and name: if is_list: return tuple(data) if as_tuple else list_type(data, instance, name) return BaseDict(data, instance, name) depth += 1 return data