diff --git a/docs/changelog.rst b/docs/changelog.rst index 15331014..abbc1e4c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,7 @@ Changelog Changes in dev ============== +- Fixed tree based circular reference bug - Add field name to validation exception messages - Added UUID field - Improved efficiency of .get() @@ -17,7 +18,7 @@ Changes in dev Changes in v0.5.1 ================= -- Circular reference bugfix +- Fixed simple circular reference bug Changes in v0.5 =============== diff --git a/mongoengine/base.py b/mongoengine/base.py index adf5eee9..e8a3fe56 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -837,13 +837,19 @@ class BaseDocument(object): if hasattr(self, '_changed_fields') and key not in self._changed_fields: self._changed_fields.append(key) - def _get_changed_fields(self, key=''): + def _get_changed_fields(self, key='', inspected=[]): """Returns a list of all fields that have explicitly been changed. """ from mongoengine import EmbeddedDocument, DynamicEmbeddedDocument _changed_fields = [] _changed_fields += getattr(self, '_changed_fields', []) + inspected = inspected or [] + if hasattr(self, 'id'): + if self.id in inspected: + return _changed_fields + inspected.append(self.id) + field_list = self._fields.copy() if self._dynamic: field_list.update(self._dynamic_fields) @@ -852,8 +858,13 @@ class BaseDocument(object): db_field_name = self._db_field_map.get(field_name, field_name) key = '%s.' % db_field_name field = getattr(self, field_name, None) + if hasattr(field, 'id'): + if field.id in inspected: + continue + inspected.append(field.id) + if isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument)) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed - _changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k] + _changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key, inspected) if k] elif isinstance(field, (list, tuple, dict)) and db_field_name not in _changed_fields: # Loop list / dict fields as they contain documents # Determine the iterator to use if not hasattr(field, 'items'): @@ -864,7 +875,7 @@ class BaseDocument(object): if not hasattr(value, '_get_changed_fields'): continue list_key = "%s%s." % (key, index) - _changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key) if k] + _changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key, inspected) if k] return _changed_fields def _delta(self): @@ -941,17 +952,17 @@ class BaseDocument(object): return set_data, unset_data @classmethod - def _geo_indices(cls, inspected_classes=None): - inspected_classes = inspected_classes or [] + def _geo_indices(cls, inspected=None): + inspected = inspected or [] geo_indices = [] - inspected_classes.append(cls) + inspected.append(cls) for field in cls._fields.values(): if hasattr(field, 'document_type'): field_cls = field.document_type - if field_cls in inspected_classes: + if field_cls in inspected: continue if hasattr(field_cls, '_geo_indices'): - geo_indices += field_cls._geo_indices(inspected_classes) + geo_indices += field_cls._geo_indices(inspected) elif field._geo_index: geo_indices.append(field) return geo_indices diff --git a/mongoengine/document.py b/mongoengine/document.py index e9b8871e..ce001d2a 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -185,18 +185,18 @@ class Document(BaseDocument): id_field = self._meta['id_field'] self[id_field] = self._fields[id_field].to_python(object_id) - def reset_changed_fields(doc, inspected_docs=None): + def reset_changed_fields(doc, inspected=None): """Loop through and reset changed fields lists""" - inspected_docs = inspected_docs or [] - inspected_docs.append(doc) + inspected = inspected or [] + inspected.append(doc) if hasattr(doc, '_changed_fields'): doc._changed_fields = [] for field_name in doc._fields: field = getattr(doc, field_name) - if field not in inspected_docs and hasattr(field, '_changed_fields'): - reset_changed_fields(field, inspected_docs) + if field not in inspected and hasattr(field, '_changed_fields'): + reset_changed_fields(field, inspected) reset_changed_fields(self) self._changed_fields = [] diff --git a/tests/dereference.py b/tests/dereference.py index b85ca179..088db98e 100644 --- a/tests/dereference.py +++ b/tests/dereference.py @@ -188,6 +188,51 @@ class FieldTest(unittest.TestCase): self.assertEquals("[, ]", "%s" % Person.objects()) + def test_circular_tree_reference(self): + """Ensure you can handle circular references with more than one level + """ + class Other(EmbeddedDocument): + name = StringField() + friends = ListField(ReferenceField('Person')) + + class Person(Document): + name = StringField() + other = EmbeddedDocumentField(Other, default=lambda: Other()) + + def __repr__(self): + return "" % self.name + + Person.drop_collection() + paul = Person(name="Paul") + paul.save() + maria = Person(name="Maria") + maria.save() + julia = Person(name='Julia') + julia.save() + anna = Person(name='Anna') + anna.save() + + paul.other.friends = [maria, julia, anna] + paul.other.name = "Paul's friends" + paul.save() + + maria.other.friends = [paul, julia, anna] + maria.other.name = "Maria's friends" + maria.save() + + julia.other.friends = [paul, maria, anna] + julia.other.name = "Julia's friends" + julia.save() + + anna.other.friends = [paul, maria, julia] + anna.other.name = "Anna's friends" + anna.save() + + self.assertEquals( + "[, , , ]", + "%s" % Person.objects() + ) + def test_generic_reference(self): class UserA(Document):