diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 3de10a69..7241efbd 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -281,7 +281,7 @@ def update(_doc_cls=None, **update): if op == 'pull': if field.required or value is not None: - if match == 'in' and not isinstance(value, dict): + if match in ('in', 'nin') and not isinstance(value, dict): value = _prepare_query_for_iterable(field, op, value) else: value = field.prepare_query_value(op, value) diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 31b1641e..0b88193e 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -2193,6 +2193,40 @@ class QuerySetTest(unittest.TestCase): Site.objects(id=s.id).update_one( pull_all__collaborators__helpful__name=['Ross']) + def test_pull_from_nested_embedded_using_in_nin(self): + """Ensure that the 'pull' update operation works on embedded documents using 'in' and 'nin' operators. + """ + + class User(EmbeddedDocument): + name = StringField() + + def __unicode__(self): + return '%s' % self.name + + class Collaborator(EmbeddedDocument): + helpful = ListField(EmbeddedDocumentField(User)) + unhelpful = ListField(EmbeddedDocumentField(User)) + + class Site(Document): + name = StringField(max_length=75, unique=True, required=True) + collaborators = EmbeddedDocumentField(Collaborator) + + Site.drop_collection() + + a = User(name='Esteban') + b = User(name='Frank') + x = User(name='Harry') + y = User(name='John') + + s = Site(name="test", collaborators=Collaborator( + helpful=[a, b], unhelpful=[x, y])).save() + + Site.objects(id=s.id).update_one(pull__collaborators__helpful__name__in=['Esteban']) # Pull a + self.assertEqual(Site.objects.first().collaborators['helpful'], [b]) + + Site.objects(id=s.id).update_one(pull__collaborators__unhelpful__name__nin=['John']) # Pull x + self.assertEqual(Site.objects.first().collaborators['unhelpful'], [y]) + def test_pull_from_nested_mapfield(self): class Collaborator(EmbeddedDocument): diff --git a/tests/queryset/transform.py b/tests/queryset/transform.py index 8064f09c..b2bc1d6c 100644 --- a/tests/queryset/transform.py +++ b/tests/queryset/transform.py @@ -283,6 +283,11 @@ class TransformTest(unittest.TestCase): update = transform.update(MainDoc, pull__content__heading='xyz') self.assertEqual(update, {'$pull': {'content.heading': 'xyz'}}) + update = transform.update(MainDoc, pull__content__text__word__in=['foo', 'bar']) + self.assertEqual(update, {'$pull': {'content.text': {'word': {'$in': ['foo', 'bar']}}}}) + + update = transform.update(MainDoc, pull__content__text__word__nin=['foo', 'bar']) + self.assertEqual(update, {'$pull': {'content.text': {'word': {'$nin': ['foo', 'bar']}}}}) if __name__ == '__main__': unittest.main()