diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index e44ec2c9..6a4c6bd9 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -2,9 +2,7 @@ import copy import numbers from functools import partial -from bson import ObjectId, json_util -from bson.dbref import DBRef -from bson.son import SON +from bson import DBRef, ObjectId, SON, json_util import pymongo import six diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 44266613..619b5d1f 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -52,26 +52,40 @@ class DeReference(object): [i.__class__ == doc_type for i in items.values()]): return items elif not field.dbref: + # We must turn the ObjectIds into DBRefs + + # Recursively dig into the sub items of a list/dict + # to turn the ObjectIds into DBRefs + def _get_items_from_list(items): + new_items = [] + for v in items: + value = v + if isinstance(v, dict): + value = _get_items_from_dict(v) + elif isinstance(v, list): + value = _get_items_from_list(v) + elif not isinstance(v, (DBRef, Document)): + value = field.to_python(v) + new_items.append(value) + return new_items + + def _get_items_from_dict(items): + new_items = {} + for k, v in items.iteritems(): + value = v + if isinstance(v, list): + value = _get_items_from_list(v) + elif isinstance(v, dict): + value = _get_items_from_dict(v) + elif not isinstance(v, (DBRef, Document)): + value = field.to_python(v) + new_items[k] = value + return new_items + if not hasattr(items, 'items'): - - def _get_items(items): - new_items = [] - for v in items: - if isinstance(v, list): - new_items.append(_get_items(v)) - elif not isinstance(v, (DBRef, Document)): - new_items.append(field.to_python(v)) - else: - new_items.append(v) - return new_items - - items = _get_items(items) + items = _get_items_from_list(items) else: - items = { - k: (v if isinstance(v, (DBRef, Document)) - else field.to_python(v)) - for k, v in items.iteritems() - } + items = _get_items_from_dict(items) self.reference_map = self._find_references(items) self.object_map = self._fetch_objects(doc_type=doc_type) diff --git a/tests/fields/fields.py b/tests/fields/fields.py index a1b586ce..0f325849 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -1853,6 +1853,48 @@ class FieldTest(MongoDBTestCase): with self.assertRaises(ValueError): e.update(set__mapping={"somestrings": ["foo", "bar", ]}) + def test_dictfield_with_referencefield_complex_nesting_cases(self): + """Ensure complex nesting inside DictField handles dereferencing of ReferenceField(dbref=True | False)""" + # Relates to Issue #1453 + class Doc(Document): + s = StringField() + + class Simple(Document): + mapping0 = DictField(ReferenceField(Doc, dbref=True)) + mapping1 = DictField(ReferenceField(Doc, dbref=False)) + mapping2 = DictField(ListField(ReferenceField(Doc, dbref=True))) + mapping3 = DictField(ListField(ReferenceField(Doc, dbref=False))) + mapping4 = DictField(DictField(field=ReferenceField(Doc, dbref=True))) + mapping5 = DictField(DictField(field=ReferenceField(Doc, dbref=False))) + mapping6 = DictField(ListField(DictField(ReferenceField(Doc, dbref=True)))) + mapping7 = DictField(ListField(DictField(ReferenceField(Doc, dbref=False)))) + mapping8 = DictField(ListField(DictField(ListField(ReferenceField(Doc, dbref=True))))) + mapping9 = DictField(ListField(DictField(ListField(ReferenceField(Doc, dbref=False))))) + + Doc.drop_collection() + Simple.drop_collection() + + d = Doc(s='aa').save() + e = Simple() + e.mapping0['someint'] = e.mapping1['someint'] = d + e.mapping2['someint'] = e.mapping3['someint'] = [d] + e.mapping4['someint'] = e.mapping5['someint'] = {'d': d} + e.mapping6['someint'] = e.mapping7['someint'] = [{'d': d}] + e.mapping8['someint'] = e.mapping9['someint'] = [{'d': [d]}] + e.save() + + s = Simple.objects.first() + self.assertIsInstance(s.mapping0['someint'], Doc) + self.assertIsInstance(s.mapping1['someint'], Doc) + self.assertIsInstance(s.mapping2['someint'][0], Doc) + self.assertIsInstance(s.mapping3['someint'][0], Doc) + self.assertIsInstance(s.mapping4['someint']['d'], Doc) + self.assertIsInstance(s.mapping5['someint']['d'], Doc) + self.assertIsInstance(s.mapping6['someint'][0]['d'], Doc) + self.assertIsInstance(s.mapping7['someint'][0]['d'], Doc) + self.assertIsInstance(s.mapping8['someint'][0]['d'][0], Doc) + self.assertIsInstance(s.mapping9['someint'][0]['d'][0], Doc) + def test_mapfield(self): """Ensure that the MapField handles the declared type.""" class Simple(Document):