Make chained querysets work if constraining the same fields.

Refs hmarr/mongoengine#554
This commit is contained in:
Anthony Nemitz 2012-07-27 23:06:47 -07:00 committed by Ross Lawley
parent f99b7a811b
commit 12c8b5c0b9
2 changed files with 41 additions and 5 deletions

View File

@ -765,8 +765,22 @@ class QuerySet(object):
key = '.'.join(parts) key = '.'.join(parts)
if op is None or key not in mongo_query: if op is None or key not in mongo_query:
mongo_query[key] = value mongo_query[key] = value
elif key in mongo_query and isinstance(mongo_query[key], dict): elif key in mongo_query:
mongo_query[key].update(value) if isinstance(mongo_query[key], dict) and isinstance(value, dict):
mongo_query[key].update(value)
elif isinstance(mongo_query[key], list):
mongo_query[key].append(value)
else:
mongo_query[key] = [mongo_query[key], value]
for k, v in mongo_query.items():
if isinstance(v, list):
value = [{k:val} for val in v]
if '$and' in mongo_query.keys():
mongo_query['$and'].append(value)
else:
mongo_query['$and'] = value
del mongo_query[k]
return mongo_query return mongo_query

View File

@ -827,7 +827,11 @@ class QuerySetTest(unittest.TestCase):
def test_filter_chaining(self): def test_filter_chaining(self):
"""Ensure filters can be chained together. """Ensure filters can be chained together.
""" """
class Blog(Document):
id = StringField(unique=True, primary_key=True)
class BlogPost(Document): class BlogPost(Document):
blog = ReferenceField(Blog)
title = StringField() title = StringField()
is_published = BooleanField() is_published = BooleanField()
published_date = DateTimeField() published_date = DateTimeField()
@ -836,13 +840,24 @@ class QuerySetTest(unittest.TestCase):
def published(doc_cls, queryset): def published(doc_cls, queryset):
return queryset(is_published=True) return queryset(is_published=True)
blog_post_1 = BlogPost(title="Blog Post #1", Blog.drop_collection()
BlogPost.drop_collection()
blog_1 = Blog(id="1")
blog_2 = Blog(id="2")
blog_3 = Blog(id="3")
blog_1.save()
blog_2.save()
blog_3.save()
blog_post_1 = BlogPost(blog=blog_1, title="Blog Post #1",
is_published = True, is_published = True,
published_date=datetime(2010, 1, 5, 0, 0 ,0)) published_date=datetime(2010, 1, 5, 0, 0 ,0))
blog_post_2 = BlogPost(title="Blog Post #2", blog_post_2 = BlogPost(blog=blog_2, title="Blog Post #2",
is_published = True, is_published = True,
published_date=datetime(2010, 1, 6, 0, 0 ,0)) published_date=datetime(2010, 1, 6, 0, 0 ,0))
blog_post_3 = BlogPost(title="Blog Post #3", blog_post_3 = BlogPost(blog=blog_3, title="Blog Post #3",
is_published = True, is_published = True,
published_date=datetime(2010, 1, 7, 0, 0 ,0)) published_date=datetime(2010, 1, 7, 0, 0 ,0))
@ -856,7 +871,14 @@ class QuerySetTest(unittest.TestCase):
published_date__lt=datetime(2010, 1, 7, 0, 0 ,0)) published_date__lt=datetime(2010, 1, 7, 0, 0 ,0))
self.assertEqual(published_posts.count(), 2) self.assertEqual(published_posts.count(), 2)
blog_posts = BlogPost.objects
blog_posts = blog_posts.filter(blog__in=[blog_1, blog_2])
blog_posts = blog_posts.filter(blog=blog_3)
self.assertEqual(blog_posts.count(), 0)
BlogPost.drop_collection() BlogPost.drop_collection()
Blog.drop_collection()
def test_ordering(self): def test_ordering(self):
"""Ensure default ordering is applied and can be overridden. """Ensure default ordering is applied and can be overridden.