From a512ccca28ea2b9645b72f7e2a6593409dd02f2b Mon Sep 17 00:00:00 2001 From: David Bordeynik Date: Wed, 29 Apr 2015 15:11:48 +0300 Subject: [PATCH] fix-#453: Queryset update doesn't go through field validation --- docs/changelog.rst | 1 + mongoengine/base/fields.py | 7 +++++++ mongoengine/fields.py | 24 +++++++++++++++--------- mongoengine/queryset/transform.py | 5 +---- tests/queryset/queryset.py | 14 ++++++++++++++ 5 files changed, 38 insertions(+), 13 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index ed08c391..2a6f7802 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in 0.9.X - DEV ====================== +- Queryset update doesn't go through field validation #453 - Added support for specifying authentication source as option `authSource` in URI. #967 - Fixed mark_as_changed to handle higher/lower level fields changed. #927 - ListField of embedded docs doesn't set the _instance attribute when iterating over it #914 diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index aa16804e..6de2bfc6 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -17,6 +17,11 @@ __all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField") +UPDATE_OPERATORS = set(['set', 'unset', 'inc', 'dec', 'pop', 'push', + 'push_all', 'pull', 'pull_all', 'add_to_set', + 'set_on_insert', 'min', 'max']) + + class BaseField(object): """A base class for fields in a MongoDB document. Instances of this class @@ -150,6 +155,8 @@ class BaseField(object): def prepare_query_value(self, op, value): """Prepare a value that is being used in a query for PyMongo. """ + if op in UPDATE_OPERATORS: + self.validate(value) return value def validate(self, value, clean=True): diff --git a/mongoengine/fields.py b/mongoengine/fields.py index dbed6b99..bd6c88d6 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -108,7 +108,7 @@ class StringField(BaseField): # escape unsafe characters which could lead to a re.error value = re.escape(value) value = re.compile(regex % value, flags) - return value + return super(StringField, self).prepare_query_value(op, value) class URLField(StringField): @@ -203,7 +203,7 @@ class IntField(BaseField): if value is None: return value - return int(value) + return super(IntField, self).prepare_query_value(op, int(value)) class LongField(BaseField): @@ -238,7 +238,7 @@ class LongField(BaseField): if value is None: return value - return long(value) + return super(LongField, self).prepare_query_value(op, long(value)) class FloatField(BaseField): @@ -273,7 +273,7 @@ class FloatField(BaseField): if value is None: return value - return float(value) + return super(FloatField, self).prepare_query_value(op, float(value)) class DecimalField(BaseField): @@ -347,7 +347,7 @@ class DecimalField(BaseField): self.error('Decimal value is too large') def prepare_query_value(self, op, value): - return self.to_mongo(value) + return super(DecimalField, self).prepare_query_value(op, self.to_mongo(value)) class BooleanField(BaseField): @@ -434,7 +434,7 @@ class DateTimeField(BaseField): return None def prepare_query_value(self, op, value): - return self.to_mongo(value) + return super(DateTimeField, self).prepare_query_value(op, self.to_mongo(value)) class ComplexDateTimeField(StringField): @@ -518,7 +518,7 @@ class ComplexDateTimeField(StringField): return self._convert_from_datetime(value) def prepare_query_value(self, op, value): - return self._convert_from_datetime(value) + return super(ComplexDateTimeField, self).prepare_query_value(op, self._convert_from_datetime(value)) class EmbeddedDocumentField(BaseField): @@ -569,6 +569,9 @@ class EmbeddedDocumentField(BaseField): return self.document_type._fields.get(member_name) def prepare_query_value(self, op, value): + if not isinstance(value, self.document_type): + value = self.document_type._from_son(value) + super(EmbeddedDocumentField, self).prepare_query_value(op, value) return self.to_mongo(value) @@ -585,7 +588,7 @@ class GenericEmbeddedDocumentField(BaseField): """ def prepare_query_value(self, op, value): - return self.to_mongo(value) + return super(GenericEmbeddedDocumentField, self).prepare_query_value(op, self.to_mongo(value)) def to_python(self, value): if isinstance(value, dict): @@ -668,7 +671,8 @@ class DynamicField(BaseField): if isinstance(value, basestring): from mongoengine.fields import StringField return StringField().prepare_query_value(op, value) - return self.to_mongo(value) + return super(DynamicField, self).prepare_query_value(op, self.to_mongo(value)) + def validate(self, value, clean=True): if hasattr(value, "validate"): @@ -979,8 +983,10 @@ class ReferenceField(BaseField): def prepare_query_value(self, op, value): if value is None: return None + super(ReferenceField, self).prepare_query_value(op, value) return self.to_mongo(value) + def validate(self, value): if not isinstance(value, (self.document_type, DBRef)): diff --git a/mongoengine/queryset/transform.py b/mongoengine/queryset/transform.py index 03a09dc5..007cf865 100644 --- a/mongoengine/queryset/transform.py +++ b/mongoengine/queryset/transform.py @@ -3,6 +3,7 @@ from collections import defaultdict import pymongo from bson import SON +from mongoengine.base.fields import UPDATE_OPERATORS from mongoengine.connection import get_connection from mongoengine.common import _import_class from mongoengine.errors import InvalidQueryError, LookUpError @@ -24,10 +25,6 @@ CUSTOM_OPERATORS = ('match',) MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + STRING_OPERATORS + CUSTOM_OPERATORS) -UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push', - 'push_all', 'pull', 'pull_all', 'add_to_set', - 'set_on_insert', 'min', 'max') - def query(_doc_cls=None, _field_operation=False, **query): """Transform a query from Django-style format to Mongo format. diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index 6cbac495..ac282c44 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -602,6 +602,20 @@ class QuerySetTest(unittest.TestCase): set__name="bobby", multi=True) self.assertEqual(result, 2) + def test_update_validate(self): + class EmDoc(EmbeddedDocument): + str_f = StringField() + + class Doc(Document): + str_f = StringField() + dt_f = DateTimeField() + cdt_f = ComplexDateTimeField() + ed_f = EmbeddedDocumentField(EmDoc) + + self.assertRaises(ValidationError, Doc.objects().update, str_f=1, upsert=True) + self.assertRaises(ValidationError, Doc.objects().update, dt_f="datetime", upsert=True) + self.assertRaises(ValidationError, Doc.objects().update, ed_f__str_f=1, upsert=True) + def test_upsert(self): self.Person.drop_collection()