import weakref from bson import DBRef import six from six import iteritems from mongoengine.common import _import_class from mongoengine.errors import DoesNotExist, MultipleObjectsReturned __all__ = ('BaseDict', 'StrictDict', 'BaseList', 'EmbeddedDocumentList', 'LazyReference') def mark_as_changed_wrapper(parent_method): """Decorators that ensures _mark_as_changed method gets called""" def wrapper(self, *args, **kwargs): result = parent_method(self, *args, **kwargs) # Can't use super() in the decorator self._mark_as_changed() return result return wrapper def mark_key_as_changed_wrapper(parent_method): """Decorators that ensures _mark_as_changed method gets called with the key argument""" def wrapper(self, key, *args, **kwargs): result = parent_method(self, key, *args, **kwargs) # Can't use super() in the decorator self._mark_as_changed(key) return result return wrapper class BaseDict(dict): """A special dict so we can watch any changes.""" _dereferenced = False _instance = None _name = None def __init__(self, dict_items, instance, name): BaseDocument = _import_class('BaseDocument') if isinstance(instance, BaseDocument): self._instance = weakref.proxy(instance) self._name = name super(BaseDict, self).__init__(dict_items) def get(self, key, default=None): # get does not use __getitem__ by default so we must override it as well try: return self.__getitem__(key) except KeyError: return default def __getitem__(self, key): value = super(BaseDict, self).__getitem__(key) EmbeddedDocument = _import_class('EmbeddedDocument') if isinstance(value, EmbeddedDocument) and value._instance is None: value._instance = self._instance elif isinstance(value, dict) and not isinstance(value, BaseDict): value = BaseDict(value, None, '%s.%s' % (self._name, key)) super(BaseDict, self).__setitem__(key, value) value._instance = self._instance elif isinstance(value, list) and not isinstance(value, BaseList): value = BaseList(value, None, '%s.%s' % (self._name, key)) super(BaseDict, self).__setitem__(key, value) value._instance = self._instance return value def __getstate__(self): self.instance = None self._dereferenced = False return self def __setstate__(self, state): self = state return self __setitem__ = mark_key_as_changed_wrapper(dict.__setitem__) __delattr__ = mark_key_as_changed_wrapper(dict.__delattr__) __delitem__ = mark_key_as_changed_wrapper(dict.__delitem__) pop = mark_as_changed_wrapper(dict.pop) clear = mark_as_changed_wrapper(dict.clear) update = mark_as_changed_wrapper(dict.update) popitem = mark_as_changed_wrapper(dict.popitem) setdefault = mark_as_changed_wrapper(dict.setdefault) def _mark_as_changed(self, key=None): if hasattr(self._instance, '_mark_as_changed'): if key: self._instance._mark_as_changed('%s.%s' % (self._name, key)) else: self._instance._mark_as_changed(self._name) class BaseList(list): """A special list so we can watch any changes.""" _dereferenced = False _instance = None _name = None def __init__(self, list_items, instance, name): BaseDocument = _import_class('BaseDocument') if isinstance(instance, BaseDocument): self._instance = weakref.proxy(instance) self._name = name super(BaseList, self).__init__(list_items) def __getitem__(self, key): value = super(BaseList, self).__getitem__(key) if isinstance(key, slice): # When receiving a slice operator, we don't convert the structure and bind # to parent's instance. This is buggy for now but would require more work to be handled properly return value EmbeddedDocument = _import_class('EmbeddedDocument') if isinstance(value, EmbeddedDocument) and value._instance is None: value._instance = self._instance elif isinstance(value, dict) and not isinstance(value, BaseDict): # Replace dict by BaseDict value = BaseDict(value, None, '%s.%s' % (self._name, key)) super(BaseList, self).__setitem__(key, value) value._instance = self._instance elif isinstance(value, list) and not isinstance(value, BaseList): # Replace list by BaseList value = BaseList(value, None, '%s.%s' % (self._name, key)) super(BaseList, self).__setitem__(key, value) value._instance = self._instance return value def __iter__(self): for v in super(BaseList, self).__iter__(): yield v def __getstate__(self): self.instance = None self._dereferenced = False return self def __setstate__(self, state): self = state return self def __setitem__(self, key, value): changed_key = key if isinstance(key, slice): # In case of slice, we don't bother to identify the exact elements being updated # instead, we simply marks the whole list as changed changed_key = None result = super(BaseList, self).__setitem__(key, value) self._mark_as_changed(changed_key) return result append = mark_as_changed_wrapper(list.append) extend = mark_as_changed_wrapper(list.extend) insert = mark_as_changed_wrapper(list.insert) pop = mark_as_changed_wrapper(list.pop) remove = mark_as_changed_wrapper(list.remove) reverse = mark_as_changed_wrapper(list.reverse) sort = mark_as_changed_wrapper(list.sort) __delitem__ = mark_as_changed_wrapper(list.__delitem__) __iadd__ = mark_as_changed_wrapper(list.__iadd__) __imul__ = mark_as_changed_wrapper(list.__imul__) if six.PY2: # Under py3 __setslice__, __delslice__ and __getslice__ # are replaced by __setitem__, __delitem__ and __getitem__ with a slice as parameter # so we mimic this under python 2 def __setslice__(self, i, j, sequence): return self.__setitem__(slice(i, j), sequence) def __delslice__(self, i, j): return self.__delitem__(slice(i, j)) def __getslice__(self, i, j): return self.__getitem__(slice(i, j)) def _mark_as_changed(self, key=None): if hasattr(self._instance, '_mark_as_changed'): if key: self._instance._mark_as_changed( '%s.%s' % (self._name, key % len(self)) ) else: self._instance._mark_as_changed(self._name) class EmbeddedDocumentList(BaseList): def __init__(self, list_items, instance, name): super(EmbeddedDocumentList, self).__init__(list_items, instance, name) self._instance = instance @classmethod def __match_all(cls, embedded_doc, kwargs): """Return True if a given embedded doc matches all the filter kwargs. If it doesn't return False. """ for key, expected_value in kwargs.items(): doc_val = getattr(embedded_doc, key) if doc_val != expected_value and six.text_type(doc_val) != expected_value: return False return True @classmethod def __only_matches(cls, embedded_docs, kwargs): """Return embedded docs that match the filter kwargs.""" if not kwargs: return embedded_docs return [doc for doc in embedded_docs if cls.__match_all(doc, kwargs)] def filter(self, **kwargs): """ Filters the list by only including embedded documents with the given keyword arguments. This method only supports simple comparison (e.g: .filter(name='John Doe')) and does not support operators like __gte, __lte, __icontains like queryset.filter does :param kwargs: The keyword arguments corresponding to the fields to filter on. *Multiple arguments are treated as if they are ANDed together.* :return: A new ``EmbeddedDocumentList`` containing the matching embedded documents. Raises ``AttributeError`` if a given keyword is not a valid field for the embedded document class. """ values = self.__only_matches(self, kwargs) return EmbeddedDocumentList(values, self._instance, self._name) def exclude(self, **kwargs): """ Filters the list by excluding embedded documents with the given keyword arguments. :param kwargs: The keyword arguments corresponding to the fields to exclude on. *Multiple arguments are treated as if they are ANDed together.* :return: A new ``EmbeddedDocumentList`` containing the non-matching embedded documents. Raises ``AttributeError`` if a given keyword is not a valid field for the embedded document class. """ exclude = self.__only_matches(self, kwargs) values = [item for item in self if item not in exclude] return EmbeddedDocumentList(values, self._instance, self._name) def count(self): """ The number of embedded documents in the list. :return: The length of the list, equivalent to the result of ``len()``. """ return len(self) def get(self, **kwargs): """ Retrieves an embedded document determined by the given keyword arguments. :param kwargs: The keyword arguments corresponding to the fields to search on. *Multiple arguments are treated as if they are ANDed together.* :return: The embedded document matched by the given keyword arguments. Raises ``DoesNotExist`` if the arguments used to query an embedded document returns no results. ``MultipleObjectsReturned`` if more than one result is returned. """ values = self.__only_matches(self, kwargs) if len(values) == 0: raise DoesNotExist( '%s matching query does not exist.' % self._name ) elif len(values) > 1: raise MultipleObjectsReturned( '%d items returned, instead of 1' % len(values) ) return values[0] def first(self): """Return the first embedded document in the list, or ``None`` if empty. """ if len(self) > 0: return self[0] def create(self, **values): """ Creates a new embedded document and saves it to the database. .. note:: The embedded document changes are not automatically saved to the database after calling this method. :param values: A dictionary of values for the embedded document. :return: The new embedded document instance. """ name = self._name EmbeddedClass = self._instance._fields[name].field.document_type_obj self._instance[self._name].append(EmbeddedClass(**values)) return self._instance[self._name][-1] def save(self, *args, **kwargs): """ Saves the ancestor document. :param args: Arguments passed up to the ancestor Document's save method. :param kwargs: Keyword arguments passed up to the ancestor Document's save method. """ self._instance.save(*args, **kwargs) def delete(self): """ Deletes the embedded documents from the database. .. note:: The embedded document changes are not automatically saved to the database after calling this method. :return: The number of entries deleted. """ values = list(self) for item in values: self._instance[self._name].remove(item) return len(values) def update(self, **update): """ Updates the embedded documents with the given replacement values. This function does not support mongoDB update operators such as ``inc__``. .. note:: The embedded document changes are not automatically saved to the database after calling this method. :param update: A dictionary of update values to apply to each embedded document. :return: The number of entries updated. """ if len(update) == 0: return 0 values = list(self) for item in values: for k, v in update.items(): setattr(item, k, v) return len(values) class StrictDict(object): __slots__ = () _special_fields = {'get', 'pop', 'iteritems', 'items', 'keys', 'create'} _classes = {} def __init__(self, **kwargs): for k, v in iteritems(kwargs): setattr(self, k, v) def __getitem__(self, key): key = '_reserved_' + key if key in self._special_fields else key try: return getattr(self, key) except AttributeError: raise KeyError(key) def __setitem__(self, key, value): key = '_reserved_' + key if key in self._special_fields else key return setattr(self, key, value) def __contains__(self, key): return hasattr(self, key) def get(self, key, default=None): try: return self[key] except KeyError: return default def pop(self, key, default=None): v = self.get(key, default) try: delattr(self, key) except AttributeError: pass return v def iteritems(self): for key in self: yield key, self[key] def items(self): return [(k, self[k]) for k in iter(self)] def iterkeys(self): return iter(self) def keys(self): return list(iter(self)) def __iter__(self): return (key for key in self.__slots__ if hasattr(self, key)) def __len__(self): return len(list(iteritems(self))) def __eq__(self, other): return self.items() == other.items() def __ne__(self, other): return self.items() != other.items() @classmethod def create(cls, allowed_keys): allowed_keys_tuple = tuple(('_reserved_' + k if k in cls._special_fields else k) for k in allowed_keys) allowed_keys = frozenset(allowed_keys_tuple) if allowed_keys not in cls._classes: class SpecificStrictDict(cls): __slots__ = allowed_keys_tuple def __repr__(self): return '{%s}' % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items()) cls._classes[allowed_keys] = SpecificStrictDict return cls._classes[allowed_keys] class LazyReference(DBRef): __slots__ = ('_cached_doc', 'passthrough', 'document_type') def fetch(self, force=False): if not self._cached_doc or force: self._cached_doc = self.document_type.objects.get(pk=self.pk) if not self._cached_doc: raise DoesNotExist('Trying to dereference unknown document %s' % (self)) return self._cached_doc @property def pk(self): return self.id def __init__(self, document_type, pk, cached_doc=None, passthrough=False): self.document_type = document_type self._cached_doc = cached_doc self.passthrough = passthrough super(LazyReference, self).__init__(self.document_type._get_collection_name(), pk) def __getitem__(self, name): if not self.passthrough: raise KeyError() document = self.fetch() return document[name] def __getattr__(self, name): if not object.__getattribute__(self, 'passthrough'): raise AttributeError() document = self.fetch() try: return document[name] except KeyError: raise AttributeError() def __repr__(self): return "" % (self.document_type, self.pk)