Fixes bug where an EmbeddedDocument that shares the same id of its parent Document could be missing updates when .save was called
Fixes #1768, Fixes #1685
This commit is contained in:
parent
d17cac8210
commit
f89214f9cf
@ -503,7 +503,13 @@ class BaseDocument(object):
|
|||||||
|
|
||||||
self._changed_fields = []
|
self._changed_fields = []
|
||||||
|
|
||||||
def _nestable_types_changed_fields(self, changed_fields, key, data, inspected):
|
def _nestable_types_changed_fields(self, changed_fields, base_key, data):
|
||||||
|
"""Inspect nested data for changed fields
|
||||||
|
|
||||||
|
:param changed_fields: Previously collected changed fields
|
||||||
|
:param base_key: The base key that must be used to prepend changes to this data
|
||||||
|
:param data: data to inspect for changes
|
||||||
|
"""
|
||||||
# Loop list / dict fields as they contain documents
|
# Loop list / dict fields as they contain documents
|
||||||
# Determine the iterator to use
|
# Determine the iterator to use
|
||||||
if not hasattr(data, 'items'):
|
if not hasattr(data, 'items'):
|
||||||
@ -511,57 +517,48 @@ class BaseDocument(object):
|
|||||||
else:
|
else:
|
||||||
iterator = data.iteritems()
|
iterator = data.iteritems()
|
||||||
|
|
||||||
for index, value in iterator:
|
for index_or_key, value in iterator:
|
||||||
list_key = '%s%s.' % (key, index)
|
item_key = '%s%s.' % (base_key, index_or_key)
|
||||||
# don't check anything lower if this key is already marked
|
# don't check anything lower if this key is already marked
|
||||||
# as changed.
|
# as changed.
|
||||||
if list_key[:-1] in changed_fields:
|
if item_key[:-1] in changed_fields:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if hasattr(value, '_get_changed_fields'):
|
if hasattr(value, '_get_changed_fields'):
|
||||||
changed = value._get_changed_fields(inspected)
|
changed = value._get_changed_fields()
|
||||||
changed_fields += ['%s%s' % (list_key, k)
|
changed_fields += ['%s%s' % (item_key, k) for k in changed if k]
|
||||||
for k in changed if k]
|
|
||||||
elif isinstance(value, (list, tuple, dict)):
|
elif isinstance(value, (list, tuple, dict)):
|
||||||
self._nestable_types_changed_fields(
|
self._nestable_types_changed_fields(
|
||||||
changed_fields, list_key, value, inspected)
|
changed_fields, item_key, value)
|
||||||
|
|
||||||
def _get_changed_fields(self, inspected=None):
|
def _get_changed_fields(self):
|
||||||
"""Return a list of all fields that have explicitly been changed.
|
"""Return a list of all fields that have explicitly been changed.
|
||||||
"""
|
"""
|
||||||
EmbeddedDocument = _import_class('EmbeddedDocument')
|
EmbeddedDocument = _import_class('EmbeddedDocument')
|
||||||
DynamicEmbeddedDocument = _import_class('DynamicEmbeddedDocument')
|
|
||||||
ReferenceField = _import_class('ReferenceField')
|
ReferenceField = _import_class('ReferenceField')
|
||||||
SortedListField = _import_class('SortedListField')
|
SortedListField = _import_class('SortedListField')
|
||||||
|
|
||||||
changed_fields = []
|
changed_fields = []
|
||||||
changed_fields += getattr(self, '_changed_fields', [])
|
changed_fields += getattr(self, '_changed_fields', [])
|
||||||
|
|
||||||
inspected = inspected or set()
|
|
||||||
if hasattr(self, 'id') and isinstance(self.id, Hashable):
|
|
||||||
if self.id in inspected:
|
|
||||||
return changed_fields
|
|
||||||
inspected.add(self.id)
|
|
||||||
|
|
||||||
for field_name in self._fields_ordered:
|
for field_name in self._fields_ordered:
|
||||||
db_field_name = self._db_field_map.get(field_name, field_name)
|
db_field_name = self._db_field_map.get(field_name, field_name)
|
||||||
key = '%s.' % db_field_name
|
key = '%s.' % db_field_name
|
||||||
data = self._data.get(field_name, None)
|
data = self._data.get(field_name, None)
|
||||||
field = self._fields.get(field_name)
|
field = self._fields.get(field_name)
|
||||||
|
|
||||||
if hasattr(data, 'id'):
|
if db_field_name in changed_fields:
|
||||||
if data.id in inspected:
|
# Whole field already marked as changed, no need to go further
|
||||||
continue
|
|
||||||
if isinstance(field, ReferenceField):
|
|
||||||
continue
|
continue
|
||||||
elif (
|
|
||||||
isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) and
|
if isinstance(field, ReferenceField): # Don't follow referenced documents
|
||||||
db_field_name not in changed_fields
|
continue
|
||||||
):
|
|
||||||
|
if isinstance(data, EmbeddedDocument):
|
||||||
# Find all embedded fields that have been changed
|
# Find all embedded fields that have been changed
|
||||||
changed = data._get_changed_fields(inspected)
|
changed = data._get_changed_fields()
|
||||||
changed_fields += ['%s%s' % (key, k) for k in changed if k]
|
changed_fields += ['%s%s' % (key, k) for k in changed if k]
|
||||||
elif (isinstance(data, (list, tuple, dict)) and
|
elif isinstance(data, (list, tuple, dict)):
|
||||||
db_field_name not in changed_fields):
|
|
||||||
if (hasattr(field, 'field') and
|
if (hasattr(field, 'field') and
|
||||||
isinstance(field.field, ReferenceField)):
|
isinstance(field.field, ReferenceField)):
|
||||||
continue
|
continue
|
||||||
@ -572,7 +569,7 @@ class BaseDocument(object):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
self._nestable_types_changed_fields(
|
self._nestable_types_changed_fields(
|
||||||
changed_fields, key, data, inspected)
|
changed_fields, key, data)
|
||||||
return changed_fields
|
return changed_fields
|
||||||
|
|
||||||
def _delta(self):
|
def _delta(self):
|
||||||
|
@ -1422,6 +1422,60 @@ class InstanceTest(MongoDBTestCase):
|
|||||||
self.assertEqual(person.age, 21)
|
self.assertEqual(person.age, 21)
|
||||||
self.assertEqual(person.active, False)
|
self.assertEqual(person.active, False)
|
||||||
|
|
||||||
|
def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop(self):
|
||||||
|
# Refers to Issue #1685
|
||||||
|
class EmbeddedChildModel(EmbeddedDocument):
|
||||||
|
id = DictField(primary_key=True)
|
||||||
|
|
||||||
|
class ParentModel(Document):
|
||||||
|
child = EmbeddedDocumentField(
|
||||||
|
EmbeddedChildModel)
|
||||||
|
|
||||||
|
emb = EmbeddedChildModel(id={'1': [1]})
|
||||||
|
ParentModel(children=emb)._get_changed_fields()
|
||||||
|
|
||||||
|
def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop(self):
|
||||||
|
class User(Document):
|
||||||
|
id = IntField(primary_key=True)
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class Message(Document):
|
||||||
|
id = IntField(primary_key=True)
|
||||||
|
author = ReferenceField(User)
|
||||||
|
|
||||||
|
Message.drop_collection()
|
||||||
|
|
||||||
|
# All objects share the same id, but each in a different collection
|
||||||
|
user = User(id=1, name='user-name').save()
|
||||||
|
message = Message(id=1, author=user).save()
|
||||||
|
|
||||||
|
message.author.name = 'tutu'
|
||||||
|
self.assertEqual(message._get_changed_fields(), [])
|
||||||
|
self.assertEqual(user._get_changed_fields(), ['name'])
|
||||||
|
|
||||||
|
def test__get_changed_fields_same_ids_embedded(self):
|
||||||
|
# Refers to Issue #1768
|
||||||
|
class User(EmbeddedDocument):
|
||||||
|
id = IntField()
|
||||||
|
name = StringField()
|
||||||
|
|
||||||
|
class Message(Document):
|
||||||
|
id = IntField(primary_key=True)
|
||||||
|
author = EmbeddedDocumentField(User)
|
||||||
|
|
||||||
|
Message.drop_collection()
|
||||||
|
|
||||||
|
# All objects share the same id, but each in a different collection
|
||||||
|
user = User(id=1, name='user-name')#.save()
|
||||||
|
message = Message(id=1, author=user).save()
|
||||||
|
|
||||||
|
message.author.name = 'tutu'
|
||||||
|
self.assertEqual(message._get_changed_fields(), ['author.name'])
|
||||||
|
message.save()
|
||||||
|
|
||||||
|
message_fetched = Message.objects.with_id(message.id)
|
||||||
|
self.assertEqual(message_fetched.author.name, 'tutu')
|
||||||
|
|
||||||
def test_query_count_when_saving(self):
|
def test_query_count_when_saving(self):
|
||||||
"""Ensure references don't cause extra fetches when saving"""
|
"""Ensure references don't cause extra fetches when saving"""
|
||||||
class Organization(Document):
|
class Organization(Document):
|
||||||
|
@ -1029,7 +1029,6 @@ class FieldTest(unittest.TestCase):
|
|||||||
self.assertEqual(type(foo.bar), Bar)
|
self.assertEqual(type(foo.bar), Bar)
|
||||||
self.assertEqual(type(foo.baz), Baz)
|
self.assertEqual(type(foo.baz), Baz)
|
||||||
|
|
||||||
|
|
||||||
def test_document_reload_reference_integrity(self):
|
def test_document_reload_reference_integrity(self):
|
||||||
"""
|
"""
|
||||||
Ensure reloading a document with multiple similar id
|
Ensure reloading a document with multiple similar id
|
||||||
|
Loading…
x
Reference in New Issue
Block a user