From d9005ac2fc7ecf56bcbeb9a00259db5b9a677edb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Mon, 28 Nov 2011 14:45:57 -0200 Subject: [PATCH] added elemMatch support --- mongoengine/queryset.py | 14 +++++++++++--- tests/queryset.py | 23 +++++++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index a9b3ea99..6025dd99 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -640,6 +640,7 @@ class QuerySet(object): match_operators = ['contains', 'icontains', 'startswith', 'istartswith', 'endswith', 'iendswith', 'exact', 'iexact'] + custom_operators = ['match'] mongo_query = {} for key, value in query.items(): @@ -652,7 +653,7 @@ class QuerySet(object): parts = [part for part in parts if not part.isdigit()] # Check for an operator and transform to mongo-style if there is op = None - if parts[-1] in operators + match_operators + geo_operators: + if parts[-1] in operators + match_operators + geo_operators + custom_operators: op = parts.pop() negate = False @@ -685,7 +686,7 @@ class QuerySet(object): if isinstance(field, basestring): if op in match_operators and isinstance(value, basestring): from mongoengine import StringField - value = StringField().prepare_query_value(op, value) + value = StringField.prepare_query_value(op, value) else: value = field else: @@ -693,7 +694,8 @@ class QuerySet(object): elif op in ('in', 'nin', 'all', 'near'): # 'in', 'nin' and 'all' require a list of values value = [field.prepare_query_value(op, v) for v in value] - + + # if op and op not in match_operators: if op: if op in geo_operators: @@ -712,6 +714,12 @@ class QuerySet(object): else: raise NotImplementedError("Geo method '%s' has not " "been implemented" % op) + elif op in custom_operators: + if op == 'match': + value = {"$elemMatch": value} + else: + NotImplementedError("Custom method '%s' has not " + "been implemented" % op) elif op not in match_operators: value = {'$' + op: value} diff --git a/tests/queryset.py b/tests/queryset.py index 7c1c8dcc..d978cf28 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -2866,6 +2866,29 @@ class QueryFieldListTest(unittest.TestCase): q += QueryFieldList(fields=['a'], value={"$slice": 5}) self.assertEqual(q.as_dict(), {'a': {"$slice": 5}}) + def test_elem_match(self): + class Foo(EmbeddedDocument): + shape = StringField() + color = StringField() + trick = BooleanField() + meta = {'allow_inheritance': False} + + class Bar(Document): + foo = ListField(EmbeddedDocumentField(Foo)) + meta = {'allow_inheritance': False} + + Bar.drop_collection() + + b1 = Bar(foo=[Foo(shape= "square", color ="purple", thick = False), + Foo(shape= "circle", color ="red", thick = True)]) + b1.save() + + b2 = Bar(foo=[Foo(shape= "square", color ="red", thick = True), + Foo(shape= "circle", color ="purple", thick = False)]) + b2.save() + + ak = list(Bar.objects(foo__match={'shape': "square", "color": "purple"})) + self.assertEqual([b1], ak) if __name__ == '__main__': unittest.main()