Fixed list-indexing syntax; created tests.

This commit is contained in:
Alistair Roche 2011-05-24 11:31:44 +01:00
parent 1b72ea9cc1
commit 118c0deb7a
2 changed files with 64 additions and 1 deletions

View File

@ -524,6 +524,15 @@ class QuerySet(object):
fields = [] fields = []
field = None field = None
for field_name in parts: for field_name in parts:
# Handle ListField indexing:
if field_name.isdigit():
try:
field = field.field
except AttributeError, err:
raise InvalidQueryError(
"Can't use index on unsubscriptable field (%s)" % err)
fields.append(field_name)
continue
if field is None: if field is None:
# Look up first field from the document # Look up first field from the document
if field_name == 'pk': if field_name == 'pk':
@ -1072,7 +1081,12 @@ class QuerySet(object):
if _doc_cls: if _doc_cls:
# Switch field names to proper names [set in Field(name='foo')] # Switch field names to proper names [set in Field(name='foo')]
fields = QuerySet._lookup_field(_doc_cls, parts) fields = QuerySet._lookup_field(_doc_cls, parts)
parts = [field.db_field for field in fields] parts = []
for field in fields:
if isinstance(field, str):
parts.append(field)
else:
parts.append(field.db_field)
# Convert value to proper value # Convert value to proper value
field = fields[-1] field = fields[-1]

View File

@ -211,6 +211,55 @@ class QuerySetTest(unittest.TestCase):
Blog.drop_collection() Blog.drop_collection()
def test_update_array_position(self):
"""Ensure that updating by array position works.
Check update() and update_one() can take syntax like:
set__posts__1__comments__1__name="testc"
Check that it only works for ListFields.
"""
class Comment(EmbeddedDocument):
name = StringField()
class Post(EmbeddedDocument):
comments = ListField(EmbeddedDocumentField(Comment))
class Blog(Document):
tags = ListField(StringField())
posts = ListField(EmbeddedDocumentField(Post))
Blog.drop_collection()
comment1 = Comment(name='testa')
comment2 = Comment(name='testb')
post1 = Post(comments=[comment1, comment2])
post2 = Post(comments=[comment2, comment2])
blog1 = Blog.objects.create(posts=[post1, post2])
blog2 = Blog.objects.create(posts=[post2, post1])
# Update all of the first comments of second posts of all blogs
blog = Blog.objects().update(set__posts__1__comments__0__name="testc")
testc_blogs = Blog.objects(posts__1__comments__0__name="testc")
self.assertEqual(len(testc_blogs), 2)
Blog.drop_collection()
blog1 = Blog.objects.create(posts=[post1, post2])
blog2 = Blog.objects.create(posts=[post2, post1])
# Update only the first blog returned by the query
blog = Blog.objects().update_one(
set__posts__1__comments__1__name="testc")
testc_blogs = Blog.objects(posts__1__comments__1__name="testc")
self.assertEqual(len(testc_blogs), 1)
# Check that using this indexing syntax on a non-list fails
def non_list_indexing():
Blog.objects().update(set__posts__1__comments__0__name__1="asdf")
self.assertRaises(InvalidQueryError, non_list_indexing)
Blog.drop_collection()
def test_get_or_create(self): def test_get_or_create(self):
"""Ensure that ``get_or_create`` returns one result or creates a new """Ensure that ``get_or_create`` returns one result or creates a new
document. document.