fix connection problems with pymongo3 and added tests
This commit is contained in:
		| @@ -1,4 +1,5 @@ | ||||
| from pymongo import MongoClient, MongoReplicaSetClient, uri_parser | ||||
| import pymongo | ||||
| from pymongo import MongoClient, ReadPreference, uri_parser | ||||
|  | ||||
|  | ||||
| __all__ = ['ConnectionError', 'connect', 'register_connection', | ||||
| @@ -6,6 +7,10 @@ __all__ = ['ConnectionError', 'connect', 'register_connection', | ||||
|  | ||||
|  | ||||
| DEFAULT_CONNECTION_NAME = 'default' | ||||
| if pymongo.version_tuple[0] >= 3: | ||||
|     READ_PREFERENCE = ReadPreference.SECONDARY_PREFERRED | ||||
| else: | ||||
|     READ_PREFERENCE = False | ||||
|  | ||||
|  | ||||
| class ConnectionError(Exception): | ||||
| @@ -18,7 +23,7 @@ _dbs = {} | ||||
|  | ||||
|  | ||||
| def register_connection(alias, name=None, host=None, port=None, | ||||
|                         read_preference=False, | ||||
|                         read_preference=READ_PREFERENCE, | ||||
|                         username=None, password=None, authentication_source=None, | ||||
|                         **kwargs): | ||||
|     """Add a connection. | ||||
| @@ -109,7 +114,6 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | ||||
|             # Discard replicaSet if not base string | ||||
|             if not isinstance(conn_settings['replicaSet'], basestring): | ||||
|                 conn_settings.pop('replicaSet', None) | ||||
|             connection_class = MongoReplicaSetClient | ||||
|  | ||||
|         try: | ||||
|             connection = None | ||||
|   | ||||
| @@ -1,4 +1,6 @@ | ||||
| import sys | ||||
| from time import sleep | ||||
|  | ||||
| sys.path[0:0] = [""] | ||||
|  | ||||
| try: | ||||
| @@ -19,6 +21,13 @@ import mongoengine.connection | ||||
| from mongoengine.connection import get_db, get_connection, ConnectionError | ||||
|  | ||||
|  | ||||
| def get_tz_awareness(connection): | ||||
|     if pymongo.version_tuple[0] < 3: | ||||
|         return connection.tz_aware | ||||
|     else: | ||||
|         return connection.codec_options.tz_aware | ||||
|  | ||||
|  | ||||
| class ConnectionTest(unittest.TestCase): | ||||
|  | ||||
|     def tearDown(self): | ||||
| @@ -51,6 +60,9 @@ class ConnectionTest(unittest.TestCase): | ||||
|  | ||||
|         connect('mongoenginetest', alias='testdb2') | ||||
|         actual_connection = get_connection('testdb2') | ||||
|  | ||||
|         # horrible, but since PyMongo3+, connection are created asynchronously | ||||
|         sleep(0.1) | ||||
|         self.assertEqual(expected_connection, actual_connection) | ||||
|  | ||||
|     def test_connect_uri(self): | ||||
| @@ -64,7 +76,8 @@ class ConnectionTest(unittest.TestCase): | ||||
|         c.admin.authenticate("admin", "password") | ||||
|         c.mongoenginetest.add_user("username", "password") | ||||
|  | ||||
|         self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') | ||||
|         if pymongo.version_tuple[0] < 3: | ||||
|             self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') | ||||
|  | ||||
|         connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest') | ||||
|  | ||||
| @@ -90,7 +103,8 @@ class ConnectionTest(unittest.TestCase): | ||||
|         c.admin.authenticate("admin", "password") | ||||
|         c.mongoenginetest.add_user("username", "password") | ||||
|  | ||||
|         self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') | ||||
|         if pymongo.version_tuple[0] < 3: | ||||
|             self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') | ||||
|  | ||||
|         connect("mongoenginetest", host='mongodb://localhost/') | ||||
|  | ||||
| @@ -160,11 +174,11 @@ class ConnectionTest(unittest.TestCase): | ||||
|         connect('mongoenginetest', alias='t1', tz_aware=True) | ||||
|         conn = get_connection('t1') | ||||
|  | ||||
|         self.assertTrue(conn.tz_aware) | ||||
|         self.assertTrue(get_tz_awareness(conn)) | ||||
|  | ||||
|         connect('mongoenginetest2', alias='t2') | ||||
|         conn = get_connection('t2') | ||||
|         self.assertFalse(conn.tz_aware) | ||||
|         self.assertFalse(get_tz_awareness(conn)) | ||||
|  | ||||
|     def test_datetime(self): | ||||
|         connect('mongoenginetest', tz_aware=True) | ||||
| @@ -188,8 +202,14 @@ class ConnectionTest(unittest.TestCase): | ||||
|         self.assertEqual(len(mongo_connections.items()), 2) | ||||
|         self.assertTrue('t1' in mongo_connections.keys()) | ||||
|         self.assertTrue('t2' in mongo_connections.keys()) | ||||
|         self.assertEqual(mongo_connections['t1'].host, 'localhost') | ||||
|         self.assertEqual(mongo_connections['t2'].host, '127.0.0.1') | ||||
|         if pymongo.version_tuple[0] < 3: | ||||
|             self.assertEqual(mongo_connections['t1'].host, 'localhost') | ||||
|             self.assertEqual(mongo_connections['t2'].host, '127.0.0.1') | ||||
|         else: | ||||
|             # horrible, but since PyMongo3+, connection are created asynchronously | ||||
|             sleep(0.1) | ||||
|             self.assertEqual(mongo_connections['t1'].address[0], 'localhost') | ||||
|             self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1') | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|   | ||||
| @@ -3,15 +3,29 @@ sys.path[0:0] = [""] | ||||
| import unittest | ||||
|  | ||||
| import pymongo | ||||
| from pymongo import ReadPreference, ReplicaSetConnection | ||||
| from pymongo import ReadPreference | ||||
|  | ||||
| if pymongo.version_tuple[0] >= 3: | ||||
|     from pymongo import MongoClient | ||||
|     CONN_CLASS = MongoClient | ||||
|     READ_PREF = ReadPreference.SECONDARY | ||||
| else: | ||||
|     from pymongo import ReplicaSetConnection | ||||
|     CONN_CLASS = ReplicaSetConnection | ||||
|     READ_PREF = ReadPreference.SECONDARY_ONLY | ||||
|  | ||||
| import mongoengine | ||||
| from mongoengine import * | ||||
| from mongoengine.connection import get_db, get_connection, ConnectionError | ||||
| from mongoengine.connection import ConnectionError | ||||
|  | ||||
|  | ||||
| class ConnectionTest(unittest.TestCase): | ||||
|  | ||||
|     def setUp(self): | ||||
|         mongoengine.connection._connection_settings = {} | ||||
|         mongoengine.connection._connections = {} | ||||
|         mongoengine.connection._dbs = {} | ||||
|  | ||||
|     def tearDown(self): | ||||
|         mongoengine.connection._connection_settings = {} | ||||
|         mongoengine.connection._connections = {} | ||||
| @@ -22,14 +36,17 @@ class ConnectionTest(unittest.TestCase): | ||||
|         """ | ||||
|  | ||||
|         try: | ||||
|             conn = connect(db='mongoenginetest', host="mongodb://localhost/mongoenginetest?replicaSet=rs", read_preference=ReadPreference.SECONDARY_ONLY) | ||||
|             conn = connect(db='mongoenginetest', | ||||
|                            host="mongodb://localhost/mongoenginetest?replicaSet=rs", | ||||
|                            read_preference=READ_PREF) | ||||
|         except ConnectionError, e: | ||||
|             return | ||||
|  | ||||
|         if not isinstance(conn, ReplicaSetConnection): | ||||
|         if not isinstance(conn, CONN_CLASS): | ||||
|             # really??? | ||||
|             return | ||||
|  | ||||
|         self.assertEqual(conn.read_preference, ReadPreference.SECONDARY_ONLY) | ||||
|         self.assertEqual(conn.read_preference, READ_PREF) | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     unittest.main() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user