add fix for reload(fields) affect changed fields #1371
This commit is contained in:
		| @@ -702,7 +702,6 @@ class Document(BaseDocument): | |||||||
|             obj = obj[0] |             obj = obj[0] | ||||||
|         else: |         else: | ||||||
|             raise self.DoesNotExist('Document does not exist') |             raise self.DoesNotExist('Document does not exist') | ||||||
|  |  | ||||||
|         for field in obj._data: |         for field in obj._data: | ||||||
|             if not fields or field in fields: |             if not fields or field in fields: | ||||||
|                 try: |                 try: | ||||||
| @@ -718,7 +717,9 @@ class Document(BaseDocument): | |||||||
|                         # i.e. obj.update(unset__field=1) followed by obj.reload() |                         # i.e. obj.update(unset__field=1) followed by obj.reload() | ||||||
|                         delattr(self, field) |                         delattr(self, field) | ||||||
|  |  | ||||||
|         self._changed_fields = obj._changed_fields |         self._changed_fields = list( | ||||||
|  |             set(self._changed_fields) - set(fields) | ||||||
|  |             ) if fields else obj._changed_fields | ||||||
|         self._created = False |         self._created = False | ||||||
|         return self |         return self | ||||||
|  |  | ||||||
|   | |||||||
| @@ -476,6 +476,24 @@ class InstanceTest(unittest.TestCase): | |||||||
|         doc.save() |         doc.save() | ||||||
|         doc.reload() |         doc.reload() | ||||||
|  |  | ||||||
|  |     def test_reload_with_changed_fields(self): | ||||||
|  |         """Ensures reloading will not affect changed fields""" | ||||||
|  |         class User(Document): | ||||||
|  |             name = StringField() | ||||||
|  |             number = IntField() | ||||||
|  |         User.drop_collection() | ||||||
|  |  | ||||||
|  |         user = User(name="Bob", number=1).save() | ||||||
|  |         user.name = "John" | ||||||
|  |         user.number = 2 | ||||||
|  |  | ||||||
|  |         self.assertEqual(user._get_changed_fields(), ['name', 'number']) | ||||||
|  |         user.reload('number') | ||||||
|  |         self.assertEqual(user._get_changed_fields(), ['name']) | ||||||
|  |         user.save() | ||||||
|  |         user.reload() | ||||||
|  |         self.assertEqual(user.name, "John") | ||||||
|  |  | ||||||
|     def test_reload_referencing(self): |     def test_reload_referencing(self): | ||||||
|         """Ensures reloading updates weakrefs correctly.""" |         """Ensures reloading updates weakrefs correctly.""" | ||||||
|         class Embedded(EmbeddedDocument): |         class Embedded(EmbeddedDocument): | ||||||
| @@ -521,7 +539,7 @@ class InstanceTest(unittest.TestCase): | |||||||
|         doc.save() |         doc.save() | ||||||
|         doc.dict_field['extra'] = 1 |         doc.dict_field['extra'] = 1 | ||||||
|         doc = doc.reload(10, 'list_field') |         doc = doc.reload(10, 'list_field') | ||||||
|         self.assertEqual(doc._get_changed_fields(), []) |         self.assertEqual(doc._get_changed_fields(), ['dict_field.extra']) | ||||||
|         self.assertEqual(len(doc.list_field), 5) |         self.assertEqual(len(doc.list_field), 5) | ||||||
|         self.assertEqual(len(doc.dict_field), 3) |         self.assertEqual(len(doc.dict_field), 3) | ||||||
|         self.assertEqual(len(doc.embedded_field.list_field), 4) |         self.assertEqual(len(doc.embedded_field.list_field), 4) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user