_delta checking didn't handle db_field_names at all

Fixed and added tests, thanks to @wpjunior and @iapain for initial test cases
[fixes #226]
This commit is contained in:
Ross Lawley 2011-07-12 10:20:36 +01:00
parent 2a8d001213
commit cace665858
2 changed files with 281 additions and 4 deletions

View File

@ -381,6 +381,7 @@ class DocumentMetaclass(type):
attr_value.db_field = attr_name
doc_fields[attr_name] = attr_value
attrs['_fields'] = doc_fields
attrs['_db_field_map'] = dict([(k, v.db_field) for k, v in doc_fields.items()])
new_class = super_new(cls, name, bases, attrs)
for field in new_class._fields.values():
@ -696,6 +697,7 @@ class BaseDocument(object):
"""
if not key:
return
key = self._db_field_map.get(key, key)
if hasattr(self, '_changed_fields') and key not in self._changed_fields:
self._changed_fields.append(key)
@ -705,13 +707,13 @@ class BaseDocument(object):
from mongoengine import EmbeddedDocument
_changed_fields = []
_changed_fields += getattr(self, '_changed_fields', [])
for field_name in self._fields:
key = '%s.' % field_name
db_field_name = self._db_field_map.get(field_name, field_name)
key = '%s.' % db_field_name
field = getattr(self, field_name, None)
if isinstance(field, EmbeddedDocument) and field_name not in _changed_fields: # Grab all embedded fields that have been changed
if isinstance(field, EmbeddedDocument) and db_field_name not in _changed_fields: # Grab all embedded fields that have been changed
_changed_fields += ["%s%s" % (key, k) for k in field._get_changed_fields(key) if k]
elif isinstance(field, (list, tuple)) and field_name not in _changed_fields: # Loop list fields as they contain documents
elif isinstance(field, (list, tuple)) and db_field_name not in _changed_fields: # Loop list fields as they contain documents
for index, value in enumerate(field):
if not hasattr(value, '_get_changed_fields'):
continue
@ -726,6 +728,7 @@ class BaseDocument(object):
# Handles cases where not loaded from_son but has _id
doc = self.to_mongo()
set_fields = self._get_changed_fields()
set_data = {}
unset_data = {}
if hasattr(self, '_changed_fields'):

View File

@ -1203,6 +1203,59 @@ class DocumentTest(unittest.TestCase):
self.assertEqual(person.name, None)
self.assertEqual(person.age, None)
def test_embedded_update(self):
"""
Test update on `EmbeddedDocumentField` fields
"""
class Page(EmbeddedDocument):
log_message = StringField(verbose_name="Log message",
required=True)
class Site(Document):
page = EmbeddedDocumentField(Page)
Site.drop_collection()
site = Site(page=Page(log_message="Warning: Dummy message"))
site.save()
# Update
site = Site.objects.first()
site.page.log_message = "Error: Dummy message"
site.save()
site = Site.objects.first()
self.assertEqual(site.page.log_message, "Error: Dummy message")
def test_embedded_update_db_field(self):
"""
Test update on `EmbeddedDocumentField` fields when db_field is other
than default.
"""
class Page(EmbeddedDocument):
log_message = StringField(verbose_name="Log message",
db_field="page_log_message",
required=True)
class Site(Document):
page = EmbeddedDocumentField(Page)
Site.drop_collection()
site = Site(page=Page(log_message="Warning: Dummy message"))
site.save()
# Update
site = Site.objects.first()
site.page.log_message = "Error: Dummy message"
site.save()
site = Site.objects.first()
self.assertEqual(site.page.log_message, "Error: Dummy message")
def test_delta(self):
class Doc(Document):
@ -1408,6 +1461,227 @@ class DocumentTest(unittest.TestCase):
del(doc.embedded_field.list_field[2].list_field)
self.assertEquals(doc._delta(), ({}, {'embedded_field.list_field.2.list_field': 1}))
def test_delta_db_field(self):
class Doc(Document):
string_field = StringField(db_field='db_string_field')
int_field = IntField(db_field='db_int_field')
dict_field = DictField(db_field='db_dict_field')
list_field = ListField(db_field='db_list_field')
Doc.drop_collection()
doc = Doc()
doc.save()
doc = Doc.objects.first()
self.assertEquals(doc._get_changed_fields(), [])
self.assertEquals(doc._delta(), ({}, {}))
doc.string_field = 'hello'
self.assertEquals(doc._get_changed_fields(), ['db_string_field'])
self.assertEquals(doc._delta(), ({'db_string_field': 'hello'}, {}))
doc._changed_fields = []
doc.int_field = 1
self.assertEquals(doc._get_changed_fields(), ['db_int_field'])
self.assertEquals(doc._delta(), ({'db_int_field': 1}, {}))
doc._changed_fields = []
dict_value = {'hello': 'world', 'ping': 'pong'}
doc.dict_field = dict_value
self.assertEquals(doc._get_changed_fields(), ['db_dict_field'])
self.assertEquals(doc._delta(), ({'db_dict_field': dict_value}, {}))
doc._changed_fields = []
list_value = ['1', 2, {'hello': 'world'}]
doc.list_field = list_value
self.assertEquals(doc._get_changed_fields(), ['db_list_field'])
self.assertEquals(doc._delta(), ({'db_list_field': list_value}, {}))
# Test unsetting
doc._changed_fields = []
doc.dict_field = {}
self.assertEquals(doc._get_changed_fields(), ['db_dict_field'])
self.assertEquals(doc._delta(), ({}, {'db_dict_field': 1}))
doc._changed_fields = []
doc.list_field = []
self.assertEquals(doc._get_changed_fields(), ['db_list_field'])
self.assertEquals(doc._delta(), ({}, {'db_list_field': 1}))
# Test it saves that data
doc = Doc()
doc.save()
doc.string_field = 'hello'
doc.int_field = 1
doc.dict_field = {'hello': 'world'}
doc.list_field = ['1', 2, {'hello': 'world'}]
doc.save()
doc.reload()
self.assertEquals(doc.string_field, 'hello')
self.assertEquals(doc.int_field, 1)
self.assertEquals(doc.dict_field, {'hello': 'world'})
self.assertEquals(doc.list_field, ['1', 2, {'hello': 'world'}])
def test_delta_recursive_db_field(self):
class Embedded(EmbeddedDocument):
string_field = StringField(db_field='db_string_field')
int_field = IntField(db_field='db_int_field')
dict_field = DictField(db_field='db_dict_field')
list_field = ListField(db_field='db_list_field')
class Doc(Document):
string_field = StringField(db_field='db_string_field')
int_field = IntField(db_field='db_int_field')
dict_field = DictField(db_field='db_dict_field')
list_field = ListField(db_field='db_list_field')
embedded_field = EmbeddedDocumentField(Embedded, db_field='db_embedded_field')
Doc.drop_collection()
doc = Doc()
doc.save()
doc = Doc.objects.first()
self.assertEquals(doc._get_changed_fields(), [])
self.assertEquals(doc._delta(), ({}, {}))
embedded_1 = Embedded()
embedded_1.string_field = 'hello'
embedded_1.int_field = 1
embedded_1.dict_field = {'hello': 'world'}
embedded_1.list_field = ['1', 2, {'hello': 'world'}]
doc.embedded_field = embedded_1
self.assertEquals(doc._get_changed_fields(), ['db_embedded_field'])
embedded_delta = {
'_types': ['Embedded'],
'_cls': 'Embedded',
'db_string_field': 'hello',
'db_int_field': 1,
'db_dict_field': {'hello': 'world'},
'db_list_field': ['1', 2, {'hello': 'world'}]
}
self.assertEquals(doc.embedded_field._delta(), (embedded_delta, {}))
self.assertEquals(doc._delta(), ({'db_embedded_field': embedded_delta}, {}))
doc.save()
doc.reload()
doc.embedded_field.dict_field = {}
self.assertEquals(doc._get_changed_fields(), ['db_embedded_field.db_dict_field'])
self.assertEquals(doc.embedded_field._delta(), ({}, {'db_dict_field': 1}))
self.assertEquals(doc._delta(), ({}, {'db_embedded_field.db_dict_field': 1}))
doc.save()
doc.reload()
self.assertEquals(doc.embedded_field.dict_field, {})
doc.embedded_field.list_field = []
self.assertEquals(doc._get_changed_fields(), ['db_embedded_field.db_list_field'])
self.assertEquals(doc.embedded_field._delta(), ({}, {'db_list_field': 1}))
self.assertEquals(doc._delta(), ({}, {'db_embedded_field.db_list_field': 1}))
doc.save()
doc.reload()
self.assertEquals(doc.embedded_field.list_field, [])
embedded_2 = Embedded()
embedded_2.string_field = 'hello'
embedded_2.int_field = 1
embedded_2.dict_field = {'hello': 'world'}
embedded_2.list_field = ['1', 2, {'hello': 'world'}]
doc.embedded_field.list_field = ['1', 2, embedded_2]
self.assertEquals(doc._get_changed_fields(), ['db_embedded_field.db_list_field'])
self.assertEquals(doc.embedded_field._delta(), ({
'db_list_field': ['1', 2, {
'_cls': 'Embedded',
'_types': ['Embedded'],
'db_string_field': 'hello',
'db_dict_field': {'hello': 'world'},
'db_int_field': 1,
'db_list_field': ['1', 2, {'hello': 'world'}],
}]
}, {}))
self.assertEquals(doc._delta(), ({
'db_embedded_field.db_list_field': ['1', 2, {
'_cls': 'Embedded',
'_types': ['Embedded'],
'db_string_field': 'hello',
'db_dict_field': {'hello': 'world'},
'db_int_field': 1,
'db_list_field': ['1', 2, {'hello': 'world'}],
}]
}, {}))
doc.save()
doc.reload()
self.assertEquals(doc.embedded_field.list_field[0], '1')
self.assertEquals(doc.embedded_field.list_field[1], 2)
for k in doc.embedded_field.list_field[2]._fields:
self.assertEquals(doc.embedded_field.list_field[2][k], embedded_2[k])
doc.embedded_field.list_field[2].string_field = 'world'
self.assertEquals(doc._get_changed_fields(), ['db_embedded_field.db_list_field.2.db_string_field'])
self.assertEquals(doc.embedded_field._delta(), ({'db_list_field.2.db_string_field': 'world'}, {}))
self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_string_field': 'world'}, {}))
doc.save()
doc.reload()
self.assertEquals(doc.embedded_field.list_field[2].string_field, 'world')
# Test multiple assignments
doc.embedded_field.list_field[2].string_field = 'hello world'
doc.embedded_field.list_field[2] = doc.embedded_field.list_field[2]
self.assertEquals(doc._get_changed_fields(), ['db_embedded_field.db_list_field'])
self.assertEquals(doc.embedded_field._delta(), ({
'db_list_field': ['1', 2, {
'_types': ['Embedded'],
'_cls': 'Embedded',
'db_string_field': 'hello world',
'db_int_field': 1,
'db_list_field': ['1', 2, {'hello': 'world'}],
'db_dict_field': {'hello': 'world'}}]}, {}))
self.assertEquals(doc._delta(), ({
'db_embedded_field.db_list_field': ['1', 2, {
'_types': ['Embedded'],
'_cls': 'Embedded',
'db_string_field': 'hello world',
'db_int_field': 1,
'db_list_field': ['1', 2, {'hello': 'world'}],
'db_dict_field': {'hello': 'world'}}
]}, {}))
doc.save()
doc.reload()
self.assertEquals(doc.embedded_field.list_field[2].string_field, 'hello world')
# Test list native methods
doc.embedded_field.list_field[2].list_field.pop(0)
self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [2, {'hello': 'world'}]}, {}))
doc.save()
doc.reload()
doc.embedded_field.list_field[2].list_field.append(1)
self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [2, {'hello': 'world'}, 1]}, {}))
doc.save()
doc.reload()
self.assertEquals(doc.embedded_field.list_field[2].list_field, [2, {'hello': 'world'}, 1])
doc.embedded_field.list_field[2].list_field.sort()
doc.save()
doc.reload()
self.assertEquals(doc.embedded_field.list_field[2].list_field, [1, 2, {'hello': 'world'}])
del(doc.embedded_field.list_field[2].list_field[2]['hello'])
self.assertEquals(doc._delta(), ({'db_embedded_field.db_list_field.2.db_list_field': [1, 2, {}]}, {}))
doc.save()
doc.reload()
del(doc.embedded_field.list_field[2].list_field)
self.assertEquals(doc._delta(), ({}, {'db_embedded_field.db_list_field.2.db_list_field': 1}))
def test_save_only_changed_fields(self):
"""Ensure save only sets / unsets changed fields
"""