From 76524b7498a87a9b3e40f89c4efd26376c13e92d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Malthe=20J=C3=B8rgensen?= Date: Tue, 13 Dec 2016 05:27:25 +0100 Subject: [PATCH] Raise TypeError when `__in`-operator used with a Document (#1237) --- mongoengine/common.py | 5 +++- mongoengine/queryset/transform.py | 17 +++++++++-- setup.cfg | 2 +- tests/queryset/queryset.py | 50 +++++++++++++++++++++++++++++++ 4 files changed, 70 insertions(+), 4 deletions(-) diff --git a/mongoengine/common.py b/mongoengine/common.py index 3e63e98e..bde7e78c 100644 --- a/mongoengine/common.py +++ b/mongoengine/common.py @@ -34,7 +34,10 @@ def _import_class(cls_name): queryset_classes = ('OperationError',) deref_classes = ('DeReference',) - if cls_name in doc_classes: + if cls_name == 'BaseDocument': + from mongoengine.base import document as module + import_classes = ['BaseDocument'] + elif cls_name in doc_classes: from mongoengine import document as module import_classes = doc_classes elif cls_name in field_classes: diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index af59917c..61d43490 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -101,8 +101,21 @@ def query(_doc_cls=None, **kwargs): value = value['_id'] elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict): - # 'in', 'nin' and 'all' require a list of values - value = [field.prepare_query_value(op, v) for v in value] + # Raise an error if the in/nin/all/near param is not iterable. We need a + # special check for BaseDocument, because - although it's iterable - using + # it as such in the context of this method is most definitely a mistake. + BaseDocument = _import_class('BaseDocument') + if isinstance(value, BaseDocument): + raise TypeError("When using the `in`, `nin`, `all`, or " + "`near`-operators you can\'t use a " + "`Document`, you must wrap your object " + "in a list (object -> [object]).") + elif not hasattr(value, '__iter__'): + raise TypeError("The `in`, `nin`, `all`, or " + "`near`-operators must be applied to an " + "iterable (e.g. a list).") + else: + value = [field.prepare_query_value(op, v) for v in value] # If we're querying a GenericReferenceField, we need to alter the # key depending on the value: diff --git a/setup.cfg b/setup.cfg index 1887c476..eabe3271 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,5 +7,5 @@ cover-package=mongoengine [flake8] ignore=E501,F401,F403,F405,I201 exclude=build,dist,docs,venv,venv3,.tox,.eggs,tests -max-complexity=45 +max-complexity=47 application-import-names=mongoengine,tests diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index e4c71de7..28b831cd 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -4963,6 +4963,56 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(i, 249) self.assertEqual(j, 249) + def test_in_operator_on_non_iterable(self): + """Ensure that using the `__in` operator on a non-iterable raises an + error. + """ + class User(Document): + name = StringField() + + class BlogPost(Document): + content = StringField() + authors = ListField(ReferenceField(User)) + + User.drop_collection() + BlogPost.drop_collection() + + author = User(name='Test User') + author.save() + post = BlogPost(content='Had a good coffee today...', authors=[author]) + post.save() + + blog_posts = BlogPost.objects(authors__in=[author]) + self.assertEqual(list(blog_posts), [post]) + + # Using the `__in`-operator with a non-iterable should raise a TypeError + self.assertRaises(TypeError, BlogPost.objects(authors__in=author.id).count) + + def test_in_operator_on_document(self): + """Ensure that using the `__in` operator on a `Document` raises an + error. + """ + class User(Document): + name = StringField() + + class BlogPost(Document): + content = StringField() + authors = ListField(ReferenceField(User)) + + User.drop_collection() + BlogPost.drop_collection() + + author = User(name='Test User') + author.save() + post = BlogPost(content='Had a good coffee today...', authors=[author]) + post.save() + + blog_posts = BlogPost.objects(authors__in=[author]) + self.assertEqual(list(blog_posts), [post]) + + # Using the `__in`-operator with a `Document` should raise a TypeError + self.assertRaises(TypeError, BlogPost.objects(authors__in=author).count) + if __name__ == '__main__': unittest.main()