Merge branch 'master' of github.com:MongoEngine/mongoengine into negative_indexes_in_list

This commit is contained in:
Bastien Gérard
2019-07-24 21:14:55 +02:00
86 changed files with 9016 additions and 7443 deletions

View File

@@ -18,12 +18,17 @@ from mongoengine.queryset import *
from mongoengine.signals import *
__all__ = (list(document.__all__) + list(fields.__all__) +
list(connection.__all__) + list(queryset.__all__) +
list(signals.__all__) + list(errors.__all__))
__all__ = (
list(document.__all__)
+ list(fields.__all__)
+ list(connection.__all__)
+ list(queryset.__all__)
+ list(signals.__all__)
+ list(errors.__all__)
)
VERSION = (0, 18, 0)
VERSION = (0, 18, 2)
def get_version():
@@ -31,7 +36,7 @@ def get_version():
For example, if `VERSION == (0, 10, 7)`, return '0.10.7'.
"""
return '.'.join(map(str, VERSION))
return ".".join(map(str, VERSION))
__version__ = get_version()

View File

@@ -12,17 +12,22 @@ from mongoengine.base.metaclasses import *
__all__ = (
# common
'UPDATE_OPERATORS', '_document_registry', 'get_document',
"UPDATE_OPERATORS",
"_document_registry",
"get_document",
# datastructures
'BaseDict', 'BaseList', 'EmbeddedDocumentList', 'LazyReference',
"BaseDict",
"BaseList",
"EmbeddedDocumentList",
"LazyReference",
# document
'BaseDocument',
"BaseDocument",
# fields
'BaseField', 'ComplexBaseField', 'ObjectIdField', 'GeoJsonBaseField',
"BaseField",
"ComplexBaseField",
"ObjectIdField",
"GeoJsonBaseField",
# metaclasses
'DocumentMetaclass', 'TopLevelDocumentMetaclass'
"DocumentMetaclass",
"TopLevelDocumentMetaclass",
)

View File

@@ -1,12 +1,25 @@
from mongoengine.errors import NotRegistered
__all__ = ('UPDATE_OPERATORS', 'get_document', '_document_registry')
__all__ = ("UPDATE_OPERATORS", "get_document", "_document_registry")
UPDATE_OPERATORS = {'set', 'unset', 'inc', 'dec', 'mul',
'pop', 'push', 'push_all', 'pull',
'pull_all', 'add_to_set', 'set_on_insert',
'min', 'max', 'rename'}
UPDATE_OPERATORS = {
"set",
"unset",
"inc",
"dec",
"mul",
"pop",
"push",
"push_all",
"pull",
"pull_all",
"add_to_set",
"set_on_insert",
"min",
"max",
"rename",
}
_document_registry = {}
@@ -17,25 +30,33 @@ def get_document(name):
doc = _document_registry.get(name, None)
if not doc:
# Possible old style name
single_end = name.split('.')[-1]
compound_end = '.%s' % single_end
possible_match = [k for k in _document_registry
if k.endswith(compound_end) or k == single_end]
single_end = name.split(".")[-1]
compound_end = ".%s" % single_end
possible_match = [
k for k in _document_registry if k.endswith(compound_end) or k == single_end
]
if len(possible_match) == 1:
doc = _document_registry.get(possible_match.pop(), None)
if not doc:
raise NotRegistered("""
raise NotRegistered(
"""
`%s` has not been registered in the document registry.
Importing the document class automatically registers it, has it
been imported?
""".strip() % name)
""".strip()
% name
)
return doc
def _get_documents_by_db(connection_alias, default_connection_alias):
"""Get all registered Documents class attached to a given database"""
def get_doc_alias(doc_cls):
return doc_cls._meta.get('db_alias', default_connection_alias)
return [doc_cls for doc_cls in _document_registry.values()
if get_doc_alias(doc_cls) == connection_alias]
def get_doc_alias(doc_cls):
return doc_cls._meta.get("db_alias", default_connection_alias)
return [
doc_cls
for doc_cls in _document_registry.values()
if get_doc_alias(doc_cls) == connection_alias
]

View File

@@ -7,24 +7,36 @@ from six import iteritems
from mongoengine.common import _import_class
from mongoengine.errors import DoesNotExist, MultipleObjectsReturned
__all__ = ('BaseDict', 'StrictDict', 'BaseList', 'EmbeddedDocumentList', 'LazyReference')
__all__ = (
"BaseDict",
"StrictDict",
"BaseList",
"EmbeddedDocumentList",
"LazyReference",
)
def mark_as_changed_wrapper(parent_method):
"""Decorators that ensures _mark_as_changed method gets called"""
"""Decorator that ensures _mark_as_changed method gets called."""
def wrapper(self, *args, **kwargs):
result = parent_method(self, *args, **kwargs) # Can't use super() in the decorator
# Can't use super() in the decorator.
result = parent_method(self, *args, **kwargs)
self._mark_as_changed()
return result
return wrapper
def mark_key_as_changed_wrapper(parent_method):
"""Decorators that ensures _mark_as_changed method gets called with the key argument"""
"""Decorator that ensures _mark_as_changed method gets called with the key argument"""
def wrapper(self, key, *args, **kwargs):
result = parent_method(self, key, *args, **kwargs) # Can't use super() in the decorator
# Can't use super() in the decorator.
result = parent_method(self, key, *args, **kwargs)
self._mark_as_changed(key)
return result
return wrapper
@@ -36,7 +48,7 @@ class BaseDict(dict):
_name = None
def __init__(self, dict_items, instance, name):
BaseDocument = _import_class('BaseDocument')
BaseDocument = _import_class("BaseDocument")
if isinstance(instance, BaseDocument):
self._instance = weakref.proxy(instance)
@@ -53,15 +65,15 @@ class BaseDict(dict):
def __getitem__(self, key):
value = super(BaseDict, self).__getitem__(key)
EmbeddedDocument = _import_class('EmbeddedDocument')
EmbeddedDocument = _import_class("EmbeddedDocument")
if isinstance(value, EmbeddedDocument) and value._instance is None:
value._instance = self._instance
elif isinstance(value, dict) and not isinstance(value, BaseDict):
value = BaseDict(value, None, '%s.%s' % (self._name, key))
value = BaseDict(value, None, "%s.%s" % (self._name, key))
super(BaseDict, self).__setitem__(key, value)
value._instance = self._instance
elif isinstance(value, list) and not isinstance(value, BaseList):
value = BaseList(value, None, '%s.%s' % (self._name, key))
value = BaseList(value, None, "%s.%s" % (self._name, key))
super(BaseDict, self).__setitem__(key, value)
value._instance = self._instance
return value
@@ -85,9 +97,9 @@ class BaseDict(dict):
setdefault = mark_as_changed_wrapper(dict.setdefault)
def _mark_as_changed(self, key=None):
if hasattr(self._instance, '_mark_as_changed'):
if hasattr(self._instance, "_mark_as_changed"):
if key:
self._instance._mark_as_changed('%s.%s' % (self._name, key))
self._instance._mark_as_changed("%s.%s" % (self._name, key))
else:
self._instance._mark_as_changed(self._name)
@@ -100,7 +112,7 @@ class BaseList(list):
_name = None
def __init__(self, list_items, instance, name):
BaseDocument = _import_class('BaseDocument')
BaseDocument = _import_class("BaseDocument")
if isinstance(instance, BaseDocument):
self._instance = weakref.proxy(instance)
@@ -118,17 +130,17 @@ class BaseList(list):
# to parent's instance. This is buggy for now but would require more work to be handled properly
return value
EmbeddedDocument = _import_class('EmbeddedDocument')
EmbeddedDocument = _import_class("EmbeddedDocument")
if isinstance(value, EmbeddedDocument) and value._instance is None:
value._instance = self._instance
elif isinstance(value, dict) and not isinstance(value, BaseDict):
# Replace dict by BaseDict
value = BaseDict(value, None, '%s.%s' % (self._name, key))
value = BaseDict(value, None, "%s.%s" % (self._name, key))
super(BaseList, self).__setitem__(key, value)
value._instance = self._instance
elif isinstance(value, list) and not isinstance(value, BaseList):
# Replace list by BaseList
value = BaseList(value, None, '%s.%s' % (self._name, key))
value = BaseList(value, None, "%s.%s" % (self._name, key))
super(BaseList, self).__setitem__(key, value)
value._instance = self._instance
return value
@@ -182,17 +194,14 @@ class BaseList(list):
return self.__getitem__(slice(i, j))
def _mark_as_changed(self, key=None):
if hasattr(self._instance, '_mark_as_changed'):
if hasattr(self._instance, "_mark_as_changed"):
if key:
self._instance._mark_as_changed(
'%s.%s' % (self._name, key % len(self))
)
self._instance._mark_as_changed("%s.%s" % (self._name, key % len(self)))
else:
self._instance._mark_as_changed(self._name)
class EmbeddedDocumentList(BaseList):
def __init__(self, list_items, instance, name):
super(EmbeddedDocumentList, self).__init__(list_items, instance, name)
self._instance = instance
@@ -277,12 +286,10 @@ class EmbeddedDocumentList(BaseList):
"""
values = self.__only_matches(self, kwargs)
if len(values) == 0:
raise DoesNotExist(
'%s matching query does not exist.' % self._name
)
raise DoesNotExist("%s matching query does not exist." % self._name)
elif len(values) > 1:
raise MultipleObjectsReturned(
'%d items returned, instead of 1' % len(values)
"%d items returned, instead of 1" % len(values)
)
return values[0]
@@ -363,7 +370,7 @@ class EmbeddedDocumentList(BaseList):
class StrictDict(object):
__slots__ = ()
_special_fields = {'get', 'pop', 'iteritems', 'items', 'keys', 'create'}
_special_fields = {"get", "pop", "iteritems", "items", "keys", "create"}
_classes = {}
def __init__(self, **kwargs):
@@ -371,14 +378,14 @@ class StrictDict(object):
setattr(self, k, v)
def __getitem__(self, key):
key = '_reserved_' + key if key in self._special_fields else key
key = "_reserved_" + key if key in self._special_fields else key
try:
return getattr(self, key)
except AttributeError:
raise KeyError(key)
def __setitem__(self, key, value):
key = '_reserved_' + key if key in self._special_fields else key
key = "_reserved_" + key if key in self._special_fields else key
return setattr(self, key, value)
def __contains__(self, key):
@@ -425,27 +432,32 @@ class StrictDict(object):
@classmethod
def create(cls, allowed_keys):
allowed_keys_tuple = tuple(('_reserved_' + k if k in cls._special_fields else k) for k in allowed_keys)
allowed_keys_tuple = tuple(
("_reserved_" + k if k in cls._special_fields else k) for k in allowed_keys
)
allowed_keys = frozenset(allowed_keys_tuple)
if allowed_keys not in cls._classes:
class SpecificStrictDict(cls):
__slots__ = allowed_keys_tuple
def __repr__(self):
return '{%s}' % ', '.join('"{0!s}": {1!r}'.format(k, v) for k, v in self.items())
return "{%s}" % ", ".join(
'"{0!s}": {1!r}'.format(k, v) for k, v in self.items()
)
cls._classes[allowed_keys] = SpecificStrictDict
return cls._classes[allowed_keys]
class LazyReference(DBRef):
__slots__ = ('_cached_doc', 'passthrough', 'document_type')
__slots__ = ("_cached_doc", "passthrough", "document_type")
def fetch(self, force=False):
if not self._cached_doc or force:
self._cached_doc = self.document_type.objects.get(pk=self.pk)
if not self._cached_doc:
raise DoesNotExist('Trying to dereference unknown document %s' % (self))
raise DoesNotExist("Trying to dereference unknown document %s" % (self))
return self._cached_doc
@property
@@ -456,7 +468,9 @@ class LazyReference(DBRef):
self.document_type = document_type
self._cached_doc = cached_doc
self.passthrough = passthrough
super(LazyReference, self).__init__(self.document_type._get_collection_name(), pk)
super(LazyReference, self).__init__(
self.document_type._get_collection_name(), pk
)
def __getitem__(self, name):
if not self.passthrough:
@@ -465,7 +479,7 @@ class LazyReference(DBRef):
return document[name]
def __getattr__(self, name):
if not object.__getattribute__(self, 'passthrough'):
if not object.__getattribute__(self, "passthrough"):
raise AttributeError()
document = self.fetch()
try:

File diff suppressed because it is too large Load Diff

View File

@@ -8,13 +8,11 @@ import six
from six import iteritems
from mongoengine.base.common import UPDATE_OPERATORS
from mongoengine.base.datastructures import (BaseDict, BaseList,
EmbeddedDocumentList)
from mongoengine.base.datastructures import BaseDict, BaseList, EmbeddedDocumentList
from mongoengine.common import _import_class
from mongoengine.errors import DeprecatedError, ValidationError
__all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField',
'GeoJsonBaseField')
__all__ = ("BaseField", "ComplexBaseField", "ObjectIdField", "GeoJsonBaseField")
class BaseField(object):
@@ -23,6 +21,7 @@ class BaseField(object):
.. versionchanged:: 0.5 - added verbose and help text
"""
name = None
_geo_index = False
_auto_gen = False # Call `generate` to generate a value
@@ -34,10 +33,21 @@ class BaseField(object):
creation_counter = 0
auto_creation_counter = -1
def __init__(self, db_field=None, name=None, required=False, default=None,
unique=False, unique_with=None, primary_key=False,
validation=None, choices=None, null=False, sparse=False,
**kwargs):
def __init__(
self,
db_field=None,
name=None,
required=False,
default=None,
unique=False,
unique_with=None,
primary_key=False,
validation=None,
choices=None,
null=False,
sparse=False,
**kwargs
):
"""
:param db_field: The database field to store this field in
(defaults to the name of the field)
@@ -65,7 +75,7 @@ class BaseField(object):
existing attributes. Common metadata includes `verbose_name` and
`help_text`.
"""
self.db_field = (db_field or name) if not primary_key else '_id'
self.db_field = (db_field or name) if not primary_key else "_id"
if name:
msg = 'Field\'s "name" attribute deprecated in favour of "db_field"'
@@ -82,17 +92,16 @@ class BaseField(object):
self._owner_document = None
# Make sure db_field is a string (if it's explicitly defined).
if (
self.db_field is not None and
not isinstance(self.db_field, six.string_types)
if self.db_field is not None and not isinstance(
self.db_field, six.string_types
):
raise TypeError('db_field should be a string.')
raise TypeError("db_field should be a string.")
# Make sure db_field doesn't contain any forbidden characters.
if isinstance(self.db_field, six.string_types) and (
'.' in self.db_field or
'\0' in self.db_field or
self.db_field.startswith('$')
"." in self.db_field
or "\0" in self.db_field
or self.db_field.startswith("$")
):
raise ValueError(
'field names cannot contain dots (".") or null characters '
@@ -102,15 +111,17 @@ class BaseField(object):
# Detect and report conflicts between metadata and base properties.
conflicts = set(dir(self)) & set(kwargs)
if conflicts:
raise TypeError('%s already has attribute(s): %s' % (
self.__class__.__name__, ', '.join(conflicts)))
raise TypeError(
"%s already has attribute(s): %s"
% (self.__class__.__name__, ", ".join(conflicts))
)
# Assign metadata to the instance
# This efficient method is available because no __slots__ are defined.
self.__dict__.update(kwargs)
# Adjust the appropriate creation counter, and save our local copy.
if self.db_field == '_id':
if self.db_field == "_id":
self.creation_counter = BaseField.auto_creation_counter
BaseField.auto_creation_counter -= 1
else:
@@ -128,10 +139,9 @@ class BaseField(object):
return instance._data.get(self.name)
def __set__(self, instance, value):
"""Descriptor for assigning a value to a field in a document.
"""
# If setting to None and there is a default
# Then set the value to the default value
"""Descriptor for assigning a value to a field in a document."""
# If setting to None and there is a default value provided for this
# field, then set the value to the default value.
if value is None:
if self.null:
value = None
@@ -142,24 +152,29 @@ class BaseField(object):
if instance._initialised:
try:
if (self.name not in instance._data or
instance._data[self.name] != value):
value_has_changed = (
self.name not in instance._data
or instance._data[self.name] != value
)
if value_has_changed:
instance._mark_as_changed(self.name)
except Exception:
# Values cant be compared eg: naive and tz datetimes
# So mark it as changed
# Some values can't be compared and throw an error when we
# attempt to do so (e.g. tz-naive and tz-aware datetimes).
# Mark the field as changed in such cases.
instance._mark_as_changed(self.name)
EmbeddedDocument = _import_class('EmbeddedDocument')
EmbeddedDocument = _import_class("EmbeddedDocument")
if isinstance(value, EmbeddedDocument):
value._instance = weakref.proxy(instance)
elif isinstance(value, (list, tuple)):
for v in value:
if isinstance(v, EmbeddedDocument):
v._instance = weakref.proxy(instance)
instance._data[self.name] = value
def error(self, message='', errors=None, field_name=None):
def error(self, message="", errors=None, field_name=None):
"""Raise a ValidationError."""
field_name = field_name if field_name else self.name
raise ValidationError(message, errors=errors, field_name=field_name)
@@ -176,11 +191,11 @@ class BaseField(object):
"""Helper method to call to_mongo with proper inputs."""
f_inputs = self.to_mongo.__code__.co_varnames
ex_vars = {}
if 'fields' in f_inputs:
ex_vars['fields'] = fields
if "fields" in f_inputs:
ex_vars["fields"] = fields
if 'use_db_field' in f_inputs:
ex_vars['use_db_field'] = use_db_field
if "use_db_field" in f_inputs:
ex_vars["use_db_field"] = use_db_field
return self.to_mongo(value, **ex_vars)
@@ -195,8 +210,8 @@ class BaseField(object):
pass
def _validate_choices(self, value):
Document = _import_class('Document')
EmbeddedDocument = _import_class('EmbeddedDocument')
Document = _import_class("Document")
EmbeddedDocument = _import_class("EmbeddedDocument")
choice_list = self.choices
if isinstance(next(iter(choice_list)), (list, tuple)):
@@ -207,15 +222,13 @@ class BaseField(object):
if isinstance(value, (Document, EmbeddedDocument)):
if not any(isinstance(value, c) for c in choice_list):
self.error(
'Value must be an instance of %s' % (
six.text_type(choice_list)
)
"Value must be an instance of %s" % (six.text_type(choice_list))
)
# Choices which are types other than Documents
else:
values = value if isinstance(value, (list, tuple)) else [value]
if len(set(values) - set(choice_list)):
self.error('Value must be one of %s' % six.text_type(choice_list))
self.error("Value must be one of %s" % six.text_type(choice_list))
def _validate(self, value, **kwargs):
# Check the Choices Constraint
@@ -231,13 +244,17 @@ class BaseField(object):
# in favor of having validation raising a ValidationError
ret = self.validation(value)
if ret is not None:
raise DeprecatedError('validation argument for `%s` must not return anything, '
'it should raise a ValidationError if validation fails' % self.name)
raise DeprecatedError(
"validation argument for `%s` must not return anything, "
"it should raise a ValidationError if validation fails"
% self.name
)
except ValidationError as ex:
self.error(str(ex))
else:
raise ValueError('validation argument for `"%s"` must be a '
'callable.' % self.name)
raise ValueError(
'validation argument for `"%s"` must be a ' "callable." % self.name
)
self.validate(value, **kwargs)
@@ -271,35 +288,41 @@ class ComplexBaseField(BaseField):
# Document class being used rather than a document object
return self
ReferenceField = _import_class('ReferenceField')
GenericReferenceField = _import_class('GenericReferenceField')
EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField')
ReferenceField = _import_class("ReferenceField")
GenericReferenceField = _import_class("GenericReferenceField")
EmbeddedDocumentListField = _import_class("EmbeddedDocumentListField")
auto_dereference = instance._fields[self.name]._auto_dereference
dereference = (auto_dereference and
(self.field is None or isinstance(self.field,
(GenericReferenceField, ReferenceField))))
dereference = auto_dereference and (
self.field is None
or isinstance(self.field, (GenericReferenceField, ReferenceField))
)
_dereference = _import_class('DeReference')()
_dereference = _import_class("DeReference")()
if (instance._initialised and
dereference and
instance._data.get(self.name) and
not getattr(instance._data[self.name], '_dereferenced', False)):
if (
instance._initialised
and dereference
and instance._data.get(self.name)
and not getattr(instance._data[self.name], "_dereferenced", False)
):
instance._data[self.name] = _dereference(
instance._data.get(self.name), max_depth=1, instance=instance,
name=self.name
instance._data.get(self.name),
max_depth=1,
instance=instance,
name=self.name,
)
if hasattr(instance._data[self.name], '_dereferenced'):
if hasattr(instance._data[self.name], "_dereferenced"):
instance._data[self.name]._dereferenced = True
value = super(ComplexBaseField, self).__get__(instance, owner)
# Convert lists / values so we can watch for any changes on them
if isinstance(value, (list, tuple)):
if (issubclass(type(self), EmbeddedDocumentListField) and
not isinstance(value, EmbeddedDocumentList)):
if issubclass(type(self), EmbeddedDocumentListField) and not isinstance(
value, EmbeddedDocumentList
):
value = EmbeddedDocumentList(value, instance, self.name)
elif not isinstance(value, BaseList):
value = BaseList(value, instance, self.name)
@@ -308,12 +331,13 @@ class ComplexBaseField(BaseField):
value = BaseDict(value, instance, self.name)
instance._data[self.name] = value
if (auto_dereference and instance._initialised and
isinstance(value, (BaseList, BaseDict)) and
not value._dereferenced):
value = _dereference(
value, max_depth=1, instance=instance, name=self.name
)
if (
auto_dereference
and instance._initialised
and isinstance(value, (BaseList, BaseDict))
and not value._dereferenced
):
value = _dereference(value, max_depth=1, instance=instance, name=self.name)
value._dereferenced = True
instance._data[self.name] = value
@@ -324,16 +348,16 @@ class ComplexBaseField(BaseField):
if isinstance(value, six.string_types):
return value
if hasattr(value, 'to_python'):
if hasattr(value, "to_python"):
return value.to_python()
BaseDocument = _import_class('BaseDocument')
BaseDocument = _import_class("BaseDocument")
if isinstance(value, BaseDocument):
# Something is wrong, return the value as it is
return value
is_list = False
if not hasattr(value, 'items'):
if not hasattr(value, "items"):
try:
is_list = True
value = {idx: v for idx, v in enumerate(value)}
@@ -342,50 +366,54 @@ class ComplexBaseField(BaseField):
if self.field:
self.field._auto_dereference = self._auto_dereference
value_dict = {key: self.field.to_python(item)
for key, item in value.items()}
value_dict = {
key: self.field.to_python(item) for key, item in value.items()
}
else:
Document = _import_class('Document')
Document = _import_class("Document")
value_dict = {}
for k, v in value.items():
if isinstance(v, Document):
# We need the id from the saved object to create the DBRef
if v.pk is None:
self.error('You can only reference documents once they'
' have been saved to the database')
self.error(
"You can only reference documents once they"
" have been saved to the database"
)
collection = v._get_collection_name()
value_dict[k] = DBRef(collection, v.pk)
elif hasattr(v, 'to_python'):
elif hasattr(v, "to_python"):
value_dict[k] = v.to_python()
else:
value_dict[k] = self.to_python(v)
if is_list: # Convert back to a list
return [v for _, v in sorted(value_dict.items(),
key=operator.itemgetter(0))]
return [
v for _, v in sorted(value_dict.items(), key=operator.itemgetter(0))
]
return value_dict
def to_mongo(self, value, use_db_field=True, fields=None):
"""Convert a Python type to a MongoDB-compatible type."""
Document = _import_class('Document')
EmbeddedDocument = _import_class('EmbeddedDocument')
GenericReferenceField = _import_class('GenericReferenceField')
Document = _import_class("Document")
EmbeddedDocument = _import_class("EmbeddedDocument")
GenericReferenceField = _import_class("GenericReferenceField")
if isinstance(value, six.string_types):
return value
if hasattr(value, 'to_mongo'):
if hasattr(value, "to_mongo"):
if isinstance(value, Document):
return GenericReferenceField().to_mongo(value)
cls = value.__class__
val = value.to_mongo(use_db_field, fields)
# If it's a document that is not inherited add _cls
if isinstance(value, EmbeddedDocument):
val['_cls'] = cls.__name__
val["_cls"] = cls.__name__
return val
is_list = False
if not hasattr(value, 'items'):
if not hasattr(value, "items"):
try:
is_list = True
value = {k: v for k, v in enumerate(value)}
@@ -403,39 +431,42 @@ class ComplexBaseField(BaseField):
if isinstance(v, Document):
# We need the id from the saved object to create the DBRef
if v.pk is None:
self.error('You can only reference documents once they'
' have been saved to the database')
self.error(
"You can only reference documents once they"
" have been saved to the database"
)
# If its a document that is not inheritable it won't have
# any _cls data so make it a generic reference allows
# us to dereference
meta = getattr(v, '_meta', {})
allow_inheritance = meta.get('allow_inheritance')
meta = getattr(v, "_meta", {})
allow_inheritance = meta.get("allow_inheritance")
if not allow_inheritance and not self.field:
value_dict[k] = GenericReferenceField().to_mongo(v)
else:
collection = v._get_collection_name()
value_dict[k] = DBRef(collection, v.pk)
elif hasattr(v, 'to_mongo'):
elif hasattr(v, "to_mongo"):
cls = v.__class__
val = v.to_mongo(use_db_field, fields)
# If it's a document that is not inherited add _cls
if isinstance(v, (Document, EmbeddedDocument)):
val['_cls'] = cls.__name__
val["_cls"] = cls.__name__
value_dict[k] = val
else:
value_dict[k] = self.to_mongo(v, use_db_field, fields)
if is_list: # Convert back to a list
return [v for _, v in sorted(value_dict.items(),
key=operator.itemgetter(0))]
return [
v for _, v in sorted(value_dict.items(), key=operator.itemgetter(0))
]
return value_dict
def validate(self, value):
"""If field is provided ensure the value is valid."""
errors = {}
if self.field:
if hasattr(value, 'iteritems') or hasattr(value, 'items'):
if hasattr(value, "iteritems") or hasattr(value, "items"):
sequence = iteritems(value)
else:
sequence = enumerate(value)
@@ -449,11 +480,10 @@ class ComplexBaseField(BaseField):
if errors:
field_class = self.field.__class__.__name__
self.error('Invalid %s item (%s)' % (field_class, value),
errors=errors)
self.error("Invalid %s item (%s)" % (field_class, value), errors=errors)
# Don't allow empty values if required
if self.required and not value:
self.error('Field is required and cannot be empty')
self.error("Field is required and cannot be empty")
def prepare_query_value(self, op, value):
return self.to_mongo(value)
@@ -496,7 +526,7 @@ class ObjectIdField(BaseField):
try:
ObjectId(six.text_type(value))
except Exception:
self.error('Invalid Object ID')
self.error("Invalid Object ID")
class GeoJsonBaseField(BaseField):
@@ -506,14 +536,14 @@ class GeoJsonBaseField(BaseField):
"""
_geo_index = pymongo.GEOSPHERE
_type = 'GeoBase'
_type = "GeoBase"
def __init__(self, auto_index=True, *args, **kwargs):
"""
:param bool auto_index: Automatically create a '2dsphere' index.\
Defaults to `True`.
"""
self._name = '%sField' % self._type
self._name = "%sField" % self._type
if not auto_index:
self._geo_index = False
super(GeoJsonBaseField, self).__init__(*args, **kwargs)
@@ -521,57 +551,58 @@ class GeoJsonBaseField(BaseField):
def validate(self, value):
"""Validate the GeoJson object based on its type."""
if isinstance(value, dict):
if set(value.keys()) == {'type', 'coordinates'}:
if value['type'] != self._type:
self.error('%s type must be "%s"' %
(self._name, self._type))
return self.validate(value['coordinates'])
if set(value.keys()) == {"type", "coordinates"}:
if value["type"] != self._type:
self.error('%s type must be "%s"' % (self._name, self._type))
return self.validate(value["coordinates"])
else:
self.error('%s can only accept a valid GeoJson dictionary'
' or lists of (x, y)' % self._name)
self.error(
"%s can only accept a valid GeoJson dictionary"
" or lists of (x, y)" % self._name
)
return
elif not isinstance(value, (list, tuple)):
self.error('%s can only accept lists of [x, y]' % self._name)
self.error("%s can only accept lists of [x, y]" % self._name)
return
validate = getattr(self, '_validate_%s' % self._type.lower())
validate = getattr(self, "_validate_%s" % self._type.lower())
error = validate(value)
if error:
self.error(error)
def _validate_polygon(self, value, top_level=True):
if not isinstance(value, (list, tuple)):
return 'Polygons must contain list of linestrings'
return "Polygons must contain list of linestrings"
# Quick and dirty validator
try:
value[0][0][0]
except (TypeError, IndexError):
return 'Invalid Polygon must contain at least one valid linestring'
return "Invalid Polygon must contain at least one valid linestring"
errors = []
for val in value:
error = self._validate_linestring(val, False)
if not error and val[0] != val[-1]:
error = 'LineStrings must start and end at the same point'
error = "LineStrings must start and end at the same point"
if error and error not in errors:
errors.append(error)
if errors:
if top_level:
return 'Invalid Polygon:\n%s' % ', '.join(errors)
return "Invalid Polygon:\n%s" % ", ".join(errors)
else:
return '%s' % ', '.join(errors)
return "%s" % ", ".join(errors)
def _validate_linestring(self, value, top_level=True):
"""Validate a linestring."""
if not isinstance(value, (list, tuple)):
return 'LineStrings must contain list of coordinate pairs'
return "LineStrings must contain list of coordinate pairs"
# Quick and dirty validator
try:
value[0][0]
except (TypeError, IndexError):
return 'Invalid LineString must contain at least one valid point'
return "Invalid LineString must contain at least one valid point"
errors = []
for val in value:
@@ -580,29 +611,30 @@ class GeoJsonBaseField(BaseField):
errors.append(error)
if errors:
if top_level:
return 'Invalid LineString:\n%s' % ', '.join(errors)
return "Invalid LineString:\n%s" % ", ".join(errors)
else:
return '%s' % ', '.join(errors)
return "%s" % ", ".join(errors)
def _validate_point(self, value):
"""Validate each set of coords"""
if not isinstance(value, (list, tuple)):
return 'Points must be a list of coordinate pairs'
return "Points must be a list of coordinate pairs"
elif not len(value) == 2:
return 'Value (%s) must be a two-dimensional point' % repr(value)
elif (not isinstance(value[0], (float, int)) or
not isinstance(value[1], (float, int))):
return 'Both values (%s) in point must be float or int' % repr(value)
return "Value (%s) must be a two-dimensional point" % repr(value)
elif not isinstance(value[0], (float, int)) or not isinstance(
value[1], (float, int)
):
return "Both values (%s) in point must be float or int" % repr(value)
def _validate_multipoint(self, value):
if not isinstance(value, (list, tuple)):
return 'MultiPoint must be a list of Point'
return "MultiPoint must be a list of Point"
# Quick and dirty validator
try:
value[0][0]
except (TypeError, IndexError):
return 'Invalid MultiPoint must contain at least one valid point'
return "Invalid MultiPoint must contain at least one valid point"
errors = []
for point in value:
@@ -611,17 +643,17 @@ class GeoJsonBaseField(BaseField):
errors.append(error)
if errors:
return '%s' % ', '.join(errors)
return "%s" % ", ".join(errors)
def _validate_multilinestring(self, value, top_level=True):
if not isinstance(value, (list, tuple)):
return 'MultiLineString must be a list of LineString'
return "MultiLineString must be a list of LineString"
# Quick and dirty validator
try:
value[0][0][0]
except (TypeError, IndexError):
return 'Invalid MultiLineString must contain at least one valid linestring'
return "Invalid MultiLineString must contain at least one valid linestring"
errors = []
for linestring in value:
@@ -631,19 +663,19 @@ class GeoJsonBaseField(BaseField):
if errors:
if top_level:
return 'Invalid MultiLineString:\n%s' % ', '.join(errors)
return "Invalid MultiLineString:\n%s" % ", ".join(errors)
else:
return '%s' % ', '.join(errors)
return "%s" % ", ".join(errors)
def _validate_multipolygon(self, value):
if not isinstance(value, (list, tuple)):
return 'MultiPolygon must be a list of Polygon'
return "MultiPolygon must be a list of Polygon"
# Quick and dirty validator
try:
value[0][0][0][0]
except (TypeError, IndexError):
return 'Invalid MultiPolygon must contain at least one valid Polygon'
return "Invalid MultiPolygon must contain at least one valid Polygon"
errors = []
for polygon in value:
@@ -652,9 +684,9 @@ class GeoJsonBaseField(BaseField):
errors.append(error)
if errors:
return 'Invalid MultiPolygon:\n%s' % ', '.join(errors)
return "Invalid MultiPolygon:\n%s" % ", ".join(errors)
def to_mongo(self, value):
if isinstance(value, dict):
return value
return SON([('type', self._type), ('coordinates', value)])
return SON([("type", self._type), ("coordinates", value)])

View File

@@ -1,3 +1,4 @@
import itertools
import warnings
import six
@@ -7,12 +8,15 @@ from mongoengine.base.common import _document_registry
from mongoengine.base.fields import BaseField, ComplexBaseField, ObjectIdField
from mongoengine.common import _import_class
from mongoengine.errors import InvalidDocumentError
from mongoengine.queryset import (DO_NOTHING, DoesNotExist,
MultipleObjectsReturned,
QuerySetManager)
from mongoengine.queryset import (
DO_NOTHING,
DoesNotExist,
MultipleObjectsReturned,
QuerySetManager,
)
__all__ = ('DocumentMetaclass', 'TopLevelDocumentMetaclass')
__all__ = ("DocumentMetaclass", "TopLevelDocumentMetaclass")
class DocumentMetaclass(type):
@@ -24,44 +28,46 @@ class DocumentMetaclass(type):
super_new = super(DocumentMetaclass, mcs).__new__
# If a base class just call super
metaclass = attrs.get('my_metaclass')
metaclass = attrs.get("my_metaclass")
if metaclass and issubclass(metaclass, DocumentMetaclass):
return super_new(mcs, name, bases, attrs)
attrs['_is_document'] = attrs.get('_is_document', False)
attrs['_cached_reference_fields'] = []
attrs["_is_document"] = attrs.get("_is_document", False)
attrs["_cached_reference_fields"] = []
# EmbeddedDocuments could have meta data for inheritance
if 'meta' in attrs:
attrs['_meta'] = attrs.pop('meta')
if "meta" in attrs:
attrs["_meta"] = attrs.pop("meta")
# EmbeddedDocuments should inherit meta data
if '_meta' not in attrs:
if "_meta" not in attrs:
meta = MetaDict()
for base in flattened_bases[::-1]:
# Add any mixin metadata from plain objects
if hasattr(base, 'meta'):
if hasattr(base, "meta"):
meta.merge(base.meta)
elif hasattr(base, '_meta'):
elif hasattr(base, "_meta"):
meta.merge(base._meta)
attrs['_meta'] = meta
attrs['_meta']['abstract'] = False # 789: EmbeddedDocument shouldn't inherit abstract
attrs["_meta"] = meta
attrs["_meta"][
"abstract"
] = False # 789: EmbeddedDocument shouldn't inherit abstract
# If allow_inheritance is True, add a "_cls" string field to the attrs
if attrs['_meta'].get('allow_inheritance'):
StringField = _import_class('StringField')
attrs['_cls'] = StringField()
if attrs["_meta"].get("allow_inheritance"):
StringField = _import_class("StringField")
attrs["_cls"] = StringField()
# Handle document Fields
# Merge all fields from subclasses
doc_fields = {}
for base in flattened_bases[::-1]:
if hasattr(base, '_fields'):
if hasattr(base, "_fields"):
doc_fields.update(base._fields)
# Standard object mixin - merge in any Fields
if not hasattr(base, '_meta'):
if not hasattr(base, "_meta"):
base_fields = {}
for attr_name, attr_value in iteritems(base.__dict__):
if not isinstance(attr_value, BaseField):
@@ -84,27 +90,31 @@ class DocumentMetaclass(type):
doc_fields[attr_name] = attr_value
# Count names to ensure no db_field redefinitions
field_names[attr_value.db_field] = field_names.get(
attr_value.db_field, 0) + 1
field_names[attr_value.db_field] = (
field_names.get(attr_value.db_field, 0) + 1
)
# Ensure no duplicate db_fields
duplicate_db_fields = [k for k, v in field_names.items() if v > 1]
if duplicate_db_fields:
msg = ('Multiple db_fields defined for: %s ' %
', '.join(duplicate_db_fields))
msg = "Multiple db_fields defined for: %s " % ", ".join(duplicate_db_fields)
raise InvalidDocumentError(msg)
# Set _fields and db_field maps
attrs['_fields'] = doc_fields
attrs['_db_field_map'] = {k: getattr(v, 'db_field', k)
for k, v in doc_fields.items()}
attrs['_reverse_db_field_map'] = {
v: k for k, v in attrs['_db_field_map'].items()
attrs["_fields"] = doc_fields
attrs["_db_field_map"] = {
k: getattr(v, "db_field", k) for k, v in doc_fields.items()
}
attrs["_reverse_db_field_map"] = {
v: k for k, v in attrs["_db_field_map"].items()
}
attrs['_fields_ordered'] = tuple(i[1] for i in sorted(
(v.creation_counter, v.name)
for v in itervalues(doc_fields)))
attrs["_fields_ordered"] = tuple(
i[1]
for i in sorted(
(v.creation_counter, v.name) for v in itervalues(doc_fields)
)
)
#
# Set document hierarchy
@@ -112,32 +122,34 @@ class DocumentMetaclass(type):
superclasses = ()
class_name = [name]
for base in flattened_bases:
if (not getattr(base, '_is_base_cls', True) and
not getattr(base, '_meta', {}).get('abstract', True)):
if not getattr(base, "_is_base_cls", True) and not getattr(
base, "_meta", {}
).get("abstract", True):
# Collate hierarchy for _cls and _subclasses
class_name.append(base.__name__)
if hasattr(base, '_meta'):
if hasattr(base, "_meta"):
# Warn if allow_inheritance isn't set and prevent
# inheritance of classes where inheritance is set to False
allow_inheritance = base._meta.get('allow_inheritance')
if not allow_inheritance and not base._meta.get('abstract'):
raise ValueError('Document %s may not be subclassed. '
'To enable inheritance, use the "allow_inheritance" meta attribute.' %
base.__name__)
allow_inheritance = base._meta.get("allow_inheritance")
if not allow_inheritance and not base._meta.get("abstract"):
raise ValueError(
"Document %s may not be subclassed. "
'To enable inheritance, use the "allow_inheritance" meta attribute.'
% base.__name__
)
# Get superclasses from last base superclass
document_bases = [b for b in flattened_bases
if hasattr(b, '_class_name')]
document_bases = [b for b in flattened_bases if hasattr(b, "_class_name")]
if document_bases:
superclasses = document_bases[0]._superclasses
superclasses += (document_bases[0]._class_name, )
superclasses += (document_bases[0]._class_name,)
_cls = '.'.join(reversed(class_name))
attrs['_class_name'] = _cls
attrs['_superclasses'] = superclasses
attrs['_subclasses'] = (_cls, )
attrs['_types'] = attrs['_subclasses'] # TODO depreciate _types
_cls = ".".join(reversed(class_name))
attrs["_class_name"] = _cls
attrs["_superclasses"] = superclasses
attrs["_subclasses"] = (_cls,)
attrs["_types"] = attrs["_subclasses"] # TODO depreciate _types
# Create the new_class
new_class = super_new(mcs, name, bases, attrs)
@@ -148,8 +160,12 @@ class DocumentMetaclass(type):
base._subclasses += (_cls,)
base._types = base._subclasses # TODO depreciate _types
(Document, EmbeddedDocument, DictField,
CachedReferenceField) = mcs._import_classes()
(
Document,
EmbeddedDocument,
DictField,
CachedReferenceField,
) = mcs._import_classes()
if issubclass(new_class, Document):
new_class._collection = None
@@ -168,52 +184,55 @@ class DocumentMetaclass(type):
for val in new_class.__dict__.values():
if isinstance(val, classmethod):
f = val.__get__(new_class)
if hasattr(f, '__func__') and not hasattr(f, 'im_func'):
f.__dict__.update({'im_func': getattr(f, '__func__')})
if hasattr(f, '__self__') and not hasattr(f, 'im_self'):
f.__dict__.update({'im_self': getattr(f, '__self__')})
if hasattr(f, "__func__") and not hasattr(f, "im_func"):
f.__dict__.update({"im_func": getattr(f, "__func__")})
if hasattr(f, "__self__") and not hasattr(f, "im_self"):
f.__dict__.update({"im_self": getattr(f, "__self__")})
# Handle delete rules
for field in itervalues(new_class._fields):
f = field
if f.owner_document is None:
f.owner_document = new_class
delete_rule = getattr(f, 'reverse_delete_rule', DO_NOTHING)
delete_rule = getattr(f, "reverse_delete_rule", DO_NOTHING)
if isinstance(f, CachedReferenceField):
if issubclass(new_class, EmbeddedDocument):
raise InvalidDocumentError('CachedReferenceFields is not '
'allowed in EmbeddedDocuments')
raise InvalidDocumentError(
"CachedReferenceFields is not allowed in EmbeddedDocuments"
)
if f.auto_sync:
f.start_listener()
f.document_type._cached_reference_fields.append(f)
if isinstance(f, ComplexBaseField) and hasattr(f, 'field'):
delete_rule = getattr(f.field,
'reverse_delete_rule',
DO_NOTHING)
if isinstance(f, ComplexBaseField) and hasattr(f, "field"):
delete_rule = getattr(f.field, "reverse_delete_rule", DO_NOTHING)
if isinstance(f, DictField) and delete_rule != DO_NOTHING:
msg = ('Reverse delete rules are not supported '
'for %s (field: %s)' %
(field.__class__.__name__, field.name))
msg = (
"Reverse delete rules are not supported "
"for %s (field: %s)" % (field.__class__.__name__, field.name)
)
raise InvalidDocumentError(msg)
f = field.field
if delete_rule != DO_NOTHING:
if issubclass(new_class, EmbeddedDocument):
msg = ('Reverse delete rules are not supported for '
'EmbeddedDocuments (field: %s)' % field.name)
msg = (
"Reverse delete rules are not supported for "
"EmbeddedDocuments (field: %s)" % field.name
)
raise InvalidDocumentError(msg)
f.document_type.register_delete_rule(new_class,
field.name, delete_rule)
f.document_type.register_delete_rule(new_class, field.name, delete_rule)
if (field.name and hasattr(Document, field.name) and
EmbeddedDocument not in new_class.mro()):
msg = ('%s is a document method and not a valid '
'field name' % field.name)
if (
field.name
and hasattr(Document, field.name)
and EmbeddedDocument not in new_class.mro()
):
msg = "%s is a document method and not a valid field name" % field.name
raise InvalidDocumentError(msg)
return new_class
@@ -238,10 +257,10 @@ class DocumentMetaclass(type):
@classmethod
def _import_classes(mcs):
Document = _import_class('Document')
EmbeddedDocument = _import_class('EmbeddedDocument')
DictField = _import_class('DictField')
CachedReferenceField = _import_class('CachedReferenceField')
Document = _import_class("Document")
EmbeddedDocument = _import_class("EmbeddedDocument")
DictField = _import_class("DictField")
CachedReferenceField = _import_class("CachedReferenceField")
return Document, EmbeddedDocument, DictField, CachedReferenceField
@@ -255,65 +274,67 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
super_new = super(TopLevelDocumentMetaclass, mcs).__new__
# Set default _meta data if base class, otherwise get user defined meta
if attrs.get('my_metaclass') == TopLevelDocumentMetaclass:
if attrs.get("my_metaclass") == TopLevelDocumentMetaclass:
# defaults
attrs['_meta'] = {
'abstract': True,
'max_documents': None,
'max_size': None,
'ordering': [], # default ordering applied at runtime
'indexes': [], # indexes to be ensured at runtime
'id_field': None,
'index_background': False,
'index_drop_dups': False,
'index_opts': None,
'delete_rules': None,
attrs["_meta"] = {
"abstract": True,
"max_documents": None,
"max_size": None,
"ordering": [], # default ordering applied at runtime
"indexes": [], # indexes to be ensured at runtime
"id_field": None,
"index_background": False,
"index_drop_dups": False,
"index_opts": None,
"delete_rules": None,
# allow_inheritance can be True, False, and None. True means
# "allow inheritance", False means "don't allow inheritance",
# None means "do whatever your parent does, or don't allow
# inheritance if you're a top-level class".
'allow_inheritance': None,
"allow_inheritance": None,
}
attrs['_is_base_cls'] = True
attrs['_meta'].update(attrs.get('meta', {}))
attrs["_is_base_cls"] = True
attrs["_meta"].update(attrs.get("meta", {}))
else:
attrs['_meta'] = attrs.get('meta', {})
attrs["_meta"] = attrs.get("meta", {})
# Explicitly set abstract to false unless set
attrs['_meta']['abstract'] = attrs['_meta'].get('abstract', False)
attrs['_is_base_cls'] = False
attrs["_meta"]["abstract"] = attrs["_meta"].get("abstract", False)
attrs["_is_base_cls"] = False
# Set flag marking as document class - as opposed to an object mixin
attrs['_is_document'] = True
attrs["_is_document"] = True
# Ensure queryset_class is inherited
if 'objects' in attrs:
manager = attrs['objects']
if hasattr(manager, 'queryset_class'):
attrs['_meta']['queryset_class'] = manager.queryset_class
if "objects" in attrs:
manager = attrs["objects"]
if hasattr(manager, "queryset_class"):
attrs["_meta"]["queryset_class"] = manager.queryset_class
# Clean up top level meta
if 'meta' in attrs:
del attrs['meta']
if "meta" in attrs:
del attrs["meta"]
# Find the parent document class
parent_doc_cls = [b for b in flattened_bases
if b.__class__ == TopLevelDocumentMetaclass]
parent_doc_cls = [
b for b in flattened_bases if b.__class__ == TopLevelDocumentMetaclass
]
parent_doc_cls = None if not parent_doc_cls else parent_doc_cls[0]
# Prevent classes setting collection different to their parents
# If parent wasn't an abstract class
if (parent_doc_cls and 'collection' in attrs.get('_meta', {}) and
not parent_doc_cls._meta.get('abstract', True)):
msg = 'Trying to set a collection on a subclass (%s)' % name
if (
parent_doc_cls
and "collection" in attrs.get("_meta", {})
and not parent_doc_cls._meta.get("abstract", True)
):
msg = "Trying to set a collection on a subclass (%s)" % name
warnings.warn(msg, SyntaxWarning)
del attrs['_meta']['collection']
del attrs["_meta"]["collection"]
# Ensure abstract documents have abstract bases
if attrs.get('_is_base_cls') or attrs['_meta'].get('abstract'):
if (parent_doc_cls and
not parent_doc_cls._meta.get('abstract', False)):
msg = 'Abstract document cannot have non-abstract base'
if attrs.get("_is_base_cls") or attrs["_meta"].get("abstract"):
if parent_doc_cls and not parent_doc_cls._meta.get("abstract", False):
msg = "Abstract document cannot have non-abstract base"
raise ValueError(msg)
return super_new(mcs, name, bases, attrs)
@@ -322,38 +343,43 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
meta = MetaDict()
for base in flattened_bases[::-1]:
# Add any mixin metadata from plain objects
if hasattr(base, 'meta'):
if hasattr(base, "meta"):
meta.merge(base.meta)
elif hasattr(base, '_meta'):
elif hasattr(base, "_meta"):
meta.merge(base._meta)
# Set collection in the meta if its callable
if (getattr(base, '_is_document', False) and
not base._meta.get('abstract')):
collection = meta.get('collection', None)
if getattr(base, "_is_document", False) and not base._meta.get("abstract"):
collection = meta.get("collection", None)
if callable(collection):
meta['collection'] = collection(base)
meta["collection"] = collection(base)
meta.merge(attrs.get('_meta', {})) # Top level meta
meta.merge(attrs.get("_meta", {})) # Top level meta
# Only simple classes (i.e. direct subclasses of Document) may set
# allow_inheritance to False. If the base Document allows inheritance,
# none of its subclasses can override allow_inheritance to False.
simple_class = all([b._meta.get('abstract')
for b in flattened_bases if hasattr(b, '_meta')])
simple_class = all(
[b._meta.get("abstract") for b in flattened_bases if hasattr(b, "_meta")]
)
if (
not simple_class and
meta['allow_inheritance'] is False and
not meta['abstract']
not simple_class
and meta["allow_inheritance"] is False
and not meta["abstract"]
):
raise ValueError('Only direct subclasses of Document may set '
'"allow_inheritance" to False')
raise ValueError(
"Only direct subclasses of Document may set "
'"allow_inheritance" to False'
)
# Set default collection name
if 'collection' not in meta:
meta['collection'] = ''.join('_%s' % c if c.isupper() else c
for c in name).strip('_').lower()
attrs['_meta'] = meta
if "collection" not in meta:
meta["collection"] = (
"".join("_%s" % c if c.isupper() else c for c in name)
.strip("_")
.lower()
)
attrs["_meta"] = meta
# Call super and get the new class
new_class = super_new(mcs, name, bases, attrs)
@@ -361,79 +387,93 @@ class TopLevelDocumentMetaclass(DocumentMetaclass):
meta = new_class._meta
# Set index specifications
meta['index_specs'] = new_class._build_index_specs(meta['indexes'])
meta["index_specs"] = new_class._build_index_specs(meta["indexes"])
# If collection is a callable - call it and set the value
collection = meta.get('collection')
collection = meta.get("collection")
if callable(collection):
new_class._meta['collection'] = collection(new_class)
new_class._meta["collection"] = collection(new_class)
# Provide a default queryset unless exists or one has been set
if 'objects' not in dir(new_class):
if "objects" not in dir(new_class):
new_class.objects = QuerySetManager()
# Validate the fields and set primary key if needed
for field_name, field in iteritems(new_class._fields):
if field.primary_key:
# Ensure only one primary key is set
current_pk = new_class._meta.get('id_field')
current_pk = new_class._meta.get("id_field")
if current_pk and current_pk != field_name:
raise ValueError('Cannot override primary key field')
raise ValueError("Cannot override primary key field")
# Set primary key
if not current_pk:
new_class._meta['id_field'] = field_name
new_class._meta["id_field"] = field_name
new_class.id = field
# Set primary key if not defined by the document
new_class._auto_id_field = getattr(parent_doc_cls,
'_auto_id_field', False)
if not new_class._meta.get('id_field'):
# After 0.10, find not existing names, instead of overwriting
# If the document doesn't explicitly define a primary key field, create
# one. Make it an ObjectIdField and give it a non-clashing name ("id"
# by default, but can be different if that one's taken).
if not new_class._meta.get("id_field"):
id_name, id_db_name = mcs.get_auto_id_names(new_class)
new_class._auto_id_field = True
new_class._meta['id_field'] = id_name
new_class._meta["id_field"] = id_name
new_class._fields[id_name] = ObjectIdField(db_field=id_db_name)
new_class._fields[id_name].name = id_name
new_class.id = new_class._fields[id_name]
new_class._db_field_map[id_name] = id_db_name
new_class._reverse_db_field_map[id_db_name] = id_name
# Prepend id field to _fields_ordered
new_class._fields_ordered = (id_name, ) + new_class._fields_ordered
# Merge in exceptions with parent hierarchy
# Prepend the ID field to _fields_ordered (so that it's *always*
# the first field).
new_class._fields_ordered = (id_name,) + new_class._fields_ordered
# Merge in exceptions with parent hierarchy.
exceptions_to_merge = (DoesNotExist, MultipleObjectsReturned)
module = attrs.get('__module__')
module = attrs.get("__module__")
for exc in exceptions_to_merge:
name = exc.__name__
parents = tuple(getattr(base, name) for base in flattened_bases
if hasattr(base, name)) or (exc,)
# Create new exception and set to new_class
exception = type(name, parents, {'__module__': module})
parents = tuple(
getattr(base, name) for base in flattened_bases if hasattr(base, name)
) or (exc,)
# Create a new exception and set it as an attribute on the new
# class.
exception = type(name, parents, {"__module__": module})
setattr(new_class, name, exception)
return new_class
@classmethod
def get_auto_id_names(mcs, new_class):
id_name, id_db_name = ('id', '_id')
if id_name not in new_class._fields and \
id_db_name not in (v.db_field for v in new_class._fields.values()):
"""Find a name for the automatic ID field for the given new class.
Return a two-element tuple where the first item is the field name (i.e.
the attribute name on the object) and the second element is the DB
field name (i.e. the name of the key stored in MongoDB).
Defaults to ('id', '_id'), or generates a non-clashing name in the form
of ('auto_id_X', '_auto_id_X') if the default name is already taken.
"""
id_name, id_db_name = ("id", "_id")
existing_fields = {field_name for field_name in new_class._fields}
existing_db_fields = {v.db_field for v in new_class._fields.values()}
if id_name not in existing_fields and id_db_name not in existing_db_fields:
return id_name, id_db_name
id_basename, id_db_basename, i = 'auto_id', '_auto_id', 0
while id_name in new_class._fields or \
id_db_name in (v.db_field for v in new_class._fields.values()):
id_name = '{0}_{1}'.format(id_basename, i)
id_db_name = '{0}_{1}'.format(id_db_basename, i)
i += 1
return id_name, id_db_name
id_basename, id_db_basename, i = ("auto_id", "_auto_id", 0)
for i in itertools.count():
id_name = "{0}_{1}".format(id_basename, i)
id_db_name = "{0}_{1}".format(id_db_basename, i)
if id_name not in existing_fields and id_db_name not in existing_db_fields:
return id_name, id_db_name
class MetaDict(dict):
"""Custom dictionary for meta classes.
Handles the merging of set indexes
"""
_merge_options = ('indexes',)
_merge_options = ("indexes",)
def merge(self, new_options):
for k, v in iteritems(new_options):
@@ -445,4 +485,5 @@ class MetaDict(dict):
class BasesTuple(tuple):
"""Special class to handle introspection of bases tuple in __new__"""
pass

View File

@@ -19,34 +19,44 @@ def _import_class(cls_name):
if cls_name in _class_registry_cache:
return _class_registry_cache.get(cls_name)
doc_classes = ('Document', 'DynamicEmbeddedDocument', 'EmbeddedDocument',
'MapReduceDocument')
doc_classes = (
"Document",
"DynamicEmbeddedDocument",
"EmbeddedDocument",
"MapReduceDocument",
)
# Field Classes
if not _field_list_cache:
from mongoengine.fields import __all__ as fields
_field_list_cache.extend(fields)
from mongoengine.base.fields import __all__ as fields
_field_list_cache.extend(fields)
field_classes = _field_list_cache
deref_classes = ('DeReference',)
deref_classes = ("DeReference",)
if cls_name == 'BaseDocument':
if cls_name == "BaseDocument":
from mongoengine.base import document as module
import_classes = ['BaseDocument']
import_classes = ["BaseDocument"]
elif cls_name in doc_classes:
from mongoengine import document as module
import_classes = doc_classes
elif cls_name in field_classes:
from mongoengine import fields as module
import_classes = field_classes
elif cls_name in deref_classes:
from mongoengine import dereference as module
import_classes = deref_classes
else:
raise ValueError('No import set for: %s' % cls_name)
raise ValueError("No import set for: %s" % cls_name)
for cls in import_classes:
_class_registry_cache[cls] = getattr(module, cls)

View File

@@ -3,21 +3,21 @@ from pymongo.database import _check_name
import six
__all__ = [
'DEFAULT_CONNECTION_NAME',
'DEFAULT_DATABASE_NAME',
'MongoEngineConnectionError',
'connect',
'disconnect',
'disconnect_all',
'get_connection',
'get_db',
'register_connection',
"DEFAULT_CONNECTION_NAME",
"DEFAULT_DATABASE_NAME",
"ConnectionFailure",
"connect",
"disconnect",
"disconnect_all",
"get_connection",
"get_db",
"register_connection",
]
DEFAULT_CONNECTION_NAME = 'default'
DEFAULT_DATABASE_NAME = 'test'
DEFAULT_HOST = 'localhost'
DEFAULT_CONNECTION_NAME = "default"
DEFAULT_DATABASE_NAME = "test"
DEFAULT_HOST = "localhost"
DEFAULT_PORT = 27017
_connection_settings = {}
@@ -27,10 +27,11 @@ _dbs = {}
READ_PREFERENCE = ReadPreference.PRIMARY
class MongoEngineConnectionError(Exception):
class ConnectionFailure(Exception):
"""Error raised when the database connection can't be established or
when a connection with a requested alias can't be retrieved.
"""
pass
@@ -39,18 +40,23 @@ def _check_db_name(name):
This functionality is copied from pymongo Database class constructor.
"""
if not isinstance(name, six.string_types):
raise TypeError('name must be an instance of %s' % six.string_types)
elif name != '$external':
raise TypeError("name must be an instance of %s" % six.string_types)
elif name != "$external":
_check_name(name)
def _get_connection_settings(
db=None, name=None, host=None, port=None,
read_preference=READ_PREFERENCE,
username=None, password=None,
authentication_source=None,
authentication_mechanism=None,
**kwargs):
db=None,
name=None,
host=None,
port=None,
read_preference=READ_PREFERENCE,
username=None,
password=None,
authentication_source=None,
authentication_mechanism=None,
**kwargs
):
"""Get the connection settings as a dict
: param db: the name of the database to use, for compatibility with connect
@@ -73,18 +79,18 @@ def _get_connection_settings(
.. versionchanged:: 0.10.6 - added mongomock support
"""
conn_settings = {
'name': name or db or DEFAULT_DATABASE_NAME,
'host': host or DEFAULT_HOST,
'port': port or DEFAULT_PORT,
'read_preference': read_preference,
'username': username,
'password': password,
'authentication_source': authentication_source,
'authentication_mechanism': authentication_mechanism
"name": name or db or DEFAULT_DATABASE_NAME,
"host": host or DEFAULT_HOST,
"port": port or DEFAULT_PORT,
"read_preference": read_preference,
"username": username,
"password": password,
"authentication_source": authentication_source,
"authentication_mechanism": authentication_mechanism,
}
_check_db_name(conn_settings['name'])
conn_host = conn_settings['host']
_check_db_name(conn_settings["name"])
conn_host = conn_settings["host"]
# Host can be a list or a string, so if string, force to a list.
if isinstance(conn_host, six.string_types):
@@ -94,32 +100,40 @@ def _get_connection_settings(
for entity in conn_host:
# Handle Mongomock
if entity.startswith('mongomock://'):
conn_settings['is_mock'] = True
if entity.startswith("mongomock://"):
conn_settings["is_mock"] = True
# `mongomock://` is not a valid url prefix and must be replaced by `mongodb://`
resolved_hosts.append(entity.replace('mongomock://', 'mongodb://', 1))
new_entity = entity.replace("mongomock://", "mongodb://", 1)
resolved_hosts.append(new_entity)
uri_dict = uri_parser.parse_uri(new_entity)
database = uri_dict.get("database")
if database:
conn_settings["name"] = database
# Handle URI style connections, only updating connection params which
# were explicitly specified in the URI.
elif '://' in entity:
elif "://" in entity:
uri_dict = uri_parser.parse_uri(entity)
resolved_hosts.append(entity)
if uri_dict.get('database'):
conn_settings['name'] = uri_dict.get('database')
database = uri_dict.get("database")
if database:
conn_settings["name"] = database
for param in ('read_preference', 'username', 'password'):
for param in ("read_preference", "username", "password"):
if uri_dict.get(param):
conn_settings[param] = uri_dict[param]
uri_options = uri_dict['options']
if 'replicaset' in uri_options:
conn_settings['replicaSet'] = uri_options['replicaset']
if 'authsource' in uri_options:
conn_settings['authentication_source'] = uri_options['authsource']
if 'authmechanism' in uri_options:
conn_settings['authentication_mechanism'] = uri_options['authmechanism']
if 'readpreference' in uri_options:
uri_options = uri_dict["options"]
if "replicaset" in uri_options:
conn_settings["replicaSet"] = uri_options["replicaset"]
if "authsource" in uri_options:
conn_settings["authentication_source"] = uri_options["authsource"]
if "authmechanism" in uri_options:
conn_settings["authentication_mechanism"] = uri_options["authmechanism"]
if "readpreference" in uri_options:
read_preferences = (
ReadPreference.NEAREST,
ReadPreference.PRIMARY,
@@ -133,40 +147,47 @@ def _get_connection_settings(
# int (e.g. 3).
# TODO simplify the code below once we drop support for
# PyMongo v3.4.
read_pf_mode = uri_options['readpreference']
read_pf_mode = uri_options["readpreference"]
if isinstance(read_pf_mode, six.string_types):
read_pf_mode = read_pf_mode.lower()
for preference in read_preferences:
if (
preference.name.lower() == read_pf_mode or
preference.mode == read_pf_mode
preference.name.lower() == read_pf_mode
or preference.mode == read_pf_mode
):
conn_settings['read_preference'] = preference
conn_settings["read_preference"] = preference
break
else:
resolved_hosts.append(entity)
conn_settings['host'] = resolved_hosts
conn_settings["host"] = resolved_hosts
# Deprecated parameters that should not be passed on
kwargs.pop('slaves', None)
kwargs.pop('is_slave', None)
kwargs.pop("slaves", None)
kwargs.pop("is_slave", None)
conn_settings.update(kwargs)
return conn_settings
def register_connection(alias, db=None, name=None, host=None, port=None,
read_preference=READ_PREFERENCE,
username=None, password=None,
authentication_source=None,
authentication_mechanism=None,
**kwargs):
def register_connection(
alias,
db=None,
name=None,
host=None,
port=None,
read_preference=READ_PREFERENCE,
username=None,
password=None,
authentication_source=None,
authentication_mechanism=None,
**kwargs
):
"""Register the connection settings.
: param alias: the name that will be used to refer to this connection
throughout MongoEngine
: param name: the name of the specific database to use
: param db: the name of the database to use, for compatibility with connect
: param name: the name of the specific database to use
: param host: the host name of the: program: `mongod` instance to connect to
: param port: the port that the: program: `mongod` instance is running on
: param read_preference: The read preference for the collection
@@ -185,12 +206,17 @@ def register_connection(alias, db=None, name=None, host=None, port=None,
.. versionchanged:: 0.10.6 - added mongomock support
"""
conn_settings = _get_connection_settings(
db=db, name=name, host=host, port=port,
db=db,
name=name,
host=host,
port=port,
read_preference=read_preference,
username=username, password=password,
username=username,
password=password,
authentication_source=authentication_source,
authentication_mechanism=authentication_mechanism,
**kwargs)
**kwargs
)
_connection_settings[alias] = conn_settings
@@ -206,7 +232,7 @@ def disconnect(alias=DEFAULT_CONNECTION_NAME):
if alias in _dbs:
# Detach all cached collections in Documents
for doc_cls in _get_documents_by_db(alias, DEFAULT_CONNECTION_NAME):
if issubclass(doc_cls, Document): # Skip EmbeddedDocument
if issubclass(doc_cls, Document): # Skip EmbeddedDocument
doc_cls._disconnect()
del _dbs[alias]
@@ -234,22 +260,24 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
return _connections[alias]
# Validate that the requested alias exists in the _connection_settings.
# Raise MongoEngineConnectionError if it doesn't.
# Raise ConnectionFailure if it doesn't.
if alias not in _connection_settings:
if alias == DEFAULT_CONNECTION_NAME:
msg = 'You have not defined a default connection'
msg = "You have not defined a default connection"
else:
msg = 'Connection with alias "%s" has not been defined' % alias
raise MongoEngineConnectionError(msg)
raise ConnectionFailure(msg)
def _clean_settings(settings_dict):
irrelevant_fields_set = {
'name', 'username', 'password',
'authentication_source', 'authentication_mechanism'
"name",
"username",
"password",
"authentication_source",
"authentication_mechanism",
}
return {
k: v for k, v in settings_dict.items()
if k not in irrelevant_fields_set
k: v for k, v in settings_dict.items() if k not in irrelevant_fields_set
}
raw_conn_settings = _connection_settings[alias].copy()
@@ -260,13 +288,12 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
conn_settings = _clean_settings(raw_conn_settings)
# Determine if we should use PyMongo's or mongomock's MongoClient.
is_mock = conn_settings.pop('is_mock', False)
is_mock = conn_settings.pop("is_mock", False)
if is_mock:
try:
import mongomock
except ImportError:
raise RuntimeError('You need mongomock installed to mock '
'MongoEngine.')
raise RuntimeError("You need mongomock installed to mock MongoEngine.")
connection_class = mongomock.MongoClient
else:
connection_class = MongoClient
@@ -277,9 +304,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
connection = existing_connection
else:
connection = _create_connection(
alias=alias,
connection_class=connection_class,
**conn_settings
alias=alias, connection_class=connection_class, **conn_settings
)
_connections[alias] = connection
return _connections[alias]
@@ -288,13 +313,12 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
def _create_connection(alias, connection_class, **connection_settings):
"""
Create the new connection for this alias. Raise
MongoEngineConnectionError if it can't be established.
ConnectionFailure if it can't be established.
"""
try:
return connection_class(**connection_settings)
except Exception as e:
raise MongoEngineConnectionError(
'Cannot connect to database %s :\n%s' % (alias, e))
raise ConnectionFailure("Cannot connect to database %s :\n%s" % (alias, e))
def _find_existing_connection(connection_settings):
@@ -316,7 +340,7 @@ def _find_existing_connection(connection_settings):
# Only remove the name but it's important to
# keep the username/password/authentication_source/authentication_mechanism
# to identify if the connection could be shared (cfr https://github.com/MongoEngine/mongoengine/issues/2047)
return {k: v for k, v in settings_dict.items() if k != 'name'}
return {k: v for k, v in settings_dict.items() if k != "name"}
cleaned_conn_settings = _clean_settings(connection_settings)
for db_alias, connection_settings in connection_settings_bis:
@@ -332,14 +356,18 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
if alias not in _dbs:
conn = get_connection(alias)
conn_settings = _connection_settings[alias]
db = conn[conn_settings['name']]
auth_kwargs = {'source': conn_settings['authentication_source']}
if conn_settings['authentication_mechanism'] is not None:
auth_kwargs['mechanism'] = conn_settings['authentication_mechanism']
db = conn[conn_settings["name"]]
auth_kwargs = {"source": conn_settings["authentication_source"]}
if conn_settings["authentication_mechanism"] is not None:
auth_kwargs["mechanism"] = conn_settings["authentication_mechanism"]
# Authenticate if necessary
if conn_settings['username'] and (conn_settings['password'] or
conn_settings['authentication_mechanism'] == 'MONGODB-X509'):
db.authenticate(conn_settings['username'], conn_settings['password'], **auth_kwargs)
if conn_settings["username"] and (
conn_settings["password"]
or conn_settings["authentication_mechanism"] == "MONGODB-X509"
):
db.authenticate(
conn_settings["username"], conn_settings["password"], **auth_kwargs
)
_dbs[alias] = db
return _dbs[alias]
@@ -368,10 +396,10 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs):
if new_conn_settings != prev_conn_setting:
err_msg = (
u'A different connection with alias `{}` was already '
u'registered. Use disconnect() first'
u"A different connection with alias `{}` was already "
u"registered. Use disconnect() first"
).format(alias)
raise MongoEngineConnectionError(err_msg)
raise ConnectionFailure(err_msg)
else:
register_connection(alias, db, **kwargs)

View File

@@ -7,8 +7,14 @@ from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.pymongo_support import count_documents
__all__ = ('switch_db', 'switch_collection', 'no_dereference',
'no_sub_classes', 'query_counter', 'set_write_concern')
__all__ = (
"switch_db",
"switch_collection",
"no_dereference",
"no_sub_classes",
"query_counter",
"set_write_concern",
)
class switch_db(object):
@@ -38,17 +44,17 @@ class switch_db(object):
self.cls = cls
self.collection = cls._get_collection()
self.db_alias = db_alias
self.ori_db_alias = cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME)
self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)
def __enter__(self):
"""Change the db_alias and clear the cached collection."""
self.cls._meta['db_alias'] = self.db_alias
self.cls._meta["db_alias"] = self.db_alias
self.cls._collection = None
return self.cls
def __exit__(self, t, value, traceback):
"""Reset the db_alias and collection."""
self.cls._meta['db_alias'] = self.ori_db_alias
self.cls._meta["db_alias"] = self.ori_db_alias
self.cls._collection = self.collection
@@ -111,14 +117,15 @@ class no_dereference(object):
"""
self.cls = cls
ReferenceField = _import_class('ReferenceField')
GenericReferenceField = _import_class('GenericReferenceField')
ComplexBaseField = _import_class('ComplexBaseField')
ReferenceField = _import_class("ReferenceField")
GenericReferenceField = _import_class("GenericReferenceField")
ComplexBaseField = _import_class("ComplexBaseField")
self.deref_fields = [k for k, v in iteritems(self.cls._fields)
if isinstance(v, (ReferenceField,
GenericReferenceField,
ComplexBaseField))]
self.deref_fields = [
k
for k, v in iteritems(self.cls._fields)
if isinstance(v, (ReferenceField, GenericReferenceField, ComplexBaseField))
]
def __enter__(self):
"""Change the objects default and _auto_dereference values."""
@@ -180,15 +187,12 @@ class query_counter(object):
"""
self.db = get_db()
self.initial_profiling_level = None
self._ctx_query_counter = 0 # number of queries issued by the context
self._ctx_query_counter = 0 # number of queries issued by the context
self._ignored_query = {
'ns':
{'$ne': '%s.system.indexes' % self.db.name},
'op': # MONGODB < 3.2
{'$ne': 'killcursors'},
'command.killCursors': # MONGODB >= 3.2
{'$exists': False}
"ns": {"$ne": "%s.system.indexes" % self.db.name},
"op": {"$ne": "killcursors"}, # MONGODB < 3.2
"command.killCursors": {"$exists": False}, # MONGODB >= 3.2
}
def _turn_on_profiling(self):
@@ -238,8 +242,13 @@ class query_counter(object):
and substracting the queries issued by this context. In fact everytime this is called, 1 query is
issued so we need to balance that
"""
count = count_documents(self.db.system.profile, self._ignored_query) - self._ctx_query_counter
self._ctx_query_counter += 1 # Account for the query we just issued to gather the information
count = (
count_documents(self.db.system.profile, self._ignored_query)
- self._ctx_query_counter
)
self._ctx_query_counter += (
1
) # Account for the query we just issued to gather the information
return count

View File

@@ -2,8 +2,13 @@ from bson import DBRef, SON
import six
from six import iteritems
from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList,
TopLevelDocumentMetaclass, get_document)
from mongoengine.base import (
BaseDict,
BaseList,
EmbeddedDocumentList,
TopLevelDocumentMetaclass,
get_document,
)
from mongoengine.base.datastructures import LazyReference
from mongoengine.connection import get_db
from mongoengine.document import Document, EmbeddedDocument
@@ -36,21 +41,23 @@ class DeReference(object):
self.max_depth = max_depth
doc_type = None
if instance and isinstance(instance, (Document, EmbeddedDocument,
TopLevelDocumentMetaclass)):
if instance and isinstance(
instance, (Document, EmbeddedDocument, TopLevelDocumentMetaclass)
):
doc_type = instance._fields.get(name)
while hasattr(doc_type, 'field'):
while hasattr(doc_type, "field"):
doc_type = doc_type.field
if isinstance(doc_type, ReferenceField):
field = doc_type
doc_type = doc_type.document_type
is_list = not hasattr(items, 'items')
is_list = not hasattr(items, "items")
if is_list and all([i.__class__ == doc_type for i in items]):
return items
elif not is_list and all(
[i.__class__ == doc_type for i in items.values()]):
[i.__class__ == doc_type for i in items.values()]
):
return items
elif not field.dbref:
# We must turn the ObjectIds into DBRefs
@@ -83,7 +90,7 @@ class DeReference(object):
new_items[k] = value
return new_items
if not hasattr(items, 'items'):
if not hasattr(items, "items"):
items = _get_items_from_list(items)
else:
items = _get_items_from_dict(items)
@@ -120,13 +127,19 @@ class DeReference(object):
continue
elif isinstance(v, DBRef):
reference_map.setdefault(field.document_type, set()).add(v.id)
elif isinstance(v, (dict, SON)) and '_ref' in v:
reference_map.setdefault(get_document(v['_cls']), set()).add(v['_ref'].id)
elif isinstance(v, (dict, SON)) and "_ref" in v:
reference_map.setdefault(get_document(v["_cls"]), set()).add(
v["_ref"].id
)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
field_cls = getattr(getattr(field, 'field', None), 'document_type', None)
field_cls = getattr(
getattr(field, "field", None), "document_type", None
)
references = self._find_references(v, depth)
for key, refs in iteritems(references):
if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)):
if isinstance(
field_cls, (Document, TopLevelDocumentMetaclass)
):
key = field_cls
reference_map.setdefault(key, set()).update(refs)
elif isinstance(item, LazyReference):
@@ -134,8 +147,10 @@ class DeReference(object):
continue
elif isinstance(item, DBRef):
reference_map.setdefault(item.collection, set()).add(item.id)
elif isinstance(item, (dict, SON)) and '_ref' in item:
reference_map.setdefault(get_document(item['_cls']), set()).add(item['_ref'].id)
elif isinstance(item, (dict, SON)) and "_ref" in item:
reference_map.setdefault(get_document(item["_cls"]), set()).add(
item["_ref"].id
)
elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth:
references = self._find_references(item, depth - 1)
for key, refs in iteritems(references):
@@ -151,12 +166,13 @@ class DeReference(object):
# we use getattr instead of hasattr because hasattr swallows any exception under python2
# so it could hide nasty things without raising exceptions (cfr bug #1688))
ref_document_cls_exists = (getattr(collection, 'objects', None) is not None)
ref_document_cls_exists = getattr(collection, "objects", None) is not None
if ref_document_cls_exists:
col_name = collection._get_collection_name()
refs = [dbref for dbref in dbrefs
if (col_name, dbref) not in object_map]
refs = [
dbref for dbref in dbrefs if (col_name, dbref) not in object_map
]
references = collection.objects.in_bulk(refs)
for key, doc in iteritems(references):
object_map[(col_name, key)] = doc
@@ -164,23 +180,26 @@ class DeReference(object):
if isinstance(doc_type, (ListField, DictField, MapField)):
continue
refs = [dbref for dbref in dbrefs
if (collection, dbref) not in object_map]
refs = [
dbref for dbref in dbrefs if (collection, dbref) not in object_map
]
if doc_type:
references = doc_type._get_db()[collection].find({'_id': {'$in': refs}})
references = doc_type._get_db()[collection].find(
{"_id": {"$in": refs}}
)
for ref in references:
doc = doc_type._from_son(ref)
object_map[(collection, doc.id)] = doc
else:
references = get_db()[collection].find({'_id': {'$in': refs}})
references = get_db()[collection].find({"_id": {"$in": refs}})
for ref in references:
if '_cls' in ref:
doc = get_document(ref['_cls'])._from_son(ref)
if "_cls" in ref:
doc = get_document(ref["_cls"])._from_son(ref)
elif doc_type is None:
doc = get_document(
''.join(x.capitalize()
for x in collection.split('_')))._from_son(ref)
"".join(x.capitalize() for x in collection.split("_"))
)._from_son(ref)
else:
doc = doc_type._from_son(ref)
object_map[(collection, doc.id)] = doc
@@ -208,19 +227,20 @@ class DeReference(object):
return BaseList(items, instance, name)
if isinstance(items, (dict, SON)):
if '_ref' in items:
if "_ref" in items:
return self.object_map.get(
(items['_ref'].collection, items['_ref'].id), items)
elif '_cls' in items:
doc = get_document(items['_cls'])._from_son(items)
_cls = doc._data.pop('_cls', None)
del items['_cls']
(items["_ref"].collection, items["_ref"].id), items
)
elif "_cls" in items:
doc = get_document(items["_cls"])._from_son(items)
_cls = doc._data.pop("_cls", None)
del items["_cls"]
doc._data = self._attach_objects(doc._data, depth, doc, None)
if _cls is not None:
doc._data['_cls'] = _cls
doc._data["_cls"] = _cls
return doc
if not hasattr(items, 'items'):
if not hasattr(items, "items"):
is_list = True
list_type = BaseList
if isinstance(items, EmbeddedDocumentList):
@@ -247,17 +267,25 @@ class DeReference(object):
v = data[k]._data.get(field_name, None)
if isinstance(v, DBRef):
data[k]._data[field_name] = self.object_map.get(
(v.collection, v.id), v)
elif isinstance(v, (dict, SON)) and '_ref' in v:
(v.collection, v.id), v
)
elif isinstance(v, (dict, SON)) and "_ref" in v:
data[k]._data[field_name] = self.object_map.get(
(v['_ref'].collection, v['_ref'].id), v)
(v["_ref"].collection, v["_ref"].id), v
)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
item_name = six.text_type('{0}.{1}.{2}').format(name, k, field_name)
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name)
item_name = six.text_type("{0}.{1}.{2}").format(
name, k, field_name
)
data[k]._data[field_name] = self._attach_objects(
v, depth, instance=instance, name=item_name
)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
item_name = '%s.%s' % (name, k) if name else name
data[k] = self._attach_objects(v, depth - 1, instance=instance, name=item_name)
elif isinstance(v, DBRef) and hasattr(v, 'id'):
item_name = "%s.%s" % (name, k) if name else name
data[k] = self._attach_objects(
v, depth - 1, instance=instance, name=item_name
)
elif isinstance(v, DBRef) and hasattr(v, "id"):
data[k] = self.object_map.get((v.collection, v.id), v)
if instance and name:

View File

@@ -8,23 +8,36 @@ import six
from six import iteritems
from mongoengine import signals
from mongoengine.base import (BaseDict, BaseDocument, BaseList,
DocumentMetaclass, EmbeddedDocumentList,
TopLevelDocumentMetaclass, get_document)
from mongoengine.base import (
BaseDict,
BaseDocument,
BaseList,
DocumentMetaclass,
EmbeddedDocumentList,
TopLevelDocumentMetaclass,
get_document,
)
from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.context_managers import (set_write_concern,
switch_collection,
switch_db)
from mongoengine.errors import (InvalidDocumentError, InvalidQueryError,
SaveConditionError)
from mongoengine.context_managers import set_write_concern, switch_collection, switch_db
from mongoengine.errors import (
InvalidDocumentError,
InvalidQueryError,
SaveConditionError,
)
from mongoengine.pymongo_support import list_collection_names
from mongoengine.queryset import (NotUniqueError, OperationError,
QuerySet, transform)
from mongoengine.queryset import NotUniqueError, OperationError, QuerySet, transform
__all__ = ('Document', 'EmbeddedDocument', 'DynamicDocument',
'DynamicEmbeddedDocument', 'OperationError',
'InvalidCollectionError', 'NotUniqueError', 'MapReduceDocument')
__all__ = (
"Document",
"EmbeddedDocument",
"DynamicDocument",
"DynamicEmbeddedDocument",
"OperationError",
"InvalidCollectionError",
"NotUniqueError",
"MapReduceDocument",
)
def includes_cls(fields):
@@ -35,7 +48,7 @@ def includes_cls(fields):
first_field = fields[0]
elif isinstance(fields[0], (list, tuple)) and len(fields[0]):
first_field = fields[0][0]
return first_field == '_cls'
return first_field == "_cls"
class InvalidCollectionError(Exception):
@@ -56,7 +69,7 @@ class EmbeddedDocument(six.with_metaclass(DocumentMetaclass, BaseDocument)):
:attr:`meta` dictionary.
"""
__slots__ = ('_instance', )
__slots__ = ("_instance",)
# The __metaclass__ attribute is removed by 2to3 when running with Python3
# my_metaclass is defined so that metaclass can be queried in Python 2 & 3
@@ -85,8 +98,8 @@ class EmbeddedDocument(six.with_metaclass(DocumentMetaclass, BaseDocument)):
data = super(EmbeddedDocument, self).to_mongo(*args, **kwargs)
# remove _id from the SON if it's in it and it's None
if '_id' in data and data['_id'] is None:
del data['_id']
if "_id" in data and data["_id"] is None:
del data["_id"]
return data
@@ -147,19 +160,19 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
# my_metaclass is defined so that metaclass can be queried in Python 2 & 3
my_metaclass = TopLevelDocumentMetaclass
__slots__ = ('__objects',)
__slots__ = ("__objects",)
@property
def pk(self):
"""Get the primary key."""
if 'id_field' not in self._meta:
if "id_field" not in self._meta:
return None
return getattr(self, self._meta['id_field'])
return getattr(self, self._meta["id_field"])
@pk.setter
def pk(self, value):
"""Set the primary key."""
return setattr(self, self._meta['id_field'], value)
return setattr(self, self._meta["id_field"], value)
def __hash__(self):
"""Return the hash based on the PK of this document. If it's new
@@ -173,7 +186,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
@classmethod
def _get_db(cls):
"""Some Model using other db_alias"""
return get_db(cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME))
return get_db(cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME))
@classmethod
def _disconnect(cls):
@@ -190,9 +203,9 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
2. Creates indexes defined in this document's :attr:`meta` dictionary.
This happens only if `auto_create_index` is True.
"""
if not hasattr(cls, '_collection') or cls._collection is None:
if not hasattr(cls, "_collection") or cls._collection is None:
# Get the collection, either capped or regular.
if cls._meta.get('max_size') or cls._meta.get('max_documents'):
if cls._meta.get("max_size") or cls._meta.get("max_documents"):
cls._collection = cls._get_capped_collection()
else:
db = cls._get_db()
@@ -203,8 +216,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
# set to False.
# Also there is no need to ensure indexes on slave.
db = cls._get_db()
if cls._meta.get('auto_create_index', True) and\
db.client.is_primary:
if cls._meta.get("auto_create_index", True) and db.client.is_primary:
cls.ensure_indexes()
return cls._collection
@@ -216,8 +228,8 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
collection_name = cls._get_collection_name()
# Get max document limit and max byte size from meta.
max_size = cls._meta.get('max_size') or 10 * 2 ** 20 # 10MB default
max_documents = cls._meta.get('max_documents')
max_size = cls._meta.get("max_size") or 10 * 2 ** 20 # 10MB default
max_documents = cls._meta.get("max_documents")
# MongoDB will automatically raise the size to make it a multiple of
# 256 bytes. We raise it here ourselves to be able to reliably compare
@@ -227,24 +239,23 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
# If the collection already exists and has different options
# (i.e. isn't capped or has different max/size), raise an error.
if collection_name in list_collection_names(db, include_system_collections=True):
if collection_name in list_collection_names(
db, include_system_collections=True
):
collection = db[collection_name]
options = collection.options()
if (
options.get('max') != max_documents or
options.get('size') != max_size
):
if options.get("max") != max_documents or options.get("size") != max_size:
raise InvalidCollectionError(
'Cannot create collection "{}" as a capped '
'collection as it already exists'.format(cls._collection)
"collection as it already exists".format(cls._collection)
)
return collection
# Create a new capped collection.
opts = {'capped': True, 'size': max_size}
opts = {"capped": True, "size": max_size}
if max_documents:
opts['max'] = max_documents
opts["max"] = max_documents
return db.create_collection(collection_name, **opts)
@@ -253,11 +264,11 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
# If '_id' is None, try and set it from self._data. If that
# doesn't exist either, remove '_id' from the SON completely.
if data['_id'] is None:
if self._data.get('id') is None:
del data['_id']
if data["_id"] is None:
if self._data.get("id") is None:
del data["_id"]
else:
data['_id'] = self._data['id']
data["_id"] = self._data["id"]
return data
@@ -279,15 +290,17 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
query = {}
if self.pk is None:
raise InvalidDocumentError('The document does not have a primary key.')
raise InvalidDocumentError("The document does not have a primary key.")
id_field = self._meta['id_field']
id_field = self._meta["id_field"]
query = query.copy() if isinstance(query, dict) else query.to_query(self)
if id_field not in query:
query[id_field] = self.pk
elif query[id_field] != self.pk:
raise InvalidQueryError('Invalid document modify query: it must modify only this document.')
raise InvalidQueryError(
"Invalid document modify query: it must modify only this document."
)
# Need to add shard key to query, or you get an error
query.update(self._object_key)
@@ -304,9 +317,19 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
return True
def save(self, force_insert=False, validate=True, clean=True,
write_concern=None, cascade=None, cascade_kwargs=None,
_refs=None, save_condition=None, signal_kwargs=None, **kwargs):
def save(
self,
force_insert=False,
validate=True,
clean=True,
write_concern=None,
cascade=None,
cascade_kwargs=None,
_refs=None,
save_condition=None,
signal_kwargs=None,
**kwargs
):
"""Save the :class:`~mongoengine.Document` to the database. If the
document already exists, it will be updated, otherwise it will be
created.
@@ -360,8 +383,8 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
"""
signal_kwargs = signal_kwargs or {}
if self._meta.get('abstract'):
raise InvalidDocumentError('Cannot save an abstract document.')
if self._meta.get("abstract"):
raise InvalidDocumentError("Cannot save an abstract document.")
signals.pre_save.send(self.__class__, document=self, **signal_kwargs)
@@ -371,15 +394,16 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
if write_concern is None:
write_concern = {}
doc_id = self.to_mongo(fields=[self._meta['id_field']])
created = ('_id' not in doc_id or self._created or force_insert)
doc_id = self.to_mongo(fields=[self._meta["id_field"]])
created = "_id" not in doc_id or self._created or force_insert
signals.pre_save_post_validation.send(self.__class__, document=self,
created=created, **signal_kwargs)
signals.pre_save_post_validation.send(
self.__class__, document=self, created=created, **signal_kwargs
)
# it might be refreshed by the pre_save_post_validation hook, e.g., for etag generation
doc = self.to_mongo()
if self._meta.get('auto_create_index', True):
if self._meta.get("auto_create_index", True):
self.ensure_indexes()
try:
@@ -387,44 +411,45 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
if created:
object_id = self._save_create(doc, force_insert, write_concern)
else:
object_id, created = self._save_update(doc, save_condition,
write_concern)
object_id, created = self._save_update(
doc, save_condition, write_concern
)
if cascade is None:
cascade = (self._meta.get('cascade', False) or
cascade_kwargs is not None)
cascade = self._meta.get("cascade", False) or cascade_kwargs is not None
if cascade:
kwargs = {
'force_insert': force_insert,
'validate': validate,
'write_concern': write_concern,
'cascade': cascade
"force_insert": force_insert,
"validate": validate,
"write_concern": write_concern,
"cascade": cascade,
}
if cascade_kwargs: # Allow granular control over cascades
kwargs.update(cascade_kwargs)
kwargs['_refs'] = _refs
kwargs["_refs"] = _refs
self.cascade_save(**kwargs)
except pymongo.errors.DuplicateKeyError as err:
message = u'Tried to save duplicate unique keys (%s)'
message = u"Tried to save duplicate unique keys (%s)"
raise NotUniqueError(message % six.text_type(err))
except pymongo.errors.OperationFailure as err:
message = 'Could not save document (%s)'
if re.match('^E1100[01] duplicate key', six.text_type(err)):
message = "Could not save document (%s)"
if re.match("^E1100[01] duplicate key", six.text_type(err)):
# E11000 - duplicate key error index
# E11001 - duplicate key on update
message = u'Tried to save duplicate unique keys (%s)'
message = u"Tried to save duplicate unique keys (%s)"
raise NotUniqueError(message % six.text_type(err))
raise OperationError(message % six.text_type(err))
# Make sure we store the PK on this document now that it's saved
id_field = self._meta['id_field']
if created or id_field not in self._meta.get('shard_key', []):
id_field = self._meta["id_field"]
if created or id_field not in self._meta.get("shard_key", []):
self[id_field] = self._fields[id_field].to_python(object_id)
signals.post_save.send(self.__class__, document=self,
created=created, **signal_kwargs)
signals.post_save.send(
self.__class__, document=self, created=created, **signal_kwargs
)
self._clear_changed_fields()
self._created = False
@@ -442,11 +467,12 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
return wc_collection.insert_one(doc).inserted_id
# insert_one will provoke UniqueError alongside save does not
# therefore, it need to catch and call replace_one.
if '_id' in doc:
if "_id" in doc:
raw_object = wc_collection.find_one_and_replace(
{'_id': doc['_id']}, doc)
{"_id": doc["_id"]}, doc
)
if raw_object:
return doc['_id']
return doc["_id"]
object_id = wc_collection.insert_one(doc).inserted_id
@@ -461,9 +487,9 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
update_doc = {}
if updates:
update_doc['$set'] = updates
update_doc["$set"] = updates
if removals:
update_doc['$unset'] = removals
update_doc["$unset"] = removals
return update_doc
@@ -473,39 +499,38 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
Helper method, should only be used inside save().
"""
collection = self._get_collection()
object_id = doc['_id']
object_id = doc["_id"]
created = False
select_dict = {}
if save_condition is not None:
select_dict = transform.query(self.__class__, **save_condition)
select_dict['_id'] = object_id
select_dict["_id"] = object_id
# Need to add shard key to query, or you get an error
shard_key = self._meta.get('shard_key', tuple())
shard_key = self._meta.get("shard_key", tuple())
for k in shard_key:
path = self._lookup_field(k.split('.'))
path = self._lookup_field(k.split("."))
actual_key = [p.db_field for p in path]
val = doc
for ak in actual_key:
val = val[ak]
select_dict['.'.join(actual_key)] = val
select_dict[".".join(actual_key)] = val
update_doc = self._get_update_doc()
if update_doc:
upsert = save_condition is None
with set_write_concern(collection, write_concern) as wc_collection:
last_error = wc_collection.update_one(
select_dict,
update_doc,
upsert=upsert
select_dict, update_doc, upsert=upsert
).raw_result
if not upsert and last_error['n'] == 0:
raise SaveConditionError('Race condition preventing'
' document update detected')
if not upsert and last_error["n"] == 0:
raise SaveConditionError(
"Race condition preventing document update detected"
)
if last_error is not None:
updated_existing = last_error.get('updatedExisting')
updated_existing = last_error.get("updatedExisting")
if updated_existing is False:
created = True
# !!! This is bad, means we accidentally created a new,
@@ -518,21 +543,20 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
"""Recursively save any references and generic references on the
document.
"""
_refs = kwargs.get('_refs') or []
_refs = kwargs.get("_refs") or []
ReferenceField = _import_class('ReferenceField')
GenericReferenceField = _import_class('GenericReferenceField')
ReferenceField = _import_class("ReferenceField")
GenericReferenceField = _import_class("GenericReferenceField")
for name, cls in self._fields.items():
if not isinstance(cls, (ReferenceField,
GenericReferenceField)):
if not isinstance(cls, (ReferenceField, GenericReferenceField)):
continue
ref = self._data.get(name)
if not ref or isinstance(ref, DBRef):
continue
if not getattr(ref, '_changed_fields', True):
if not getattr(ref, "_changed_fields", True):
continue
ref_id = "%s,%s" % (ref.__class__.__name__, str(ref._data))
@@ -544,27 +568,31 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
@property
def _qs(self):
"""Return the queryset to use for updating / reloading / deletions."""
if not hasattr(self, '__objects'):
"""Return the default queryset corresponding to this document."""
if not hasattr(self, "__objects"):
self.__objects = QuerySet(self, self._get_collection())
return self.__objects
@property
def _object_key(self):
"""Get the query dict that can be used to fetch this object from
the database. Most of the time it's a simple PK lookup, but in
case of a sharded collection with a compound shard key, it can
contain a more complex query.
"""Return a query dict that can be used to fetch this document.
Most of the time the dict is a simple PK lookup, but in case of
a sharded collection with a compound shard key, it can contain a more
complex query.
Note that the dict returned by this method uses MongoEngine field
names instead of PyMongo field names (e.g. "pk" instead of "_id",
"some__nested__field" instead of "some.nested.field", etc.).
"""
select_dict = {'pk': self.pk}
shard_key = self.__class__._meta.get('shard_key', tuple())
select_dict = {"pk": self.pk}
shard_key = self.__class__._meta.get("shard_key", tuple())
for k in shard_key:
path = self._lookup_field(k.split('.'))
actual_key = [p.db_field for p in path]
val = self
for ak in actual_key:
val = getattr(val, ak)
select_dict['__'.join(actual_key)] = val
field_parts = k.split(".")
for part in field_parts:
val = getattr(val, part)
select_dict["__".join(field_parts)] = val
return select_dict
def update(self, **kwargs):
@@ -575,14 +603,13 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
been saved.
"""
if self.pk is None:
if kwargs.get('upsert', False):
if kwargs.get("upsert", False):
query = self.to_mongo()
if '_cls' in query:
del query['_cls']
if "_cls" in query:
del query["_cls"]
return self._qs.filter(**query).update_one(**kwargs)
else:
raise OperationError(
'attempt to update a document not yet saved')
raise OperationError("attempt to update a document not yet saved")
# Need to add shard key to query, or you get an error
return self._qs.filter(**self._object_key).update_one(**kwargs)
@@ -606,16 +633,17 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
signals.pre_delete.send(self.__class__, document=self, **signal_kwargs)
# Delete FileFields separately
FileField = _import_class('FileField')
FileField = _import_class("FileField")
for name, field in iteritems(self._fields):
if isinstance(field, FileField):
getattr(self, name).delete()
try:
self._qs.filter(
**self._object_key).delete(write_concern=write_concern, _from_doc_delete=True)
self._qs.filter(**self._object_key).delete(
write_concern=write_concern, _from_doc_delete=True
)
except pymongo.errors.OperationFailure as err:
message = u'Could not delete document (%s)' % err.message
message = u"Could not delete document (%s)" % err.message
raise OperationError(message)
signals.post_delete.send(self.__class__, document=self, **signal_kwargs)
@@ -684,7 +712,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
.. versionadded:: 0.5
"""
DeReference = _import_class('DeReference')
DeReference = _import_class("DeReference")
DeReference()([self], max_depth + 1)
return self
@@ -702,20 +730,24 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
if fields and isinstance(fields[0], int):
max_depth = fields[0]
fields = fields[1:]
elif 'max_depth' in kwargs:
max_depth = kwargs['max_depth']
elif "max_depth" in kwargs:
max_depth = kwargs["max_depth"]
if self.pk is None:
raise self.DoesNotExist('Document does not exist')
raise self.DoesNotExist("Document does not exist")
obj = self._qs.read_preference(ReadPreference.PRIMARY).filter(
**self._object_key).only(*fields).limit(
1).select_related(max_depth=max_depth)
obj = (
self._qs.read_preference(ReadPreference.PRIMARY)
.filter(**self._object_key)
.only(*fields)
.limit(1)
.select_related(max_depth=max_depth)
)
if obj:
obj = obj[0]
else:
raise self.DoesNotExist('Document does not exist')
raise self.DoesNotExist("Document does not exist")
for field in obj._data:
if not fields or field in fields:
try:
@@ -731,9 +763,11 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
# i.e. obj.update(unset__field=1) followed by obj.reload()
delattr(self, field)
self._changed_fields = list(
set(self._changed_fields) - set(fields)
) if fields else obj._changed_fields
self._changed_fields = (
list(set(self._changed_fields) - set(fields))
if fields
else obj._changed_fields
)
self._created = False
return self
@@ -759,7 +793,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
"""Returns an instance of :class:`~bson.dbref.DBRef` useful in
`__raw__` queries."""
if self.pk is None:
msg = 'Only saved documents can have a valid dbref'
msg = "Only saved documents can have a valid dbref"
raise OperationError(msg)
return DBRef(self.__class__._get_collection_name(), self.pk)
@@ -768,18 +802,22 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
"""This method registers the delete rules to apply when removing this
object.
"""
classes = [get_document(class_name)
for class_name in cls._subclasses
if class_name != cls.__name__] + [cls]
documents = [get_document(class_name)
for class_name in document_cls._subclasses
if class_name != document_cls.__name__] + [document_cls]
classes = [
get_document(class_name)
for class_name in cls._subclasses
if class_name != cls.__name__
] + [cls]
documents = [
get_document(class_name)
for class_name in document_cls._subclasses
if class_name != document_cls.__name__
] + [document_cls]
for klass in classes:
for document_cls in documents:
delete_rules = klass._meta.get('delete_rules') or {}
delete_rules = klass._meta.get("delete_rules") or {}
delete_rules[(document_cls, field_name)] = rule
klass._meta['delete_rules'] = delete_rules
klass._meta["delete_rules"] = delete_rules
@classmethod
def drop_collection(cls):
@@ -794,8 +832,9 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
"""
coll_name = cls._get_collection_name()
if not coll_name:
raise OperationError('Document %s has no collection defined '
'(is it abstract ?)' % cls)
raise OperationError(
"Document %s has no collection defined (is it abstract ?)" % cls
)
cls._collection = None
db = cls._get_db()
db.drop_collection(coll_name)
@@ -811,19 +850,18 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
"""
index_spec = cls._build_index_spec(keys)
index_spec = index_spec.copy()
fields = index_spec.pop('fields')
drop_dups = kwargs.get('drop_dups', False)
fields = index_spec.pop("fields")
drop_dups = kwargs.get("drop_dups", False)
if drop_dups:
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
msg = "drop_dups is deprecated and is removed when using PyMongo 3+."
warnings.warn(msg, DeprecationWarning)
index_spec['background'] = background
index_spec["background"] = background
index_spec.update(kwargs)
return cls._get_collection().create_index(fields, **index_spec)
@classmethod
def ensure_index(cls, key_or_list, drop_dups=False, background=False,
**kwargs):
def ensure_index(cls, key_or_list, drop_dups=False, background=False, **kwargs):
"""Ensure that the given indexes are in place. Deprecated in favour
of create_index.
@@ -835,7 +873,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
will be removed if PyMongo3+ is used
"""
if drop_dups:
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
msg = "drop_dups is deprecated and is removed when using PyMongo 3+."
warnings.warn(msg, DeprecationWarning)
return cls.create_index(key_or_list, background=background, **kwargs)
@@ -848,12 +886,12 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
.. note:: You can disable automatic index creation by setting
`auto_create_index` to False in the documents meta data
"""
background = cls._meta.get('index_background', False)
drop_dups = cls._meta.get('index_drop_dups', False)
index_opts = cls._meta.get('index_opts') or {}
index_cls = cls._meta.get('index_cls', True)
background = cls._meta.get("index_background", False)
drop_dups = cls._meta.get("index_drop_dups", False)
index_opts = cls._meta.get("index_opts") or {}
index_cls = cls._meta.get("index_cls", True)
if drop_dups:
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
msg = "drop_dups is deprecated and is removed when using PyMongo 3+."
warnings.warn(msg, DeprecationWarning)
collection = cls._get_collection()
@@ -869,40 +907,39 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
cls_indexed = False
# Ensure document-defined indexes are created
if cls._meta['index_specs']:
index_spec = cls._meta['index_specs']
if cls._meta["index_specs"]:
index_spec = cls._meta["index_specs"]
for spec in index_spec:
spec = spec.copy()
fields = spec.pop('fields')
fields = spec.pop("fields")
cls_indexed = cls_indexed or includes_cls(fields)
opts = index_opts.copy()
opts.update(spec)
# we shouldn't pass 'cls' to the collection.ensureIndex options
# because of https://jira.mongodb.org/browse/SERVER-769
if 'cls' in opts:
del opts['cls']
if "cls" in opts:
del opts["cls"]
collection.create_index(fields, background=background, **opts)
# If _cls is being used (for polymorphism), it needs an index,
# only if another index doesn't begin with _cls
if index_cls and not cls_indexed and cls._meta.get('allow_inheritance'):
if index_cls and not cls_indexed and cls._meta.get("allow_inheritance"):
# we shouldn't pass 'cls' to the collection.ensureIndex options
# because of https://jira.mongodb.org/browse/SERVER-769
if 'cls' in index_opts:
del index_opts['cls']
if "cls" in index_opts:
del index_opts["cls"]
collection.create_index('_cls', background=background,
**index_opts)
collection.create_index("_cls", background=background, **index_opts)
@classmethod
def list_indexes(cls):
""" Lists all of the indexes that should be created for given
collection. It includes all the indexes from super- and sub-classes.
"""
if cls._meta.get('abstract'):
if cls._meta.get("abstract"):
return []
# get all the base classes, subclasses and siblings
@@ -910,22 +947,27 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
def get_classes(cls):
if (cls not in classes and
isinstance(cls, TopLevelDocumentMetaclass)):
if cls not in classes and isinstance(cls, TopLevelDocumentMetaclass):
classes.append(cls)
for base_cls in cls.__bases__:
if (isinstance(base_cls, TopLevelDocumentMetaclass) and
base_cls != Document and
not base_cls._meta.get('abstract') and
base_cls._get_collection().full_name == cls._get_collection().full_name and
base_cls not in classes):
if (
isinstance(base_cls, TopLevelDocumentMetaclass)
and base_cls != Document
and not base_cls._meta.get("abstract")
and base_cls._get_collection().full_name
== cls._get_collection().full_name
and base_cls not in classes
):
classes.append(base_cls)
get_classes(base_cls)
for subclass in cls.__subclasses__():
if (isinstance(base_cls, TopLevelDocumentMetaclass) and
subclass._get_collection().full_name == cls._get_collection().full_name and
subclass not in classes):
if (
isinstance(base_cls, TopLevelDocumentMetaclass)
and subclass._get_collection().full_name
== cls._get_collection().full_name
and subclass not in classes
):
classes.append(subclass)
get_classes(subclass)
@@ -935,11 +977,11 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
def get_indexes_spec(cls):
indexes = []
if cls._meta['index_specs']:
index_spec = cls._meta['index_specs']
if cls._meta["index_specs"]:
index_spec = cls._meta["index_specs"]
for spec in index_spec:
spec = spec.copy()
fields = spec.pop('fields')
fields = spec.pop("fields")
indexes.append(fields)
return indexes
@@ -950,10 +992,10 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
indexes.append(index)
# finish up by appending { '_id': 1 } and { '_cls': 1 }, if needed
if [(u'_id', 1)] not in indexes:
indexes.append([(u'_id', 1)])
if cls._meta.get('index_cls', True) and cls._meta.get('allow_inheritance'):
indexes.append([(u'_cls', 1)])
if [(u"_id", 1)] not in indexes:
indexes.append([(u"_id", 1)])
if cls._meta.get("index_cls", True) and cls._meta.get("allow_inheritance"):
indexes.append([(u"_cls", 1)])
return indexes
@@ -967,27 +1009,26 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
existing = []
for info in cls._get_collection().index_information().values():
if '_fts' in info['key'][0]:
index_type = info['key'][0][1]
text_index_fields = info.get('weights').keys()
existing.append(
[(key, index_type) for key in text_index_fields])
if "_fts" in info["key"][0]:
index_type = info["key"][0][1]
text_index_fields = info.get("weights").keys()
existing.append([(key, index_type) for key in text_index_fields])
else:
existing.append(info['key'])
existing.append(info["key"])
missing = [index for index in required if index not in existing]
extra = [index for index in existing if index not in required]
# if { _cls: 1 } is missing, make sure it's *really* necessary
if [(u'_cls', 1)] in missing:
if [(u"_cls", 1)] in missing:
cls_obsolete = False
for index in existing:
if includes_cls(index) and index not in extra:
cls_obsolete = True
break
if cls_obsolete:
missing.remove([(u'_cls', 1)])
missing.remove([(u"_cls", 1)])
return {'missing': missing, 'extra': extra}
return {"missing": missing, "extra": extra}
class DynamicDocument(six.with_metaclass(TopLevelDocumentMetaclass, Document)):
@@ -1072,17 +1113,16 @@ class MapReduceDocument(object):
"""Lazy-load the object referenced by ``self.key``. ``self.key``
should be the ``primary_key``.
"""
id_field = self._document()._meta['id_field']
id_field = self._document()._meta["id_field"]
id_field_type = type(id_field)
if not isinstance(self.key, id_field_type):
try:
self.key = id_field_type(self.key)
except Exception:
raise Exception('Could not cast key as %s' %
id_field_type.__name__)
raise Exception("Could not cast key as %s" % id_field_type.__name__)
if not hasattr(self, '_key_object'):
if not hasattr(self, "_key_object"):
self._key_object = self._document.objects.with_id(self.key)
return self._key_object
return self._key_object

View File

@@ -3,10 +3,20 @@ from collections import defaultdict
import six
from six import iteritems
__all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError',
'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError',
'OperationError', 'NotUniqueError', 'FieldDoesNotExist',
'ValidationError', 'SaveConditionError', 'DeprecatedError')
__all__ = (
"NotRegistered",
"InvalidDocumentError",
"LookUpError",
"DoesNotExist",
"MultipleObjectsReturned",
"InvalidQueryError",
"OperationError",
"NotUniqueError",
"FieldDoesNotExist",
"ValidationError",
"SaveConditionError",
"DeprecatedError",
)
class NotRegistered(Exception):
@@ -71,25 +81,25 @@ class ValidationError(AssertionError):
field_name = None
_message = None
def __init__(self, message='', **kwargs):
def __init__(self, message="", **kwargs):
super(ValidationError, self).__init__(message)
self.errors = kwargs.get('errors', {})
self.field_name = kwargs.get('field_name')
self.errors = kwargs.get("errors", {})
self.field_name = kwargs.get("field_name")
self.message = message
def __str__(self):
return six.text_type(self.message)
def __repr__(self):
return '%s(%s,)' % (self.__class__.__name__, self.message)
return "%s(%s,)" % (self.__class__.__name__, self.message)
def __getattribute__(self, name):
message = super(ValidationError, self).__getattribute__(name)
if name == 'message':
if name == "message":
if self.field_name:
message = '%s' % message
message = "%s" % message
if self.errors:
message = '%s(%s)' % (message, self._format_errors())
message = "%s(%s)" % (message, self._format_errors())
return message
def _get_message(self):
@@ -128,22 +138,22 @@ class ValidationError(AssertionError):
def _format_errors(self):
"""Returns a string listing all errors within a document"""
def generate_key(value, prefix=''):
def generate_key(value, prefix=""):
if isinstance(value, list):
value = ' '.join([generate_key(k) for k in value])
value = " ".join([generate_key(k) for k in value])
elif isinstance(value, dict):
value = ' '.join(
[generate_key(v, k) for k, v in iteritems(value)])
value = " ".join([generate_key(v, k) for k, v in iteritems(value)])
results = '%s.%s' % (prefix, value) if prefix else value
results = "%s.%s" % (prefix, value) if prefix else value
return results
error_dict = defaultdict(list)
for k, v in iteritems(self.to_dict()):
error_dict[generate_key(v)].append(k)
return ' '.join(['%s: %s' % (k, v) for k, v in iteritems(error_dict)])
return " ".join(["%s: %s" % (k, v) for k, v in iteritems(error_dict)])
class DeprecatedError(Exception):
"""Raise when a user uses a feature that has been Deprecated"""
pass

File diff suppressed because it is too large Load Diff

View File

@@ -15,5 +15,5 @@ def get_mongodb_version():
:return: tuple(int, int)
"""
version_list = get_connection().server_info()['versionArray'][:2] # e.g: (3, 2)
version_list = get_connection().server_info()["versionArray"][:2] # e.g: (3, 2)
return tuple(version_list)

View File

@@ -27,6 +27,6 @@ def list_collection_names(db, include_system_collections=False):
collections = db.collection_names()
if not include_system_collections:
collections = [c for c in collections if not c.startswith('system.')]
collections = [c for c in collections if not c.startswith("system.")]
return collections

View File

@@ -7,11 +7,22 @@ from mongoengine.queryset.visitor import *
# Expose just the public subset of all imported objects and constants.
__all__ = (
'QuerySet', 'QuerySetNoCache', 'Q', 'queryset_manager', 'QuerySetManager',
'QueryFieldList', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL',
"QuerySet",
"QuerySetNoCache",
"Q",
"queryset_manager",
"QuerySetManager",
"QueryFieldList",
"DO_NOTHING",
"NULLIFY",
"CASCADE",
"DENY",
"PULL",
# Errors that might be related to a queryset, mostly here for backward
# compatibility
'DoesNotExist', 'InvalidQueryError', 'MultipleObjectsReturned',
'NotUniqueError', 'OperationError',
"DoesNotExist",
"InvalidQueryError",
"MultipleObjectsReturned",
"NotUniqueError",
"OperationError",
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,15 @@
__all__ = ('QueryFieldList',)
__all__ = ("QueryFieldList",)
class QueryFieldList(object):
"""Object that handles combinations of .only() and .exclude() calls"""
ONLY = 1
EXCLUDE = 0
def __init__(self, fields=None, value=ONLY, always_include=None, _only_called=False):
def __init__(
self, fields=None, value=ONLY, always_include=None, _only_called=False
):
"""The QueryFieldList builder
:param fields: A list of fields used in `.only()` or `.exclude()`
@@ -49,7 +52,7 @@ class QueryFieldList(object):
self.fields = f.fields - self.fields
self._clean_slice()
if '_id' in f.fields:
if "_id" in f.fields:
self._id = f.value
if self.always_include:
@@ -59,7 +62,7 @@ class QueryFieldList(object):
else:
self.fields -= self.always_include
if getattr(f, '_only_called', False):
if getattr(f, "_only_called", False):
self._only_called = True
return self
@@ -73,7 +76,7 @@ class QueryFieldList(object):
if self.slice:
field_list.update(self.slice)
if self._id is not None:
field_list['_id'] = self._id
field_list["_id"] = self._id
return field_list
def reset(self):

View File

@@ -1,7 +1,7 @@
from functools import partial
from mongoengine.queryset.queryset import QuerySet
__all__ = ('queryset_manager', 'QuerySetManager')
__all__ = ("queryset_manager", "QuerySetManager")
class QuerySetManager(object):
@@ -33,7 +33,7 @@ class QuerySetManager(object):
return self
# owner is the document that contains the QuerySetManager
queryset_class = owner._meta.get('queryset_class', self.default)
queryset_class = owner._meta.get("queryset_class", self.default)
queryset = queryset_class(owner, owner._get_collection())
if self.get_queryset:
arg_count = self.get_queryset.__code__.co_argcount

View File

@@ -1,11 +1,24 @@
import six
from mongoengine.errors import OperationError
from mongoengine.queryset.base import (BaseQuerySet, CASCADE, DENY, DO_NOTHING,
NULLIFY, PULL)
from mongoengine.queryset.base import (
BaseQuerySet,
CASCADE,
DENY,
DO_NOTHING,
NULLIFY,
PULL,
)
__all__ = ('QuerySet', 'QuerySetNoCache', 'DO_NOTHING', 'NULLIFY', 'CASCADE',
'DENY', 'PULL')
__all__ = (
"QuerySet",
"QuerySetNoCache",
"DO_NOTHING",
"NULLIFY",
"CASCADE",
"DENY",
"PULL",
)
# The maximum number of items to display in a QuerySet.__repr__
REPR_OUTPUT_SIZE = 20
@@ -57,12 +70,12 @@ class QuerySet(BaseQuerySet):
def __repr__(self):
"""Provide a string representation of the QuerySet"""
if self._iter:
return '.. queryset mid-iteration ..'
return ".. queryset mid-iteration .."
self._populate_cache()
data = self._result_cache[:REPR_OUTPUT_SIZE + 1]
data = self._result_cache[: REPR_OUTPUT_SIZE + 1]
if len(data) > REPR_OUTPUT_SIZE:
data[-1] = '...(remaining elements truncated)...'
data[-1] = "...(remaining elements truncated)..."
return repr(data)
def _iter_results(self):
@@ -143,10 +156,9 @@ class QuerySet(BaseQuerySet):
.. versionadded:: 0.8.3 Convert to non caching queryset
"""
if self._result_cache is not None:
raise OperationError('QuerySet already cached')
raise OperationError("QuerySet already cached")
return self._clone_into(QuerySetNoCache(self._document,
self._collection))
return self._clone_into(QuerySetNoCache(self._document, self._collection))
class QuerySetNoCache(BaseQuerySet):
@@ -165,7 +177,7 @@ class QuerySetNoCache(BaseQuerySet):
.. versionchanged:: 0.6.13 Now doesnt modify the cursor
"""
if self._iter:
return '.. queryset mid-iteration ..'
return ".. queryset mid-iteration .."
data = []
for _ in six.moves.range(REPR_OUTPUT_SIZE + 1):
@@ -175,7 +187,7 @@ class QuerySetNoCache(BaseQuerySet):
break
if len(data) > REPR_OUTPUT_SIZE:
data[-1] = '...(remaining elements truncated)...'
data[-1] = "...(remaining elements truncated)..."
self.rewind()
return repr(data)

View File

@@ -10,21 +10,54 @@ from mongoengine.base import UPDATE_OPERATORS
from mongoengine.common import _import_class
from mongoengine.errors import InvalidQueryError
__all__ = ('query', 'update')
__all__ = ("query", "update")
COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
'all', 'size', 'exists', 'not', 'elemMatch', 'type')
GEO_OPERATORS = ('within_distance', 'within_spherical_distance',
'within_box', 'within_polygon', 'near', 'near_sphere',
'max_distance', 'min_distance', 'geo_within', 'geo_within_box',
'geo_within_polygon', 'geo_within_center',
'geo_within_sphere', 'geo_intersects')
STRING_OPERATORS = ('contains', 'icontains', 'startswith',
'istartswith', 'endswith', 'iendswith',
'exact', 'iexact')
CUSTOM_OPERATORS = ('match',)
MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
STRING_OPERATORS + CUSTOM_OPERATORS)
COMPARISON_OPERATORS = (
"ne",
"gt",
"gte",
"lt",
"lte",
"in",
"nin",
"mod",
"all",
"size",
"exists",
"not",
"elemMatch",
"type",
)
GEO_OPERATORS = (
"within_distance",
"within_spherical_distance",
"within_box",
"within_polygon",
"near",
"near_sphere",
"max_distance",
"min_distance",
"geo_within",
"geo_within_box",
"geo_within_polygon",
"geo_within_center",
"geo_within_sphere",
"geo_intersects",
)
STRING_OPERATORS = (
"contains",
"icontains",
"startswith",
"istartswith",
"endswith",
"iendswith",
"exact",
"iexact",
)
CUSTOM_OPERATORS = ("match",)
MATCH_OPERATORS = (
COMPARISON_OPERATORS + GEO_OPERATORS + STRING_OPERATORS + CUSTOM_OPERATORS
)
# TODO make this less complex
@@ -33,11 +66,11 @@ def query(_doc_cls=None, **kwargs):
mongo_query = {}
merge_query = defaultdict(list)
for key, value in sorted(kwargs.items()):
if key == '__raw__':
if key == "__raw__":
mongo_query.update(value)
continue
parts = key.rsplit('__')
parts = key.rsplit("__")
indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()]
parts = [part for part in parts if not part.isdigit()]
# Check for an operator and transform to mongo-style if there is
@@ -46,11 +79,11 @@ def query(_doc_cls=None, **kwargs):
op = parts.pop()
# Allow to escape operator-like field name by __
if len(parts) > 1 and parts[-1] == '':
if len(parts) > 1 and parts[-1] == "":
parts.pop()
negate = False
if len(parts) > 1 and parts[-1] == 'not':
if len(parts) > 1 and parts[-1] == "not":
parts.pop()
negate = True
@@ -62,8 +95,8 @@ def query(_doc_cls=None, **kwargs):
raise InvalidQueryError(e)
parts = []
CachedReferenceField = _import_class('CachedReferenceField')
GenericReferenceField = _import_class('GenericReferenceField')
CachedReferenceField = _import_class("CachedReferenceField")
GenericReferenceField = _import_class("GenericReferenceField")
cleaned_fields = []
for field in fields:
@@ -73,7 +106,7 @@ def query(_doc_cls=None, **kwargs):
append_field = False
# is last and CachedReferenceField
elif isinstance(field, CachedReferenceField) and fields[-1] == field:
parts.append('%s._id' % field.db_field)
parts.append("%s._id" % field.db_field)
else:
parts.append(field.db_field)
@@ -83,15 +116,15 @@ def query(_doc_cls=None, **kwargs):
# Convert value to proper value
field = cleaned_fields[-1]
singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
singular_ops = [None, "ne", "gt", "gte", "lt", "lte", "not"]
singular_ops += STRING_OPERATORS
if op in singular_ops:
value = field.prepare_query_value(op, value)
if isinstance(field, CachedReferenceField) and value:
value = value['_id']
value = value["_id"]
elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
elif op in ("in", "nin", "all", "near") and not isinstance(value, dict):
# Raise an error if the in/nin/all/near param is not iterable.
value = _prepare_query_for_iterable(field, op, value)
@@ -101,39 +134,40 @@ def query(_doc_cls=None, **kwargs):
# * If the value is an ObjectId, the key should be "field_name._ref.$id".
if isinstance(field, GenericReferenceField):
if isinstance(value, DBRef):
parts[-1] += '._ref'
parts[-1] += "._ref"
elif isinstance(value, ObjectId):
parts[-1] += '._ref.$id'
parts[-1] += "._ref.$id"
# if op and op not in COMPARISON_OPERATORS:
if op:
if op in GEO_OPERATORS:
value = _geo_operator(field, op, value)
elif op in ('match', 'elemMatch'):
ListField = _import_class('ListField')
EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
elif op in ("match", "elemMatch"):
ListField = _import_class("ListField")
EmbeddedDocumentField = _import_class("EmbeddedDocumentField")
if (
isinstance(value, dict) and
isinstance(field, ListField) and
isinstance(field.field, EmbeddedDocumentField)
isinstance(value, dict)
and isinstance(field, ListField)
and isinstance(field.field, EmbeddedDocumentField)
):
value = query(field.field.document_type, **value)
else:
value = field.prepare_query_value(op, value)
value = {'$elemMatch': value}
value = {"$elemMatch": value}
elif op in CUSTOM_OPERATORS:
NotImplementedError('Custom method "%s" has not '
'been implemented' % op)
NotImplementedError(
'Custom method "%s" has not ' "been implemented" % op
)
elif op not in STRING_OPERATORS:
value = {'$' + op: value}
value = {"$" + op: value}
if negate:
value = {'$not': value}
value = {"$not": value}
for i, part in indices:
parts.insert(i, part)
key = '.'.join(parts)
key = ".".join(parts)
if op is None or key not in mongo_query:
mongo_query[key] = value
@@ -142,30 +176,35 @@ def query(_doc_cls=None, **kwargs):
mongo_query[key].update(value)
# $max/minDistance needs to come last - convert to SON
value_dict = mongo_query[key]
if ('$maxDistance' in value_dict or '$minDistance' in value_dict) and \
('$near' in value_dict or '$nearSphere' in value_dict):
if ("$maxDistance" in value_dict or "$minDistance" in value_dict) and (
"$near" in value_dict or "$nearSphere" in value_dict
):
value_son = SON()
for k, v in iteritems(value_dict):
if k == '$maxDistance' or k == '$minDistance':
if k == "$maxDistance" or k == "$minDistance":
continue
value_son[k] = v
# Required for MongoDB >= 2.6, may fail when combining
# PyMongo 3+ and MongoDB < 2.6
near_embedded = False
for near_op in ('$near', '$nearSphere'):
for near_op in ("$near", "$nearSphere"):
if isinstance(value_dict.get(near_op), dict):
value_son[near_op] = SON(value_son[near_op])
if '$maxDistance' in value_dict:
value_son[near_op]['$maxDistance'] = value_dict['$maxDistance']
if '$minDistance' in value_dict:
value_son[near_op]['$minDistance'] = value_dict['$minDistance']
if "$maxDistance" in value_dict:
value_son[near_op]["$maxDistance"] = value_dict[
"$maxDistance"
]
if "$minDistance" in value_dict:
value_son[near_op]["$minDistance"] = value_dict[
"$minDistance"
]
near_embedded = True
if not near_embedded:
if '$maxDistance' in value_dict:
value_son['$maxDistance'] = value_dict['$maxDistance']
if '$minDistance' in value_dict:
value_son['$minDistance'] = value_dict['$minDistance']
if "$maxDistance" in value_dict:
value_son["$maxDistance"] = value_dict["$maxDistance"]
if "$minDistance" in value_dict:
value_son["$minDistance"] = value_dict["$minDistance"]
mongo_query[key] = value_son
else:
# Store for manually merging later
@@ -177,10 +216,10 @@ def query(_doc_cls=None, **kwargs):
del mongo_query[k]
if isinstance(v, list):
value = [{k: val} for val in v]
if '$and' in mongo_query.keys():
mongo_query['$and'].extend(value)
if "$and" in mongo_query.keys():
mongo_query["$and"].extend(value)
else:
mongo_query['$and'] = value
mongo_query["$and"] = value
return mongo_query
@@ -192,15 +231,15 @@ def update(_doc_cls=None, **update):
mongo_update = {}
for key, value in update.items():
if key == '__raw__':
if key == "__raw__":
mongo_update.update(value)
continue
parts = key.split('__')
parts = key.split("__")
# if there is no operator, default to 'set'
if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS:
parts.insert(0, 'set')
parts.insert(0, "set")
# Check for an operator and transform to mongo-style if there is
op = None
@@ -208,13 +247,13 @@ def update(_doc_cls=None, **update):
op = parts.pop(0)
# Convert Pythonic names to Mongo equivalents
operator_map = {
'push_all': 'pushAll',
'pull_all': 'pullAll',
'dec': 'inc',
'add_to_set': 'addToSet',
'set_on_insert': 'setOnInsert'
"push_all": "pushAll",
"pull_all": "pullAll",
"dec": "inc",
"add_to_set": "addToSet",
"set_on_insert": "setOnInsert",
}
if op == 'dec':
if op == "dec":
# Support decrement by flipping a positive value's sign
# and using 'inc'
value = -value
@@ -227,7 +266,7 @@ def update(_doc_cls=None, **update):
match = parts.pop()
# Allow to escape operator-like field name by __
if len(parts) > 1 and parts[-1] == '':
if len(parts) > 1 and parts[-1] == "":
parts.pop()
if _doc_cls:
@@ -244,8 +283,8 @@ def update(_doc_cls=None, **update):
append_field = True
if isinstance(field, six.string_types):
# Convert the S operator to $
if field == 'S':
field = '$'
if field == "S":
field = "$"
parts.append(field)
append_field = False
else:
@@ -253,7 +292,7 @@ def update(_doc_cls=None, **update):
if append_field:
appended_sub_field = False
cleaned_fields.append(field)
if hasattr(field, 'field'):
if hasattr(field, "field"):
cleaned_fields.append(field.field)
appended_sub_field = True
@@ -263,52 +302,53 @@ def update(_doc_cls=None, **update):
else:
field = cleaned_fields[-1]
GeoJsonBaseField = _import_class('GeoJsonBaseField')
GeoJsonBaseField = _import_class("GeoJsonBaseField")
if isinstance(field, GeoJsonBaseField):
value = field.to_mongo(value)
if op == 'pull':
if op == "pull":
if field.required or value is not None:
if match in ('in', 'nin') and not isinstance(value, dict):
if match in ("in", "nin") and not isinstance(value, dict):
value = _prepare_query_for_iterable(field, op, value)
else:
value = field.prepare_query_value(op, value)
elif op == 'push' and isinstance(value, (list, tuple, set)):
elif op == "push" and isinstance(value, (list, tuple, set)):
value = [field.prepare_query_value(op, v) for v in value]
elif op in (None, 'set', 'push'):
elif op in (None, "set", "push"):
if field.required or value is not None:
value = field.prepare_query_value(op, value)
elif op in ('pushAll', 'pullAll'):
elif op in ("pushAll", "pullAll"):
value = [field.prepare_query_value(op, v) for v in value]
elif op in ('addToSet', 'setOnInsert'):
elif op in ("addToSet", "setOnInsert"):
if isinstance(value, (list, tuple, set)):
value = [field.prepare_query_value(op, v) for v in value]
elif field.required or value is not None:
value = field.prepare_query_value(op, value)
elif op == 'unset':
elif op == "unset":
value = 1
elif op == 'inc':
elif op == "inc":
value = field.prepare_query_value(op, value)
if match:
match = '$' + match
match = "$" + match
value = {match: value}
key = '.'.join(parts)
key = ".".join(parts)
if 'pull' in op and '.' in key:
if "pull" in op and "." in key:
# Dot operators don't work on pull operations
# unless they point to a list field
# Otherwise it uses nested dict syntax
if op == 'pullAll':
raise InvalidQueryError('pullAll operations only support '
'a single field depth')
if op == "pullAll":
raise InvalidQueryError(
"pullAll operations only support a single field depth"
)
# Look for the last list field and use dot notation until there
field_classes = [c.__class__ for c in cleaned_fields]
field_classes.reverse()
ListField = _import_class('ListField')
EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField')
ListField = _import_class("ListField")
EmbeddedDocumentListField = _import_class("EmbeddedDocumentListField")
if ListField in field_classes or EmbeddedDocumentListField in field_classes:
# Join all fields via dot notation to the last ListField or EmbeddedDocumentListField
# Then process as normal
@@ -317,37 +357,36 @@ def update(_doc_cls=None, **update):
else:
_check_field = EmbeddedDocumentListField
last_listField = len(
cleaned_fields) - field_classes.index(_check_field)
key = '.'.join(parts[:last_listField])
last_listField = len(cleaned_fields) - field_classes.index(_check_field)
key = ".".join(parts[:last_listField])
parts = parts[last_listField:]
parts.insert(0, key)
parts.reverse()
for key in parts:
value = {key: value}
elif op == 'addToSet' and isinstance(value, list):
value = {key: {'$each': value}}
elif op in ('push', 'pushAll'):
elif op == "addToSet" and isinstance(value, list):
value = {key: {"$each": value}}
elif op in ("push", "pushAll"):
if parts[-1].isdigit():
key = '.'.join(parts[0:-1])
key = ".".join(parts[0:-1])
position = int(parts[-1])
# $position expects an iterable. If pushing a single value,
# wrap it in a list.
if not isinstance(value, (set, tuple, list)):
value = [value]
value = {key: {'$each': value, '$position': position}}
value = {key: {"$each": value, "$position": position}}
else:
if op == 'pushAll':
op = 'push' # convert to non-deprecated keyword
if op == "pushAll":
op = "push" # convert to non-deprecated keyword
if not isinstance(value, (set, tuple, list)):
value = [value]
value = {key: {'$each': value}}
value = {key: {"$each": value}}
else:
value = {key: value}
else:
value = {key: value}
key = '$' + op
key = "$" + op
if key not in mongo_update:
mongo_update[key] = value
elif key in mongo_update and isinstance(mongo_update[key], dict):
@@ -358,45 +397,45 @@ def update(_doc_cls=None, **update):
def _geo_operator(field, op, value):
"""Helper to return the query for a given geo query."""
if op == 'max_distance':
value = {'$maxDistance': value}
elif op == 'min_distance':
value = {'$minDistance': value}
if op == "max_distance":
value = {"$maxDistance": value}
elif op == "min_distance":
value = {"$minDistance": value}
elif field._geo_index == pymongo.GEO2D:
if op == 'within_distance':
value = {'$within': {'$center': value}}
elif op == 'within_spherical_distance':
value = {'$within': {'$centerSphere': value}}
elif op == 'within_polygon':
value = {'$within': {'$polygon': value}}
elif op == 'near':
value = {'$near': value}
elif op == 'near_sphere':
value = {'$nearSphere': value}
elif op == 'within_box':
value = {'$within': {'$box': value}}
else:
raise NotImplementedError('Geo method "%s" has not been '
'implemented for a GeoPointField' % op)
else:
if op == 'geo_within':
value = {'$geoWithin': _infer_geometry(value)}
elif op == 'geo_within_box':
value = {'$geoWithin': {'$box': value}}
elif op == 'geo_within_polygon':
value = {'$geoWithin': {'$polygon': value}}
elif op == 'geo_within_center':
value = {'$geoWithin': {'$center': value}}
elif op == 'geo_within_sphere':
value = {'$geoWithin': {'$centerSphere': value}}
elif op == 'geo_intersects':
value = {'$geoIntersects': _infer_geometry(value)}
elif op == 'near':
value = {'$near': _infer_geometry(value)}
if op == "within_distance":
value = {"$within": {"$center": value}}
elif op == "within_spherical_distance":
value = {"$within": {"$centerSphere": value}}
elif op == "within_polygon":
value = {"$within": {"$polygon": value}}
elif op == "near":
value = {"$near": value}
elif op == "near_sphere":
value = {"$nearSphere": value}
elif op == "within_box":
value = {"$within": {"$box": value}}
else:
raise NotImplementedError(
'Geo method "%s" has not been implemented for a %s '
% (op, field._name)
'Geo method "%s" has not been ' "implemented for a GeoPointField" % op
)
else:
if op == "geo_within":
value = {"$geoWithin": _infer_geometry(value)}
elif op == "geo_within_box":
value = {"$geoWithin": {"$box": value}}
elif op == "geo_within_polygon":
value = {"$geoWithin": {"$polygon": value}}
elif op == "geo_within_center":
value = {"$geoWithin": {"$center": value}}
elif op == "geo_within_sphere":
value = {"$geoWithin": {"$centerSphere": value}}
elif op == "geo_intersects":
value = {"$geoIntersects": _infer_geometry(value)}
elif op == "near":
value = {"$near": _infer_geometry(value)}
else:
raise NotImplementedError(
'Geo method "%s" has not been implemented for a %s ' % (op, field._name)
)
return value
@@ -406,51 +445,58 @@ def _infer_geometry(value):
given value.
"""
if isinstance(value, dict):
if '$geometry' in value:
if "$geometry" in value:
return value
elif 'coordinates' in value and 'type' in value:
return {'$geometry': value}
raise InvalidQueryError('Invalid $geometry dictionary should have '
'type and coordinates keys')
elif "coordinates" in value and "type" in value:
return {"$geometry": value}
raise InvalidQueryError(
"Invalid $geometry dictionary should have type and coordinates keys"
)
elif isinstance(value, (list, set)):
# TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon?
try:
value[0][0][0]
return {'$geometry': {'type': 'Polygon', 'coordinates': value}}
return {"$geometry": {"type": "Polygon", "coordinates": value}}
except (TypeError, IndexError):
pass
try:
value[0][0]
return {'$geometry': {'type': 'LineString', 'coordinates': value}}
return {"$geometry": {"type": "LineString", "coordinates": value}}
except (TypeError, IndexError):
pass
try:
value[0]
return {'$geometry': {'type': 'Point', 'coordinates': value}}
return {"$geometry": {"type": "Point", "coordinates": value}}
except (TypeError, IndexError):
pass
raise InvalidQueryError('Invalid $geometry data. Can be either a '
'dictionary or (nested) lists of coordinate(s)')
raise InvalidQueryError(
"Invalid $geometry data. Can be either a "
"dictionary or (nested) lists of coordinate(s)"
)
def _prepare_query_for_iterable(field, op, value):
# We need a special check for BaseDocument, because - although it's iterable - using
# it as such in the context of this method is most definitely a mistake.
BaseDocument = _import_class('BaseDocument')
BaseDocument = _import_class("BaseDocument")
if isinstance(value, BaseDocument):
raise TypeError("When using the `in`, `nin`, `all`, or "
"`near`-operators you can\'t use a "
"`Document`, you must wrap your object "
"in a list (object -> [object]).")
raise TypeError(
"When using the `in`, `nin`, `all`, or "
"`near`-operators you can't use a "
"`Document`, you must wrap your object "
"in a list (object -> [object])."
)
if not hasattr(value, '__iter__'):
raise TypeError("The `in`, `nin`, `all`, or "
"`near`-operators must be applied to an "
"iterable (e.g. a list).")
if not hasattr(value, "__iter__"):
raise TypeError(
"The `in`, `nin`, `all`, or "
"`near`-operators must be applied to an "
"iterable (e.g. a list)."
)
return [field.prepare_query_value(op, v) for v in value]

View File

@@ -3,7 +3,7 @@ import copy
from mongoengine.errors import InvalidQueryError
from mongoengine.queryset import transform
__all__ = ('Q', 'QNode')
__all__ = ("Q", "QNode")
class QNodeVisitor(object):
@@ -69,9 +69,9 @@ class QueryCompilerVisitor(QNodeVisitor):
self.document = document
def visit_combination(self, combination):
operator = '$and'
operator = "$and"
if combination.operation == combination.OR:
operator = '$or'
operator = "$or"
return {operator: combination.children}
def visit_query(self, query):
@@ -96,7 +96,7 @@ class QNode(object):
"""Combine this node with another node into a QCombination
object.
"""
if getattr(other, 'empty', True):
if getattr(other, "empty", True):
return self
if self.empty:
@@ -132,8 +132,8 @@ class QCombination(QNode):
self.children.append(node)
def __repr__(self):
op = ' & ' if self.operation is self.AND else ' | '
return '(%s)' % op.join([repr(node) for node in self.children])
op = " & " if self.operation is self.AND else " | "
return "(%s)" % op.join([repr(node) for node in self.children])
def accept(self, visitor):
for i in range(len(self.children)):
@@ -156,7 +156,7 @@ class Q(QNode):
self.query = query
def __repr__(self):
return 'Q(**%s)' % repr(self.query)
return "Q(**%s)" % repr(self.query)
def accept(self, visitor):
return visitor.visit_query(self)

View File

@@ -1,5 +1,12 @@
__all__ = ('pre_init', 'post_init', 'pre_save', 'pre_save_post_validation',
'post_save', 'pre_delete', 'post_delete')
__all__ = (
"pre_init",
"post_init",
"pre_save",
"pre_save_post_validation",
"post_save",
"pre_delete",
"post_delete",
)
signals_available = False
try:
@@ -7,6 +14,7 @@ try:
signals_available = True
except ImportError:
class Namespace(object):
def signal(self, name, doc=None):
return _FakeSignal(name, doc)
@@ -23,13 +31,16 @@ except ImportError:
self.__doc__ = doc
def _fail(self, *args, **kwargs):
raise RuntimeError('signalling support is unavailable '
'because the blinker library is '
'not installed.')
raise RuntimeError(
"signalling support is unavailable "
"because the blinker library is "
"not installed."
)
send = lambda *a, **kw: None # noqa
connect = disconnect = has_receivers_for = receivers_for = \
temporarily_connected_to = _fail
connect = (
disconnect
) = has_receivers_for = receivers_for = temporarily_connected_to = _fail
del _fail
@@ -37,12 +48,12 @@ except ImportError:
# not put signals in here. Create your own namespace instead.
_signals = Namespace()
pre_init = _signals.signal('pre_init')
post_init = _signals.signal('post_init')
pre_save = _signals.signal('pre_save')
pre_save_post_validation = _signals.signal('pre_save_post_validation')
post_save = _signals.signal('post_save')
pre_delete = _signals.signal('pre_delete')
post_delete = _signals.signal('post_delete')
pre_bulk_insert = _signals.signal('pre_bulk_insert')
post_bulk_insert = _signals.signal('post_bulk_insert')
pre_init = _signals.signal("pre_init")
post_init = _signals.signal("post_init")
pre_save = _signals.signal("pre_save")
pre_save_post_validation = _signals.signal("pre_save_post_validation")
post_save = _signals.signal("post_save")
pre_delete = _signals.signal("pre_delete")
post_delete = _signals.signal("post_delete")
pre_bulk_insert = _signals.signal("pre_bulk_insert")
post_bulk_insert = _signals.signal("post_bulk_insert")