diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 783aac46..c07a45e3 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -1249,7 +1249,7 @@ class GenericReferenceField(BaseField): if document is None: return None - if isinstance(document, (dict, SON)): + if isinstance(document, (dict, SON, ObjectId, DBRef)): return document id_field_name = document.__class__._meta['id_field'] diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 54371d6b..42b01fd0 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -1,6 +1,7 @@ from collections import defaultdict -from bson import SON +from bson import ObjectId, SON +from bson.dbref import DBRef import pymongo from mongoengine.base.fields import UPDATE_OPERATORS @@ -62,6 +63,7 @@ def query(_doc_cls=None, **kwargs): parts = [] CachedReferenceField = _import_class('CachedReferenceField') + GenericReferenceField = _import_class('GenericReferenceField') cleaned_fields = [] for field in fields: @@ -101,6 +103,16 @@ def query(_doc_cls=None, **kwargs): # 'in', 'nin' and 'all' require a list of values value = [field.prepare_query_value(op, v) for v in value] + # If we're querying a GenericReferenceField, we need to alter the + # key depending on the value: + # * If the value is a DBRef, the key should be "field_name._ref". + # * If the value is an ObjectId, the key should be "field_name._ref.$id". + if isinstance(field, GenericReferenceField): + if isinstance(value, DBRef): + parts[-1] += '._ref' + elif isinstance(value, ObjectId): + parts[-1] += '._ref.$id' + # if op and op not in COMPARISON_OPERATORS: if op: if op in GEO_OPERATORS: @@ -128,11 +140,13 @@ def query(_doc_cls=None, **kwargs): for i, part in indices: parts.insert(i, part) + key = '.'.join(parts) + if op is None or key not in mongo_query: mongo_query[key] = value elif key in mongo_query: - if key in mongo_query and isinstance(mongo_query[key], dict): + if isinstance(mongo_query[key], dict): mongo_query[key].update(value) # $max/minDistance needs to come last - convert to SON value_dict = mongo_query[key] diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 2153a42e..14b10561 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -2810,6 +2810,38 @@ class FieldTest(unittest.TestCase): Post.drop_collection() User.drop_collection() + def test_generic_reference_filter_by_dbref(self): + """Ensure we can search for a specific generic reference by + providing its ObjectId. + """ + class Doc(Document): + ref = GenericReferenceField() + + Doc.drop_collection() + + doc1 = Doc.objects.create() + doc2 = Doc.objects.create(ref=doc1) + + doc = Doc.objects.get(ref=DBRef('doc', doc1.pk)) + self.assertEqual(doc, doc2) + + def test_generic_reference_filter_by_objectid(self): + """Ensure we can search for a specific generic reference by + providing its DBRef. + """ + class Doc(Document): + ref = GenericReferenceField() + + Doc.drop_collection() + + doc1 = Doc.objects.create() + doc2 = Doc.objects.create(ref=doc1) + + self.assertTrue(isinstance(doc1.pk, ObjectId)) + + doc = Doc.objects.get(ref=doc1.pk) + self.assertEqual(doc, doc2) + def test_binary_fields(self): """Ensure that binary fields can be stored and retrieved. """