fix CI that fails due to pypi + override BaseDict.get as it was missing

This commit is contained in:
Bastien Gérard 2018-08-30 14:33:57 +02:00
parent d72daf5f39
commit 7ddbea697e
2 changed files with 39 additions and 16 deletions

View File

@ -43,7 +43,14 @@ class BaseDict(dict):
self._name = name self._name = name
super(BaseDict, self).__init__(dict_items) super(BaseDict, self).__init__(dict_items)
def __getitem__(self, key, *args, **kwargs): def get(self, key, default=None):
# get does not use __getitem__ by default so we must override it as well
try:
return self.__getitem__(key)
except KeyError:
return default
def __getitem__(self, key):
value = super(BaseDict, self).__getitem__(key) value = super(BaseDict, self).__getitem__(key)
EmbeddedDocument = _import_class('EmbeddedDocument') EmbeddedDocument = _import_class('EmbeddedDocument')

View File

@ -68,6 +68,11 @@ class TestBaseDict(unittest.TestCase):
self.assertEqual(base_dict._instance._changed_fields, ['my_name.k']) self.assertEqual(base_dict._instance._changed_fields, ['my_name.k'])
self.assertEqual(base_dict, {}) self.assertEqual(base_dict, {})
def test___getitem____KeyError(self):
base_dict = self._get_basedict({})
with self.assertRaises(KeyError):
base_dict['new']
def test___getitem____simple_value(self): def test___getitem____simple_value(self):
base_dict = self._get_basedict({'k': 'v'}) base_dict = self._get_basedict({'k': 'v'})
base_dict['k'] = 'v' base_dict['k'] = 'v'
@ -98,6 +103,24 @@ class TestBaseDict(unittest.TestCase):
sub_dict['subk'] = None sub_dict['subk'] = None
self.assertEqual(base_dict._instance._changed_fields, ['my_name.k.subk']) self.assertEqual(base_dict._instance._changed_fields, ['my_name.k.subk'])
def test_get_sublist_gets_converted_to_BaseList_just_like__getitem__(self):
base_dict = self._get_basedict({'k': [0, 1, 2]})
sub_list = base_dict.get('k')
self.assertEqual(sub_list, [0, 1, 2])
self.assertIsInstance(sub_list, BaseList)
def test_get_returns_the_same_as___getitem__(self):
base_dict = self._get_basedict({'k': [0, 1, 2]})
get_ = base_dict.get('k')
getitem_ = base_dict['k']
self.assertEqual(get_, getitem_)
def test_get_default(self):
base_dict = self._get_basedict({})
sentinel = object()
self.assertEqual(base_dict.get('new'), None)
self.assertIs(base_dict.get('new', sentinel), sentinel)
def test___setitem___calls_mark_as_changed(self): def test___setitem___calls_mark_as_changed(self):
base_dict = self._get_basedict({}) base_dict = self._get_basedict({})
base_dict['k'] = 'v' base_dict['k'] = 'v'
@ -244,10 +267,9 @@ class TestBaseList(unittest.TestCase):
def test_remove_not_mark_as_changed_when_it_fails(self): def test_remove_not_mark_as_changed_when_it_fails(self):
base_list = self._get_baselist([True]) base_list = self._get_baselist([True])
try: with self.assertRaises(ValueError):
base_list.remove(False) base_list.remove(False)
except ValueError: self.assertFalse(base_list._instance._changed_fields)
self.assertFalse(base_list._instance._changed_fields)
def test_pop_calls_mark_as_changed(self): def test_pop_calls_mark_as_changed(self):
base_list = self._get_baselist([True]) base_list = self._get_baselist([True])
@ -311,24 +333,18 @@ class TestBaseList(unittest.TestCase):
base_list += [False] base_list += [False]
self.assertEqual(base_list._instance._changed_fields, ['my_name']) self.assertEqual(base_list._instance._changed_fields, ['my_name'])
def test___iadd___not_mark_as_changed_when_it_fails(self):
base_list = self._get_baselist([True])
try:
base_list += None
except TypeError:
self.assertFalse(base_list._instance._changed_fields)
def test___imul___calls_mark_as_changed(self): def test___imul___calls_mark_as_changed(self):
base_list = self._get_baselist([True]) base_list = self._get_baselist([True])
self.assertEqual(base_list._instance._changed_fields, [])
base_list *= 2 base_list *= 2
self.assertEqual(base_list._instance._changed_fields, ['my_name']) self.assertEqual(base_list._instance._changed_fields, ['my_name'])
def test___imul___not_mark_as_changed_when_it_fails(self): def test_sort_calls_not_marked_as_changed_when_it_fails(self):
base_list = self._get_baselist([True]) base_list = self._get_baselist([True])
try: with self.assertRaises(TypeError):
base_list *= None base_list.sort(key=1)
except TypeError:
self.assertFalse(base_list._instance._changed_fields) self.assertEqual(base_list._instance._changed_fields, [])
def test_sort_calls_mark_as_changed(self): def test_sort_calls_mark_as_changed(self):
base_list = self._get_baselist([True, False]) base_list = self._get_baselist([True, False])