Remove more code related to supporting pymongo2

This commit is contained in:
Bastien Gérard 2019-05-15 22:23:35 +02:00
parent ac64ade10f
commit cf38ef70cb
9 changed files with 70 additions and 231 deletions

View File

@ -2,8 +2,6 @@ from pymongo import MongoClient, ReadPreference, uri_parser
from pymongo.database import _check_name from pymongo.database import _check_name
import six import six
from mongoengine.pymongo_support import IS_PYMONGO_3
__all__ = ['MongoEngineConnectionError', 'connect', 'disconnect', 'disconnect_all', __all__ = ['MongoEngineConnectionError', 'connect', 'disconnect', 'disconnect_all',
'register_connection', 'DEFAULT_CONNECTION_NAME', 'DEFAULT_DATABASE_NAME', 'register_connection', 'DEFAULT_CONNECTION_NAME', 'DEFAULT_DATABASE_NAME',
'get_db', 'get_connection'] 'get_db', 'get_connection']
@ -14,11 +12,11 @@ DEFAULT_DATABASE_NAME = 'test'
DEFAULT_HOST = 'localhost' DEFAULT_HOST = 'localhost'
DEFAULT_PORT = 27017 DEFAULT_PORT = 27017
if IS_PYMONGO_3: _connection_settings = {}
READ_PREFERENCE = ReadPreference.PRIMARY _connections = {}
else: _dbs = {}
from pymongo import MongoReplicaSetClient
READ_PREFERENCE = False READ_PREFERENCE = ReadPreference.PRIMARY
class MongoEngineConnectionError(Exception): class MongoEngineConnectionError(Exception):
@ -28,12 +26,7 @@ class MongoEngineConnectionError(Exception):
pass pass
_connection_settings = {} def _check_db_name(name):
_connections = {}
_dbs = {}
def check_db_name(name):
"""Check if a database name is valid. """Check if a database name is valid.
This functionality is copied from pymongo Database class constructor. This functionality is copied from pymongo Database class constructor.
""" """
@ -57,7 +50,6 @@ def _get_connection_settings(
: param host: the host name of the: program: `mongod` instance to connect to : param host: the host name of the: program: `mongod` instance to connect to
: param port: the port that the: program: `mongod` instance is running on : param port: the port that the: program: `mongod` instance is running on
: param read_preference: The read preference for the collection : param read_preference: The read preference for the collection
** Added pymongo 2.1
: param username: username to authenticate with : param username: username to authenticate with
: param password: password to authenticate with : param password: password to authenticate with
: param authentication_source: database to authenticate against : param authentication_source: database to authenticate against
@ -83,7 +75,7 @@ def _get_connection_settings(
'authentication_mechanism': authentication_mechanism 'authentication_mechanism': authentication_mechanism
} }
check_db_name(conn_settings['name']) _check_db_name(conn_settings['name'])
conn_host = conn_settings['host'] conn_host = conn_settings['host']
# Host can be a list or a string, so if string, force to a list. # Host can be a list or a string, so if string, force to a list.
@ -119,7 +111,7 @@ def _get_connection_settings(
conn_settings['authentication_source'] = uri_options['authsource'] conn_settings['authentication_source'] = uri_options['authsource']
if 'authmechanism' in uri_options: if 'authmechanism' in uri_options:
conn_settings['authentication_mechanism'] = uri_options['authmechanism'] conn_settings['authentication_mechanism'] = uri_options['authmechanism']
if IS_PYMONGO_3 and 'readpreference' in uri_options: if 'readpreference' in uri_options:
read_preferences = ( read_preferences = (
ReadPreference.NEAREST, ReadPreference.NEAREST,
ReadPreference.PRIMARY, ReadPreference.PRIMARY,
@ -158,7 +150,6 @@ def register_connection(alias, db=None, name=None, host=None, port=None,
: param host: the host name of the: program: `mongod` instance to connect to : param host: the host name of the: program: `mongod` instance to connect to
: param port: the port that the: program: `mongod` instance is running on : param port: the port that the: program: `mongod` instance is running on
: param read_preference: The read preference for the collection : param read_preference: The read preference for the collection
** Added pymongo 2.1
: param username: username to authenticate with : param username: username to authenticate with
: param password: password to authenticate with : param password: password to authenticate with
: param authentication_source: database to authenticate against : param authentication_source: database to authenticate against
@ -259,22 +250,6 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
else: else:
connection_class = MongoClient connection_class = MongoClient
# For replica set connections with PyMongo 2.x, use
# MongoReplicaSetClient.
# TODO remove this once we stop supporting PyMongo 2.x.
if 'replicaSet' in conn_settings and not IS_PYMONGO_3:
connection_class = MongoReplicaSetClient
conn_settings['hosts_or_uri'] = conn_settings.pop('host', None)
# hosts_or_uri has to be a string, so if 'host' was provided
# as a list, join its parts and separate them by ','
if isinstance(conn_settings['hosts_or_uri'], list):
conn_settings['hosts_or_uri'] = ','.join(
conn_settings['hosts_or_uri'])
# Discard port since it can't be used on MongoReplicaSetClient
conn_settings.pop('port', None)
# Iterate over all of the connection settings and if a connection with # Iterate over all of the connection settings and if a connection with
# the same parameters is already established, use it instead of creating # the same parameters is already established, use it instead of creating
# a new one. # a new one.

View File

@ -18,7 +18,7 @@ from mongoengine.context_managers import (set_write_concern,
switch_db) switch_db)
from mongoengine.errors import (InvalidDocumentError, InvalidQueryError, from mongoengine.errors import (InvalidDocumentError, InvalidQueryError,
SaveConditionError) SaveConditionError)
from mongoengine.pymongo_support import IS_PYMONGO_3, list_collection_names from mongoengine.pymongo_support import list_collection_names
from mongoengine.queryset import (NotUniqueError, OperationError, from mongoengine.queryset import (NotUniqueError, OperationError,
QuerySet, transform) QuerySet, transform)
@ -822,10 +822,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
index_spec['background'] = background index_spec['background'] = background
index_spec.update(kwargs) index_spec.update(kwargs)
if IS_PYMONGO_3: return cls._get_collection().create_index(fields, **index_spec)
return cls._get_collection().create_index(fields, **index_spec)
else:
return cls._get_collection().ensure_index(fields, **index_spec)
@classmethod @classmethod
def ensure_index(cls, key_or_list, drop_dups=False, background=False, def ensure_index(cls, key_or_list, drop_dups=False, background=False,
@ -858,7 +855,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
drop_dups = cls._meta.get('index_drop_dups', False) drop_dups = cls._meta.get('index_drop_dups', False)
index_opts = cls._meta.get('index_opts') or {} index_opts = cls._meta.get('index_opts') or {}
index_cls = cls._meta.get('index_cls', True) index_cls = cls._meta.get('index_cls', True)
if IS_PYMONGO_3 and drop_dups: if drop_dups:
msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.' msg = 'drop_dups is deprecated and is removed when using PyMongo 3+.'
warnings.warn(msg, DeprecationWarning) warnings.warn(msg, DeprecationWarning)
@ -889,11 +886,7 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
if 'cls' in opts: if 'cls' in opts:
del opts['cls'] del opts['cls']
if IS_PYMONGO_3: collection.create_index(fields, background=background, **opts)
collection.create_index(fields, background=background, **opts)
else:
collection.ensure_index(fields, background=background,
drop_dups=drop_dups, **opts)
# If _cls is being used (for polymorphism), it needs an index, # If _cls is being used (for polymorphism), it needs an index,
# only if another index doesn't begin with _cls # only if another index doesn't begin with _cls
@ -904,12 +897,8 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)):
if 'cls' in index_opts: if 'cls' in index_opts:
del index_opts['cls'] del index_opts['cls']
if IS_PYMONGO_3: collection.create_index('_cls', background=background,
collection.create_index('_cls', background=background, **index_opts)
**index_opts)
else:
collection.ensure_index('_cls', background=background,
**index_opts)
@classmethod @classmethod
def list_indexes(cls): def list_indexes(cls):

View File

@ -7,7 +7,6 @@ _PYMONGO_37 = (3, 7)
PYMONGO_VERSION = tuple(pymongo.version_tuple[:2]) PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])
IS_PYMONGO_3 = PYMONGO_VERSION[0] >= 3
IS_PYMONGO_GTE_37 = PYMONGO_VERSION >= _PYMONGO_37 IS_PYMONGO_GTE_37 = PYMONGO_VERSION >= _PYMONGO_37

View File

@ -10,6 +10,7 @@ from bson import SON, json_util
from bson.code import Code from bson.code import Code
import pymongo import pymongo
import pymongo.errors import pymongo.errors
from pymongo.collection import ReturnDocument
from pymongo.common import validate_read_preference from pymongo.common import validate_read_preference
import six import six
from six import iteritems from six import iteritems
@ -21,14 +22,10 @@ from mongoengine.connection import get_db
from mongoengine.context_managers import set_write_concern, switch_db from mongoengine.context_managers import set_write_concern, switch_db
from mongoengine.errors import (InvalidQueryError, LookUpError, from mongoengine.errors import (InvalidQueryError, LookUpError,
NotUniqueError, OperationError) NotUniqueError, OperationError)
from mongoengine.pymongo_support import IS_PYMONGO_3
from mongoengine.queryset import transform from mongoengine.queryset import transform
from mongoengine.queryset.field_list import QueryFieldList from mongoengine.queryset.field_list import QueryFieldList
from mongoengine.queryset.visitor import Q, QNode from mongoengine.queryset.visitor import Q, QNode
if IS_PYMONGO_3:
from pymongo.collection import ReturnDocument
__all__ = ('BaseQuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL') __all__ = ('BaseQuerySet', 'DO_NOTHING', 'NULLIFY', 'CASCADE', 'DENY', 'PULL')
@ -631,26 +628,20 @@ class BaseQuerySet(object):
sort = queryset._ordering sort = queryset._ordering
try: try:
if IS_PYMONGO_3: if full_response:
if full_response: msg = 'With PyMongo 3+, it is not possible anymore to get the full response.'
msg = 'With PyMongo 3+, it is not possible anymore to get the full response.' warnings.warn(msg, DeprecationWarning)
warnings.warn(msg, DeprecationWarning) if remove:
if remove: result = queryset._collection.find_one_and_delete(
result = queryset._collection.find_one_and_delete( query, sort=sort, **self._cursor_args)
query, sort=sort, **self._cursor_args)
else:
if new:
return_doc = ReturnDocument.AFTER
else:
return_doc = ReturnDocument.BEFORE
result = queryset._collection.find_one_and_update(
query, update, upsert=upsert, sort=sort, return_document=return_doc,
**self._cursor_args)
else: else:
result = queryset._collection.find_and_modify( if new:
query, update, upsert=upsert, sort=sort, remove=remove, new=new, return_doc = ReturnDocument.AFTER
full_response=full_response, **self._cursor_args) else:
return_doc = ReturnDocument.BEFORE
result = queryset._collection.find_one_and_update(
query, update, upsert=upsert, sort=sort, return_document=return_doc,
**self._cursor_args)
except pymongo.errors.DuplicateKeyError as err: except pymongo.errors.DuplicateKeyError as err:
raise NotUniqueError(u'Update failed (%s)' % err) raise NotUniqueError(u'Update failed (%s)' % err)
except pymongo.errors.OperationFailure as err: except pymongo.errors.OperationFailure as err:
@ -1082,9 +1073,8 @@ class BaseQuerySet(object):
..versionchanged:: 0.5 - made chainable ..versionchanged:: 0.5 - made chainable
.. deprecated:: Ignored with PyMongo 3+ .. deprecated:: Ignored with PyMongo 3+
""" """
if IS_PYMONGO_3: msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.'
msg = 'snapshot is deprecated as it has no impact when using PyMongo 3+.' warnings.warn(msg, DeprecationWarning)
warnings.warn(msg, DeprecationWarning)
queryset = self.clone() queryset = self.clone()
queryset._snapshot = enabled queryset._snapshot = enabled
return queryset return queryset
@ -1108,9 +1098,8 @@ class BaseQuerySet(object):
.. deprecated:: Ignored with PyMongo 3+ .. deprecated:: Ignored with PyMongo 3+
""" """
if IS_PYMONGO_3: msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.'
msg = 'slave_okay is deprecated as it has no impact when using PyMongo 3+.' warnings.warn(msg, DeprecationWarning)
warnings.warn(msg, DeprecationWarning)
queryset = self.clone() queryset = self.clone()
queryset._slave_okay = enabled queryset._slave_okay = enabled
return queryset return queryset
@ -1211,7 +1200,7 @@ class BaseQuerySet(object):
pipeline = initial_pipeline + list(pipeline) pipeline = initial_pipeline + list(pipeline)
if IS_PYMONGO_3 and self._read_preference is not None: if self._read_preference is not None:
return self._collection.with_options(read_preference=self._read_preference) \ return self._collection.with_options(read_preference=self._read_preference) \
.aggregate(pipeline, cursor={}, **kwargs) .aggregate(pipeline, cursor={}, **kwargs)
@ -1421,11 +1410,7 @@ class BaseQuerySet(object):
if isinstance(field_instances[-1], ListField): if isinstance(field_instances[-1], ListField):
pipeline.insert(1, {'$unwind': '$' + field}) pipeline.insert(1, {'$unwind': '$' + field})
result = self._document._get_collection().aggregate(pipeline) result = tuple(self._document._get_collection().aggregate(pipeline))
if IS_PYMONGO_3:
result = tuple(result)
else:
result = result.get('result')
if result: if result:
return result[0]['total'] return result[0]['total']
@ -1452,11 +1437,7 @@ class BaseQuerySet(object):
if isinstance(field_instances[-1], ListField): if isinstance(field_instances[-1], ListField):
pipeline.insert(1, {'$unwind': '$' + field}) pipeline.insert(1, {'$unwind': '$' + field})
result = self._document._get_collection().aggregate(pipeline) result = tuple(self._document._get_collection().aggregate(pipeline))
if IS_PYMONGO_3:
result = tuple(result)
else:
result = result.get('result')
if result: if result:
return result[0]['total'] return result[0]['total']
return 0 return 0
@ -1564,7 +1545,7 @@ class BaseQuerySet(object):
# XXX In PyMongo 3+, we define the read preference on a collection # XXX In PyMongo 3+, we define the read preference on a collection
# level, not a cursor level. Thus, we need to get a cloned collection # level, not a cursor level. Thus, we need to get a cloned collection
# object using `with_options` first. # object using `with_options` first.
if IS_PYMONGO_3 and self._read_preference is not None: if self._read_preference is not None:
self._cursor_obj = self._collection\ self._cursor_obj = self._collection\
.with_options(read_preference=self._read_preference)\ .with_options(read_preference=self._read_preference)\
.find(self._query, **self._cursor_args) .find(self._query, **self._cursor_args)

View File

@ -8,9 +8,7 @@ from six import iteritems
from mongoengine.base import UPDATE_OPERATORS from mongoengine.base import UPDATE_OPERATORS
from mongoengine.common import _import_class from mongoengine.common import _import_class
from mongoengine.connection import get_connection
from mongoengine.errors import InvalidQueryError from mongoengine.errors import InvalidQueryError
from mongoengine.pymongo_support import IS_PYMONGO_3
__all__ = ('query', 'update') __all__ = ('query', 'update')
@ -163,16 +161,14 @@ def query(_doc_cls=None, **kwargs):
# PyMongo 3+ and MongoDB < 2.6 # PyMongo 3+ and MongoDB < 2.6
near_embedded = False near_embedded = False
for near_op in ('$near', '$nearSphere'): for near_op in ('$near', '$nearSphere'):
if isinstance(value_dict.get(near_op), dict) and ( if isinstance(value_dict.get(near_op), dict):
IS_PYMONGO_3 or get_connection().max_wire_version > 1):
value_son[near_op] = SON(value_son[near_op]) value_son[near_op] = SON(value_son[near_op])
if '$maxDistance' in value_dict: if '$maxDistance' in value_dict:
value_son[near_op][ value_son[near_op]['$maxDistance'] = value_dict['$maxDistance']
'$maxDistance'] = value_dict['$maxDistance']
if '$minDistance' in value_dict: if '$minDistance' in value_dict:
value_son[near_op][ value_son[near_op]['$minDistance'] = value_dict['$minDistance']
'$minDistance'] = value_dict['$minDistance']
near_embedded = True near_embedded = True
if not near_embedded: if not near_embedded:
if '$maxDistance' in value_dict: if '$maxDistance' in value_dict:
value_son['$maxDistance'] = value_dict['$maxDistance'] value_son['$maxDistance'] = value_dict['$maxDistance']

View File

@ -19,10 +19,9 @@ from mongoengine.connection import get_connection, get_db
from mongoengine.context_managers import query_counter, switch_db from mongoengine.context_managers import query_counter, switch_db
from mongoengine.errors import InvalidQueryError from mongoengine.errors import InvalidQueryError
from mongoengine.mongodb_support import get_mongodb_version, MONGODB_32 from mongoengine.mongodb_support import get_mongodb_version, MONGODB_32
from mongoengine.pymongo_support import IS_PYMONGO_3
from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned, from mongoengine.queryset import (DoesNotExist, MultipleObjectsReturned,
QuerySet, QuerySetManager, queryset_manager) QuerySet, QuerySetManager, queryset_manager)
from tests.utils import requires_mongodb_gte_26, skip_pymongo3 from tests.utils import requires_mongodb_gte_26
class db_ops_tracker(query_counter): class db_ops_tracker(query_counter):
@ -1047,48 +1046,6 @@ class QuerySetTest(unittest.TestCase):
org.save() # saves the org org.save() # saves the org
self.assertEqual(q, 2) self.assertEqual(q, 2)
@skip_pymongo3
def test_slave_okay(self):
"""Ensures that a query can take slave_okay syntax.
Useless with PyMongo 3+ as well as with MongoDB 3+.
"""
person1 = self.Person(name="User A", age=20)
person1.save()
person2 = self.Person(name="User B", age=30)
person2.save()
# Retrieve the first person from the database
person = self.Person.objects.slave_okay(True).first()
self.assertIsInstance(person, self.Person)
self.assertEqual(person.name, "User A")
self.assertEqual(person.age, 20)
@requires_mongodb_gte_26
@skip_pymongo3
def test_cursor_args(self):
"""Ensures the cursor args can be set as expected
"""
p = self.Person.objects
# Check default
self.assertEqual(p._cursor_args,
{'snapshot': False, 'slave_okay': False, 'timeout': True})
p = p.snapshot(False).slave_okay(False).timeout(False)
self.assertEqual(p._cursor_args,
{'snapshot': False, 'slave_okay': False, 'timeout': False})
p = p.snapshot(True).slave_okay(False).timeout(False)
self.assertEqual(p._cursor_args,
{'snapshot': True, 'slave_okay': False, 'timeout': False})
p = p.snapshot(True).slave_okay(True).timeout(False)
self.assertEqual(p._cursor_args,
{'snapshot': True, 'slave_okay': True, 'timeout': False})
p = p.snapshot(True).slave_okay(True).timeout(True)
self.assertEqual(p._cursor_args,
{'snapshot': True, 'slave_okay': True, 'timeout': True})
def test_repeated_iteration(self): def test_repeated_iteration(self):
"""Ensure that QuerySet rewinds itself one iteration finishes. """Ensure that QuerySet rewinds itself one iteration finishes.
""" """
@ -4568,12 +4525,8 @@ class QuerySetTest(unittest.TestCase):
bars = Bar.objects \ bars = Bar.objects \
.read_preference(ReadPreference.SECONDARY_PREFERRED) \ .read_preference(ReadPreference.SECONDARY_PREFERRED) \
.aggregate() .aggregate()
if IS_PYMONGO_3: self.assertEqual(bars._CommandCursor__collection.read_preference,
self.assertEqual(bars._CommandCursor__collection.read_preference, ReadPreference.SECONDARY_PREFERRED)
ReadPreference.SECONDARY_PREFERRED)
else:
self.assertNotEqual(bars._CommandCursor__collection.read_preference,
ReadPreference.SECONDARY_PREFERRED)
def test_json_simple(self): def test_json_simple(self):

View File

@ -2,6 +2,7 @@ import datetime
from pymongo import MongoClient from pymongo import MongoClient
from pymongo.errors import OperationFailure, InvalidName from pymongo.errors import OperationFailure, InvalidName
from pymongo import ReadPreference
try: try:
import unittest2 as unittest import unittest2 as unittest
@ -16,7 +17,6 @@ from mongoengine import (
connect, register_connection, connect, register_connection,
Document, DateTimeField, Document, DateTimeField,
disconnect_all, StringField) disconnect_all, StringField)
from mongoengine.pymongo_support import IS_PYMONGO_3
import mongoengine.connection import mongoengine.connection
from mongoengine.connection import (MongoEngineConnectionError, get_db, from mongoengine.connection import (MongoEngineConnectionError, get_db,
get_connection, disconnect, DEFAULT_DATABASE_NAME) get_connection, disconnect, DEFAULT_DATABASE_NAME)
@ -404,11 +404,7 @@ class ConnectionTest(unittest.TestCase):
connect('mongoenginetests', alias='testdb2') connect('mongoenginetests', alias='testdb2')
actual_connection = get_connection('testdb2') actual_connection = get_connection('testdb2')
# Handle PyMongo 3+ Async Connection expected_connection.server_info()
if IS_PYMONGO_3:
# Ensure we are connected, throws ServerSelectionTimeoutError otherwise.
# Purposely not catching exception to fail test if thrown.
expected_connection.server_info()
self.assertEqual(expected_connection, actual_connection) self.assertEqual(expected_connection, actual_connection)
@ -484,19 +480,11 @@ class ConnectionTest(unittest.TestCase):
c.admin.command("createUser", "username2", pwd="password", roles=["dbOwner"]) c.admin.command("createUser", "username2", pwd="password", roles=["dbOwner"])
# Authentication fails without "authSource" # Authentication fails without "authSource"
if IS_PYMONGO_3: test_conn = connect(
test_conn = connect( 'mongoenginetest', alias='test1',
'mongoenginetest', alias='test1', host='mongodb://username2:password@localhost/mongoenginetest'
host='mongodb://username2:password@localhost/mongoenginetest' )
) self.assertRaises(OperationFailure, test_conn.server_info)
self.assertRaises(OperationFailure, test_conn.server_info)
else:
self.assertRaises(
MongoEngineConnectionError,
connect, 'mongoenginetest', alias='test1',
host='mongodb://username2:password@localhost/mongoenginetest'
)
self.assertRaises(MongoEngineConnectionError, get_db, 'test1')
# Authentication succeeds with "authSource" # Authentication succeeds with "authSource"
authd_conn = connect( authd_conn = connect(
@ -565,44 +553,28 @@ class ConnectionTest(unittest.TestCase):
""" """
conn1 = connect(alias='conn1', host='mongodb://localhost/testing?w=1&j=true') conn1 = connect(alias='conn1', host='mongodb://localhost/testing?w=1&j=true')
conn2 = connect('testing', alias='conn2', w=1, j=True) conn2 = connect('testing', alias='conn2', w=1, j=True)
if IS_PYMONGO_3: self.assertEqual(conn1.write_concern.document, {'w': 1, 'j': True})
self.assertEqual(conn1.write_concern.document, {'w': 1, 'j': True}) self.assertEqual(conn2.write_concern.document, {'w': 1, 'j': True})
self.assertEqual(conn2.write_concern.document, {'w': 1, 'j': True})
else:
self.assertEqual(dict(conn1.write_concern), {'w': 1, 'j': True})
self.assertEqual(dict(conn2.write_concern), {'w': 1, 'j': True})
def test_connect_with_replicaset_via_uri(self): def test_connect_with_replicaset_via_uri(self):
"""Ensure connect() works when specifying a replicaSet via the """Ensure connect() works when specifying a replicaSet via the
MongoDB URI. MongoDB URI.
""" """
if IS_PYMONGO_3: c = connect(host='mongodb://localhost/test?replicaSet=local-rs')
c = connect(host='mongodb://localhost/test?replicaSet=local-rs') db = get_db()
db = get_db() self.assertIsInstance(db, pymongo.database.Database)
self.assertIsInstance(db, pymongo.database.Database) self.assertEqual(db.name, 'test')
self.assertEqual(db.name, 'test')
else:
# PyMongo < v3.x raises an exception:
# "localhost:27017 is not a member of replica set local-rs"
with self.assertRaises(MongoEngineConnectionError):
c = connect(host='mongodb://localhost/test?replicaSet=local-rs')
def test_connect_with_replicaset_via_kwargs(self): def test_connect_with_replicaset_via_kwargs(self):
"""Ensure connect() works when specifying a replicaSet via the """Ensure connect() works when specifying a replicaSet via the
connection kwargs connection kwargs
""" """
if IS_PYMONGO_3: c = connect(replicaset='local-rs')
c = connect(replicaset='local-rs') self.assertEqual(c._MongoClient__options.replica_set_name,
self.assertEqual(c._MongoClient__options.replica_set_name, 'local-rs')
'local-rs') db = get_db()
db = get_db() self.assertIsInstance(db, pymongo.database.Database)
self.assertIsInstance(db, pymongo.database.Database) self.assertEqual(db.name, 'test')
self.assertEqual(db.name, 'test')
else:
# PyMongo < v3.x raises an exception:
# "localhost:27017 is not a member of replica set local-rs"
with self.assertRaises(MongoEngineConnectionError):
c = connect(replicaset='local-rs')
def test_connect_tz_aware(self): def test_connect_tz_aware(self):
connect('mongoenginetest', tz_aware=True) connect('mongoenginetest', tz_aware=True)
@ -618,10 +590,8 @@ class ConnectionTest(unittest.TestCase):
self.assertEqual(d, date_doc.the_date) self.assertEqual(d, date_doc.the_date)
def test_read_preference_from_parse(self): def test_read_preference_from_parse(self):
if IS_PYMONGO_3: conn = connect(host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred")
from pymongo import ReadPreference self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED)
conn = connect(host="mongodb://a1.vpc,a2.vpc,a3.vpc/prod?readPreference=secondaryPreferred")
self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_PREFERRED)
def test_multiple_connection_settings(self): def test_multiple_connection_settings(self):
connect('mongoenginetest', alias='t1', host="localhost") connect('mongoenginetest', alias='t1', host="localhost")

View File

@ -1,23 +1,16 @@
import unittest import unittest
from pymongo import ReadPreference from pymongo import ReadPreference
from pymongo import MongoClient
from mongoengine.pymongo_support import IS_PYMONGO_3
if IS_PYMONGO_3:
from pymongo import MongoClient
CONN_CLASS = MongoClient
READ_PREF = ReadPreference.SECONDARY
else:
from pymongo import ReplicaSetConnection
CONN_CLASS = ReplicaSetConnection
READ_PREF = ReadPreference.SECONDARY_ONLY
import mongoengine import mongoengine
from mongoengine import *
from mongoengine.connection import MongoEngineConnectionError from mongoengine.connection import MongoEngineConnectionError
CONN_CLASS = MongoClient
READ_PREF = ReadPreference.SECONDARY
class ConnectionTest(unittest.TestCase): class ConnectionTest(unittest.TestCase):
def setUp(self): def setUp(self):
@ -35,7 +28,7 @@ class ConnectionTest(unittest.TestCase):
""" """
try: try:
conn = connect(db='mongoenginetest', conn = mongoengine.connect(db='mongoenginetest',
host="mongodb://localhost/mongoenginetest?replicaSet=rs", host="mongodb://localhost/mongoenginetest?replicaSet=rs",
read_preference=READ_PREF) read_preference=READ_PREF)
except MongoEngineConnectionError as e: except MongoEngineConnectionError as e:

View File

@ -6,7 +6,6 @@ from nose.plugins.skip import SkipTest
from mongoengine import connect from mongoengine import connect
from mongoengine.connection import get_db, disconnect_all from mongoengine.connection import get_db, disconnect_all
from mongoengine.mongodb_support import get_mongodb_version, MONGODB_26, MONGODB_3, MONGODB_32, MONGODB_34 from mongoengine.mongodb_support import get_mongodb_version, MONGODB_26, MONGODB_3, MONGODB_32, MONGODB_34
from mongoengine.pymongo_support import IS_PYMONGO_3
MONGO_TEST_DB = 'mongoenginetest' # standard name for the test database MONGO_TEST_DB = 'mongoenginetest' # standard name for the test database
@ -80,19 +79,3 @@ def requires_mongodb_gte_3(func):
lower than v3.0. lower than v3.0.
""" """
return _decorated_with_ver_requirement(func, MONGODB_3, oper=operator.ge) return _decorated_with_ver_requirement(func, MONGODB_3, oper=operator.ge)
def skip_pymongo3(f):
"""Raise a SkipTest exception if we're running a test against
PyMongo v3.x.
"""
def _inner(*args, **kwargs):
if IS_PYMONGO_3:
raise SkipTest("Useless with PyMongo 3+")
return f(*args, **kwargs)
_inner.__name__ = f.__name__
_inner.__doc__ = f.__doc__
return _inner