diff --git a/docs/changelog.rst b/docs/changelog.rst index 9f109cc3..64142fde 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in dev ============== +- Added recursive validation error of documents / complex fields - Fixed breaking during queryset iteration - Added pre and post bulk-insert signals - Added ImageField - requires PIL diff --git a/mongoengine/base.py b/mongoengine/base.py index 20388e91..8eaa05b1 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -4,7 +4,6 @@ from queryset import DO_NOTHING from mongoengine import signals -import weakref import sys import pymongo import pymongo.objectid @@ -20,8 +19,56 @@ class InvalidDocumentError(Exception): pass -class ValidationError(Exception): - pass +class ValidationError(AssertionError): + """Validation exception. + """ + errors = {} + field_name = None + _message = None + + def __init__(self, message="", **kwargs): + self.errors = kwargs.get('errors', {}) + self.field_name = kwargs.get('field_name') + self.message = message + + def __str__(self): + return self.message + + def __repr__(self): + return '%s(%s,)' % (self.__class__.__name__, self.message) + + def __getattribute__(self, name): + message = super(ValidationError, self).__getattribute__(name) + if name == 'message' and self.field_name: + return message + ' ("%s")' % self.field_name + else: + return message + + def _get_message(self): + return self._message + + def _set_message(self, message): + self._message = message + + message = property(_get_message, _set_message) + + @property + def schema(self): + def get_schema(source): + errors_dict = {} + if not source: + return errors_dict + if isinstance(source, dict): + for field_name, error in source.iteritems(): + errors_dict[field_name] = get_schema(error) + elif isinstance(source, ValidationError) and source.errors: + return get_schema(source.errors) + else: + return unicode(source) + return errors_dict + if not self.errors: + return {} + return get_schema(self.errors) _document_registry = {} @@ -51,6 +98,8 @@ class BaseField(object): .. versionchanged:: 0.5 - added verbose and help text """ + name = None + # Fields may have _types inserted into indexes by default _index_with_types = True _geo_index = False @@ -117,6 +166,12 @@ class BaseField(object): instance._data[self.name] = value instance._mark_as_changed(self.name) + def error(self, message="", errors=None, field_name=None): + """Raises a ValidationError. + """ + field_name = field_name if field_name else self.name + raise ValidationError(message, errors=errors, field_name=field_name) + def to_python(self, value): """Convert a MongoDB-compatible type to a Python type. """ @@ -142,15 +197,13 @@ class BaseField(object): if self.choices is not None: option_keys = [option_key for option_key, option_value in self.choices] if value not in option_keys: - raise ValidationError('Value must be one of %s ("%s")' % - (unicode(option_keys), self.name)) + self.error('Value must be one of %s' % unicode(option_keys)) # check validation argument if self.validation is not None: if callable(self.validation): if not self.validation(value): - raise ValidationError('Value does not match custom ' - 'validation method ("%s")' % self.name) + self.error('Value does not match custom validation method') else: raise ValueError('validation argument for "%s" must be a ' 'callable.' % self.name) @@ -198,7 +251,7 @@ class ComplexBaseField(BaseField): if not hasattr(value, 'items'): try: is_list = True - value = dict([(k,v) for k,v in enumerate(value)]) + value = dict([(k, v) for k, v in enumerate(value)]) except TypeError: # Not iterable return the value return value @@ -206,13 +259,12 @@ class ComplexBaseField(BaseField): value_dict = dict([(key, self.field.to_python(item)) for key, item in value.items()]) else: value_dict = {} - for k,v in value.items(): + for k, v in value.items(): if isinstance(v, Document): # We need the id from the saved object to create the DBRef if v.pk is None: - raise ValidationError('You can only reference ' - 'documents once they have been saved ' - 'to the database ("%s")' % self.name) + self.error('You can only reference documents once they' + ' have been saved to the database') collection = v._get_collection_name() value_dict[k] = pymongo.dbref.DBRef(collection, v.pk) elif hasattr(v, 'to_python'): @@ -221,7 +273,7 @@ class ComplexBaseField(BaseField): value_dict[k] = self.to_python(v) if is_list: # Convert back to a list - return [v for k,v in sorted(value_dict.items(), key=operator.itemgetter(0))] + return [v for k, v in sorted(value_dict.items(), key=operator.itemgetter(0))] return value_dict def to_mongo(self, value): @@ -239,7 +291,7 @@ class ComplexBaseField(BaseField): if not hasattr(value, 'items'): try: is_list = True - value = dict([(k,v) for k,v in enumerate(value)]) + value = dict([(k, v) for k, v in enumerate(value)]) except TypeError: # Not iterable return the value return value @@ -247,13 +299,12 @@ class ComplexBaseField(BaseField): value_dict = dict([(key, self.field.to_mongo(item)) for key, item in value.items()]) else: value_dict = {} - for k,v in value.items(): + for k, v in value.items(): if isinstance(v, Document): # We need the id from the saved object to create the DBRef if v.pk is None: - raise ValidationError('You can only reference ' - 'documents once they have been saved ' - 'to the database ("%s")' % self.name) + self.error('You can only reference documents once they' + ' have been saved to the database') # If its a document that is not inheritable it won't have # _types / _cls data so make it a generic reference allows @@ -271,26 +322,33 @@ class ComplexBaseField(BaseField): value_dict[k] = self.to_mongo(v) if is_list: # Convert back to a list - return [v for k,v in sorted(value_dict.items(), key=operator.itemgetter(0))] + return [v for k, v in sorted(value_dict.items(), key=operator.itemgetter(0))] return value_dict def validate(self, value): - """If field provided ensure the value is valid. + """If field is provided ensure the value is valid. """ + errors = {} if self.field: - try: - if hasattr(value, 'iteritems'): - [self.field.validate(v) for k,v in value.iteritems()] - else: - [self.field.validate(v) for v in value] - except Exception, err: - raise ValidationError('Invalid %s item (%s) ("%s")' % ( - self.field.__class__.__name__, str(v), self.name)) - + if hasattr(value, 'iteritems'): + sequence = value.iteritems() + else: + sequence = enumerate(value) + for k, v in sequence: + try: + self.field.validate(v) + except (ValidationError, AssertionError), error: + if hasattr(error, 'errors'): + errors[k] = error.errors + else: + errors[k] = error + if errors: + field_class = self.field.__class__.__name__ + self.error('Invalid %s item (%s)' % (field_class, value), + errors=errors) # Don't allow empty values if required if self.required and not value: - raise ValidationError('Field "%s" is required and cannot be empty' % - self.name) + self.error('Field is required and cannot be empty') def prepare_query_value(self, op, value): return self.to_mongo(value) @@ -345,6 +403,7 @@ class BaseDynamicField(BaseField): def lookup_member(self, member_name): return member_name + class ObjectIdField(BaseField): """An field wrapper around MongoDB's ObjectIds. """ @@ -357,8 +416,8 @@ class ObjectIdField(BaseField): try: return pymongo.objectid.ObjectId(unicode(value)) except Exception, e: - #e.message attribute has been deprecated since Python 2.6 - raise ValidationError(unicode(e)) + # e.message attribute has been deprecated since Python 2.6 + self.error(unicode(e)) return value def prepare_query_value(self, op, value): @@ -368,7 +427,7 @@ class ObjectIdField(BaseField): try: pymongo.objectid.ObjectId(unicode(value)) except: - raise ValidationError('Invalid Object ID ("%s")' % self.name) + self.error('Invalid Object ID') class DocumentMetaclass(type): @@ -394,7 +453,7 @@ class DocumentMetaclass(type): superclasses[base._class_name] = base superclasses.update(base._superclasses) else: # Add any mixin fields - attrs.update(dict([(k,v) for k,v in base.__dict__.items() + attrs.update(dict([(k, v) for k, v in base.__dict__.items() if issubclass(v.__class__, BaseField)])) if hasattr(base, '_meta') and not base._meta.get('abstract'): @@ -489,7 +548,7 @@ class TopLevelDocumentMetaclass(DocumentMetaclass): ('meta' in attrs and attrs['meta'].get('abstract', False))): # Make sure no base class was non-abstract non_abstract_bases = [b for b in bases - if hasattr(b,'_meta') and not b._meta.get('abstract', False)] + if hasattr(b, '_meta') and not b._meta.get('abstract', False)] if non_abstract_bases: raise ValueError("Abstract document cannot have non-abstract base") return super_new(cls, name, bases, attrs) @@ -666,7 +725,7 @@ class BaseDocument(object): signals.post_init.send(self.__class__, document=self) def __setattr__(self, name, value): - # Handle dynamic data only if an intialised dynamic document + # Handle dynamic data only if an initialised dynamic document if self._dynamic and getattr(self, '_initialised', False): field = None @@ -709,7 +768,8 @@ class BaseDocument(object): data[k] = self.__expand_dynamic_values(key, v) if is_list: # Convert back to a list - value = [v for k, v in sorted(data.items(), key=operator.itemgetter(0))] + data_items = sorted(data.items(), key=operator.itemgetter(0)) + value = [v for k, v in data_items] else: value = data @@ -730,15 +790,21 @@ class BaseDocument(object): for name, field in self._fields.items()] # Ensure that each field is matched to a valid value + errors = {} for field, value in fields: if value is not None: try: field._validate(value) - except (ValueError, AttributeError, AssertionError), e: - raise ValidationError('Invalid value for field named "%s" of type "%s": %s' - % (field.name, field.__class__.__name__, value)) + except ValidationError, error: + errors[field.name] = error.errors or error + except (ValueError, AttributeError, AssertionError), error: + errors[field.name] = error elif field.required: - raise ValidationError('Field "%s" is required' % field.name) + errors[field.name] = ValidationError('Field is required', + field_name=field.name) + if errors: + raise ValidationError('Errors encountered validating document', + errors=errors) def to_mongo(self): """Return data dictionary ready for use with MongoDB. @@ -812,7 +878,6 @@ class BaseDocument(object): """.strip() % class_name) cls = subclasses[class_name] - present_fields = data.keys() for field_name, field in cls._fields.items(): if field.db_field in data: value = data[field.db_field] @@ -963,8 +1028,7 @@ class BaseDocument(object): return geo_indices def __getstate__(self): - self_dict = self.__dict__ - removals = ["get_%s_display" % k for k,v in self._fields.items() if v.choices] + removals = ["get_%s_display" % k for k, v in self._fields.items() if v.choices] for k in removals: if hasattr(self, k): delattr(self, k) @@ -1039,7 +1103,7 @@ class BaseDocument(object): def __hash__(self): if self.pk is None: # For new object - return super(BaseDocument,self).__hash__() + return super(BaseDocument, self).__hash__() else: return hash(self.pk) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 5700fe41..0bfd54d2 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -54,20 +54,17 @@ class StringField(BaseField): return unicode(value) def validate(self, value): - assert isinstance(value, (str, unicode)) + if not isinstance(value, (str, unicode)): + self.error('StringField only accepts string values') if self.max_length is not None and len(value) > self.max_length: - raise ValidationError('String value is too long ("%s")' % - self.name) + self.error('String value is too long') if self.min_length is not None and len(value) < self.min_length: - raise ValidationError('String value is too short ("%s")' % - self.name) + self.error('String value is too short') if self.regex is not None and self.regex.match(value) is None: - message = 'String value did not match validation regex ("%s")' % \ - self.name - raise ValidationError(message) + self.error('String value did not match validation regex') def lookup_member(self, member_name): return None @@ -117,18 +114,15 @@ class URLField(StringField): def validate(self, value): if not URLField.URL_REGEX.match(value): - raise ValidationError('Invalid URL: %s ("%s")' % (value, - self.name)) + self.error('Invalid URL: %s' % value) if self.verify_exists: import urllib2 try: request = urllib2.Request(value) - response = urllib2.urlopen(request) + urllib2.urlopen(request) except Exception, e: - message = 'This URL appears to be a broken link: %s ("%s")' % ( - e, self.name) - raise ValidationError(message) + self.error('This URL appears to be a broken link: %s' % e) class EmailField(StringField): @@ -145,8 +139,7 @@ class EmailField(StringField): def validate(self, value): if not EmailField.EMAIL_REGEX.match(value): - raise ValidationError('Invalid Mail-address: %s ("%s")' % (value, - self.name)) + self.error('Invalid Mail-address: %s' % value) class IntField(BaseField): @@ -164,16 +157,13 @@ class IntField(BaseField): try: value = int(value) except: - raise ValidationError('%s could not be converted to int ("%s")' % ( - value, self.name)) + self.error('%s could not be converted to int' % value) if self.min_value is not None and value < self.min_value: - raise ValidationError('Integer value is too small ("%s")' % - self.name) + self.error('Integer value is too small') if self.max_value is not None and value > self.max_value: - raise ValidationError('Integer value is too large ("%s")' % - self.name) + self.error('Integer value is too large') def prepare_query_value(self, op, value): return int(value) @@ -193,15 +183,14 @@ class FloatField(BaseField): def validate(self, value): if isinstance(value, int): value = float(value) - assert isinstance(value, float) + if not isinstance(value, float): + self.error('FoatField only accepts float values') if self.min_value is not None and value < self.min_value: - raise ValidationError('Float value is too small ("%s")' % - self.name) + self.error('Float value is too small') if self.max_value is not None and value > self.max_value: - raise ValidationError('Float value is too large ("%s")' % - self.name) + self.error('Float value is too large') def prepare_query_value(self, op, value): return float(value) @@ -232,16 +221,13 @@ class DecimalField(BaseField): try: value = decimal.Decimal(value) except Exception, exc: - raise ValidationError('Could not convert value to decimal: %s' - '("%s")' % (exc, self.name)) + self.error('Could not convert value to decimal: %s' % exc) if self.min_value is not None and value < self.min_value: - raise ValidationError('Decimal value is too small ("%s")' % - self.name) + self.error('Decimal value is too small') if self.max_value is not None and value > self.max_value: - raise ValidationError('Decimal value is too large ("%s")' % - self.name) + self.error('Decimal value is too large') class BooleanField(BaseField): @@ -254,7 +240,8 @@ class BooleanField(BaseField): return bool(value) def validate(self, value): - assert isinstance(value, bool) + if not isinstance(value, bool): + self.error('BooleanField only accepts boolean values') class DateTimeField(BaseField): @@ -267,7 +254,8 @@ class DateTimeField(BaseField): """ def validate(self, value): - assert isinstance(value, (datetime.datetime, datetime.date)) + if not isinstance(value, (datetime.datetime, datetime.date)): + self.error(u'cannot parse date "%s"' % value) def to_mongo(self, value): return self.prepare_query_value(None, value) @@ -388,8 +376,8 @@ class ComplexDateTimeField(StringField): def validate(self, value): if not isinstance(value, datetime.datetime): - raise ValidationError('Only datetime objects may used in a ' - 'ComplexDateTimeField ("%s")' % self.name) + self.error('Only datetime objects may used in a ' + 'ComplexDateTimeField') def to_python(self, value): return self._convert_from_string(value) @@ -409,9 +397,8 @@ class EmbeddedDocumentField(BaseField): def __init__(self, document_type, **kwargs): if not isinstance(document_type, basestring): if not issubclass(document_type, EmbeddedDocument): - raise ValidationError('Invalid embedded document class ' - 'provided to an EmbeddedDocumentField ' - '("%s")' % self.name) + self.error('Invalid embedded document class provided to an ' + 'EmbeddedDocumentField') self.document_type_obj = document_type super(EmbeddedDocumentField, self).__init__(**kwargs) @@ -440,9 +427,8 @@ class EmbeddedDocumentField(BaseField): """ # Using isinstance also works for subclasses of self.document if not isinstance(value, self.document_type): - raise ValidationError('Invalid embedded document instance ' - 'provided to an EmbeddedDocumentField ' - '("%s")' % self.name) + self.error('Invalid embedded document instance provided to an ' + 'EmbeddedDocumentField') self.document_type.validate(value) def lookup_member(self, member_name): @@ -471,9 +457,8 @@ class GenericEmbeddedDocumentField(BaseField): def validate(self, value): if not isinstance(value, EmbeddedDocument): - raise ValidationError('Invalid embedded document instance ' - 'provided to an GenericEmbeddedDocumentField ' - '("%s")' % self.name) + self.error('Invalid embedded document instance provided to an ' + 'GenericEmbeddedDocumentField') value.validate() @@ -508,8 +493,7 @@ class ListField(ComplexBaseField): """ if (not isinstance(value, (list, tuple)) or isinstance(value, basestring)): - raise ValidationError('Only lists and tuples may be used in a ' - 'list field ("%s")' % self.name) + self.error('Only lists and tuples may be used in a list field') super(ListField, self).validate(value) def prepare_query_value(self, op, value): @@ -557,7 +541,8 @@ class DictField(ComplexBaseField): def __init__(self, basecls=None, field=None, *args, **kwargs): self.field = field self.basecls = basecls or BaseField - assert issubclass(self.basecls, BaseField) + if not issubclass(self.basecls, BaseField): + self.error('DictField only accepts dict values') kwargs.setdefault('default', lambda: {}) super(DictField, self).__init__(*args, **kwargs) @@ -565,13 +550,11 @@ class DictField(ComplexBaseField): """Make sure that a list of valid fields is being used. """ if not isinstance(value, dict): - raise ValidationError('Only dictionaries may be used in a ' - 'DictField ("%s")' % self.name) + self.error('Only dictionaries may be used in a DictField') if any(('.' in k or '$' in k) for k in value): - raise ValidationError('Invalid dictionary key name - keys may not ' - 'contain "." or "$" characters ("%s")' % - self.name) + self.error('Invalid dictionary key name - keys may not contain "."' + ' or "$" characters') super(DictField, self).validate(value) def lookup_member(self, member_name): @@ -598,12 +581,11 @@ class MapField(DictField): def __init__(self, field=None, *args, **kwargs): if not isinstance(field, BaseField): - raise ValidationError('Argument to MapField constructor must be ' - 'a valid field') + self.error('Argument to MapField constructor must be a valid ' + 'field') super(MapField, self).__init__(field=field, *args, **kwargs) - class ReferenceField(BaseField): """A reference to a document that will be automatically dereferenced on access (lazily). @@ -629,8 +611,8 @@ class ReferenceField(BaseField): """ if not isinstance(document_type, basestring): if not issubclass(document_type, (Document, basestring)): - raise ValidationError('Argument to ReferenceField constructor ' - 'must be a document class or a string') + self.error('Argument to ReferenceField constructor must be a ' + 'document class or a string') self.document_type_obj = document_type self.reverse_delete_rule = reverse_delete_rule super(ReferenceField, self).__init__(**kwargs) @@ -669,8 +651,8 @@ class ReferenceField(BaseField): # We need the id from the saved object to create the DBRef id_ = document.id if id_ is None: - raise ValidationError('You can only reference documents once ' - 'they have been saved to the database') + self.error('You can only reference documents once they have' + ' been saved to the database') else: id_ = document @@ -685,13 +667,12 @@ class ReferenceField(BaseField): return self.to_mongo(value) def validate(self, value): - assert isinstance(value, (self.document_type, pymongo.dbref.DBRef)) + if not isinstance(value, (self.document_type, pymongo.dbref.DBRef)): + self.error('A ReferenceField only accepts DBRef') if isinstance(value, Document) and value.id is None: - raise ValidationError('You can only reference documents once ' - 'they have been saved to the database ' - '("%s")' % self.name) - + self.error('You can only reference documents once they have been ' + 'saved to the database') def lookup_member(self, member_name): return self.document_type._fields.get(member_name) @@ -719,14 +700,12 @@ class GenericReferenceField(BaseField): def validate(self, value): if not isinstance(value, (Document, pymongo.dbref.DBRef)): - raise ValidationError('GenericReferences can only contain ' - 'documents ("%s")' % self.name) + self.error('GenericReferences can only contain documents') # We need the id from the saved object to create the DBRef if isinstance(value, Document) and value.id is None: - raise ValidationError('You can only reference documents once ' - 'they have been saved to the database ' - '("%s")' % self.name) + self.error('You can only reference documents once they have been' + ' saved to the database') def dereference(self, value): doc_cls = get_document(value['_cls']) @@ -747,9 +726,8 @@ class GenericReferenceField(BaseField): # We need the id from the saved object to create the DBRef id_ = document.id if id_ is None: - raise ValidationError('You can only reference documents once ' - 'they have been saved to the database ' - '("%s")' % self.name) + self.error('You can only reference documents once they have' + ' been saved to the database') else: id_ = document @@ -781,11 +759,11 @@ class BinaryField(BaseField): return str(value) def validate(self, value): - assert isinstance(value, str) + if not isinstance(value, str): + self.error('BinaryField only accepts string values') if self.max_bytes is not None and len(value) > self.max_bytes: - raise ValidationError('Binary value is too long ("%s")' % - self.name) + self.error('Binary value is too long') class GridFSError(Exception): @@ -950,8 +928,10 @@ class FileField(BaseField): def validate(self, value): if value.grid_id is not None: - assert isinstance(value, self.proxy_class) - assert isinstance(value.grid_id, pymongo.objectid.ObjectId) + if not isinstance(value, self.proxy_class): + self.error('FileField only accepts GridFSProxy values') + if not isinstance(value.grid_id, pymongo.objectid.ObjectId): + self.error('Invalid GridFSProxy value') class ImageGridFsProxy(GridFSProxy): @@ -1125,16 +1105,14 @@ class GeoPointField(BaseField): """Make sure that a geo-value is of type (x, y) """ if not isinstance(value, (list, tuple)): - raise ValidationError('GeoPointField can only accept tuples or ' - 'lists of (x, y) ("%s")' % self.name) + self.error('GeoPointField can only accept tuples or lists ' + 'of (x, y)') if not len(value) == 2: - raise ValidationError('Value must be a two-dimensional point ' - '("%s")' % self.name) + self.error('Value must be a two-dimensional point') if (not isinstance(value[0], (float, int)) and not isinstance(value[1], (float, int))): - raise ValidationError('Both values in point must be float or int ' - '("%s")' % self.name) + self.error('Both values in point must be float or int') class SequenceField(IntField): @@ -1221,4 +1199,4 @@ class UUIDField(BaseField): try: value = uuid.UUID(value) except Exception, exc: - raise ValidationError('Could not convert to UUID: %s' % exc) + self.error('Could not convert to UUID: %s' % exc) diff --git a/tests/fields.py b/tests/fields.py index 20cdf197..6920648d 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -359,27 +359,27 @@ class FieldTest(unittest.TestCase): logs = LogEntry.objects.order_by("date") count = logs.count() i = 0 - while i == count-1: - self.assertTrue(logs[i].date <= logs[i+1].date) - i +=1 + while i == count - 1: + self.assertTrue(logs[i].date <= logs[i + 1].date) + i += 1 logs = LogEntry.objects.order_by("-date") count = logs.count() i = 0 - while i == count-1: - self.assertTrue(logs[i].date >= logs[i+1].date) - i +=1 + while i == count - 1: + self.assertTrue(logs[i].date >= logs[i + 1].date) + i += 1 # Test searching - logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980,1,1)) + logs = LogEntry.objects.filter(date__gte=datetime.datetime(1980, 1, 1)) self.assertEqual(logs.count(), 30) - logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980,1,1)) + logs = LogEntry.objects.filter(date__lte=datetime.datetime(1980, 1, 1)) self.assertEqual(logs.count(), 30) logs = LogEntry.objects.filter( - date__lte=datetime.datetime(2011,1,1), - date__gte=datetime.datetime(2000,1,1), + date__lte=datetime.datetime(2011, 1, 1), + date__gte=datetime.datetime(2000, 1, 1), ) self.assertEqual(logs.count(), 10) @@ -1130,7 +1130,6 @@ class FieldTest(unittest.TestCase): Post.drop_collection() User.drop_collection() - def test_generic_reference_document_not_registered(self): """Ensure dereferencing out of the document registry throws a `NotRegistered` error. @@ -1157,7 +1156,7 @@ class FieldTest(unittest.TestCase): user = User.objects.first() try: user.bookmarks - raise AssertionError, "Link was removed from the registry" + raise AssertionError("Link was removed from the registry") except NotRegistered: pass @@ -1357,7 +1356,7 @@ class FieldTest(unittest.TestCase): # Make sure FileField is optional and not required class DemoFile(Document): file = FileField() - d = DemoFile.objects.create() + DemoFile.objects.create() def test_file_uniqueness(self): """Ensure that each instance of a FileField is unique @@ -1617,7 +1616,6 @@ class FieldTest(unittest.TestCase): c = self.db['mongoengine.counters'].find_one({'_id': 'animal.id'}) self.assertEqual(c['next'], 10) - def test_generic_embedded_document(self): class Car(EmbeddedDocument): name = StringField() @@ -1643,5 +1641,88 @@ class FieldTest(unittest.TestCase): person = Person.objects.first() self.assertTrue(isinstance(person.like, Dish)) + def test_recursive_validation(self): + """Ensure that a validation result schema is available. + """ + class Author(EmbeddedDocument): + name = StringField(required=True) + + class Comment(EmbeddedDocument): + author = EmbeddedDocumentField(Author, required=True) + content = StringField(required=True) + + class Post(Document): + title = StringField(required=True) + comments = ListField(EmbeddedDocumentField(Comment)) + + bob = Author(name='Bob') + post = Post(title='hello world') + post.comments.append(Comment(content='hello', author=bob)) + post.comments.append(Comment(author=bob)) + + try: + post.validate() + except ValidationError, error: + pass + + # ValidationError.errors property + self.assertTrue(hasattr(error, 'errors')) + self.assertTrue(isinstance(error.errors, dict)) + self.assertTrue('comments' in error.errors) + self.assertTrue(1 in error.errors['comments']) + self.assertTrue(isinstance(error.errors['comments'][1]['content'], + ValidationError)) + + # ValidationError.schema property + schema = error.schema + self.assertTrue(isinstance(schema, dict)) + self.assertTrue('comments' in schema) + self.assertTrue(1 in schema['comments']) + self.assertTrue('content' in schema['comments'][1]) + self.assertEquals(schema['comments'][1]['content'], + u'Field is required ("content")') + + post.comments[1].content = 'here we go' + post.validate() + + +class ValidatorErrorTest(unittest.TestCase): + + def test_schema(self): + """Ensure a ValidationError handles error schema correctly. + """ + error = ValidationError('root') + self.assertEquals(error.schema, {}) + + # 1st level error schema + error.errors = {'1st': ValidationError('bad 1st'), } + self.assertTrue('1st' in error.schema) + self.assertEquals(error.schema['1st'], 'bad 1st') + + # 2nd level error schema + error.errors = {'1st': ValidationError('bad 1st', errors={ + '2nd': ValidationError('bad 2nd'), + })} + self.assertTrue('1st' in error.schema) + self.assertTrue(isinstance(error.schema['1st'], dict)) + self.assertTrue('2nd' in error.schema['1st']) + self.assertEquals(error.schema['1st']['2nd'], 'bad 2nd') + + # moar levels + error.errors = {'1st': ValidationError('bad 1st', errors={ + '2nd': ValidationError('bad 2nd', errors={ + '3rd': ValidationError('bad 3rd', errors={ + '4th': ValidationError('Inception'), + }), + }), + })} + self.assertTrue('1st' in error.schema) + self.assertTrue('2nd' in error.schema['1st']) + self.assertTrue('3rd' in error.schema['1st']['2nd']) + self.assertTrue('4th' in error.schema['1st']['2nd']['3rd']) + self.assertEquals(error.schema['1st']['2nd']['3rd']['4th'], + 'Inception') + + if __name__ == '__main__': unittest.main()