From f89214f9cf39d2a06fd3611ffb6e1647e6fce4b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Sat, 8 Sep 2018 22:20:49 +0200 Subject: [PATCH] Fixes bug where an EmbeddedDocument that shares the same id of its parent Document could be missing updates when .save was called Fixes #1768, Fixes #1685 --- mongoengine/base/document.py | 53 +++++++++++++++++------------------ tests/document/instance.py | 54 ++++++++++++++++++++++++++++++++++++ tests/test_dereference.py | 1 - 3 files changed, 79 insertions(+), 29 deletions(-) diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 85906a3e..7a3e22f9 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -503,7 +503,13 @@ class BaseDocument(object): self._changed_fields = [] - def _nestable_types_changed_fields(self, changed_fields, key, data, inspected): + def _nestable_types_changed_fields(self, changed_fields, base_key, data): + """Inspect nested data for changed fields + + :param changed_fields: Previously collected changed fields + :param base_key: The base key that must be used to prepend changes to this data + :param data: data to inspect for changes + """ # Loop list / dict fields as they contain documents # Determine the iterator to use if not hasattr(data, 'items'): @@ -511,57 +517,48 @@ class BaseDocument(object): else: iterator = data.iteritems() - for index, value in iterator: - list_key = '%s%s.' % (key, index) + for index_or_key, value in iterator: + item_key = '%s%s.' % (base_key, index_or_key) # don't check anything lower if this key is already marked # as changed. - if list_key[:-1] in changed_fields: + if item_key[:-1] in changed_fields: continue + if hasattr(value, '_get_changed_fields'): - changed = value._get_changed_fields(inspected) - changed_fields += ['%s%s' % (list_key, k) - for k in changed if k] + changed = value._get_changed_fields() + changed_fields += ['%s%s' % (item_key, k) for k in changed if k] elif isinstance(value, (list, tuple, dict)): self._nestable_types_changed_fields( - changed_fields, list_key, value, inspected) + changed_fields, item_key, value) - def _get_changed_fields(self, inspected=None): + def _get_changed_fields(self): """Return a list of all fields that have explicitly been changed. """ EmbeddedDocument = _import_class('EmbeddedDocument') - DynamicEmbeddedDocument = _import_class('DynamicEmbeddedDocument') ReferenceField = _import_class('ReferenceField') SortedListField = _import_class('SortedListField') changed_fields = [] changed_fields += getattr(self, '_changed_fields', []) - inspected = inspected or set() - if hasattr(self, 'id') and isinstance(self.id, Hashable): - if self.id in inspected: - return changed_fields - inspected.add(self.id) - for field_name in self._fields_ordered: db_field_name = self._db_field_map.get(field_name, field_name) key = '%s.' % db_field_name data = self._data.get(field_name, None) field = self._fields.get(field_name) - if hasattr(data, 'id'): - if data.id in inspected: - continue - if isinstance(field, ReferenceField): + if db_field_name in changed_fields: + # Whole field already marked as changed, no need to go further continue - elif ( - isinstance(data, (EmbeddedDocument, DynamicEmbeddedDocument)) and - db_field_name not in changed_fields - ): + + if isinstance(field, ReferenceField): # Don't follow referenced documents + continue + + if isinstance(data, EmbeddedDocument): # Find all embedded fields that have been changed - changed = data._get_changed_fields(inspected) + changed = data._get_changed_fields() changed_fields += ['%s%s' % (key, k) for k in changed if k] - elif (isinstance(data, (list, tuple, dict)) and - db_field_name not in changed_fields): + elif isinstance(data, (list, tuple, dict)): if (hasattr(field, 'field') and isinstance(field.field, ReferenceField)): continue @@ -572,7 +569,7 @@ class BaseDocument(object): continue self._nestable_types_changed_fields( - changed_fields, key, data, inspected) + changed_fields, key, data) return changed_fields def _delta(self): diff --git a/tests/document/instance.py b/tests/document/instance.py index e637b3e6..d8ff8b43 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -1422,6 +1422,60 @@ class InstanceTest(MongoDBTestCase): self.assertEqual(person.age, 21) self.assertEqual(person.active, False) + def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop(self): + # Refers to Issue #1685 + class EmbeddedChildModel(EmbeddedDocument): + id = DictField(primary_key=True) + + class ParentModel(Document): + child = EmbeddedDocumentField( + EmbeddedChildModel) + + emb = EmbeddedChildModel(id={'1': [1]}) + ParentModel(children=emb)._get_changed_fields() + + def test__get_changed_fields_same_ids_reference_field_does_not_enters_infinite_loop(self): + class User(Document): + id = IntField(primary_key=True) + name = StringField() + + class Message(Document): + id = IntField(primary_key=True) + author = ReferenceField(User) + + Message.drop_collection() + + # All objects share the same id, but each in a different collection + user = User(id=1, name='user-name').save() + message = Message(id=1, author=user).save() + + message.author.name = 'tutu' + self.assertEqual(message._get_changed_fields(), []) + self.assertEqual(user._get_changed_fields(), ['name']) + + def test__get_changed_fields_same_ids_embedded(self): + # Refers to Issue #1768 + class User(EmbeddedDocument): + id = IntField() + name = StringField() + + class Message(Document): + id = IntField(primary_key=True) + author = EmbeddedDocumentField(User) + + Message.drop_collection() + + # All objects share the same id, but each in a different collection + user = User(id=1, name='user-name')#.save() + message = Message(id=1, author=user).save() + + message.author.name = 'tutu' + self.assertEqual(message._get_changed_fields(), ['author.name']) + message.save() + + message_fetched = Message.objects.with_id(message.id) + self.assertEqual(message_fetched.author.name, 'tutu') + def test_query_count_when_saving(self): """Ensure references don't cause extra fetches when saving""" class Organization(Document): diff --git a/tests/test_dereference.py b/tests/test_dereference.py index 8b8bcfb2..5cf089f4 100644 --- a/tests/test_dereference.py +++ b/tests/test_dereference.py @@ -1029,7 +1029,6 @@ class FieldTest(unittest.TestCase): self.assertEqual(type(foo.bar), Bar) self.assertEqual(type(foo.baz), Baz) - def test_document_reload_reference_integrity(self): """ Ensure reloading a document with multiple similar id