Fixed: ListField minus index assignment does not work #1119

Add code to detect '-1' as a integer.
Normalize negative index to regular list index
Added list assignment test case
This commit is contained in:
Gang Li 2015-10-12 10:34:26 -04:00
parent d4f6ef4f1b
commit 5bbfca45fa
2 changed files with 54 additions and 2 deletions

View File

@ -606,7 +606,9 @@ class BaseDocument(object):
for p in parts: for p in parts:
if isinstance(d, (ObjectId, DBRef)): if isinstance(d, (ObjectId, DBRef)):
break break
elif isinstance(d, list) and p.isdigit(): elif isinstance(d, list) and p.lstrip('-').isdigit():
if p[0] == '-':
p = str(len(d)+int(p))
try: try:
d = d[int(p)] d = d[int(p)]
except IndexError: except IndexError:
@ -640,7 +642,9 @@ class BaseDocument(object):
parts = path.split('.') parts = path.split('.')
db_field_name = parts.pop() db_field_name = parts.pop()
for p in parts: for p in parts:
if isinstance(d, list) and p.isdigit(): if isinstance(d, list) and p.lstrip('-').isdigit():
if p[0] == '-':
p = str(len(d)+int(p))
d = d[int(p)] d = d[int(p)]
elif (hasattr(d, '__getattribute__') and elif (hasattr(d, '__getattribute__') and
not isinstance(d, dict)): not isinstance(d, dict)):

View File

@ -1022,6 +1022,54 @@ class FieldTest(unittest.TestCase):
self.assertEqual(BlogPost.objects(info=['1', '2', '3', '4', '1', '2', '3', '4']).count(), 1) self.assertEqual(BlogPost.objects(info=['1', '2', '3', '4', '1', '2', '3', '4']).count(), 1)
BlogPost.drop_collection() BlogPost.drop_collection()
def test_list_assignment(self):
"""Ensure that list field element assignment and slicing work
"""
class BlogPost(Document):
info = ListField()
BlogPost.drop_collection()
post = BlogPost()
post.info = ['e1', 'e2', 3, '4', 5]
post.save()
post.info[0] = 1
post.save()
post.reload()
self.assertEqual(post.info[0], 1)
post.info[1:3] = ['n2', 'n3']
post.save()
post.reload()
self.assertEqual(post.info, [1, 'n2', 'n3', '4', 5])
post.info[-1] = 'n5'
post.save()
post.reload()
self.assertEqual(post.info, [1, 'n2', 'n3', '4', 'n5'])
post.info[-2] = 4
post.save()
post.reload()
self.assertEqual(post.info, [1, 'n2', 'n3', 4, 'n5'])
post.info[1:-1] = [2]
post.save()
post.reload()
self.assertEqual(post.info, [1, 2, 'n5'])
post.info[:-1] = [1, 'n2', 'n3', 4]
post.save()
post.reload()
self.assertEqual(post.info, [1, 'n2', 'n3', 4, 'n5'])
post.info[-4:3] = [2, 3]
post.save()
post.reload()
self.assertEqual(post.info, [1, 2, 3, 4, 'n5'])
def test_list_field_passed_in_value(self): def test_list_field_passed_in_value(self):
class Foo(Document): class Foo(Document):
bars = ListField(ReferenceField("Bar")) bars = ListField(ReferenceField("Bar"))