From 159b0828286378543b3d5437b149d041eadfb53b Mon Sep 17 00:00:00 2001 From: reallistic Date: Thu, 24 Sep 2015 16:31:38 -0700 Subject: [PATCH] Recursively create mongo query for embeddeddocument elemMatch --- docs/changelog.rst | 1 + mongoengine/queryset/transform.py | 21 +++++++++++++-------- tests/queryset/queryset.py | 9 +++++++++ 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index b340aab0..bb626f33 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -11,6 +11,7 @@ Changes in 0.10.1 - DEV - Fix Document.reload for DynamicDocument. #1050 - StrictDict & SemiStrictDict are shadowed at init time. #1105 - Remove test dependencies (nose and rednose) from install dependencies list. #1079 +- Recursively build query when using elemMatch operator. #1130 Changes in 0.10.0 ================= diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 03f3acf0..1f18c429 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -26,12 +26,12 @@ MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + STRING_OPERATORS + CUSTOM_OPERATORS) -def query(_doc_cls=None, **query): +def query(_doc_cls=None, **kwargs): """Transform a query from Django-style format to Mongo format. """ mongo_query = {} merge_query = defaultdict(list) - for key, value in sorted(query.items()): + for key, value in sorted(kwargs.items()): if key == "__raw__": mongo_query.update(value) continue @@ -105,13 +105,18 @@ def query(_doc_cls=None, **query): if op: if op in GEO_OPERATORS: value = _geo_operator(field, op, value) - elif op in CUSTOM_OPERATORS: - if op in ('elem_match', 'match'): - value = field.prepare_query_value(op, value) - value = {"$elemMatch": value} + elif op in ('match', 'elemMatch'): + ListField = _import_class('ListField') + EmbeddedDocumentField = _import_class('EmbeddedDocumentField') + if (isinstance(value, dict) and isinstance(field, ListField) and + isinstance(field.field, EmbeddedDocumentField)): + value = query(field.field.document_type, **value) else: - NotImplementedError("Custom method '%s' has not " - "been implemented" % op) + value = field.prepare_query_value(op, value) + value = {"$elemMatch": value} + elif op in CUSTOM_OPERATORS: + NotImplementedError("Custom method '%s' has not " + "been implemented" % op) elif op not in STRING_OPERATORS: value = {'$' + op: value} diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 944c6fc1..3eff7cea 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -4116,6 +4116,15 @@ class QuerySetTest(unittest.TestCase): ak = list(Bar.objects(foo__match=Foo(shape="square", color="purple"))) self.assertEqual([b1], ak) + ak = list( + Bar.objects(foo__elemMatch={'shape': "square", "color__exists": True})) + self.assertEqual([b1, b2], ak) + + ak = list( + Bar.objects(foo__match={'shape': "square", "color__exists": True})) + self.assertEqual([b1, b2], ak) + + def test_upsert_includes_cls(self): """Upserts should include _cls information for inheritable classes """