diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index 8948243e..6340dcdd 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -1,4 +1,3 @@ -import itertools import weakref from bson import DBRef @@ -7,7 +6,25 @@ import six from mongoengine.common import _import_class from mongoengine.errors import DoesNotExist, MultipleObjectsReturned -__all__ = ('BaseDict', 'BaseList', 'EmbeddedDocumentList', 'LazyReference') +__all__ = ('BaseDict', 'StrictDict', 'BaseList', 'EmbeddedDocumentList', 'LazyReference') + + +def mark_as_changed_wrapper(parent_method): + """Decorators that ensures _mark_as_changed method gets called""" + def wrapper(self, *args, **kwargs): + result = parent_method(self, *args, **kwargs) # Can't use super() in the decorator + self._mark_as_changed() + return result + return wrapper + + +def mark_key_as_changed_wrapper(parent_method): + """Decorators that ensures _mark_as_changed method gets called with the key argument""" + def wrapper(self, key, *args, **kwargs): + result = parent_method(self, key, *args, **kwargs) # Can't use super() in the decorator + self._mark_as_changed(key) + return result + return wrapper class BaseDict(dict): @@ -42,22 +59,6 @@ class BaseDict(dict): value._instance = self._instance return value - def __setitem__(self, key, value, *args, **kwargs): - self._mark_as_changed(key) - return super(BaseDict, self).__setitem__(key, value) - - def __delete__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).__delete__(*args, **kwargs) - - def __delitem__(self, key, *args, **kwargs): - self._mark_as_changed(key) - return super(BaseDict, self).__delitem__(key) - - def __delattr__(self, key, *args, **kwargs): - self._mark_as_changed(key) - return super(BaseDict, self).__delattr__(key) - def __getstate__(self): self.instance = None self._dereferenced = False @@ -67,25 +68,14 @@ class BaseDict(dict): self = state return self - def clear(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).clear() - - def pop(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).pop(*args, **kwargs) - - def popitem(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).popitem() - - def setdefault(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).setdefault(*args, **kwargs) - - def update(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseDict, self).update(*args, **kwargs) + __setitem__ = mark_key_as_changed_wrapper(dict.__setitem__) + __delattr__ = mark_key_as_changed_wrapper(dict.__delattr__) + __delitem__ = mark_key_as_changed_wrapper(dict.__delitem__) + pop = mark_as_changed_wrapper(dict.pop) + clear = mark_as_changed_wrapper(dict.clear) + update = mark_as_changed_wrapper(dict.update) + popitem = mark_as_changed_wrapper(dict.popitem) + setdefault = mark_as_changed_wrapper(dict.setdefault) def _mark_as_changed(self, key=None): if hasattr(self._instance, '_mark_as_changed'): @@ -111,17 +101,24 @@ class BaseList(list): self._name = name super(BaseList, self).__init__(list_items) - def __getitem__(self, key, *args, **kwargs): + def __getitem__(self, key): value = super(BaseList, self).__getitem__(key) + if isinstance(key, slice): + # When receiving a slice operator, we don't convert the structure and bind + # to parent's instance. This is buggy for now but would require more work to be handled properly + return value + EmbeddedDocument = _import_class('EmbeddedDocument') if isinstance(value, EmbeddedDocument) and value._instance is None: value._instance = self._instance elif not isinstance(value, BaseDict) and isinstance(value, dict): + # Replace dict by BaseDict value = BaseDict(value, None, '%s.%s' % (self._name, key)) super(BaseList, self).__setitem__(key, value) value._instance = self._instance elif not isinstance(value, BaseList) and isinstance(value, list): + # Replace list by BaseList value = BaseList(value, None, '%s.%s' % (self._name, key)) super(BaseList, self).__setitem__(key, value) value._instance = self._instance @@ -131,25 +128,6 @@ class BaseList(list): for v in super(BaseList, self).__iter__(): yield v - def __setitem__(self, key, value, *args, **kwargs): - if isinstance(key, slice): - self._mark_as_changed() - else: - self._mark_as_changed(key) - return super(BaseList, self).__setitem__(key, value) - - def __delitem__(self, key): - self._mark_as_changed() - return super(BaseList, self).__delitem__(key) - - def __setslice__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).__setslice__(*args, **kwargs) - - def __delslice__(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).__delslice__(*args, **kwargs) - def __getstate__(self): self.instance = None self._dereferenced = False @@ -159,41 +137,40 @@ class BaseList(list): self = state return self - def __iadd__(self, other): - self._mark_as_changed() - return super(BaseList, self).__iadd__(other) + def __setitem__(self, key, value): + changed_key = key + if isinstance(key, slice): + # In case of slice, we don't bother to identify the exact elements being updated + # instead, we simply marks the whole list as changed + changed_key = None - def __imul__(self, other): - self._mark_as_changed() - return super(BaseList, self).__imul__(other) + result = super(BaseList, self).__setitem__(key, value) + self._mark_as_changed(changed_key) + return result - def append(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).append(*args, **kwargs) + append = mark_as_changed_wrapper(list.append) + extend = mark_as_changed_wrapper(list.extend) + insert = mark_as_changed_wrapper(list.insert) + pop = mark_as_changed_wrapper(list.pop) + remove = mark_as_changed_wrapper(list.remove) + reverse = mark_as_changed_wrapper(list.reverse) + sort = mark_as_changed_wrapper(list.sort) + __delitem__ = mark_as_changed_wrapper(list.__delitem__) + __iadd__ = mark_as_changed_wrapper(list.__iadd__) + __imul__ = mark_as_changed_wrapper(list.__imul__) - def extend(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).extend(*args, **kwargs) + if six.PY2: + # Under py3 __setslice__, __delslice__ and __getslice__ + # are replaced by __setitem__, __delitem__ and __getitem__ with a slice as parameter + # so we mimic this under python 2 + def __setslice__(self, i, j, sequence): + return self.__setitem__(slice(i, j), sequence) - def insert(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).insert(*args, **kwargs) + def __delslice__(self, i, j): + return self.__delitem__(slice(i, j)) - def pop(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).pop(*args, **kwargs) - - def remove(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).remove(*args, **kwargs) - - def reverse(self): - self._mark_as_changed() - return super(BaseList, self).reverse() - - def sort(self, *args, **kwargs): - self._mark_as_changed() - return super(BaseList, self).sort(*args, **kwargs) + def __getslice__(self, i, j): + return self.__getitem__(slice(i, j)) def _mark_as_changed(self, key=None): if hasattr(self._instance, '_mark_as_changed'): @@ -207,6 +184,10 @@ class BaseList(list): class EmbeddedDocumentList(BaseList): + def __init__(self, list_items, instance, name): + super(EmbeddedDocumentList, self).__init__(list_items, instance, name) + self._instance = instance + @classmethod def __match_all(cls, embedded_doc, kwargs): """Return True if a given embedded doc matches all the filter @@ -225,10 +206,6 @@ class EmbeddedDocumentList(BaseList): return embedded_docs return [doc for doc in embedded_docs if cls.__match_all(doc, kwargs)] - def __init__(self, list_items, instance, name): - super(EmbeddedDocumentList, self).__init__(list_items, instance, name) - self._instance = instance - def filter(self, **kwargs): """ Filters the list by only including embedded documents with the diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index 1ea562a5..f2ec098e 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,22 +1,345 @@ import unittest -from mongoengine.base.datastructures import StrictDict, BaseList +from mongoengine import Document +from mongoengine.base.datastructures import StrictDict, BaseList, BaseDict + + +class DocumentStub(object): + def __init__(self): + self._changed_fields = [] + + def _mark_as_changed(self, key): + self._changed_fields.append(key) + + +class TestBaseDict(unittest.TestCase): + + @staticmethod + def _get_basedict(dict_items): + """Get a BaseList bound to a fake document instance""" + fake_doc = DocumentStub() + base_list = BaseDict(dict_items, instance=None, name='my_name') + base_list._instance = fake_doc # hack to inject the mock, it does not work in the constructor + return base_list + + def test___init___(self): + class MyDoc(Document): + pass + + dict_items = {'k': 'v'} + doc = MyDoc() + base_dict = BaseDict(dict_items, instance=doc, name='my_name') + self.assertIsInstance(base_dict._instance, Document) + self.assertEqual(base_dict._name, 'my_name') + self.assertEqual(base_dict, dict_items) + + def test_setdefault_calls_mark_as_changed(self): + base_dict = self._get_basedict({}) + base_dict.setdefault('k', 'v') + self.assertEqual(base_dict._instance._changed_fields, [base_dict._name]) + + def test_popitems_calls_mark_as_changed(self): + base_dict = self._get_basedict({'k': 'v'}) + self.assertEqual(base_dict.popitem(), ('k', 'v')) + self.assertEqual(base_dict._instance._changed_fields, [base_dict._name]) + self.assertFalse(base_dict) + + def test_pop_calls_mark_as_changed(self): + base_dict = self._get_basedict({'k': 'v'}) + self.assertEqual(base_dict.pop('k'), 'v') + self.assertEqual(base_dict._instance._changed_fields, [base_dict._name]) + self.assertFalse(base_dict) + + def test_pop_calls_does_not_mark_as_changed_when_it_fails(self): + base_dict = self._get_basedict({'k': 'v'}) + with self.assertRaises(KeyError): + base_dict.pop('X') + self.assertFalse(base_dict._instance._changed_fields) + + def test_clear_calls_mark_as_changed(self): + base_dict = self._get_basedict({'k': 'v'}) + base_dict.clear() + self.assertEqual(base_dict._instance._changed_fields, ['my_name']) + self.assertEqual(base_dict, {}) + + def test___delitem___calls_mark_as_changed(self): + base_dict = self._get_basedict({'k': 'v'}) + del base_dict['k'] + self.assertEqual(base_dict._instance._changed_fields, ['my_name.k']) + self.assertEqual(base_dict, {}) + + def test___getitem____simple_value(self): + base_dict = self._get_basedict({'k': 'v'}) + base_dict['k'] = 'v' + + def test___getitem____sublist_gets_converted_to_BaseList(self): + base_dict = self._get_basedict({'k': [0, 1, 2]}) + sub_list = base_dict['k'] + self.assertEqual(sub_list, [0, 1, 2]) + self.assertIsInstance(sub_list, BaseList) + self.assertIs(sub_list._instance, base_dict._instance) + self.assertEqual(sub_list._name, 'my_name.k') + self.assertEqual(base_dict._instance._changed_fields, []) + + # Challenge mark_as_changed from sublist + sub_list[1] = None + self.assertEqual(base_dict._instance._changed_fields, ['my_name.k.1']) + + def test___getitem____subdict_gets_converted_to_BaseDict(self): + base_dict = self._get_basedict({'k': {'subk': 'subv'}}) + sub_dict = base_dict['k'] + self.assertEqual(sub_dict, {'subk': 'subv'}) + self.assertIsInstance(sub_dict, BaseDict) + self.assertIs(sub_dict._instance, base_dict._instance) + self.assertEqual(sub_dict._name, 'my_name.k') + self.assertEqual(base_dict._instance._changed_fields, []) + + # Challenge mark_as_changed from subdict + sub_dict['subk'] = None + self.assertEqual(base_dict._instance._changed_fields, ['my_name.k.subk']) + + def test___setitem___calls_mark_as_changed(self): + base_dict = self._get_basedict({}) + base_dict['k'] = 'v' + self.assertEqual(base_dict._instance._changed_fields, ['my_name.k']) + self.assertEqual(base_dict, {'k': 'v'}) + + def test_update_calls_mark_as_changed(self): + base_dict = self._get_basedict({}) + base_dict.update({'k': 'v'}) + self.assertEqual(base_dict._instance._changed_fields, ['my_name']) + + def test___setattr____not_tracked_by_changes(self): + base_dict = self._get_basedict({}) + base_dict.a_new_attr = 'test' + self.assertEqual(base_dict._instance._changed_fields, []) + + def test___delattr____tracked_by_changes(self): + # This is probably a bug as __setattr__ is not tracked + # This is even bad because it could be that there is an attribute + # with the same name as a key + base_dict = self._get_basedict({}) + base_dict.a_new_attr = 'test' + del base_dict.a_new_attr + self.assertEqual(base_dict._instance._changed_fields, ['my_name.a_new_attr']) class TestBaseList(unittest.TestCase): - def test_iter_simple(self): + @staticmethod + def _get_baselist(list_items): + """Get a BaseList bound to a fake document instance""" + fake_doc = DocumentStub() + base_list = BaseList(list_items, instance=None, name='my_name') + base_list._instance = fake_doc # hack to inject the mock, it does not work in the constructor + return base_list + + def test___init___(self): + class MyDoc(Document): + pass + + list_items = [True] + doc = MyDoc() + base_list = BaseList(list_items, instance=doc, name='my_name') + self.assertIsInstance(base_list._instance, Document) + self.assertEqual(base_list._name, 'my_name') + self.assertEqual(base_list, list_items) + + def test___iter__(self): values = [True, False, True, False] base_list = BaseList(values, instance=None, name='my_name') self.assertEqual(values, list(base_list)) - def test_iter_allow_modification_while_iterating_withou_error(self): + def test___iter___allow_modification_while_iterating_withou_error(self): # regular list allows for this, thus this subclass must comply to that base_list = BaseList([True, False, True, False], instance=None, name='my_name') for idx, val in enumerate(base_list): if val: base_list.pop(idx) + def test_append_calls_mark_as_changed(self): + base_list = self._get_baselist([]) + self.assertFalse(base_list._instance._changed_fields) + base_list.append(True) + self.assertEqual(base_list._instance._changed_fields, ['my_name']) + + def test_subclass_append(self): + # Due to the way mark_as_changed_wrapper is implemented + # it is good to test subclasses + class SubBaseList(BaseList): + pass + + base_list = SubBaseList([], instance=None, name='my_name') + base_list.append(True) + + def test___getitem__using_simple_index(self): + base_list = self._get_baselist([0, 1, 2]) + self.assertEqual(base_list[0], 0) + self.assertEqual(base_list[1], 1) + self.assertEqual(base_list[-1], 2) + + def test___getitem__using_slice(self): + base_list = self._get_baselist([0, 1, 2]) + self.assertEqual(base_list[1:3], [1,2]) + self.assertEqual(base_list[0:3:2], [0, 2]) + + def test___getitem___using_slice_returns_list(self): + # Bug: using slice does not properly handles the instance + # and mark_as_changed behaviour. + base_list = self._get_baselist([0, 1, 2]) + sliced = base_list[1:3] + self.assertEqual(sliced, [1, 2]) + self.assertIsInstance(sliced, list) + self.assertEqual(base_list._instance._changed_fields, []) + + def test___getitem__sublist_returns_BaseList_bound_to_instance(self): + base_list = self._get_baselist( + [ + [1,2], + [3, 4] + ] + ) + sub_list = base_list[0] + self.assertEqual(sub_list, [1, 2]) + self.assertIsInstance(sub_list, BaseList) + self.assertIs(sub_list._instance, base_list._instance) + self.assertEqual(sub_list._name, 'my_name.0') + self.assertEqual(base_list._instance._changed_fields, []) + + # Challenge mark_as_changed from sublist + sub_list[1] = None + self.assertEqual(base_list._instance._changed_fields, ['my_name.0.1']) + + def test___getitem__subdict_returns_BaseList_bound_to_instance(self): + base_list = self._get_baselist( + [ + {'subk': 'subv'} + ] + ) + sub_dict = base_list[0] + self.assertEqual(sub_dict, {'subk': 'subv'}) + self.assertIsInstance(sub_dict, BaseDict) + self.assertIs(sub_dict._instance, base_list._instance) + self.assertEqual(sub_dict._name, 'my_name.0') + self.assertEqual(base_list._instance._changed_fields, []) + + # Challenge mark_as_changed from subdict + sub_dict['subk'] = None + self.assertEqual(base_list._instance._changed_fields, ['my_name.0.subk']) + + def test_extend_calls_mark_as_changed(self): + base_list = self._get_baselist([]) + base_list.extend([True]) + self.assertEqual(base_list._instance._changed_fields, ['my_name']) + + def test_insert_calls_mark_as_changed(self): + base_list = self._get_baselist([]) + base_list.insert(0, True) + self.assertEqual(base_list._instance._changed_fields, ['my_name']) + + def test_remove_calls_mark_as_changed(self): + base_list = self._get_baselist([True]) + base_list.remove(True) + self.assertEqual(base_list._instance._changed_fields, ['my_name']) + + def test_remove_not_mark_as_changed_when_it_fails(self): + base_list = self._get_baselist([True]) + try: + base_list.remove(False) + except ValueError: + self.assertFalse(base_list._instance._changed_fields) + + def test_pop_calls_mark_as_changed(self): + base_list = self._get_baselist([True]) + base_list.pop() + self.assertEqual(base_list._instance._changed_fields, ['my_name']) + + def test_reverse_calls_mark_as_changed(self): + base_list = self._get_baselist([True, False]) + base_list.reverse() + self.assertEqual(base_list._instance._changed_fields, ['my_name']) + + def test___delitem___calls_mark_as_changed(self): + base_list = self._get_baselist([True]) + del base_list[0] + self.assertEqual(base_list._instance._changed_fields, ['my_name']) + + def test___setitem___calls_with_full_slice_mark_as_changed(self): + base_list = self._get_baselist([]) + base_list[:] = [0, 1] # Will use __setslice__ under py2 and __setitem__ under py3 + self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list, [0, 1]) + + def test___setitem___calls_with_partial_slice_mark_as_changed(self): + base_list = self._get_baselist([0, 1, 2]) + base_list[0:2] = [1, 0] # Will use __setslice__ under py2 and __setitem__ under py3 + self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list, [1, 0, 2]) + + def test___setitem___calls_with_step_slice_mark_as_changed(self): + base_list = self._get_baselist([0, 1, 2]) + base_list[0:3:2] = [-1, -2] # uses __setitem__ in both py2 & 3 + self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list, [-1, 1, -2]) + + def test___setitem___with_slice(self): + base_list = self._get_baselist([0,1,2,3,4,5]) + base_list[0:6:2] = [None, None, None] + self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list, [None,1,None,3,None,5]) + + def test___setitem___item_0_calls_mark_as_changed(self): + base_list = self._get_baselist([True]) + base_list[0] = False + self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list, [False]) + + def test___setitem___item_1_calls_mark_as_changed(self): + base_list = self._get_baselist([True, True]) + base_list[1] = False + self.assertEqual(base_list._instance._changed_fields, ['my_name.1']) + self.assertEqual(base_list, [True, False]) + + def test___delslice___calls_mark_as_changed(self): + base_list = self._get_baselist([0, 1]) + del base_list[0:1] + self.assertEqual(base_list._instance._changed_fields, ['my_name']) + self.assertEqual(base_list, [1]) + + def test___iadd___calls_mark_as_changed(self): + base_list = self._get_baselist([True]) + base_list += [False] + self.assertEqual(base_list._instance._changed_fields, ['my_name']) + + def test___iadd___not_mark_as_changed_when_it_fails(self): + base_list = self._get_baselist([True]) + try: + base_list += None + except TypeError: + self.assertFalse(base_list._instance._changed_fields) + + def test___imul___calls_mark_as_changed(self): + base_list = self._get_baselist([True]) + base_list *= 2 + self.assertEqual(base_list._instance._changed_fields, ['my_name']) + + def test___imul___not_mark_as_changed_when_it_fails(self): + base_list = self._get_baselist([True]) + try: + base_list *= None + except TypeError: + self.assertFalse(base_list._instance._changed_fields) + + def test_sort_calls_mark_as_changed(self): + base_list = self._get_baselist([True, False]) + base_list.sort() + self.assertEqual(base_list._instance._changed_fields, ['my_name']) + + def test_sort_calls_with_key(self): + base_list = self._get_baselist([1, 2, 11]) + base_list.sort(key=lambda i: str(i)) + self.assertEqual(base_list, [1, 11, 2]) + class TestStrictDict(unittest.TestCase): def strict_dict_class(self, *args, **kwargs):