Merge remote-tracking branch 'origin/pr/333' into 333

This commit is contained in:
Ross Lawley 2013-06-04 10:22:54 +00:00
commit 985bfd22de
2 changed files with 10 additions and 7 deletions

View File

@ -23,6 +23,9 @@ class QNodeVisitor(object):
return query return query
class DuplicateQueryConditionsError(InvalidQueryError):
pass
class SimplificationVisitor(QNodeVisitor): class SimplificationVisitor(QNodeVisitor):
"""Simplifies query trees by combinging unnecessary 'and' connection nodes """Simplifies query trees by combinging unnecessary 'and' connection nodes
into a single Q-object. into a single Q-object.
@ -33,7 +36,10 @@ class SimplificationVisitor(QNodeVisitor):
# The simplification only applies to 'simple' queries # The simplification only applies to 'simple' queries
if all(isinstance(node, Q) for node in combination.children): if all(isinstance(node, Q) for node in combination.children):
queries = [n.query for n in combination.children] queries = [n.query for n in combination.children]
return Q(**self._query_conjunction(queries)) try:
return Q(**self._query_conjunction(queries))
except DuplicateQueryConditionsError:
pass
return combination return combination
def _query_conjunction(self, queries): def _query_conjunction(self, queries):
@ -47,8 +53,7 @@ class SimplificationVisitor(QNodeVisitor):
# to a single field # to a single field
intersection = ops.intersection(query_ops) intersection = ops.intersection(query_ops)
if intersection: if intersection:
msg = 'Duplicate query conditions: ' raise DuplicateQueryConditionsError()
raise InvalidQueryError(msg + ', '.join(intersection))
query_ops.update(ops) query_ops.update(ops)
combined_query.update(copy.deepcopy(query)) combined_query.update(copy.deepcopy(query))

View File

@ -69,10 +69,8 @@ class QTest(unittest.TestCase):
y = StringField() y = StringField()
# Check than an error is raised when conflicting queries are anded # Check than an error is raised when conflicting queries are anded
def invalid_combination(): query = (Q(x__lt=7) & Q(x__lt=3)).to_query(TestDoc)
query = Q(x__lt=7) & Q(x__lt=3) self.assertEqual(query, {'$and': [ {'x': {'$lt': 7}}, {'x': {'$lt': 3}} ]})
query.to_query(TestDoc)
self.assertRaises(InvalidQueryError, invalid_combination)
# Check normal cases work without an error # Check normal cases work without an error
query = Q(x__lt=7) & Q(x__gt=3) query = Q(x__lt=7) & Q(x__gt=3)