fix doc.get_<field>_display + unit test inspired by #1279 (#1419)

This commit is contained in:
Stefan Wójcik 2016-12-04 00:34:24 -05:00 committed by GitHub
parent 0007535a46
commit eb743beaa3
2 changed files with 38 additions and 35 deletions

View File

@ -121,7 +121,7 @@ class BaseDocument(object):
else:
self._data[key] = value
# Set any get_fieldname_display methods
# Set any get_<field>_display methods
self.__set_field_display()
if self._dynamic:
@ -1005,19 +1005,18 @@ class BaseDocument(object):
return '.'.join(parts)
def __set_field_display(self):
"""Dynamically set the display value for a field with choices"""
for attr_name, field in self._fields.items():
if field.choices:
if self._dynamic:
obj = self
else:
obj = type(self)
setattr(obj,
'get_%s_display' % attr_name,
partial(self.__get_field_display, field=field))
"""For each field that specifies choices, create a
get_<field>_display method.
"""
fields_with_choices = [(n, f) for n, f in self._fields.items()
if f.choices]
for attr_name, field in fields_with_choices:
setattr(self,
'get_%s_display' % attr_name,
partial(self.__get_field_display, field=field))
def __get_field_display(self, field):
"""Returns the display value for a choice field"""
"""Return the display value for a choice field"""
value = getattr(self, field.name)
if field.choices and isinstance(field.choices[0], (list, tuple)):
return dict(field.choices).get(value, value)

View File

@ -3001,28 +3001,32 @@ class FieldTest(unittest.TestCase):
('S', 'Small'), ('M', 'Medium'), ('L', 'Large'),
('XL', 'Extra Large'), ('XXL', 'Extra Extra Large')))
style = StringField(max_length=3, choices=(
('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S')
('S', 'Small'), ('B', 'Baggy'), ('W', 'Wide')), default='W')
Shirt.drop_collection()
shirt = Shirt()
shirt1 = Shirt()
shirt2 = Shirt()
self.assertEqual(shirt.get_size_display(), None)
self.assertEqual(shirt.get_style_display(), 'Small')
# Make sure get_<field>_display returns the default value (or None)
self.assertEqual(shirt1.get_size_display(), None)
self.assertEqual(shirt1.get_style_display(), 'Wide')
shirt.size = "XXL"
shirt.style = "B"
self.assertEqual(shirt.get_size_display(), 'Extra Extra Large')
self.assertEqual(shirt.get_style_display(), 'Baggy')
shirt1.size = 'XXL'
shirt1.style = 'B'
shirt2.size = 'M'
shirt2.style = 'S'
self.assertEqual(shirt1.get_size_display(), 'Extra Extra Large')
self.assertEqual(shirt1.get_style_display(), 'Baggy')
self.assertEqual(shirt2.get_size_display(), 'Medium')
self.assertEqual(shirt2.get_style_display(), 'Small')
# Set as Z - an invalid choice
shirt.size = "Z"
shirt.style = "Z"
self.assertEqual(shirt.get_size_display(), 'Z')
self.assertEqual(shirt.get_style_display(), 'Z')
self.assertRaises(ValidationError, shirt.validate)
Shirt.drop_collection()
shirt1.size = 'Z'
shirt1.style = 'Z'
self.assertEqual(shirt1.get_size_display(), 'Z')
self.assertEqual(shirt1.get_style_display(), 'Z')
self.assertRaises(ValidationError, shirt1.validate)
def test_simple_choices_validation(self):
"""Ensure that value is in a container of allowed values.