From 7ddbea697e84dc8a9b06dbc144a03a27706a21d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Thu, 30 Aug 2018 14:33:57 +0200 Subject: [PATCH] fix CI that fails due to pypi + override BaseDict.get as it was missing --- mongoengine/base/datastructures.py | 9 +++++- tests/test_datastructures.py | 46 ++++++++++++++++++++---------- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index 6340dcdd..d5faa71b 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -43,7 +43,14 @@ class BaseDict(dict): self._name = name super(BaseDict, self).__init__(dict_items) - def __getitem__(self, key, *args, **kwargs): + def get(self, key, default=None): + # get does not use __getitem__ by default so we must override it as well + try: + return self.__getitem__(key) + except KeyError: + return default + + def __getitem__(self, key): value = super(BaseDict, self).__getitem__(key) EmbeddedDocument = _import_class('EmbeddedDocument') diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index f2ec098e..2f1277e6 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -68,6 +68,11 @@ class TestBaseDict(unittest.TestCase): self.assertEqual(base_dict._instance._changed_fields, ['my_name.k']) self.assertEqual(base_dict, {}) + def test___getitem____KeyError(self): + base_dict = self._get_basedict({}) + with self.assertRaises(KeyError): + base_dict['new'] + def test___getitem____simple_value(self): base_dict = self._get_basedict({'k': 'v'}) base_dict['k'] = 'v' @@ -98,6 +103,24 @@ class TestBaseDict(unittest.TestCase): sub_dict['subk'] = None self.assertEqual(base_dict._instance._changed_fields, ['my_name.k.subk']) + def test_get_sublist_gets_converted_to_BaseList_just_like__getitem__(self): + base_dict = self._get_basedict({'k': [0, 1, 2]}) + sub_list = base_dict.get('k') + self.assertEqual(sub_list, [0, 1, 2]) + self.assertIsInstance(sub_list, BaseList) + + def test_get_returns_the_same_as___getitem__(self): + base_dict = self._get_basedict({'k': [0, 1, 2]}) + get_ = base_dict.get('k') + getitem_ = base_dict['k'] + self.assertEqual(get_, getitem_) + + def test_get_default(self): + base_dict = self._get_basedict({}) + sentinel = object() + self.assertEqual(base_dict.get('new'), None) + self.assertIs(base_dict.get('new', sentinel), sentinel) + def test___setitem___calls_mark_as_changed(self): base_dict = self._get_basedict({}) base_dict['k'] = 'v' @@ -244,10 +267,9 @@ class TestBaseList(unittest.TestCase): def test_remove_not_mark_as_changed_when_it_fails(self): base_list = self._get_baselist([True]) - try: + with self.assertRaises(ValueError): base_list.remove(False) - except ValueError: - self.assertFalse(base_list._instance._changed_fields) + self.assertFalse(base_list._instance._changed_fields) def test_pop_calls_mark_as_changed(self): base_list = self._get_baselist([True]) @@ -311,24 +333,18 @@ class TestBaseList(unittest.TestCase): base_list += [False] self.assertEqual(base_list._instance._changed_fields, ['my_name']) - def test___iadd___not_mark_as_changed_when_it_fails(self): - base_list = self._get_baselist([True]) - try: - base_list += None - except TypeError: - self.assertFalse(base_list._instance._changed_fields) - def test___imul___calls_mark_as_changed(self): base_list = self._get_baselist([True]) + self.assertEqual(base_list._instance._changed_fields, []) base_list *= 2 self.assertEqual(base_list._instance._changed_fields, ['my_name']) - def test___imul___not_mark_as_changed_when_it_fails(self): + def test_sort_calls_not_marked_as_changed_when_it_fails(self): base_list = self._get_baselist([True]) - try: - base_list *= None - except TypeError: - self.assertFalse(base_list._instance._changed_fields) + with self.assertRaises(TypeError): + base_list.sort(key=1) + + self.assertEqual(base_list._instance._changed_fields, []) def test_sort_calls_mark_as_changed(self): base_list = self._get_baselist([True, False])