Add ability to filter the generic reference field by ObjectId and DBRef (#1425)
This commit is contained in:
parent
25e0f12976
commit
1b9432824b
@ -1249,7 +1249,7 @@ class GenericReferenceField(BaseField):
|
|||||||
if document is None:
|
if document is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if isinstance(document, (dict, SON)):
|
if isinstance(document, (dict, SON, ObjectId, DBRef)):
|
||||||
return document
|
return document
|
||||||
|
|
||||||
id_field_name = document.__class__._meta['id_field']
|
id_field_name = document.__class__._meta['id_field']
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
from bson import SON
|
from bson import ObjectId, SON
|
||||||
|
from bson.dbref import DBRef
|
||||||
import pymongo
|
import pymongo
|
||||||
|
|
||||||
from mongoengine.base.fields import UPDATE_OPERATORS
|
from mongoengine.base.fields import UPDATE_OPERATORS
|
||||||
@ -26,6 +27,7 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
|
|||||||
STRING_OPERATORS + CUSTOM_OPERATORS)
|
STRING_OPERATORS + CUSTOM_OPERATORS)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO make this less complex
|
||||||
def query(_doc_cls=None, **kwargs):
|
def query(_doc_cls=None, **kwargs):
|
||||||
"""Transform a query from Django-style format to Mongo format.
|
"""Transform a query from Django-style format to Mongo format.
|
||||||
"""
|
"""
|
||||||
@ -62,6 +64,7 @@ def query(_doc_cls=None, **kwargs):
|
|||||||
parts = []
|
parts = []
|
||||||
|
|
||||||
CachedReferenceField = _import_class('CachedReferenceField')
|
CachedReferenceField = _import_class('CachedReferenceField')
|
||||||
|
GenericReferenceField = _import_class('GenericReferenceField')
|
||||||
|
|
||||||
cleaned_fields = []
|
cleaned_fields = []
|
||||||
for field in fields:
|
for field in fields:
|
||||||
@ -101,6 +104,16 @@ def query(_doc_cls=None, **kwargs):
|
|||||||
# 'in', 'nin' and 'all' require a list of values
|
# 'in', 'nin' and 'all' require a list of values
|
||||||
value = [field.prepare_query_value(op, v) for v in value]
|
value = [field.prepare_query_value(op, v) for v in value]
|
||||||
|
|
||||||
|
# If we're querying a GenericReferenceField, we need to alter the
|
||||||
|
# key depending on the value:
|
||||||
|
# * If the value is a DBRef, the key should be "field_name._ref".
|
||||||
|
# * If the value is an ObjectId, the key should be "field_name._ref.$id".
|
||||||
|
if isinstance(field, GenericReferenceField):
|
||||||
|
if isinstance(value, DBRef):
|
||||||
|
parts[-1] += '._ref'
|
||||||
|
elif isinstance(value, ObjectId):
|
||||||
|
parts[-1] += '._ref.$id'
|
||||||
|
|
||||||
# if op and op not in COMPARISON_OPERATORS:
|
# if op and op not in COMPARISON_OPERATORS:
|
||||||
if op:
|
if op:
|
||||||
if op in GEO_OPERATORS:
|
if op in GEO_OPERATORS:
|
||||||
@ -128,11 +141,13 @@ def query(_doc_cls=None, **kwargs):
|
|||||||
|
|
||||||
for i, part in indices:
|
for i, part in indices:
|
||||||
parts.insert(i, part)
|
parts.insert(i, part)
|
||||||
|
|
||||||
key = '.'.join(parts)
|
key = '.'.join(parts)
|
||||||
|
|
||||||
if op is None or key not in mongo_query:
|
if op is None or key not in mongo_query:
|
||||||
mongo_query[key] = value
|
mongo_query[key] = value
|
||||||
elif key in mongo_query:
|
elif key in mongo_query:
|
||||||
if key in mongo_query and isinstance(mongo_query[key], dict):
|
if isinstance(mongo_query[key], dict):
|
||||||
mongo_query[key].update(value)
|
mongo_query[key].update(value)
|
||||||
# $max/minDistance needs to come last - convert to SON
|
# $max/minDistance needs to come last - convert to SON
|
||||||
value_dict = mongo_query[key]
|
value_dict = mongo_query[key]
|
||||||
|
@ -9,5 +9,5 @@ tests = tests
|
|||||||
[flake8]
|
[flake8]
|
||||||
ignore=E501,F401,F403,F405,I201
|
ignore=E501,F401,F403,F405,I201
|
||||||
exclude=build,dist,docs,venv,.tox,.eggs,tests
|
exclude=build,dist,docs,venv,.tox,.eggs,tests
|
||||||
max-complexity=42
|
max-complexity=45
|
||||||
application-import-names=mongoengine,tests
|
application-import-names=mongoengine,tests
|
||||||
|
@ -2810,6 +2810,38 @@ class FieldTest(unittest.TestCase):
|
|||||||
Post.drop_collection()
|
Post.drop_collection()
|
||||||
User.drop_collection()
|
User.drop_collection()
|
||||||
|
|
||||||
|
def test_generic_reference_filter_by_dbref(self):
|
||||||
|
"""Ensure we can search for a specific generic reference by
|
||||||
|
providing its ObjectId.
|
||||||
|
"""
|
||||||
|
class Doc(Document):
|
||||||
|
ref = GenericReferenceField()
|
||||||
|
|
||||||
|
Doc.drop_collection()
|
||||||
|
|
||||||
|
doc1 = Doc.objects.create()
|
||||||
|
doc2 = Doc.objects.create(ref=doc1)
|
||||||
|
|
||||||
|
doc = Doc.objects.get(ref=DBRef('doc', doc1.pk))
|
||||||
|
self.assertEqual(doc, doc2)
|
||||||
|
|
||||||
|
def test_generic_reference_filter_by_objectid(self):
|
||||||
|
"""Ensure we can search for a specific generic reference by
|
||||||
|
providing its DBRef.
|
||||||
|
"""
|
||||||
|
class Doc(Document):
|
||||||
|
ref = GenericReferenceField()
|
||||||
|
|
||||||
|
Doc.drop_collection()
|
||||||
|
|
||||||
|
doc1 = Doc.objects.create()
|
||||||
|
doc2 = Doc.objects.create(ref=doc1)
|
||||||
|
|
||||||
|
self.assertTrue(isinstance(doc1.pk, ObjectId))
|
||||||
|
|
||||||
|
doc = Doc.objects.get(ref=doc1.pk)
|
||||||
|
self.assertEqual(doc, doc2)
|
||||||
|
|
||||||
def test_binary_fields(self):
|
def test_binary_fields(self):
|
||||||
"""Ensure that binary fields can be stored and retrieved.
|
"""Ensure that binary fields can be stored and retrieved.
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user