support position in 'push' #1565

This commit is contained in:
Erdenezul Batmunkh 2017-06-15 06:08:40 +00:00
parent 2f1fe5468e
commit 6903eed4e7
2 changed files with 36 additions and 3 deletions

View File

@ -284,9 +284,11 @@ def update(_doc_cls=None, **update):
if isinstance(field, GeoJsonBaseField):
value = field.to_mongo(value)
if op in (None, 'set', 'push', 'pull'):
if op == 'push' and isinstance(value, (list, tuple, set)):
value = [field.prepare_query_value(op, v) for v in value]
elif op in (None, 'set', 'push', 'pull'):
if field.required or value is not None:
value = field.prepare_query_value(op, value)
value = field.prepare_query_value(op, value)
elif op in ('pushAll', 'pullAll'):
value = [field.prepare_query_value(op, v) for v in value]
elif op in ('addToSet', 'setOnInsert'):
@ -302,6 +304,10 @@ def update(_doc_cls=None, **update):
value = {match: value}
key = '.'.join(parts)
position = None
if parts[-1].isdigit() and isinstance(value, (list, tuple, set)):
key = parts[0]
position = int(parts[-1])
if not op:
raise InvalidQueryError('Updates must supply an operation '
@ -333,10 +339,14 @@ def update(_doc_cls=None, **update):
value = {key: value}
elif op == 'addToSet' and isinstance(value, list):
value = {key: {'$each': value}}
elif op == 'push' and isinstance(value, list):
if position is not None:
value = {key: {'$each': value, '$position': position}}
else:
value = {key: {'$each': value}}
else:
value = {key: value}
key = '$' + op
if key not in mongo_update:
mongo_update[key] = value
elif key in mongo_update and isinstance(mongo_update[key], dict):

View File

@ -1903,6 +1903,29 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection()
def test_update_push_with_position(self):
"""Ensure that the 'push' update with position works properly.
"""
class BlogPost(Document):
slug = StringField()
tags = ListField(StringField())
BlogPost.drop_collection()
post = BlogPost(slug="test")
post.save()
BlogPost.objects.filter(id=post.id).update(push__tags="code")
BlogPost.objects.filter(id=post.id).update(push__tags__0=["mongodb", "python"])
post.reload()
self.assertEqual(post.tags[0], "mongodb")
self.assertEqual(post.tags[1], "python")
self.assertEqual(post.tags[2], "code")
BlogPost.objects.filter(id=post.id).update(set__tags__2="java")
post.reload()
self.assertEqual(post.tags[2], "java")
def test_update_push_and_pull_add_to_set(self):
"""Ensure that the 'pull' update operation works correctly.
"""