From 118c0deb7a7bd48170eb49624f79f737419c5342 Mon Sep 17 00:00:00 2001 From: Alistair Roche Date: Tue, 24 May 2011 11:31:44 +0100 Subject: [PATCH] Fixed list-indexing syntax; created tests. --- mongoengine/queryset.py | 16 +++++++++++++- tests/queryset.py | 49 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 239e146b..e6c93353 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -524,6 +524,15 @@ class QuerySet(object): fields = [] field = None for field_name in parts: + # Handle ListField indexing: + if field_name.isdigit(): + try: + field = field.field + except AttributeError, err: + raise InvalidQueryError( + "Can't use index on unsubscriptable field (%s)" % err) + fields.append(field_name) + continue if field is None: # Look up first field from the document if field_name == 'pk': @@ -1072,7 +1081,12 @@ class QuerySet(object): if _doc_cls: # Switch field names to proper names [set in Field(name='foo')] fields = QuerySet._lookup_field(_doc_cls, parts) - parts = [field.db_field for field in fields] + parts = [] + for field in fields: + if isinstance(field, str): + parts.append(field) + else: + parts.append(field.db_field) # Convert value to proper value field = fields[-1] diff --git a/tests/queryset.py b/tests/queryset.py index 5a2c46cb..b0693f66 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -211,6 +211,55 @@ class QuerySetTest(unittest.TestCase): Blog.drop_collection() + def test_update_array_position(self): + """Ensure that updating by array position works. + + Check update() and update_one() can take syntax like: + set__posts__1__comments__1__name="testc" + Check that it only works for ListFields. + """ + class Comment(EmbeddedDocument): + name = StringField() + + class Post(EmbeddedDocument): + comments = ListField(EmbeddedDocumentField(Comment)) + + class Blog(Document): + tags = ListField(StringField()) + posts = ListField(EmbeddedDocumentField(Post)) + + Blog.drop_collection() + + comment1 = Comment(name='testa') + comment2 = Comment(name='testb') + post1 = Post(comments=[comment1, comment2]) + post2 = Post(comments=[comment2, comment2]) + blog1 = Blog.objects.create(posts=[post1, post2]) + blog2 = Blog.objects.create(posts=[post2, post1]) + + # Update all of the first comments of second posts of all blogs + blog = Blog.objects().update(set__posts__1__comments__0__name="testc") + testc_blogs = Blog.objects(posts__1__comments__0__name="testc") + self.assertEqual(len(testc_blogs), 2) + + Blog.drop_collection() + + blog1 = Blog.objects.create(posts=[post1, post2]) + blog2 = Blog.objects.create(posts=[post2, post1]) + + # Update only the first blog returned by the query + blog = Blog.objects().update_one( + set__posts__1__comments__1__name="testc") + testc_blogs = Blog.objects(posts__1__comments__1__name="testc") + self.assertEqual(len(testc_blogs), 1) + + # Check that using this indexing syntax on a non-list fails + def non_list_indexing(): + Blog.objects().update(set__posts__1__comments__0__name__1="asdf") + self.assertRaises(InvalidQueryError, non_list_indexing) + + Blog.drop_collection() + def test_get_or_create(self): """Ensure that ``get_or_create`` returns one result or creates a new document.