From 50923d809deab64900b594b32abf8964f7eecc28 Mon Sep 17 00:00:00 2001 From: Stefan Wojcik Date: Sat, 3 Dec 2016 17:26:39 -0500 Subject: [PATCH] fix doc.get__display + unit test inspired by #1279 --- mongoengine/base/document.py | 23 ++++++++--------- tests/fields/fields.py | 50 +++++++++++++++++++----------------- 2 files changed, 38 insertions(+), 35 deletions(-) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index eaa2019a..59f5aebc 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -121,7 +121,7 @@ class BaseDocument(object): else: self._data[key] = value - # Set any get_fieldname_display methods + # Set any get__display methods self.__set_field_display() if self._dynamic: @@ -1005,19 +1005,18 @@ class BaseDocument(object): return '.'.join(parts) def __set_field_display(self): - """Dynamically set the display value for a field with choices""" - for attr_name, field in self._fields.items(): - if field.choices: - if self._dynamic: - obj = self - else: - obj = type(self) - setattr(obj, - 'get_%s_display' % attr_name, - partial(self.__get_field_display, field=field)) + """For each field that specifies choices, create a + get__display method. + """ + fields_with_choices = [(n, f) for n, f in self._fields.items() + if f.choices] + for attr_name, field in fields_with_choices: + setattr(self, + 'get_%s_display' % attr_name, + partial(self.__get_field_display, field=field)) def __get_field_display(self, field): - """Returns the display value for a choice field""" + """Return the display value for a choice field""" value = getattr(self, field.name) if field.choices and isinstance(field.choices[0], (list, tuple)): return dict(field.choices).get(value, value) diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 36b9f4cd..2153a42e 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -1047,7 +1047,7 @@ class FieldTest(unittest.TestCase): BlogPost.drop_collection() def test_list_assignment(self): - """Ensure that list field element assignment and slicing work + """Ensure that list field element assignment and slicing work """ class BlogPost(Document): info = ListField() @@ -1057,12 +1057,12 @@ class FieldTest(unittest.TestCase): post = BlogPost() post.info = ['e1', 'e2', 3, '4', 5] post.save() - + post.info[0] = 1 post.save() post.reload() self.assertEqual(post.info[0], 1) - + post.info[1:3] = ['n2', 'n3'] post.save() post.reload() @@ -1209,7 +1209,7 @@ class FieldTest(unittest.TestCase): self.assertEqual(simple.widgets, [4]) def test_list_field_with_negative_indices(self): - + class Simple(Document): widgets = ListField() @@ -1823,7 +1823,7 @@ class FieldTest(unittest.TestCase): 'parent': "50a234ea469ac1eda42d347d"}) mongoed = p1.to_mongo() self.assertTrue(isinstance(mongoed['parent'], ObjectId)) - + def test_cached_reference_field_get_and_save(self): """ Tests #1047: CachedReferenceField creates DBRefs on to_python, but can't save them on to_mongo @@ -1835,11 +1835,11 @@ class FieldTest(unittest.TestCase): class Ocorrence(Document): person = StringField() animal = CachedReferenceField(Animal) - + Animal.drop_collection() Ocorrence.drop_collection() - - Ocorrence(person="testte", + + Ocorrence(person="testte", animal=Animal(name="Leopard", tag="heavy").save()).save() p = Ocorrence.objects.get() p.person = 'new_testte' @@ -3001,28 +3001,32 @@ class FieldTest(unittest.TestCase): ('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) style = StringField(max_length=3, choices=( - ('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S') + ('S', 'Small'), ('B', 'Baggy'), ('W', 'Wide')), default='W') Shirt.drop_collection() - shirt = Shirt() + shirt1 = Shirt() + shirt2 = Shirt() - self.assertEqual(shirt.get_size_display(), None) - self.assertEqual(shirt.get_style_display(), 'Small') + # Make sure get__display returns the default value (or None) + self.assertEqual(shirt1.get_size_display(), None) + self.assertEqual(shirt1.get_style_display(), 'Wide') - shirt.size = "XXL" - shirt.style = "B" - self.assertEqual(shirt.get_size_display(), 'Extra Extra Large') - self.assertEqual(shirt.get_style_display(), 'Baggy') + shirt1.size = 'XXL' + shirt1.style = 'B' + shirt2.size = 'M' + shirt2.style = 'S' + self.assertEqual(shirt1.get_size_display(), 'Extra Extra Large') + self.assertEqual(shirt1.get_style_display(), 'Baggy') + self.assertEqual(shirt2.get_size_display(), 'Medium') + self.assertEqual(shirt2.get_style_display(), 'Small') # Set as Z - an invalid choice - shirt.size = "Z" - shirt.style = "Z" - self.assertEqual(shirt.get_size_display(), 'Z') - self.assertEqual(shirt.get_style_display(), 'Z') - self.assertRaises(ValidationError, shirt.validate) - - Shirt.drop_collection() + shirt1.size = 'Z' + shirt1.style = 'Z' + self.assertEqual(shirt1.get_size_display(), 'Z') + self.assertEqual(shirt1.get_style_display(), 'Z') + self.assertRaises(ValidationError, shirt1.validate) def test_simple_choices_validation(self): """Ensure that value is in a container of allowed values.