From 8ef771912dbcd0cc56cfb7a1b5fefb6172518139 Mon Sep 17 00:00:00 2001 From: Bastien Gerard Date: Sun, 28 Feb 2021 14:07:15 +0100 Subject: [PATCH] fixing incompatibility with mongoengine aggregation to support mongo 4.4 --- mongoengine/mongodb_support.py | 1 + mongoengine/queryset/base.py | 9 +++------ tests/document/test_indexes.py | 5 ++++- tests/queryset/test_queryset.py | 18 +++++++++++++++--- tests/utils.py | 8 ++++++++ 5 files changed, 31 insertions(+), 10 deletions(-) diff --git a/mongoengine/mongodb_support.py b/mongoengine/mongodb_support.py index 522f064e..df51100d 100644 --- a/mongoengine/mongodb_support.py +++ b/mongoengine/mongodb_support.py @@ -8,6 +8,7 @@ from mongoengine.connection import get_connection # get_mongodb_version() MONGODB_34 = (3, 4) MONGODB_36 = (3, 6) +MONGODB_44 = (4, 4) def get_mongodb_version(): diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index ae8cd407..47a5f733 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -1355,21 +1355,18 @@ class BaseQuerySet: MapReduceDocument = _import_class("MapReduceDocument") - if not hasattr(self._collection, "map_reduce"): - raise NotImplementedError("Requires MongoDB >= 1.7.1") - map_f_scope = {} if isinstance(map_f, Code): map_f_scope = map_f.scope map_f = str(map_f) - map_f = Code(queryset._sub_js_fields(map_f), map_f_scope) + map_f = Code(queryset._sub_js_fields(map_f), map_f_scope or None) reduce_f_scope = {} if isinstance(reduce_f, Code): reduce_f_scope = reduce_f.scope reduce_f = str(reduce_f) reduce_f_code = queryset._sub_js_fields(reduce_f) - reduce_f = Code(reduce_f_code, reduce_f_scope) + reduce_f = Code(reduce_f_code, reduce_f_scope or None) mr_args = {"query": queryset._query} @@ -1379,7 +1376,7 @@ class BaseQuerySet: finalize_f_scope = finalize_f.scope finalize_f = str(finalize_f) finalize_f_code = queryset._sub_js_fields(finalize_f) - finalize_f = Code(finalize_f_code, finalize_f_scope) + finalize_f = Code(finalize_f_code, finalize_f_scope or None) mr_args["finalize"] = finalize_f if scope: diff --git a/tests/document/test_indexes.py b/tests/document/test_indexes.py index 55a56931..17643dd8 100644 --- a/tests/document/test_indexes.py +++ b/tests/document/test_indexes.py @@ -7,6 +7,7 @@ import pytest from mongoengine import * from mongoengine.connection import get_db +from mongoengine.mongodb_support import MONGODB_44, get_mongodb_version class TestIndexes(unittest.TestCase): @@ -452,9 +453,11 @@ class TestIndexes(unittest.TestCase): .get("stage") == "IXSCAN" ) + mongo_db = get_mongodb_version() + PROJECTION_STR = "PROJECTION" if mongo_db < MONGODB_44 else "PROJECTION_COVERED" assert ( query_plan.get("queryPlanner").get("winningPlan").get("stage") - == "PROJECTION" + == PROJECTION_STR ) query_plan = Test.objects(a=1).explain() diff --git a/tests/queryset/test_queryset.py b/tests/queryset/test_queryset.py index 4d281c60..c346abde 100644 --- a/tests/queryset/test_queryset.py +++ b/tests/queryset/test_queryset.py @@ -21,7 +21,11 @@ from mongoengine.queryset import ( QuerySetManager, queryset_manager, ) -from tests.utils import requires_mongodb_gte_44 +from tests.utils import ( + requires_mongodb_gte_44, + requires_mongodb_lt_42, + requires_mongodb_lte_42, +) class db_ops_tracker(query_counter): @@ -1490,6 +1494,7 @@ class TestQueryset(unittest.TestCase): BlogPost.drop_collection() + @requires_mongodb_lt_42 def test_exec_js_query(self): """Ensure that queries are properly formed for use in exec_js.""" @@ -1527,6 +1532,7 @@ class TestQueryset(unittest.TestCase): BlogPost.drop_collection() + @requires_mongodb_lt_42 def test_exec_js_field_sub(self): """Ensure that field substitutions occur properly in exec_js functions.""" @@ -3109,6 +3115,7 @@ class TestQueryset(unittest.TestCase): freq = Person.objects.item_frequencies("city", normalize=True, map_reduce=True) assert freq == {"CRB": 0.5, None: 0.5} + @requires_mongodb_lte_42 def test_item_frequencies_with_null_embedded(self): class Data(EmbeddedDocument): name = StringField() @@ -3137,6 +3144,7 @@ class TestQueryset(unittest.TestCase): ot = Person.objects.item_frequencies("extra.tag", map_reduce=True) assert ot == {None: 1.0, "friend": 1.0} + @requires_mongodb_lte_42 def test_item_frequencies_with_0_values(self): class Test(Document): val = IntField() @@ -3151,6 +3159,7 @@ class TestQueryset(unittest.TestCase): ot = Test.objects.item_frequencies("val", map_reduce=False) assert ot == {0: 1} + @requires_mongodb_lte_42 def test_item_frequencies_with_False_values(self): class Test(Document): val = BooleanField() @@ -3165,6 +3174,7 @@ class TestQueryset(unittest.TestCase): ot = Test.objects.item_frequencies("val", map_reduce=False) assert ot == {False: 1} + @requires_mongodb_lte_42 def test_item_frequencies_normalize(self): class Test(Document): val = IntField() @@ -3551,7 +3561,8 @@ class TestQueryset(unittest.TestCase): Book.objects.create(title="The Stories", authors=[mark_twain, john_tolkien]) authors = Book.objects.distinct("authors") - assert authors == [mark_twain, john_tolkien] + authors_names = {author.name for author in authors} + assert authors_names == {mark_twain.name, john_tolkien.name} def test_distinct_ListField_EmbeddedDocumentField_EmbeddedDocumentField(self): class Continent(EmbeddedDocument): @@ -3588,7 +3599,8 @@ class TestQueryset(unittest.TestCase): assert country_list == [scotland, tibet] continent_list = Book.objects.distinct("authors.country.continent") - assert continent_list == [europe, asia] + continent_list_names = {c.continent_name for c in continent_list} + assert continent_list_names == {europe.continent_name, asia.continent_name} def test_distinct_ListField_ReferenceField(self): class Bar(Document): diff --git a/tests/utils.py b/tests/utils.py index adb0bdb4..19596afa 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -34,6 +34,14 @@ def get_as_pymongo(doc): return doc.__class__.objects.as_pymongo().get(id=doc.id) +def requires_mongodb_lt_42(func): + return _decorated_with_ver_requirement(func, (4, 2), oper=operator.lt) + + +def requires_mongodb_lte_42(func): + return _decorated_with_ver_requirement(func, (4, 2), oper=operator.le) + + def requires_mongodb_gte_44(func): return _decorated_with_ver_requirement(func, (4, 4), oper=operator.ge)