From cae3f3eefffa3809dc396e65941f36f66c4bdb52 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Thu, 16 Jun 2011 12:50:45 +0100 Subject: [PATCH] Fixes pickling issue with choice fields Removes the dynamic __get_field_display partials before pickling --- mongoengine/base.py | 72 ++++++++++++++++++++++++++++++--------------- tests/document.py | 6 ++-- 2 files changed, 51 insertions(+), 27 deletions(-) diff --git a/mongoengine/base.py b/mongoengine/base.py index 49efba60..938808a8 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -614,9 +614,6 @@ class BaseDocument(object): self._data = {} # Assign default values to instance for attr_name, field in self._fields.items(): - if field.choices: # dynamically adds a way to get the display value for a field with choices - setattr(self, 'get_%s_display' % attr_name, partial(self._get_FIELD_display, field=field)) - value = getattr(self, attr_name, None) setattr(self, attr_name, value) @@ -628,9 +625,29 @@ class BaseDocument(object): except AttributeError: pass + # Set any get_fieldname_display methods + self.__set_field_display() + signals.post_init.send(self.__class__, document=self) - def _get_FIELD_display(self, field): + def __getstate__(self): + self_dict = self.__dict__ + removals = ["get_%s_display" % k for k,v in self._fields.items() if v.choices] + for k in removals: + if hasattr(self, k): + delattr(self, k) + return self.__dict__ + + def __setstate__(self, __dict__): + self.__dict__ = __dict__ + self.__set_field_display() + + def __set_field_display(self): + for attr_name, field in self._fields.items(): + if field.choices: # dynamically adds a way to get the display value for a field with choices + setattr(self, 'get_%s_display' % attr_name, partial(self.__get_field_display, field=field)) + + def __get_field_display(self, field): """Returns the display value for a choice field""" value = getattr(self, field.name) return dict(field.choices).get(value, value) @@ -865,42 +882,46 @@ class BaseList(list): super(BaseList, self).__init__(list_items) def __setitem__(self, *args, **kwargs): - if hasattr(self, 'instance') and hasattr(self, 'name'): - self.instance._mark_as_changed(self.name) + self._mark_as_changed() super(BaseList, self).__setitem__(*args, **kwargs) def __delitem__(self, *args, **kwargs): - self.instance._mark_as_changed(self.name) + self._mark_as_changed() super(BaseList, self).__delitem__(*args, **kwargs) def append(self, *args, **kwargs): - self.instance._mark_as_changed(self.name) + self._mark_as_changed() return super(BaseList, self).append(*args, **kwargs) def extend(self, *args, **kwargs): - self.instance._mark_as_changed(self.name) + self._mark_as_changed() return super(BaseList, self).extend(*args, **kwargs) def insert(self, *args, **kwargs): - self.instance._mark_as_changed(self.name) + self._mark_as_changed() return super(BaseList, self).insert(*args, **kwargs) def pop(self, *args, **kwargs): - self.instance._mark_as_changed(self.name) + self._mark_as_changed() return super(BaseList, self).pop(*args, **kwargs) def remove(self, *args, **kwargs): - self.instance._mark_as_changed(self.name) + self._mark_as_changed() return super(BaseList, self).remove(*args, **kwargs) def reverse(self, *args, **kwargs): - self.instance._mark_as_changed(self.name) + self._mark_as_changed() return super(BaseList, self).reverse(*args, **kwargs) def sort(self, *args, **kwargs): - self.instance._mark_as_changed(self.name) + self._mark_as_changed() return super(BaseList, self).sort(*args, **kwargs) + def _mark_as_changed(self): + """Marks a list as changed if has an instance and a name""" + if hasattr(self, 'instance') and hasattr(self, 'name'): + self.instance._mark_as_changed(self.name) + class BaseDict(dict): """A special dict so we can watch any changes @@ -912,39 +933,42 @@ class BaseDict(dict): super(BaseDict, self).__init__(dict_items) def __setitem__(self, *args, **kwargs): - if hasattr(self, 'instance') and hasattr(self, 'name'): - self.instance._mark_as_changed(self.name) + self._mark_as_changed() super(BaseDict, self).__setitem__(*args, **kwargs) def __setattr__(self, *args, **kwargs): - if hasattr(self, 'instance') and hasattr(self, 'name'): - self.instance._mark_as_changed(self.name) + self._mark_as_changed() super(BaseDict, self).__setattr__(*args, **kwargs) def __delete__(self, *args, **kwargs): - self.instance._mark_as_changed(self.name) + self._mark_as_changed() super(BaseDict, self).__delete__(*args, **kwargs) def __delitem__(self, *args, **kwargs): - self.instance._mark_as_changed(self.name) + self._mark_as_changed() super(BaseDict, self).__delitem__(*args, **kwargs) def __delattr__(self, *args, **kwargs): - self.instance._mark_as_changed(self.name) + self._mark_as_changed() super(BaseDict, self).__delattr__(*args, **kwargs) def clear(self, *args, **kwargs): - self.instance._mark_as_changed(self.name) + self._mark_as_changed() super(BaseDict, self).clear(*args, **kwargs) def pop(self, *args, **kwargs): - self.instance._mark_as_changed(self.name) + self._mark_as_changed() super(BaseDict, self).clear(*args, **kwargs) def popitem(self, *args, **kwargs): - self.instance._mark_as_changed(self.name) + self._mark_as_changed() super(BaseDict, self).clear(*args, **kwargs) + def _mark_as_changed(self): + """Marks a dict as changed if has an instance and a name""" + if hasattr(self, 'instance') and hasattr(self, 'name'): + self.instance._mark_as_changed(self.name) + if sys.version_info < (2, 5): # Prior to Python 2.5, Exception was an old-style class import types diff --git a/tests/document.py b/tests/document.py index 3a5419da..b33f3fe7 100644 --- a/tests/document.py +++ b/tests/document.py @@ -15,7 +15,7 @@ class PickleEmbedded(EmbeddedDocument): class PickleTest(Document): number = IntField() - string = StringField() + string = StringField(choices=(('One', '1'), ('Two', '2'))) embedded = EmbeddedDocumentField(PickleEmbedded) lists = ListField(StringField()) @@ -1516,7 +1516,7 @@ class DocumentTest(unittest.TestCase): def test_picklable(self): - pickle_doc = PickleTest(number=1, string="OH HAI", lists=['1', '2']) + pickle_doc = PickleTest(number=1, string="One", lists=['1', '2']) pickle_doc.embedded = PickleEmbedded() pickle_doc.save() @@ -1525,7 +1525,7 @@ class DocumentTest(unittest.TestCase): self.assertEquals(resurrected, pickle_doc) - resurrected.string = "Working" + resurrected.string = "Two" resurrected.save() pickle_doc.reload()