mongoengine/mongoengine/dereference.py
Stefan Wójcik b47669403b
Format the codebase using Black (#2109)
This commit:
1. Formats all of our existing code using `black`.
2. Adds a note about using `black` to `CONTRIBUTING.rst`.
3. Runs `black --check` as part of CI (failing builds that aren't properly formatted).
2019-06-27 13:05:54 +02:00

297 lines
13 KiB
Python

from bson import DBRef, SON
import six
from six import iteritems
from mongoengine.base import (
BaseDict,
BaseList,
EmbeddedDocumentList,
TopLevelDocumentMetaclass,
get_document,
)
from mongoengine.base.datastructures import LazyReference
from mongoengine.connection import get_db
from mongoengine.document import Document, EmbeddedDocument
from mongoengine.fields import DictField, ListField, MapField, ReferenceField
from mongoengine.queryset import QuerySet
class DeReference(object):
def __call__(self, items, max_depth=1, instance=None, name=None):
"""
Cheaply dereferences the items to a set depth.
Also handles the conversion of complex data types.
:param items: The iterable (dict, list, queryset) to be dereferenced.
:param max_depth: The maximum depth to recurse to
:param instance: The owning instance used for tracking changes by
:class:`~mongoengine.base.ComplexBaseField`
:param name: The name of the field, used for tracking changes by
:class:`~mongoengine.base.ComplexBaseField`
:param get: A boolean determining if being called by __get__
"""
if items is None or isinstance(items, six.string_types):
return items
# cheapest way to convert a queryset to a list
# list(queryset) uses a count() query to determine length
if isinstance(items, QuerySet):
items = [i for i in items]
self.max_depth = max_depth
doc_type = None
if instance and isinstance(
instance, (Document, EmbeddedDocument, TopLevelDocumentMetaclass)
):
doc_type = instance._fields.get(name)
while hasattr(doc_type, "field"):
doc_type = doc_type.field
if isinstance(doc_type, ReferenceField):
field = doc_type
doc_type = doc_type.document_type
is_list = not hasattr(items, "items")
if is_list and all([i.__class__ == doc_type for i in items]):
return items
elif not is_list and all(
[i.__class__ == doc_type for i in items.values()]
):
return items
elif not field.dbref:
# We must turn the ObjectIds into DBRefs
# Recursively dig into the sub items of a list/dict
# to turn the ObjectIds into DBRefs
def _get_items_from_list(items):
new_items = []
for v in items:
value = v
if isinstance(v, dict):
value = _get_items_from_dict(v)
elif isinstance(v, list):
value = _get_items_from_list(v)
elif not isinstance(v, (DBRef, Document)):
value = field.to_python(v)
new_items.append(value)
return new_items
def _get_items_from_dict(items):
new_items = {}
for k, v in iteritems(items):
value = v
if isinstance(v, list):
value = _get_items_from_list(v)
elif isinstance(v, dict):
value = _get_items_from_dict(v)
elif not isinstance(v, (DBRef, Document)):
value = field.to_python(v)
new_items[k] = value
return new_items
if not hasattr(items, "items"):
items = _get_items_from_list(items)
else:
items = _get_items_from_dict(items)
self.reference_map = self._find_references(items)
self.object_map = self._fetch_objects(doc_type=doc_type)
return self._attach_objects(items, 0, instance, name)
def _find_references(self, items, depth=0):
"""
Recursively finds all db references to be dereferenced
:param items: The iterable (dict, list, queryset)
:param depth: The current depth of recursion
"""
reference_map = {}
if not items or depth >= self.max_depth:
return reference_map
# Determine the iterator to use
if isinstance(items, dict):
iterator = items.values()
else:
iterator = items
# Recursively find dbreferences
depth += 1
for item in iterator:
if isinstance(item, (Document, EmbeddedDocument)):
for field_name, field in iteritems(item._fields):
v = item._data.get(field_name, None)
if isinstance(v, LazyReference):
# LazyReference inherits DBRef but should not be dereferenced here !
continue
elif isinstance(v, DBRef):
reference_map.setdefault(field.document_type, set()).add(v.id)
elif isinstance(v, (dict, SON)) and "_ref" in v:
reference_map.setdefault(get_document(v["_cls"]), set()).add(
v["_ref"].id
)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
field_cls = getattr(
getattr(field, "field", None), "document_type", None
)
references = self._find_references(v, depth)
for key, refs in iteritems(references):
if isinstance(
field_cls, (Document, TopLevelDocumentMetaclass)
):
key = field_cls
reference_map.setdefault(key, set()).update(refs)
elif isinstance(item, LazyReference):
# LazyReference inherits DBRef but should not be dereferenced here !
continue
elif isinstance(item, DBRef):
reference_map.setdefault(item.collection, set()).add(item.id)
elif isinstance(item, (dict, SON)) and "_ref" in item:
reference_map.setdefault(get_document(item["_cls"]), set()).add(
item["_ref"].id
)
elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth:
references = self._find_references(item, depth - 1)
for key, refs in iteritems(references):
reference_map.setdefault(key, set()).update(refs)
return reference_map
def _fetch_objects(self, doc_type=None):
"""Fetch all references and convert to their document objects
"""
object_map = {}
for collection, dbrefs in iteritems(self.reference_map):
# we use getattr instead of hasattr because hasattr swallows any exception under python2
# so it could hide nasty things without raising exceptions (cfr bug #1688))
ref_document_cls_exists = getattr(collection, "objects", None) is not None
if ref_document_cls_exists:
col_name = collection._get_collection_name()
refs = [
dbref for dbref in dbrefs if (col_name, dbref) not in object_map
]
references = collection.objects.in_bulk(refs)
for key, doc in iteritems(references):
object_map[(col_name, key)] = doc
else: # Generic reference: use the refs data to convert to document
if isinstance(doc_type, (ListField, DictField, MapField)):
continue
refs = [
dbref for dbref in dbrefs if (collection, dbref) not in object_map
]
if doc_type:
references = doc_type._get_db()[collection].find(
{"_id": {"$in": refs}}
)
for ref in references:
doc = doc_type._from_son(ref)
object_map[(collection, doc.id)] = doc
else:
references = get_db()[collection].find({"_id": {"$in": refs}})
for ref in references:
if "_cls" in ref:
doc = get_document(ref["_cls"])._from_son(ref)
elif doc_type is None:
doc = get_document(
"".join(x.capitalize() for x in collection.split("_"))
)._from_son(ref)
else:
doc = doc_type._from_son(ref)
object_map[(collection, doc.id)] = doc
return object_map
def _attach_objects(self, items, depth=0, instance=None, name=None):
"""
Recursively finds all db references to be dereferenced
:param items: The iterable (dict, list, queryset)
:param depth: The current depth of recursion
:param instance: The owning instance used for tracking changes by
:class:`~mongoengine.base.ComplexBaseField`
:param name: The name of the field, used for tracking changes by
:class:`~mongoengine.base.ComplexBaseField`
"""
if not items:
if isinstance(items, (BaseDict, BaseList)):
return items
if instance:
if isinstance(items, dict):
return BaseDict(items, instance, name)
else:
return BaseList(items, instance, name)
if isinstance(items, (dict, SON)):
if "_ref" in items:
return self.object_map.get(
(items["_ref"].collection, items["_ref"].id), items
)
elif "_cls" in items:
doc = get_document(items["_cls"])._from_son(items)
_cls = doc._data.pop("_cls", None)
del items["_cls"]
doc._data = self._attach_objects(doc._data, depth, doc, None)
if _cls is not None:
doc._data["_cls"] = _cls
return doc
if not hasattr(items, "items"):
is_list = True
list_type = BaseList
if isinstance(items, EmbeddedDocumentList):
list_type = EmbeddedDocumentList
as_tuple = isinstance(items, tuple)
iterator = enumerate(items)
data = []
else:
is_list = False
iterator = iteritems(items)
data = {}
depth += 1
for k, v in iterator:
if is_list:
data.append(v)
else:
data[k] = v
if k in self.object_map and not is_list:
data[k] = self.object_map[k]
elif isinstance(v, (Document, EmbeddedDocument)):
for field_name in v._fields:
v = data[k]._data.get(field_name, None)
if isinstance(v, DBRef):
data[k]._data[field_name] = self.object_map.get(
(v.collection, v.id), v
)
elif isinstance(v, (dict, SON)) and "_ref" in v:
data[k]._data[field_name] = self.object_map.get(
(v["_ref"].collection, v["_ref"].id), v
)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
item_name = six.text_type("{0}.{1}.{2}").format(
name, k, field_name
)
data[k]._data[field_name] = self._attach_objects(
v, depth, instance=instance, name=item_name
)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
item_name = "%s.%s" % (name, k) if name else name
data[k] = self._attach_objects(
v, depth - 1, instance=instance, name=item_name
)
elif isinstance(v, DBRef) and hasattr(v, "id"):
data[k] = self.object_map.get((v.collection, v.id), v)
if instance and name:
if is_list:
return tuple(data) if as_tuple else list_type(data, instance, name)
return BaseDict(data, instance, name)
depth += 1
return data