From 290b821a3ae2e130192a90bb62ad5ae8d5461005 Mon Sep 17 00:00:00 2001 From: Erdenezul Date: Sat, 2 Sep 2017 02:05:27 +0900 Subject: [PATCH] add fix for reload(fields) affect changed fields #1371 --- mongoengine/document.py | 5 +++-- tests/document/instance.py | 20 +++++++++++++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/mongoengine/document.py b/mongoengine/document.py index 2de0b1a3..182733c7 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -705,7 +705,6 @@ class Document(BaseDocument): obj = obj[0] else: raise self.DoesNotExist('Document does not exist') - for field in obj._data: if not fields or field in fields: try: @@ -721,7 +720,9 @@ class Document(BaseDocument): # i.e. obj.update(unset__field=1) followed by obj.reload() delattr(self, field) - self._changed_fields = obj._changed_fields + self._changed_fields = list( + set(self._changed_fields) - set(fields) + ) if fields else obj._changed_fields self._created = False return self diff --git a/tests/document/instance.py b/tests/document/instance.py index 38c7fcaf..b255e8a6 100644 --- a/tests/document/instance.py +++ b/tests/document/instance.py @@ -476,6 +476,24 @@ class InstanceTest(unittest.TestCase): doc.save() doc.reload() + def test_reload_with_changed_fields(self): + """Ensures reloading will not affect changed fields""" + class User(Document): + name = StringField() + number = IntField() + User.drop_collection() + + user = User(name="Bob", number=1).save() + user.name = "John" + user.number = 2 + + self.assertEqual(user._get_changed_fields(), ['name', 'number']) + user.reload('number') + self.assertEqual(user._get_changed_fields(), ['name']) + user.save() + user.reload() + self.assertEqual(user.name, "John") + def test_reload_referencing(self): """Ensures reloading updates weakrefs correctly.""" class Embedded(EmbeddedDocument): @@ -521,7 +539,7 @@ class InstanceTest(unittest.TestCase): doc.save() doc.dict_field['extra'] = 1 doc = doc.reload(10, 'list_field') - self.assertEqual(doc._get_changed_fields(), []) + self.assertEqual(doc._get_changed_fields(), ['dict_field.extra']) self.assertEqual(len(doc.list_field), 5) self.assertEqual(len(doc.dict_field), 3) self.assertEqual(len(doc.embedded_field.list_field), 4)