diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 5fae9507..6249225c 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -245,6 +245,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): } raw_conn_settings = _connection_settings[alias].copy() + # Retrieve a copy of the connection settings associated with the requested # alias and remove the database name and authentication info (we don't # care about them at this point). @@ -269,22 +270,20 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): if existing_connection: _connections[alias] = existing_connection else: - # Otherwise, create the new connection for this alias. Raise - # MongoEngineConnectionError if it can't be established. - try: - _connections[alias] = connection_class(**conn_settings) - except Exception as e: - raise MongoEngineConnectionError( - 'Cannot connect to database %s :\n%s' % (alias, e)) + _create_connection(alias=alias, + connection_class=connection_class, + **conn_settings) return _connections[alias] -def _create_connection(connection_class, **connection_settings): - # Otherwise, create the new connection for this alias. Raise - # MongoEngineConnectionError if it can't be established. +def _create_connection(alias, connection_class, **connection_settings): + """ + Create the new connection for this alias. Raise + MongoEngineConnectionError if it can't be established. + """ try: - _connections[alias] = connection_class(**conn_settings) + _connections[alias] = connection_class(**connection_settings) except Exception as e: raise MongoEngineConnectionError( 'Cannot connect to database %s :\n%s' % (alias, e)) @@ -300,7 +299,7 @@ def _find_existing_connection(connection_settings): :param connection_settings: the settings of the new connection :return: An existing connection or None """ - connection_settings_iterator = ( + connection_settings_bis = ( (db_alias, settings.copy()) for db_alias, settings in _connection_settings.items() ) @@ -312,7 +311,7 @@ def _find_existing_connection(connection_settings): return {k: v for k, v in settings_dict.items() if k != 'name'} cleaned_conn_settings = _clean_settings(connection_settings) - for db_alias, connection_settings in connection_settings_iterator: + for db_alias, connection_settings in connection_settings_bis: db_conn_settings = _clean_settings(connection_settings) if cleaned_conn_settings == db_conn_settings and _connections.get(db_alias): return _connections[db_alias] diff --git a/tests/test_connection.py b/tests/test_connection.py index 5473b8a0..d3fcc395 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -611,6 +611,16 @@ class ConnectionTest(unittest.TestCase): self.assertEqual(mongo_connections['t1'].address[0], 'localhost') self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1') + def test_connect_2_databases_uses_same_client_if_only_dbname_differs(self): + c1 = connect(alias='testdb1', db='testdb1') + c2 = connect(alias='testdb2', db='testdb2') + self.assertIs(c1, c2) + + def test_connect_2_databases_uses_different_client_if_different_parameters(self): + c1 = connect(alias='testdb1', db='testdb1', username='u1') + c2 = connect(alias='testdb2', db='testdb2', username='u2') + self.assertIsNot(c1, c2) + if __name__ == '__main__': unittest.main()