Merge branch 'master' into feature/allow-setting-read-concern-queryset

This commit is contained in:
Bastien Gérard
2020-04-26 22:36:47 +02:00
committed by GitHub
65 changed files with 687 additions and 863 deletions

View File

@@ -1,5 +1,3 @@
from __future__ import absolute_import
import copy
import itertools
import re
@@ -14,8 +12,6 @@ import pymongo.errors
from pymongo.collection import ReturnDocument
from pymongo.common import validate_read_preference
from pymongo.read_concern import ReadConcern
import six
from six import iteritems
from mongoengine import signals
from mongoengine.base import get_document
@@ -48,7 +44,7 @@ DENY = 3
PULL = 4
class BaseQuerySet(object):
class BaseQuerySet:
"""A set of results returned from a query. Wraps a MongoDB cursor,
providing :class:`~mongoengine.Document` objects as the results.
"""
@@ -67,7 +63,6 @@ class BaseQuerySet(object):
self._ordering = None
self._snapshot = False
self._timeout = True
self._slave_okay = False
self._read_preference = None
self._read_concern = None
self._iter = False
@@ -212,8 +207,6 @@ class BaseQuerySet(object):
"""Avoid to open all records in an if stmt in Py3."""
return self._has_data()
__nonzero__ = __bool__ # For Py2 support
# Core functions
def all(self):
@@ -264,20 +257,21 @@ class BaseQuerySet(object):
queryset = queryset.filter(*q_objs, **query)
try:
result = six.next(queryset)
result = next(queryset)
except StopIteration:
msg = "%s matching query does not exist." % queryset._document._class_name
raise queryset._document.DoesNotExist(msg)
try:
six.next(queryset)
# Check if there is another match
next(queryset)
except StopIteration:
return result
# If we were able to retrieve the 2nd doc, rewind the cursor and
# raise the MultipleObjectsReturned exception.
queryset.rewind()
message = u"%d items returned, instead of 1" % queryset.count()
raise queryset._document.MultipleObjectsReturned(message)
# If we were able to retrieve the 2nd doc, raise the MultipleObjectsReturned exception.
raise queryset._document.MultipleObjectsReturned(
"2 or more items returned, instead of 1"
)
def create(self, **kwargs):
"""Create new object. Returns the saved object instance.
@@ -361,20 +355,20 @@ class BaseQuerySet(object):
)
except pymongo.errors.DuplicateKeyError as err:
message = "Could not save document (%s)"
raise NotUniqueError(message % six.text_type(err))
raise NotUniqueError(message % err)
except pymongo.errors.BulkWriteError as err:
# inserting documents that already have an _id field will
# give huge performance debt or raise
message = u"Bulk write error: (%s)"
raise BulkWriteError(message % six.text_type(err.details))
message = "Bulk write error: (%s)"
raise BulkWriteError(message % err.details)
except pymongo.errors.OperationFailure as err:
message = "Could not save document (%s)"
if re.match("^E1100[01] duplicate key", six.text_type(err)):
if re.match("^E1100[01] duplicate key", str(err)):
# E11000 - duplicate key error index
# E11001 - duplicate key on update
message = u"Tried to save duplicate unique keys (%s)"
raise NotUniqueError(message % six.text_type(err))
raise OperationError(message % six.text_type(err))
message = "Tried to save duplicate unique keys (%s)"
raise NotUniqueError(message % err)
raise OperationError(message % err)
# Apply inserted_ids to documents
for doc, doc_id in zip(docs, ids):
@@ -555,12 +549,12 @@ class BaseQuerySet(object):
elif result.raw_result:
return result.raw_result["n"]
except pymongo.errors.DuplicateKeyError as err:
raise NotUniqueError(u"Update failed (%s)" % six.text_type(err))
raise NotUniqueError("Update failed (%s)" % err)
except pymongo.errors.OperationFailure as err:
if six.text_type(err) == u"multi not coded yet":
message = u"update() method requires MongoDB 1.1.3+"
if str(err) == "multi not coded yet":
message = "update() method requires MongoDB 1.1.3+"
raise OperationError(message)
raise OperationError(u"Update failed (%s)" % six.text_type(err))
raise OperationError("Update failed (%s)" % err)
def upsert_one(self, write_concern=None, read_concern=None, **update):
"""Overwrite or add the first document matched by the query.
@@ -680,9 +674,9 @@ class BaseQuerySet(object):
**self._cursor_args
)
except pymongo.errors.DuplicateKeyError as err:
raise NotUniqueError(u"Update failed (%s)" % err)
raise NotUniqueError("Update failed (%s)" % err)
except pymongo.errors.OperationFailure as err:
raise OperationError(u"Update failed (%s)" % err)
raise OperationError("Update failed (%s)" % err)
if full_response:
if result["value"] is not None:
@@ -711,7 +705,7 @@ class BaseQuerySet(object):
return queryset.filter(pk=object_id).first()
def in_bulk(self, object_ids):
"""Retrieve a set of documents by their ids.
""""Retrieve a set of documents by their ids.
:param object_ids: a list or tuple of ObjectId's
:rtype: dict of ObjectId's as keys and collection-specific
@@ -794,7 +788,6 @@ class BaseQuerySet(object):
"_ordering",
"_snapshot",
"_timeout",
"_slave_okay",
"_read_preference",
"_iter",
"_scalar",
@@ -1008,7 +1001,7 @@ class BaseQuerySet(object):
.. versionchanged:: 0.5 - Added subfield support
"""
fields = {f: QueryFieldList.ONLY for f in fields}
self.only_fields = fields.keys()
self.only_fields = list(fields.keys())
return self.fields(True, **fields)
def exclude(self, *fields):
@@ -1191,20 +1184,6 @@ class BaseQuerySet(object):
queryset._timeout = enabled
return queryset
# DEPRECATED. Has no more impact on PyMongo 3+
def slave_okay(self, enabled):
"""Enable or disable the slave_okay when querying.
:param enabled: whether or not the slave_okay is enabled
.. deprecated:: Ignored with PyMongo 3+
"""
msg = "slave_okay is deprecated as it has no impact when using PyMongo 3+."
warnings.warn(msg, DeprecationWarning)
queryset = self.clone()
queryset._slave_okay = enabled
return queryset
def read_preference(self, read_preference):
"""Change the read_preference when querying.
@@ -1387,13 +1366,13 @@ class BaseQuerySet(object):
map_f_scope = {}
if isinstance(map_f, Code):
map_f_scope = map_f.scope
map_f = six.text_type(map_f)
map_f = str(map_f)
map_f = Code(queryset._sub_js_fields(map_f), map_f_scope)
reduce_f_scope = {}
if isinstance(reduce_f, Code):
reduce_f_scope = reduce_f.scope
reduce_f = six.text_type(reduce_f)
reduce_f = str(reduce_f)
reduce_f_code = queryset._sub_js_fields(reduce_f)
reduce_f = Code(reduce_f_code, reduce_f_scope)
@@ -1403,7 +1382,7 @@ class BaseQuerySet(object):
finalize_f_scope = {}
if isinstance(finalize_f, Code):
finalize_f_scope = finalize_f.scope
finalize_f = six.text_type(finalize_f)
finalize_f = str(finalize_f)
finalize_f_code = queryset._sub_js_fields(finalize_f)
finalize_f = Code(finalize_f_code, finalize_f_scope)
mr_args["finalize"] = finalize_f
@@ -1419,7 +1398,7 @@ class BaseQuerySet(object):
else:
map_reduce_function = "map_reduce"
if isinstance(output, six.string_types):
if isinstance(output, str):
mr_args["out"] = output
elif isinstance(output, dict):
@@ -1606,7 +1585,7 @@ class BaseQuerySet(object):
if self._limit == 0 or self._none:
raise StopIteration
raw_doc = six.next(self._cursor)
raw_doc = next(self._cursor)
if self._as_pymongo:
return raw_doc
@@ -1851,13 +1830,13 @@ class BaseQuerySet(object):
}
"""
total, data, types = self.exec_js(freq_func, field)
values = {types.get(k): int(v) for k, v in iteritems(data)}
values = {types.get(k): int(v) for k, v in data.items()}
if normalize:
values = {k: float(v) / total for k, v in values.items()}
frequencies = {}
for k, v in iteritems(values):
for k, v in values.items():
if isinstance(k, float):
if int(k) == k:
k = int(k)
@@ -1877,7 +1856,7 @@ class BaseQuerySet(object):
field_parts = field.split(".")
try:
field = ".".join(
f if isinstance(f, six.string_types) else f.db_field
f if isinstance(f, str) else f.db_field
for f in self._document._lookup_field(field_parts)
)
db_field_paths.append(field)
@@ -1889,7 +1868,7 @@ class BaseQuerySet(object):
for subdoc in subclasses:
try:
subfield = ".".join(
f if isinstance(f, six.string_types) else f.db_field
f if isinstance(f, str) else f.db_field
for f in subdoc._lookup_field(field_parts)
)
db_field_paths.append(subfield)
@@ -1963,7 +1942,7 @@ class BaseQuerySet(object):
field_name = match.group(1).split(".")
fields = self._document._lookup_field(field_name)
# Substitute the correct name for the field into the javascript
return u'["%s"]' % fields[-1].db_field
return '["%s"]' % fields[-1].db_field
def field_path_sub(match):
# Extract just the field name, and look up the field objects
@@ -1993,23 +1972,3 @@ class BaseQuerySet(object):
setattr(queryset, "_" + method_name, val)
return queryset
# Deprecated
def ensure_index(self, **kwargs):
"""Deprecated use :func:`Document.ensure_index`"""
msg = (
"Doc.objects()._ensure_index() is deprecated. "
"Use Doc.ensure_index() instead."
)
warnings.warn(msg, DeprecationWarning)
self._document.__class__.ensure_index(**kwargs)
return self
def _ensure_indexes(self):
"""Deprecated use :func:`~Document.ensure_indexes`"""
msg = (
"Doc.objects()._ensure_indexes() is deprecated. "
"Use Doc.ensure_indexes() instead."
)
warnings.warn(msg, DeprecationWarning)
self._document.__class__.ensure_indexes()

View File

@@ -1,7 +1,7 @@
__all__ = ("QueryFieldList",)
class QueryFieldList(object):
class QueryFieldList:
"""Object that handles combinations of .only() and .exclude() calls"""
ONLY = 1
@@ -69,8 +69,6 @@ class QueryFieldList(object):
def __bool__(self):
return bool(self.fields)
__nonzero__ = __bool__ # For Py2 support
def as_dict(self):
field_list = {field: self.value for field in self.fields}
if self.slice:
@@ -80,7 +78,7 @@ class QueryFieldList(object):
return field_list
def reset(self):
self.fields = set([])
self.fields = set()
self.slice = {}
self.value = self.ONLY

View File

@@ -4,7 +4,7 @@ from mongoengine.queryset.queryset import QuerySet
__all__ = ("queryset_manager", "QuerySetManager")
class QuerySetManager(object):
class QuerySetManager:
"""
The default QuerySet Manager.

View File

@@ -1,5 +1,3 @@
import six
from mongoengine.errors import OperationError
from mongoengine.queryset.base import (
BaseQuerySet,
@@ -127,8 +125,8 @@ class QuerySet(BaseQuerySet):
# Pull in ITER_CHUNK_SIZE docs from the database and store them in
# the result cache.
try:
for _ in six.moves.range(ITER_CHUNK_SIZE):
self._result_cache.append(six.next(self))
for _ in range(ITER_CHUNK_SIZE):
self._result_cache.append(next(self))
except StopIteration:
# Getting this exception means there are no more docs in the
# db cursor. Set _has_more to False so that we can use that
@@ -143,10 +141,10 @@ class QuerySet(BaseQuerySet):
getting the count
"""
if with_limit_and_skip is False:
return super(QuerySet, self).count(with_limit_and_skip)
return super().count(with_limit_and_skip)
if self._len is None:
self._len = super(QuerySet, self).count(with_limit_and_skip)
self._len = super().count(with_limit_and_skip)
return self._len
@@ -180,9 +178,9 @@ class QuerySetNoCache(BaseQuerySet):
return ".. queryset mid-iteration .."
data = []
for _ in six.moves.range(REPR_OUTPUT_SIZE + 1):
for _ in range(REPR_OUTPUT_SIZE + 1):
try:
data.append(six.next(self))
data.append(next(self))
except StopIteration:
break

View File

@@ -3,14 +3,12 @@ from collections import defaultdict
from bson import ObjectId, SON
from bson.dbref import DBRef
import pymongo
import six
from six import iteritems
from mongoengine.base import UPDATE_OPERATORS
from mongoengine.common import _import_class
from mongoengine.errors import InvalidQueryError
__all__ = ("query", "update")
__all__ = ("query", "update", "STRING_OPERATORS")
COMPARISON_OPERATORS = (
"ne",
@@ -101,7 +99,7 @@ def query(_doc_cls=None, **kwargs):
cleaned_fields = []
for field in fields:
append_field = True
if isinstance(field, six.string_types):
if isinstance(field, str):
parts.append(field)
append_field = False
# is last and CachedReferenceField
@@ -180,7 +178,7 @@ def query(_doc_cls=None, **kwargs):
"$near" in value_dict or "$nearSphere" in value_dict
):
value_son = SON()
for k, v in iteritems(value_dict):
for k, v in value_dict.items():
if k == "$maxDistance" or k == "$minDistance":
continue
value_son[k] = v
@@ -281,7 +279,7 @@ def update(_doc_cls=None, **update):
appended_sub_field = False
for field in fields:
append_field = True
if isinstance(field, six.string_types):
if isinstance(field, str):
# Convert the S operator to $
if field == "S":
field = "$"
@@ -435,7 +433,9 @@ def _geo_operator(field, op, value):
value = {"$near": _infer_geometry(value)}
else:
raise NotImplementedError(
'Geo method "%s" has not been implemented for a %s ' % (op, field._name)
'Geo method "{}" has not been implemented for a {} '.format(
op, field._name
)
)
return value

View File

@@ -7,7 +7,7 @@ from mongoengine.queryset import transform
__all__ = ("Q", "QNode")
class QNodeVisitor(object):
class QNodeVisitor:
"""Base visitor class for visiting Q-object nodes in a query tree.
"""
@@ -79,7 +79,7 @@ class QueryCompilerVisitor(QNodeVisitor):
return transform.query(self.document, **query.query)
class QNode(object):
class QNode:
"""Base class for nodes in query trees."""
AND = 0
@@ -143,8 +143,6 @@ class QCombination(QNode):
def __bool__(self):
return bool(self.children)
__nonzero__ = __bool__ # For Py2 support
def accept(self, visitor):
for i in range(len(self.children)):
if isinstance(self.children[i], QNode):
@@ -180,8 +178,6 @@ class Q(QNode):
def __bool__(self):
return bool(self.query)
__nonzero__ = __bool__ # For Py2 support
def __eq__(self, other):
return self.__class__ == other.__class__ and self.query == other.query