add ability to filter the generic reference field by ObjectId and DBRef
This commit is contained in:
parent
15714ef855
commit
2904ce091b
@ -1249,7 +1249,7 @@ class GenericReferenceField(BaseField):
|
|||||||
if document is None:
|
if document is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if isinstance(document, (dict, SON)):
|
if isinstance(document, (dict, SON, ObjectId, DBRef)):
|
||||||
return document
|
return document
|
||||||
|
|
||||||
id_field_name = document.__class__._meta['id_field']
|
id_field_name = document.__class__._meta['id_field']
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from bson import SON
|
from bson import ObjectId, SON
|
||||||
|
from bson.dbref import DBRef
|
||||||
import pymongo
|
import pymongo
|
||||||
|
|
||||||
from mongoengine.base.fields import UPDATE_OPERATORS
|
from mongoengine.base.fields import UPDATE_OPERATORS
|
||||||
@ -62,6 +63,7 @@ def query(_doc_cls=None, **kwargs):
|
|||||||
parts = []
|
parts = []
|
||||||
|
|
||||||
CachedReferenceField = _import_class('CachedReferenceField')
|
CachedReferenceField = _import_class('CachedReferenceField')
|
||||||
|
GenericReferenceField = _import_class('GenericReferenceField')
|
||||||
|
|
||||||
cleaned_fields = []
|
cleaned_fields = []
|
||||||
for field in fields:
|
for field in fields:
|
||||||
@ -101,6 +103,16 @@ def query(_doc_cls=None, **kwargs):
|
|||||||
# 'in', 'nin' and 'all' require a list of values
|
# 'in', 'nin' and 'all' require a list of values
|
||||||
value = [field.prepare_query_value(op, v) for v in value]
|
value = [field.prepare_query_value(op, v) for v in value]
|
||||||
|
|
||||||
|
# If we're querying a GenericReferenceField, we need to alter the
|
||||||
|
# key depending on the value:
|
||||||
|
# * If the value is a DBRef, the key should be "field_name._ref".
|
||||||
|
# * If the value is an ObjectId, the key should be "field_name._ref.$id".
|
||||||
|
if isinstance(field, GenericReferenceField):
|
||||||
|
if isinstance(value, DBRef):
|
||||||
|
parts[-1] += '._ref'
|
||||||
|
elif isinstance(value, ObjectId):
|
||||||
|
parts[-1] += '._ref.$id'
|
||||||
|
|
||||||
# if op and op not in COMPARISON_OPERATORS:
|
# if op and op not in COMPARISON_OPERATORS:
|
||||||
if op:
|
if op:
|
||||||
if op in GEO_OPERATORS:
|
if op in GEO_OPERATORS:
|
||||||
@ -128,11 +140,13 @@ def query(_doc_cls=None, **kwargs):
|
|||||||
|
|
||||||
for i, part in indices:
|
for i, part in indices:
|
||||||
parts.insert(i, part)
|
parts.insert(i, part)
|
||||||
|
|
||||||
key = '.'.join(parts)
|
key = '.'.join(parts)
|
||||||
|
|
||||||
if op is None or key not in mongo_query:
|
if op is None or key not in mongo_query:
|
||||||
mongo_query[key] = value
|
mongo_query[key] = value
|
||||||
elif key in mongo_query:
|
elif key in mongo_query:
|
||||||
if key in mongo_query and isinstance(mongo_query[key], dict):
|
if isinstance(mongo_query[key], dict):
|
||||||
mongo_query[key].update(value)
|
mongo_query[key].update(value)
|
||||||
# $max/minDistance needs to come last - convert to SON
|
# $max/minDistance needs to come last - convert to SON
|
||||||
value_dict = mongo_query[key]
|
value_dict = mongo_query[key]
|
||||||
|
@ -2810,6 +2810,38 @@ class FieldTest(unittest.TestCase):
|
|||||||
Post.drop_collection()
|
Post.drop_collection()
|
||||||
User.drop_collection()
|
User.drop_collection()
|
||||||
|
|
||||||
|
def test_generic_reference_filter_by_dbref(self):
|
||||||
|
"""Ensure we can search for a specific generic reference by
|
||||||
|
providing its ObjectId.
|
||||||
|
"""
|
||||||
|
class Doc(Document):
|
||||||
|
ref = GenericReferenceField()
|
||||||
|
|
||||||
|
Doc.drop_collection()
|
||||||
|
|
||||||
|
doc1 = Doc.objects.create()
|
||||||
|
doc2 = Doc.objects.create(ref=doc1)
|
||||||
|
|
||||||
|
doc = Doc.objects.get(ref=DBRef('doc', doc1.pk))
|
||||||
|
self.assertEqual(doc, doc2)
|
||||||
|
|
||||||
|
def test_generic_reference_filter_by_objectid(self):
|
||||||
|
"""Ensure we can search for a specific generic reference by
|
||||||
|
providing its DBRef.
|
||||||
|
"""
|
||||||
|
class Doc(Document):
|
||||||
|
ref = GenericReferenceField()
|
||||||
|
|
||||||
|
Doc.drop_collection()
|
||||||
|
|
||||||
|
doc1 = Doc.objects.create()
|
||||||
|
doc2 = Doc.objects.create(ref=doc1)
|
||||||
|
|
||||||
|
self.assertTrue(isinstance(doc1.pk, ObjectId))
|
||||||
|
|
||||||
|
doc = Doc.objects.get(ref=doc1.pk)
|
||||||
|
self.assertEqual(doc, doc2)
|
||||||
|
|
||||||
def test_binary_fields(self):
|
def test_binary_fields(self):
|
||||||
"""Ensure that binary fields can be stored and retrieved.
|
"""Ensure that binary fields can be stored and retrieved.
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user