From 928770c43a892e7e6a01e5a85f5d90a5b9a63c72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bastien=20G=C3=A9rard?= Date: Fri, 20 Dec 2019 21:50:00 +0100 Subject: [PATCH] switching to count_documents --- mongoengine/pymongo_support.py | 28 ++++++++++++++++++++++++---- mongoengine/queryset/base.py | 32 +++++++++++++++++++++++++++++++- mongoengine/queryset/queryset.py | 1 + tests/test_connection.py | 2 +- 4 files changed, 57 insertions(+), 6 deletions(-) diff --git a/mongoengine/pymongo_support.py b/mongoengine/pymongo_support.py index 80c0661b..1fea9525 100644 --- a/mongoengine/pymongo_support.py +++ b/mongoengine/pymongo_support.py @@ -10,12 +10,32 @@ PYMONGO_VERSION = tuple(pymongo.version_tuple[:2]) IS_PYMONGO_GTE_37 = PYMONGO_VERSION >= _PYMONGO_37 -def count_documents(collection, filter): - """Pymongo>3.7 deprecates count in favour of count_documents""" +def count_documents(collection, filter, skip=None, limit=None, hint=None, collation=None): + """Pymongo>3.7 deprecates count in favour of count_documents + """ + if limit == 0: + return 0 # Pymongo raises an OperationFailure if called with limit=0 + if IS_PYMONGO_GTE_37: - return collection.count_documents(filter) + kwargs = {} + if skip is not None: + kwargs["skip"] = skip + if limit is not None: + kwargs["limit"] = limit + if collation is not None: + kwargs["collation"] = collation + if hint not in (-1, None): + kwargs["hint"] = hint + return collection.count_documents(filter=filter, **kwargs) else: - count = collection.find(filter).count() + cursor = collection.find(filter) + if limit: + cursor = cursor.limit(limit) + if skip: + cursor = cursor.skip(skip) + if hint != -1: + cursor = cursor.hint(hint) + count = cursor.count() return count diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 50cb37ac..d7b4007e 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -12,6 +12,7 @@ import pymongo.errors from pymongo.collection import ReturnDocument from pymongo.common import validate_read_preference import six +from pymongo.errors import OperationFailure from six import iteritems from mongoengine import signals @@ -26,6 +27,7 @@ from mongoengine.errors import ( NotUniqueError, OperationError, ) +from mongoengine.pymongo_support import count_documents from mongoengine.queryset import transform from mongoengine.queryset.field_list import QueryFieldList from mongoengine.queryset.visitor import Q, QNode @@ -392,9 +394,37 @@ class BaseQuerySet(object): :meth:`skip` that has been applied to this cursor into account when getting the count """ + # mimic the fact that setting .limit(0) in pymongo sets no limit + # https://docs.mongodb.com/manual/reference/method/cursor.limit/#zero-value if self._limit == 0 and with_limit_and_skip is False or self._none: return 0 - count = self._cursor.count(with_limit_and_skip=with_limit_and_skip) + + kwargs = ( + {"limit": self._limit, "skip": self._skip} if with_limit_and_skip else {} + ) + + if self._limit == 0: + # mimic the fact that historically .limit(0) sets no limit + kwargs.pop('limit', None) + + if self._hint not in (-1, None): + kwargs["hint"] = self._hint + + if self._collation: + kwargs["collation"] = self._collation + + try: + count = count_documents( + collection=self._cursor.collection, + filter=self._cursor._Cursor__spec, + **kwargs + ) + except OperationFailure: + # Accounts for some operators that used to work with .count but are no longer working + # with count_documents (i.e $geoNear, $near, and $nearSphere) + # fallback to deprecated Cursor.count + count = self._cursor.count(with_limit_and_skip=with_limit_and_skip) + self._cursor_obj = None return count diff --git a/mongoengine/queryset/queryset.py b/mongoengine/queryset/queryset.py index 4ba62d46..cc1891f6 100644 --- a/mongoengine/queryset/queryset.py +++ b/mongoengine/queryset/queryset.py @@ -146,6 +146,7 @@ class QuerySet(BaseQuerySet): return super(QuerySet, self).count(with_limit_and_skip) if self._len is None: + # cache the length self._len = super(QuerySet, self).count(with_limit_and_skip) return self._len diff --git a/tests/test_connection.py b/tests/test_connection.py index e40a6994..542da4f0 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -282,7 +282,7 @@ class ConnectionTest(unittest.TestCase): # database won't exist until we save a document some_document.save() assert conn.get_default_database().name == "mongoenginetest" - assert conn.database_names()[0] == "mongoenginetest" + assert conn.list_database_names()[0] == "mongoenginetest" @require_mongomock def test_connect_with_host_list(self):