diff --git a/mongoengine/base.py b/mongoengine/base.py index cd6d8bab..40b03dfa 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -35,17 +35,8 @@ class BaseField(object): return value def __set__(self, instance, value): - """Descriptor for assigning a value to a field in a document. Do any - necessary conversion between Python and MongoDB types. + """Descriptor for assigning a value to a field in a document. """ - if value is not None: - try: - self.validate(value) - except (ValueError, AttributeError, AssertionError), e: - raise ValidationError('Invalid value for field of type "' + - self.__class__.__name__ + '"') - elif self.required: - raise ValidationError('Field "%s" is required' % self.name) instance._data[self.name] = value def to_python(self, value): @@ -183,8 +174,6 @@ class BaseDocument(object): else: # Use default value if present value = getattr(self, attr_name, None) - if value is None and attr_value.required: - raise ValidationError('Field "%s" is required' % attr_name) setattr(self, attr_name, value) @classmethod diff --git a/mongoengine/document.py b/mongoengine/document.py index b9093caa..c031c860 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,4 +1,5 @@ -from base import DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument +from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, + ValidationError) from connection import _get_db @@ -44,6 +45,7 @@ class Document(BaseDocument): document already exists, it will be updated, otherwise it will be created. """ + self.validate() object_id = self.__class__.objects._collection.save(self.to_mongo()) self.id = self._fields['id'].to_python(object_id) @@ -54,6 +56,25 @@ class Document(BaseDocument): object_id = self._fields['id'].to_mongo(self.id) self.__class__.objects(id=object_id).delete() + def validate(self): + """Ensure that all fields' values are valid and that required fields + are present. + """ + # Get a list of tuples of field names and their current values + fields = [(field, getattr(self, name)) + for name, field in self._fields.items()] + + # Ensure that each field is matched to a valid value + for field, value in fields: + if value is not None: + try: + field.validate(value) + except (ValueError, AttributeError, AssertionError), e: + raise ValidationError('Invalid value for field of type "' + + field.__class__.__name__ + '"') + elif field.required: + raise ValidationError('Field "%s" is required' % field.name) + @classmethod def drop_collection(cls): """Drops the entire collection associated with this diff --git a/tests/fields.py b/tests/fields.py index 9bad1ea1..b580dc20 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -31,13 +31,10 @@ class FieldTest(unittest.TestCase): age = IntField(required=True) userid = StringField() - self.assertRaises(ValidationError, Person, name="Test User") - self.assertRaises(ValidationError, Person, age=30) - - person = Person(name="Test User", age=30, userid="testuser") - self.assertRaises(ValidationError, person.__setattr__, 'name', None) - self.assertRaises(ValidationError, person.__setattr__, 'age', None) - person.userid = None + person = Person(name="Test User") + self.assertRaises(ValidationError, person.validate) + person = Person(age=30) + self.assertRaises(ValidationError, person.validate) def test_object_id_validation(self): """Ensure that invalid values cannot be assigned to string fields. @@ -47,9 +44,15 @@ class FieldTest(unittest.TestCase): person = Person(name='Test User') self.assertEqual(person.id, None) - self.assertRaises(ValidationError, person.__setattr__, 'id', 47) - self.assertRaises(ValidationError, person.__setattr__, 'id', 'abc') + + person.id = 47 + self.assertRaises(ValidationError, person.validate) + + person.id = 'abc' + self.assertRaises(ValidationError, person.validate) + person.id = '497ce96f395f2f052a494fd4' + person.validate() def test_string_validation(self): """Ensure that invalid values cannot be assigned to string fields. @@ -58,20 +61,23 @@ class FieldTest(unittest.TestCase): name = StringField(max_length=20) userid = StringField(r'[0-9a-z_]+$') - person = Person() - self.assertRaises(ValidationError, person.__setattr__, 'name', 34) + person = Person(name=34) + self.assertRaises(ValidationError, person.validate) # Test regex validation on userid - self.assertRaises(ValidationError, person.__setattr__, 'userid', - 'test.User') + person = Person(userid='test.User') + self.assertRaises(ValidationError, person.validate) + person.userid = 'test_user' self.assertEqual(person.userid, 'test_user') + person.validate() # Test max length validation on name - self.assertRaises(ValidationError, person.__setattr__, 'name', - 'Name that is more than twenty characters') + person = Person(name='Name that is more than twenty characters') + self.assertRaises(ValidationError, person.validate) + person.name = 'Shorter name' - self.assertEqual(person.name, 'Shorter name') + person.validate() def test_int_validation(self): """Ensure that invalid values cannot be assigned to int fields. @@ -81,9 +87,14 @@ class FieldTest(unittest.TestCase): person = Person() person.age = 50 - self.assertRaises(ValidationError, person.__setattr__, 'age', -1) - self.assertRaises(ValidationError, person.__setattr__, 'age', 120) - self.assertRaises(ValidationError, person.__setattr__, 'age', 'ten') + person.validate() + + person.age = -1 + self.assertRaises(ValidationError, person.validate) + person.age = 120 + self.assertRaises(ValidationError, person.validate) + person.age = 'ten' + self.assertRaises(ValidationError, person.validate) def test_float_validation(self): """Ensure that invalid values cannot be assigned to float fields. @@ -93,9 +104,14 @@ class FieldTest(unittest.TestCase): person = Person() person.height = 1.89 - self.assertRaises(ValidationError, person.__setattr__, 'height', 2) - self.assertRaises(ValidationError, person.__setattr__, 'height', 0.01) - self.assertRaises(ValidationError, person.__setattr__, 'height', 4.0) + person.validate() + + person.height = 2 + self.assertRaises(ValidationError, person.validate) + person.height = 0.01 + self.assertRaises(ValidationError, person.validate) + person.height = 4.0 + self.assertRaises(ValidationError, person.validate) def test_datetime_validation(self): """Ensure that invalid values cannot be assigned to datetime fields. @@ -104,9 +120,13 @@ class FieldTest(unittest.TestCase): time = DateTimeField() log = LogEntry() - self.assertRaises(ValidationError, log.__setattr__, 'time', -1) - self.assertRaises(ValidationError, log.__setattr__, 'time', '1pm') log.time = datetime.datetime.now() + log.validate() + + log.time = -1 + self.assertRaises(ValidationError, log.validate) + log.time = '1pm' + self.assertRaises(ValidationError, log.validate) def test_list_validation(self): """Ensure that a list field only accepts lists with valid elements. @@ -120,16 +140,26 @@ class FieldTest(unittest.TestCase): tags = ListField(StringField()) post = BlogPost(content='Went for a walk today...') - self.assertRaises(ValidationError, post.__setattr__, 'tags', 'fun') - self.assertRaises(ValidationError, post.__setattr__, 'tags', [1, 2]) + post.validate() + + post.tags = 'fun' + self.assertRaises(ValidationError, post.validate) + post.tags = [1, 2] + self.assertRaises(ValidationError, post.validate) + post.tags = ['fun', 'leisure'] + post.validate() post.tags = ('fun', 'leisure') + post.validate() comments = [Comment(content='Good for you'), Comment(content='Yay.')] - self.assertRaises(ValidationError, post.__setattr__, 'comments', ['a']) - self.assertRaises(ValidationError, post.__setattr__, 'comments', 'Yay') - self.assertRaises(ValidationError, post.__setattr__, 'comments', 'Yay') post.comments = comments + post.validate() + + post.comments = ['a'] + self.assertRaises(ValidationError, post.validate) + post.comments = 'yay' + self.assertRaises(ValidationError, post.validate) def test_embedded_document_validation(self): """Ensure that invalid embedded documents cannot be assigned to @@ -147,12 +177,15 @@ class FieldTest(unittest.TestCase): preferences = EmbeddedDocumentField(PersonPreferences) person = Person(name='Test User') - self.assertRaises(ValidationError, person.__setattr__, 'preferences', - 'My preferences') - self.assertRaises(ValidationError, person.__setattr__, 'preferences', - Comment(content='Nice blog post...')) + person.preferences = 'My Preferences' + self.assertRaises(ValidationError, person.validate) + + person.preferences = Comment(content='Nice blog post...') + self.assertRaises(ValidationError, person.validate) + person.preferences = PersonPreferences(food='Cheese', number=47) self.assertEqual(person.preferences.food, 'Cheese') + person.validate() def test_embedded_document_inheritance(self): """Ensure that subclasses of embedded documents may be provided to @@ -194,14 +227,16 @@ class FieldTest(unittest.TestCase): # Check that an invalid object type cannot be used post2 = BlogPost(content='Chips and chilli taste good.') - self.assertRaises(ValidationError, post1.__setattr__, 'author', post2) + post1.author = post2 + self.assertRaises(ValidationError, post1.validate) user.save() post1.author = user post1.save() post2.save() - self.assertRaises(ValidationError, post1.__setattr__, 'author', post2) + post1.author = post2 + self.assertRaises(ValidationError, post1.validate) User.drop_collection() BlogPost.drop_collection()