From a3c46fec0778dd6dba5aa9d693eb9cecac530f0c Mon Sep 17 00:00:00 2001 From: Harry Marr Date: Sun, 3 Oct 2010 21:26:26 +0100 Subject: [PATCH] Compilation of combinations - simple $or now works --- mongoengine/queryset.py | 32 +++++++++++++++++++------------- tests/queryset.py | 14 ++++++++++++++ 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index ad3c2de1..b3fe29f5 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -43,6 +43,20 @@ class QNodeVisitor(object): def visit_query(self, query): return query + def _query_conjunction(self, queries): + query_ops = set() + combined_query = {} + for query in queries: + ops = set(query.keys()) + intersection = ops.intersection(query_ops) + if intersection: + msg = 'Duplicate query contitions: ' + raise InvalidQueryError(msg + ', '.join(intersection)) + + query_ops.update(ops) + combined_query.update(copy.deepcopy(query)) + return combined_query + class SimplificationVisitor(QNodeVisitor): @@ -53,18 +67,8 @@ class SimplificationVisitor(QNodeVisitor): if any(not isinstance(node, NewQ) for node in combination.children): return combination - query_ops = set() - query = {} - for node in combination.children: - ops = set(node.query.keys()) - intersection = ops.intersection(query_ops) - if intersection: - msg = 'Duplicate query contitions: ' - raise InvalidQueryError(msg + ', '.join(intersection)) - - query_ops.update(ops) - query.update(copy.deepcopy(node.query)) - return NewQ(**query) + queries = [node.query for node in combination.children] + return NewQ(**self._query_conjunction(queries)) class QueryCompilerVisitor(QNodeVisitor): @@ -74,7 +78,9 @@ class QueryCompilerVisitor(QNodeVisitor): def visit_combination(self, combination): if combination.operation == combination.OR: - return combination + return {'$or': combination.children} + elif combination.operation == combination.AND: + return self._query_conjunction(combination.children) return combination def visit_query(self, query): diff --git a/tests/queryset.py b/tests/queryset.py index 60952513..6d3114e5 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -1432,5 +1432,19 @@ class NewQTest(unittest.TestCase): query = (q1 & q2).to_query(TestDoc) self.assertEqual(query, {'x': {'$lt': 7, '$gt': 3}}) + def test_or_combination(self): + class TestDoc(Document): + x = IntField() + + q1 = NewQ(x__lt=3) + q2 = NewQ(x__gt=7) + query = (q1 | q2).to_query(TestDoc) + self.assertEqual(query, { + '$or': [ + {'x': {'$lt': 3}}, + {'x': {'$gt': 7}}, + ] + }) + if __name__ == '__main__': unittest.main()