fix connection problems with pymongo3 and added tests

This commit is contained in:
mrigal 2015-04-09 03:40:42 +02:00 committed by Matthieu Rigal
parent d6b2d8dcb5
commit a34fa74eaa
3 changed files with 55 additions and 14 deletions

View File

@ -1,4 +1,5 @@
from pymongo import MongoClient, MongoReplicaSetClient, uri_parser import pymongo
from pymongo import MongoClient, ReadPreference, uri_parser
__all__ = ['ConnectionError', 'connect', 'register_connection', __all__ = ['ConnectionError', 'connect', 'register_connection',
@ -6,6 +7,10 @@ __all__ = ['ConnectionError', 'connect', 'register_connection',
DEFAULT_CONNECTION_NAME = 'default' DEFAULT_CONNECTION_NAME = 'default'
if pymongo.version_tuple[0] >= 3:
READ_PREFERENCE = ReadPreference.SECONDARY_PREFERRED
else:
READ_PREFERENCE = False
class ConnectionError(Exception): class ConnectionError(Exception):
@ -18,7 +23,7 @@ _dbs = {}
def register_connection(alias, name=None, host=None, port=None, def register_connection(alias, name=None, host=None, port=None,
read_preference=False, read_preference=READ_PREFERENCE,
username=None, password=None, authentication_source=None, username=None, password=None, authentication_source=None,
**kwargs): **kwargs):
"""Add a connection. """Add a connection.
@ -109,7 +114,6 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
# Discard replicaSet if not base string # Discard replicaSet if not base string
if not isinstance(conn_settings['replicaSet'], basestring): if not isinstance(conn_settings['replicaSet'], basestring):
conn_settings.pop('replicaSet', None) conn_settings.pop('replicaSet', None)
connection_class = MongoReplicaSetClient
try: try:
connection = None connection = None

View File

@ -1,4 +1,6 @@
import sys import sys
from time import sleep
sys.path[0:0] = [""] sys.path[0:0] = [""]
try: try:
@ -19,6 +21,13 @@ import mongoengine.connection
from mongoengine.connection import get_db, get_connection, ConnectionError from mongoengine.connection import get_db, get_connection, ConnectionError
def get_tz_awareness(connection):
if pymongo.version_tuple[0] < 3:
return connection.tz_aware
else:
return connection.codec_options.tz_aware
class ConnectionTest(unittest.TestCase): class ConnectionTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
@ -51,6 +60,9 @@ class ConnectionTest(unittest.TestCase):
connect('mongoenginetest', alias='testdb2') connect('mongoenginetest', alias='testdb2')
actual_connection = get_connection('testdb2') actual_connection = get_connection('testdb2')
# horrible, but since PyMongo3+, connection are created asynchronously
sleep(0.1)
self.assertEqual(expected_connection, actual_connection) self.assertEqual(expected_connection, actual_connection)
def test_connect_uri(self): def test_connect_uri(self):
@ -64,6 +76,7 @@ class ConnectionTest(unittest.TestCase):
c.admin.authenticate("admin", "password") c.admin.authenticate("admin", "password")
c.mongoenginetest.add_user("username", "password") c.mongoenginetest.add_user("username", "password")
if pymongo.version_tuple[0] < 3:
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest') connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest')
@ -90,6 +103,7 @@ class ConnectionTest(unittest.TestCase):
c.admin.authenticate("admin", "password") c.admin.authenticate("admin", "password")
c.mongoenginetest.add_user("username", "password") c.mongoenginetest.add_user("username", "password")
if pymongo.version_tuple[0] < 3:
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
connect("mongoenginetest", host='mongodb://localhost/') connect("mongoenginetest", host='mongodb://localhost/')
@ -160,11 +174,11 @@ class ConnectionTest(unittest.TestCase):
connect('mongoenginetest', alias='t1', tz_aware=True) connect('mongoenginetest', alias='t1', tz_aware=True)
conn = get_connection('t1') conn = get_connection('t1')
self.assertTrue(conn.tz_aware) self.assertTrue(get_tz_awareness(conn))
connect('mongoenginetest2', alias='t2') connect('mongoenginetest2', alias='t2')
conn = get_connection('t2') conn = get_connection('t2')
self.assertFalse(conn.tz_aware) self.assertFalse(get_tz_awareness(conn))
def test_datetime(self): def test_datetime(self):
connect('mongoenginetest', tz_aware=True) connect('mongoenginetest', tz_aware=True)
@ -188,8 +202,14 @@ class ConnectionTest(unittest.TestCase):
self.assertEqual(len(mongo_connections.items()), 2) self.assertEqual(len(mongo_connections.items()), 2)
self.assertTrue('t1' in mongo_connections.keys()) self.assertTrue('t1' in mongo_connections.keys())
self.assertTrue('t2' in mongo_connections.keys()) self.assertTrue('t2' in mongo_connections.keys())
if pymongo.version_tuple[0] < 3:
self.assertEqual(mongo_connections['t1'].host, 'localhost') self.assertEqual(mongo_connections['t1'].host, 'localhost')
self.assertEqual(mongo_connections['t2'].host, '127.0.0.1') self.assertEqual(mongo_connections['t2'].host, '127.0.0.1')
else:
# horrible, but since PyMongo3+, connection are created asynchronously
sleep(0.1)
self.assertEqual(mongo_connections['t1'].address[0], 'localhost')
self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1')
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -3,15 +3,29 @@ sys.path[0:0] = [""]
import unittest import unittest
import pymongo import pymongo
from pymongo import ReadPreference, ReplicaSetConnection from pymongo import ReadPreference
if pymongo.version_tuple[0] >= 3:
from pymongo import MongoClient
CONN_CLASS = MongoClient
READ_PREF = ReadPreference.SECONDARY
else:
from pymongo import ReplicaSetConnection
CONN_CLASS = ReplicaSetConnection
READ_PREF = ReadPreference.SECONDARY_ONLY
import mongoengine import mongoengine
from mongoengine import * from mongoengine import *
from mongoengine.connection import get_db, get_connection, ConnectionError from mongoengine.connection import ConnectionError
class ConnectionTest(unittest.TestCase): class ConnectionTest(unittest.TestCase):
def setUp(self):
mongoengine.connection._connection_settings = {}
mongoengine.connection._connections = {}
mongoengine.connection._dbs = {}
def tearDown(self): def tearDown(self):
mongoengine.connection._connection_settings = {} mongoengine.connection._connection_settings = {}
mongoengine.connection._connections = {} mongoengine.connection._connections = {}
@ -22,14 +36,17 @@ class ConnectionTest(unittest.TestCase):
""" """
try: try:
conn = connect(db='mongoenginetest', host="mongodb://localhost/mongoenginetest?replicaSet=rs", read_preference=ReadPreference.SECONDARY_ONLY) conn = connect(db='mongoenginetest',
host="mongodb://localhost/mongoenginetest?replicaSet=rs",
read_preference=READ_PREF)
except ConnectionError, e: except ConnectionError, e:
return return
if not isinstance(conn, ReplicaSetConnection): if not isinstance(conn, CONN_CLASS):
# really???
return return
self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_ONLY) self.assertEqual(conn.read_preference, READ_PREF)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()