Ported fix for Circular Reference bug to Master

Ready for a 0.5.2 release
This commit is contained in:
Ross Lawley 2011-10-12 00:30:12 -07:00
parent 591149b1f0
commit 452bbcc19b
4 changed files with 71 additions and 6 deletions

View File

@ -67,3 +67,5 @@ that much better:
* Gareth Lloyd * Gareth Lloyd
* Albert Choi * Albert Choi
* John Arnfield * John Arnfield
* Julien Rebetez

View File

@ -2,6 +2,11 @@
Changelog Changelog
========= =========
Changes in v0.5.2
=================
- A Robust Circular reference bugfix
Changes in v0.5.1 Changes in v0.5.1
================= =================

View File

@ -724,18 +724,32 @@ class BaseDocument(object):
if hasattr(self, '_changed_fields') and key not in self._changed_fields: if hasattr(self, '_changed_fields') and key not in self._changed_fields:
self._changed_fields.append(key) self._changed_fields.append(key)
def _get_changed_fields(self, key=''): def _get_changed_fields(self, key='', inspected=None):
"""Returns a list of all fields that have explicitly been changed. """Returns a list of all fields that have explicitly been changed.
""" """
from mongoengine import EmbeddedDocument from mongoengine import EmbeddedDocument
_changed_fields = [] _changed_fields = []
_changed_fields += getattr(self, '_changed_fields', []) _changed_fields += getattr(self, '_changed_fields', [])
for field_name in self._fields:
inspected = inspected or set()
if hasattr(self, 'id'):
if self.id in inspected:
return _changed_fields
inspected.add(self.id)
field_list = self._fields.copy()
for field_name in field_list:
db_field_name = self._db_field_map.get(field_name, field_name) db_field_name = self._db_field_map.get(field_name, field_name)
key = '%s.' % db_field_name key = '%s.' % db_field_name
field = getattr(self, field_name, None) field = getattr(self, field_name, None)
if isinstance(field, EmbeddedDocument) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed if hasattr(field, 'id'):
_changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k] if field.id in inspected:
continue
inspected.add(field.id)
if isinstance(field, (EmbeddedDocument,)) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed
_changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key, inspected) if k]
elif isinstance(field, (list, tuple, dict)) and db_field_name not in _changed_fields: # Loop list / dict fields as they contain documents elif isinstance(field, (list, tuple, dict)) and db_field_name not in _changed_fields: # Loop list / dict fields as they contain documents
# Determine the iterator to use # Determine the iterator to use
if not hasattr(field, 'items'): if not hasattr(field, 'items'):
@ -746,8 +760,7 @@ class BaseDocument(object):
if not hasattr(value, '_get_changed_fields'): if not hasattr(value, '_get_changed_fields'):
continue continue
list_key = "%s%s." % (key, index) list_key = "%s%s." % (key, index)
_changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key) if k] _changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key, inspected) if k]
return _changed_fields return _changed_fields
def _delta(self): def _delta(self):

View File

@ -188,6 +188,51 @@ class FieldTest(unittest.TestCase):
self.assertEquals("[<Person: Mother>, <Person: Daughter>]", "%s" % Person.objects()) self.assertEquals("[<Person: Mother>, <Person: Daughter>]", "%s" % Person.objects())
def test_circular_tree_reference(self):
"""Ensure you can handle circular references with more than one level
"""
class Other(EmbeddedDocument):
name = StringField()
friends = ListField(ReferenceField('Person'))
class Person(Document):
name = StringField()
other = EmbeddedDocumentField(Other, default=lambda: Other())
def __repr__(self):
return "<Person: %s>" % self.name
Person.drop_collection()
paul = Person(name="Paul")
paul.save()
maria = Person(name="Maria")
maria.save()
julia = Person(name='Julia')
julia.save()
anna = Person(name='Anna')
anna.save()
paul.other.friends = [maria, julia, anna]
paul.other.name = "Paul's friends"
paul.save()
maria.other.friends = [paul, julia, anna]
maria.other.name = "Maria's friends"
maria.save()
julia.other.friends = [paul, maria, anna]
julia.other.name = "Julia's friends"
julia.save()
anna.other.friends = [paul, maria, julia]
anna.other.name = "Anna's friends"
anna.save()
self.assertEquals(
"[<Person: Paul>, <Person: Maria>, <Person: Julia>, <Person: Anna>]",
"%s" % Person.objects()
)
def test_generic_reference(self): def test_generic_reference(self):
class UserA(Document): class UserA(Document):