Fixes bug where an EmbeddedDocument that shares the same id of its parent Document could be missing updates when .save was called

Fixes #1768, Fixes #1685
This commit is contained in:
Bastien Gérard 2018-09-08 22:20:49 +02:00
parent d17cac8210
commit f89214f9cf
3 changed files with 79 additions and 29 deletions

View File

@ -503,7 +503,13 @@ class BaseDocument(object):
self._changed_fields = [] self._changed_fields = []
def _nestable_types_changed_fields(self, changed_fields, key, data, inspected): def _nestable_types_changed_fields(self, changed_fields, base_key, data):
"""Inspect nested data for changed fields
:param changed_fields: Previously collected changed fields
:param base_key: The base key that must be used to prepend changes to this data
:param data: data to inspect for changes
"""
# Loop list / dict fields as they contain documents # Loop list / dict fields as they contain documents
# Determine the iterator to use # Determine the iterator to use
if not hasattr(data, 'items'): if not hasattr(data, 'items'):
@ -511,57 +517,48 @@ class BaseDocument(object):
else: else:
iterator = data.iteritems() iterator = data.iteritems()
for index, value in iterator: for index_or_key, value in iterator:
list_key = '%s%s.' % (key, index) item_key = '%s%s.' % (base_key, index_or_key)
# don't check anything lower if this key is already marked # don't check anything lower if this key is already marked
# as changed. # as changed.
if list_key[:-1] in changed_fields: if item_key[:-1] in changed_fields:
continue continue
if hasattr(value, '_get_changed_fields'): if hasattr(value, '_get_changed_fields'):
changed = value._get_changed_fields(inspected) changed = value._get_changed_fields()
changed_fields += ['%s%s' % (list_key, k) changed_fields += ['%s%s' % (item_key, k) for k in changed if k]
for k in changed if k]
elif isinstance(value, (list, tuple, dict)): elif isinstance(value, (list, tuple, dict)):
self._nestable_types_changed_fields( self._nestable_types_changed_fields(
changed_fields, list_key, value, inspected) changed_fields, item_key, value)
def _get_changed_fields(self, inspected=None): def _get_changed_fields(self):
"""Return a list of all fields that have explicitly been changed. """Return a list of all fields that have explicitly been changed.
""" """
EmbeddedDocument = _import_class('EmbeddedDocument') EmbeddedDocument = _import_class('EmbeddedDocument')
DynamicEmbeddedDocument = _import_class('DynamicEmbeddedDocument')
ReferenceField = _import_class('ReferenceField') ReferenceField = _import_class('ReferenceField')
SortedListField = _import_class('SortedListField') SortedListField = _import_class('SortedListField')
changed_fields = [] changed_fields = []
changed_fields += getattr(self, '_changed_fields', []) changed_fields += getattr(self, '_changed_fields', [])
inspected = inspected or set()
if hasattr(self, 'id') and isinstance(self.id, Hashable):
if self.id in inspected:
return changed_fields
inspected.add(self.id)
for field_name in self._fields_ordered: for field_name in self._fields_ordered:
db_field_name = self._db_field_map.get(field_name, field_name) db_field_name = self._db_field_map.get(field_name, field_name)
key = '%s.' % db_field_name key = '%s.' % db_field_name
data = self._data.get(field_name, None) data = self._data.get(field_name, None)
field = self._fields.get(field_name) field = self._fields.get(field_name)
if hasattr(data, 'id'): if db_field_name in changed_fields:
if data.id in inspected: # Whole field already marked as changed, no need to go further
continue continue
if isinstance(field, ReferenceField):
if isinstance(field, ReferenceField): # Don't follow referenced documents
continue continue
elif (
isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) and if isinstance(data, EmbeddedDocument):
db_field_name not in changed_fields
):
# Find all embedded fields that have been changed # Find all embedded fields that have been changed
changed = data._get_changed_fields(inspected) changed = data._get_changed_fields()
changed_fields += ['%s%s' % (key, k) for k in changed if k] changed_fields += ['%s%s' % (key, k) for k in changed if k]
elif (isinstance(data, (list, tuple, dict)) and elif isinstance(data, (list, tuple, dict)):
db_field_name not in changed_fields):
if (hasattr(field, 'field') and if (hasattr(field, 'field') and
isinstance(field.field, ReferenceField)): isinstance(field.field, ReferenceField)):
continue continue
@ -572,7 +569,7 @@ class BaseDocument(object):
continue continue
self._nestable_types_changed_fields( self._nestable_types_changed_fields(
changed_fields, key, data, inspected) changed_fields, key, data)
return changed_fields return changed_fields
def _delta(self): def _delta(self):

View File

@ -1422,6 +1422,60 @@ class InstanceTest(MongoDBTestCase):
self.assertEqual(person.age, 21) self.assertEqual(person.age, 21)
self.assertEqual(person.active, False) self.assertEqual(person.active, False)
def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop(self):
# Refers to Issue #1685
class EmbeddedChildModel(EmbeddedDocument):
id = DictField(primary_key=True)
class ParentModel(Document):
child = EmbeddedDocumentField(
EmbeddedChildModel)
emb = EmbeddedChildModel(id={'1': [1]})
ParentModel(children=emb)._get_changed_fields()
def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop(self):
class User(Document):
id = IntField(primary_key=True)
name = StringField()
class Message(Document):
id = IntField(primary_key=True)
author = ReferenceField(User)
Message.drop_collection()
# All objects share the same id, but each in a different collection
user = User(id=1, name='user-name').save()
message = Message(id=1, author=user).save()
message.author.name = 'tutu'
self.assertEqual(message._get_changed_fields(), [])
self.assertEqual(user._get_changed_fields(), ['name'])
def test__get_changed_fields_same_ids_embedded(self):
# Refers to Issue #1768
class User(EmbeddedDocument):
id = IntField()
name = StringField()
class Message(Document):
id = IntField(primary_key=True)
author = EmbeddedDocumentField(User)
Message.drop_collection()
# All objects share the same id, but each in a different collection
user = User(id=1, name='user-name')#.save()
message = Message(id=1, author=user).save()
message.author.name = 'tutu'
self.assertEqual(message._get_changed_fields(), ['author.name'])
message.save()
message_fetched = Message.objects.with_id(message.id)
self.assertEqual(message_fetched.author.name, 'tutu')
def test_query_count_when_saving(self): def test_query_count_when_saving(self):
"""Ensure references don't cause extra fetches when saving""" """Ensure references don't cause extra fetches when saving"""
class Organization(Document): class Organization(Document):

View File

@ -1029,7 +1029,6 @@ class FieldTest(unittest.TestCase):
self.assertEqual(type(foo.bar), Bar) self.assertEqual(type(foo.bar), Bar)
self.assertEqual(type(foo.baz), Baz) self.assertEqual(type(foo.baz), Baz)
def test_document_reload_reference_integrity(self): def test_document_reload_reference_integrity(self):
""" """
Ensure reloading a document with multiple similar id Ensure reloading a document with multiple similar id