From 2af5f3c56ebc7feff6a0a1cd95a77e12b425efed Mon Sep 17 00:00:00 2001 From: Harry Marr Date: Tue, 31 Aug 2010 00:24:30 +0100 Subject: [PATCH] Added support for querying by array position. Closes #36. --- mongoengine/queryset.py | 6 +++++- tests/queryset.py | 43 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 662fa8c3..b8ca125d 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -344,6 +344,8 @@ class QuerySet(object): mongo_query = {} for key, value in query.items(): parts = key.split('__') + indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()] + parts = [part for part in parts if not part.isdigit()] # Check for an operator and transform to mongo-style if there is op = None if parts[-1] in operators + match_operators + geo_operators: @@ -381,7 +383,9 @@ class QuerySet(object): "been implemented" % op) elif op not in match_operators: value = {'$' + op: value} - + + for i, part in indices: + parts.insert(i, part) key = '.'.join(parts) if op is None or key not in mongo_query: mongo_query[key] = value diff --git a/tests/queryset.py b/tests/queryset.py index e3912246..3d714be2 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -165,8 +165,49 @@ class QuerySetTest(unittest.TestCase): person = self.Person.objects.get(age__lt=30) self.assertEqual(person.name, "User A") + def test_find_array_position(self): + """Ensure that query by array position works. + """ + class Comment(EmbeddedDocument): + name = StringField() + + class Post(EmbeddedDocument): + comments = ListField(EmbeddedDocumentField(Comment)) + + class Blog(Document): + tags = ListField(StringField()) + posts = ListField(EmbeddedDocumentField(Post)) + + Blog.drop_collection() - + Blog.objects.create(tags=['a', 'b']) + self.assertEqual(len(Blog.objects(tags__0='a')), 1) + self.assertEqual(len(Blog.objects(tags__0='b')), 0) + self.assertEqual(len(Blog.objects(tags__1='a')), 0) + self.assertEqual(len(Blog.objects(tags__1='b')), 1) + + Blog.drop_collection() + + comment1 = Comment(name='testa') + comment2 = Comment(name='testb') + post1 = Post(comments=[comment1, comment2]) + post2 = Post(comments=[comment2, comment2]) + blog1 = Blog.objects.create(posts=[post1, post2]) + blog2 = Blog.objects.create(posts=[post2, post1]) + + blog = Blog.objects(posts__0__comments__0__name='testa').get() + self.assertEqual(blog, blog1) + + query = Blog.objects(posts__1__comments__1__name='testb') + self.assertEqual(len(query), 2) + + query = Blog.objects(posts__1__comments__1__name='testa') + self.assertEqual(len(query), 0) + + query = Blog.objects(posts__0__comments__1__name='testa') + self.assertEqual(len(query), 0) + + Blog.drop_collection() def test_get_or_create(self): """Ensure that ``get_or_create`` returns one result or creates a new