Merge branch 'validation-schema' of https://github.com/n1k0/mongoengine into validation-schema

Conflicts:
	mongoengine/base.py
	mongoengine/fields.py
This commit is contained in:
Ross Lawley 2011-11-01 01:45:32 -07:00
commit 558b8123b5
4 changed files with 270 additions and 146 deletions

View File

@ -5,6 +5,7 @@ Changelog
Changes in dev Changes in dev
============== ==============
- Added recursive validation error of documents / complex fields
- Fixed breaking during queryset iteration - Fixed breaking during queryset iteration
- Added pre and post bulk-insert signals - Added pre and post bulk-insert signals
- Added ImageField - requires PIL - Added ImageField - requires PIL

View File

@ -4,7 +4,6 @@ from queryset import DO_NOTHING
from mongoengine import signals from mongoengine import signals
import weakref
import sys import sys
import pymongo import pymongo
import pymongo.objectid import pymongo.objectid
@ -20,8 +19,56 @@ class InvalidDocumentError(Exception):
pass pass
class ValidationError(Exception): class ValidationError(AssertionError):
pass """Validation exception.
"""
errors = {}
field_name = None
_message = None
def __init__(self, message="", **kwargs):
self.errors = kwargs.get('errors', {})
self.field_name = kwargs.get('field_name')
self.message = message
def __str__(self):
return self.message
def __repr__(self):
return '%s(%s,)' % (self.__class__.__name__, self.message)
def __getattribute__(self, name):
message = super(ValidationError, self).__getattribute__(name)
if name == 'message' and self.field_name:
return message + ' ("%s")' % self.field_name
else:
return message
def _get_message(self):
return self._message
def _set_message(self, message):
self._message = message
message = property(_get_message, _set_message)
@property
def schema(self):
def get_schema(source):
errors_dict = {}
if not source:
return errors_dict
if isinstance(source, dict):
for field_name, error in source.iteritems():
errors_dict[field_name] = get_schema(error)
elif isinstance(source, ValidationError) and source.errors:
return get_schema(source.errors)
else:
return unicode(source)
return errors_dict
if not self.errors:
return {}
return get_schema(self.errors)
_document_registry = {} _document_registry = {}
@ -51,6 +98,8 @@ class BaseField(object):
.. versionchanged:: 0.5 - added verbose and help text .. versionchanged:: 0.5 - added verbose and help text
""" """
name = None
# Fields may have _types inserted into indexes by default # Fields may have _types inserted into indexes by default
_index_with_types = True _index_with_types = True
_geo_index = False _geo_index = False
@ -117,6 +166,12 @@ class BaseField(object):
instance._data[self.name] = value instance._data[self.name] = value
instance._mark_as_changed(self.name) instance._mark_as_changed(self.name)
def error(self, message="", errors=None, field_name=None):
"""Raises a ValidationError.
"""
field_name = field_name if field_name else self.name
raise ValidationError(message, errors=errors, field_name=field_name)
def to_python(self, value): def to_python(self, value):
"""Convert a MongoDB-compatible type to a Python type. """Convert a MongoDB-compatible type to a Python type.
""" """
@ -142,15 +197,13 @@ class BaseField(object):
if self.choices is not None: if self.choices is not None:
option_keys = [option_key for option_key, option_value in self.choices] option_keys = [option_key for option_key, option_value in self.choices]
if value not in option_keys: if value not in option_keys:
raise ValidationError('Value must be one of %s ("%s")' % self.error('Value must be one of %s' % unicode(option_keys))
(unicode(option_keys), self.name))
# check validation argument # check validation argument
if self.validation is not None: if self.validation is not None:
if callable(self.validation): if callable(self.validation):
if not self.validation(value): if not self.validation(value):
raise ValidationError('Value does not match custom ' self.error('Value does not match custom validation method')
'validation method ("%s")' % self.name)
else: else:
raise ValueError('validation argument for "%s" must be a ' raise ValueError('validation argument for "%s" must be a '
'callable.' % self.name) 'callable.' % self.name)
@ -198,7 +251,7 @@ class ComplexBaseField(BaseField):
if not hasattr(value, 'items'): if not hasattr(value, 'items'):
try: try:
is_list = True is_list = True
value = dict([(k,v) for k,v in enumerate(value)]) value = dict([(k, v) for k, v in enumerate(value)])
except TypeError: # Not iterable return the value except TypeError: # Not iterable return the value
return value return value
@ -206,13 +259,12 @@ class ComplexBaseField(BaseField):
value_dict = dict([(key, self.field.to_python(item)) for key, item in value.items()]) value_dict = dict([(key, self.field.to_python(item)) for key, item in value.items()])
else: else:
value_dict = {} value_dict = {}
for k,v in value.items(): for k, v in value.items():
if isinstance(v, Document): if isinstance(v, Document):
# We need the id from the saved object to create the DBRef # We need the id from the saved object to create the DBRef
if v.pk is None: if v.pk is None:
raise ValidationError('You can only reference ' self.error('You can only reference documents once they'
'documents once they have been saved ' ' have been saved to the database')
'to the database ("%s")' % self.name)
collection = v._get_collection_name() collection = v._get_collection_name()
value_dict[k] = pymongo.dbref.DBRef(collection, v.pk) value_dict[k] = pymongo.dbref.DBRef(collection, v.pk)
elif hasattr(v, 'to_python'): elif hasattr(v, 'to_python'):
@ -221,7 +273,7 @@ class ComplexBaseField(BaseField):
value_dict[k] = self.to_python(v) value_dict[k] = self.to_python(v)
if is_list: # Convert back to a list if is_list: # Convert back to a list
return [v for k,v in sorted(value_dict.items(), key=operator.itemgetter(0))] return [v for k, v in sorted(value_dict.items(), key=operator.itemgetter(0))]
return value_dict return value_dict
def to_mongo(self, value): def to_mongo(self, value):
@ -239,7 +291,7 @@ class ComplexBaseField(BaseField):
if not hasattr(value, 'items'): if not hasattr(value, 'items'):
try: try:
is_list = True is_list = True
value = dict([(k,v) for k,v in enumerate(value)]) value = dict([(k, v) for k, v in enumerate(value)])
except TypeError: # Not iterable return the value except TypeError: # Not iterable return the value
return value return value
@ -247,13 +299,12 @@ class ComplexBaseField(BaseField):
value_dict = dict([(key, self.field.to_mongo(item)) for key, item in value.items()]) value_dict = dict([(key, self.field.to_mongo(item)) for key, item in value.items()])
else: else:
value_dict = {} value_dict = {}
for k,v in value.items(): for k, v in value.items():
if isinstance(v, Document): if isinstance(v, Document):
# We need the id from the saved object to create the DBRef # We need the id from the saved object to create the DBRef
if v.pk is None: if v.pk is None:
raise ValidationError('You can only reference ' self.error('You can only reference documents once they'
'documents once they have been saved ' ' have been saved to the database')
'to the database ("%s")' % self.name)
# If its a document that is not inheritable it won't have # If its a document that is not inheritable it won't have
# _types / _cls data so make it a generic reference allows # _types / _cls data so make it a generic reference allows
@ -271,26 +322,33 @@ class ComplexBaseField(BaseField):
value_dict[k] = self.to_mongo(v) value_dict[k] = self.to_mongo(v)
if is_list: # Convert back to a list if is_list: # Convert back to a list
return [v for k,v in sorted(value_dict.items(), key=operator.itemgetter(0))] return [v for k, v in sorted(value_dict.items(), key=operator.itemgetter(0))]
return value_dict return value_dict
def validate(self, value): def validate(self, value):
"""If field provided ensure the value is valid. """If field is provided ensure the value is valid.
""" """
errors = {}
if self.field: if self.field:
try:
if hasattr(value, 'iteritems'): if hasattr(value, 'iteritems'):
[self.field.validate(v) for k,v in value.iteritems()] sequence = value.iteritems()
else: else:
[self.field.validate(v) for v in value] sequence = enumerate(value)
except Exception, err: for k, v in sequence:
raise ValidationError('Invalid %s item (%s) ("%s")' % ( try:
self.field.__class__.__name__, str(v), self.name)) self.field.validate(v)
except (ValidationError, AssertionError), error:
if hasattr(error, 'errors'):
errors[k] = error.errors
else:
errors[k] = error
if errors:
field_class = self.field.__class__.__name__
self.error('Invalid %s item (%s)' % (field_class, value),
errors=errors)
# Don't allow empty values if required # Don't allow empty values if required
if self.required and not value: if self.required and not value:
raise ValidationError('Field "%s" is required and cannot be empty' % self.error('Field is required and cannot be empty')
self.name)
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
return self.to_mongo(value) return self.to_mongo(value)
@ -345,6 +403,7 @@ class BaseDynamicField(BaseField):
def lookup_member(self, member_name): def lookup_member(self, member_name):
return member_name return member_name
class ObjectIdField(BaseField): class ObjectIdField(BaseField):
"""An field wrapper around MongoDB's ObjectIds. """An field wrapper around MongoDB's ObjectIds.
""" """
@ -357,8 +416,8 @@ class ObjectIdField(BaseField):
try: try:
return pymongo.objectid.ObjectId(unicode(value)) return pymongo.objectid.ObjectId(unicode(value))
except Exception, e: except Exception, e:
#e.message attribute has been deprecated since Python 2.6 # e.message attribute has been deprecated since Python 2.6
raise ValidationError(unicode(e)) self.error(unicode(e))
return value return value
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
@ -368,7 +427,7 @@ class ObjectIdField(BaseField):
try: try:
pymongo.objectid.ObjectId(unicode(value)) pymongo.objectid.ObjectId(unicode(value))
except: except:
raise ValidationError('Invalid Object ID ("%s")' % self.name) self.error('Invalid Object ID')
class DocumentMetaclass(type): class DocumentMetaclass(type):
@ -394,7 +453,7 @@ class DocumentMetaclass(type):
superclasses[base._class_name] = base superclasses[base._class_name] = base
superclasses.update(base._superclasses) superclasses.update(base._superclasses)
else: # Add any mixin fields else: # Add any mixin fields
attrs.update(dict([(k,v) for k,v in base.__dict__.items() attrs.update(dict([(k, v) for k, v in base.__dict__.items()
if issubclass(v.__class__, BaseField)])) if issubclass(v.__class__, BaseField)]))
if hasattr(base, '_meta') and not base._meta.get('abstract'): if hasattr(base, '_meta') and not base._meta.get('abstract'):
@ -489,7 +548,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
('meta' in attrs and attrs['meta'].get('abstract', False))): ('meta' in attrs and attrs['meta'].get('abstract', False))):
# Make sure no base class was non-abstract # Make sure no base class was non-abstract
non_abstract_bases = [b for b in bases non_abstract_bases = [b for b in bases
if hasattr(b,'_meta') and not b._meta.get('abstract', False)] if hasattr(b, '_meta') and not b._meta.get('abstract', False)]
if non_abstract_bases: if non_abstract_bases:
raise ValueError("Abstract document cannot have non-abstract base") raise ValueError("Abstract document cannot have non-abstract base")
return super_new(cls, name, bases, attrs) return super_new(cls, name, bases, attrs)
@ -666,7 +725,7 @@ class BaseDocument(object):
signals.post_init.send(self.__class__, document=self) signals.post_init.send(self.__class__, document=self)
def __setattr__(self, name, value): def __setattr__(self, name, value):
# Handle dynamic data only if an intialised dynamic document # Handle dynamic data only if an initialised dynamic document
if self._dynamic and getattr(self, '_initialised', False): if self._dynamic and getattr(self, '_initialised', False):
field = None field = None
@ -709,7 +768,8 @@ class BaseDocument(object):
data[k] = self.__expand_dynamic_values(key, v) data[k] = self.__expand_dynamic_values(key, v)
if is_list: # Convert back to a list if is_list: # Convert back to a list
value = [v for k, v in sorted(data.items(), key=operator.itemgetter(0))] data_items = sorted(data.items(), key=operator.itemgetter(0))
value = [v for k, v in data_items]
else: else:
value = data value = data
@ -730,15 +790,21 @@ class BaseDocument(object):
for name, field in self._fields.items()] for name, field in self._fields.items()]
# Ensure that each field is matched to a valid value # Ensure that each field is matched to a valid value
errors = {}
for field, value in fields: for field, value in fields:
if value is not None: if value is not None:
try: try:
field._validate(value) field._validate(value)
except (ValueError, AttributeError, AssertionError), e: except ValidationError, error:
raise ValidationError('Invalid value for field named "%s" of type "%s": %s' errors[field.name] = error.errors or error
% (field.name, field.__class__.__name__, value)) except (ValueError, AttributeError, AssertionError), error:
errors[field.name] = error
elif field.required: elif field.required:
raise ValidationError('Field "%s" is required' % field.name) errors[field.name] = ValidationError('Field is required',
field_name=field.name)
if errors:
raise ValidationError('Errors encountered validating document',
errors=errors)
def to_mongo(self): def to_mongo(self):
"""Return data dictionary ready for use with MongoDB. """Return data dictionary ready for use with MongoDB.
@ -812,7 +878,6 @@ class BaseDocument(object):
""".strip() % class_name) """.strip() % class_name)
cls = subclasses[class_name] cls = subclasses[class_name]
present_fields = data.keys()
for field_name, field in cls._fields.items(): for field_name, field in cls._fields.items():
if field.db_field in data: if field.db_field in data:
value = data[field.db_field] value = data[field.db_field]
@ -963,8 +1028,7 @@ class BaseDocument(object):
return geo_indices return geo_indices
def __getstate__(self): def __getstate__(self):
self_dict = self.__dict__ removals = ["get_%s_display" % k for k, v in self._fields.items() if v.choices]
removals = ["get_%s_display" % k for k,v in self._fields.items() if v.choices]
for k in removals: for k in removals:
if hasattr(self, k): if hasattr(self, k):
delattr(self, k) delattr(self, k)
@ -1039,7 +1103,7 @@ class BaseDocument(object):
def __hash__(self): def __hash__(self):
if self.pk is None: if self.pk is None:
# For new object # For new object
return super(BaseDocument,self).__hash__() return super(BaseDocument, self).__hash__()
else: else:
return hash(self.pk) return hash(self.pk)

View File

@ -54,20 +54,17 @@ class StringField(BaseField):
return unicode(value) return unicode(value)
def validate(self, value): def validate(self, value):
assert isinstance(value, (str, unicode)) if not isinstance(value, (str, unicode)):
self.error('StringField only accepts string values')
if self.max_length is not None and len(value) > self.max_length: if self.max_length is not None and len(value) > self.max_length:
raise ValidationError('String value is too long ("%s")' % self.error('String value is too long')
self.name)
if self.min_length is not None and len(value) < self.min_length: if self.min_length is not None and len(value) < self.min_length:
raise ValidationError('String value is too short ("%s")' % self.error('String value is too short')
self.name)
if self.regex is not None and self.regex.match(value) is None: if self.regex is not None and self.regex.match(value) is None:
message = 'String value did not match validation regex ("%s")' % \ self.error('String value did not match validation regex')
self.name
raise ValidationError(message)
def lookup_member(self, member_name): def lookup_member(self, member_name):
return None return None
@ -117,18 +114,15 @@ class URLField(StringField):
def validate(self, value): def validate(self, value):
if not URLField.URL_REGEX.match(value): if not URLField.URL_REGEX.match(value):
raise ValidationError('Invalid URL: %s ("%s")' % (value, self.error('Invalid URL: %s' % value)
self.name))
if self.verify_exists: if self.verify_exists:
import urllib2 import urllib2
try: try:
request = urllib2.Request(value) request = urllib2.Request(value)
response = urllib2.urlopen(request) urllib2.urlopen(request)
except Exception, e: except Exception, e:
message = 'This URL appears to be a broken link: %s ("%s")' % ( self.error('This URL appears to be a broken link: %s' % e)
e, self.name)
raise ValidationError(message)
class EmailField(StringField): class EmailField(StringField):
@ -145,8 +139,7 @@ class EmailField(StringField):
def validate(self, value): def validate(self, value):
if not EmailField.EMAIL_REGEX.match(value): if not EmailField.EMAIL_REGEX.match(value):
raise ValidationError('Invalid Mail-address: %s ("%s")' % (value, self.error('Invalid Mail-address: %s' % value)
self.name))
class IntField(BaseField): class IntField(BaseField):
@ -164,16 +157,13 @@ class IntField(BaseField):
try: try:
value = int(value) value = int(value)
except: except:
raise ValidationError('%s could not be converted to int ("%s")' % ( self.error('%s could not be converted to int' % value)
value, self.name))
if self.min_value is not None and value < self.min_value: if self.min_value is not None and value < self.min_value:
raise ValidationError('Integer value is too small ("%s")' % self.error('Integer value is too small')
self.name)
if self.max_value is not None and value > self.max_value: if self.max_value is not None and value > self.max_value:
raise ValidationError('Integer value is too large ("%s")' % self.error('Integer value is too large')
self.name)
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
return int(value) return int(value)
@ -193,15 +183,14 @@ class FloatField(BaseField):
def validate(self, value): def validate(self, value):
if isinstance(value, int): if isinstance(value, int):
value = float(value) value = float(value)
assert isinstance(value, float) if not isinstance(value, float):
self.error('FoatField only accepts float values')
if self.min_value is not None and value < self.min_value: if self.min_value is not None and value < self.min_value:
raise ValidationError('Float value is too small ("%s")' % self.error('Float value is too small')
self.name)
if self.max_value is not None and value > self.max_value: if self.max_value is not None and value > self.max_value:
raise ValidationError('Float value is too large ("%s")' % self.error('Float value is too large')
self.name)
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
return float(value) return float(value)
@ -232,16 +221,13 @@ class DecimalField(BaseField):
try: try:
value = decimal.Decimal(value) value = decimal.Decimal(value)
except Exception, exc: except Exception, exc:
raise ValidationError('Could not convert value to decimal: %s' self.error('Could not convert value to decimal: %s' % exc)
'("%s")' % (exc, self.name))
if self.min_value is not None and value < self.min_value: if self.min_value is not None and value < self.min_value:
raise ValidationError('Decimal value is too small ("%s")' % self.error('Decimal value is too small')
self.name)
if self.max_value is not None and value > self.max_value: if self.max_value is not None and value > self.max_value:
raise ValidationError('Decimal value is too large ("%s")' % self.error('Decimal value is too large')
self.name)
class BooleanField(BaseField): class BooleanField(BaseField):
@ -254,7 +240,8 @@ class BooleanField(BaseField):
return bool(value) return bool(value)
def validate(self, value): def validate(self, value):
assert isinstance(value, bool) if not isinstance(value, bool):
self.error('BooleanField only accepts boolean values')
class DateTimeField(BaseField): class DateTimeField(BaseField):
@ -267,7 +254,8 @@ class DateTimeField(BaseField):
""" """
def validate(self, value): def validate(self, value):
assert isinstance(value, (datetime.datetime, datetime.date)) if not isinstance(value, (datetime.datetime, datetime.date)):
self.error(u'cannot parse date "%s"' % value)
def to_mongo(self, value): def to_mongo(self, value):
return self.prepare_query_value(None, value) return self.prepare_query_value(None, value)
@ -388,8 +376,8 @@ class ComplexDateTimeField(StringField):
def validate(self, value): def validate(self, value):
if not isinstance(value, datetime.datetime): if not isinstance(value, datetime.datetime):
raise ValidationError('Only datetime objects may used in a ' self.error('Only datetime objects may used in a '
'ComplexDateTimeField ("%s")' % self.name) 'ComplexDateTimeField')
def to_python(self, value): def to_python(self, value):
return self._convert_from_string(value) return self._convert_from_string(value)
@ -409,9 +397,8 @@ class EmbeddedDocumentField(BaseField):
def __init__(self, document_type, **kwargs): def __init__(self, document_type, **kwargs):
if not isinstance(document_type, basestring): if not isinstance(document_type, basestring):
if not issubclass(document_type, EmbeddedDocument): if not issubclass(document_type, EmbeddedDocument):
raise ValidationError('Invalid embedded document class ' self.error('Invalid embedded document class provided to an '
'provided to an EmbeddedDocumentField ' 'EmbeddedDocumentField')
'("%s")' % self.name)
self.document_type_obj = document_type self.document_type_obj = document_type
super(EmbeddedDocumentField, self).__init__(**kwargs) super(EmbeddedDocumentField, self).__init__(**kwargs)
@ -440,9 +427,8 @@ class EmbeddedDocumentField(BaseField):
""" """
# Using isinstance also works for subclasses of self.document # Using isinstance also works for subclasses of self.document
if not isinstance(value, self.document_type): if not isinstance(value, self.document_type):
raise ValidationError('Invalid embedded document instance ' self.error('Invalid embedded document instance provided to an '
'provided to an EmbeddedDocumentField ' 'EmbeddedDocumentField')
'("%s")' % self.name)
self.document_type.validate(value) self.document_type.validate(value)
def lookup_member(self, member_name): def lookup_member(self, member_name):
@ -471,9 +457,8 @@ class GenericEmbeddedDocumentField(BaseField):
def validate(self, value): def validate(self, value):
if not isinstance(value, EmbeddedDocument): if not isinstance(value, EmbeddedDocument):
raise ValidationError('Invalid embedded document instance ' self.error('Invalid embedded document instance provided to an '
'provided to an GenericEmbeddedDocumentField ' 'GenericEmbeddedDocumentField')
'("%s")' % self.name)
value.validate() value.validate()
@ -508,8 +493,7 @@ class ListField(ComplexBaseField):
""" """
if (not isinstance(value, (list, tuple)) or if (not isinstance(value, (list, tuple)) or
isinstance(value, basestring)): isinstance(value, basestring)):
raise ValidationError('Only lists and tuples may be used in a ' self.error('Only lists and tuples may be used in a list field')
'list field ("%s")' % self.name)
super(ListField, self).validate(value) super(ListField, self).validate(value)
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
@ -557,7 +541,8 @@ class DictField(ComplexBaseField):
def __init__(self, basecls=None, field=None, *args, **kwargs): def __init__(self, basecls=None, field=None, *args, **kwargs):
self.field = field self.field = field
self.basecls = basecls or BaseField self.basecls = basecls or BaseField
assert issubclass(self.basecls, BaseField) if not issubclass(self.basecls, BaseField):
self.error('DictField only accepts dict values')
kwargs.setdefault('default', lambda: {}) kwargs.setdefault('default', lambda: {})
super(DictField, self).__init__(*args, **kwargs) super(DictField, self).__init__(*args, **kwargs)
@ -565,13 +550,11 @@ class DictField(ComplexBaseField):
"""Make sure that a list of valid fields is being used. """Make sure that a list of valid fields is being used.
""" """
if not isinstance(value, dict): if not isinstance(value, dict):
raise ValidationError('Only dictionaries may be used in a ' self.error('Only dictionaries may be used in a DictField')
'DictField ("%s")' % self.name)
if any(('.' in k or '$' in k) for k in value): if any(('.' in k or '$' in k) for k in value):
raise ValidationError('Invalid dictionary key name - keys may not ' self.error('Invalid dictionary key name - keys may not contain "."'
'contain "." or "$" characters ("%s")' % ' or "$" characters')
self.name)
super(DictField, self).validate(value) super(DictField, self).validate(value)
def lookup_member(self, member_name): def lookup_member(self, member_name):
@ -598,12 +581,11 @@ class MapField(DictField):
def __init__(self, field=None, *args, **kwargs): def __init__(self, field=None, *args, **kwargs):
if not isinstance(field, BaseField): if not isinstance(field, BaseField):
raise ValidationError('Argument to MapField constructor must be ' self.error('Argument to MapField constructor must be a valid '
'a valid field') 'field')
super(MapField, self).__init__(field=field, *args, **kwargs) super(MapField, self).__init__(field=field, *args, **kwargs)
class ReferenceField(BaseField): class ReferenceField(BaseField):
"""A reference to a document that will be automatically dereferenced on """A reference to a document that will be automatically dereferenced on
access (lazily). access (lazily).
@ -629,8 +611,8 @@ class ReferenceField(BaseField):
""" """
if not isinstance(document_type, basestring): if not isinstance(document_type, basestring):
if not issubclass(document_type, (Document, basestring)): if not issubclass(document_type, (Document, basestring)):
raise ValidationError('Argument to ReferenceField constructor ' self.error('Argument to ReferenceField constructor must be a '
'must be a document class or a string') 'document class or a string')
self.document_type_obj = document_type self.document_type_obj = document_type
self.reverse_delete_rule = reverse_delete_rule self.reverse_delete_rule = reverse_delete_rule
super(ReferenceField, self).__init__(**kwargs) super(ReferenceField, self).__init__(**kwargs)
@ -669,8 +651,8 @@ class ReferenceField(BaseField):
# We need the id from the saved object to create the DBRef # We need the id from the saved object to create the DBRef
id_ = document.id id_ = document.id
if id_ is None: if id_ is None:
raise ValidationError('You can only reference documents once ' self.error('You can only reference documents once they have'
'they have been saved to the database') ' been saved to the database')
else: else:
id_ = document id_ = document
@ -685,13 +667,12 @@ class ReferenceField(BaseField):
return self.to_mongo(value) return self.to_mongo(value)
def validate(self, value): def validate(self, value):
assert isinstance(value, (self.document_type, pymongo.dbref.DBRef)) if not isinstance(value, (self.document_type, pymongo.dbref.DBRef)):
self.error('A ReferenceField only accepts DBRef')
if isinstance(value, Document) and value.id is None: if isinstance(value, Document) and value.id is None:
raise ValidationError('You can only reference documents once ' self.error('You can only reference documents once they have been '
'they have been saved to the database ' 'saved to the database')
'("%s")' % self.name)
def lookup_member(self, member_name): def lookup_member(self, member_name):
return self.document_type._fields.get(member_name) return self.document_type._fields.get(member_name)
@ -719,14 +700,12 @@ class GenericReferenceField(BaseField):
def validate(self, value): def validate(self, value):
if not isinstance(value, (Document, pymongo.dbref.DBRef)): if not isinstance(value, (Document, pymongo.dbref.DBRef)):
raise ValidationError('GenericReferences can only contain ' self.error('GenericReferences can only contain documents')
'documents ("%s")' % self.name)
# We need the id from the saved object to create the DBRef # We need the id from the saved object to create the DBRef
if isinstance(value, Document) and value.id is None: if isinstance(value, Document) and value.id is None:
raise ValidationError('You can only reference documents once ' self.error('You can only reference documents once they have been'
'they have been saved to the database ' ' saved to the database')
'("%s")' % self.name)
def dereference(self, value): def dereference(self, value):
doc_cls = get_document(value['_cls']) doc_cls = get_document(value['_cls'])
@ -747,9 +726,8 @@ class GenericReferenceField(BaseField):
# We need the id from the saved object to create the DBRef # We need the id from the saved object to create the DBRef
id_ = document.id id_ = document.id
if id_ is None: if id_ is None:
raise ValidationError('You can only reference documents once ' self.error('You can only reference documents once they have'
'they have been saved to the database ' ' been saved to the database')
'("%s")' % self.name)
else: else:
id_ = document id_ = document
@ -781,11 +759,11 @@ class BinaryField(BaseField):
return str(value) return str(value)
def validate(self, value): def validate(self, value):
assert isinstance(value, str) if not isinstance(value, str):
self.error('BinaryField only accepts string values')
if self.max_bytes is not None and len(value) > self.max_bytes: if self.max_bytes is not None and len(value) > self.max_bytes:
raise ValidationError('Binary value is too long ("%s")' % self.error('Binary value is too long')
self.name)
class GridFSError(Exception): class GridFSError(Exception):
@ -950,8 +928,10 @@ class FileField(BaseField):
def validate(self, value): def validate(self, value):
if value.grid_id is not None: if value.grid_id is not None:
assert isinstance(value, self.proxy_class) if not isinstance(value, self.proxy_class):
assert isinstance(value.grid_id, pymongo.objectid.ObjectId) self.error('FileField only accepts GridFSProxy values')
if not isinstance(value.grid_id, pymongo.objectid.ObjectId):
self.error('Invalid GridFSProxy value')
class ImageGridFsProxy(GridFSProxy): class ImageGridFsProxy(GridFSProxy):
@ -1125,16 +1105,14 @@ class GeoPointField(BaseField):
"""Make sure that a geo-value is of type (x, y) """Make sure that a geo-value is of type (x, y)
""" """
if not isinstance(value, (list, tuple)): if not isinstance(value, (list, tuple)):
raise ValidationError('GeoPointField can only accept tuples or ' self.error('GeoPointField can only accept tuples or lists '
'lists of (x, y) ("%s")' % self.name) 'of (x, y)')
if not len(value) == 2: if not len(value) == 2:
raise ValidationError('Value must be a two-dimensional point ' self.error('Value must be a two-dimensional point')
'("%s")' % self.name)
if (not isinstance(value[0], (float, int)) and if (not isinstance(value[0], (float, int)) and
not isinstance(value[1], (float, int))): not isinstance(value[1], (float, int))):
raise ValidationError('Both values in point must be float or int ' self.error('Both values in point must be float or int')
'("%s")' % self.name)
class SequenceField(IntField): class SequenceField(IntField):
@ -1221,4 +1199,4 @@ class UUIDField(BaseField):
try: try:
value = uuid.UUID(value) value = uuid.UUID(value)
except Exception, exc: except Exception, exc:
raise ValidationError('Could not convert to UUID: %s' % exc) self.error('Could not convert to UUID: %s' % exc)

View File

@ -359,27 +359,27 @@ class FieldTest(unittest.TestCase):
logs = LogEntry.objects.order_by("date") logs = LogEntry.objects.order_by("date")
count = logs.count() count = logs.count()
i = 0 i = 0
while i == count-1: while i == count - 1:
self.assertTrue(logs[i].date <= logs[i+1].date) self.assertTrue(logs[i].date <= logs[i + 1].date)
i +=1 i += 1
logs = LogEntry.objects.order_by("-date") logs = LogEntry.objects.order_by("-date")
count = logs.count() count = logs.count()
i = 0 i = 0
while i == count-1: while i == count - 1:
self.assertTrue(logs[i].date >= logs[i+1].date) self.assertTrue(logs[i].date >= logs[i + 1].date)
i +=1 i += 1
# Test searching # Test searching
logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980,1,1)) logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1))
self.assertEqual(logs.count(), 30) self.assertEqual(logs.count(), 30)
logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980,1,1)) logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1))
self.assertEqual(logs.count(), 30) self.assertEqual(logs.count(), 30)
logs = LogEntry.objects.filter( logs = LogEntry.objects.filter(
date__lte=datetime.datetime(2011,1,1), date__lte=datetime.datetime(2011, 1, 1),
date__gte=datetime.datetime(2000,1,1), date__gte=datetime.datetime(2000, 1, 1),
) )
self.assertEqual(logs.count(), 10) self.assertEqual(logs.count(), 10)
@ -1130,7 +1130,6 @@ class FieldTest(unittest.TestCase):
Post.drop_collection() Post.drop_collection()
User.drop_collection() User.drop_collection()
def test_generic_reference_document_not_registered(self): def test_generic_reference_document_not_registered(self):
"""Ensure dereferencing out of the document registry throws a """Ensure dereferencing out of the document registry throws a
`NotRegistered` error. `NotRegistered` error.
@ -1157,7 +1156,7 @@ class FieldTest(unittest.TestCase):
user = User.objects.first() user = User.objects.first()
try: try:
user.bookmarks user.bookmarks
raise AssertionError, "Link was removed from the registry" raise AssertionError("Link was removed from the registry")
except NotRegistered: except NotRegistered:
pass pass
@ -1357,7 +1356,7 @@ class FieldTest(unittest.TestCase):
# Make sure FileField is optional and not required # Make sure FileField is optional and not required
class DemoFile(Document): class DemoFile(Document):
file = FileField() file = FileField()
d = DemoFile.objects.create() DemoFile.objects.create()
def test_file_uniqueness(self): def test_file_uniqueness(self):
"""Ensure that each instance of a FileField is unique """Ensure that each instance of a FileField is unique
@ -1617,7 +1616,6 @@ class FieldTest(unittest.TestCase):
c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'})
self.assertEqual(c['next'], 10) self.assertEqual(c['next'], 10)
def test_generic_embedded_document(self): def test_generic_embedded_document(self):
class Car(EmbeddedDocument): class Car(EmbeddedDocument):
name = StringField() name = StringField()
@ -1643,5 +1641,88 @@ class FieldTest(unittest.TestCase):
person = Person.objects.first() person = Person.objects.first()
self.assertTrue(isinstance(person.like, Dish)) self.assertTrue(isinstance(person.like, Dish))
def test_recursive_validation(self):
"""Ensure that a validation result schema is available.
"""
class Author(EmbeddedDocument):
name = StringField(required=True)
class Comment(EmbeddedDocument):
author = EmbeddedDocumentField(Author, required=True)
content = StringField(required=True)
class Post(Document):
title = StringField(required=True)
comments = ListField(EmbeddedDocumentField(Comment))
bob = Author(name='Bob')
post = Post(title='hello world')
post.comments.append(Comment(content='hello', author=bob))
post.comments.append(Comment(author=bob))
try:
post.validate()
except ValidationError, error:
pass
# ValidationError.errors property
self.assertTrue(hasattr(error, 'errors'))
self.assertTrue(isinstance(error.errors, dict))
self.assertTrue('comments' in error.errors)
self.assertTrue(1 in error.errors['comments'])
self.assertTrue(isinstance(error.errors['comments'][1]['content'],
ValidationError))
# ValidationError.schema property
schema = error.schema
self.assertTrue(isinstance(schema, dict))
self.assertTrue('comments' in schema)
self.assertTrue(1 in schema['comments'])
self.assertTrue('content' in schema['comments'][1])
self.assertEquals(schema['comments'][1]['content'],
u'Field is required ("content")')
post.comments[1].content = 'here we go'
post.validate()
class ValidatorErrorTest(unittest.TestCase):
def test_schema(self):
"""Ensure a ValidationError handles error schema correctly.
"""
error = ValidationError('root')
self.assertEquals(error.schema, {})
# 1st level error schema
error.errors = {'1st': ValidationError('bad 1st'), }
self.assertTrue('1st' in error.schema)
self.assertEquals(error.schema['1st'], 'bad 1st')
# 2nd level error schema
error.errors = {'1st': ValidationError('bad 1st', errors={
'2nd': ValidationError('bad 2nd'),
})}
self.assertTrue('1st' in error.schema)
self.assertTrue(isinstance(error.schema['1st'], dict))
self.assertTrue('2nd' in error.schema['1st'])
self.assertEquals(error.schema['1st']['2nd'], 'bad 2nd')
# moar levels
error.errors = {'1st': ValidationError('bad 1st', errors={
'2nd': ValidationError('bad 2nd', errors={
'3rd': ValidationError('bad 3rd', errors={
'4th': ValidationError('Inception'),
}),
}),
})}
self.assertTrue('1st' in error.schema)
self.assertTrue('2nd' in error.schema['1st'])
self.assertTrue('3rd' in error.schema['1st']['2nd'])
self.assertTrue('4th' in error.schema['1st']['2nd']['3rd'])
self.assertEquals(error.schema['1st']['2nd']['3rd']['4th'],
'Inception')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()