diff --git a/docs/changelog.rst b/docs/changelog.rst index 3569132d..8199e03e 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Changes in 0.8.4 ================ +- Update transform to handle docs erroneously passed to unset (#416) - Fixed indexing - turn off _cls (#414) - Fixed dereference threading issue in ComplexField.__get__ (#412) - Fixed QuerySetNoCache.count() caching (#410) diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 352774ff..e0a7d3c6 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -203,11 +203,13 @@ def update(_doc_cls=None, **update): value = field.prepare_query_value(op, value) elif op in ('pushAll', 'pullAll'): value = [field.prepare_query_value(op, v) for v in value] - elif op == 'addToSet': + elif op in ('addToSet', 'setOnInsert'): if isinstance(value, (list, tuple, set)): value = [field.prepare_query_value(op, v) for v in value] elif field.required or value is not None: value = field.prepare_query_value(op, value) + elif op == "unset": + value = 1 if match: match = '$' + match diff --git a/tests/queryset/transform.py b/tests/queryset/transform.py index 7886965b..d2e8b784 100644 --- a/tests/queryset/transform.py +++ b/tests/queryset/transform.py @@ -31,6 +31,31 @@ class TransformTest(unittest.TestCase): self.assertEqual(transform.query(name__exists=True), {'name': {'$exists': True}}) + def test_transform_update(self): + class DicDoc(Document): + dictField = DictField() + + class Doc(Document): + pass + + DicDoc.drop_collection() + Doc.drop_collection() + + doc = Doc().save() + dic_doc = DicDoc().save() + + for k, v in (("set", "$set"), ("set_on_insert", "$setOnInsert"), ("push", "$push")): + update = transform.update(DicDoc, **{"%s__dictField__test" % k: doc}) + self.assertTrue(isinstance(update[v]["dictField.test"], dict)) + + # Update special cases + update = transform.update(DicDoc, unset__dictField__test=doc) + self.assertEqual(update["$unset"]["dictField.test"], 1) + + update = transform.update(DicDoc, pull__dictField__test=doc) + self.assertTrue(isinstance(update["$pull"]["dictField"]["test"], dict)) + + def test_query_field_name(self): """Ensure that the correct field name is used when querying. """