Fixes tree based circular references

Thanks to jpfarias for the fix.
Also normalised the other circular checks.
This commit is contained in:
Ross Lawley 2011-10-10 09:16:32 -07:00
parent 8797565606
commit 7b1860d17b
4 changed files with 71 additions and 14 deletions

View File

@ -6,6 +6,7 @@ Changelog
Changes in dev
==============
- Fixed tree based circular reference bug
- Add field name to validation exception messages
- Added UUID field
- Improved efficiency of .get()
@ -17,7 +18,7 @@ Changes in dev
Changes in v0.5.1
=================
- Circular reference bugfix
- Fixed simple circular reference bug
Changes in v0.5
===============

View File

@ -837,13 +837,19 @@ class BaseDocument(object):
if hasattr(self, '_changed_fields') and key not in self._changed_fields:
self._changed_fields.append(key)
def _get_changed_fields(self, key=''):
def _get_changed_fields(self, key='', inspected=[]):
"""Returns a list of all fields that have explicitly been changed.
"""
from mongoengine import EmbeddedDocument, DynamicEmbeddedDocument
_changed_fields = []
_changed_fields += getattr(self, '_changed_fields', [])
inspected = inspected or []
if hasattr(self, 'id'):
if self.id in inspected:
return _changed_fields
inspected.append(self.id)
field_list = self._fields.copy()
if self._dynamic:
field_list.update(self._dynamic_fields)
@ -852,8 +858,13 @@ class BaseDocument(object):
db_field_name = self._db_field_map.get(field_name, field_name)
key = '%s.' % db_field_name
field = getattr(self, field_name, None)
if hasattr(field, 'id'):
if field.id in inspected:
continue
inspected.append(field.id)
if isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument)) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed
_changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k]
_changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key, inspected) if k]
elif isinstance(field, (list, tuple, dict)) and db_field_name not in _changed_fields: # Loop list / dict fields as they contain documents
# Determine the iterator to use
if not hasattr(field, 'items'):
@ -864,7 +875,7 @@ class BaseDocument(object):
if not hasattr(value, '_get_changed_fields'):
continue
list_key = "%s%s." % (key, index)
_changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key) if k]
_changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key, inspected) if k]
return _changed_fields
def _delta(self):
@ -941,17 +952,17 @@ class BaseDocument(object):
return set_data, unset_data
@classmethod
def _geo_indices(cls, inspected_classes=None):
inspected_classes = inspected_classes or []
def _geo_indices(cls, inspected=None):
inspected = inspected or []
geo_indices = []
inspected_classes.append(cls)
inspected.append(cls)
for field in cls._fields.values():
if hasattr(field, 'document_type'):
field_cls = field.document_type
if field_cls in inspected_classes:
if field_cls in inspected:
continue
if hasattr(field_cls, '_geo_indices'):
geo_indices += field_cls._geo_indices(inspected_classes)
geo_indices += field_cls._geo_indices(inspected)
elif field._geo_index:
geo_indices.append(field)
return geo_indices

View File

@ -185,18 +185,18 @@ class Document(BaseDocument):
id_field = self._meta['id_field']
self[id_field] = self._fields[id_field].to_python(object_id)
def reset_changed_fields(doc, inspected_docs=None):
def reset_changed_fields(doc, inspected=None):
"""Loop through and reset changed fields lists"""
inspected_docs = inspected_docs or []
inspected_docs.append(doc)
inspected = inspected or []
inspected.append(doc)
if hasattr(doc, '_changed_fields'):
doc._changed_fields = []
for field_name in doc._fields:
field = getattr(doc, field_name)
if field not in inspected_docs and hasattr(field, '_changed_fields'):
reset_changed_fields(field, inspected_docs)
if field not in inspected and hasattr(field, '_changed_fields'):
reset_changed_fields(field, inspected)
reset_changed_fields(self)
self._changed_fields = []

View File

@ -188,6 +188,51 @@ class FieldTest(unittest.TestCase):
self.assertEquals("[<Person: Mother>, <Person: Daughter>]", "%s" % Person.objects())
def test_circular_tree_reference(self):
"""Ensure you can handle circular references with more than one level
"""
class Other(EmbeddedDocument):
name = StringField()
friends = ListField(ReferenceField('Person'))
class Person(Document):
name = StringField()
other = EmbeddedDocumentField(Other, default=lambda: Other())
def __repr__(self):
return "<Person: %s>" % self.name
Person.drop_collection()
paul = Person(name="Paul")
paul.save()
maria = Person(name="Maria")
maria.save()
julia = Person(name='Julia')
julia.save()
anna = Person(name='Anna')
anna.save()
paul.other.friends = [maria, julia, anna]
paul.other.name = "Paul's friends"
paul.save()
maria.other.friends = [paul, julia, anna]
maria.other.name = "Maria's friends"
maria.save()
julia.other.friends = [paul, maria, anna]
julia.other.name = "Julia's friends"
julia.save()
anna.other.friends = [paul, maria, julia]
anna.other.name = "Anna's friends"
anna.save()
self.assertEquals(
"[<Person: Paul>, <Person: Maria>, <Person: Julia>, <Person: Anna>]",
"%s" % Person.objects()
)
def test_generic_reference(self):
class UserA(Document):