Raise TypeError when __in-operator used with a Document (#1237)

This commit is contained in:
Malthe Jørgensen 2016-12-13 05:27:25 +01:00 committed by Stefan Wójcik
parent 65914fb2b2
commit 76524b7498
4 changed files with 70 additions and 4 deletions

View File

@ -34,7 +34,10 @@ def _import_class(cls_name):
queryset_classes = ('OperationError',) queryset_classes = ('OperationError',)
deref_classes = ('DeReference',) deref_classes = ('DeReference',)
if cls_name in doc_classes: if cls_name == 'BaseDocument':
from mongoengine.base import document as module
import_classes = ['BaseDocument']
elif cls_name in doc_classes:
from mongoengine import document as module from mongoengine import document as module
import_classes = doc_classes import_classes = doc_classes
elif cls_name in field_classes: elif cls_name in field_classes:

View File

@ -101,7 +101,20 @@ def query(_doc_cls=None, **kwargs):
value = value['_id'] value = value['_id']
elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict): elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
# 'in', 'nin' and 'all' require a list of values # Raise an error if the in/nin/all/near param is not iterable. We need a
# special check for BaseDocument, because - although it's iterable - using
# it as such in the context of this method is most definitely a mistake.
BaseDocument = _import_class('BaseDocument')
if isinstance(value, BaseDocument):
raise TypeError("When using the `in`, `nin`, `all`, or "
"`near`-operators you can\'t use a "
"`Document`, you must wrap your object "
"in a list (object -> [object]).")
elif not hasattr(value, '__iter__'):
raise TypeError("The `in`, `nin`, `all`, or "
"`near`-operators must be applied to an "
"iterable (e.g. a list).")
else:
value = [field.prepare_query_value(op, v) for v in value] value = [field.prepare_query_value(op, v) for v in value]
# If we're querying a GenericReferenceField, we need to alter the # If we're querying a GenericReferenceField, we need to alter the

View File

@ -7,5 +7,5 @@ cover-package=mongoengine
[flake8] [flake8]
ignore=E501,F401,F403,F405,I201 ignore=E501,F401,F403,F405,I201
exclude=build,dist,docs,venv,venv3,.tox,.eggs,tests exclude=build,dist,docs,venv,venv3,.tox,.eggs,tests
max-complexity=45 max-complexity=47
application-import-names=mongoengine,tests application-import-names=mongoengine,tests

View File

@ -4963,6 +4963,56 @@ class QuerySetTest(unittest.TestCase):
self.assertEqual(i, 249) self.assertEqual(i, 249)
self.assertEqual(j, 249) self.assertEqual(j, 249)
def test_in_operator_on_non_iterable(self):
"""Ensure that using the `__in` operator on a non-iterable raises an
error.
"""
class User(Document):
name = StringField()
class BlogPost(Document):
content = StringField()
authors = ListField(ReferenceField(User))
User.drop_collection()
BlogPost.drop_collection()
author = User(name='Test User')
author.save()
post = BlogPost(content='Had a good coffee today...', authors=[author])
post.save()
blog_posts = BlogPost.objects(authors__in=[author])
self.assertEqual(list(blog_posts), [post])
# Using the `__in`-operator with a non-iterable should raise a TypeError
self.assertRaises(TypeError, BlogPost.objects(authors__in=author.id).count)
def test_in_operator_on_document(self):
"""Ensure that using the `__in` operator on a `Document` raises an
error.
"""
class User(Document):
name = StringField()
class BlogPost(Document):
content = StringField()
authors = ListField(ReferenceField(User))
User.drop_collection()
BlogPost.drop_collection()
author = User(name='Test User')
author.save()
post = BlogPost(content='Had a good coffee today...', authors=[author])
post.save()
blog_posts = BlogPost.objects(authors__in=[author])
self.assertEqual(list(blog_posts), [post])
# Using the `__in`-operator with a `Document` should raise a TypeError
self.assertRaises(TypeError, BlogPost.objects(authors__in=author).count)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()