Made field validation lazier

This commit is contained in:
Harry Marr 2010-01-03 22:37:55 +00:00
parent 3574198210
commit b01596c942
3 changed files with 93 additions and 48 deletions

View File

@ -35,17 +35,8 @@ class BaseField(object):
return value return value
def __set__(self, instance, value): def __set__(self, instance, value):
"""Descriptor for assigning a value to a field in a document. Do any """Descriptor for assigning a value to a field in a document.
necessary conversion between Python and MongoDB types.
""" """
if value is not None:
try:
self.validate(value)
except (ValueError, AttributeError, AssertionError), e:
raise ValidationError('Invalid value for field of type "' +
self.__class__.__name__ + '"')
elif self.required:
raise ValidationError('Field "%s" is required' % self.name)
instance._data[self.name] = value instance._data[self.name] = value
def to_python(self, value): def to_python(self, value):
@ -183,8 +174,6 @@ class BaseDocument(object):
else: else:
# Use default value if present # Use default value if present
value = getattr(self, attr_name, None) value = getattr(self, attr_name, None)
if value is None and attr_value.required:
raise ValidationError('Field "%s" is required' % attr_name)
setattr(self, attr_name, value) setattr(self, attr_name, value)
@classmethod @classmethod

View File

@ -1,4 +1,5 @@
from base import DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument,
ValidationError)
from connection import _get_db from connection import _get_db
@ -44,6 +45,7 @@ class Document(BaseDocument):
document already exists, it will be updated, otherwise it will be document already exists, it will be updated, otherwise it will be
created. created.
""" """
self.validate()
object_id = self.__class__.objects._collection.save(self.to_mongo()) object_id = self.__class__.objects._collection.save(self.to_mongo())
self.id = self._fields['id'].to_python(object_id) self.id = self._fields['id'].to_python(object_id)
@ -54,6 +56,25 @@ class Document(BaseDocument):
object_id = self._fields['id'].to_mongo(self.id) object_id = self._fields['id'].to_mongo(self.id)
self.__class__.objects(id=object_id).delete() self.__class__.objects(id=object_id).delete()
def validate(self):
"""Ensure that all fields' values are valid and that required fields
are present.
"""
# Get a list of tuples of field names and their current values
fields = [(field, getattr(self, name))
for name, field in self._fields.items()]
# Ensure that each field is matched to a valid value
for field, value in fields:
if value is not None:
try:
field.validate(value)
except (ValueError, AttributeError, AssertionError), e:
raise ValidationError('Invalid value for field of type "' +
field.__class__.__name__ + '"')
elif field.required:
raise ValidationError('Field "%s" is required' % field.name)
@classmethod @classmethod
def drop_collection(cls): def drop_collection(cls):
"""Drops the entire collection associated with this """Drops the entire collection associated with this

View File

@ -31,13 +31,10 @@ class FieldTest(unittest.TestCase):
age = IntField(required=True) age = IntField(required=True)
userid = StringField() userid = StringField()
self.assertRaises(ValidationError, Person, name="Test User") person = Person(name="Test User")
self.assertRaises(ValidationError, Person, age=30) self.assertRaises(ValidationError, person.validate)
person = Person(age=30)
person = Person(name="Test User", age=30, userid="testuser") self.assertRaises(ValidationError, person.validate)
self.assertRaises(ValidationError, person.__setattr__, 'name', None)
self.assertRaises(ValidationError, person.__setattr__, 'age', None)
person.userid = None
def test_object_id_validation(self): def test_object_id_validation(self):
"""Ensure that invalid values cannot be assigned to string fields. """Ensure that invalid values cannot be assigned to string fields.
@ -47,9 +44,15 @@ class FieldTest(unittest.TestCase):
person = Person(name='Test User') person = Person(name='Test User')
self.assertEqual(person.id, None) self.assertEqual(person.id, None)
self.assertRaises(ValidationError, person.__setattr__, 'id', 47)
self.assertRaises(ValidationError, person.__setattr__, 'id', 'abc') person.id = 47
self.assertRaises(ValidationError, person.validate)
person.id = 'abc'
self.assertRaises(ValidationError, person.validate)
person.id = '497ce96f395f2f052a494fd4' person.id = '497ce96f395f2f052a494fd4'
person.validate()
def test_string_validation(self): def test_string_validation(self):
"""Ensure that invalid values cannot be assigned to string fields. """Ensure that invalid values cannot be assigned to string fields.
@ -58,20 +61,23 @@ class FieldTest(unittest.TestCase):
name = StringField(max_length=20) name = StringField(max_length=20)
userid = StringField(r'[0-9a-z_]+$') userid = StringField(r'[0-9a-z_]+$')
person = Person() person = Person(name=34)
self.assertRaises(ValidationError, person.__setattr__, 'name', 34) self.assertRaises(ValidationError, person.validate)
# Test regex validation on userid # Test regex validation on userid
self.assertRaises(ValidationError, person.__setattr__, 'userid', person = Person(userid='test.User')
'test.User') self.assertRaises(ValidationError, person.validate)
person.userid = 'test_user' person.userid = 'test_user'
self.assertEqual(person.userid, 'test_user') self.assertEqual(person.userid, 'test_user')
person.validate()
# Test max length validation on name # Test max length validation on name
self.assertRaises(ValidationError, person.__setattr__, 'name', person = Person(name='Name that is more than twenty characters')
'Name that is more than twenty characters') self.assertRaises(ValidationError, person.validate)
person.name = 'Shorter name' person.name = 'Shorter name'
self.assertEqual(person.name, 'Shorter name') person.validate()
def test_int_validation(self): def test_int_validation(self):
"""Ensure that invalid values cannot be assigned to int fields. """Ensure that invalid values cannot be assigned to int fields.
@ -81,9 +87,14 @@ class FieldTest(unittest.TestCase):
person = Person() person = Person()
person.age = 50 person.age = 50
self.assertRaises(ValidationError, person.__setattr__, 'age', -1) person.validate()
self.assertRaises(ValidationError, person.__setattr__, 'age', 120)
self.assertRaises(ValidationError, person.__setattr__, 'age', 'ten') person.age = -1
self.assertRaises(ValidationError, person.validate)
person.age = 120
self.assertRaises(ValidationError, person.validate)
person.age = 'ten'
self.assertRaises(ValidationError, person.validate)
def test_float_validation(self): def test_float_validation(self):
"""Ensure that invalid values cannot be assigned to float fields. """Ensure that invalid values cannot be assigned to float fields.
@ -93,9 +104,14 @@ class FieldTest(unittest.TestCase):
person = Person() person = Person()
person.height = 1.89 person.height = 1.89
self.assertRaises(ValidationError, person.__setattr__, 'height', 2) person.validate()
self.assertRaises(ValidationError, person.__setattr__, 'height', 0.01)
self.assertRaises(ValidationError, person.__setattr__, 'height', 4.0) person.height = 2
self.assertRaises(ValidationError, person.validate)
person.height = 0.01
self.assertRaises(ValidationError, person.validate)
person.height = 4.0
self.assertRaises(ValidationError, person.validate)
def test_datetime_validation(self): def test_datetime_validation(self):
"""Ensure that invalid values cannot be assigned to datetime fields. """Ensure that invalid values cannot be assigned to datetime fields.
@ -104,9 +120,13 @@ class FieldTest(unittest.TestCase):
time = DateTimeField() time = DateTimeField()
log = LogEntry() log = LogEntry()
self.assertRaises(ValidationError, log.__setattr__, 'time', -1)
self.assertRaises(ValidationError, log.__setattr__, 'time', '1pm')
log.time = datetime.datetime.now() log.time = datetime.datetime.now()
log.validate()
log.time = -1
self.assertRaises(ValidationError, log.validate)
log.time = '1pm'
self.assertRaises(ValidationError, log.validate)
def test_list_validation(self): def test_list_validation(self):
"""Ensure that a list field only accepts lists with valid elements. """Ensure that a list field only accepts lists with valid elements.
@ -120,16 +140,26 @@ class FieldTest(unittest.TestCase):
tags = ListField(StringField()) tags = ListField(StringField())
post = BlogPost(content='Went for a walk today...') post = BlogPost(content='Went for a walk today...')
self.assertRaises(ValidationError, post.__setattr__, 'tags', 'fun') post.validate()
self.assertRaises(ValidationError, post.__setattr__, 'tags', [1, 2])
post.tags = 'fun'
self.assertRaises(ValidationError, post.validate)
post.tags = [1, 2]
self.assertRaises(ValidationError, post.validate)
post.tags = ['fun', 'leisure'] post.tags = ['fun', 'leisure']
post.validate()
post.tags = ('fun', 'leisure') post.tags = ('fun', 'leisure')
post.validate()
comments = [Comment(content='Good for you'), Comment(content='Yay.')] comments = [Comment(content='Good for you'), Comment(content='Yay.')]
self.assertRaises(ValidationError, post.__setattr__, 'comments', ['a'])
self.assertRaises(ValidationError, post.__setattr__, 'comments', 'Yay')
self.assertRaises(ValidationError, post.__setattr__, 'comments', 'Yay')
post.comments = comments post.comments = comments
post.validate()
post.comments = ['a']
self.assertRaises(ValidationError, post.validate)
post.comments = 'yay'
self.assertRaises(ValidationError, post.validate)
def test_embedded_document_validation(self): def test_embedded_document_validation(self):
"""Ensure that invalid embedded documents cannot be assigned to """Ensure that invalid embedded documents cannot be assigned to
@ -147,12 +177,15 @@ class FieldTest(unittest.TestCase):
preferences = EmbeddedDocumentField(PersonPreferences) preferences = EmbeddedDocumentField(PersonPreferences)
person = Person(name='Test User') person = Person(name='Test User')
self.assertRaises(ValidationError, person.__setattr__, 'preferences', person.preferences = 'My Preferences'
'My preferences') self.assertRaises(ValidationError, person.validate)
self.assertRaises(ValidationError, person.__setattr__, 'preferences',
Comment(content='Nice blog post...')) person.preferences = Comment(content='Nice blog post...')
self.assertRaises(ValidationError, person.validate)
person.preferences = PersonPreferences(food='Cheese', number=47) person.preferences = PersonPreferences(food='Cheese', number=47)
self.assertEqual(person.preferences.food, 'Cheese') self.assertEqual(person.preferences.food, 'Cheese')
person.validate()
def test_embedded_document_inheritance(self): def test_embedded_document_inheritance(self):
"""Ensure that subclasses of embedded documents may be provided to """Ensure that subclasses of embedded documents may be provided to
@ -194,14 +227,16 @@ class FieldTest(unittest.TestCase):
# Check that an invalid object type cannot be used # Check that an invalid object type cannot be used
post2 = BlogPost(content='Chips and chilli taste good.') post2 = BlogPost(content='Chips and chilli taste good.')
self.assertRaises(ValidationError, post1.__setattr__, 'author', post2) post1.author = post2
self.assertRaises(ValidationError, post1.validate)
user.save() user.save()
post1.author = user post1.author = user
post1.save() post1.save()
post2.save() post2.save()
self.assertRaises(ValidationError, post1.__setattr__, 'author', post2) post1.author = post2
self.assertRaises(ValidationError, post1.validate)
User.drop_collection() User.drop_collection()
BlogPost.drop_collection() BlogPost.drop_collection()