Fixes tree based circular references
Thanks to jpfarias for the fix. Also normalised the other circular checks.
This commit is contained in:
parent
8797565606
commit
7b1860d17b
@ -6,6 +6,7 @@ Changelog
|
|||||||
Changes in dev
|
Changes in dev
|
||||||
==============
|
==============
|
||||||
|
|
||||||
|
- Fixed tree based circular reference bug
|
||||||
- Add field name to validation exception messages
|
- Add field name to validation exception messages
|
||||||
- Added UUID field
|
- Added UUID field
|
||||||
- Improved efficiency of .get()
|
- Improved efficiency of .get()
|
||||||
@ -17,7 +18,7 @@ Changes in dev
|
|||||||
Changes in v0.5.1
|
Changes in v0.5.1
|
||||||
=================
|
=================
|
||||||
|
|
||||||
- Circular reference bugfix
|
- Fixed simple circular reference bug
|
||||||
|
|
||||||
Changes in v0.5
|
Changes in v0.5
|
||||||
===============
|
===============
|
||||||
|
@ -837,13 +837,19 @@ class BaseDocument(object):
|
|||||||
if hasattr(self, '_changed_fields') and key not in self._changed_fields:
|
if hasattr(self, '_changed_fields') and key not in self._changed_fields:
|
||||||
self._changed_fields.append(key)
|
self._changed_fields.append(key)
|
||||||
|
|
||||||
def _get_changed_fields(self, key=''):
|
def _get_changed_fields(self, key='', inspected=[]):
|
||||||
"""Returns a list of all fields that have explicitly been changed.
|
"""Returns a list of all fields that have explicitly been changed.
|
||||||
"""
|
"""
|
||||||
from mongoengine import EmbeddedDocument, DynamicEmbeddedDocument
|
from mongoengine import EmbeddedDocument, DynamicEmbeddedDocument
|
||||||
_changed_fields = []
|
_changed_fields = []
|
||||||
_changed_fields += getattr(self, '_changed_fields', [])
|
_changed_fields += getattr(self, '_changed_fields', [])
|
||||||
|
|
||||||
|
inspected = inspected or []
|
||||||
|
if hasattr(self, 'id'):
|
||||||
|
if self.id in inspected:
|
||||||
|
return _changed_fields
|
||||||
|
inspected.append(self.id)
|
||||||
|
|
||||||
field_list = self._fields.copy()
|
field_list = self._fields.copy()
|
||||||
if self._dynamic:
|
if self._dynamic:
|
||||||
field_list.update(self._dynamic_fields)
|
field_list.update(self._dynamic_fields)
|
||||||
@ -852,8 +858,13 @@ class BaseDocument(object):
|
|||||||
db_field_name = self._db_field_map.get(field_name, field_name)
|
db_field_name = self._db_field_map.get(field_name, field_name)
|
||||||
key = '%s.' % db_field_name
|
key = '%s.' % db_field_name
|
||||||
field = getattr(self, field_name, None)
|
field = getattr(self, field_name, None)
|
||||||
|
if hasattr(field, 'id'):
|
||||||
|
if field.id in inspected:
|
||||||
|
continue
|
||||||
|
inspected.append(field.id)
|
||||||
|
|
||||||
if isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument)) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed
|
if isinstance(field, (EmbeddedDocument, DynamicEmbeddedDocument)) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed
|
||||||
_changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k]
|
_changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key, inspected) if k]
|
||||||
elif isinstance(field, (list, tuple, dict)) and db_field_name not in _changed_fields: # Loop list / dict fields as they contain documents
|
elif isinstance(field, (list, tuple, dict)) and db_field_name not in _changed_fields: # Loop list / dict fields as they contain documents
|
||||||
# Determine the iterator to use
|
# Determine the iterator to use
|
||||||
if not hasattr(field, 'items'):
|
if not hasattr(field, 'items'):
|
||||||
@ -864,7 +875,7 @@ class BaseDocument(object):
|
|||||||
if not hasattr(value, '_get_changed_fields'):
|
if not hasattr(value, '_get_changed_fields'):
|
||||||
continue
|
continue
|
||||||
list_key = "%s%s." % (key, index)
|
list_key = "%s%s." % (key, index)
|
||||||
_changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key) if k]
|
_changed_fields += ["%s%s" % (list_key, k) for k in value._get_changed_fields(list_key, inspected) if k]
|
||||||
return _changed_fields
|
return _changed_fields
|
||||||
|
|
||||||
def _delta(self):
|
def _delta(self):
|
||||||
@ -941,17 +952,17 @@ class BaseDocument(object):
|
|||||||
return set_data, unset_data
|
return set_data, unset_data
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _geo_indices(cls, inspected_classes=None):
|
def _geo_indices(cls, inspected=None):
|
||||||
inspected_classes = inspected_classes or []
|
inspected = inspected or []
|
||||||
geo_indices = []
|
geo_indices = []
|
||||||
inspected_classes.append(cls)
|
inspected.append(cls)
|
||||||
for field in cls._fields.values():
|
for field in cls._fields.values():
|
||||||
if hasattr(field, 'document_type'):
|
if hasattr(field, 'document_type'):
|
||||||
field_cls = field.document_type
|
field_cls = field.document_type
|
||||||
if field_cls in inspected_classes:
|
if field_cls in inspected:
|
||||||
continue
|
continue
|
||||||
if hasattr(field_cls, '_geo_indices'):
|
if hasattr(field_cls, '_geo_indices'):
|
||||||
geo_indices += field_cls._geo_indices(inspected_classes)
|
geo_indices += field_cls._geo_indices(inspected)
|
||||||
elif field._geo_index:
|
elif field._geo_index:
|
||||||
geo_indices.append(field)
|
geo_indices.append(field)
|
||||||
return geo_indices
|
return geo_indices
|
||||||
|
@ -185,18 +185,18 @@ class Document(BaseDocument):
|
|||||||
id_field = self._meta['id_field']
|
id_field = self._meta['id_field']
|
||||||
self[id_field] = self._fields[id_field].to_python(object_id)
|
self[id_field] = self._fields[id_field].to_python(object_id)
|
||||||
|
|
||||||
def reset_changed_fields(doc, inspected_docs=None):
|
def reset_changed_fields(doc, inspected=None):
|
||||||
"""Loop through and reset changed fields lists"""
|
"""Loop through and reset changed fields lists"""
|
||||||
|
|
||||||
inspected_docs = inspected_docs or []
|
inspected = inspected or []
|
||||||
inspected_docs.append(doc)
|
inspected.append(doc)
|
||||||
if hasattr(doc, '_changed_fields'):
|
if hasattr(doc, '_changed_fields'):
|
||||||
doc._changed_fields = []
|
doc._changed_fields = []
|
||||||
|
|
||||||
for field_name in doc._fields:
|
for field_name in doc._fields:
|
||||||
field = getattr(doc, field_name)
|
field = getattr(doc, field_name)
|
||||||
if field not in inspected_docs and hasattr(field, '_changed_fields'):
|
if field not in inspected and hasattr(field, '_changed_fields'):
|
||||||
reset_changed_fields(field, inspected_docs)
|
reset_changed_fields(field, inspected)
|
||||||
|
|
||||||
reset_changed_fields(self)
|
reset_changed_fields(self)
|
||||||
self._changed_fields = []
|
self._changed_fields = []
|
||||||
|
@ -188,6 +188,51 @@ class FieldTest(unittest.TestCase):
|
|||||||
|
|
||||||
self.assertEquals("[<Person: Mother>, <Person: Daughter>]", "%s" % Person.objects())
|
self.assertEquals("[<Person: Mother>, <Person: Daughter>]", "%s" % Person.objects())
|
||||||
|
|
||||||
|
def test_circular_tree_reference(self):
|
||||||
|
"""Ensure you can handle circular references with more than one level
|
||||||
|
"""
|
||||||
|
class Other(EmbeddedDocument):
|
||||||
|
name = StringField()
|
||||||
|
friends = ListField(ReferenceField('Person'))
|
||||||
|
|
||||||
|
class Person(Document):
|
||||||
|
name = StringField()
|
||||||
|
other = EmbeddedDocumentField(Other, default=lambda: Other())
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<Person: %s>" % self.name
|
||||||
|
|
||||||
|
Person.drop_collection()
|
||||||
|
paul = Person(name="Paul")
|
||||||
|
paul.save()
|
||||||
|
maria = Person(name="Maria")
|
||||||
|
maria.save()
|
||||||
|
julia = Person(name='Julia')
|
||||||
|
julia.save()
|
||||||
|
anna = Person(name='Anna')
|
||||||
|
anna.save()
|
||||||
|
|
||||||
|
paul.other.friends = [maria, julia, anna]
|
||||||
|
paul.other.name = "Paul's friends"
|
||||||
|
paul.save()
|
||||||
|
|
||||||
|
maria.other.friends = [paul, julia, anna]
|
||||||
|
maria.other.name = "Maria's friends"
|
||||||
|
maria.save()
|
||||||
|
|
||||||
|
julia.other.friends = [paul, maria, anna]
|
||||||
|
julia.other.name = "Julia's friends"
|
||||||
|
julia.save()
|
||||||
|
|
||||||
|
anna.other.friends = [paul, maria, julia]
|
||||||
|
anna.other.name = "Anna's friends"
|
||||||
|
anna.save()
|
||||||
|
|
||||||
|
self.assertEquals(
|
||||||
|
"[<Person: Paul>, <Person: Maria>, <Person: Julia>, <Person: Anna>]",
|
||||||
|
"%s" % Person.objects()
|
||||||
|
)
|
||||||
|
|
||||||
def test_generic_reference(self):
|
def test_generic_reference(self):
|
||||||
|
|
||||||
class UserA(Document):
|
class UserA(Document):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user