From 452bbcc19b2f003efb9050227455d2489e5182d2 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Wed, 12 Oct 2011 00:30:12 -0700 Subject: [PATCH] Ported fix for Circular Reference bug to Master Ready for a 0.5.2 release --- AUTHORS | 2 ++ docs/changelog.rst | 5 +++++ mongoengine/base.py | 25 ++++++++++++++++++------ tests/dereference.py | 45 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 6 deletions(-) diff --git a/AUTHORS b/AUTHORS index b342830a..dd86cc5f 100644 --- a/AUTHORS +++ b/AUTHORS @@ -67,3 +67,5 @@ that much better: * Gareth Lloyd * Albert Choi * John Arnfield + * Julien Rebetez + diff --git a/docs/changelog.rst b/docs/changelog.rst index 0b93c74c..aa48b5cd 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -2,6 +2,11 @@ Changelog ========= +Changes in v0.5.2 +================= + +- A Robust Circular reference bugfix + Changes in v0.5.1 ================= diff --git a/mongoengine/base.py b/mongoengine/base.py index c4bcee1e..113cab66 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -724,18 +724,32 @@ class BaseDocument(object): if hasattr(self, '_changed_fields') and key not in self._changed_fields: self._changed_fields.append(key) - def _get_changed_fields(self, key=''): + def _get_changed_fields(self, key='', inspected=None): """Returns a list of all fields that have explicitly been changed. """ from mongoengine import EmbeddedDocument _changed_fields = [] _changed_fields += getattr(self, '_changed_fields', []) - for field_name in self._fields: + + inspected = inspected or set() + if hasattr(self, 'id'): + if self.id in inspected: + return _changed_fields + inspected.add(self.id) + + field_list = self._fields.copy() + + for field_name in field_list: db_field_name = self._db_field_map.get(field_name, field_name) key = '%s.' % db_field_name field = getattr(self, field_name, None) - if isinstance(field, EmbeddedDocument) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed - _changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k] + if hasattr(field, 'id'): + if field.id in inspected: + continue + inspected.add(field.id) + + if isinstance(field, (EmbeddedDocument,)) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed + _changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key, inspected) if k] elif isinstance(field, (list, tuple, dict)) and db_field_name not in _changed_fields: # Loop list / dict fields as they contain documents # Determine the iterator to use if not hasattr(field, 'items'): @@ -746,8 +760,7 @@ class BaseDocument(object): if not hasattr(value, '_get_changed_fields'): continue list_key = "%s%s." % (key, index) - _changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key) if k] - + _changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key, inspected) if k] return _changed_fields def _delta(self): diff --git a/tests/dereference.py b/tests/dereference.py index b85ca179..088db98e 100644 --- a/tests/dereference.py +++ b/tests/dereference.py @@ -188,6 +188,51 @@ class FieldTest(unittest.TestCase): self.assertEquals("[, ]", "%s" % Person.objects()) + def test_circular_tree_reference(self): + """Ensure you can handle circular references with more than one level + """ + class Other(EmbeddedDocument): + name = StringField() + friends = ListField(ReferenceField('Person')) + + class Person(Document): + name = StringField() + other = EmbeddedDocumentField(Other, default=lambda: Other()) + + def __repr__(self): + return "" % self.name + + Person.drop_collection() + paul = Person(name="Paul") + paul.save() + maria = Person(name="Maria") + maria.save() + julia = Person(name='Julia') + julia.save() + anna = Person(name='Anna') + anna.save() + + paul.other.friends = [maria, julia, anna] + paul.other.name = "Paul's friends" + paul.save() + + maria.other.friends = [paul, julia, anna] + maria.other.name = "Maria's friends" + maria.save() + + julia.other.friends = [paul, maria, anna] + julia.other.name = "Julia's friends" + julia.save() + + anna.other.friends = [paul, maria, julia] + anna.other.name = "Anna's friends" + anna.save() + + self.assertEquals( + "[, , , ]", + "%s" % Person.objects() + ) + def test_generic_reference(self): class UserA(Document):