Ported fix for Circular Reference bug to Master
Ready for a 0.5.2 release
This commit is contained in:
parent
591149b1f0
commit
452bbcc19b
2
AUTHORS
2
AUTHORS
@ -67,3 +67,5 @@ that much better:
|
|||||||
* Gareth Lloyd
|
* Gareth Lloyd
|
||||||
* Albert Choi
|
* Albert Choi
|
||||||
* John Arnfield
|
* John Arnfield
|
||||||
|
* Julien Rebetez
|
||||||
|
|
||||||
|
@ -2,6 +2,11 @@
|
|||||||
Changelog
|
Changelog
|
||||||
=========
|
=========
|
||||||
|
|
||||||
|
Changes in v0.5.2
|
||||||
|
=================
|
||||||
|
|
||||||
|
- A Robust Circular reference bugfix
|
||||||
|
|
||||||
Changes in v0.5.1
|
Changes in v0.5.1
|
||||||
=================
|
=================
|
||||||
|
|
||||||
|
@ -724,18 +724,32 @@ class BaseDocument(object):
|
|||||||
if hasattr(self, '_changed_fields') and key not in self._changed_fields:
|
if hasattr(self, '_changed_fields') and key not in self._changed_fields:
|
||||||
self._changed_fields.append(key)
|
self._changed_fields.append(key)
|
||||||
|
|
||||||
def _get_changed_fields(self, key=''):
|
def _get_changed_fields(self, key='', inspected=None):
|
||||||
"""Returns a list of all fields that have explicitly been changed.
|
"""Returns a list of all fields that have explicitly been changed.
|
||||||
"""
|
"""
|
||||||
from mongoengine import EmbeddedDocument
|
from mongoengine import EmbeddedDocument
|
||||||
_changed_fields = []
|
_changed_fields = []
|
||||||
_changed_fields += getattr(self, '_changed_fields', [])
|
_changed_fields += getattr(self, '_changed_fields', [])
|
||||||
for field_name in self._fields:
|
|
||||||
|
inspected = inspected or set()
|
||||||
|
if hasattr(self, 'id'):
|
||||||
|
if self.id in inspected:
|
||||||
|
return _changed_fields
|
||||||
|
inspected.add(self.id)
|
||||||
|
|
||||||
|
field_list = self._fields.copy()
|
||||||
|
|
||||||
|
for field_name in field_list:
|
||||||
db_field_name = self._db_field_map.get(field_name, field_name)
|
db_field_name = self._db_field_map.get(field_name, field_name)
|
||||||
key = '%s.' % db_field_name
|
key = '%s.' % db_field_name
|
||||||
field = getattr(self, field_name, None)
|
field = getattr(self, field_name, None)
|
||||||
if isinstance(field, EmbeddedDocument) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed
|
if hasattr(field, 'id'):
|
||||||
_changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k]
|
if field.id in inspected:
|
||||||
|
continue
|
||||||
|
inspected.add(field.id)
|
||||||
|
|
||||||
|
if isinstance(field, (EmbeddedDocument,)) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed
|
||||||
|
_changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key, inspected) if k]
|
||||||
elif isinstance(field, (list, tuple, dict)) and db_field_name not in _changed_fields: # Loop list / dict fields as they contain documents
|
elif isinstance(field, (list, tuple, dict)) and db_field_name not in _changed_fields: # Loop list / dict fields as they contain documents
|
||||||
# Determine the iterator to use
|
# Determine the iterator to use
|
||||||
if not hasattr(field, 'items'):
|
if not hasattr(field, 'items'):
|
||||||
@ -746,8 +760,7 @@ class BaseDocument(object):
|
|||||||
if not hasattr(value, '_get_changed_fields'):
|
if not hasattr(value, '_get_changed_fields'):
|
||||||
continue
|
continue
|
||||||
list_key = "%s%s." % (key, index)
|
list_key = "%s%s." % (key, index)
|
||||||
_changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key) if k]
|
_changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key, inspected) if k]
|
||||||
|
|
||||||
return _changed_fields
|
return _changed_fields
|
||||||
|
|
||||||
def _delta(self):
|
def _delta(self):
|
||||||
|
@ -188,6 +188,51 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEquals("[<Person: Mother>, <Person: Daughter>]", "%s" % Person.objects())
|
self.assertEquals("[<Person: Mother>, <Person: Daughter>]", "%s" % Person.objects())
|
||||||
|
|
||||||
|
def test_circular_tree_reference(self):
|
||||||
|
"""Ensure you can handle circular references with more than one level
|
||||||
|
"""
|
||||||
|
class Other(EmbeddedDocument):
|
||||||
|
name = StringField()
|
||||||
|
friends = ListField(ReferenceField('Person'))
|
||||||
|
|
||||||
|
class Person(Document):
|
||||||
|
name = StringField()
|
||||||
|
other = EmbeddedDocumentField(Other, default=lambda: Other())
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<Person: %s>" % self.name
|
||||||
|
|
||||||
|
Person.drop_collection()
|
||||||
|
paul = Person(name="Paul")
|
||||||
|
paul.save()
|
||||||
|
maria = Person(name="Maria")
|
||||||
|
maria.save()
|
||||||
|
julia = Person(name='Julia')
|
||||||
|
julia.save()
|
||||||
|
anna = Person(name='Anna')
|
||||||
|
anna.save()
|
||||||
|
|
||||||
|
paul.other.friends = [maria, julia, anna]
|
||||||
|
paul.other.name = "Paul's friends"
|
||||||
|
paul.save()
|
||||||
|
|
||||||
|
maria.other.friends = [paul, julia, anna]
|
||||||
|
maria.other.name = "Maria's friends"
|
||||||
|
maria.save()
|
||||||
|
|
||||||
|
julia.other.friends = [paul, maria, anna]
|
||||||
|
julia.other.name = "Julia's friends"
|
||||||
|
julia.save()
|
||||||
|
|
||||||
|
anna.other.friends = [paul, maria, julia]
|
||||||
|
anna.other.name = "Anna's friends"
|
||||||
|
anna.save()
|
||||||
|
|
||||||
|
self.assertEquals(
|
||||||
|
"[<Person: Paul>, <Person: Maria>, <Person: Julia>, <Person: Anna>]",
|
||||||
|
"%s" % Person.objects()
|
||||||
|
)
|
||||||
|
|
||||||
def test_generic_reference(self):
|
def test_generic_reference(self):
|
||||||
|
|
||||||
class UserA(Document):
|
class UserA(Document):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user