diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 51a88bdb..d5336117 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1261,7 +1261,7 @@ class GenericReferenceField(BaseField): if document is None: return None - if isinstance(document, (dict, SON)): + if isinstance(document, (dict, SON, ObjectId, DBRef)): return document id_field_name = document.__class__._meta['id_field'] diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 4e5553c9..af59917c 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -1,6 +1,7 @@ from collections import defaultdict -from bson import SON +from bson import ObjectId, SON +from bson.dbref import DBRef import pymongo import six @@ -27,6 +28,7 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + STRING_OPERATORS + CUSTOM_OPERATORS) +# TODO make this less complex def query(_doc_cls=None, **kwargs): """Transform a query from Django-style format to Mongo format.""" mongo_query = {} @@ -62,6 +64,7 @@ def query(_doc_cls=None, **kwargs): parts = [] CachedReferenceField = _import_class('CachedReferenceField') + GenericReferenceField = _import_class('GenericReferenceField') cleaned_fields = [] for field in fields: @@ -101,6 +104,16 @@ def query(_doc_cls=None, **kwargs): # 'in', 'nin' and 'all' require a list of values value = [field.prepare_query_value(op, v) for v in value] + # If we're querying a GenericReferenceField, we need to alter the + # key depending on the value: + # * If the value is a DBRef, the key should be "field_name._ref". + # * If the value is an ObjectId, the key should be "field_name._ref.$id". + if isinstance(field, GenericReferenceField): + if isinstance(value, DBRef): + parts[-1] += '._ref' + elif isinstance(value, ObjectId): + parts[-1] += '._ref.$id' + # if op and op not in COMPARISON_OPERATORS: if op: if op in GEO_OPERATORS: @@ -128,11 +141,13 @@ def query(_doc_cls=None, **kwargs): for i, part in indices: parts.insert(i, part) + key = '.'.join(parts) + if op is None or key not in mongo_query: mongo_query[key] = value elif key in mongo_query: - if key in mongo_query and isinstance(mongo_query[key], dict): + if isinstance(mongo_query[key], dict): mongo_query[key].update(value) # $max/minDistance needs to come last - convert to SON value_dict = mongo_query[key] diff --git a/setup.cfg b/setup.cfg index ac98a0f3..f883b26c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,5 @@ tests = tests [flake8] ignore=E501,F401,F403,F405,I201 -exclude=build,dist,docs,venv,venv26,venv3,.tox,.eggs,tests -max-complexity=42 -application-import-names=mongoengine,tests +exclude=build,dist,docs,venv,venv3,.tox,.eggs,tests +max-complexity=45 diff --git a/tests/fields/fields.py b/tests/fields/fields.py index a22a2561..87beda70 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -2805,6 +2805,38 @@ class FieldTest(unittest.TestCase): Post.drop_collection() User.drop_collection() + def test_generic_reference_filter_by_dbref(self): + """Ensure we can search for a specific generic reference by + providing its ObjectId. + """ + class Doc(Document): + ref = GenericReferenceField() + + Doc.drop_collection() + + doc1 = Doc.objects.create() + doc2 = Doc.objects.create(ref=doc1) + + doc = Doc.objects.get(ref=DBRef('doc', doc1.pk)) + self.assertEqual(doc, doc2) + + def test_generic_reference_filter_by_objectid(self): + """Ensure we can search for a specific generic reference by + providing its DBRef. + """ + class Doc(Document): + ref = GenericReferenceField() + + Doc.drop_collection() + + doc1 = Doc.objects.create() + doc2 = Doc.objects.create(ref=doc1) + + self.assertTrue(isinstance(doc1.pk, ObjectId)) + + doc = Doc.objects.get(ref=doc1.pk) + self.assertEqual(doc, doc2) + def test_binary_fields(self): """Ensure that binary fields can be stored and retrieved. """