From cace665858f78782d6f0aaecf9cbd68d39c2f224 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Tue, 12 Jul 2011 10:20:36 +0100 Subject: [PATCH] _delta checking didn't handle db_field_names at all Fixed and added tests, thanks to @wpjunior and @iapain for initial test cases [fixes #226] --- mongoengine/base.py | 11 +- tests/document.py | 274 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 281 insertions(+), 4 deletions(-) diff --git a/mongoengine/base.py b/mongoengine/base.py index 25b049a3..c2f4d214 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -381,6 +381,7 @@ class DocumentMetaclass(type): attr_value.db_field = attr_name doc_fields[attr_name] = attr_value attrs['_fields'] = doc_fields + attrs['_db_field_map'] = dict([(k, v.db_field) for k, v in doc_fields.items()]) new_class = super_new(cls, name, bases, attrs) for field in new_class._fields.values(): @@ -696,6 +697,7 @@ class BaseDocument(object): """ if not key: return + key = self._db_field_map.get(key, key) if hasattr(self, '_changed_fields') and key not in self._changed_fields: self._changed_fields.append(key) @@ -705,13 +707,13 @@ class BaseDocument(object): from mongoengine import EmbeddedDocument _changed_fields = [] _changed_fields += getattr(self, '_changed_fields', []) - for field_name in self._fields: - key = '%s.' % field_name + db_field_name = self._db_field_map.get(field_name, field_name) + key = '%s.' % db_field_name field = getattr(self, field_name, None) - if isinstance(field, EmbeddedDocument) and field_name not in _changed_fields: # Grab all embedded fields that have been changed + if isinstance(field, EmbeddedDocument) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed _changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k] - elif isinstance(field, (list, tuple)) and field_name not in _changed_fields: # Loop list fields as they contain documents + elif isinstance(field, (list, tuple)) and db_field_name not in _changed_fields: # Loop list fields as they contain documents for index, value in enumerate(field): if not hasattr(value, '_get_changed_fields'): continue @@ -726,6 +728,7 @@ class BaseDocument(object): # Handles cases where not loaded from_son but has _id doc = self.to_mongo() set_fields = self._get_changed_fields() + set_data = {} unset_data = {} if hasattr(self, '_changed_fields'): diff --git a/tests/document.py b/tests/document.py index a8164697..df3b4fa1 100644 --- a/tests/document.py +++ b/tests/document.py @@ -1203,6 +1203,59 @@ class DocumentTest(unittest.TestCase): self.assertEqual(person.name, None) self.assertEqual(person.age, None) + def test_embedded_update(self): + """ + Test update on `EmbeddedDocumentField` fields + """ + + class Page(EmbeddedDocument): + log_message = StringField(verbose_name="Log message", + required=True) + + class Site(Document): + page = EmbeddedDocumentField(Page) + + + Site.drop_collection() + site = Site(page=Page(log_message="Warning: Dummy message")) + site.save() + + # Update + site = Site.objects.first() + site.page.log_message = "Error: Dummy message" + site.save() + + site = Site.objects.first() + self.assertEqual(site.page.log_message, "Error: Dummy message") + + def test_embedded_update_db_field(self): + """ + Test update on `EmbeddedDocumentField` fields when db_field is other + than default. + """ + + class Page(EmbeddedDocument): + log_message = StringField(verbose_name="Log message", + db_field="page_log_message", + required=True) + + class Site(Document): + page = EmbeddedDocumentField(Page) + + + Site.drop_collection() + + site = Site(page=Page(log_message="Warning: Dummy message")) + site.save() + + # Update + site = Site.objects.first() + site.page.log_message = "Error: Dummy message" + site.save() + + site = Site.objects.first() + self.assertEqual(site.page.log_message, "Error: Dummy message") + def test_delta(self): class Doc(Document): @@ -1408,6 +1461,227 @@ class DocumentTest(unittest.TestCase): del(doc.embedded_field.list_field[2].list_field) self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field.2.list_field': 1})) + def test_delta_db_field(self): + + class Doc(Document): + string_field = StringField(db_field='db_string_field') + int_field = IntField(db_field='db_int_field') + dict_field = DictField(db_field='db_dict_field') + list_field = ListField(db_field='db_list_field') + + Doc.drop_collection() + doc = Doc() + doc.save() + + doc = Doc.objects.first() + self.assertEquals(doc._get_changed_fields(), []) + self.assertEquals(doc._delta(), ({}, {})) + + doc.string_field = 'hello' + self.assertEquals(doc._get_changed_fields(), ['db_string_field']) + self.assertEquals(doc._delta(), ({'db_string_field': 'hello'}, {})) + + doc._changed_fields = [] + doc.int_field = 1 + self.assertEquals(doc._get_changed_fields(), ['db_int_field']) + self.assertEquals(doc._delta(), ({'db_int_field': 1}, {})) + + doc._changed_fields = [] + dict_value = {'hello': 'world', 'ping': 'pong'} + doc.dict_field = dict_value + self.assertEquals(doc._get_changed_fields(), ['db_dict_field']) + self.assertEquals(doc._delta(), ({'db_dict_field': dict_value}, {})) + + doc._changed_fields = [] + list_value = ['1', 2, {'hello': 'world'}] + doc.list_field = list_value + self.assertEquals(doc._get_changed_fields(), ['db_list_field']) + self.assertEquals(doc._delta(), ({'db_list_field': list_value}, {})) + + # Test unsetting + doc._changed_fields = [] + doc.dict_field = {} + self.assertEquals(doc._get_changed_fields(), ['db_dict_field']) + self.assertEquals(doc._delta(), ({}, {'db_dict_field': 1})) + + doc._changed_fields = [] + doc.list_field = [] + self.assertEquals(doc._get_changed_fields(), ['db_list_field']) + self.assertEquals(doc._delta(), ({}, {'db_list_field': 1})) + + # Test it saves that data + doc = Doc() + doc.save() + + doc.string_field = 'hello' + doc.int_field = 1 + doc.dict_field = {'hello': 'world'} + doc.list_field = ['1', 2, {'hello': 'world'}] + doc.save() + doc.reload() + + self.assertEquals(doc.string_field, 'hello') + self.assertEquals(doc.int_field, 1) + self.assertEquals(doc.dict_field, {'hello': 'world'}) + self.assertEquals(doc.list_field, ['1', 2, {'hello': 'world'}]) + + def test_delta_recursive_db_field(self): + + class Embedded(EmbeddedDocument): + string_field = StringField(db_field='db_string_field') + int_field = IntField(db_field='db_int_field') + dict_field = DictField(db_field='db_dict_field') + list_field = ListField(db_field='db_list_field') + + class Doc(Document): + string_field = StringField(db_field='db_string_field') + int_field = IntField(db_field='db_int_field') + dict_field = DictField(db_field='db_dict_field') + list_field = ListField(db_field='db_list_field') + embedded_field = EmbeddedDocumentField(Embedded, db_field='db_embedded_field') + + Doc.drop_collection() + doc = Doc() + doc.save() + + doc = Doc.objects.first() + self.assertEquals(doc._get_changed_fields(), []) + self.assertEquals(doc._delta(), ({}, {})) + + embedded_1 = Embedded() + embedded_1.string_field = 'hello' + embedded_1.int_field = 1 + embedded_1.dict_field = {'hello': 'world'} + embedded_1.list_field = ['1', 2, {'hello': 'world'}] + doc.embedded_field = embedded_1 + + self.assertEquals(doc._get_changed_fields(), ['db_embedded_field']) + + embedded_delta = { + '_types': ['Embedded'], + '_cls': 'Embedded', + 'db_string_field': 'hello', + 'db_int_field': 1, + 'db_dict_field': {'hello': 'world'}, + 'db_list_field': ['1', 2, {'hello': 'world'}] + } + self.assertEquals(doc.embedded_field._delta(), (embedded_delta, {})) + self.assertEquals(doc._delta(), ({'db_embedded_field': embedded_delta}, {})) + + doc.save() + doc.reload() + + doc.embedded_field.dict_field = {} + self.assertEquals(doc._get_changed_fields(), ['db_embedded_field.db_dict_field']) + self.assertEquals(doc.embedded_field._delta(), ({}, {'db_dict_field': 1})) + self.assertEquals(doc._delta(), ({}, {'db_embedded_field.db_dict_field': 1})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.dict_field, {}) + + doc.embedded_field.list_field = [] + self.assertEquals(doc._get_changed_fields(), ['db_embedded_field.db_list_field']) + self.assertEquals(doc.embedded_field._delta(), ({}, {'db_list_field': 1})) + self.assertEquals(doc._delta(), ({}, {'db_embedded_field.db_list_field': 1})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field, []) + + embedded_2 = Embedded() + embedded_2.string_field = 'hello' + embedded_2.int_field = 1 + embedded_2.dict_field = {'hello': 'world'} + embedded_2.list_field = ['1', 2, {'hello': 'world'}] + + doc.embedded_field.list_field = ['1', 2, embedded_2] + self.assertEquals(doc._get_changed_fields(), ['db_embedded_field.db_list_field']) + self.assertEquals(doc.embedded_field._delta(), ({ + 'db_list_field': ['1', 2, { + '_cls': 'Embedded', + '_types': ['Embedded'], + 'db_string_field': 'hello', + 'db_dict_field': {'hello': 'world'}, + 'db_int_field': 1, + 'db_list_field': ['1', 2, {'hello': 'world'}], + }] + }, {})) + + self.assertEquals(doc._delta(), ({ + 'db_embedded_field.db_list_field': ['1', 2, { + '_cls': 'Embedded', + '_types': ['Embedded'], + 'db_string_field': 'hello', + 'db_dict_field': {'hello': 'world'}, + 'db_int_field': 1, + 'db_list_field': ['1', 2, {'hello': 'world'}], + }] + }, {})) + doc.save() + doc.reload() + + self.assertEquals(doc.embedded_field.list_field[0], '1') + self.assertEquals(doc.embedded_field.list_field[1], 2) + for k in doc.embedded_field.list_field[2]._fields: + self.assertEquals(doc.embedded_field.list_field[2][k], embedded_2[k]) + + doc.embedded_field.list_field[2].string_field = 'world' + self.assertEquals(doc._get_changed_fields(), ['db_embedded_field.db_list_field.2.db_string_field']) + self.assertEquals(doc.embedded_field._delta(), ({'db_list_field.2.db_string_field': 'world'}, {})) + self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_string_field': 'world'}, {})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field[2].string_field, 'world') + + # Test multiple assignments + doc.embedded_field.list_field[2].string_field = 'hello world' + doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2] + self.assertEquals(doc._get_changed_fields(), ['db_embedded_field.db_list_field']) + self.assertEquals(doc.embedded_field._delta(), ({ + 'db_list_field': ['1', 2, { + '_types': ['Embedded'], + '_cls': 'Embedded', + 'db_string_field': 'hello world', + 'db_int_field': 1, + 'db_list_field': ['1', 2, {'hello': 'world'}], + 'db_dict_field': {'hello': 'world'}}]}, {})) + self.assertEquals(doc._delta(), ({ + 'db_embedded_field.db_list_field': ['1', 2, { + '_types': ['Embedded'], + '_cls': 'Embedded', + 'db_string_field': 'hello world', + 'db_int_field': 1, + 'db_list_field': ['1', 2, {'hello': 'world'}], + 'db_dict_field': {'hello': 'world'}} + ]}, {})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field[2].string_field, 'hello world') + + # Test list native methods + doc.embedded_field.list_field[2].list_field.pop(0) + self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [2, {'hello': 'world'}]}, {})) + doc.save() + doc.reload() + + doc.embedded_field.list_field[2].list_field.append(1) + self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [2, {'hello': 'world'}, 1]}, {})) + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1]) + + doc.embedded_field.list_field[2].list_field.sort() + doc.save() + doc.reload() + self.assertEquals(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}]) + + del(doc.embedded_field.list_field[2].list_field[2]['hello']) + self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [1, 2, {}]}, {})) + doc.save() + doc.reload() + + del(doc.embedded_field.list_field[2].list_field) + self.assertEquals(doc._delta(), ({}, {'db_embedded_field.db_list_field.2.db_list_field': 1})) + def test_save_only_changed_fields(self): """Ensure save only sets / unsets changed fields """