Fixes tree based circular references

Thanks to jpfarias for the fix.
Also normalised the other circular checks.
This commit is contained in:
Ross Lawley 2011-10-10 09:16:32 -07:00
parent 8797565606
commit 7b1860d17b
4 changed files with 71 additions and 14 deletions

View File

@ -6,6 +6,7 @@ Changelog
Changes in dev Changes in dev
============== ==============
- Fixed tree based circular reference bug
- Add field name to validation exception messages - Add field name to validation exception messages
- Added UUID field - Added UUID field
- Improved efficiency of .get() - Improved efficiency of .get()
@ -17,7 +18,7 @@ Changes in dev
Changes in v0.5.1 Changes in v0.5.1
================= =================
- Circular reference bugfix - Fixed simple circular reference bug
Changes in v0.5 Changes in v0.5
=============== ===============

View File

@ -837,13 +837,19 @@ class BaseDocument(object):
if hasattr(self, '_changed_fields') and key not in self._changed_fields: if hasattr(self, '_changed_fields') and key not in self._changed_fields:
self._changed_fields.append(key) self._changed_fields.append(key)
def _get_changed_fields(self, key=''): def _get_changed_fields(self, key='', inspected=[]):
"""Returns a list of all fields that have explicitly been changed. """Returns a list of all fields that have explicitly been changed.
""" """
from mongoengine import EmbeddedDocument, DynamicEmbeddedDocument from mongoengine import EmbeddedDocument, DynamicEmbeddedDocument
_changed_fields = [] _changed_fields = []
_changed_fields += getattr(self, '_changed_fields', []) _changed_fields += getattr(self, '_changed_fields', [])
inspected = inspected or []
if hasattr(self, 'id'):
if self.id in inspected:
return _changed_fields
inspected.append(self.id)
field_list = self._fields.copy() field_list = self._fields.copy()
if self._dynamic: if self._dynamic:
field_list.update(self._dynamic_fields) field_list.update(self._dynamic_fields)
@ -852,8 +858,13 @@ class BaseDocument(object):
db_field_name = self._db_field_map.get(field_name, field_name) db_field_name = self._db_field_map.get(field_name, field_name)
key = '%s.' % db_field_name key = '%s.' % db_field_name
field = getattr(self, field_name, None) field = getattr(self, field_name, None)
if hasattr(field, 'id'):
if field.id in inspected:
continue
inspected.append(field.id)
if isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument)) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed if isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument)) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed
_changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k] _changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key, inspected) if k]
elif isinstance(field, (list, tuple, dict)) and db_field_name not in _changed_fields: # Loop list / dict fields as they contain documents elif isinstance(field, (list, tuple, dict)) and db_field_name not in _changed_fields: # Loop list / dict fields as they contain documents
# Determine the iterator to use # Determine the iterator to use
if not hasattr(field, 'items'): if not hasattr(field, 'items'):
@ -864,7 +875,7 @@ class BaseDocument(object):
if not hasattr(value, '_get_changed_fields'): if not hasattr(value, '_get_changed_fields'):
continue continue
list_key = "%s%s." % (key, index) list_key = "%s%s." % (key, index)
_changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key) if k] _changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key, inspected) if k]
return _changed_fields return _changed_fields
def _delta(self): def _delta(self):
@ -941,17 +952,17 @@ class BaseDocument(object):
return set_data, unset_data return set_data, unset_data
@classmethod @classmethod
def _geo_indices(cls, inspected_classes=None): def _geo_indices(cls, inspected=None):
inspected_classes = inspected_classes or [] inspected = inspected or []
geo_indices = [] geo_indices = []
inspected_classes.append(cls) inspected.append(cls)
for field in cls._fields.values(): for field in cls._fields.values():
if hasattr(field, 'document_type'): if hasattr(field, 'document_type'):
field_cls = field.document_type field_cls = field.document_type
if field_cls in inspected_classes: if field_cls in inspected:
continue continue
if hasattr(field_cls, '_geo_indices'): if hasattr(field_cls, '_geo_indices'):
geo_indices += field_cls._geo_indices(inspected_classes) geo_indices += field_cls._geo_indices(inspected)
elif field._geo_index: elif field._geo_index:
geo_indices.append(field) geo_indices.append(field)
return geo_indices return geo_indices

View File

@ -185,18 +185,18 @@ class Document(BaseDocument):
id_field = self._meta['id_field'] id_field = self._meta['id_field']
self[id_field] = self._fields[id_field].to_python(object_id) self[id_field] = self._fields[id_field].to_python(object_id)
def reset_changed_fields(doc, inspected_docs=None): def reset_changed_fields(doc, inspected=None):
"""Loop through and reset changed fields lists""" """Loop through and reset changed fields lists"""
inspected_docs = inspected_docs or [] inspected = inspected or []
inspected_docs.append(doc) inspected.append(doc)
if hasattr(doc, '_changed_fields'): if hasattr(doc, '_changed_fields'):
doc._changed_fields = [] doc._changed_fields = []
for field_name in doc._fields: for field_name in doc._fields:
field = getattr(doc, field_name) field = getattr(doc, field_name)
if field not in inspected_docs and hasattr(field, '_changed_fields'): if field not in inspected and hasattr(field, '_changed_fields'):
reset_changed_fields(field, inspected_docs) reset_changed_fields(field, inspected)
reset_changed_fields(self) reset_changed_fields(self)
self._changed_fields = [] self._changed_fields = []

View File

@ -188,6 +188,51 @@ class FieldTest(unittest.TestCase):
self.assertEquals("[<Person: Mother>, <Person: Daughter>]", "%s" % Person.objects()) self.assertEquals("[<Person: Mother>, <Person: Daughter>]", "%s" % Person.objects())
def test_circular_tree_reference(self):
"""Ensure you can handle circular references with more than one level
"""
class Other(EmbeddedDocument):
name = StringField()
friends = ListField(ReferenceField('Person'))
class Person(Document):
name = StringField()
other = EmbeddedDocumentField(Other, default=lambda: Other())
def __repr__(self):
return "<Person: %s>" % self.name
Person.drop_collection()
paul = Person(name="Paul")
paul.save()
maria = Person(name="Maria")
maria.save()
julia = Person(name='Julia')
julia.save()
anna = Person(name='Anna')
anna.save()
paul.other.friends = [maria, julia, anna]
paul.other.name = "Paul's friends"
paul.save()
maria.other.friends = [paul, julia, anna]
maria.other.name = "Maria's friends"
maria.save()
julia.other.friends = [paul, maria, anna]
julia.other.name = "Julia's friends"
julia.save()
anna.other.friends = [paul, maria, julia]
anna.other.name = "Anna's friends"
anna.save()
self.assertEquals(
"[<Person: Paul>, <Person: Maria>, <Person: Julia>, <Person: Anna>]",
"%s" % Person.objects()
)
def test_generic_reference(self): def test_generic_reference(self):
class UserA(Document): class UserA(Document):