mongoengine/mongoengine/base/datastructures.py

462 lines
15 KiB
Python

import weakref
import functools
import itertools
from mongoengine.common import _import_class
from mongoengine.errors import DoesNotExist, MultipleObjectsReturned
__all__ = ("BaseDict", "BaseList", "EmbeddedDocumentList")
class BaseDict(dict):
"""A special dict so we can watch any changes"""
_dereferenced = False
_instance = None
_name = None
def __init__(self, dict_items, instance, name):
Document = _import_class('Document')
EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(instance, (Document, EmbeddedDocument)):
self._instance = weakref.proxy(instance)
self._name = name
return super(BaseDict, self).__init__(dict_items)
def __getitem__(self, key, *args, **kwargs):
value = super(BaseDict, self).__getitem__(key)
EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(value, EmbeddedDocument) and value._instance is None:
value._instance = self._instance
elif not isinstance(value, BaseDict) and isinstance(value, dict):
value = BaseDict(value, None, '%s.%s' % (self._name, key))
super(BaseDict, self).__setitem__(key, value)
value._instance = self._instance
elif not isinstance(value, BaseList) and isinstance(value, list):
value = BaseList(value, None, '%s.%s' % (self._name, key))
super(BaseDict, self).__setitem__(key, value)
value._instance = self._instance
return value
def __setitem__(self, key, value, *args, **kwargs):
self._mark_as_changed(key)
return super(BaseDict, self).__setitem__(key, value)
def __delete__(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseDict, self).__delete__(*args, **kwargs)
def __delitem__(self, key, *args, **kwargs):
self._mark_as_changed(key)
return super(BaseDict, self).__delitem__(key)
def __delattr__(self, key, *args, **kwargs):
self._mark_as_changed(key)
return super(BaseDict, self).__delattr__(key)
def __getstate__(self):
self.instance = None
self._dereferenced = False
return self
def __setstate__(self, state):
self = state
return self
def clear(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseDict, self).clear(*args, **kwargs)
def pop(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseDict, self).pop(*args, **kwargs)
def popitem(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseDict, self).popitem(*args, **kwargs)
def setdefault(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseDict, self).setdefault(*args, **kwargs)
def update(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseDict, self).update(*args, **kwargs)
def _mark_as_changed(self, key=None):
if hasattr(self._instance, '_mark_as_changed'):
if key:
self._instance._mark_as_changed('%s.%s' % (self._name, key))
else:
self._instance._mark_as_changed(self._name)
class BaseList(list):
"""A special list so we can watch any changes
"""
_dereferenced = False
_instance = None
_name = None
def __init__(self, list_items, instance, name):
Document = _import_class('Document')
EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(instance, (Document, EmbeddedDocument)):
self._instance = weakref.proxy(instance)
self._name = name
super(BaseList, self).__init__(list_items)
def __getitem__(self, key, *args, **kwargs):
value = super(BaseList, self).__getitem__(key)
EmbeddedDocument = _import_class('EmbeddedDocument')
if isinstance(value, EmbeddedDocument) and value._instance is None:
value._instance = self._instance
elif not isinstance(value, BaseDict) and isinstance(value, dict):
value = BaseDict(value, None, '%s.%s' % (self._name, key))
super(BaseList, self).__setitem__(key, value)
value._instance = self._instance
elif not isinstance(value, BaseList) and isinstance(value, list):
value = BaseList(value, None, '%s.%s' % (self._name, key))
super(BaseList, self).__setitem__(key, value)
value._instance = self._instance
return value
def __iter__(self):
for i in xrange(self.__len__()):
yield self[i]
def __setitem__(self, key, value, *args, **kwargs):
if isinstance(key, slice):
self._mark_as_changed()
else:
self._mark_as_changed(key)
return super(BaseList, self).__setitem__(key, value)
def __delitem__(self, key, *args, **kwargs):
if isinstance(key, slice):
self._mark_as_changed()
else:
self._mark_as_changed(key)
return super(BaseList, self).__delitem__(key)
def __setslice__(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseList, self).__setslice__(*args, **kwargs)
def __delslice__(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseList, self).__delslice__(*args, **kwargs)
def __getstate__(self):
self.instance = None
self._dereferenced = False
return self
def __setstate__(self, state):
self = state
return self
def __iadd__(self, other):
self._mark_as_changed()
return super(BaseList, self).__iadd__(other)
def __imul__(self, other):
self._mark_as_changed()
return super(BaseList, self).__imul__(other)
def append(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseList, self).append(*args, **kwargs)
def extend(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseList, self).extend(*args, **kwargs)
def insert(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseList, self).insert(*args, **kwargs)
def pop(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseList, self).pop(*args, **kwargs)
def remove(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseList, self).remove(*args, **kwargs)
def reverse(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseList, self).reverse(*args, **kwargs)
def sort(self, *args, **kwargs):
self._mark_as_changed()
return super(BaseList, self).sort(*args, **kwargs)
def _mark_as_changed(self, key=None):
if hasattr(self._instance, '_mark_as_changed'):
if key:
self._instance._mark_as_changed('%s.%s' % (self._name, key))
else:
self._instance._mark_as_changed(self._name)
class EmbeddedDocumentList(BaseList):
@classmethod
def __match_all(cls, i, kwargs):
items = kwargs.items()
return all([
getattr(i, k) == v or str(getattr(i, k)) == v for k, v in items
])
@classmethod
def __only_matches(cls, obj, kwargs):
if not kwargs:
return obj
return filter(lambda i: cls.__match_all(i, kwargs), obj)
def __init__(self, list_items, instance, name):
super(EmbeddedDocumentList, self).__init__(list_items, instance, name)
self._instance = instance
def filter(self, **kwargs):
"""
Filters the list by only including embedded documents with the
given keyword arguments.
:param kwargs: The keyword arguments corresponding to the fields to
filter on. *Multiple arguments are treated as if they are ANDed
together.*
:return: A new ``EmbeddedDocumentList`` containing the matching
embedded documents.
Raises ``AttributeError`` if a given keyword is not a valid field for
the embedded document class.
"""
values = self.__only_matches(self, kwargs)
return EmbeddedDocumentList(values, self._instance, self._name)
def exclude(self, **kwargs):
"""
Filters the list by excluding embedded documents with the given
keyword arguments.
:param kwargs: The keyword arguments corresponding to the fields to
exclude on. *Multiple arguments are treated as if they are ANDed
together.*
:return: A new ``EmbeddedDocumentList`` containing the non-matching
embedded documents.
Raises ``AttributeError`` if a given keyword is not a valid field for
the embedded document class.
"""
exclude = self.__only_matches(self, kwargs)
values = [item for item in self if item not in exclude]
return EmbeddedDocumentList(values, self._instance, self._name)
def count(self):
"""
The number of embedded documents in the list.
:return: The length of the list, equivalent to the result of ``len()``.
"""
return len(self)
def get(self, **kwargs):
"""
Retrieves an embedded document determined by the given keyword
arguments.
:param kwargs: The keyword arguments corresponding to the fields to
search on. *Multiple arguments are treated as if they are ANDed
together.*
:return: The embedded document matched by the given keyword arguments.
Raises ``DoesNotExist`` if the arguments used to query an embedded
document returns no results. ``MultipleObjectsReturned`` if more
than one result is returned.
"""
values = self.__only_matches(self, kwargs)
if len(values) == 0:
raise DoesNotExist(
"%s matching query does not exist." % self._name
)
elif len(values) > 1:
raise MultipleObjectsReturned(
"%d items returned, instead of 1" % len(values)
)
return values[0]
def first(self):
"""
Returns the first embedded document in the list, or ``None`` if empty.
"""
if len(self) > 0:
return self[0]
def create(self, **values):
"""
Creates a new embedded document and saves it to the database.
.. note::
The embedded document changes are not automatically saved
to the database after calling this method.
:param values: A dictionary of values for the embedded document.
:return: The new embedded document instance.
"""
name = self._name
EmbeddedClass = self._instance._fields[name].field.document_type_obj
self._instance[self._name].append(EmbeddedClass(**values))
return self._instance[self._name][-1]
def save(self, *args, **kwargs):
"""
Saves the ancestor document.
:param args: Arguments passed up to the ancestor Document's save
method.
:param kwargs: Keyword arguments passed up to the ancestor Document's
save method.
"""
self._instance.save(*args, **kwargs)
def delete(self):
"""
Deletes the embedded documents from the database.
.. note::
The embedded document changes are not automatically saved
to the database after calling this method.
:return: The number of entries deleted.
"""
values = list(self)
for item in values:
self._instance[self._name].remove(item)
return len(values)
def update(self, **update):
"""
Updates the embedded documents with the given update values.
.. note::
The embedded document changes are not automatically saved
to the database after calling this method.
:param update: A dictionary of update values to apply to each
embedded document.
:return: The number of entries updated.
"""
if len(update) == 0:
return 0
values = list(self)
for item in values:
for k, v in update.items():
setattr(item, k, v)
return len(values)
class StrictDict(object):
__slots__ = ()
_special_fields = set(['get', 'pop', 'iteritems', 'items', 'keys', 'create'])
_classes = {}
def __init__(self, **kwargs):
for k,v in kwargs.iteritems():
setattr(self, k, v)
def __getitem__(self, key):
key = '_reserved_' + key if key in self._special_fields else key
try:
return getattr(self, key)
except AttributeError:
raise KeyError(key)
def __setitem__(self, key, value):
key = '_reserved_' + key if key in self._special_fields else key
return setattr(self, key, value)
def __contains__(self, key):
return hasattr(self, key)
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def pop(self, key, default=None):
v = self.get(key, default)
try:
delattr(self, key)
except AttributeError:
pass
return v
def iteritems(self):
for key in self:
yield key, self[key]
def items(self):
return [(k, self[k]) for k in iter(self)]
def keys(self):
return list(iter(self))
def __iter__(self):
return (key for key in self.__slots__ if hasattr(self, key))
def __len__(self):
return len(list(self.iteritems()))
def __eq__(self, other):
return self.items() == other.items()
def __neq__(self, other):
return self.items() != other.items()
@classmethod
def create(cls, allowed_keys):
allowed_keys_tuple = tuple(('_reserved_' + k if k in cls._special_fields else k) for k in allowed_keys)
allowed_keys = frozenset(allowed_keys_tuple)
if allowed_keys not in cls._classes:
class SpecificStrictDict(cls):
__slots__ = allowed_keys_tuple
def __repr__(self):
return "{%s}" % ', '.join('"{0!s}": {0!r}'.format(k,v) for (k,v) in self.iteritems())
cls._classes[allowed_keys] = SpecificStrictDict
return cls._classes[allowed_keys]
class SemiStrictDict(StrictDict):
__slots__ = ('_extras')
_classes = {}
def __getattr__(self, attr):
try:
super(SemiStrictDict, self).__getattr__(attr)
except AttributeError:
try:
return self.__getattribute__('_extras')[attr]
except KeyError as e:
raise AttributeError(e)
def __setattr__(self, attr, value):
try:
super(SemiStrictDict, self).__setattr__(attr, value)
except AttributeError:
try:
self._extras[attr] = value
except AttributeError:
self._extras = {attr: value}
def __delattr__(self, attr):
try:
super(SemiStrictDict, self).__delattr__(attr)
except AttributeError:
try:
del self._extras[attr]
except KeyError as e:
raise AttributeError(e)
def __iter__(self):
try:
extras_iter = iter(self.__getattribute__('_extras'))
except AttributeError:
extras_iter = ()
return itertools.chain(super(SemiStrictDict, self).__iter__(), extras_iter)