From a34fa74eaa74a8e0eede9cbb109dae2136f075c7 Mon Sep 17 00:00:00 2001 From: mrigal Date: Thu, 9 Apr 2015 03:40:42 +0200 Subject: [PATCH] fix connection problems with pymongo3 and added tests --- mongoengine/connection.py | 10 ++++++--- tests/test_connection.py | 32 +++++++++++++++++++++++------ tests/test_replicaset_connection.py | 27 +++++++++++++++++++----- 3 files changed, 55 insertions(+), 14 deletions(-) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 5e18efb7..31f4cbcc 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,4 +1,5 @@ -from pymongo import MongoClient, MongoReplicaSetClient, uri_parser +import pymongo +from pymongo import MongoClient, ReadPreference, uri_parser __all__ = ['ConnectionError', 'connect', 'register_connection', @@ -6,6 +7,10 @@ __all__ = ['ConnectionError', 'connect', 'register_connection', DEFAULT_CONNECTION_NAME = 'default' +if pymongo.version_tuple[0] >= 3: + READ_PREFERENCE = ReadPreference.SECONDARY_PREFERRED +else: + READ_PREFERENCE = False class ConnectionError(Exception): @@ -18,7 +23,7 @@ _dbs = {} def register_connection(alias, name=None, host=None, port=None, - read_preference=False, + read_preference=READ_PREFERENCE, username=None, password=None, authentication_source=None, **kwargs): """Add a connection. @@ -109,7 +114,6 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): # Discard replicaSet if not base string if not isinstance(conn_settings['replicaSet'], basestring): conn_settings.pop('replicaSet', None) - connection_class = MongoReplicaSetClient try: connection = None diff --git a/tests/test_connection.py b/tests/test_connection.py index 9204d80c..88e03994 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,4 +1,6 @@ import sys +from time import sleep + sys.path[0:0] = [""] try: @@ -19,6 +21,13 @@ import mongoengine.connection from mongoengine.connection import get_db, get_connection, ConnectionError +def get_tz_awareness(connection): + if pymongo.version_tuple[0] < 3: + return connection.tz_aware + else: + return connection.codec_options.tz_aware + + class ConnectionTest(unittest.TestCase): def tearDown(self): @@ -51,6 +60,9 @@ class ConnectionTest(unittest.TestCase): connect('mongoenginetest', alias='testdb2') actual_connection = get_connection('testdb2') + + # horrible, but since PyMongo3+, connection are created asynchronously + sleep(0.1) self.assertEqual(expected_connection, actual_connection) def test_connect_uri(self): @@ -64,7 +76,8 @@ class ConnectionTest(unittest.TestCase): c.admin.authenticate("admin", "password") c.mongoenginetest.add_user("username", "password") - self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') + if pymongo.version_tuple[0] < 3: + self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest') @@ -90,7 +103,8 @@ class ConnectionTest(unittest.TestCase): c.admin.authenticate("admin", "password") c.mongoenginetest.add_user("username", "password") - self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') + if pymongo.version_tuple[0] < 3: + self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') connect("mongoenginetest", host='mongodb://localhost/') @@ -160,11 +174,11 @@ class ConnectionTest(unittest.TestCase): connect('mongoenginetest', alias='t1', tz_aware=True) conn = get_connection('t1') - self.assertTrue(conn.tz_aware) + self.assertTrue(get_tz_awareness(conn)) connect('mongoenginetest2', alias='t2') conn = get_connection('t2') - self.assertFalse(conn.tz_aware) + self.assertFalse(get_tz_awareness(conn)) def test_datetime(self): connect('mongoenginetest', tz_aware=True) @@ -188,8 +202,14 @@ class ConnectionTest(unittest.TestCase): self.assertEqual(len(mongo_connections.items()), 2) self.assertTrue('t1' in mongo_connections.keys()) self.assertTrue('t2' in mongo_connections.keys()) - self.assertEqual(mongo_connections['t1'].host, 'localhost') - self.assertEqual(mongo_connections['t2'].host, '127.0.0.1') + if pymongo.version_tuple[0] < 3: + self.assertEqual(mongo_connections['t1'].host, 'localhost') + self.assertEqual(mongo_connections['t2'].host, '127.0.0.1') + else: + # horrible, but since PyMongo3+, connection are created asynchronously + sleep(0.1) + self.assertEqual(mongo_connections['t1'].address[0], 'localhost') + self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1') if __name__ == '__main__': diff --git a/tests/test_replicaset_connection.py b/tests/test_replicaset_connection.py index d27960f7..b3a7e1bf 100644 --- a/tests/test_replicaset_connection.py +++ b/tests/test_replicaset_connection.py @@ -3,15 +3,29 @@ sys.path[0:0] = [""] import unittest import pymongo -from pymongo import ReadPreference, ReplicaSetConnection +from pymongo import ReadPreference + +if pymongo.version_tuple[0] >= 3: + from pymongo import MongoClient + CONN_CLASS = MongoClient + READ_PREF = ReadPreference.SECONDARY +else: + from pymongo import ReplicaSetConnection + CONN_CLASS = ReplicaSetConnection + READ_PREF = ReadPreference.SECONDARY_ONLY import mongoengine from mongoengine import * -from mongoengine.connection import get_db, get_connection, ConnectionError +from mongoengine.connection import ConnectionError class ConnectionTest(unittest.TestCase): + def setUp(self): + mongoengine.connection._connection_settings = {} + mongoengine.connection._connections = {} + mongoengine.connection._dbs = {} + def tearDown(self): mongoengine.connection._connection_settings = {} mongoengine.connection._connections = {} @@ -22,14 +36,17 @@ class ConnectionTest(unittest.TestCase): """ try: - conn = connect(db='mongoenginetest', host="mongodb://localhost/mongoenginetest?replicaSet=rs", read_preference=ReadPreference.SECONDARY_ONLY) + conn = connect(db='mongoenginetest', + host="mongodb://localhost/mongoenginetest?replicaSet=rs", + read_preference=READ_PREF) except ConnectionError, e: return - if not isinstance(conn, ReplicaSetConnection): + if not isinstance(conn, CONN_CLASS): + # really??? return - self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_ONLY) + self.assertEqual(conn.read_preference, READ_PREF) if __name__ == '__main__': unittest.main()