From 9835b382dab4ed3ceef277a2948dde103542303c Mon Sep 17 00:00:00 2001 From: Sagiv Malihi Date: Thu, 3 Apr 2014 12:38:33 +0300 Subject: [PATCH] added __slots__ to BaseDocument and Document changed the _data field to static key-value mapping instead of hash table This implements #624 --- mongoengine/base/datastructures.py | 97 ++++++++++++++++++++++++++ mongoengine/base/document.py | 46 +++++++++---- mongoengine/document.py | 7 +- tests/test_datastructures.py | 107 +++++++++++++++++++++++++++++ 4 files changed, 243 insertions(+), 14 deletions(-) create mode 100644 tests/test_datastructures.py diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index 4652fb56..32a66018 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -1,4 +1,6 @@ import weakref +import functools +import itertools from mongoengine.common import _import_class __all__ = ("BaseDict", "BaseList") @@ -156,3 +158,98 @@ class BaseList(list): def _mark_as_changed(self): if hasattr(self._instance, '_mark_as_changed'): self._instance._mark_as_changed(self._name) + + +class StrictDict(object): + __slots__ = () + _special_fields = set(['get', 'pop', 'iteritems', 'items', 'keys', 'create']) + _classes = {} + def __init__(self, **kwargs): + for k,v in kwargs.iteritems(): + setattr(self, k, v) + def __getitem__(self, key): + key = '_reserved_' + key if key in self._special_fields else key + try: + return getattr(self, key) + except AttributeError: + raise KeyError(key) + def __setitem__(self, key, value): + key = '_reserved_' + key if key in self._special_fields else key + return setattr(self, key, value) + def __contains__(self, key): + return hasattr(self, key) + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + def pop(self, key, default=None): + v = self.get(key, default) + try: + delattr(self, key) + except AttributeError: + pass + return v + def iteritems(self): + for key in self: + yield key, self[key] + def items(self): + return [(k, self[k]) for k in iter(self)] + def keys(self): + return list(iter(self)) + def __iter__(self): + return (key for key in self.__slots__ if hasattr(self, key)) + def __len__(self): + return len(list(self.iteritems())) + def __eq__(self, other): + return self.items() == other.items() + def __neq__(self, other): + return self.items() != other.items() + + @classmethod + def create(cls, allowed_keys): + allowed_keys_tuple = tuple(('_reserved_' + k if k in cls._special_fields else k) for k in allowed_keys) + allowed_keys = frozenset(allowed_keys_tuple) + if allowed_keys not in cls._classes: + class SpecificStrictDict(cls): + __slots__ = allowed_keys_tuple + cls._classes[allowed_keys] = SpecificStrictDict + return cls._classes[allowed_keys] + + +class SemiStrictDict(StrictDict): + __slots__ = ('_extras') + _classes = {} + def __getattr__(self, attr): + try: + super(SemiStrictDict, self).__getattr__(attr) + except AttributeError: + try: + return self.__getattribute__('_extras')[attr] + except KeyError as e: + raise AttributeError(e) + def __setattr__(self, attr, value): + try: + super(SemiStrictDict, self).__setattr__(attr, value) + except AttributeError: + try: + self._extras[attr] = value + except AttributeError: + self._extras = {attr: value} + + def __delattr__(self, attr): + try: + super(SemiStrictDict, self).__delattr__(attr) + except AttributeError: + try: + del self._extras[attr] + except KeyError as e: + raise AttributeError(e) + + def __iter__(self): + try: + extras_iter = iter(self.__getattribute__('_extras')) + except AttributeError: + extras_iter = () + return itertools.chain(super(SemiStrictDict, self).__iter__(), extras_iter) + diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index cea2f09b..01809aa9 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -16,20 +16,20 @@ from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type, to_str_keys_recursive) from mongoengine.base.common import get_document, ALLOW_INHERITANCE -from mongoengine.base.datastructures import BaseDict, BaseList +from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict, SemiStrictDict from mongoengine.base.fields import ComplexBaseField __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') NON_FIELD_ERRORS = '__all__' - class BaseDocument(object): + __slots__ = ('_changed_fields', '_initialised', '_created', '_data', + '_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__') _dynamic = False - _created = True _dynamic_lock = True - _initialised = False + STRICT = False def __init__(self, *args, **values): """ @@ -38,6 +38,8 @@ class BaseDocument(object): :param __auto_convert: Try and will cast python objects to Object types :param values: A dictionary of values for the document """ + self._initialised = False + self._created = True if args: # Combine positional arguments with named arguments. # We only want named arguments. @@ -52,8 +54,12 @@ class BaseDocument(object): values[name] = value __auto_convert = values.pop("__auto_convert", True) signals.pre_init.send(self.__class__, document=self, values=values) - - self._data = {} + + if self.STRICT and not self._dynamic: + self._data = StrictDict.create(allowed_keys=self._fields.keys())() + else: + self._data = SemiStrictDict.create(allowed_keys=self._fields.keys())() + self._dynamic_fields = SON() # Assign default values to instance @@ -129,17 +135,25 @@ class BaseDocument(object): self._data[name] = value if hasattr(self, '_changed_fields'): self._mark_as_changed(name) + try: + self__created = self._created + except AttributeError: + self__created = True - if (self._is_document and not self._created and + if (self._is_document and not self__created and name in self._meta.get('shard_key', tuple()) and self._data.get(name) != value): OperationError = _import_class('OperationError') msg = "Shard Keys are immutable. Tried to update %s" % name raise OperationError(msg) + try: + self__initialised = self._initialised + except AttributeError: + self__initialised = False # Check if the user has created a new instance of a class - if (self._is_document and self._initialised - and self._created and name == self._meta['id_field']): + if (self._is_document and self__initialised + and self__created and name == self._meta['id_field']): super(BaseDocument, self).__setattr__('_created', False) super(BaseDocument, self).__setattr__(name, value) @@ -157,9 +171,11 @@ class BaseDocument(object): if isinstance(data["_data"], SON): data["_data"] = self.__class__._from_son(data["_data"])._data for k in ('_changed_fields', '_initialised', '_created', '_data', - '_fields_ordered', '_dynamic_fields'): + '_dynamic_fields'): if k in data: setattr(self, k, data[k]) + if '_fields_ordered' in data: + setattr(type(self), '_fields_ordered', data['_fields_ordered']) dynamic_fields = data.get('_dynamic_fields') or SON() for k in dynamic_fields.keys(): setattr(self, k, data["_data"].get(k)) @@ -576,7 +592,9 @@ class BaseDocument(object): msg = ("Invalid data to create a `%s` instance.\n%s" % (cls._class_name, errors)) raise InvalidDocumentError(msg) - + + if cls.STRICT: + data = dict((k, v) for k,v in data.iteritems() if k in cls._fields) obj = cls(__auto_convert=False, **data) obj._changed_fields = changed_fields obj._created = False @@ -813,7 +831,11 @@ class BaseDocument(object): """Dynamically set the display value for a field with choices""" for attr_name, field in self._fields.items(): if field.choices: - setattr(self, + if self._dynamic: + obj = self + else: + obj = type(self) + setattr(obj, 'get_%s_display' % attr_name, partial(self.__get_field_display, field=field)) diff --git a/mongoengine/document.py b/mongoengine/document.py index 1bbd7b73..98e1d2a3 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -52,16 +52,17 @@ class EmbeddedDocument(BaseDocument): `_cls` set :attr:`allow_inheritance` to ``False`` in the :attr:`meta` dictionary. """ + + __slots__ = ('_instance') # The __metaclass__ attribute is removed by 2to3 when running with Python3 # my_metaclass is defined so that metaclass can be queried in Python 2 & 3 my_metaclass = DocumentMetaclass __metaclass__ = DocumentMetaclass - _instance = None - def __init__(self, *args, **kwargs): super(EmbeddedDocument, self).__init__(*args, **kwargs) + self._instance = None self._changed_fields = [] def __eq__(self, other): @@ -124,6 +125,8 @@ class Document(BaseDocument): my_metaclass = TopLevelDocumentMetaclass __metaclass__ = TopLevelDocumentMetaclass + __slots__ = ('__objects' ) + def pk(): """Primary key alias """ diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py new file mode 100644 index 00000000..c761a41e --- /dev/null +++ b/tests/test_datastructures.py @@ -0,0 +1,107 @@ +import unittest +from mongoengine.base.datastructures import StrictDict, SemiStrictDict + +class TestStrictDict(unittest.TestCase): + def strict_dict_class(self, *args, **kwargs): + return StrictDict.create(*args, **kwargs) + def setUp(self): + self.dtype = self.strict_dict_class(("a", "b", "c")) + def test_init(self): + d = self.dtype(a=1, b=1, c=1) + self.assertEqual((d.a, d.b, d.c), (1, 1, 1)) + + def test_init_fails_on_nonexisting_attrs(self): + self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3)) + + def test_eq(self): + d = self.dtype(a=1, b=1, c=1) + dd = self.dtype(a=1, b=1, c=1) + e = self.dtype(a=1, b=1, c=3) + f = self.dtype(a=1, b=1) + g = self.strict_dict_class(("a", "b", "c", "d"))(a=1, b=1, c=1, d=1) + h = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=1) + i = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=2) + + self.assertEqual(d, dd) + self.assertNotEqual(d, e) + self.assertNotEqual(d, f) + self.assertNotEqual(d, g) + self.assertNotEqual(f, d) + self.assertEqual(d, h) + self.assertNotEqual(d, i) + + def test_setattr_getattr(self): + d = self.dtype() + d.a = 1 + self.assertEqual(d.a, 1) + self.assertRaises(AttributeError, lambda: d.b) + + def test_setattr_raises_on_nonexisting_attr(self): + d = self.dtype() + def _f(): + d.x=1 + self.assertRaises(AttributeError, _f) + + def test_setattr_getattr_special(self): + d = self.strict_dict_class(["items"]) + d.items = 1 + self.assertEqual(d.items, 1) + + def test_get(self): + d = self.dtype(a=1) + self.assertEqual(d.get('a'), 1) + self.assertEqual(d.get('b', 'bla'), 'bla') + + def test_items(self): + d = self.dtype(a=1) + self.assertEqual(d.items(), [('a', 1)]) + d = self.dtype(a=1, b=2) + self.assertEqual(d.items(), [('a', 1), ('b', 2)]) + + def test_mappings_protocol(self): + d = self.dtype(a=1, b=2) + assert dict(d) == {'a': 1, 'b': 2} + assert dict(**d) == {'a': 1, 'b': 2} + + +class TestSemiSrictDict(TestStrictDict): + def strict_dict_class(self, *args, **kwargs): + return SemiStrictDict.create(*args, **kwargs) + + def test_init_fails_on_nonexisting_attrs(self): + # disable irrelevant test + pass + + def test_setattr_raises_on_nonexisting_attr(self): + # disable irrelevant test + pass + + def test_setattr_getattr_nonexisting_attr_succeeds(self): + d = self.dtype() + d.x = 1 + self.assertEqual(d.x, 1) + + def test_init_succeeds_with_nonexisting_attrs(self): + d = self.dtype(a=1, b=1, c=1, x=2) + self.assertEqual((d.a, d.b, d.c, d.x), (1, 1, 1, 2)) + + def test_iter_with_nonexisting_attrs(self): + d = self.dtype(a=1, b=1, c=1, x=2) + self.assertEqual(list(d), ['a', 'b', 'c', 'x']) + + def test_iteritems_with_nonexisting_attrs(self): + d = self.dtype(a=1, b=1, c=1, x=2) + self.assertEqual(list(d.iteritems()), [('a', 1), ('b', 1), ('c', 1), ('x', 2)]) + + def tets_cmp_with_strict_dicts(self): + d = self.dtype(a=1, b=1, c=1) + dd = StrictDict.create(("a", "b", "c"))(a=1, b=1, c=1) + self.assertEqual(d, dd) + + def test_cmp_with_strict_dict_with_nonexisting_attrs(self): + d = self.dtype(a=1, b=1, c=1, x=2) + dd = StrictDict.create(("a", "b", "c", "x"))(a=1, b=1, c=1, x=2) + self.assertEqual(d, dd) + +if __name__ == '__main__': + unittest.main()