From e90f6a2fa3766872b44edad29a21527659984ae3 Mon Sep 17 00:00:00 2001 From: Andy Yankovsky Date: Thu, 14 Sep 2017 20:28:15 +0300 Subject: [PATCH] Fix update via pull__something__in=[] --- mongoengine/queryset/transform.py | 46 +++++++++++++++++++------------ tests/queryset/transform.py | 7 +++++ 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index a9907ada..1f70a48b 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -101,21 +101,8 @@ def query(_doc_cls=None, **kwargs): value = value['_id'] elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict): - # Raise an error if the in/nin/all/near param is not iterable. We need a - # special check for BaseDocument, because - although it's iterable - using - # it as such in the context of this method is most definitely a mistake. - BaseDocument = _import_class('BaseDocument') - if isinstance(value, BaseDocument): - raise TypeError("When using the `in`, `nin`, `all`, or " - "`near`-operators you can\'t use a " - "`Document`, you must wrap your object " - "in a list (object -> [object]).") - elif not hasattr(value, '__iter__'): - raise TypeError("The `in`, `nin`, `all`, or " - "`near`-operators must be applied to an " - "iterable (e.g. a list).") - else: - value = [field.prepare_query_value(op, v) for v in value] + # Raise an error if the in/nin/all/near param is not iterable. + value = _prepare_query_for_iterable(field, op, value) # If we're querying a GenericReferenceField, we need to alter the # key depending on the value: @@ -284,9 +271,15 @@ def update(_doc_cls=None, **update): if isinstance(field, GeoJsonBaseField): value = field.to_mongo(value) - if op == 'push' and isinstance(value, (list, tuple, set)): + if op == 'pull': + if field.required or value is not None: + if match == 'in' and not isinstance(value, dict): + value = _prepare_query_for_iterable(field, op, value) + else: + value = field.prepare_query_value(op, value) + elif op == 'push' and isinstance(value, (list, tuple, set)): value = [field.prepare_query_value(op, v) for v in value] - elif op in (None, 'set', 'push', 'pull'): + elif op in (None, 'set', 'push'): if field.required or value is not None: value = field.prepare_query_value(op, value) elif op in ('pushAll', 'pullAll'): @@ -439,3 +432,22 @@ def _infer_geometry(value): raise InvalidQueryError('Invalid $geometry data. Can be either a ' 'dictionary or (nested) lists of coordinate(s)') + + +def _prepare_query_for_iterable(field, op, value): + # We need a special check for BaseDocument, because - although it's iterable - using + # it as such in the context of this method is most definitely a mistake. + BaseDocument = _import_class('BaseDocument') + + if isinstance(value, BaseDocument): + raise TypeError("When using the `in`, `nin`, `all`, or " + "`near`-operators you can\'t use a " + "`Document`, you must wrap your object " + "in a list (object -> [object]).") + + if not hasattr(value, '__iter__'): + raise TypeError("The `in`, `nin`, `all`, or " + "`near`-operators must be applied to an " + "iterable (e.g. a list).") + + return [field.prepare_query_value(op, v) for v in value] diff --git a/tests/queryset/transform.py b/tests/queryset/transform.py index 20ab0b3f..a043a647 100644 --- a/tests/queryset/transform.py +++ b/tests/queryset/transform.py @@ -28,12 +28,16 @@ class TransformTest(unittest.TestCase): {'name': {'$exists': True}}) def test_transform_update(self): + class LisDoc(Document): + foo = ListField(StringField()) + class DicDoc(Document): dictField = DictField() class Doc(Document): pass + LisDoc.drop_collection() DicDoc.drop_collection() Doc.drop_collection() @@ -51,6 +55,9 @@ class TransformTest(unittest.TestCase): update = transform.update(DicDoc, pull__dictField__test=doc) self.assertTrue(isinstance(update["$pull"]["dictField"]["test"], dict)) + update = transform.update(LisDoc, pull__foo__in=['a']) + self.assertEqual(update, {'$pull': {'foo': {'$in': ['a']}}}) + def test_query_field_name(self): """Ensure that the correct field name is used when querying. """