diff --git a/AUTHORS b/AUTHORS index 200182b0..9ebc0054 100644 --- a/AUTHORS +++ b/AUTHORS @@ -72,4 +72,5 @@ that much better: * Paul Cunnane * Julien Rebetez * Marc Tamlyn + * Karim Allah diff --git a/docs/changelog.rst b/docs/changelog.rst index facb1b99..dbb4a728 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in dev ============== +- Added Non-Django Style choices back (you can have either) - Fixed __repr__ of a sliced queryset - Added recursive validation error of documents / complex fields - Fixed breaking during queryset iteration diff --git a/mongoengine/base.py b/mongoengine/base.py index ab6ccbee..24a37974 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -195,14 +195,14 @@ class BaseField(object): def _validate(self, value): # check choices - if self.choices is not None: - if type(choices[0]) is tuple: - option_keys = [option_key for option_key, option_value in self.choices] - if value not in option_keys: - self.error('Value must be one of %s' % unicode(option_keys)) - else: - if value not in self.choices: + if self.choices: + if isinstance(self.choices[0], (list, tuple)): + option_keys = [option_key for option_key, option_value in self.choices] + if value not in option_keys: self.error('Value must be one of %s' % unicode(option_keys)) + else: + if value not in self.choices: + self.error('Value must be one of %s' % unicode(self.choices)) # check validation argument if self.validation is not None: @@ -1051,7 +1051,9 @@ class BaseDocument(object): def __get_field_display(self, field): """Returns the display value for a choice field""" value = getattr(self, field.name) - return dict(field.choices).get(value, value) + if field.choices and isinstance(field.choices[0], (list, tuple)): + return dict(field.choices).get(value, value) + return value def __iter__(self): return iter(self._fields) diff --git a/tests/fields.py b/tests/fields.py index dd68cb55..3b697917 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -1280,6 +1280,53 @@ class FieldTest(unittest.TestCase): Shirt.drop_collection() + def test_simple_choices_validation(self): + """Ensure that value is in a container of allowed values. + """ + class Shirt(Document): + size = StringField(max_length=3, choices=('S', 'M', 'L', 'XL', 'XXL')) + + Shirt.drop_collection() + + shirt = Shirt() + shirt.validate() + + shirt.size = "S" + shirt.validate() + + shirt.size = "XS" + self.assertRaises(ValidationError, shirt.validate) + + Shirt.drop_collection() + + def test_simple_choices_get_field_display(self): + """Test dynamic helper for returning the display value of a choices field. + """ + class Shirt(Document): + size = StringField(max_length=3, choices=('S', 'M', 'L', 'XL', 'XXL')) + style = StringField(max_length=3, choices=('Small', 'Baggy', 'wide'), default='Small') + + Shirt.drop_collection() + + shirt = Shirt() + + self.assertEqual(shirt.get_size_display(), None) + self.assertEqual(shirt.get_style_display(), 'Small') + + shirt.size = "XXL" + shirt.style = "Baggy" + self.assertEqual(shirt.get_size_display(), 'XXL') + self.assertEqual(shirt.get_style_display(), 'Baggy') + + # Set as Z - an invalid choice + shirt.size = "Z" + shirt.style = "Z" + self.assertEqual(shirt.get_size_display(), 'Z') + self.assertEqual(shirt.get_style_display(), 'Z') + self.assertRaises(ValidationError, shirt.validate) + + Shirt.drop_collection() + def test_file_fields(self): """Ensure that file fields can be written to and their data retrieved """