added __slots__ to BaseDocument and Document
changed the _data field to static key-value mapping instead of hash table This implements #624
This commit is contained in:
parent
bcbe740598
commit
9835b382da
@ -1,4 +1,6 @@
|
|||||||
import weakref
|
import weakref
|
||||||
|
import functools
|
||||||
|
import itertools
|
||||||
from mongoengine.common import _import_class
|
from mongoengine.common import _import_class
|
||||||
|
|
||||||
__all__ = ("BaseDict", "BaseList")
|
__all__ = ("BaseDict", "BaseList")
|
||||||
@ -156,3 +158,98 @@ class BaseList(list):
|
|||||||
def _mark_as_changed(self):
|
def _mark_as_changed(self):
|
||||||
if hasattr(self._instance, '_mark_as_changed'):
|
if hasattr(self._instance, '_mark_as_changed'):
|
||||||
self._instance._mark_as_changed(self._name)
|
self._instance._mark_as_changed(self._name)
|
||||||
|
|
||||||
|
|
||||||
|
class StrictDict(object):
|
||||||
|
__slots__ = ()
|
||||||
|
_special_fields = set(['get', 'pop', 'iteritems', 'items', 'keys', 'create'])
|
||||||
|
_classes = {}
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
for k,v in kwargs.iteritems():
|
||||||
|
setattr(self, k, v)
|
||||||
|
def __getitem__(self, key):
|
||||||
|
key = '_reserved_' + key if key in self._special_fields else key
|
||||||
|
try:
|
||||||
|
return getattr(self, key)
|
||||||
|
except AttributeError:
|
||||||
|
raise KeyError(key)
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
key = '_reserved_' + key if key in self._special_fields else key
|
||||||
|
return setattr(self, key, value)
|
||||||
|
def __contains__(self, key):
|
||||||
|
return hasattr(self, key)
|
||||||
|
def get(self, key, default=None):
|
||||||
|
try:
|
||||||
|
return self[key]
|
||||||
|
except KeyError:
|
||||||
|
return default
|
||||||
|
def pop(self, key, default=None):
|
||||||
|
v = self.get(key, default)
|
||||||
|
try:
|
||||||
|
delattr(self, key)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
return v
|
||||||
|
def iteritems(self):
|
||||||
|
for key in self:
|
||||||
|
yield key, self[key]
|
||||||
|
def items(self):
|
||||||
|
return [(k, self[k]) for k in iter(self)]
|
||||||
|
def keys(self):
|
||||||
|
return list(iter(self))
|
||||||
|
def __iter__(self):
|
||||||
|
return (key for key in self.__slots__ if hasattr(self, key))
|
||||||
|
def __len__(self):
|
||||||
|
return len(list(self.iteritems()))
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.items() == other.items()
|
||||||
|
def __neq__(self, other):
|
||||||
|
return self.items() != other.items()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, allowed_keys):
|
||||||
|
allowed_keys_tuple = tuple(('_reserved_' + k if k in cls._special_fields else k) for k in allowed_keys)
|
||||||
|
allowed_keys = frozenset(allowed_keys_tuple)
|
||||||
|
if allowed_keys not in cls._classes:
|
||||||
|
class SpecificStrictDict(cls):
|
||||||
|
__slots__ = allowed_keys_tuple
|
||||||
|
cls._classes[allowed_keys] = SpecificStrictDict
|
||||||
|
return cls._classes[allowed_keys]
|
||||||
|
|
||||||
|
|
||||||
|
class SemiStrictDict(StrictDict):
|
||||||
|
__slots__ = ('_extras')
|
||||||
|
_classes = {}
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
try:
|
||||||
|
super(SemiStrictDict, self).__getattr__(attr)
|
||||||
|
except AttributeError:
|
||||||
|
try:
|
||||||
|
return self.__getattribute__('_extras')[attr]
|
||||||
|
except KeyError as e:
|
||||||
|
raise AttributeError(e)
|
||||||
|
def __setattr__(self, attr, value):
|
||||||
|
try:
|
||||||
|
super(SemiStrictDict, self).__setattr__(attr, value)
|
||||||
|
except AttributeError:
|
||||||
|
try:
|
||||||
|
self._extras[attr] = value
|
||||||
|
except AttributeError:
|
||||||
|
self._extras = {attr: value}
|
||||||
|
|
||||||
|
def __delattr__(self, attr):
|
||||||
|
try:
|
||||||
|
super(SemiStrictDict, self).__delattr__(attr)
|
||||||
|
except AttributeError:
|
||||||
|
try:
|
||||||
|
del self._extras[attr]
|
||||||
|
except KeyError as e:
|
||||||
|
raise AttributeError(e)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
try:
|
||||||
|
extras_iter = iter(self.__getattribute__('_extras'))
|
||||||
|
except AttributeError:
|
||||||
|
extras_iter = ()
|
||||||
|
return itertools.chain(super(SemiStrictDict, self).__iter__(), extras_iter)
|
||||||
|
|
||||||
|
@ -16,20 +16,20 @@ from mongoengine.python_support import (PY3, UNICODE_KWARGS, txt_type,
|
|||||||
to_str_keys_recursive)
|
to_str_keys_recursive)
|
||||||
|
|
||||||
from mongoengine.base.common import get_document, ALLOW_INHERITANCE
|
from mongoengine.base.common import get_document, ALLOW_INHERITANCE
|
||||||
from mongoengine.base.datastructures import BaseDict, BaseList
|
from mongoengine.base.datastructures import BaseDict, BaseList, StrictDict, SemiStrictDict
|
||||||
from mongoengine.base.fields import ComplexBaseField
|
from mongoengine.base.fields import ComplexBaseField
|
||||||
|
|
||||||
__all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
|
__all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
|
||||||
|
|
||||||
NON_FIELD_ERRORS = '__all__'
|
NON_FIELD_ERRORS = '__all__'
|
||||||
|
|
||||||
|
|
||||||
class BaseDocument(object):
|
class BaseDocument(object):
|
||||||
|
__slots__ = ('_changed_fields', '_initialised', '_created', '_data',
|
||||||
|
'_dynamic_fields', '_auto_id_field', '_db_field_map', '_cls', '__weakref__')
|
||||||
|
|
||||||
_dynamic = False
|
_dynamic = False
|
||||||
_created = True
|
|
||||||
_dynamic_lock = True
|
_dynamic_lock = True
|
||||||
_initialised = False
|
STRICT = False
|
||||||
|
|
||||||
def __init__(self, *args, **values):
|
def __init__(self, *args, **values):
|
||||||
"""
|
"""
|
||||||
@ -38,6 +38,8 @@ class BaseDocument(object):
|
|||||||
:param __auto_convert: Try and will cast python objects to Object types
|
:param __auto_convert: Try and will cast python objects to Object types
|
||||||
:param values: A dictionary of values for the document
|
:param values: A dictionary of values for the document
|
||||||
"""
|
"""
|
||||||
|
self._initialised = False
|
||||||
|
self._created = True
|
||||||
if args:
|
if args:
|
||||||
# Combine positional arguments with named arguments.
|
# Combine positional arguments with named arguments.
|
||||||
# We only want named arguments.
|
# We only want named arguments.
|
||||||
@ -53,7 +55,11 @@ class BaseDocument(object):
|
|||||||
__auto_convert = values.pop("__auto_convert", True)
|
__auto_convert = values.pop("__auto_convert", True)
|
||||||
signals.pre_init.send(self.__class__, document=self, values=values)
|
signals.pre_init.send(self.__class__, document=self, values=values)
|
||||||
|
|
||||||
self._data = {}
|
if self.STRICT and not self._dynamic:
|
||||||
|
self._data = StrictDict.create(allowed_keys=self._fields.keys())()
|
||||||
|
else:
|
||||||
|
self._data = SemiStrictDict.create(allowed_keys=self._fields.keys())()
|
||||||
|
|
||||||
self._dynamic_fields = SON()
|
self._dynamic_fields = SON()
|
||||||
|
|
||||||
# Assign default values to instance
|
# Assign default values to instance
|
||||||
@ -129,17 +135,25 @@ class BaseDocument(object):
|
|||||||
self._data[name] = value
|
self._data[name] = value
|
||||||
if hasattr(self, '_changed_fields'):
|
if hasattr(self, '_changed_fields'):
|
||||||
self._mark_as_changed(name)
|
self._mark_as_changed(name)
|
||||||
|
try:
|
||||||
|
self__created = self._created
|
||||||
|
except AttributeError:
|
||||||
|
self__created = True
|
||||||
|
|
||||||
if (self._is_document and not self._created and
|
if (self._is_document and not self__created and
|
||||||
name in self._meta.get('shard_key', tuple()) and
|
name in self._meta.get('shard_key', tuple()) and
|
||||||
self._data.get(name) != value):
|
self._data.get(name) != value):
|
||||||
OperationError = _import_class('OperationError')
|
OperationError = _import_class('OperationError')
|
||||||
msg = "Shard Keys are immutable. Tried to update %s" % name
|
msg = "Shard Keys are immutable. Tried to update %s" % name
|
||||||
raise OperationError(msg)
|
raise OperationError(msg)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self__initialised = self._initialised
|
||||||
|
except AttributeError:
|
||||||
|
self__initialised = False
|
||||||
# Check if the user has created a new instance of a class
|
# Check if the user has created a new instance of a class
|
||||||
if (self._is_document and self._initialised
|
if (self._is_document and self__initialised
|
||||||
and self._created and name == self._meta['id_field']):
|
and self__created and name == self._meta['id_field']):
|
||||||
super(BaseDocument, self).__setattr__('_created', False)
|
super(BaseDocument, self).__setattr__('_created', False)
|
||||||
|
|
||||||
super(BaseDocument, self).__setattr__(name, value)
|
super(BaseDocument, self).__setattr__(name, value)
|
||||||
@ -157,9 +171,11 @@ class BaseDocument(object):
|
|||||||
if isinstance(data["_data"], SON):
|
if isinstance(data["_data"], SON):
|
||||||
data["_data"] = self.__class__._from_son(data["_data"])._data
|
data["_data"] = self.__class__._from_son(data["_data"])._data
|
||||||
for k in ('_changed_fields', '_initialised', '_created', '_data',
|
for k in ('_changed_fields', '_initialised', '_created', '_data',
|
||||||
'_fields_ordered', '_dynamic_fields'):
|
'_dynamic_fields'):
|
||||||
if k in data:
|
if k in data:
|
||||||
setattr(self, k, data[k])
|
setattr(self, k, data[k])
|
||||||
|
if '_fields_ordered' in data:
|
||||||
|
setattr(type(self), '_fields_ordered', data['_fields_ordered'])
|
||||||
dynamic_fields = data.get('_dynamic_fields') or SON()
|
dynamic_fields = data.get('_dynamic_fields') or SON()
|
||||||
for k in dynamic_fields.keys():
|
for k in dynamic_fields.keys():
|
||||||
setattr(self, k, data["_data"].get(k))
|
setattr(self, k, data["_data"].get(k))
|
||||||
@ -577,6 +593,8 @@ class BaseDocument(object):
|
|||||||
% (cls._class_name, errors))
|
% (cls._class_name, errors))
|
||||||
raise InvalidDocumentError(msg)
|
raise InvalidDocumentError(msg)
|
||||||
|
|
||||||
|
if cls.STRICT:
|
||||||
|
data = dict((k, v) for k,v in data.iteritems() if k in cls._fields)
|
||||||
obj = cls(__auto_convert=False, **data)
|
obj = cls(__auto_convert=False, **data)
|
||||||
obj._changed_fields = changed_fields
|
obj._changed_fields = changed_fields
|
||||||
obj._created = False
|
obj._created = False
|
||||||
@ -813,7 +831,11 @@ class BaseDocument(object):
|
|||||||
"""Dynamically set the display value for a field with choices"""
|
"""Dynamically set the display value for a field with choices"""
|
||||||
for attr_name, field in self._fields.items():
|
for attr_name, field in self._fields.items():
|
||||||
if field.choices:
|
if field.choices:
|
||||||
setattr(self,
|
if self._dynamic:
|
||||||
|
obj = self
|
||||||
|
else:
|
||||||
|
obj = type(self)
|
||||||
|
setattr(obj,
|
||||||
'get_%s_display' % attr_name,
|
'get_%s_display' % attr_name,
|
||||||
partial(self.__get_field_display, field=field))
|
partial(self.__get_field_display, field=field))
|
||||||
|
|
||||||
|
@ -53,15 +53,16 @@ class EmbeddedDocument(BaseDocument):
|
|||||||
dictionary.
|
dictionary.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
__slots__ = ('_instance')
|
||||||
|
|
||||||
# The __metaclass__ attribute is removed by 2to3 when running with Python3
|
# The __metaclass__ attribute is removed by 2to3 when running with Python3
|
||||||
# my_metaclass is defined so that metaclass can be queried in Python 2 & 3
|
# my_metaclass is defined so that metaclass can be queried in Python 2 & 3
|
||||||
my_metaclass = DocumentMetaclass
|
my_metaclass = DocumentMetaclass
|
||||||
__metaclass__ = DocumentMetaclass
|
__metaclass__ = DocumentMetaclass
|
||||||
|
|
||||||
_instance = None
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(EmbeddedDocument, self).__init__(*args, **kwargs)
|
super(EmbeddedDocument, self).__init__(*args, **kwargs)
|
||||||
|
self._instance = None
|
||||||
self._changed_fields = []
|
self._changed_fields = []
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
@ -124,6 +125,8 @@ class Document(BaseDocument):
|
|||||||
my_metaclass = TopLevelDocumentMetaclass
|
my_metaclass = TopLevelDocumentMetaclass
|
||||||
__metaclass__ = TopLevelDocumentMetaclass
|
__metaclass__ = TopLevelDocumentMetaclass
|
||||||
|
|
||||||
|
__slots__ = ('__objects' )
|
||||||
|
|
||||||
def pk():
|
def pk():
|
||||||
"""Primary key alias
|
"""Primary key alias
|
||||||
"""
|
"""
|
||||||
|
107
tests/test_datastructures.py
Normal file
107
tests/test_datastructures.py
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
import unittest
|
||||||
|
from mongoengine.base.datastructures import StrictDict, SemiStrictDict
|
||||||
|
|
||||||
|
class TestStrictDict(unittest.TestCase):
|
||||||
|
def strict_dict_class(self, *args, **kwargs):
|
||||||
|
return StrictDict.create(*args, **kwargs)
|
||||||
|
def setUp(self):
|
||||||
|
self.dtype = self.strict_dict_class(("a", "b", "c"))
|
||||||
|
def test_init(self):
|
||||||
|
d = self.dtype(a=1, b=1, c=1)
|
||||||
|
self.assertEqual((d.a, d.b, d.c), (1, 1, 1))
|
||||||
|
|
||||||
|
def test_init_fails_on_nonexisting_attrs(self):
|
||||||
|
self.assertRaises(AttributeError, lambda: self.dtype(a=1, b=2, d=3))
|
||||||
|
|
||||||
|
def test_eq(self):
|
||||||
|
d = self.dtype(a=1, b=1, c=1)
|
||||||
|
dd = self.dtype(a=1, b=1, c=1)
|
||||||
|
e = self.dtype(a=1, b=1, c=3)
|
||||||
|
f = self.dtype(a=1, b=1)
|
||||||
|
g = self.strict_dict_class(("a", "b", "c", "d"))(a=1, b=1, c=1, d=1)
|
||||||
|
h = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=1)
|
||||||
|
i = self.strict_dict_class(("a", "c", "b"))(a=1, b=1, c=2)
|
||||||
|
|
||||||
|
self.assertEqual(d, dd)
|
||||||
|
self.assertNotEqual(d, e)
|
||||||
|
self.assertNotEqual(d, f)
|
||||||
|
self.assertNotEqual(d, g)
|
||||||
|
self.assertNotEqual(f, d)
|
||||||
|
self.assertEqual(d, h)
|
||||||
|
self.assertNotEqual(d, i)
|
||||||
|
|
||||||
|
def test_setattr_getattr(self):
|
||||||
|
d = self.dtype()
|
||||||
|
d.a = 1
|
||||||
|
self.assertEqual(d.a, 1)
|
||||||
|
self.assertRaises(AttributeError, lambda: d.b)
|
||||||
|
|
||||||
|
def test_setattr_raises_on_nonexisting_attr(self):
|
||||||
|
d = self.dtype()
|
||||||
|
def _f():
|
||||||
|
d.x=1
|
||||||
|
self.assertRaises(AttributeError, _f)
|
||||||
|
|
||||||
|
def test_setattr_getattr_special(self):
|
||||||
|
d = self.strict_dict_class(["items"])
|
||||||
|
d.items = 1
|
||||||
|
self.assertEqual(d.items, 1)
|
||||||
|
|
||||||
|
def test_get(self):
|
||||||
|
d = self.dtype(a=1)
|
||||||
|
self.assertEqual(d.get('a'), 1)
|
||||||
|
self.assertEqual(d.get('b', 'bla'), 'bla')
|
||||||
|
|
||||||
|
def test_items(self):
|
||||||
|
d = self.dtype(a=1)
|
||||||
|
self.assertEqual(d.items(), [('a', 1)])
|
||||||
|
d = self.dtype(a=1, b=2)
|
||||||
|
self.assertEqual(d.items(), [('a', 1), ('b', 2)])
|
||||||
|
|
||||||
|
def test_mappings_protocol(self):
|
||||||
|
d = self.dtype(a=1, b=2)
|
||||||
|
assert dict(d) == {'a': 1, 'b': 2}
|
||||||
|
assert dict(**d) == {'a': 1, 'b': 2}
|
||||||
|
|
||||||
|
|
||||||
|
class TestSemiSrictDict(TestStrictDict):
|
||||||
|
def strict_dict_class(self, *args, **kwargs):
|
||||||
|
return SemiStrictDict.create(*args, **kwargs)
|
||||||
|
|
||||||
|
def test_init_fails_on_nonexisting_attrs(self):
|
||||||
|
# disable irrelevant test
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_setattr_raises_on_nonexisting_attr(self):
|
||||||
|
# disable irrelevant test
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_setattr_getattr_nonexisting_attr_succeeds(self):
|
||||||
|
d = self.dtype()
|
||||||
|
d.x = 1
|
||||||
|
self.assertEqual(d.x, 1)
|
||||||
|
|
||||||
|
def test_init_succeeds_with_nonexisting_attrs(self):
|
||||||
|
d = self.dtype(a=1, b=1, c=1, x=2)
|
||||||
|
self.assertEqual((d.a, d.b, d.c, d.x), (1, 1, 1, 2))
|
||||||
|
|
||||||
|
def test_iter_with_nonexisting_attrs(self):
|
||||||
|
d = self.dtype(a=1, b=1, c=1, x=2)
|
||||||
|
self.assertEqual(list(d), ['a', 'b', 'c', 'x'])
|
||||||
|
|
||||||
|
def test_iteritems_with_nonexisting_attrs(self):
|
||||||
|
d = self.dtype(a=1, b=1, c=1, x=2)
|
||||||
|
self.assertEqual(list(d.iteritems()), [('a', 1), ('b', 1), ('c', 1), ('x', 2)])
|
||||||
|
|
||||||
|
def tets_cmp_with_strict_dicts(self):
|
||||||
|
d = self.dtype(a=1, b=1, c=1)
|
||||||
|
dd = StrictDict.create(("a", "b", "c"))(a=1, b=1, c=1)
|
||||||
|
self.assertEqual(d, dd)
|
||||||
|
|
||||||
|
def test_cmp_with_strict_dict_with_nonexisting_attrs(self):
|
||||||
|
d = self.dtype(a=1, b=1, c=1, x=2)
|
||||||
|
dd = StrictDict.create(("a", "b", "c", "x"))(a=1, b=1, c=1, x=2)
|
||||||
|
self.assertEqual(d, dd)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user