From ee0c75a26da5d83a6d8042f0317f8a3e92a39bb3 Mon Sep 17 00:00:00 2001 From: Don Spaulding Date: Thu, 15 Apr 2010 17:59:35 -0500 Subject: [PATCH] Add choices keyword argument to BaseField.__init__() --- mongoengine/base.py | 14 +++++++++++--- tests/fields.py | 21 +++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/mongoengine/base.py b/mongoengine/base.py index 55323ddc..78f06500 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -24,7 +24,8 @@ class BaseField(object): _index_with_types = True def __init__(self, db_field=None, name=None, required=False, default=None, - unique=False, unique_with=None, primary_key=False): + unique=False, unique_with=None, primary_key=False, + choices=None): self.db_field = (db_field or name) if not primary_key else '_id' if name: import warnings @@ -36,6 +37,7 @@ class BaseField(object): self.unique = bool(unique or unique_with) self.unique_with = unique_with self.primary_key = primary_key + self.choices = choices def __get__(self, instance, owner): """Descriptor for retrieving a value from a field in a document. Do @@ -79,6 +81,12 @@ class BaseField(object): """ pass + def _validate(self, value): + if self.choices is not None: + if value not in self.choices: + raise ValidationError("Value must be one of %s."%unicode(self.choices)) + self.validate(value) + class ObjectIdField(BaseField): """An field wrapper around MongoDB's ObjectIds. @@ -314,7 +322,7 @@ class BaseDocument(object): for field, value in fields: if value is not None: try: - field.validate(value) + field._validate(value) except (ValueError, AttributeError, AssertionError), e: raise ValidationError('Invalid value for field of type "' + field.__class__.__name__ + '"') @@ -441,4 +449,4 @@ if sys.version_info < (2, 5): return types.ClassType(name, parents, {}) else: def subclass_exception(name, parents, module): - return type(name, parents, {'__module__': module}) \ No newline at end of file + return type(name, parents, {'__module__': module}) diff --git a/tests/fields.py b/tests/fields.py index 7e68155c..4050e264 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -588,5 +588,26 @@ class FieldTest(unittest.TestCase): AttachmentRequired.drop_collection() AttachmentSizeLimit.drop_collection() + def test_choices_validation(self): + """Ensure that value is in a container of allowed values. + """ + class Shirt(Document): + size = StringField(max_length=3, choices=('S','M','L','XL','XXL')) + + Shirt.drop_collection() + + shirt = Shirt() + shirt.validate() + + shirt.size = "S" + shirt.validate() + + shirt.size = "XS" + self.assertRaises(ValidationError, shirt.validate) + + Shirt.drop_collection() + + + if __name__ == '__main__': unittest.main()