From f168682a68451e7e55fb993ddfaff09ea3358fba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stefan=20W=C3=B3jcik?= Date: Mon, 5 Dec 2016 22:31:00 -0500 Subject: [PATCH] Dont let the MongoDB URI override connection settings it doesnt explicitly specify (#1421) --- mongoengine/connection.py | 23 ++++++++++++++-------- tests/test_connection.py | 41 ++++++++++++++++++++++++++------------- 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/mongoengine/connection.py b/mongoengine/connection.py index b4e852d4..ee21ba90 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -25,7 +25,8 @@ _dbs = {} def register_connection(alias, name=None, host=None, port=None, read_preference=READ_PREFERENCE, - username=None, password=None, authentication_source=None, + username=None, password=None, + authentication_source=None, authentication_mechanism=None, **kwargs): """Add a connection. @@ -70,20 +71,26 @@ def register_connection(alias, name=None, host=None, port=None, resolved_hosts = [] for entity in conn_host: - # Handle uri style connections + + # Handle Mongomock if entity.startswith('mongomock://'): conn_settings['is_mock'] = True # `mongomock://` is not a valid url prefix and must be replaced by `mongodb://` resolved_hosts.append(entity.replace('mongomock://', 'mongodb://', 1)) + + # Handle URI style connections, only updating connection params which + # were explicitly specified in the URI. elif '://' in entity: uri_dict = uri_parser.parse_uri(entity) resolved_hosts.append(entity) - conn_settings.update({ - 'name': uri_dict.get('database') or name, - 'username': uri_dict.get('username'), - 'password': uri_dict.get('password'), - 'read_preference': read_preference, - }) + + if uri_dict.get('database'): + conn_settings['name'] = uri_dict.get('database') + + for param in ('read_preference', 'username', 'password'): + if uri_dict.get(param): + conn_settings[param] = uri_dict[param] + uri_options = uri_dict['options'] if 'replicaset' in uri_options: conn_settings['replicaSet'] = True diff --git a/tests/test_connection.py b/tests/test_connection.py index 1d422d09..e6431891 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -174,19 +174,9 @@ class ConnectionTest(unittest.TestCase): c.mongoenginetest.system.users.remove({}) def test_connect_uri_without_db(self): - """Ensure connect() method works properly with uri's without database_name + """Ensure connect() method works properly if the URI doesn't + include a database name. """ - c = connect(db='mongoenginetest', alias='admin') - c.admin.system.users.remove({}) - c.mongoenginetest.system.users.remove({}) - - c.admin.add_user("admin", "password") - c.admin.authenticate("admin", "password") - c.mongoenginetest.add_user("username", "password") - - if not IS_PYMONGO_3: - self.assertRaises(ConnectionError, connect, "testdb_uri_bad", host='mongodb://test:password@localhost') - connect("mongoenginetest", host='mongodb://localhost/') conn = get_connection() @@ -196,8 +186,31 @@ class ConnectionTest(unittest.TestCase): self.assertTrue(isinstance(db, pymongo.database.Database)) self.assertEqual(db.name, 'mongoenginetest') - c.admin.system.users.remove({}) - c.mongoenginetest.system.users.remove({}) + def test_connect_uri_default_db(self): + """Ensure connect() defaults to the right database name if + the URI and the database_name don't explicitly specify it. + """ + connect(host='mongodb://localhost/') + + conn = get_connection() + self.assertTrue(isinstance(conn, pymongo.mongo_client.MongoClient)) + + db = get_db() + self.assertTrue(isinstance(db, pymongo.database.Database)) + self.assertEqual(db.name, 'test') + + def test_uri_without_credentials_doesnt_override_conn_settings(self): + """Ensure connect() uses the username & password params if the URI + doesn't explicitly specify them. + """ + c = connect(host='mongodb://localhost/mongoenginetest', + username='user', + password='pass') + + # OperationFailure means that mongoengine attempted authentication + # w/ the provided username/password and failed - that's the desired + # behavior. If the MongoDB URI would override the credentials + self.assertRaises(OperationFailure, get_db) def test_connect_uri_with_authsource(self): """Ensure that the connect() method works well with