Compare commits
3 Commits
fix-strict
...
fix-generi
Author | SHA1 | Date | |
---|---|---|---|
|
e085f22b9b | ||
|
2904ce091b | ||
|
15714ef855 |
@@ -438,7 +438,7 @@ class StrictDict(object):
|
||||
__slots__ = allowed_keys_tuple
|
||||
|
||||
def __repr__(self):
|
||||
return "{%s}" % ', '.join('"{0!s}": {0!r}'.format(k) for k in self.iterkeys())
|
||||
return "{%s}" % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items())
|
||||
|
||||
cls._classes[allowed_keys] = SpecificStrictDict
|
||||
return cls._classes[allowed_keys]
|
||||
|
@@ -1249,7 +1249,7 @@ class GenericReferenceField(BaseField):
|
||||
if document is None:
|
||||
return None
|
||||
|
||||
if isinstance(document, (dict, SON)):
|
||||
if isinstance(document, (dict, SON, ObjectId, DBRef)):
|
||||
return document
|
||||
|
||||
id_field_name = document.__class__._meta['id_field']
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from bson import SON
|
||||
from bson import ObjectId, SON
|
||||
from bson.dbref import DBRef
|
||||
import pymongo
|
||||
|
||||
from mongoengine.base.fields import UPDATE_OPERATORS
|
||||
@@ -26,6 +27,7 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
|
||||
STRING_OPERATORS + CUSTOM_OPERATORS)
|
||||
|
||||
|
||||
# TODO make this less complex
|
||||
def query(_doc_cls=None, **kwargs):
|
||||
"""Transform a query from Django-style format to Mongo format.
|
||||
"""
|
||||
@@ -62,6 +64,7 @@ def query(_doc_cls=None, **kwargs):
|
||||
parts = []
|
||||
|
||||
CachedReferenceField = _import_class('CachedReferenceField')
|
||||
GenericReferenceField = _import_class('GenericReferenceField')
|
||||
|
||||
cleaned_fields = []
|
||||
for field in fields:
|
||||
@@ -101,6 +104,16 @@ def query(_doc_cls=None, **kwargs):
|
||||
# 'in', 'nin' and 'all' require a list of values
|
||||
value = [field.prepare_query_value(op, v) for v in value]
|
||||
|
||||
# If we're querying a GenericReferenceField, we need to alter the
|
||||
# key depending on the value:
|
||||
# * If the value is a DBRef, the key should be "field_name._ref".
|
||||
# * If the value is an ObjectId, the key should be "field_name._ref.$id".
|
||||
if isinstance(field, GenericReferenceField):
|
||||
if isinstance(value, DBRef):
|
||||
parts[-1] += '._ref'
|
||||
elif isinstance(value, ObjectId):
|
||||
parts[-1] += '._ref.$id'
|
||||
|
||||
# if op and op not in COMPARISON_OPERATORS:
|
||||
if op:
|
||||
if op in GEO_OPERATORS:
|
||||
@@ -128,11 +141,13 @@ def query(_doc_cls=None, **kwargs):
|
||||
|
||||
for i, part in indices:
|
||||
parts.insert(i, part)
|
||||
|
||||
key = '.'.join(parts)
|
||||
|
||||
if op is None or key not in mongo_query:
|
||||
mongo_query[key] = value
|
||||
elif key in mongo_query:
|
||||
if key in mongo_query and isinstance(mongo_query[key], dict):
|
||||
if isinstance(mongo_query[key], dict):
|
||||
mongo_query[key].update(value)
|
||||
# $max/minDistance needs to come last - convert to SON
|
||||
value_dict = mongo_query[key]
|
||||
|
@@ -9,5 +9,5 @@ tests = tests
|
||||
[flake8]
|
||||
ignore=E501,F401,F403,F405,I201
|
||||
exclude=build,dist,docs,venv,.tox,.eggs,tests
|
||||
max-complexity=42
|
||||
max-complexity=45
|
||||
application-import-names=mongoengine,tests
|
||||
|
@@ -2810,6 +2810,38 @@ class FieldTest(unittest.TestCase):
|
||||
Post.drop_collection()
|
||||
User.drop_collection()
|
||||
|
||||
def test_generic_reference_filter_by_dbref(self):
|
||||
"""Ensure we can search for a specific generic reference by
|
||||
providing its ObjectId.
|
||||
"""
|
||||
class Doc(Document):
|
||||
ref = GenericReferenceField()
|
||||
|
||||
Doc.drop_collection()
|
||||
|
||||
doc1 = Doc.objects.create()
|
||||
doc2 = Doc.objects.create(ref=doc1)
|
||||
|
||||
doc = Doc.objects.get(ref=DBRef('doc', doc1.pk))
|
||||
self.assertEqual(doc, doc2)
|
||||
|
||||
def test_generic_reference_filter_by_objectid(self):
|
||||
"""Ensure we can search for a specific generic reference by
|
||||
providing its DBRef.
|
||||
"""
|
||||
class Doc(Document):
|
||||
ref = GenericReferenceField()
|
||||
|
||||
Doc.drop_collection()
|
||||
|
||||
doc1 = Doc.objects.create()
|
||||
doc2 = Doc.objects.create(ref=doc1)
|
||||
|
||||
self.assertTrue(isinstance(doc1.pk, ObjectId))
|
||||
|
||||
doc = Doc.objects.get(ref=doc1.pk)
|
||||
self.assertEqual(doc, doc2)
|
||||
|
||||
def test_binary_fields(self):
|
||||
"""Ensure that binary fields can be stored and retrieved.
|
||||
"""
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import unittest
|
||||
from mongoengine.base.datastructures import StrictDict, SemiStrictDict
|
||||
|
||||
from mongoengine.base.datastructures import StrictDict, SemiStrictDict
|
||||
|
||||
|
||||
class TestStrictDict(unittest.TestCase):
|
||||
@@ -13,9 +14,17 @@ class TestStrictDict(unittest.TestCase):
|
||||
d = self.dtype(a=1, b=1, c=1)
|
||||
self.assertEqual((d.a, d.b, d.c), (1, 1, 1))
|
||||
|
||||
def test_repr(self):
|
||||
d = self.dtype(a=1, b=2, c=3)
|
||||
self.assertEqual(repr(d), '{"a": 1, "b": 2, "c": 3}')
|
||||
|
||||
# make sure quotes are escaped properly
|
||||
d = self.dtype(a='"', b="'", c="")
|
||||
self.assertEqual(repr(d), '{"a": \'"\', "b": "\'", "c": \'\'}')
|
||||
|
||||
def test_init_fails_on_nonexisting_attrs(self):
|
||||
self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3))
|
||||
|
||||
|
||||
def test_eq(self):
|
||||
d = self.dtype(a=1, b=1, c=1)
|
||||
dd = self.dtype(a=1, b=1, c=1)
|
||||
@@ -24,7 +33,7 @@ class TestStrictDict(unittest.TestCase):
|
||||
g = self.strict_dict_class(("a", "b", "c", "d"))(a=1, b=1, c=1, d=1)
|
||||
h = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=1)
|
||||
i = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=2)
|
||||
|
||||
|
||||
self.assertEqual(d, dd)
|
||||
self.assertNotEqual(d, e)
|
||||
self.assertNotEqual(d, f)
|
||||
@@ -38,19 +47,19 @@ class TestStrictDict(unittest.TestCase):
|
||||
d.a = 1
|
||||
self.assertEqual(d.a, 1)
|
||||
self.assertRaises(AttributeError, lambda: d.b)
|
||||
|
||||
|
||||
def test_setattr_raises_on_nonexisting_attr(self):
|
||||
d = self.dtype()
|
||||
|
||||
def _f():
|
||||
d.x = 1
|
||||
self.assertRaises(AttributeError, _f)
|
||||
|
||||
|
||||
def test_setattr_getattr_special(self):
|
||||
d = self.strict_dict_class(["items"])
|
||||
d.items = 1
|
||||
self.assertEqual(d.items, 1)
|
||||
|
||||
|
||||
def test_get(self):
|
||||
d = self.dtype(a=1)
|
||||
self.assertEqual(d.get('a'), 1)
|
||||
@@ -88,7 +97,7 @@ class TestSemiSrictDict(TestStrictDict):
|
||||
def test_init_succeeds_with_nonexisting_attrs(self):
|
||||
d = self.dtype(a=1, b=1, c=1, x=2)
|
||||
self.assertEqual((d.a, d.b, d.c, d.x), (1, 1, 1, 2))
|
||||
|
||||
|
||||
def test_iter_with_nonexisting_attrs(self):
|
||||
d = self.dtype(a=1, b=1, c=1, x=2)
|
||||
self.assertEqual(list(d), ['a', 'b', 'c', 'x'])
|
||||
|
Reference in New Issue
Block a user