diff --git a/docs/changelog.rst b/docs/changelog.rst index 5ea1e4f0..ca18d3e9 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Changes in 0.8 ============== +- Added clean method to documents for pre validation data cleaning (MongoEngine/mongoengine#60) - Added support setting for read prefrence at a query level (MongoEngine/mongoengine#157) - Added _instance to EmbeddedDocuments pointing to the parent (MongoEngine/mongoengine#139) - Inheritance is off by default (MongoEngine/mongoengine#122) diff --git a/docs/guide/document-instances.rst b/docs/guide/document-instances.rst index 54fa804b..b3bf687b 100644 --- a/docs/guide/document-instances.rst +++ b/docs/guide/document-instances.rst @@ -38,6 +38,34 @@ already exist, then any changes will be updated atomically. For example:: .. seealso:: :ref:`guide-atomic-updates` +Pre save data validation and cleaning +------------------------------------- +MongoEngine allows you to create custom cleaning rules for your documents when +calling :meth:`~mongoengine.Document.save`. By providing a custom +:meth:`~mongoengine.Document.clean` method you can do any pre validation / data +cleaning. + +This might be useful if you want to ensure a default value based on other +document values for example:: + + class Essay(Document): + status = StringField(choices=('Published', 'Draft'), required=True) + pub_date = DateTimeField() + + def clean(self): + """Ensures that only published essays have a `pub_date` and + automatically sets the pub_date if published and not set""" + if self.status == 'Draft' and self.pub_date is not None: + msg = 'Draft entries should not have a publication date.' + raise ValidationError(msg) + # Set the pub_date for published items if not set. + if self.status == 'Published' and self.pub_date is None: + self.pub_date = datetime.now() + +.. note:: + Cleaning is only called if validation is turned on and when calling +:meth:`~mongoengine.Document.save`. + Cascading Saves --------------- If your document contains :class:`~mongoengine.ReferenceField` or diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index bc509af2..46f53205 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -15,7 +15,9 @@ from .common import get_document, ALLOW_INHERITANCE from .datastructures import BaseDict, BaseList from .fields import ComplexBaseField -__all__ = ('BaseDocument', ) +__all__ = ('BaseDocument', 'NON_FIELD_ERRORS') + +NON_FIELD_ERRORS = '__all__' class BaseDocument(object): @@ -82,11 +84,6 @@ class BaseDocument(object): if hasattr(self, '_changed_fields'): self._mark_as_changed(name) - # Check if the user has created a new instance of a class - if (self._is_document and self._initialised - and self._created and name == self._meta['id_field']): - super(BaseDocument, self).__setattr__('_created', False) - if (self._is_document and not self._created and name in self._meta.get('shard_key', tuple()) and self._data.get(name) != value): @@ -94,6 +91,11 @@ class BaseDocument(object): msg = "Shard Keys are immutable. Tried to update %s" % name raise OperationError(msg) + # Check if the user has created a new instance of a class + if (self._is_document and self._initialised + and self._created and name == self._meta['id_field']): + super(BaseDocument, self).__setattr__('_created', False) + super(BaseDocument, self).__setattr__(name, value) def __getstate__(self): @@ -171,6 +173,16 @@ class BaseDocument(object): else: return hash(self.pk) + def clean(self): + """ + Hook for doing document level data cleaning before validation is run. + + Any ValidationError raised by this method will not be associated with + a particular field; it will have a special-case association with the + field defined by NON_FIELD_ERRORS. + """ + pass + def to_mongo(self): """Return data dictionary ready for use with MongoDB. """ @@ -203,20 +215,33 @@ class BaseDocument(object): data[name] = field.to_mongo(self._data.get(name, None)) return data - def validate(self): + def validate(self, clean=True): """Ensure that all fields' values are valid and that required fields are present. """ + # Ensure that each field is matched to a valid value + errors = {} + if clean: + try: + self.clean() + except ValidationError, error: + errors[NON_FIELD_ERRORS] = error + # Get a list of tuples of field names and their current values fields = [(field, self._data.get(name)) for name, field in self._fields.items()] - # Ensure that each field is matched to a valid value - errors = {} + EmbeddedDocumentField = _import_class("EmbeddedDocumentField") + GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField") + for field, value in fields: if value is not None: try: - field._validate(value) + if isinstance(field, (EmbeddedDocumentField, + GenericEmbeddedDocumentField)): + field._validate(value, clean=clean) + else: + field._validate(value) except ValidationError, error: errors[field.name] = error.errors or error except (ValueError, AttributeError, AssertionError), error: @@ -224,6 +249,7 @@ class BaseDocument(object): elif field.required and not getattr(field, '_auto_gen', False): errors[field.name] = ValidationError('Field is required', field_name=field.name) + if errors: raise ValidationError('ValidationError', errors=errors) diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index fc1a0767..11719b55 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -105,12 +105,12 @@ class BaseField(object): """ return value - def validate(self, value): + def validate(self, value, clean=True): """Perform validation on a value. """ pass - def _validate(self, value): + def _validate(self, value, **kwargs): Document = _import_class('Document') EmbeddedDocument = _import_class('EmbeddedDocument') # check choices @@ -138,7 +138,7 @@ class BaseField(object): raise ValueError('validation argument for "%s" must be a ' 'callable.' % self.name) - self.validate(value) + self.validate(value, **kwargs) class ComplexBaseField(BaseField): diff --git a/mongoengine/common.py b/mongoengine/common.py index c284777e..c76801ce 100644 --- a/mongoengine/common.py +++ b/mongoengine/common.py @@ -9,8 +9,8 @@ def _import_class(cls_name): doc_classes = ('Document', 'DynamicEmbeddedDocument', 'EmbeddedDocument', 'MapReduceDocument') field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField', - 'GenericReferenceField', 'GeoPointField', - 'ReferenceField', 'StringField') + 'GenericReferenceField', 'GenericEmbeddedDocumentField', + 'GeoPointField', 'ReferenceField', 'StringField') queryset_classes = ('OperationError',) deref_classes = ('DeReference',) diff --git a/mongoengine/document.py b/mongoengine/document.py index adbdcca2..fcf82563 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -100,8 +100,8 @@ class Document(BaseDocument): Automatic index creation can be disabled by specifying attr:`auto_create_index` in the :attr:`meta` dictionary. If this is set to False then indexes will not be created by MongoEngine. This is useful in - production systems where index creation is performed as part of a deployment - system. + production systems where index creation is performed as part of a + deployment system. By default, _cls will be added to the start of every index (that doesn't contain a list) if allow_inheritance is True. This can be @@ -165,7 +165,7 @@ class Document(BaseDocument): cls._collection = db[collection_name] return cls._collection - def save(self, safe=True, force_insert=False, validate=True, + def save(self, safe=True, force_insert=False, validate=True, clean=True, write_options=None, cascade=None, cascade_kwargs=None, _refs=None): """Save the :class:`~mongoengine.Document` to the database. If the @@ -179,6 +179,8 @@ class Document(BaseDocument): :param force_insert: only try to create a new document, don't allow updates of existing documents :param validate: validates the document; set to ``False`` to skip. + :param clean: call the document clean method, requires `validate` to be + True. :param write_options: Extra keyword arguments are passed down to :meth:`~pymongo.collection.Collection.save` OR :meth:`~pymongo.collection.Collection.insert` @@ -208,7 +210,7 @@ class Document(BaseDocument): signals.pre_save.send(self.__class__, document=self) if validate: - self.validate() + self.validate(clean=clean) if not write_options: write_options = {} diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 94e11556..8aa7f641 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -461,7 +461,7 @@ class EmbeddedDocumentField(BaseField): return value return self.document_type.to_mongo(value) - def validate(self, value): + def validate(self, value, clean=True): """Make sure that the document instance is an instance of the EmbeddedDocument subclass provided when the document was defined. """ @@ -469,7 +469,7 @@ class EmbeddedDocumentField(BaseField): if not isinstance(value, self.document_type): self.error('Invalid embedded document instance provided to an ' 'EmbeddedDocumentField') - self.document_type.validate(value) + self.document_type.validate(value, clean) def lookup_member(self, member_name): return self.document_type._fields.get(member_name) @@ -499,12 +499,12 @@ class GenericEmbeddedDocumentField(BaseField): return value - def validate(self, value): + def validate(self, value, clean=True): if not isinstance(value, EmbeddedDocument): self.error('Invalid embedded document instance provided to an ' 'GenericEmbeddedDocumentField') - value.validate() + value.validate(clean=clean) def to_mongo(self, document): if document is None: diff --git a/tests/document/instance.py b/tests/document/instance.py index 48ddc10d..2e07eb26 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -490,6 +490,76 @@ class InstanceTest(unittest.TestCase): self.assertTrue('id' in keys) self.assertTrue('e' in keys) + def test_document_clean(self): + class TestDocument(Document): + status = StringField() + pub_date = DateTimeField() + + def clean(self): + if self.status == 'draft' and self.pub_date is not None: + msg = 'Draft entries may not have a publication date.' + raise ValidationError(msg) + # Set the pub_date for published items if not set. + if self.status == 'published' and self.pub_date is None: + self.pub_date = datetime.now() + + TestDocument.drop_collection() + + t = TestDocument(status="draft", pub_date=datetime.now()) + + try: + t.save() + except ValidationError, e: + expect_msg = "Draft entries may not have a publication date." + self.assertTrue(expect_msg in e.message) + self.assertEqual(e.to_dict(), {'__all__': expect_msg}) + + t = TestDocument(status="published") + t.save(clean=False) + + self.assertEquals(t.pub_date, None) + + t = TestDocument(status="published") + t.save(clean=True) + + self.assertEquals(type(t.pub_date), datetime) + + def test_document_embedded_clean(self): + class TestEmbeddedDocument(EmbeddedDocument): + x = IntField(required=True) + y = IntField(required=True) + z = IntField(required=True) + + meta = {'allow_inheritance': False} + + def clean(self): + if self.z: + if self.z != self.x + self.y: + raise ValidationError('Value of z != x + y') + else: + self.z = self.x + self.y + + class TestDocument(Document): + doc = EmbeddedDocumentField(TestEmbeddedDocument) + status = StringField() + + TestDocument.drop_collection() + + t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25, z=15)) + try: + t.save() + except ValidationError, e: + expect_msg = "Value of z != x + y" + self.assertTrue(expect_msg in e.message) + self.assertEqual(e.to_dict(), {'doc': {'__all__': expect_msg}}) + + t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25)).save() + self.assertEquals(t.doc.z, 35) + + # Asserts not raises + t = TestDocument(doc=TestEmbeddedDocument(x=15, y=35, z=5)) + t.save(clean=False) + def test_save(self): """Ensure that a document may be saved in the database. """ @@ -1935,7 +2005,5 @@ class ValidatorErrorTest(unittest.TestCase): self.assertRaises(OperationError, change_shard_key) - - if __name__ == '__main__': unittest.main()