diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index 76791e27..3c276869 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -47,6 +47,40 @@ are as follows: * :class:`~mongoengine.ReferenceField` * :class:`~mongoengine.GenericReferenceField` +Field arguments +--------------- +Each field type can be customized by keyword arguments. The following keyword +arguments can be set on all fields: + +:attr:`db_field` (Default: None) + The MongoDB field name. + +:attr:`name` (Default: None) + The mongoengine field name. + +:attr:`required` (Default: False) + If set to True and the field is not set on the document instance, a + :class:`~mongoengine.base.ValidationError` will be raised when the document is + validated. + +:attr:`default` (Default: None) + A value to use when no value is set for this field. + +:attr:`unique` (Default: False) + When True, no documents in the collection will have the same value for this + field. + +:attr:`unique_with` (Default: None) + A field name (or list of field names) that when taken together with this + field, will not have two documents in the collection with the same value. + +:attr:`primary_key` (Default: False) + When True, use this field as a primary key for the collection. + +:attr:`choices` (Default: None) + An iterable of choices to which the value of this field should be limited. + + List fields ----------- MongoDB allows the storage of lists of items. To add a list of items to a diff --git a/mongoengine/base.py b/mongoengine/base.py index 43a54eea..8c814ac1 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -25,7 +25,7 @@ class BaseField(object): def __init__(self, db_field=None, name=None, required=False, default=None, unique=False, unique_with=None, primary_key=False, validation=None, - create_default=False): + create_default=False, choices=None): self.db_field = (db_field or name) if not primary_key else '_id' if name: import warnings @@ -39,6 +39,7 @@ class BaseField(object): self.primary_key = primary_key self.validation = validation self.create_default = create_default + self.choices = choices def __get__(self, instance, owner): """Descriptor for retrieving a value from a field in a document. Do @@ -84,11 +85,23 @@ class BaseField(object): def validate(self, value): """Perform validation on a value. """ + pass + + def _validate(self, value): + # check choices + if self.choices is not None: + if value not in self.choices: + raise ValidationError("Value must be one of %s."%unicode(self.choices)) + + # check validation argument if self.validation is not None: - if (isinstance(self.validation, list) or isinstance(self.validation, tuple)) and value not in self.validation: - raise ValidationError('Value not in validation list.') - elif callable(self.validation) and not self.validation(value): - raise ValidationError('Value does not match custom validation method.') + if callable(self.validation): + if not self.validation(value): + raise ValidationError('Value does not match custom validation method.') + else: + raise ValueError('validation argument must be a callable.') + + self.validate(value) class ObjectIdField(BaseField): """An field wrapper around MongoDB's ObjectIds. @@ -324,7 +337,7 @@ class BaseDocument(object): for field, value in fields: if value is not None: try: - field.validate(value) + field._validate(value) except (ValueError, AttributeError, AssertionError), e: raise ValidationError('Invalid value for field of type "' + field.__class__.__name__ + '"') @@ -451,4 +464,4 @@ if sys.version_info < (2, 5): return types.ClassType(name, parents, {}) else: def subclass_exception(name, parents, module): - return type(name, parents, {'__module__': module}) \ No newline at end of file + return type(name, parents, {'__module__': module}) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index e4d95d53..5b2f14a4 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -43,8 +43,6 @@ class StringField(BaseField): if self.regex is not None and self.regex.match(value) is None: message = 'String value did not match validation regex' raise ValidationError(message) - - super(StringField, self).validate(value) def lookup_member(self, member_name): return None @@ -99,8 +97,6 @@ class URLField(StringField): except Exception, e: message = 'This URL appears to be a broken link: %s' % e raise ValidationError(message) - - super(URLField, self).validate(value) class EmailField(StringField): """A field that validates input as an E-Mail-Address. @@ -138,8 +134,6 @@ class IntField(BaseField): if self.max_value is not None and value > self.max_value: raise ValidationError('Integer value is too large') - - super(IntField, self).validate(value) class FloatField(BaseField): """An floating point number field. @@ -162,9 +156,6 @@ class FloatField(BaseField): if self.max_value is not None and value > self.max_value: raise ValidationError('Float value is too large') - - super(FloatField, self).validate(value) - class DecimalField(BaseField): """A fixed-point decimal number field. @@ -198,9 +189,6 @@ class DecimalField(BaseField): if self.max_value is not None and value > self.max_value: raise ValidationError('Decimal value is too large') - - super(DecimalField, self).validate(value) - class BooleanField(BaseField): """A boolean field type. @@ -213,9 +201,6 @@ class BooleanField(BaseField): def validate(self, value): assert isinstance(value, bool) - - super(BooleanField, self).validate(value) - class DateTimeField(BaseField): """A datetime field. @@ -223,9 +208,6 @@ class DateTimeField(BaseField): def validate(self, value): assert isinstance(value, datetime.datetime) - - super(DateTimeField, self).validate(value) - class EmbeddedDocumentField(BaseField): """An embedded document field. Only valid values are subclasses of @@ -256,8 +238,6 @@ class EmbeddedDocumentField(BaseField): raise ValidationError('Invalid embedded document instance ' 'provided to an EmbeddedDocumentField') self.document.validate(value) - - super(EmbeddedDocumentField, self).validate(value) def lookup_member(self, member_name): return self.document._fields.get(member_name) @@ -334,8 +314,6 @@ class ListField(BaseField): [self.field.validate(item) for item in value] except Exception, err: raise ValidationError('Invalid ListField item (%s)' % str(err)) - - super(ListField, self).validate(value) def prepare_query_value(self, op, value): if op in ('set', 'unset'): @@ -380,8 +358,6 @@ class DictField(BaseField): if any(('.' in k or '$' in k) for k in value): raise ValidationError('Invalid dictionary key name - keys may not ' 'contain "." or "$" characters') - - super(DictField, self).validate(value) def lookup_member(self, member_name): return BaseField(db_field=member_name) @@ -397,8 +373,6 @@ class GeoLocationField(DictField): if len(value) <> 2: raise ValidationError('GeoLocationField must have exactly two elements (x, y)') - - super(GeoLocationField, self).validate(value) def to_mongo(self, value): return {'x': value[0], 'y': value[1]} @@ -468,8 +442,6 @@ class ReferenceField(BaseField): def validate(self, value): assert isinstance(value, (self.document_type, pymongo.dbref.DBRef)) - - super(ReferenceField, self).validate(value) def lookup_member(self, member_name): return self.document_type._fields.get(member_name) @@ -539,6 +511,4 @@ class BinaryField(BaseField): assert isinstance(value, str) if self.max_bytes is not None and len(value) > self.max_bytes: - raise ValidationError('Binary value is too long') - - super(BinaryField, self).validate(value) + raise ValidationError('Binary value is too long') \ No newline at end of file diff --git a/tests/fields.py b/tests/fields.py index 7e68155c..4050e264 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -588,5 +588,26 @@ class FieldTest(unittest.TestCase): AttachmentRequired.drop_collection() AttachmentSizeLimit.drop_collection() + def test_choices_validation(self): + """Ensure that value is in a container of allowed values. + """ + class Shirt(Document): + size = StringField(max_length=3, choices=('S','M','L','XL','XXL')) + + Shirt.drop_collection() + + shirt = Shirt() + shirt.validate() + + shirt.size = "S" + shirt.validate() + + shirt.size = "XS" + self.assertRaises(ValidationError, shirt.validate) + + Shirt.drop_collection() + + + if __name__ == '__main__': unittest.main()