fix connection problems with pymongo3 and added tests
This commit is contained in:
parent
d6b2d8dcb5
commit
a34fa74eaa
@ -1,4 +1,5 @@
|
|||||||
from pymongo import MongoClient, MongoReplicaSetClient, uri_parser
|
import pymongo
|
||||||
|
from pymongo import MongoClient, ReadPreference, uri_parser
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['ConnectionError', 'connect', 'register_connection',
|
__all__ = ['ConnectionError', 'connect', 'register_connection',
|
||||||
@ -6,6 +7,10 @@ __all__ = ['ConnectionError', 'connect', 'register_connection',
|
|||||||
|
|
||||||
|
|
||||||
DEFAULT_CONNECTION_NAME = 'default'
|
DEFAULT_CONNECTION_NAME = 'default'
|
||||||
|
if pymongo.version_tuple[0] >= 3:
|
||||||
|
READ_PREFERENCE = ReadPreference.SECONDARY_PREFERRED
|
||||||
|
else:
|
||||||
|
READ_PREFERENCE = False
|
||||||
|
|
||||||
|
|
||||||
class ConnectionError(Exception):
|
class ConnectionError(Exception):
|
||||||
@ -18,7 +23,7 @@ _dbs = {}
|
|||||||
|
|
||||||
|
|
||||||
def register_connection(alias, name=None, host=None, port=None,
|
def register_connection(alias, name=None, host=None, port=None,
|
||||||
read_preference=False,
|
read_preference=READ_PREFERENCE,
|
||||||
username=None, password=None, authentication_source=None,
|
username=None, password=None, authentication_source=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Add a connection.
|
"""Add a connection.
|
||||||
@ -109,7 +114,6 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
|||||||
# Discard replicaSet if not base string
|
# Discard replicaSet if not base string
|
||||||
if not isinstance(conn_settings['replicaSet'], basestring):
|
if not isinstance(conn_settings['replicaSet'], basestring):
|
||||||
conn_settings.pop('replicaSet', None)
|
conn_settings.pop('replicaSet', None)
|
||||||
connection_class = MongoReplicaSetClient
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
connection = None
|
connection = None
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
import sys
|
import sys
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
sys.path[0:0] = [""]
|
sys.path[0:0] = [""]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -19,6 +21,13 @@ import mongoengine.connection
|
|||||||
from mongoengine.connection import get_db, get_connection, ConnectionError
|
from mongoengine.connection import get_db, get_connection, ConnectionError
|
||||||
|
|
||||||
|
|
||||||
|
def get_tz_awareness(connection):
|
||||||
|
if pymongo.version_tuple[0] < 3:
|
||||||
|
return connection.tz_aware
|
||||||
|
else:
|
||||||
|
return connection.codec_options.tz_aware
|
||||||
|
|
||||||
|
|
||||||
class ConnectionTest(unittest.TestCase):
|
class ConnectionTest(unittest.TestCase):
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
@ -51,6 +60,9 @@ class ConnectionTest(unittest.TestCase):
|
|||||||
|
|
||||||
connect('mongoenginetest', alias='testdb2')
|
connect('mongoenginetest', alias='testdb2')
|
||||||
actual_connection = get_connection('testdb2')
|
actual_connection = get_connection('testdb2')
|
||||||
|
|
||||||
|
# horrible, but since PyMongo3+, connection are created asynchronously
|
||||||
|
sleep(0.1)
|
||||||
self.assertEqual(expected_connection, actual_connection)
|
self.assertEqual(expected_connection, actual_connection)
|
||||||
|
|
||||||
def test_connect_uri(self):
|
def test_connect_uri(self):
|
||||||
@ -64,6 +76,7 @@ class ConnectionTest(unittest.TestCase):
|
|||||||
c.admin.authenticate("admin", "password")
|
c.admin.authenticate("admin", "password")
|
||||||
c.mongoenginetest.add_user("username", "password")
|
c.mongoenginetest.add_user("username", "password")
|
||||||
|
|
||||||
|
if pymongo.version_tuple[0] < 3:
|
||||||
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
|
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
|
||||||
|
|
||||||
connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest')
|
connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest')
|
||||||
@ -90,6 +103,7 @@ class ConnectionTest(unittest.TestCase):
|
|||||||
c.admin.authenticate("admin", "password")
|
c.admin.authenticate("admin", "password")
|
||||||
c.mongoenginetest.add_user("username", "password")
|
c.mongoenginetest.add_user("username", "password")
|
||||||
|
|
||||||
|
if pymongo.version_tuple[0] < 3:
|
||||||
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
|
self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost')
|
||||||
|
|
||||||
connect("mongoenginetest", host='mongodb://localhost/')
|
connect("mongoenginetest", host='mongodb://localhost/')
|
||||||
@ -160,11 +174,11 @@ class ConnectionTest(unittest.TestCase):
|
|||||||
connect('mongoenginetest', alias='t1', tz_aware=True)
|
connect('mongoenginetest', alias='t1', tz_aware=True)
|
||||||
conn = get_connection('t1')
|
conn = get_connection('t1')
|
||||||
|
|
||||||
self.assertTrue(conn.tz_aware)
|
self.assertTrue(get_tz_awareness(conn))
|
||||||
|
|
||||||
connect('mongoenginetest2', alias='t2')
|
connect('mongoenginetest2', alias='t2')
|
||||||
conn = get_connection('t2')
|
conn = get_connection('t2')
|
||||||
self.assertFalse(conn.tz_aware)
|
self.assertFalse(get_tz_awareness(conn))
|
||||||
|
|
||||||
def test_datetime(self):
|
def test_datetime(self):
|
||||||
connect('mongoenginetest', tz_aware=True)
|
connect('mongoenginetest', tz_aware=True)
|
||||||
@ -188,8 +202,14 @@ class ConnectionTest(unittest.TestCase):
|
|||||||
self.assertEqual(len(mongo_connections.items()), 2)
|
self.assertEqual(len(mongo_connections.items()), 2)
|
||||||
self.assertTrue('t1' in mongo_connections.keys())
|
self.assertTrue('t1' in mongo_connections.keys())
|
||||||
self.assertTrue('t2' in mongo_connections.keys())
|
self.assertTrue('t2' in mongo_connections.keys())
|
||||||
|
if pymongo.version_tuple[0] < 3:
|
||||||
self.assertEqual(mongo_connections['t1'].host, 'localhost')
|
self.assertEqual(mongo_connections['t1'].host, 'localhost')
|
||||||
self.assertEqual(mongo_connections['t2'].host, '127.0.0.1')
|
self.assertEqual(mongo_connections['t2'].host, '127.0.0.1')
|
||||||
|
else:
|
||||||
|
# horrible, but since PyMongo3+, connection are created asynchronously
|
||||||
|
sleep(0.1)
|
||||||
|
self.assertEqual(mongo_connections['t1'].address[0], 'localhost')
|
||||||
|
self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -3,15 +3,29 @@ sys.path[0:0] = [""]
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import pymongo
|
import pymongo
|
||||||
from pymongo import ReadPreference, ReplicaSetConnection
|
from pymongo import ReadPreference
|
||||||
|
|
||||||
|
if pymongo.version_tuple[0] >= 3:
|
||||||
|
from pymongo import MongoClient
|
||||||
|
CONN_CLASS = MongoClient
|
||||||
|
READ_PREF = ReadPreference.SECONDARY
|
||||||
|
else:
|
||||||
|
from pymongo import ReplicaSetConnection
|
||||||
|
CONN_CLASS = ReplicaSetConnection
|
||||||
|
READ_PREF = ReadPreference.SECONDARY_ONLY
|
||||||
|
|
||||||
import mongoengine
|
import mongoengine
|
||||||
from mongoengine import *
|
from mongoengine import *
|
||||||
from mongoengine.connection import get_db, get_connection, ConnectionError
|
from mongoengine.connection import ConnectionError
|
||||||
|
|
||||||
|
|
||||||
class ConnectionTest(unittest.TestCase):
|
class ConnectionTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
mongoengine.connection._connection_settings = {}
|
||||||
|
mongoengine.connection._connections = {}
|
||||||
|
mongoengine.connection._dbs = {}
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
mongoengine.connection._connection_settings = {}
|
mongoengine.connection._connection_settings = {}
|
||||||
mongoengine.connection._connections = {}
|
mongoengine.connection._connections = {}
|
||||||
@ -22,14 +36,17 @@ class ConnectionTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn = connect(db='mongoenginetest', host="mongodb://localhost/mongoenginetest?replicaSet=rs", read_preference=ReadPreference.SECONDARY_ONLY)
|
conn = connect(db='mongoenginetest',
|
||||||
|
host="mongodb://localhost/mongoenginetest?replicaSet=rs",
|
||||||
|
read_preference=READ_PREF)
|
||||||
except ConnectionError, e:
|
except ConnectionError, e:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not isinstance(conn, ReplicaSetConnection):
|
if not isinstance(conn, CONN_CLASS):
|
||||||
|
# really???
|
||||||
return
|
return
|
||||||
|
|
||||||
self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_ONLY)
|
self.assertEqual(conn.read_preference, READ_PREF)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user