Made field validation lazier
This commit is contained in:
parent
3574198210
commit
b01596c942
@ -35,17 +35,8 @@ class BaseField(object):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
def __set__(self, instance, value):
|
def __set__(self, instance, value):
|
||||||
"""Descriptor for assigning a value to a field in a document. Do any
|
"""Descriptor for assigning a value to a field in a document.
|
||||||
necessary conversion between Python and MongoDB types.
|
|
||||||
"""
|
"""
|
||||||
if value is not None:
|
|
||||||
try:
|
|
||||||
self.validate(value)
|
|
||||||
except (ValueError, AttributeError, AssertionError), e:
|
|
||||||
raise ValidationError('Invalid value for field of type "' +
|
|
||||||
self.__class__.__name__ + '"')
|
|
||||||
elif self.required:
|
|
||||||
raise ValidationError('Field "%s" is required' % self.name)
|
|
||||||
instance._data[self.name] = value
|
instance._data[self.name] = value
|
||||||
|
|
||||||
def to_python(self, value):
|
def to_python(self, value):
|
||||||
@ -183,8 +174,6 @@ class BaseDocument(object):
|
|||||||
else:
|
else:
|
||||||
# Use default value if present
|
# Use default value if present
|
||||||
value = getattr(self, attr_name, None)
|
value = getattr(self, attr_name, None)
|
||||||
if value is None and attr_value.required:
|
|
||||||
raise ValidationError('Field "%s" is required' % attr_name)
|
|
||||||
setattr(self, attr_name, value)
|
setattr(self, attr_name, value)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from base import DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument
|
from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument,
|
||||||
|
ValidationError)
|
||||||
from connection import _get_db
|
from connection import _get_db
|
||||||
|
|
||||||
|
|
||||||
@ -44,6 +45,7 @@ class Document(BaseDocument):
|
|||||||
document already exists, it will be updated, otherwise it will be
|
document already exists, it will be updated, otherwise it will be
|
||||||
created.
|
created.
|
||||||
"""
|
"""
|
||||||
|
self.validate()
|
||||||
object_id = self.__class__.objects._collection.save(self.to_mongo())
|
object_id = self.__class__.objects._collection.save(self.to_mongo())
|
||||||
self.id = self._fields['id'].to_python(object_id)
|
self.id = self._fields['id'].to_python(object_id)
|
||||||
|
|
||||||
@ -54,6 +56,25 @@ class Document(BaseDocument):
|
|||||||
object_id = self._fields['id'].to_mongo(self.id)
|
object_id = self._fields['id'].to_mongo(self.id)
|
||||||
self.__class__.objects(id=object_id).delete()
|
self.__class__.objects(id=object_id).delete()
|
||||||
|
|
||||||
|
def validate(self):
|
||||||
|
"""Ensure that all fields' values are valid and that required fields
|
||||||
|
are present.
|
||||||
|
"""
|
||||||
|
# Get a list of tuples of field names and their current values
|
||||||
|
fields = [(field, getattr(self, name))
|
||||||
|
for name, field in self._fields.items()]
|
||||||
|
|
||||||
|
# Ensure that each field is matched to a valid value
|
||||||
|
for field, value in fields:
|
||||||
|
if value is not None:
|
||||||
|
try:
|
||||||
|
field.validate(value)
|
||||||
|
except (ValueError, AttributeError, AssertionError), e:
|
||||||
|
raise ValidationError('Invalid value for field of type "' +
|
||||||
|
field.__class__.__name__ + '"')
|
||||||
|
elif field.required:
|
||||||
|
raise ValidationError('Field "%s" is required' % field.name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def drop_collection(cls):
|
def drop_collection(cls):
|
||||||
"""Drops the entire collection associated with this
|
"""Drops the entire collection associated with this
|
||||||
|
105
tests/fields.py
105
tests/fields.py
@ -31,13 +31,10 @@ class FieldTest(unittest.TestCase):
|
|||||||
age = IntField(required=True)
|
age = IntField(required=True)
|
||||||
userid = StringField()
|
userid = StringField()
|
||||||
|
|
||||||
self.assertRaises(ValidationError, Person, name="Test User")
|
person = Person(name="Test User")
|
||||||
self.assertRaises(ValidationError, Person, age=30)
|
self.assertRaises(ValidationError, person.validate)
|
||||||
|
person = Person(age=30)
|
||||||
person = Person(name="Test User", age=30, userid="testuser")
|
self.assertRaises(ValidationError, person.validate)
|
||||||
self.assertRaises(ValidationError, person.__setattr__, 'name', None)
|
|
||||||
self.assertRaises(ValidationError, person.__setattr__, 'age', None)
|
|
||||||
person.userid = None
|
|
||||||
|
|
||||||
def test_object_id_validation(self):
|
def test_object_id_validation(self):
|
||||||
"""Ensure that invalid values cannot be assigned to string fields.
|
"""Ensure that invalid values cannot be assigned to string fields.
|
||||||
@ -47,9 +44,15 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
person = Person(name='Test User')
|
person = Person(name='Test User')
|
||||||
self.assertEqual(person.id, None)
|
self.assertEqual(person.id, None)
|
||||||
self.assertRaises(ValidationError, person.__setattr__, 'id', 47)
|
|
||||||
self.assertRaises(ValidationError, person.__setattr__, 'id', 'abc')
|
person.id = 47
|
||||||
|
self.assertRaises(ValidationError, person.validate)
|
||||||
|
|
||||||
|
person.id = 'abc'
|
||||||
|
self.assertRaises(ValidationError, person.validate)
|
||||||
|
|
||||||
person.id = '497ce96f395f2f052a494fd4'
|
person.id = '497ce96f395f2f052a494fd4'
|
||||||
|
person.validate()
|
||||||
|
|
||||||
def test_string_validation(self):
|
def test_string_validation(self):
|
||||||
"""Ensure that invalid values cannot be assigned to string fields.
|
"""Ensure that invalid values cannot be assigned to string fields.
|
||||||
@ -58,20 +61,23 @@ class FieldTest(unittest.TestCase):
|
|||||||
name = StringField(max_length=20)
|
name = StringField(max_length=20)
|
||||||
userid = StringField(r'[0-9a-z_]+$')
|
userid = StringField(r'[0-9a-z_]+$')
|
||||||
|
|
||||||
person = Person()
|
person = Person(name=34)
|
||||||
self.assertRaises(ValidationError, person.__setattr__, 'name', 34)
|
self.assertRaises(ValidationError, person.validate)
|
||||||
|
|
||||||
# Test regex validation on userid
|
# Test regex validation on userid
|
||||||
self.assertRaises(ValidationError, person.__setattr__, 'userid',
|
person = Person(userid='test.User')
|
||||||
'test.User')
|
self.assertRaises(ValidationError, person.validate)
|
||||||
|
|
||||||
person.userid = 'test_user'
|
person.userid = 'test_user'
|
||||||
self.assertEqual(person.userid, 'test_user')
|
self.assertEqual(person.userid, 'test_user')
|
||||||
|
person.validate()
|
||||||
|
|
||||||
# Test max length validation on name
|
# Test max length validation on name
|
||||||
self.assertRaises(ValidationError, person.__setattr__, 'name',
|
person = Person(name='Name that is more than twenty characters')
|
||||||
'Name that is more than twenty characters')
|
self.assertRaises(ValidationError, person.validate)
|
||||||
|
|
||||||
person.name = 'Shorter name'
|
person.name = 'Shorter name'
|
||||||
self.assertEqual(person.name, 'Shorter name')
|
person.validate()
|
||||||
|
|
||||||
def test_int_validation(self):
|
def test_int_validation(self):
|
||||||
"""Ensure that invalid values cannot be assigned to int fields.
|
"""Ensure that invalid values cannot be assigned to int fields.
|
||||||
@ -81,9 +87,14 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
person = Person()
|
person = Person()
|
||||||
person.age = 50
|
person.age = 50
|
||||||
self.assertRaises(ValidationError, person.__setattr__, 'age', -1)
|
person.validate()
|
||||||
self.assertRaises(ValidationError, person.__setattr__, 'age', 120)
|
|
||||||
self.assertRaises(ValidationError, person.__setattr__, 'age', 'ten')
|
person.age = -1
|
||||||
|
self.assertRaises(ValidationError, person.validate)
|
||||||
|
person.age = 120
|
||||||
|
self.assertRaises(ValidationError, person.validate)
|
||||||
|
person.age = 'ten'
|
||||||
|
self.assertRaises(ValidationError, person.validate)
|
||||||
|
|
||||||
def test_float_validation(self):
|
def test_float_validation(self):
|
||||||
"""Ensure that invalid values cannot be assigned to float fields.
|
"""Ensure that invalid values cannot be assigned to float fields.
|
||||||
@ -93,9 +104,14 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
person = Person()
|
person = Person()
|
||||||
person.height = 1.89
|
person.height = 1.89
|
||||||
self.assertRaises(ValidationError, person.__setattr__, 'height', 2)
|
person.validate()
|
||||||
self.assertRaises(ValidationError, person.__setattr__, 'height', 0.01)
|
|
||||||
self.assertRaises(ValidationError, person.__setattr__, 'height', 4.0)
|
person.height = 2
|
||||||
|
self.assertRaises(ValidationError, person.validate)
|
||||||
|
person.height = 0.01
|
||||||
|
self.assertRaises(ValidationError, person.validate)
|
||||||
|
person.height = 4.0
|
||||||
|
self.assertRaises(ValidationError, person.validate)
|
||||||
|
|
||||||
def test_datetime_validation(self):
|
def test_datetime_validation(self):
|
||||||
"""Ensure that invalid values cannot be assigned to datetime fields.
|
"""Ensure that invalid values cannot be assigned to datetime fields.
|
||||||
@ -104,9 +120,13 @@ class FieldTest(unittest.TestCase):
|
|||||||
time = DateTimeField()
|
time = DateTimeField()
|
||||||
|
|
||||||
log = LogEntry()
|
log = LogEntry()
|
||||||
self.assertRaises(ValidationError, log.__setattr__, 'time', -1)
|
|
||||||
self.assertRaises(ValidationError, log.__setattr__, 'time', '1pm')
|
|
||||||
log.time = datetime.datetime.now()
|
log.time = datetime.datetime.now()
|
||||||
|
log.validate()
|
||||||
|
|
||||||
|
log.time = -1
|
||||||
|
self.assertRaises(ValidationError, log.validate)
|
||||||
|
log.time = '1pm'
|
||||||
|
self.assertRaises(ValidationError, log.validate)
|
||||||
|
|
||||||
def test_list_validation(self):
|
def test_list_validation(self):
|
||||||
"""Ensure that a list field only accepts lists with valid elements.
|
"""Ensure that a list field only accepts lists with valid elements.
|
||||||
@ -120,16 +140,26 @@ class FieldTest(unittest.TestCase):
|
|||||||
tags = ListField(StringField())
|
tags = ListField(StringField())
|
||||||
|
|
||||||
post = BlogPost(content='Went for a walk today...')
|
post = BlogPost(content='Went for a walk today...')
|
||||||
self.assertRaises(ValidationError, post.__setattr__, 'tags', 'fun')
|
post.validate()
|
||||||
self.assertRaises(ValidationError, post.__setattr__, 'tags', [1, 2])
|
|
||||||
|
post.tags = 'fun'
|
||||||
|
self.assertRaises(ValidationError, post.validate)
|
||||||
|
post.tags = [1, 2]
|
||||||
|
self.assertRaises(ValidationError, post.validate)
|
||||||
|
|
||||||
post.tags = ['fun', 'leisure']
|
post.tags = ['fun', 'leisure']
|
||||||
|
post.validate()
|
||||||
post.tags = ('fun', 'leisure')
|
post.tags = ('fun', 'leisure')
|
||||||
|
post.validate()
|
||||||
|
|
||||||
comments = [Comment(content='Good for you'), Comment(content='Yay.')]
|
comments = [Comment(content='Good for you'), Comment(content='Yay.')]
|
||||||
self.assertRaises(ValidationError, post.__setattr__, 'comments', ['a'])
|
|
||||||
self.assertRaises(ValidationError, post.__setattr__, 'comments', 'Yay')
|
|
||||||
self.assertRaises(ValidationError, post.__setattr__, 'comments', 'Yay')
|
|
||||||
post.comments = comments
|
post.comments = comments
|
||||||
|
post.validate()
|
||||||
|
|
||||||
|
post.comments = ['a']
|
||||||
|
self.assertRaises(ValidationError, post.validate)
|
||||||
|
post.comments = 'yay'
|
||||||
|
self.assertRaises(ValidationError, post.validate)
|
||||||
|
|
||||||
def test_embedded_document_validation(self):
|
def test_embedded_document_validation(self):
|
||||||
"""Ensure that invalid embedded documents cannot be assigned to
|
"""Ensure that invalid embedded documents cannot be assigned to
|
||||||
@ -147,12 +177,15 @@ class FieldTest(unittest.TestCase):
|
|||||||
preferences = EmbeddedDocumentField(PersonPreferences)
|
preferences = EmbeddedDocumentField(PersonPreferences)
|
||||||
|
|
||||||
person = Person(name='Test User')
|
person = Person(name='Test User')
|
||||||
self.assertRaises(ValidationError, person.__setattr__, 'preferences',
|
person.preferences = 'My Preferences'
|
||||||
'My preferences')
|
self.assertRaises(ValidationError, person.validate)
|
||||||
self.assertRaises(ValidationError, person.__setattr__, 'preferences',
|
|
||||||
Comment(content='Nice blog post...'))
|
person.preferences = Comment(content='Nice blog post...')
|
||||||
|
self.assertRaises(ValidationError, person.validate)
|
||||||
|
|
||||||
person.preferences = PersonPreferences(food='Cheese', number=47)
|
person.preferences = PersonPreferences(food='Cheese', number=47)
|
||||||
self.assertEqual(person.preferences.food, 'Cheese')
|
self.assertEqual(person.preferences.food, 'Cheese')
|
||||||
|
person.validate()
|
||||||
|
|
||||||
def test_embedded_document_inheritance(self):
|
def test_embedded_document_inheritance(self):
|
||||||
"""Ensure that subclasses of embedded documents may be provided to
|
"""Ensure that subclasses of embedded documents may be provided to
|
||||||
@ -194,14 +227,16 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
# Check that an invalid object type cannot be used
|
# Check that an invalid object type cannot be used
|
||||||
post2 = BlogPost(content='Chips and chilli taste good.')
|
post2 = BlogPost(content='Chips and chilli taste good.')
|
||||||
self.assertRaises(ValidationError, post1.__setattr__, 'author', post2)
|
post1.author = post2
|
||||||
|
self.assertRaises(ValidationError, post1.validate)
|
||||||
|
|
||||||
user.save()
|
user.save()
|
||||||
post1.author = user
|
post1.author = user
|
||||||
post1.save()
|
post1.save()
|
||||||
|
|
||||||
post2.save()
|
post2.save()
|
||||||
self.assertRaises(ValidationError, post1.__setattr__, 'author', post2)
|
post1.author = post2
|
||||||
|
self.assertRaises(ValidationError, post1.validate)
|
||||||
|
|
||||||
User.drop_collection()
|
User.drop_collection()
|
||||||
BlogPost.drop_collection()
|
BlogPost.drop_collection()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user