Merge pull request #979 from DavidBord/fix-453
fix-#453: Queryset update doesn't go through field validation
This commit is contained in:
commit
422ca87a12
@ -5,6 +5,7 @@ Changelog
|
|||||||
|
|
||||||
Changes in 0.9.X - DEV
|
Changes in 0.9.X - DEV
|
||||||
======================
|
======================
|
||||||
|
- Queryset update doesn't go through field validation #453
|
||||||
- Added support for specifying authentication source as option `authSource` in URI. #967
|
- Added support for specifying authentication source as option `authSource` in URI. #967
|
||||||
- Fixed mark_as_changed to handle higher/lower level fields changed. #927
|
- Fixed mark_as_changed to handle higher/lower level fields changed. #927
|
||||||
- ListField of embedded docs doesn't set the _instance attribute when iterating over it #914
|
- ListField of embedded docs doesn't set the _instance attribute when iterating over it #914
|
||||||
|
@ -17,6 +17,11 @@ __all__ = ("BaseField", "ComplexBaseField",
|
|||||||
"ObjectIdField", "GeoJsonBaseField")
|
"ObjectIdField", "GeoJsonBaseField")
|
||||||
|
|
||||||
|
|
||||||
|
UPDATE_OPERATORS = set(['set', 'unset', 'inc', 'dec', 'pop', 'push',
|
||||||
|
'push_all', 'pull', 'pull_all', 'add_to_set',
|
||||||
|
'set_on_insert', 'min', 'max'])
|
||||||
|
|
||||||
|
|
||||||
class BaseField(object):
|
class BaseField(object):
|
||||||
|
|
||||||
"""A base class for fields in a MongoDB document. Instances of this class
|
"""A base class for fields in a MongoDB document. Instances of this class
|
||||||
@ -151,6 +156,8 @@ class BaseField(object):
|
|||||||
def prepare_query_value(self, op, value):
|
def prepare_query_value(self, op, value):
|
||||||
"""Prepare a value that is being used in a query for PyMongo.
|
"""Prepare a value that is being used in a query for PyMongo.
|
||||||
"""
|
"""
|
||||||
|
if op in UPDATE_OPERATORS:
|
||||||
|
self.validate(value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def validate(self, value, clean=True):
|
def validate(self, value, clean=True):
|
||||||
|
@ -108,7 +108,7 @@ class StringField(BaseField):
|
|||||||
# escape unsafe characters which could lead to a re.error
|
# escape unsafe characters which could lead to a re.error
|
||||||
value = re.escape(value)
|
value = re.escape(value)
|
||||||
value = re.compile(regex % value, flags)
|
value = re.compile(regex % value, flags)
|
||||||
return value
|
return super(StringField, self).prepare_query_value(op, value)
|
||||||
|
|
||||||
|
|
||||||
class URLField(StringField):
|
class URLField(StringField):
|
||||||
@ -203,7 +203,7 @@ class IntField(BaseField):
|
|||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
return int(value)
|
return super(IntField, self).prepare_query_value(op, int(value))
|
||||||
|
|
||||||
|
|
||||||
class LongField(BaseField):
|
class LongField(BaseField):
|
||||||
@ -238,7 +238,7 @@ class LongField(BaseField):
|
|||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
return long(value)
|
return super(LongField, self).prepare_query_value(op, long(value))
|
||||||
|
|
||||||
|
|
||||||
class FloatField(BaseField):
|
class FloatField(BaseField):
|
||||||
@ -273,7 +273,7 @@ class FloatField(BaseField):
|
|||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
return float(value)
|
return super(FloatField, self).prepare_query_value(op, float(value))
|
||||||
|
|
||||||
|
|
||||||
class DecimalField(BaseField):
|
class DecimalField(BaseField):
|
||||||
@ -347,7 +347,7 @@ class DecimalField(BaseField):
|
|||||||
self.error('Decimal value is too large')
|
self.error('Decimal value is too large')
|
||||||
|
|
||||||
def prepare_query_value(self, op, value):
|
def prepare_query_value(self, op, value):
|
||||||
return self.to_mongo(value)
|
return super(DecimalField, self).prepare_query_value(op, self.to_mongo(value))
|
||||||
|
|
||||||
|
|
||||||
class BooleanField(BaseField):
|
class BooleanField(BaseField):
|
||||||
@ -434,7 +434,7 @@ class DateTimeField(BaseField):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def prepare_query_value(self, op, value):
|
def prepare_query_value(self, op, value):
|
||||||
return self.to_mongo(value)
|
return super(DateTimeField, self).prepare_query_value(op, self.to_mongo(value))
|
||||||
|
|
||||||
|
|
||||||
class ComplexDateTimeField(StringField):
|
class ComplexDateTimeField(StringField):
|
||||||
@ -518,7 +518,7 @@ class ComplexDateTimeField(StringField):
|
|||||||
return self._convert_from_datetime(value)
|
return self._convert_from_datetime(value)
|
||||||
|
|
||||||
def prepare_query_value(self, op, value):
|
def prepare_query_value(self, op, value):
|
||||||
return self._convert_from_datetime(value)
|
return super(ComplexDateTimeField, self).prepare_query_value(op, self._convert_from_datetime(value))
|
||||||
|
|
||||||
|
|
||||||
class EmbeddedDocumentField(BaseField):
|
class EmbeddedDocumentField(BaseField):
|
||||||
@ -569,6 +569,9 @@ class EmbeddedDocumentField(BaseField):
|
|||||||
return self.document_type._fields.get(member_name)
|
return self.document_type._fields.get(member_name)
|
||||||
|
|
||||||
def prepare_query_value(self, op, value):
|
def prepare_query_value(self, op, value):
|
||||||
|
if not isinstance(value, self.document_type):
|
||||||
|
value = self.document_type._from_son(value)
|
||||||
|
super(EmbeddedDocumentField, self).prepare_query_value(op, value)
|
||||||
return self.to_mongo(value)
|
return self.to_mongo(value)
|
||||||
|
|
||||||
|
|
||||||
@ -585,7 +588,7 @@ class GenericEmbeddedDocumentField(BaseField):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def prepare_query_value(self, op, value):
|
def prepare_query_value(self, op, value):
|
||||||
return self.to_mongo(value)
|
return super(GenericEmbeddedDocumentField, self).prepare_query_value(op, self.to_mongo(value))
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value):
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
@ -668,7 +671,8 @@ class DynamicField(BaseField):
|
|||||||
if isinstance(value, basestring):
|
if isinstance(value, basestring):
|
||||||
from mongoengine.fields import StringField
|
from mongoengine.fields import StringField
|
||||||
return StringField().prepare_query_value(op, value)
|
return StringField().prepare_query_value(op, value)
|
||||||
return self.to_mongo(value)
|
return super(DynamicField, self).prepare_query_value(op, self.to_mongo(value))
|
||||||
|
|
||||||
|
|
||||||
def validate(self, value, clean=True):
|
def validate(self, value, clean=True):
|
||||||
if hasattr(value, "validate"):
|
if hasattr(value, "validate"):
|
||||||
@ -979,8 +983,10 @@ class ReferenceField(BaseField):
|
|||||||
def prepare_query_value(self, op, value):
|
def prepare_query_value(self, op, value):
|
||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
|
super(ReferenceField, self).prepare_query_value(op, value)
|
||||||
return self.to_mongo(value)
|
return self.to_mongo(value)
|
||||||
|
|
||||||
|
|
||||||
def validate(self, value):
|
def validate(self, value):
|
||||||
|
|
||||||
if not isinstance(value, (self.document_type, DBRef)):
|
if not isinstance(value, (self.document_type, DBRef)):
|
||||||
|
@ -3,6 +3,7 @@ from collections import defaultdict
|
|||||||
import pymongo
|
import pymongo
|
||||||
from bson import SON
|
from bson import SON
|
||||||
|
|
||||||
|
from mongoengine.base.fields import UPDATE_OPERATORS
|
||||||
from mongoengine.connection import get_connection
|
from mongoengine.connection import get_connection
|
||||||
from mongoengine.common import _import_class
|
from mongoengine.common import _import_class
|
||||||
from mongoengine.errors import InvalidQueryError, LookUpError
|
from mongoengine.errors import InvalidQueryError, LookUpError
|
||||||
@ -24,10 +25,6 @@ CUSTOM_OPERATORS = ('match',)
|
|||||||
MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
|
MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
|
||||||
STRING_OPERATORS + CUSTOM_OPERATORS)
|
STRING_OPERATORS + CUSTOM_OPERATORS)
|
||||||
|
|
||||||
UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push',
|
|
||||||
'push_all', 'pull', 'pull_all', 'add_to_set',
|
|
||||||
'set_on_insert', 'min', 'max')
|
|
||||||
|
|
||||||
|
|
||||||
def query(_doc_cls=None, _field_operation=False, **query):
|
def query(_doc_cls=None, _field_operation=False, **query):
|
||||||
"""Transform a query from Django-style format to Mongo format.
|
"""Transform a query from Django-style format to Mongo format.
|
||||||
|
@ -602,6 +602,20 @@ class QuerySetTest(unittest.TestCase):
|
|||||||
set__name="bobby", multi=True)
|
set__name="bobby", multi=True)
|
||||||
self.assertEqual(result, 2)
|
self.assertEqual(result, 2)
|
||||||
|
|
||||||
|
def test_update_validate(self):
|
||||||
|
class EmDoc(EmbeddedDocument):
|
||||||
|
str_f = StringField()
|
||||||
|
|
||||||
|
class Doc(Document):
|
||||||
|
str_f = StringField()
|
||||||
|
dt_f = DateTimeField()
|
||||||
|
cdt_f = ComplexDateTimeField()
|
||||||
|
ed_f = EmbeddedDocumentField(EmDoc)
|
||||||
|
|
||||||
|
self.assertRaises(ValidationError, Doc.objects().update, str_f=1, upsert=True)
|
||||||
|
self.assertRaises(ValidationError, Doc.objects().update, dt_f="datetime", upsert=True)
|
||||||
|
self.assertRaises(ValidationError, Doc.objects().update, ed_f__str_f=1, upsert=True)
|
||||||
|
|
||||||
def test_upsert(self):
|
def test_upsert(self):
|
||||||
self.Person.drop_collection()
|
self.Person.drop_collection()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user