Fix the issue that the same MongoClient gets re-used in case we connect to 2 databases on the same host (problematic when different users authenticate)
This commit is contained in:
parent
048a045966
commit
9634e44343
@ -235,7 +235,6 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
raise MongoEngineConnectionError(msg)
|
||||
|
||||
def _clean_settings(settings_dict):
|
||||
# set literal more efficient than calling set function
|
||||
irrelevant_fields_set = {
|
||||
'name', 'username', 'password',
|
||||
'authentication_source', 'authentication_mechanism'
|
||||
@ -245,10 +244,11 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
if k not in irrelevant_fields_set
|
||||
}
|
||||
|
||||
raw_conn_settings = _connection_settings[alias].copy()
|
||||
# Retrieve a copy of the connection settings associated with the requested
|
||||
# alias and remove the database name and authentication info (we don't
|
||||
# care about them at this point).
|
||||
conn_settings = _clean_settings(_connection_settings[alias].copy())
|
||||
conn_settings = _clean_settings(raw_conn_settings)
|
||||
|
||||
# Determine if we should use PyMongo's or mongomock's MongoClient.
|
||||
is_mock = conn_settings.pop('is_mock', False)
|
||||
@ -262,19 +262,8 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
else:
|
||||
connection_class = MongoClient
|
||||
|
||||
# Iterate over all of the connection settings and if a connection with
|
||||
# the same parameters is already established, use it instead of creating
|
||||
# a new one.
|
||||
existing_connection = None
|
||||
connection_settings_iterator = (
|
||||
(db_alias, settings.copy())
|
||||
for db_alias, settings in _connection_settings.items()
|
||||
)
|
||||
for db_alias, connection_settings in connection_settings_iterator:
|
||||
connection_settings = _clean_settings(connection_settings)
|
||||
if conn_settings == connection_settings and _connections.get(db_alias):
|
||||
existing_connection = _connections[db_alias]
|
||||
break
|
||||
# Re-use existing connection if one is suitable
|
||||
existing_connection = _find_existing_connection(raw_conn_settings)
|
||||
|
||||
# If an existing connection was found, assign it to the new alias
|
||||
if existing_connection:
|
||||
@ -291,6 +280,44 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
return _connections[alias]
|
||||
|
||||
|
||||
def _create_connection(connection_class, **connection_settings):
|
||||
# Otherwise, create the new connection for this alias. Raise
|
||||
# MongoEngineConnectionError if it can't be established.
|
||||
try:
|
||||
_connections[alias] = connection_class(**conn_settings)
|
||||
except Exception as e:
|
||||
raise MongoEngineConnectionError(
|
||||
'Cannot connect to database %s :\n%s' % (alias, e))
|
||||
|
||||
|
||||
def _find_existing_connection(connection_settings):
|
||||
"""
|
||||
Check if an existing connection could be reused
|
||||
|
||||
Iterate over all of the connection settings and if an existing connection
|
||||
with the same parameters is suitable, return it
|
||||
|
||||
:param connection_settings: the settings of the new connection
|
||||
:return: An existing connection or None
|
||||
"""
|
||||
connection_settings_iterator = (
|
||||
(db_alias, settings.copy())
|
||||
for db_alias, settings in _connection_settings.items()
|
||||
)
|
||||
|
||||
def _clean_settings(settings_dict):
|
||||
# Only remove the name but it's important to
|
||||
# keep the username/password/authentication_source/authentication_mechanism
|
||||
# to identify if the connection could be shared (cfr https://github.com/MongoEngine/mongoengine/issues/2047)
|
||||
return {k: v for k, v in settings_dict.items() if k != 'name'}
|
||||
|
||||
cleaned_conn_settings = _clean_settings(connection_settings)
|
||||
for db_alias, connection_settings in connection_settings_iterator:
|
||||
db_conn_settings = _clean_settings(connection_settings)
|
||||
if cleaned_conn_settings == db_conn_settings and _connections.get(db_alias):
|
||||
return _connections[db_alias]
|
||||
|
||||
|
||||
def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
if reconnect:
|
||||
disconnect(alias)
|
||||
|
Loading…
x
Reference in New Issue
Block a user