From 0462f186807a4b710f3e35e063048e1bad9a9cb7 Mon Sep 17 00:00:00 2001 From: Bob Cribbs Date: Tue, 19 Aug 2014 21:29:13 +0300 Subject: [PATCH] Allow atomic update for the entire `DictField` --- AUTHORS | 1 + mongoengine/fields.py | 4 ++++ tests/fields/fields.py | 25 +++++++++++++++++++++++++ 3 files changed, 30 insertions(+) diff --git a/AUTHORS b/AUTHORS index 81ec2f76..326caec6 100644 --- a/AUTHORS +++ b/AUTHORS @@ -206,3 +206,4 @@ that much better: * Clay McClure (https://github.com/claymation) * Bruno Rocha (https://github.com/rochacbruno) * Norberto Leite (https://github.com/nleite) + * Bob Cribbs (https://github.com/bocribbz) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 7bbc221a..8b3cf4c7 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -826,6 +826,10 @@ class DictField(ComplexBaseField): return StringField().prepare_query_value(op, value) if hasattr(self.field, 'field'): + if op in ('set', 'unset') and isinstance(value, dict): + return dict( + (k, self.field.prepare_query_value(op, v)) + for k, v in value.items()) return self.field.prepare_query_value(op, value) return super(DictField, self).prepare_query_value(op, value) diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 0af22a34..7d906917 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -18,6 +18,7 @@ from bson import Binary, DBRef, ObjectId from mongoengine import * from mongoengine.connection import get_db from mongoengine.base import _document_registry +from mongoengine.base.datastructures import BaseDict from mongoengine.errors import NotRegistered from mongoengine.python_support import PY3, b, bin_type @@ -1251,6 +1252,30 @@ class FieldTest(unittest.TestCase): Simple.drop_collection() + def test_atomic_update_dict_field(self): + """Ensure that the entire DictField can be atomically updated.""" + + + class Simple(Document): + mapping = DictField(field=ListField(IntField(required=True))) + + Simple.drop_collection() + + e = Simple() + e.mapping['someints'] = [1, 2] + e.save() + e.update(set__mapping={"ints": [3, 4]}) + e.reload() + self.assertEqual(BaseDict, type(e.mapping)) + self.assertEqual({"ints": [3, 4]}, e.mapping) + + def create_invalid_mapping(): + e.update(set__mapping={"somestrings": ["foo", "bar",]}) + + self.assertRaises(ValueError, create_invalid_mapping) + + Simple.drop_collection() + def test_mapfield(self): """Ensure that the MapField handles the declared type."""