Added clean method to documents for pre validation data cleaning (MongoEngine/mongoengine#60)

This commit is contained in:
Ross Lawley 2012-11-07 12:12:28 +00:00
parent 7073b9d395
commit 1986e82783
8 changed files with 150 additions and 25 deletions

View File

@ -4,6 +4,7 @@ Changelog
Changes in 0.8
==============
- Added clean method to documents for pre validation data cleaning (MongoEngine/mongoengine#60)
- Added support setting for read prefrence at a query level (MongoEngine/mongoengine#157)
- Added _instance to EmbeddedDocuments pointing to the parent (MongoEngine/mongoengine#139)
- Inheritance is off by default (MongoEngine/mongoengine#122)

View File

@ -38,6 +38,34 @@ already exist, then any changes will be updated atomically. For example::
.. seealso::
:ref:`guide-atomic-updates`
Pre save data validation and cleaning
-------------------------------------
MongoEngine allows you to create custom cleaning rules for your documents when
calling :meth:`~mongoengine.Document.save`. By providing a custom
:meth:`~mongoengine.Document.clean` method you can do any pre validation / data
cleaning.
This might be useful if you want to ensure a default value based on other
document values for example::
class Essay(Document):
status = StringField(choices=('Published', 'Draft'), required=True)
pub_date = DateTimeField()
def clean(self):
"""Ensures that only published essays have a `pub_date` and
automatically sets the pub_date if published and not set"""
if self.status == 'Draft' and self.pub_date is not None:
msg = 'Draft entries should not have a publication date.'
raise ValidationError(msg)
# Set the pub_date for published items if not set.
if self.status == 'Published' and self.pub_date is None:
self.pub_date = datetime.now()
.. note::
Cleaning is only called if validation is turned on and when calling
:meth:`~mongoengine.Document.save`.
Cascading Saves
---------------
If your document contains :class:`~mongoengine.ReferenceField` or

View File

@ -15,7 +15,9 @@ from .common import get_document, ALLOW_INHERITANCE
from .datastructures import BaseDict, BaseList
from .fields import ComplexBaseField
__all__ = ('BaseDocument', )
__all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
NON_FIELD_ERRORS = '__all__'
class BaseDocument(object):
@ -82,11 +84,6 @@ class BaseDocument(object):
if hasattr(self, '_changed_fields'):
self._mark_as_changed(name)
# Check if the user has created a new instance of a class
if (self._is_document and self._initialised
and self._created and name == self._meta['id_field']):
super(BaseDocument, self).__setattr__('_created', False)
if (self._is_document and not self._created and
name in self._meta.get('shard_key', tuple()) and
self._data.get(name) != value):
@ -94,6 +91,11 @@ class BaseDocument(object):
msg = "Shard Keys are immutable. Tried to update %s" % name
raise OperationError(msg)
# Check if the user has created a new instance of a class
if (self._is_document and self._initialised
and self._created and name == self._meta['id_field']):
super(BaseDocument, self).__setattr__('_created', False)
super(BaseDocument, self).__setattr__(name, value)
def __getstate__(self):
@ -171,6 +173,16 @@ class BaseDocument(object):
else:
return hash(self.pk)
def clean(self):
"""
Hook for doing document level data cleaning before validation is run.
Any ValidationError raised by this method will not be associated with
a particular field; it will have a special-case association with the
field defined by NON_FIELD_ERRORS.
"""
pass
def to_mongo(self):
"""Return data dictionary ready for use with MongoDB.
"""
@ -203,20 +215,33 @@ class BaseDocument(object):
data[name] = field.to_mongo(self._data.get(name, None))
return data
def validate(self):
def validate(self, clean=True):
"""Ensure that all fields' values are valid and that required fields
are present.
"""
# Ensure that each field is matched to a valid value
errors = {}
if clean:
try:
self.clean()
except ValidationError, error:
errors[NON_FIELD_ERRORS] = error
# Get a list of tuples of field names and their current values
fields = [(field, self._data.get(name))
for name, field in self._fields.items()]
# Ensure that each field is matched to a valid value
errors = {}
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
GenericEmbeddedDocumentField = _import_class("GenericEmbeddedDocumentField")
for field, value in fields:
if value is not None:
try:
field._validate(value)
if isinstance(field, (EmbeddedDocumentField,
GenericEmbeddedDocumentField)):
field._validate(value, clean=clean)
else:
field._validate(value)
except ValidationError, error:
errors[field.name] = error.errors or error
except (ValueError, AttributeError, AssertionError), error:
@ -224,6 +249,7 @@ class BaseDocument(object):
elif field.required and not getattr(field, '_auto_gen', False):
errors[field.name] = ValidationError('Field is required',
field_name=field.name)
if errors:
raise ValidationError('ValidationError', errors=errors)

View File

@ -105,12 +105,12 @@ class BaseField(object):
"""
return value
def validate(self, value):
def validate(self, value, clean=True):
"""Perform validation on a value.
"""
pass
def _validate(self, value):
def _validate(self, value, **kwargs):
Document = _import_class('Document')
EmbeddedDocument = _import_class('EmbeddedDocument')
# check choices
@ -138,7 +138,7 @@ class BaseField(object):
raise ValueError('validation argument for "%s" must be a '
'callable.' % self.name)
self.validate(value)
self.validate(value, **kwargs)
class ComplexBaseField(BaseField):

View File

@ -9,8 +9,8 @@ def _import_class(cls_name):
doc_classes = ('Document', 'DynamicEmbeddedDocument', 'EmbeddedDocument',
'MapReduceDocument')
field_classes = ('DictField', 'DynamicField', 'EmbeddedDocumentField',
'GenericReferenceField', 'GeoPointField',
'ReferenceField', 'StringField')
'GenericReferenceField', 'GenericEmbeddedDocumentField',
'GeoPointField', 'ReferenceField', 'StringField')
queryset_classes = ('OperationError',)
deref_classes = ('DeReference',)

View File

@ -100,8 +100,8 @@ class Document(BaseDocument):
Automatic index creation can be disabled by specifying
attr:`auto_create_index` in the :attr:`meta` dictionary. If this is set to
False then indexes will not be created by MongoEngine. This is useful in
production systems where index creation is performed as part of a deployment
system.
production systems where index creation is performed as part of a
deployment system.
By default, _cls will be added to the start of every index (that
doesn't contain a list) if allow_inheritance is True. This can be
@ -165,7 +165,7 @@ class Document(BaseDocument):
cls._collection = db[collection_name]
return cls._collection
def save(self, safe=True, force_insert=False, validate=True,
def save(self, safe=True, force_insert=False, validate=True, clean=True,
write_options=None, cascade=None, cascade_kwargs=None,
_refs=None):
"""Save the :class:`~mongoengine.Document` to the database. If the
@ -179,6 +179,8 @@ class Document(BaseDocument):
:param force_insert: only try to create a new document, don't allow
updates of existing documents
:param validate: validates the document; set to ``False`` to skip.
:param clean: call the document clean method, requires `validate` to be
True.
:param write_options: Extra keyword arguments are passed down to
:meth:`~pymongo.collection.Collection.save` OR
:meth:`~pymongo.collection.Collection.insert`
@ -208,7 +210,7 @@ class Document(BaseDocument):
signals.pre_save.send(self.__class__, document=self)
if validate:
self.validate()
self.validate(clean=clean)
if not write_options:
write_options = {}

View File

@ -461,7 +461,7 @@ class EmbeddedDocumentField(BaseField):
return value
return self.document_type.to_mongo(value)
def validate(self, value):
def validate(self, value, clean=True):
"""Make sure that the document instance is an instance of the
EmbeddedDocument subclass provided when the document was defined.
"""
@ -469,7 +469,7 @@ class EmbeddedDocumentField(BaseField):
if not isinstance(value, self.document_type):
self.error('Invalid embedded document instance provided to an '
'EmbeddedDocumentField')
self.document_type.validate(value)
self.document_type.validate(value, clean)
def lookup_member(self, member_name):
return self.document_type._fields.get(member_name)
@ -499,12 +499,12 @@ class GenericEmbeddedDocumentField(BaseField):
return value
def validate(self, value):
def validate(self, value, clean=True):
if not isinstance(value, EmbeddedDocument):
self.error('Invalid embedded document instance provided to an '
'GenericEmbeddedDocumentField')
value.validate()
value.validate(clean=clean)
def to_mongo(self, document):
if document is None:

View File

@ -490,6 +490,76 @@ class InstanceTest(unittest.TestCase):
self.assertTrue('id' in keys)
self.assertTrue('e' in keys)
def test_document_clean(self):
class TestDocument(Document):
status = StringField()
pub_date = DateTimeField()
def clean(self):
if self.status == 'draft' and self.pub_date is not None:
msg = 'Draft entries may not have a publication date.'
raise ValidationError(msg)
# Set the pub_date for published items if not set.
if self.status == 'published' and self.pub_date is None:
self.pub_date = datetime.now()
TestDocument.drop_collection()
t = TestDocument(status="draft", pub_date=datetime.now())
try:
t.save()
except ValidationError, e:
expect_msg = "Draft entries may not have a publication date."
self.assertTrue(expect_msg in e.message)
self.assertEqual(e.to_dict(), {'__all__': expect_msg})
t = TestDocument(status="published")
t.save(clean=False)
self.assertEquals(t.pub_date, None)
t = TestDocument(status="published")
t.save(clean=True)
self.assertEquals(type(t.pub_date), datetime)
def test_document_embedded_clean(self):
class TestEmbeddedDocument(EmbeddedDocument):
x = IntField(required=True)
y = IntField(required=True)
z = IntField(required=True)
meta = {'allow_inheritance': False}
def clean(self):
if self.z:
if self.z != self.x + self.y:
raise ValidationError('Value of z != x + y')
else:
self.z = self.x + self.y
class TestDocument(Document):
doc = EmbeddedDocumentField(TestEmbeddedDocument)
status = StringField()
TestDocument.drop_collection()
t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25, z=15))
try:
t.save()
except ValidationError, e:
expect_msg = "Value of z != x + y"
self.assertTrue(expect_msg in e.message)
self.assertEqual(e.to_dict(), {'doc': {'__all__': expect_msg}})
t = TestDocument(doc=TestEmbeddedDocument(x=10, y=25)).save()
self.assertEquals(t.doc.z, 35)
# Asserts not raises
t = TestDocument(doc=TestEmbeddedDocument(x=15, y=35, z=5))
t.save(clean=False)
def test_save(self):
"""Ensure that a document may be saved in the database.
"""
@ -1935,7 +2005,5 @@ class ValidatorErrorTest(unittest.TestCase):
self.assertRaises(OperationError, change_shard_key)
if __name__ == '__main__':
unittest.main()