diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index ef0ed04d..703b6e5f 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -765,8 +765,22 @@ class QuerySet(object): key = '.'.join(parts) if op is None or key not in mongo_query: mongo_query[key] = value - elif key in mongo_query and isinstance(mongo_query[key], dict): - mongo_query[key].update(value) + elif key in mongo_query: + if isinstance(mongo_query[key], dict) and isinstance(value, dict): + mongo_query[key].update(value) + elif isinstance(mongo_query[key], list): + mongo_query[key].append(value) + else: + mongo_query[key] = [mongo_query[key], value] + + for k, v in mongo_query.items(): + if isinstance(v, list): + value = [{k:val} for val in v] + if '$and' in mongo_query.keys(): + mongo_query['$and'].append(value) + else: + mongo_query['$and'] = value + del mongo_query[k] return mongo_query diff --git a/tests/test_queryset.py b/tests/test_queryset.py index b4ae805b..02c97e47 100644 --- a/tests/test_queryset.py +++ b/tests/test_queryset.py @@ -827,7 +827,11 @@ class QuerySetTest(unittest.TestCase): def test_filter_chaining(self): """Ensure filters can be chained together. """ + class Blog(Document): + id = StringField(unique=True, primary_key=True) + class BlogPost(Document): + blog = ReferenceField(Blog) title = StringField() is_published = BooleanField() published_date = DateTimeField() @@ -836,13 +840,24 @@ class QuerySetTest(unittest.TestCase): def published(doc_cls, queryset): return queryset(is_published=True) - blog_post_1 = BlogPost(title="Blog Post #1", + Blog.drop_collection() + BlogPost.drop_collection() + + blog_1 = Blog(id="1") + blog_2 = Blog(id="2") + blog_3 = Blog(id="3") + + blog_1.save() + blog_2.save() + blog_3.save() + + blog_post_1 = BlogPost(blog=blog_1, title="Blog Post #1", is_published = True, published_date=datetime(2010, 1, 5, 0, 0 ,0)) - blog_post_2 = BlogPost(title="Blog Post #2", + blog_post_2 = BlogPost(blog=blog_2, title="Blog Post #2", is_published = True, published_date=datetime(2010, 1, 6, 0, 0 ,0)) - blog_post_3 = BlogPost(title="Blog Post #3", + blog_post_3 = BlogPost(blog=blog_3, title="Blog Post #3", is_published = True, published_date=datetime(2010, 1, 7, 0, 0 ,0)) @@ -856,7 +871,14 @@ class QuerySetTest(unittest.TestCase): published_date__lt=datetime(2010, 1, 7, 0, 0 ,0)) self.assertEqual(published_posts.count(), 2) + + blog_posts = BlogPost.objects + blog_posts = blog_posts.filter(blog__in=[blog_1, blog_2]) + blog_posts = blog_posts.filter(blog=blog_3) + self.assertEqual(blog_posts.count(), 0) + BlogPost.drop_collection() + Blog.drop_collection() def test_ordering(self): """Ensure default ordering is applied and can be overridden.