Added clean method to documents for pre validation data cleaning (MongoEngine/mongoengine#60)

This commit is contained in:
Ross Lawley 2012-11-07 12:12:28 +00:00
parent 7073b9d395
commit 1986e82783
8 changed files with 150 additions and 25 deletions

View File

@ -4,6 +4,7 @@ Changelog
Changes in 0.8 Changes in 0.8
============== ==============
- Added clean method to documents for pre validation data cleaning (MongoEngine/mongoengine#60)
- Added support setting for read prefrence at a query level (MongoEngine/mongoengine#157) - Added support setting for read prefrence at a query level (MongoEngine/mongoengine#157)
- Added _instance to EmbeddedDocuments pointing to the parent (MongoEngine/mongoengine#139) - Added _instance to EmbeddedDocuments pointing to the parent (MongoEngine/mongoengine#139)
- Inheritance is off by default (MongoEngine/mongoengine#122) - Inheritance is off by default (MongoEngine/mongoengine#122)

View File

@ -38,6 +38,34 @@ already exist, then any changes will be updated atomically. For example::
.. seealso:: .. seealso::
:ref:`guide-atomic-updates` :ref:`guide-atomic-updates`
Pre save data validation and cleaning
-------------------------------------
MongoEngine allows you to create custom cleaning rules for your documents when
calling :meth:`~mongoengine.Document.save`. By providing a custom
:meth:`~mongoengine.Document.clean` method you can do any pre validation / data
cleaning.
This might be useful if you want to ensure a default value based on other
document values for example::
class Essay(Document):
status = StringField(choices=('Published', 'Draft'), required=True)
pub_date = DateTimeField()
def clean(self):
"""Ensures that only published essays have a `pub_date` and
automatically sets the pub_date if published and not set"""
if self.status == 'Draft' and self.pub_date is not None:
msg = 'Draft entries should not have a publication date.'
raise ValidationError(msg)
# Set the pub_date for published items if not set.
if self.status == 'Published' and self.pub_date is None:
self.pub_date = datetime.now()
.. note::
Cleaning is only called if validation is turned on and when calling
:meth:`~mongoengine.Document.save`.
Cascading Saves Cascading Saves
--------------- ---------------
If your document contains :class:`~mongoengine.ReferenceField` or If your document contains :class:`~mongoengine.ReferenceField` or

View File

@ -15,7 +15,9 @@ from .common import get_document, ALLOW_INHERITANCE
from .datastructures import BaseDict, BaseList from .datastructures import BaseDict, BaseList
from .fields import ComplexBaseField from .fields import ComplexBaseField
__all__ = ('BaseDocument', ) __all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
NON_FIELD_ERRORS = '__all__'
class BaseDocument(object): class BaseDocument(object):
@ -82,11 +84,6 @@ class BaseDocument(object):
if hasattr(self, '_changed_fields'): if hasattr(self, '_changed_fields'):
self._mark_as_changed(name) self._mark_as_changed(name)
# Check if the user has created a new instance of a class
if (self._is_document and self._initialised
and self._created and name == self._meta['id_field']):
super(BaseDocument, self).__setattr__('_created', False)
if (self._is_document and not self._created and if (self._is_document and not self._created and
name in self._meta.get('shard_key', tuple()) and name in self._meta.get('shard_key', tuple()) and
self._data.get(name) != value): self._data.get(name) != value):
@ -94,6 +91,11 @@ class BaseDocument(object):
msg = "Shard Keys are immutable. Tried to update %s" % name msg = "Shard Keys are immutable. Tried to update %s" % name
raise OperationError(msg) raise OperationError(msg)
# Check if the user has created a new instance of a class
if (self._is_document and self._initialised
and self._created and name == self._meta['id_field']):
super(BaseDocument, self).__setattr__('_created', False)
super(BaseDocument, self).__setattr__(name, value) super(BaseDocument, self).__setattr__(name, value)
def __getstate__(self): def __getstate__(self):
@ -171,6 +173,16 @@ class BaseDocument(object):
else: else:
return hash(self.pk) return hash(self.pk)
def clean(self):
"""
Hook for doing document level data cleaning before validation is run.
Any ValidationError raised by this method will not be associated with
a particular field; it will have a special-case association with the
field defined by NON_FIELD_ERRORS.
"""
pass
def to_mongo(self): def to_mongo(self):
"""Return data dictionary ready for use with MongoDB. """Return data dictionary ready for use with MongoDB.
""" """
@ -203,20 +215,33 @@ class BaseDocument(object):
data[name] = field.to_mongo(self._data.get(name, None)) data[name] = field.to_mongo(self._data.get(name, None))
return data return data
def validate(self): def validate(self, clean=True):
"""Ensure that all fields' values are valid and that required fields """Ensure that all fields' values are valid and that required fields
are present. are present.
""" """
# Ensure that each field is matched to a valid value
errors = {}
if clean:
try:
self.clean()
except ValidationError, error:
errors[NON_FIELD_ERRORS] = error
# Get a list of tuples of field names and their current values # Get a list of tuples of field names and their current values
fields = [(field, self._data.get(name)) fields = [(field, self._data.get(name))
for name, field in self._fields.items()] for name, field in self._fields.items()]
# Ensure that each field is matched to a valid value EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
errors = {} GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField")
for field, value in fields: for field, value in fields:
if value is not None: if value is not None:
try: try:
field._validate(value) if isinstance(field, (EmbeddedDocumentField,
GenericEmbeddedDocumentField)):
field._validate(value, clean=clean)
else:
field._validate(value)
except ValidationError, error: except ValidationError, error:
errors[field.name] = error.errors or error errors[field.name] = error.errors or error
except (ValueError, AttributeError, AssertionError), error: except (ValueError, AttributeError, AssertionError), error:
@ -224,6 +249,7 @@ class BaseDocument(object):
elif field.required and not getattr(field, '_auto_gen', False): elif field.required and not getattr(field, '_auto_gen', False):
errors[field.name] = ValidationError('Field is required', errors[field.name] = ValidationError('Field is required',
field_name=field.name) field_name=field.name)
if errors: if errors:
raise ValidationError('ValidationError', errors=errors) raise ValidationError('ValidationError', errors=errors)

View File

@ -105,12 +105,12 @@ class BaseField(object):
""" """
return value return value
def validate(self, value): def validate(self, value, clean=True):
"""Perform validation on a value. """Perform validation on a value.
""" """
pass pass
def _validate(self, value): def _validate(self, value, **kwargs):
Document = _import_class('Document') Document = _import_class('Document')
EmbeddedDocument = _import_class('EmbeddedDocument') EmbeddedDocument = _import_class('EmbeddedDocument')
# check choices # check choices
@ -138,7 +138,7 @@ class BaseField(object):
raise ValueError('validation argument for "%s" must be a ' raise ValueError('validation argument for "%s" must be a '
'callable.' % self.name) 'callable.' % self.name)
self.validate(value) self.validate(value, **kwargs)
class ComplexBaseField(BaseField): class ComplexBaseField(BaseField):

View File

@ -9,8 +9,8 @@ def _import_class(cls_name):
doc_classes = ('Document', 'DynamicEmbeddedDocument', 'EmbeddedDocument', doc_classes = ('Document', 'DynamicEmbeddedDocument', 'EmbeddedDocument',
'MapReduceDocument') 'MapReduceDocument')
field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField', field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField',
'GenericReferenceField', 'GeoPointField', 'GenericReferenceField', 'GenericEmbeddedDocumentField',
'ReferenceField', 'StringField') 'GeoPointField', 'ReferenceField', 'StringField')
queryset_classes = ('OperationError',) queryset_classes = ('OperationError',)
deref_classes = ('DeReference',) deref_classes = ('DeReference',)

View File

@ -100,8 +100,8 @@ class Document(BaseDocument):
Automatic index creation can be disabled by specifying Automatic index creation can be disabled by specifying
attr:`auto_create_index` in the :attr:`meta` dictionary. If this is set to attr:`auto_create_index` in the :attr:`meta` dictionary. If this is set to
False then indexes will not be created by MongoEngine. This is useful in False then indexes will not be created by MongoEngine. This is useful in
production systems where index creation is performed as part of a deployment production systems where index creation is performed as part of a
system. deployment system.
By default, _cls will be added to the start of every index (that By default, _cls will be added to the start of every index (that
doesn't contain a list) if allow_inheritance is True. This can be doesn't contain a list) if allow_inheritance is True. This can be
@ -165,7 +165,7 @@ class Document(BaseDocument):
cls._collection = db[collection_name] cls._collection = db[collection_name]
return cls._collection return cls._collection
def save(self, safe=True, force_insert=False, validate=True, def save(self, safe=True, force_insert=False, validate=True, clean=True,
write_options=None, cascade=None, cascade_kwargs=None, write_options=None, cascade=None, cascade_kwargs=None,
_refs=None): _refs=None):
"""Save the :class:`~mongoengine.Document` to the database. If the """Save the :class:`~mongoengine.Document` to the database. If the
@ -179,6 +179,8 @@ class Document(BaseDocument):
:param force_insert: only try to create a new document, don't allow :param force_insert: only try to create a new document, don't allow
updates of existing documents updates of existing documents
:param validate: validates the document; set to ``False`` to skip. :param validate: validates the document; set to ``False`` to skip.
:param clean: call the document clean method, requires `validate` to be
True.
:param write_options: Extra keyword arguments are passed down to :param write_options: Extra keyword arguments are passed down to
:meth:`~pymongo.collection.Collection.save` OR :meth:`~pymongo.collection.Collection.save` OR
:meth:`~pymongo.collection.Collection.insert` :meth:`~pymongo.collection.Collection.insert`
@ -208,7 +210,7 @@ class Document(BaseDocument):
signals.pre_save.send(self.__class__, document=self) signals.pre_save.send(self.__class__, document=self)
if validate: if validate:
self.validate() self.validate(clean=clean)
if not write_options: if not write_options:
write_options = {} write_options = {}

View File

@ -461,7 +461,7 @@ class EmbeddedDocumentField(BaseField):
return value return value
return self.document_type.to_mongo(value) return self.document_type.to_mongo(value)
def validate(self, value): def validate(self, value, clean=True):
"""Make sure that the document instance is an instance of the """Make sure that the document instance is an instance of the
EmbeddedDocument subclass provided when the document was defined. EmbeddedDocument subclass provided when the document was defined.
""" """
@ -469,7 +469,7 @@ class EmbeddedDocumentField(BaseField):
if not isinstance(value, self.document_type): if not isinstance(value, self.document_type):
self.error('Invalid embedded document instance provided to an ' self.error('Invalid embedded document instance provided to an '
'EmbeddedDocumentField') 'EmbeddedDocumentField')
self.document_type.validate(value) self.document_type.validate(value, clean)
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.document_type._fields.get(member_name) return self.document_type._fields.get(member_name)
@ -499,12 +499,12 @@ class GenericEmbeddedDocumentField(BaseField):
return value return value
def validate(self, value): def validate(self, value, clean=True):
if not isinstance(value, EmbeddedDocument): if not isinstance(value, EmbeddedDocument):
self.error('Invalid embedded document instance provided to an ' self.error('Invalid embedded document instance provided to an '
'GenericEmbeddedDocumentField') 'GenericEmbeddedDocumentField')
value.validate() value.validate(clean=clean)
def to_mongo(self, document): def to_mongo(self, document):
if document is None: if document is None:

View File

@ -490,6 +490,76 @@ class InstanceTest(unittest.TestCase):
self.assertTrue('id' in keys) self.assertTrue('id' in keys)
self.assertTrue('e' in keys) self.assertTrue('e' in keys)
def test_document_clean(self):
class TestDocument(Document):
status = StringField()
pub_date = DateTimeField()
def clean(self):
if self.status == 'draft' and self.pub_date is not None:
msg = 'Draft entries may not have a publication date.'
raise ValidationError(msg)
# Set the pub_date for published items if not set.
if self.status == 'published' and self.pub_date is None:
self.pub_date = datetime.now()
TestDocument.drop_collection()
t = TestDocument(status="draft", pub_date=datetime.now())
try:
t.save()
except ValidationError, e:
expect_msg = "Draft entries may not have a publication date."
self.assertTrue(expect_msg in e.message)
self.assertEqual(e.to_dict(), {'__all__': expect_msg})
t = TestDocument(status="published")
t.save(clean=False)
self.assertEquals(t.pub_date, None)
t = TestDocument(status="published")
t.save(clean=True)
self.assertEquals(type(t.pub_date), datetime)
def test_document_embedded_clean(self):
class TestEmbeddedDocument(EmbeddedDocument):
x = IntField(required=True)
y = IntField(required=True)
z = IntField(required=True)
meta = {'allow_inheritance': False}
def clean(self):
if self.z:
if self.z != self.x + self.y:
raise ValidationError('Value of z != x + y')
else:
self.z = self.x + self.y
class TestDocument(Document):
doc = EmbeddedDocumentField(TestEmbeddedDocument)
status = StringField()
TestDocument.drop_collection()
t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25, z=15))
try:
t.save()
except ValidationError, e:
expect_msg = "Value of z != x + y"
self.assertTrue(expect_msg in e.message)
self.assertEqual(e.to_dict(), {'doc': {'__all__': expect_msg}})
t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25)).save()
self.assertEquals(t.doc.z, 35)
# Asserts not raises
t = TestDocument(doc=TestEmbeddedDocument(x=15, y=35, z=5))
t.save(clean=False)
def test_save(self): def test_save(self):
"""Ensure that a document may be saved in the database. """Ensure that a document may be saved in the database.
""" """
@ -1935,7 +2005,5 @@ class ValidatorErrorTest(unittest.TestCase):
self.assertRaises(OperationError, change_shard_key) self.assertRaises(OperationError, change_shard_key)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()