From 48316ba60d4051d38d156c5704d422b077ce11bd Mon Sep 17 00:00:00 2001 From: mrigal Date: Fri, 10 Apr 2015 22:23:56 +0200 Subject: [PATCH] implemented global IS_PYMONGO_3 --- mongoengine/connection.py | 7 +++---- mongoengine/document.py | 2 +- mongoengine/python_support.py | 7 +++++++ mongoengine/queryset/base.py | 3 ++- tests/document/indexes.py | 11 ++++++----- tests/queryset/queryset.py | 8 ++++---- tests/test_connection.py | 9 +++++---- tests/test_replicaset_connection.py | 6 ++++-- 8 files changed, 32 insertions(+), 21 deletions(-) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 897dcc2a..b203e168 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,13 +1,12 @@ -import pymongo from pymongo import MongoClient, ReadPreference, uri_parser - +from mongoengine.python_support import IS_PYMONGO_3 __all__ = ['ConnectionError', 'connect', 'register_connection', 'DEFAULT_CONNECTION_NAME'] DEFAULT_CONNECTION_NAME = 'default' -if pymongo.version_tuple[0] >= 3: +if IS_PYMONGO_3: READ_PREFERENCE = ReadPreference.PRIMARY else: from pymongo import MongoReplicaSetClient @@ -115,7 +114,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): # Discard replicaSet if not base string if not isinstance(conn_settings['replicaSet'], basestring): conn_settings.pop('replicaSet', None) - if pymongo.version_tuple[0] < 3: + if not IS_PYMONGO_3: connection_class = MongoReplicaSetClient try: diff --git a/mongoengine/document.py b/mongoengine/document.py index b27bd086..8cc92866 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -295,7 +295,7 @@ class Document(BaseDocument): # I think the self._created flag is not necessarily required in PyMongo3 # but may cause test test_collection_name_and_primary to fail - # if pymongo.version_tuple[0] < 3: + # if not IS_PYMONGO_3: created = ('_id' not in doc or self._created or force_insert) # else: # created = ('_id' not in doc or force_insert) diff --git a/mongoengine/python_support.py b/mongoengine/python_support.py index 1214b490..3412c841 100644 --- a/mongoengine/python_support.py +++ b/mongoengine/python_support.py @@ -1,6 +1,13 @@ """Helper functions and types to aid with Python 2.5 - 3 support.""" import sys +import pymongo + + +if pymongo.version_tuple[0] < 3: + IS_PYMONGO_3 = False +else: + IS_PYMONGO_3 = True PY3 = sys.version_info[0] == 3 diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index 09a4c3bc..6867cef3 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -21,6 +21,7 @@ from mongoengine.common import _import_class from mongoengine.base.common import get_document from mongoengine.errors import (OperationError, NotUniqueError, InvalidQueryError, LookUpError) +from mongoengine.python_support import IS_PYMONGO_3 from mongoengine.queryset import transform from mongoengine.queryset.field_list import QueryFieldList from mongoengine.queryset.visitor import Q, QNode @@ -1385,7 +1386,7 @@ class BaseQuerySet(object): @property def _cursor_args(self): - if pymongo.version_tuple[0] < 3: + if not IS_PYMONGO_3: fields_name = 'fields' cursor_args = { 'timeout': self._timeout, diff --git a/tests/document/indexes.py b/tests/document/indexes.py index b8b3ba0c..7f40e2fd 100644 --- a/tests/document/indexes.py +++ b/tests/document/indexes.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- import unittest import sys + sys.path[0:0] = [""] -import os import pymongo from nose.plugins.skip import SkipTest @@ -11,6 +11,7 @@ from datetime import datetime from mongoengine import * from mongoengine.connection import get_db, get_connection +from mongoengine.python_support import IS_PYMONGO_3 __all__ = ("IndexesTest", ) @@ -447,26 +448,26 @@ class IndexesTest(unittest.TestCase): # Need to be explicit about covered indexes as mongoDB doesn't know if # the documents returned might have more keys in that here. query_plan = Test.objects(id=obj.id).exclude('a').explain() - if pymongo.version_tuple[0] < 3: + if not IS_PYMONGO_3: self.assertFalse(query_plan['indexOnly']) else: self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IDHACK') query_plan = Test.objects(id=obj.id).only('id').explain() - if pymongo.version_tuple[0] < 3: + if not IS_PYMONGO_3: self.assertTrue(query_plan['indexOnly']) else: self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IDHACK') query_plan = Test.objects(a=1).only('a').exclude('id').explain() - if pymongo.version_tuple[0] < 3: + if not IS_PYMONGO_3: self.assertTrue(query_plan['indexOnly']) else: self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IXSCAN') self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('stage'), 'PROJECTION') query_plan = Test.objects(a=1).explain() - if pymongo.version_tuple[0] < 3: + if not IS_PYMONGO_3: self.assertFalse(query_plan['indexOnly']) else: self.assertEqual(query_plan.get('queryPlanner').get('winningPlan').get('inputStage').get('stage'), 'IXSCAN') diff --git a/tests/queryset/queryset.py b/tests/queryset/queryset.py index a0268361..b757ddda 100644 --- a/tests/queryset/queryset.py +++ b/tests/queryset/queryset.py @@ -17,7 +17,7 @@ from bson import ObjectId from mongoengine import * from mongoengine.connection import get_connection, get_db -from mongoengine.python_support import PY3 +from mongoengine.python_support import PY3, IS_PYMONGO_3 from mongoengine.context_managers import query_counter, switch_db from mongoengine.queryset import (QuerySet, QuerySetManager, MultipleObjectsReturned, DoesNotExist, @@ -54,7 +54,7 @@ def skip_older_mongodb(f): def skip_pymongo3(f): def _inner(*args, **kwargs): - if pymongo.version_tuple[0] >= 3: + if IS_PYMONGO_3: raise SkipTest("Useless with PyMongo 3+") return f(*args, **kwargs) @@ -2942,7 +2942,7 @@ class QuerySetTest(unittest.TestCase): self.assertEqual(query.count(), 3) self.assertEqual(query._query, {'$text': {'$search': 'brasil'}}) cursor_args = query._cursor_args - if pymongo.version_tuple[0] < 3: + if not IS_PYMONGO_3: cursor_args_fields = cursor_args['fields'] else: cursor_args_fields = cursor_args['projection'] @@ -4012,7 +4012,7 @@ class QuerySetTest(unittest.TestCase): bars = list(Bar.objects(read_preference=ReadPreference.PRIMARY)) self.assertEqual([], bars) - if pymongo.version_tuple[0] < 3: + if not IS_PYMONGO_3: error_class = ConfigurationError else: error_class = TypeError diff --git a/tests/test_connection.py b/tests/test_connection.py index 88e03994..f9cc9c78 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -17,12 +17,13 @@ from mongoengine import ( connect, register_connection, Document, DateTimeField ) +from mongoengine.python_support import IS_PYMONGO_3 import mongoengine.connection from mongoengine.connection import get_db, get_connection, ConnectionError def get_tz_awareness(connection): - if pymongo.version_tuple[0] < 3: + if not IS_PYMONGO_3: return connection.tz_aware else: return connection.codec_options.tz_aware @@ -76,7 +77,7 @@ class ConnectionTest(unittest.TestCase): c.admin.authenticate("admin", "password") c.mongoenginetest.add_user("username", "password") - if pymongo.version_tuple[0] < 3: + if not IS_PYMONGO_3: self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest') @@ -103,7 +104,7 @@ class ConnectionTest(unittest.TestCase): c.admin.authenticate("admin", "password") c.mongoenginetest.add_user("username", "password") - if pymongo.version_tuple[0] < 3: + if not IS_PYMONGO_3: self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') connect("mongoenginetest", host='mongodb://localhost/') @@ -202,7 +203,7 @@ class ConnectionTest(unittest.TestCase): self.assertEqual(len(mongo_connections.items()), 2) self.assertTrue('t1' in mongo_connections.keys()) self.assertTrue('t2' in mongo_connections.keys()) - if pymongo.version_tuple[0] < 3: + if not IS_PYMONGO_3: self.assertEqual(mongo_connections['t1'].host, 'localhost') self.assertEqual(mongo_connections['t2'].host, '127.0.0.1') else: diff --git a/tests/test_replicaset_connection.py b/tests/test_replicaset_connection.py index b3a7e1bf..361cff41 100644 --- a/tests/test_replicaset_connection.py +++ b/tests/test_replicaset_connection.py @@ -1,11 +1,13 @@ import sys + sys.path[0:0] = [""] import unittest -import pymongo from pymongo import ReadPreference -if pymongo.version_tuple[0] >= 3: +from mongoengine.python_support import IS_PYMONGO_3 + +if IS_PYMONGO_3: from pymongo import MongoClient CONN_CLASS = MongoClient READ_PREF = ReadPreference.SECONDARY