From f7ac8cea9041a93027e7dacd0be674b78b2beeae Mon Sep 17 00:00:00 2001 From: Jeff Tharp Date: Wed, 19 Oct 2016 11:57:02 -0700 Subject: [PATCH] Fix connecting to a list of hosts --- mongoengine/connection.py | 48 +++++++++++++++++++++++---------------- tests/test_connection.py | 34 +++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 19 deletions(-) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 4055a9b6..0974f83b 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,5 +1,5 @@ from pymongo import MongoClient, ReadPreference, uri_parser -from mongoengine.python_support import IS_PYMONGO_3 +from mongoengine.python_support import (IS_PYMONGO_3, str_types) __all__ = ['ConnectionError', 'connect', 'register_connection', 'DEFAULT_CONNECTION_NAME'] @@ -56,25 +56,35 @@ def register_connection(alias, name=None, host=None, port=None, 'authentication_source': authentication_source } - # Handle uri style connections conn_host = conn_settings['host'] - if conn_host.startswith('mongomock://'): - conn_settings['is_mock'] = True - # `mongomock://` is not a valid url prefix and must be replaced by `mongodb://` - conn_settings['host'] = conn_host.replace('mongomock://', 'mongodb://', 1) - elif '://' in conn_host: - uri_dict = uri_parser.parse_uri(conn_host) - conn_settings.update({ - 'name': uri_dict.get('database') or name, - 'username': uri_dict.get('username'), - 'password': uri_dict.get('password'), - 'read_preference': read_preference, - }) - uri_options = uri_dict['options'] - if 'replicaset' in uri_options: - conn_settings['replicaSet'] = True - if 'authsource' in uri_options: - conn_settings['authentication_source'] = uri_options['authsource'] + # host can be a list or a string, so if string, force to a list + if isinstance(conn_host, str_types): + conn_host = [conn_host] + + resolved_hosts = [] + for entity in conn_host: + # Handle uri style connections + if entity.startswith('mongomock://'): + conn_settings['is_mock'] = True + # `mongomock://` is not a valid url prefix and must be replaced by `mongodb://` + resolved_hosts.append(entity.replace('mongomock://', 'mongodb://', 1)) + elif '://' in entity: + uri_dict = uri_parser.parse_uri(entity) + resolved_hosts.append(entity) + conn_settings.update({ + 'name': uri_dict.get('database') or name, + 'username': uri_dict.get('username'), + 'password': uri_dict.get('password'), + 'read_preference': read_preference, + }) + uri_options = uri_dict['options'] + if 'replicaset' in uri_options: + conn_settings['replicaSet'] = True + if 'authsource' in uri_options: + conn_settings['authentication_source'] = uri_options['authsource'] + else: + resolved_hosts.append(entity) + conn_settings['host'] = resolved_hosts # Deprecated parameters that should not be passed on kwargs.pop('slaves', None) diff --git a/tests/test_connection.py b/tests/test_connection.py index b2f7406e..1d422d09 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -88,6 +88,40 @@ class ConnectionTest(unittest.TestCase): conn = get_connection('testdb7') self.assertTrue(isinstance(conn, mongomock.MongoClient)) + def test_connect_with_host_list(self): + """Ensure that the connect() method works when host is a list + + Uses mongomock to test w/o needing multiple mongod/mongos processes + """ + try: + import mongomock + except ImportError: + raise SkipTest('you need mongomock installed to run this testcase') + + connect(host=['mongomock://localhost']) + conn = get_connection() + self.assertTrue(isinstance(conn, mongomock.MongoClient)) + + connect(host=['mongodb://localhost'], is_mock=True, alias='testdb2') + conn = get_connection('testdb2') + self.assertTrue(isinstance(conn, mongomock.MongoClient)) + + connect(host=['localhost'], is_mock=True, alias='testdb3') + conn = get_connection('testdb3') + self.assertTrue(isinstance(conn, mongomock.MongoClient)) + + connect(host=['mongomock://localhost:27017', 'mongomock://localhost:27018'], alias='testdb4') + conn = get_connection('testdb4') + self.assertTrue(isinstance(conn, mongomock.MongoClient)) + + connect(host=['mongodb://localhost:27017', 'mongodb://localhost:27018'], is_mock=True, alias='testdb5') + conn = get_connection('testdb5') + self.assertTrue(isinstance(conn, mongomock.MongoClient)) + + connect(host=['localhost:27017', 'localhost:27018'], is_mock=True, alias='testdb6') + conn = get_connection('testdb6') + self.assertTrue(isinstance(conn, mongomock.MongoClient)) + def test_disconnect(self): """Ensure that the disconnect() method works properly """