Merge pull request #979 from DavidBord/fix-453

fix-#453: Queryset update doesn't go through field validation
This commit is contained in:
David Bordeynik 2015-05-02 20:26:56 +03:00
commit 422ca87a12
5 changed files with 38 additions and 13 deletions

View File

@ -5,6 +5,7 @@ Changelog
Changes in 0.9.X - DEV
======================
- Queryset update doesn't go through field validation #453
- Added support for specifying authentication source as option `authSource` in URI. #967
- Fixed mark_as_changed to handle higher/lower level fields changed. #927
- ListField of embedded docs doesn't set the _instance attribute when iterating over it #914

View File

@ -17,6 +17,11 @@ __all__ = ("BaseField", "ComplexBaseField",
"ObjectIdField", "GeoJsonBaseField")
UPDATE_OPERATORS = set(['set', 'unset', 'inc', 'dec', 'pop', 'push',
'push_all', 'pull', 'pull_all', 'add_to_set',
'set_on_insert', 'min', 'max'])
class BaseField(object):
"""A base class for fields in a MongoDB document. Instances of this class
@ -151,6 +156,8 @@ class BaseField(object):
def prepare_query_value(self, op, value):
"""Prepare a value that is being used in a query for PyMongo.
"""
if op in UPDATE_OPERATORS:
self.validate(value)
return value
def validate(self, value, clean=True):

View File

@ -108,7 +108,7 @@ class StringField(BaseField):
# escape unsafe characters which could lead to a re.error
value = re.escape(value)
value = re.compile(regex % value, flags)
return value
return super(StringField, self).prepare_query_value(op, value)
class URLField(StringField):
@ -203,7 +203,7 @@ class IntField(BaseField):
if value is None:
return value
return int(value)
return super(IntField, self).prepare_query_value(op, int(value))
class LongField(BaseField):
@ -238,7 +238,7 @@ class LongField(BaseField):
if value is None:
return value
return long(value)
return super(LongField, self).prepare_query_value(op, long(value))
class FloatField(BaseField):
@ -273,7 +273,7 @@ class FloatField(BaseField):
if value is None:
return value
return float(value)
return super(FloatField, self).prepare_query_value(op, float(value))
class DecimalField(BaseField):
@ -347,7 +347,7 @@ class DecimalField(BaseField):
self.error('Decimal value is too large')
def prepare_query_value(self, op, value):
return self.to_mongo(value)
return super(DecimalField, self).prepare_query_value(op, self.to_mongo(value))
class BooleanField(BaseField):
@ -434,7 +434,7 @@ class DateTimeField(BaseField):
return None
def prepare_query_value(self, op, value):
return self.to_mongo(value)
return super(DateTimeField, self).prepare_query_value(op, self.to_mongo(value))
class ComplexDateTimeField(StringField):
@ -518,7 +518,7 @@ class ComplexDateTimeField(StringField):
return self._convert_from_datetime(value)
def prepare_query_value(self, op, value):
return self._convert_from_datetime(value)
return super(ComplexDateTimeField, self).prepare_query_value(op, self._convert_from_datetime(value))
class EmbeddedDocumentField(BaseField):
@ -569,6 +569,9 @@ class EmbeddedDocumentField(BaseField):
return self.document_type._fields.get(member_name)
def prepare_query_value(self, op, value):
if not isinstance(value, self.document_type):
value = self.document_type._from_son(value)
super(EmbeddedDocumentField, self).prepare_query_value(op, value)
return self.to_mongo(value)
@ -585,7 +588,7 @@ class GenericEmbeddedDocumentField(BaseField):
"""
def prepare_query_value(self, op, value):
return self.to_mongo(value)
return super(GenericEmbeddedDocumentField, self).prepare_query_value(op, self.to_mongo(value))
def to_python(self, value):
if isinstance(value, dict):
@ -668,7 +671,8 @@ class DynamicField(BaseField):
if isinstance(value, basestring):
from mongoengine.fields import StringField
return StringField().prepare_query_value(op, value)
return self.to_mongo(value)
return super(DynamicField, self).prepare_query_value(op, self.to_mongo(value))
def validate(self, value, clean=True):
if hasattr(value, "validate"):
@ -979,8 +983,10 @@ class ReferenceField(BaseField):
def prepare_query_value(self, op, value):
if value is None:
return None
super(ReferenceField, self).prepare_query_value(op, value)
return self.to_mongo(value)
def validate(self, value):
if not isinstance(value, (self.document_type, DBRef)):

View File

@ -3,6 +3,7 @@ from collections import defaultdict
import pymongo
from bson import SON
from mongoengine.base.fields import UPDATE_OPERATORS
from mongoengine.connection import get_connection
from mongoengine.common import _import_class
from mongoengine.errors import InvalidQueryError, LookUpError
@ -24,10 +25,6 @@ CUSTOM_OPERATORS = ('match',)
MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
STRING_OPERATORS + CUSTOM_OPERATORS)
UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push',
'push_all', 'pull', 'pull_all', 'add_to_set',
'set_on_insert', 'min', 'max')
def query(_doc_cls=None, _field_operation=False, **query):
"""Transform a query from Django-style format to Mongo format.

View File

@ -602,6 +602,20 @@ class QuerySetTest(unittest.TestCase):
set__name="bobby", multi=True)
self.assertEqual(result, 2)
def test_update_validate(self):
class EmDoc(EmbeddedDocument):
str_f = StringField()
class Doc(Document):
str_f = StringField()
dt_f = DateTimeField()
cdt_f = ComplexDateTimeField()
ed_f = EmbeddedDocumentField(EmDoc)
self.assertRaises(ValidationError, Doc.objects().update, str_f=1, upsert=True)
self.assertRaises(ValidationError, Doc.objects().update, dt_f="datetime", upsert=True)
self.assertRaises(ValidationError, Doc.objects().update, ed_f__str_f=1, upsert=True)
def test_upsert(self):
self.Person.drop_collection()