Merge pull request #1887 from bagerard/fix_changed_fields_issue_same_id_in_nested_doc2

Fix bug where an EmbeddedDocument with the same id as its parent would not be tracked for changes
This commit is contained in:
Bastien Gérard 2018-11-01 22:49:07 +01:00 committed by GitHub
commit 26e2fc8fd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 79 additions and 29 deletions

View File

@ -500,7 +500,13 @@ class BaseDocument(object):
self._changed_fields = []
def _nestable_types_changed_fields(self, changed_fields, key, data, inspected):
def _nestable_types_changed_fields(self, changed_fields, base_key, data):
"""Inspect nested data for changed fields
:param changed_fields: Previously collected changed fields
:param base_key: The base key that must be used to prepend changes to this data
:param data: data to inspect for changes
"""
# Loop list / dict fields as they contain documents
# Determine the iterator to use
if not hasattr(data, 'items'):
@ -508,25 +514,24 @@ class BaseDocument(object):
else:
iterator = data.iteritems()
for index, value in iterator:
list_key = '%s%s.' % (key, index)
for index_or_key, value in iterator:
item_key = '%s%s.' % (base_key, index_or_key)
# don't check anything lower if this key is already marked
# as changed.
if list_key[:-1] in changed_fields:
if item_key[:-1] in changed_fields:
continue
if hasattr(value, '_get_changed_fields'):
changed = value._get_changed_fields(inspected)
changed_fields += ['%s%s' % (list_key, k)
for k in changed if k]
changed = value._get_changed_fields()
changed_fields += ['%s%s' % (item_key, k) for k in changed if k]
elif isinstance(value, (list, tuple, dict)):
self._nestable_types_changed_fields(
changed_fields, list_key, value, inspected)
changed_fields, item_key, value)
def _get_changed_fields(self, inspected=None):
def _get_changed_fields(self):
"""Return a list of all fields that have explicitly been changed.
"""
EmbeddedDocument = _import_class('EmbeddedDocument')
DynamicEmbeddedDocument = _import_class('DynamicEmbeddedDocument')
ReferenceField = _import_class('ReferenceField')
GenericReferenceField = _import_class('GenericReferenceField')
SortedListField = _import_class('SortedListField')
@ -534,32 +539,24 @@ class BaseDocument(object):
changed_fields = []
changed_fields += getattr(self, '_changed_fields', [])
inspected = inspected or set()
if hasattr(self, 'id') and isinstance(self.id, Hashable):
if self.id in inspected:
return changed_fields
inspected.add(self.id)
for field_name in self._fields_ordered:
db_field_name = self._db_field_map.get(field_name, field_name)
key = '%s.' % db_field_name
data = self._data.get(field_name, None)
field = self._fields.get(field_name)
if hasattr(data, 'id'):
if data.id in inspected:
continue
if isinstance(field, ReferenceField):
if db_field_name in changed_fields:
# Whole field already marked as changed, no need to go further
continue
elif (
isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) and
db_field_name not in changed_fields
):
if isinstance(field, ReferenceField): # Don't follow referenced documents
continue
if isinstance(data, EmbeddedDocument):
# Find all embedded fields that have been changed
changed = data._get_changed_fields(inspected)
changed = data._get_changed_fields()
changed_fields += ['%s%s' % (key, k) for k in changed if k]
elif (isinstance(data, (list, tuple, dict)) and
db_field_name not in changed_fields):
elif isinstance(data, (list, tuple, dict)):
if (hasattr(field, 'field') and
isinstance(field.field, (ReferenceField, GenericReferenceField))):
continue
@ -570,7 +567,7 @@ class BaseDocument(object):
continue
self._nestable_types_changed_fields(
changed_fields, key, data, inspected)
changed_fields, key, data)
return changed_fields
def _delta(self):

View File

@ -1422,6 +1422,60 @@ class InstanceTest(MongoDBTestCase):
self.assertEqual(person.age, 21)
self.assertEqual(person.active, False)
def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop(self):
# Refers to Issue #1685
class EmbeddedChildModel(EmbeddedDocument):
id = DictField(primary_key=True)
class ParentModel(Document):
child = EmbeddedDocumentField(
EmbeddedChildModel)
emb = EmbeddedChildModel(id={'1': [1]})
ParentModel(children=emb)._get_changed_fields()
def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop(self):
class User(Document):
id = IntField(primary_key=True)
name = StringField()
class Message(Document):
id = IntField(primary_key=True)
author = ReferenceField(User)
Message.drop_collection()
# All objects share the same id, but each in a different collection
user = User(id=1, name='user-name').save()
message = Message(id=1, author=user).save()
message.author.name = 'tutu'
self.assertEqual(message._get_changed_fields(), [])
self.assertEqual(user._get_changed_fields(), ['name'])
def test__get_changed_fields_same_ids_embedded(self):
# Refers to Issue #1768
class User(EmbeddedDocument):
id = IntField()
name = StringField()
class Message(Document):
id = IntField(primary_key=True)
author = EmbeddedDocumentField(User)
Message.drop_collection()
# All objects share the same id, but each in a different collection
user = User(id=1, name='user-name')#.save()
message = Message(id=1, author=user).save()
message.author.name = 'tutu'
self.assertEqual(message._get_changed_fields(), ['author.name'])
message.save()
message_fetched = Message.objects.with_id(message.id)
self.assertEqual(message_fetched.author.name, 'tutu')
def test_query_count_when_saving(self):
"""Ensure references don't cause extra fetches when saving"""
class Organization(Document):

View File

@ -1029,7 +1029,6 @@ class FieldTest(unittest.TestCase):
self.assertEqual(type(foo.bar), Bar)
self.assertEqual(type(foo.baz), Baz)
def test_document_reload_reference_integrity(self):
"""
Ensure reloading a document with multiple similar id