Add tests
This commit is contained in:
		| @@ -245,6 +245,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | |||||||
|         } |         } | ||||||
|  |  | ||||||
|     raw_conn_settings = _connection_settings[alias].copy() |     raw_conn_settings = _connection_settings[alias].copy() | ||||||
|  |  | ||||||
|     # Retrieve a copy of the connection settings associated with the requested |     # Retrieve a copy of the connection settings associated with the requested | ||||||
|     # alias and remove the database name and authentication info (we don't |     # alias and remove the database name and authentication info (we don't | ||||||
|     # care about them at this point). |     # care about them at this point). | ||||||
| @@ -269,22 +270,20 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): | |||||||
|     if existing_connection: |     if existing_connection: | ||||||
|         _connections[alias] = existing_connection |         _connections[alias] = existing_connection | ||||||
|     else: |     else: | ||||||
|         # Otherwise, create the new connection for this alias. Raise |         _create_connection(alias=alias, | ||||||
|         # MongoEngineConnectionError if it can't be established. |                            connection_class=connection_class, | ||||||
|         try: |                            **conn_settings) | ||||||
|             _connections[alias] = connection_class(**conn_settings) |  | ||||||
|         except Exception as e: |  | ||||||
|             raise MongoEngineConnectionError( |  | ||||||
|                 'Cannot connect to database %s :\n%s' % (alias, e)) |  | ||||||
|  |  | ||||||
|     return _connections[alias] |     return _connections[alias] | ||||||
|  |  | ||||||
|  |  | ||||||
| def _create_connection(connection_class, **connection_settings): | def _create_connection(alias, connection_class, **connection_settings): | ||||||
|     # Otherwise, create the new connection for this alias. Raise |     """ | ||||||
|     # MongoEngineConnectionError if it can't be established. |     Create the new connection for this alias. Raise | ||||||
|  |     MongoEngineConnectionError if it can't be established. | ||||||
|  |     """ | ||||||
|     try: |     try: | ||||||
|         _connections[alias] = connection_class(**conn_settings) |         _connections[alias] = connection_class(**connection_settings) | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         raise MongoEngineConnectionError( |         raise MongoEngineConnectionError( | ||||||
|             'Cannot connect to database %s :\n%s' % (alias, e)) |             'Cannot connect to database %s :\n%s' % (alias, e)) | ||||||
| @@ -300,7 +299,7 @@ def _find_existing_connection(connection_settings): | |||||||
|     :param connection_settings: the settings of the new connection |     :param connection_settings: the settings of the new connection | ||||||
|     :return: An existing connection or None |     :return: An existing connection or None | ||||||
|     """ |     """ | ||||||
|     connection_settings_iterator = ( |     connection_settings_bis = ( | ||||||
|         (db_alias, settings.copy()) |         (db_alias, settings.copy()) | ||||||
|         for db_alias, settings in _connection_settings.items() |         for db_alias, settings in _connection_settings.items() | ||||||
|     ) |     ) | ||||||
| @@ -312,7 +311,7 @@ def _find_existing_connection(connection_settings): | |||||||
|         return {k: v for k, v in settings_dict.items() if k != 'name'} |         return {k: v for k, v in settings_dict.items() if k != 'name'} | ||||||
|  |  | ||||||
|     cleaned_conn_settings = _clean_settings(connection_settings) |     cleaned_conn_settings = _clean_settings(connection_settings) | ||||||
|     for db_alias, connection_settings in connection_settings_iterator: |     for db_alias, connection_settings in connection_settings_bis: | ||||||
|         db_conn_settings = _clean_settings(connection_settings) |         db_conn_settings = _clean_settings(connection_settings) | ||||||
|         if cleaned_conn_settings == db_conn_settings and _connections.get(db_alias): |         if cleaned_conn_settings == db_conn_settings and _connections.get(db_alias): | ||||||
|             return _connections[db_alias] |             return _connections[db_alias] | ||||||
|   | |||||||
| @@ -611,6 +611,16 @@ class ConnectionTest(unittest.TestCase): | |||||||
|         self.assertEqual(mongo_connections['t1'].address[0], 'localhost') |         self.assertEqual(mongo_connections['t1'].address[0], 'localhost') | ||||||
|         self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1') |         self.assertEqual(mongo_connections['t2'].address[0], '127.0.0.1') | ||||||
|  |  | ||||||
|  |     def test_connect_2_databases_uses_same_client_if_only_dbname_differs(self): | ||||||
|  |         c1 = connect(alias='testdb1', db='testdb1') | ||||||
|  |         c2 = connect(alias='testdb2', db='testdb2') | ||||||
|  |         self.assertIs(c1, c2) | ||||||
|  |  | ||||||
|  |     def test_connect_2_databases_uses_different_client_if_different_parameters(self): | ||||||
|  |         c1 = connect(alias='testdb1', db='testdb1', username='u1') | ||||||
|  |         c2 = connect(alias='testdb2', db='testdb2', username='u2') | ||||||
|  |         self.assertIsNot(c1, c2) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     unittest.main() |     unittest.main() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user