fix connection problems with pymongo3 and added tests

This commit is contained in:
mrigal 2015-04-09 03:40:42 +02:00 committed by Matthieu Rigal
parent d6b2d8dcb5
commit a34fa74eaa
3 changed files with 55 additions and 14 deletions

View File

@ -1,4 +1,5 @@
from pymongo import MongoClient, MongoReplicaSetClient, uri_parser
import pymongo
from pymongo import MongoClient, ReadPreference, uri_parser
__all__ = ['ConnectionError', 'connect', 'register_connection',
@ -6,6 +7,10 @@ __all__ = ['ConnectionError', 'connect', 'register_connection',
DEFAULT_CONNECTION_NAME = 'default'
if pymongo.version_tuple[0] >= 3:
READ_PREFERENCE = ReadPreference.SECONDARY_PREFERRED
else:
READ_PREFERENCE = False
class ConnectionError(Exception):
@ -18,7 +23,7 @@ _dbs = {}
def register_connection(alias, name=None, host=None, port=None,
read_preference=False,
read_preference=READ_PREFERENCE,
username=None, password=None, authentication_source=None,
**kwargs):
"""Add a connection.
@ -109,7 +114,6 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
# Discard replicaSet if not base string
if not isinstance(conn_settings['replicaSet'], basestring):
conn_settings.pop('replicaSet', None)
connection_class = MongoReplicaSetClient
try:
connection = None

View File

@ -1,4 +1,6 @@
import sys
from time import sleep
sys.path[0:0] = [""]
try:
@ -19,6 +21,13 @@ import mongoengine.connection
from mongoengine.connection import get_db, get_connection, ConnectionError
def get_tz_awareness(connection):
if pymongo.version_tuple[0] < 3:
return connection.tz_aware
else:
return connection.codec_options.tz_aware
class ConnectionTest(unittest.TestCase):
def tearDown(self):
@ -51,6 +60,9 @@ class ConnectionTest(unittest.TestCase):
connect('mongoenginetest', alias='testdb2')
actual_connection = get_connection('testdb2')
# horrible, but since PyMongo3+, connection are created asynchronously
sleep(0.1)
self.assertEqual(expected_connection, actual_connection)
def test_connect_uri(self):
@ -64,7 +76,8 @@ class ConnectionTest(unittest.TestCase):
c.admin.authenticate("admin", "password")
c.mongoenginetest.add_user("username", "password")
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
if pymongo.version_tuple[0] < 3:
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest')
@ -90,7 +103,8 @@ class ConnectionTest(unittest.TestCase):
c.admin.authenticate("admin", "password")
c.mongoenginetest.add_user("username", "password")
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
if pymongo.version_tuple[0] < 3:
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
connect("mongoenginetest", host='mongodb://localhost/')
@ -160,11 +174,11 @@ class ConnectionTest(unittest.TestCase):
connect('mongoenginetest', alias='t1', tz_aware=True)
conn = get_connection('t1')
self.assertTrue(conn.tz_aware)
self.assertTrue(get_tz_awareness(conn))
connect('mongoenginetest2', alias='t2')
conn = get_connection('t2')
self.assertFalse(conn.tz_aware)
self.assertFalse(get_tz_awareness(conn))
def test_datetime(self):
connect('mongoenginetest', tz_aware=True)
@ -188,8 +202,14 @@ class ConnectionTest(unittest.TestCase):
self.assertEqual(len(mongo_connections.items()), 2)
self.assertTrue('t1' in mongo_connections.keys())
self.assertTrue('t2' in mongo_connections.keys())
self.assertEqual(mongo_connections['t1'].host, 'localhost')
self.assertEqual(mongo_connections['t2'].host, '127.0.0.1')
if pymongo.version_tuple[0] < 3:
self.assertEqual(mongo_connections['t1'].host, 'localhost')
self.assertEqual(mongo_connections['t2'].host, '127.0.0.1')
else:
# horrible, but since PyMongo3+, connection are created asynchronously
sleep(0.1)
self.assertEqual(mongo_connections['t1'].address[0], 'localhost')
self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1')
if __name__ == '__main__':

View File

@ -3,15 +3,29 @@ sys.path[0:0] = [""]
import unittest
import pymongo
from pymongo import ReadPreference, ReplicaSetConnection
from pymongo import ReadPreference
if pymongo.version_tuple[0] >= 3:
from pymongo import MongoClient
CONN_CLASS = MongoClient
READ_PREF = ReadPreference.SECONDARY
else:
from pymongo import ReplicaSetConnection
CONN_CLASS = ReplicaSetConnection
READ_PREF = ReadPreference.SECONDARY_ONLY
import mongoengine
from mongoengine import *
from mongoengine.connection import get_db, get_connection, ConnectionError
from mongoengine.connection import ConnectionError
class ConnectionTest(unittest.TestCase):
def setUp(self):
mongoengine.connection._connection_settings = {}
mongoengine.connection._connections = {}
mongoengine.connection._dbs = {}
def tearDown(self):
mongoengine.connection._connection_settings = {}
mongoengine.connection._connections = {}
@ -22,14 +36,17 @@ class ConnectionTest(unittest.TestCase):
"""
try:
conn = connect(db='mongoenginetest', host="mongodb://localhost/mongoenginetest?replicaSet=rs", read_preference=ReadPreference.SECONDARY_ONLY)
conn = connect(db='mongoenginetest',
host="mongodb://localhost/mongoenginetest?replicaSet=rs",
read_preference=READ_PREF)
except ConnectionError, e:
return
if not isinstance(conn, ReplicaSetConnection):
if not isinstance(conn, CONN_CLASS):
# really???
return
self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_ONLY)
self.assertEqual(conn.read_preference, READ_PREF)
if __name__ == '__main__':
unittest.main()