Merge branch 'master' into support_multiple_operator

This commit is contained in:
erdenezul 2017-07-31 22:53:01 +08:00 committed by GitHub
commit 9a6aa8f8c6
9 changed files with 156 additions and 22 deletions

View File

@ -243,4 +243,5 @@ that much better:
* Victor Varvaryuk * Victor Varvaryuk
* Stanislav Kaledin (https://github.com/sallyruthstruik) * Stanislav Kaledin (https://github.com/sallyruthstruik)
* Dmitry Yantsen (https://github.com/mrTable) * Dmitry Yantsen (https://github.com/mrTable)
* Erdenezul Batmunkh (https://github.com/erdenezul) * Renjianxin (https://github.com/Davidrjx)
* Erdenezul Batmunkh (https://github.com/erdenezul)

View File

@ -565,6 +565,15 @@ cannot use the `$` syntax in keyword arguments it has been mapped to `S`::
>>> post.tags >>> post.tags
['database', 'mongodb'] ['database', 'mongodb']
From MongoDB version 2.6, push operator supports $position value which allows
to push values with index.
>>> post = BlogPost(title="Test", tags=["mongo"])
>>> post.save()
>>> post.update(push__tags__0=["database", "code"])
>>> post.reload()
>>> post.tags
['database', 'code', 'mongo']
.. note:: .. note::
Currently only top level lists are handled, future versions of mongodb / Currently only top level lists are handled, future versions of mongodb /
pymongo plan to support nested positional operators. See `The $ positional pymongo plan to support nested positional operators. See `The $ positional

View File

@ -146,13 +146,14 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
raise MongoEngineConnectionError(msg) raise MongoEngineConnectionError(msg)
def _clean_settings(settings_dict): def _clean_settings(settings_dict):
irrelevant_fields = set([ # set literal more efficient than calling set function
'name', 'username', 'password', 'authentication_source', irrelevant_fields_set = {
'authentication_mechanism' 'name', 'username', 'password',
]) 'authentication_source', 'authentication_mechanism'
}
return { return {
k: v for k, v in settings_dict.items() k: v for k, v in settings_dict.items()
if k not in irrelevant_fields if k not in irrelevant_fields_set
} }
# Retrieve a copy of the connection settings associated with the requested # Retrieve a copy of the connection settings associated with the requested

View File

@ -1722,25 +1722,33 @@ class BaseQuerySet(object):
return frequencies return frequencies
def _fields_to_dbfields(self, fields): def _fields_to_dbfields(self, fields):
"""Translate fields paths to its db equivalents""" """Translate fields' paths to their db equivalents."""
ret = []
subclasses = [] subclasses = []
document = self._document if self._document._meta['allow_inheritance']:
if document._meta['allow_inheritance']:
subclasses = [get_document(x) subclasses = [get_document(x)
for x in document._subclasses][1:] for x in self._document._subclasses][1:]
db_field_paths = []
for field in fields: for field in fields:
field_parts = field.split('.')
try: try:
field = '.'.join(f.db_field for f in field = '.'.join(
document._lookup_field(field.split('.'))) f if isinstance(f, six.string_types) else f.db_field
ret.append(field) for f in self._document._lookup_field(field_parts)
)
db_field_paths.append(field)
except LookUpError as err: except LookUpError as err:
found = False found = False
# If a field path wasn't found on the main document, go
# through its subclasses and see if it exists on any of them.
for subdoc in subclasses: for subdoc in subclasses:
try: try:
subfield = '.'.join(f.db_field for f in subfield = '.'.join(
subdoc._lookup_field(field.split('.'))) f if isinstance(f, six.string_types) else f.db_field
ret.append(subfield) for f in subdoc._lookup_field(field_parts)
)
db_field_paths.append(subfield)
found = True found = True
break break
except LookUpError: except LookUpError:
@ -1748,7 +1756,8 @@ class BaseQuerySet(object):
if not found: if not found:
raise err raise err
return ret
return db_field_paths
def _get_order_by(self, keys): def _get_order_by(self, keys):
"""Given a list of MongoEngine-style sort keys, return a list """Given a list of MongoEngine-style sort keys, return a list

View File

@ -284,7 +284,9 @@ def update(_doc_cls=None, **update):
if isinstance(field, GeoJsonBaseField): if isinstance(field, GeoJsonBaseField):
value = field.to_mongo(value) value = field.to_mongo(value)
if op in (None, 'set', 'push', 'pull'): if op == 'push' and isinstance(value, (list, tuple, set)):
value = [field.prepare_query_value(op, v) for v in value]
elif op in (None, 'set', 'push', 'pull'):
if field.required or value is not None: if field.required or value is not None:
value = field.prepare_query_value(op, value) value = field.prepare_query_value(op, value)
elif op in ('pushAll', 'pullAll'): elif op in ('pushAll', 'pullAll'):
@ -333,10 +335,22 @@ def update(_doc_cls=None, **update):
value = {key: value} value = {key: value}
elif op == 'addToSet' and isinstance(value, list): elif op == 'addToSet' and isinstance(value, list):
value = {key: {'$each': value}} value = {key: {'$each': value}}
elif op == 'push':
if parts[-1].isdigit():
key = parts[0]
position = int(parts[-1])
# $position expects an iterable. If pushing a single value,
# wrap it in a list.
if not isinstance(value, (set, tuple, list)):
value = [value]
value = {key: {'$each': value, '$position': position}}
elif isinstance(value, list):
value = {key: {'$each': value}}
else:
value = {key: value}
else: else:
value = {key: value} value = {key: value}
key = '$' + op key = '$' + op
if key not in mongo_update: if key not in mongo_update:
mongo_update[key] = value mongo_update[key] = value
elif key in mongo_update and isinstance(mongo_update[key], dict): elif key in mongo_update and isinstance(mongo_update[key], dict):

View File

@ -22,6 +22,8 @@ from mongoengine.queryset import NULLIFY, Q
from mongoengine.context_managers import switch_db, query_counter from mongoengine.context_managers import switch_db, query_counter
from mongoengine import signals from mongoengine import signals
from tests.utils import needs_mongodb_v26
TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__),
'../fields/mongoengine.png') '../fields/mongoengine.png')
@ -826,6 +828,22 @@ class InstanceTest(unittest.TestCase):
self.assertDbEqual([dict(other_doc.to_mongo()), dict(doc.to_mongo())]) self.assertDbEqual([dict(other_doc.to_mongo()), dict(doc.to_mongo())])
@needs_mongodb_v26
def test_modify_with_positional_push(self):
class BlogPost(Document):
tags = ListField(StringField())
post = BlogPost.objects.create(tags=['python'])
self.assertEqual(post.tags, ['python'])
post.modify(push__tags__0=['code', 'mongo'])
self.assertEqual(post.tags, ['code', 'mongo', 'python'])
# Assert same order of the list items is maintained in the db
self.assertEqual(
BlogPost._get_collection().find_one({'_id': post.pk})['tags'],
['code', 'mongo', 'python']
)
def test_save(self): def test_save(self):
"""Ensure that a document may be saved in the database.""" """Ensure that a document may be saved in the database."""
@ -3149,6 +3167,22 @@ class InstanceTest(unittest.TestCase):
person.update(set__height=2.0) person.update(set__height=2.0)
@needs_mongodb_v26
def test_push_with_position(self):
"""Ensure that push with position works properly for an instance."""
class BlogPost(Document):
slug = StringField()
tags = ListField(StringField())
blog = BlogPost()
blog.slug = "ABC"
blog.tags = ["python"]
blog.save()
blog.update(push__tags__0=["mongodb", "code"])
blog.reload()
self.assertEqual(blog.tags, ['mongodb', 'code', 'python'])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -197,14 +197,18 @@ class OnlyExcludeAllTest(unittest.TestCase):
title = StringField() title = StringField()
text = StringField() text = StringField()
class VariousData(EmbeddedDocument):
some = BooleanField()
class BlogPost(Document): class BlogPost(Document):
content = StringField() content = StringField()
author = EmbeddedDocumentField(User) author = EmbeddedDocumentField(User)
comments = ListField(EmbeddedDocumentField(Comment)) comments = ListField(EmbeddedDocumentField(Comment))
various = MapField(field=EmbeddedDocumentField(VariousData))
BlogPost.drop_collection() BlogPost.drop_collection()
post = BlogPost(content='Had a good coffee today...') post = BlogPost(content='Had a good coffee today...', various={'test_dynamic':{'some': True}})
post.author = User(name='Test User') post.author = User(name='Test User')
post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')] post.comments = [Comment(title='I aggree', text='Great post!'), Comment(title='Coffee', text='I hate coffee')]
post.save() post.save()
@ -215,6 +219,9 @@ class OnlyExcludeAllTest(unittest.TestCase):
self.assertEqual(obj.author.name, 'Test User') self.assertEqual(obj.author.name, 'Test User')
self.assertEqual(obj.comments, []) self.assertEqual(obj.comments, [])
obj = BlogPost.objects.only('various.test_dynamic.some').get()
self.assertEqual(obj.various["test_dynamic"].some, True)
obj = BlogPost.objects.only('content', 'comments.title',).get() obj = BlogPost.objects.only('content', 'comments.title',).get()
self.assertEqual(obj.content, 'Had a good coffee today...') self.assertEqual(obj.content, 'Had a good coffee today...')
self.assertEqual(obj.author, None) self.assertEqual(obj.author, None)

View File

@ -1,6 +1,8 @@
import unittest import unittest
from mongoengine import connect, Document, IntField from mongoengine import connect, Document, IntField, StringField, ListField
from tests.utils import needs_mongodb_v26
__all__ = ("FindAndModifyTest",) __all__ = ("FindAndModifyTest",)
@ -94,6 +96,37 @@ class FindAndModifyTest(unittest.TestCase):
self.assertEqual(old_doc.to_mongo(), {"_id": 1}) self.assertEqual(old_doc.to_mongo(), {"_id": 1})
self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}]) self.assertDbEqual([{"_id": 0, "value": 0}, {"_id": 1, "value": -1}])
@needs_mongodb_v26
def test_modify_with_push(self):
class BlogPost(Document):
tags = ListField(StringField())
BlogPost.drop_collection()
blog = BlogPost.objects.create()
# Push a new tag via modify with new=False (default).
BlogPost(id=blog.id).modify(push__tags='code')
self.assertEqual(blog.tags, [])
blog.reload()
self.assertEqual(blog.tags, ['code'])
# Push a new tag via modify with new=True.
blog = BlogPost.objects(id=blog.id).modify(push__tags='java', new=True)
self.assertEqual(blog.tags, ['code', 'java'])
# Push a new tag with a positional argument.
blog = BlogPost.objects(id=blog.id).modify(
push__tags__0='python',
new=True)
self.assertEqual(blog.tags, ['python', 'code', 'java'])
# Push multiple new tags with a positional argument.
blog = BlogPost.objects(id=blog.id).modify(
push__tags__1=['go', 'rust'],
new=True)
self.assertEqual(blog.tags, ['python', 'go', 'rust', 'code', 'java'])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -1917,6 +1917,32 @@ class QuerySetTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
@needs_mongodb_v26
def test_update_push_with_position(self):
"""Ensure that the 'push' update with position works properly.
"""
class BlogPost(Document):
slug = StringField()
tags = ListField(StringField())
BlogPost.drop_collection()
post = BlogPost.objects.create(slug="test")
BlogPost.objects.filter(id=post.id).update(push__tags="code")
BlogPost.objects.filter(id=post.id).update(push__tags__0=["mongodb", "python"])
post.reload()
self.assertEqual(post.tags, ['mongodb', 'python', 'code'])
BlogPost.objects.filter(id=post.id).update(set__tags__2="java")
post.reload()
self.assertEqual(post.tags, ['mongodb', 'python', 'java'])
#test push with singular value
BlogPost.objects.filter(id=post.id).update(push__tags__0='scala')
post.reload()
self.assertEqual(post.tags, ['scala', 'mongodb', 'python', 'java'])
def test_update_push_and_pull_add_to_set(self): def test_update_push_and_pull_add_to_set(self):
"""Ensure that the 'pull' update operation works correctly. """Ensure that the 'pull' update operation works correctly.
""" """