From d32dd9ff62c0984af5062a4b52f974bb009b22a3 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 8 Jun 2011 13:07:08 +0100 Subject: [PATCH] Added _get_FIELD_display() for handy choice field display lookups closes #188 --- docs/changelog.rst | 1 + mongoengine/base.py | 12 +++++++++++- tests/fields.py | 29 +++++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 0a2a273f..c76b1154 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in dev ============== +- Added get_FIELD_display() method for easy choice field displaying. - Added queryset.slave_okay(enabled) method - Updated queryset.timeout(enabled) and queryset.snapshot(enabled) to be chainable - Added insert method for bulk inserts diff --git a/mongoengine/base.py b/mongoengine/base.py index 76bb1ab7..3875fea5 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -8,6 +8,7 @@ import sys import pymongo import pymongo.objectid from operator import itemgetter +from functools import partial class NotRegistered(Exception): @@ -61,6 +62,7 @@ class BaseField(object): self.primary_key = primary_key self.validation = validation self.choices = choices + # Adjust the appropriate creation counter, and save our local copy. if self.db_field == '_id': self.creation_counter = BaseField.auto_creation_counter @@ -471,7 +473,10 @@ class BaseDocument(object): self._data = {} # Assign default values to instance - for attr_name in self._fields.keys(): + for attr_name, field in self._fields.items(): + if field.choices: # dynamically adds a way to get the display value for a field with choices + setattr(self, 'get_%s_display' % attr_name, partial(self._get_FIELD_display, field=field)) + # Use default value if present value = getattr(self, attr_name, None) setattr(self, attr_name, value) @@ -484,6 +489,11 @@ class BaseDocument(object): signals.post_init.send(self) + def _get_FIELD_display(self, field): + """Returns the display value for a choice field""" + value = getattr(self, field.name) + return dict(field.choices).get(value, value) + def validate(self): """Ensure that all fields' values are valid and that required fields are present. diff --git a/tests/fields.py b/tests/fields.py index 320e33db..d8970043 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -773,6 +773,35 @@ class FieldTest(unittest.TestCase): Shirt.drop_collection() + def test_choices_get_field_display(self): + """Test dynamic helper for returning the display value of a choices field. + """ + class Shirt(Document): + size = StringField(max_length=3, choices=(('S', 'Small'), ('M', 'Medium'), ('L', 'Large'), + ('XL', 'Extra Large'), ('XXL', 'Extra Extra Large'))) + style = StringField(max_length=3, choices=(('S', 'Small'), ('B', 'Baggy'), ('W', 'wide')), default='S') + + Shirt.drop_collection() + + shirt = Shirt() + + self.assertEqual(shirt.get_size_display(), None) + self.assertEqual(shirt.get_style_display(), 'Small') + + shirt.size = "XXL" + shirt.style = "B" + self.assertEqual(shirt.get_size_display(), 'Extra Extra Large') + self.assertEqual(shirt.get_style_display(), 'Baggy') + + # Set as Z - an invalid choice + shirt.size = "Z" + shirt.style = "Z" + self.assertEqual(shirt.get_size_display(), 'Z') + self.assertEqual(shirt.get_style_display(), 'Z') + self.assertRaises(ValidationError, shirt.validate) + + Shirt.drop_collection() + def test_file_fields(self): """Ensure that file fields can be written to and their data retrieved """