fix doc.get_<field>_display + unit test inspired by #1279

This commit is contained in:
Stefan Wojcik 2016-12-03 17:26:39 -05:00
parent 088c5f49d9
commit 50923d809d
2 changed files with 38 additions and 35 deletions

View File

@ -121,7 +121,7 @@ class BaseDocument(object):
else: else:
self._data[key] = value self._data[key] = value
# Set any get_fieldname_display methods # Set any get_<field>_display methods
self.__set_field_display() self.__set_field_display()
if self._dynamic: if self._dynamic:
@ -1005,19 +1005,18 @@ class BaseDocument(object):
return '.'.join(parts) return '.'.join(parts)
def __set_field_display(self): def __set_field_display(self):
"""Dynamically set the display value for a field with choices""" """For each field that specifies choices, create a
for attr_name, field in self._fields.items(): get_<field>_display method.
if field.choices: """
if self._dynamic: fields_with_choices = [(n, f) for n, f in self._fields.items()
obj = self if f.choices]
else: for attr_name, field in fields_with_choices:
obj = type(self) setattr(self,
setattr(obj, 'get_%s_display' % attr_name,
'get_%s_display' % attr_name, partial(self.__get_field_display, field=field))
partial(self.__get_field_display, field=field))
def __get_field_display(self, field): def __get_field_display(self, field):
"""Returns the display value for a choice field""" """Return the display value for a choice field"""
value = getattr(self, field.name) value = getattr(self, field.name)
if field.choices and isinstance(field.choices[0], (list, tuple)): if field.choices and isinstance(field.choices[0], (list, tuple)):
return dict(field.choices).get(value, value) return dict(field.choices).get(value, value)

View File

@ -1047,7 +1047,7 @@ class FieldTest(unittest.TestCase):
BlogPost.drop_collection() BlogPost.drop_collection()
def test_list_assignment(self): def test_list_assignment(self):
"""Ensure that list field element assignment and slicing work """Ensure that list field element assignment and slicing work
""" """
class BlogPost(Document): class BlogPost(Document):
info = ListField() info = ListField()
@ -1057,12 +1057,12 @@ class FieldTest(unittest.TestCase):
post = BlogPost() post = BlogPost()
post.info = ['e1', 'e2', 3, '4', 5] post.info = ['e1', 'e2', 3, '4', 5]
post.save() post.save()
post.info[0] = 1 post.info[0] = 1
post.save() post.save()
post.reload() post.reload()
self.assertEqual(post.info[0], 1) self.assertEqual(post.info[0], 1)
post.info[1:3] = ['n2', 'n3'] post.info[1:3] = ['n2', 'n3']
post.save() post.save()
post.reload() post.reload()
@ -1209,7 +1209,7 @@ class FieldTest(unittest.TestCase):
self.assertEqual(simple.widgets, [4]) self.assertEqual(simple.widgets, [4])
def test_list_field_with_negative_indices(self): def test_list_field_with_negative_indices(self):
class Simple(Document): class Simple(Document):
widgets = ListField() widgets = ListField()
@ -1823,7 +1823,7 @@ class FieldTest(unittest.TestCase):
'parent': "50a234ea469ac1eda42d347d"}) 'parent': "50a234ea469ac1eda42d347d"})
mongoed = p1.to_mongo() mongoed = p1.to_mongo()
self.assertTrue(isinstance(mongoed['parent'], ObjectId)) self.assertTrue(isinstance(mongoed['parent'], ObjectId))
def test_cached_reference_field_get_and_save(self): def test_cached_reference_field_get_and_save(self):
""" """
Tests #1047: CachedReferenceField creates DBRefs on to_python, but can't save them on to_mongo Tests #1047: CachedReferenceField creates DBRefs on to_python, but can't save them on to_mongo
@ -1835,11 +1835,11 @@ class FieldTest(unittest.TestCase):
class Ocorrence(Document): class Ocorrence(Document):
person = StringField() person = StringField()
animal = CachedReferenceField(Animal) animal = CachedReferenceField(Animal)
Animal.drop_collection() Animal.drop_collection()
Ocorrence.drop_collection() Ocorrence.drop_collection()
Ocorrence(person="testte", Ocorrence(person="testte",
animal=Animal(name="Leopard", tag="heavy").save()).save() animal=Animal(name="Leopard", tag="heavy").save()).save()
p = Ocorrence.objects.get() p = Ocorrence.objects.get()
p.person = 'new_testte' p.person = 'new_testte'
@ -3001,28 +3001,32 @@ class FieldTest(unittest.TestCase):
('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'),
('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large')))
style = StringField(max_length=3, choices=( style = StringField(max_length=3, choices=(
('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S') ('S', 'Small'), ('B', 'Baggy'), ('W', 'Wide')), default='W')
Shirt.drop_collection() Shirt.drop_collection()
shirt = Shirt() shirt1 = Shirt()
shirt2 = Shirt()
self.assertEqual(shirt.get_size_display(), None) # Make sure get_<field>_display returns the default value (or None)
self.assertEqual(shirt.get_style_display(), 'Small') self.assertEqual(shirt1.get_size_display(), None)
self.assertEqual(shirt1.get_style_display(), 'Wide')
shirt.size = "XXL" shirt1.size = 'XXL'
shirt.style = "B" shirt1.style = 'B'
self.assertEqual(shirt.get_size_display(), 'Extra Extra Large') shirt2.size = 'M'
self.assertEqual(shirt.get_style_display(), 'Baggy') shirt2.style = 'S'
self.assertEqual(shirt1.get_size_display(), 'Extra Extra Large')
self.assertEqual(shirt1.get_style_display(), 'Baggy')
self.assertEqual(shirt2.get_size_display(), 'Medium')
self.assertEqual(shirt2.get_style_display(), 'Small')
# Set as Z - an invalid choice # Set as Z - an invalid choice
shirt.size = "Z" shirt1.size = 'Z'
shirt.style = "Z" shirt1.style = 'Z'
self.assertEqual(shirt.get_size_display(), 'Z') self.assertEqual(shirt1.get_size_display(), 'Z')
self.assertEqual(shirt.get_style_display(), 'Z') self.assertEqual(shirt1.get_style_display(), 'Z')
self.assertRaises(ValidationError, shirt.validate) self.assertRaises(ValidationError, shirt1.validate)
Shirt.drop_collection()
def test_simple_choices_validation(self): def test_simple_choices_validation(self):
"""Ensure that value is in a container of allowed values. """Ensure that value is in a container of allowed values.