added __slots__ to BaseDocument and Document
changed the _data field to static key-value mapping instead of hash table This implements #624
This commit is contained in:
parent
bcbe740598
commit
9835b382da
@ -1,4 +1,6 @@
|
||||
import weakref
|
||||
import functools
|
||||
import itertools
|
||||
from mongoengine.common import _import_class
|
||||
|
||||
__all__ = ("BaseDict", "BaseList")
|
||||
@ -156,3 +158,98 @@ class BaseList(list):
|
||||
def _mark_as_changed(self):
|
||||
if hasattr(self._instance, '_mark_as_changed'):
|
||||
self._instance._mark_as_changed(self._name)
|
||||
|
||||
|
||||
class StrictDict(object):
|
||||
__slots__ = ()
|
||||
_special_fields = set(['get', 'pop', 'iteritems', 'items', 'keys', 'create'])
|
||||
_classes = {}
|
||||
def __init__(self, **kwargs):
|
||||
for k,v in kwargs.iteritems():
|
||||
setattr(self, k, v)
|
||||
def __getitem__(self, key):
|
||||
key = '_reserved_' + key if key in self._special_fields else key
|
||||
try:
|
||||
return getattr(self, key)
|
||||
except AttributeError:
|
||||
raise KeyError(key)
|
||||
def __setitem__(self, key, value):
|
||||
key = '_reserved_' + key if key in self._special_fields else key
|
||||
return setattr(self, key, value)
|
||||
def __contains__(self, key):
|
||||
return hasattr(self, key)
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self[key]
|
||||
except KeyError:
|
||||
return default
|
||||
def pop(self, key, default=None):
|
||||
v = self.get(key, default)
|
||||
try:
|
||||
delattr(self, key)
|
||||
except AttributeError:
|
||||
pass
|
||||
return v
|
||||
def iteritems(self):
|
||||
for key in self:
|
||||
yield key, self[key]
|
||||
def items(self):
|
||||
return [(k, self[k]) for k in iter(self)]
|
||||
def keys(self):
|
||||
return list(iter(self))
|
||||
def __iter__(self):
|
||||
return (key for key in self.__slots__ if hasattr(self, key))
|
||||
def __len__(self):
|
||||
return len(list(self.iteritems()))
|
||||
def __eq__(self, other):
|
||||
return self.items() == other.items()
|
||||
def __neq__(self, other):
|
||||
return self.items() != other.items()
|
||||
|
||||
@classmethod
|
||||
def create(cls, allowed_keys):
|
||||
allowed_keys_tuple = tuple(('_reserved_' + k if k in cls._special_fields else k) for k in allowed_keys)
|
||||
allowed_keys = frozenset(allowed_keys_tuple)
|
||||
if allowed_keys not in cls._classes:
|
||||
class SpecificStrictDict(cls):
|
||||
__slots__ = allowed_keys_tuple
|
||||
cls._classes[allowed_keys] = SpecificStrictDict
|
||||
return cls._classes[allowed_keys]
|
||||
|
||||
|
||||
class SemiStrictDict(StrictDict):
|
||||
__slots__ = ('_extras')
|
||||
_classes = {}
|
||||
def __getattr__(self, attr):
|
||||
try:
|
||||
super(SemiStrictDict, self).__getattr__(attr)
|
||||
except AttributeError:
|
||||
try:
|
||||
return self.__getattribute__('_extras')[attr]
|
||||
except KeyError as e:
|
||||
raise AttributeError(e)
|
||||
def __setattr__(self, attr, value):
|
||||
try:
|
||||
super(SemiStrictDict, self).__setattr__(attr, value)
|
||||
except AttributeError:
|
||||
try:
|
||||
self._extras[attr] = value
|
||||
except AttributeError:
|
||||
self._extras = {attr: value}
|
||||
|
||||
def __delattr__(self, attr):
|
||||
try:
|
||||
super(SemiStrictDict, self).__delattr__(attr)
|
||||
except AttributeError:
|
||||
try:
|
||||
del self._extras[attr]
|
||||
except KeyError as e:
|
||||
raise AttributeError(e)
|
||||
|
||||
def __iter__(self):
|
||||
try:
|
||||
extras_iter = iter(self.__getattribute__('_extras'))
|
||||
except AttributeError:
|
||||
extras_iter = ()
|
||||
return itertools.chain(super(SemiStrictDict, self).__iter__(), extras_iter)
|
||||
|
||||
|
@ -16,20 +16,20 @@ from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type,
|
||||
to_str_keys_recursive)
|
||||
|
||||
from mongoengine.base.common import get_document, ALLOW_INHERITANCE
|
||||
from mongoengine.base.datastructures import BaseDict, BaseList
|
||||
from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict, SemiStrictDict
|
||||
from mongoengine.base.fields import ComplexBaseField
|
||||
|
||||
__all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
|
||||
|
||||
NON_FIELD_ERRORS = '__all__'
|
||||
|
||||
|
||||
class BaseDocument(object):
|
||||
__slots__ = ('_changed_fields', '_initialised', '_created', '_data',
|
||||
'_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__')
|
||||
|
||||
_dynamic = False
|
||||
_created = True
|
||||
_dynamic_lock = True
|
||||
_initialised = False
|
||||
STRICT = False
|
||||
|
||||
def __init__(self, *args, **values):
|
||||
"""
|
||||
@ -38,6 +38,8 @@ class BaseDocument(object):
|
||||
:param __auto_convert: Try and will cast python objects to Object types
|
||||
:param values: A dictionary of values for the document
|
||||
"""
|
||||
self._initialised = False
|
||||
self._created = True
|
||||
if args:
|
||||
# Combine positional arguments with named arguments.
|
||||
# We only want named arguments.
|
||||
@ -52,8 +54,12 @@ class BaseDocument(object):
|
||||
values[name] = value
|
||||
__auto_convert = values.pop("__auto_convert", True)
|
||||
signals.pre_init.send(self.__class__, document=self, values=values)
|
||||
|
||||
self._data = {}
|
||||
|
||||
if self.STRICT and not self._dynamic:
|
||||
self._data = StrictDict.create(allowed_keys=self._fields.keys())()
|
||||
else:
|
||||
self._data = SemiStrictDict.create(allowed_keys=self._fields.keys())()
|
||||
|
||||
self._dynamic_fields = SON()
|
||||
|
||||
# Assign default values to instance
|
||||
@ -129,17 +135,25 @@ class BaseDocument(object):
|
||||
self._data[name] = value
|
||||
if hasattr(self, '_changed_fields'):
|
||||
self._mark_as_changed(name)
|
||||
try:
|
||||
self__created = self._created
|
||||
except AttributeError:
|
||||
self__created = True
|
||||
|
||||
if (self._is_document and not self._created and
|
||||
if (self._is_document and not self__created and
|
||||
name in self._meta.get('shard_key', tuple()) and
|
||||
self._data.get(name) != value):
|
||||
OperationError = _import_class('OperationError')
|
||||
msg = "Shard Keys are immutable. Tried to update %s" % name
|
||||
raise OperationError(msg)
|
||||
|
||||
try:
|
||||
self__initialised = self._initialised
|
||||
except AttributeError:
|
||||
self__initialised = False
|
||||
# Check if the user has created a new instance of a class
|
||||
if (self._is_document and self._initialised
|
||||
and self._created and name == self._meta['id_field']):
|
||||
if (self._is_document and self__initialised
|
||||
and self__created and name == self._meta['id_field']):
|
||||
super(BaseDocument, self).__setattr__('_created', False)
|
||||
|
||||
super(BaseDocument, self).__setattr__(name, value)
|
||||
@ -157,9 +171,11 @@ class BaseDocument(object):
|
||||
if isinstance(data["_data"], SON):
|
||||
data["_data"] = self.__class__._from_son(data["_data"])._data
|
||||
for k in ('_changed_fields', '_initialised', '_created', '_data',
|
||||
'_fields_ordered', '_dynamic_fields'):
|
||||
'_dynamic_fields'):
|
||||
if k in data:
|
||||
setattr(self, k, data[k])
|
||||
if '_fields_ordered' in data:
|
||||
setattr(type(self), '_fields_ordered', data['_fields_ordered'])
|
||||
dynamic_fields = data.get('_dynamic_fields') or SON()
|
||||
for k in dynamic_fields.keys():
|
||||
setattr(self, k, data["_data"].get(k))
|
||||
@ -576,7 +592,9 @@ class BaseDocument(object):
|
||||
msg = ("Invalid data to create a `%s` instance.\n%s"
|
||||
% (cls._class_name, errors))
|
||||
raise InvalidDocumentError(msg)
|
||||
|
||||
|
||||
if cls.STRICT:
|
||||
data = dict((k, v) for k,v in data.iteritems() if k in cls._fields)
|
||||
obj = cls(__auto_convert=False, **data)
|
||||
obj._changed_fields = changed_fields
|
||||
obj._created = False
|
||||
@ -813,7 +831,11 @@ class BaseDocument(object):
|
||||
"""Dynamically set the display value for a field with choices"""
|
||||
for attr_name, field in self._fields.items():
|
||||
if field.choices:
|
||||
setattr(self,
|
||||
if self._dynamic:
|
||||
obj = self
|
||||
else:
|
||||
obj = type(self)
|
||||
setattr(obj,
|
||||
'get_%s_display' % attr_name,
|
||||
partial(self.__get_field_display, field=field))
|
||||
|
||||
|
@ -52,16 +52,17 @@ class EmbeddedDocument(BaseDocument):
|
||||
`_cls` set :attr:`allow_inheritance` to ``False`` in the :attr:`meta`
|
||||
dictionary.
|
||||
"""
|
||||
|
||||
__slots__ = ('_instance')
|
||||
|
||||
# The __metaclass__ attribute is removed by 2to3 when running with Python3
|
||||
# my_metaclass is defined so that metaclass can be queried in Python 2 & 3
|
||||
my_metaclass = DocumentMetaclass
|
||||
__metaclass__ = DocumentMetaclass
|
||||
|
||||
_instance = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(EmbeddedDocument, self).__init__(*args, **kwargs)
|
||||
self._instance = None
|
||||
self._changed_fields = []
|
||||
|
||||
def __eq__(self, other):
|
||||
@ -124,6 +125,8 @@ class Document(BaseDocument):
|
||||
my_metaclass = TopLevelDocumentMetaclass
|
||||
__metaclass__ = TopLevelDocumentMetaclass
|
||||
|
||||
__slots__ = ('__objects' )
|
||||
|
||||
def pk():
|
||||
"""Primary key alias
|
||||
"""
|
||||
|
107
tests/test_datastructures.py
Normal file
107
tests/test_datastructures.py
Normal file
@ -0,0 +1,107 @@
|
||||
import unittest
|
||||
from mongoengine.base.datastructures import StrictDict, SemiStrictDict
|
||||
|
||||
class TestStrictDict(unittest.TestCase):
|
||||
def strict_dict_class(self, *args, **kwargs):
|
||||
return StrictDict.create(*args, **kwargs)
|
||||
def setUp(self):
|
||||
self.dtype = self.strict_dict_class(("a", "b", "c"))
|
||||
def test_init(self):
|
||||
d = self.dtype(a=1, b=1, c=1)
|
||||
self.assertEqual((d.a, d.b, d.c), (1, 1, 1))
|
||||
|
||||
def test_init_fails_on_nonexisting_attrs(self):
|
||||
self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3))
|
||||
|
||||
def test_eq(self):
|
||||
d = self.dtype(a=1, b=1, c=1)
|
||||
dd = self.dtype(a=1, b=1, c=1)
|
||||
e = self.dtype(a=1, b=1, c=3)
|
||||
f = self.dtype(a=1, b=1)
|
||||
g = self.strict_dict_class(("a", "b", "c", "d"))(a=1, b=1, c=1, d=1)
|
||||
h = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=1)
|
||||
i = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=2)
|
||||
|
||||
self.assertEqual(d, dd)
|
||||
self.assertNotEqual(d, e)
|
||||
self.assertNotEqual(d, f)
|
||||
self.assertNotEqual(d, g)
|
||||
self.assertNotEqual(f, d)
|
||||
self.assertEqual(d, h)
|
||||
self.assertNotEqual(d, i)
|
||||
|
||||
def test_setattr_getattr(self):
|
||||
d = self.dtype()
|
||||
d.a = 1
|
||||
self.assertEqual(d.a, 1)
|
||||
self.assertRaises(AttributeError, lambda: d.b)
|
||||
|
||||
def test_setattr_raises_on_nonexisting_attr(self):
|
||||
d = self.dtype()
|
||||
def _f():
|
||||
d.x=1
|
||||
self.assertRaises(AttributeError, _f)
|
||||
|
||||
def test_setattr_getattr_special(self):
|
||||
d = self.strict_dict_class(["items"])
|
||||
d.items = 1
|
||||
self.assertEqual(d.items, 1)
|
||||
|
||||
def test_get(self):
|
||||
d = self.dtype(a=1)
|
||||
self.assertEqual(d.get('a'), 1)
|
||||
self.assertEqual(d.get('b', 'bla'), 'bla')
|
||||
|
||||
def test_items(self):
|
||||
d = self.dtype(a=1)
|
||||
self.assertEqual(d.items(), [('a', 1)])
|
||||
d = self.dtype(a=1, b=2)
|
||||
self.assertEqual(d.items(), [('a', 1), ('b', 2)])
|
||||
|
||||
def test_mappings_protocol(self):
|
||||
d = self.dtype(a=1, b=2)
|
||||
assert dict(d) == {'a': 1, 'b': 2}
|
||||
assert dict(**d) == {'a': 1, 'b': 2}
|
||||
|
||||
|
||||
class TestSemiSrictDict(TestStrictDict):
|
||||
def strict_dict_class(self, *args, **kwargs):
|
||||
return SemiStrictDict.create(*args, **kwargs)
|
||||
|
||||
def test_init_fails_on_nonexisting_attrs(self):
|
||||
# disable irrelevant test
|
||||
pass
|
||||
|
||||
def test_setattr_raises_on_nonexisting_attr(self):
|
||||
# disable irrelevant test
|
||||
pass
|
||||
|
||||
def test_setattr_getattr_nonexisting_attr_succeeds(self):
|
||||
d = self.dtype()
|
||||
d.x = 1
|
||||
self.assertEqual(d.x, 1)
|
||||
|
||||
def test_init_succeeds_with_nonexisting_attrs(self):
|
||||
d = self.dtype(a=1, b=1, c=1, x=2)
|
||||
self.assertEqual((d.a, d.b, d.c, d.x), (1, 1, 1, 2))
|
||||
|
||||
def test_iter_with_nonexisting_attrs(self):
|
||||
d = self.dtype(a=1, b=1, c=1, x=2)
|
||||
self.assertEqual(list(d), ['a', 'b', 'c', 'x'])
|
||||
|
||||
def test_iteritems_with_nonexisting_attrs(self):
|
||||
d = self.dtype(a=1, b=1, c=1, x=2)
|
||||
self.assertEqual(list(d.iteritems()), [('a', 1), ('b', 1), ('c', 1), ('x', 2)])
|
||||
|
||||
def tets_cmp_with_strict_dicts(self):
|
||||
d = self.dtype(a=1, b=1, c=1)
|
||||
dd = StrictDict.create(("a", "b", "c"))(a=1, b=1, c=1)
|
||||
self.assertEqual(d, dd)
|
||||
|
||||
def test_cmp_with_strict_dict_with_nonexisting_attrs(self):
|
||||
d = self.dtype(a=1, b=1, c=1, x=2)
|
||||
dd = StrictDict.create(("a", "b", "c", "x"))(a=1, b=1, c=1, x=2)
|
||||
self.assertEqual(d, dd)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user