Merge pull request #979 from DavidBord/fix-453
fix-#453: Queryset update doesn't go through field validation
This commit is contained in:
		| @@ -5,6 +5,7 @@ Changelog | |||||||
|  |  | ||||||
| Changes in 0.9.X - DEV | Changes in 0.9.X - DEV | ||||||
| ====================== | ====================== | ||||||
|  | - Queryset update doesn't go through field validation #453 | ||||||
| - Added support for specifying authentication source as option `authSource` in URI. #967 | - Added support for specifying authentication source as option `authSource` in URI. #967 | ||||||
| - Fixed mark_as_changed to handle higher/lower level fields changed. #927 | - Fixed mark_as_changed to handle higher/lower level fields changed. #927 | ||||||
| - ListField of embedded docs doesn't set the _instance attribute when iterating over it #914 | - ListField of embedded docs doesn't set the _instance attribute when iterating over it #914 | ||||||
|   | |||||||
| @@ -17,6 +17,11 @@ __all__ = ("BaseField", "ComplexBaseField", | |||||||
|            "ObjectIdField", "GeoJsonBaseField") |            "ObjectIdField", "GeoJsonBaseField") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | UPDATE_OPERATORS = set(['set', 'unset', 'inc', 'dec', 'pop', 'push', | ||||||
|  |                         'push_all', 'pull', 'pull_all', 'add_to_set', | ||||||
|  |                         'set_on_insert', 'min', 'max']) | ||||||
|  |  | ||||||
|  |  | ||||||
| class BaseField(object): | class BaseField(object): | ||||||
|  |  | ||||||
|     """A base class for fields in a MongoDB document. Instances of this class |     """A base class for fields in a MongoDB document. Instances of this class | ||||||
| @@ -151,6 +156,8 @@ class BaseField(object): | |||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         """Prepare a value that is being used in a query for PyMongo. |         """Prepare a value that is being used in a query for PyMongo. | ||||||
|         """ |         """ | ||||||
|  |         if op in UPDATE_OPERATORS: | ||||||
|  |             self.validate(value) | ||||||
|         return value |         return value | ||||||
|  |  | ||||||
|     def validate(self, value, clean=True): |     def validate(self, value, clean=True): | ||||||
|   | |||||||
| @@ -108,7 +108,7 @@ class StringField(BaseField): | |||||||
|             # escape unsafe characters which could lead to a re.error |             # escape unsafe characters which could lead to a re.error | ||||||
|             value = re.escape(value) |             value = re.escape(value) | ||||||
|             value = re.compile(regex % value, flags) |             value = re.compile(regex % value, flags) | ||||||
|         return value |         return super(StringField, self).prepare_query_value(op, value) | ||||||
|  |  | ||||||
|  |  | ||||||
| class URLField(StringField): | class URLField(StringField): | ||||||
| @@ -203,7 +203,7 @@ class IntField(BaseField): | |||||||
|         if value is None: |         if value is None: | ||||||
|             return value |             return value | ||||||
|  |  | ||||||
|         return int(value) |         return super(IntField, self).prepare_query_value(op, int(value)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class LongField(BaseField): | class LongField(BaseField): | ||||||
| @@ -238,7 +238,7 @@ class LongField(BaseField): | |||||||
|         if value is None: |         if value is None: | ||||||
|             return value |             return value | ||||||
|  |  | ||||||
|         return long(value) |         return super(LongField, self).prepare_query_value(op, long(value)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class FloatField(BaseField): | class FloatField(BaseField): | ||||||
| @@ -273,7 +273,7 @@ class FloatField(BaseField): | |||||||
|         if value is None: |         if value is None: | ||||||
|             return value |             return value | ||||||
|  |  | ||||||
|         return float(value) |         return super(FloatField, self).prepare_query_value(op, float(value)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class DecimalField(BaseField): | class DecimalField(BaseField): | ||||||
| @@ -347,7 +347,7 @@ class DecimalField(BaseField): | |||||||
|             self.error('Decimal value is too large') |             self.error('Decimal value is too large') | ||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         return self.to_mongo(value) |         return super(DecimalField, self).prepare_query_value(op, self.to_mongo(value)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class BooleanField(BaseField): | class BooleanField(BaseField): | ||||||
| @@ -434,7 +434,7 @@ class DateTimeField(BaseField): | |||||||
|                     return None |                     return None | ||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         return self.to_mongo(value) |         return super(DateTimeField, self).prepare_query_value(op, self.to_mongo(value)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ComplexDateTimeField(StringField): | class ComplexDateTimeField(StringField): | ||||||
| @@ -518,7 +518,7 @@ class ComplexDateTimeField(StringField): | |||||||
|         return self._convert_from_datetime(value) |         return self._convert_from_datetime(value) | ||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         return self._convert_from_datetime(value) |         return super(ComplexDateTimeField, self).prepare_query_value(op, self._convert_from_datetime(value)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class EmbeddedDocumentField(BaseField): | class EmbeddedDocumentField(BaseField): | ||||||
| @@ -569,6 +569,9 @@ class EmbeddedDocumentField(BaseField): | |||||||
|         return self.document_type._fields.get(member_name) |         return self.document_type._fields.get(member_name) | ||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|  |         if not isinstance(value, self.document_type): | ||||||
|  |             value = self.document_type._from_son(value) | ||||||
|  |         super(EmbeddedDocumentField, self).prepare_query_value(op, value) | ||||||
|         return self.to_mongo(value) |         return self.to_mongo(value) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -585,7 +588,7 @@ class GenericEmbeddedDocumentField(BaseField): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         return self.to_mongo(value) |         return super(GenericEmbeddedDocumentField, self).prepare_query_value(op, self.to_mongo(value)) | ||||||
|  |  | ||||||
|     def to_python(self, value): |     def to_python(self, value): | ||||||
|         if isinstance(value, dict): |         if isinstance(value, dict): | ||||||
| @@ -668,7 +671,8 @@ class DynamicField(BaseField): | |||||||
|         if isinstance(value, basestring): |         if isinstance(value, basestring): | ||||||
|             from mongoengine.fields import StringField |             from mongoengine.fields import StringField | ||||||
|             return StringField().prepare_query_value(op, value) |             return StringField().prepare_query_value(op, value) | ||||||
|         return self.to_mongo(value) |         return super(DynamicField, self).prepare_query_value(op, self.to_mongo(value)) | ||||||
|  |  | ||||||
|  |  | ||||||
|     def validate(self, value, clean=True): |     def validate(self, value, clean=True): | ||||||
|         if hasattr(value, "validate"): |         if hasattr(value, "validate"): | ||||||
| @@ -979,8 +983,10 @@ class ReferenceField(BaseField): | |||||||
|     def prepare_query_value(self, op, value): |     def prepare_query_value(self, op, value): | ||||||
|         if value is None: |         if value is None: | ||||||
|             return None |             return None | ||||||
|  |         super(ReferenceField, self).prepare_query_value(op, value) | ||||||
|         return self.to_mongo(value) |         return self.to_mongo(value) | ||||||
|  |  | ||||||
|  |  | ||||||
|     def validate(self, value): |     def validate(self, value): | ||||||
|  |  | ||||||
|         if not isinstance(value, (self.document_type, DBRef)): |         if not isinstance(value, (self.document_type, DBRef)): | ||||||
|   | |||||||
| @@ -3,6 +3,7 @@ from collections import defaultdict | |||||||
| import pymongo | import pymongo | ||||||
| from bson import SON | from bson import SON | ||||||
|  |  | ||||||
|  | from mongoengine.base.fields import UPDATE_OPERATORS | ||||||
| from mongoengine.connection import get_connection | from mongoengine.connection import get_connection | ||||||
| from mongoengine.common import _import_class | from mongoengine.common import _import_class | ||||||
| from mongoengine.errors import InvalidQueryError, LookUpError | from mongoengine.errors import InvalidQueryError, LookUpError | ||||||
| @@ -24,10 +25,6 @@ CUSTOM_OPERATORS = ('match',) | |||||||
| MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + | MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS + | ||||||
|                    STRING_OPERATORS + CUSTOM_OPERATORS) |                    STRING_OPERATORS + CUSTOM_OPERATORS) | ||||||
|  |  | ||||||
| UPDATE_OPERATORS = ('set', 'unset', 'inc', 'dec', 'pop', 'push', |  | ||||||
|                     'push_all', 'pull', 'pull_all', 'add_to_set', |  | ||||||
|                     'set_on_insert', 'min', 'max') |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def query(_doc_cls=None, _field_operation=False, **query): | def query(_doc_cls=None, _field_operation=False, **query): | ||||||
|     """Transform a query from Django-style format to Mongo format. |     """Transform a query from Django-style format to Mongo format. | ||||||
|   | |||||||
| @@ -602,6 +602,20 @@ class QuerySetTest(unittest.TestCase): | |||||||
|             set__name="bobby", multi=True) |             set__name="bobby", multi=True) | ||||||
|         self.assertEqual(result, 2) |         self.assertEqual(result, 2) | ||||||
|  |  | ||||||
|  |     def test_update_validate(self): | ||||||
|  |         class EmDoc(EmbeddedDocument): | ||||||
|  |             str_f = StringField() | ||||||
|  |  | ||||||
|  |         class Doc(Document): | ||||||
|  |             str_f = StringField() | ||||||
|  |             dt_f = DateTimeField() | ||||||
|  |             cdt_f = ComplexDateTimeField() | ||||||
|  |             ed_f = EmbeddedDocumentField(EmDoc) | ||||||
|  |  | ||||||
|  |         self.assertRaises(ValidationError, Doc.objects().update, str_f=1, upsert=True) | ||||||
|  |         self.assertRaises(ValidationError, Doc.objects().update, dt_f="datetime", upsert=True) | ||||||
|  |         self.assertRaises(ValidationError, Doc.objects().update, ed_f__str_f=1, upsert=True) | ||||||
|  |  | ||||||
|     def test_upsert(self): |     def test_upsert(self): | ||||||
|         self.Person.drop_collection() |         self.Person.drop_collection() | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user