Support for authentication mechanism #905 (#1333)

This commit is contained in:
zeez 2016-12-04 02:08:24 +05:00 committed by Stefan Wójcik
parent 4f87db784e
commit 02fb3b9315

View File

@ -6,6 +6,7 @@ __all__ = ['ConnectionError', 'connect', 'register_connection',
DEFAULT_CONNECTION_NAME = 'default' DEFAULT_CONNECTION_NAME = 'default'
if IS_PYMONGO_3: if IS_PYMONGO_3:
READ_PREFERENCE = ReadPreference.PRIMARY READ_PREFERENCE = ReadPreference.PRIMARY
else: else:
@ -25,6 +26,7 @@ _dbs = {}
def register_connection(alias, name=None, host=None, port=None, def register_connection(alias, name=None, host=None, port=None,
read_preference=READ_PREFERENCE, read_preference=READ_PREFERENCE,
username=None, password=None, authentication_source=None, username=None, password=None, authentication_source=None,
authentication_mechanism=None,
**kwargs): **kwargs):
"""Add a connection. """Add a connection.
@ -38,6 +40,9 @@ def register_connection(alias, name=None, host=None, port=None,
:param username: username to authenticate with :param username: username to authenticate with
:param password: password to authenticate with :param password: password to authenticate with
:param authentication_source: database to authenticate against :param authentication_source: database to authenticate against
:param authentication_mechanism: database authentication mechanisms.
By default, use SCRAM-SHA-1 with MongoDB 3.0 and later,
MONGODB-CR (MongoDB Challenge Response protocol) for older servers.
:param is_mock: explicitly use mongomock for this connection :param is_mock: explicitly use mongomock for this connection
(can also be done by using `mongomock://` as db host prefix) (can also be done by using `mongomock://` as db host prefix)
:param kwargs: allow ad-hoc parameters to be passed into the pymongo driver :param kwargs: allow ad-hoc parameters to be passed into the pymongo driver
@ -53,9 +58,11 @@ def register_connection(alias, name=None, host=None, port=None,
'read_preference': read_preference, 'read_preference': read_preference,
'username': username, 'username': username,
'password': password, 'password': password,
'authentication_source': authentication_source 'authentication_source': authentication_source,
'authentication_mechanism': authentication_mechanism
} }
# Handle uri style connections
conn_host = conn_settings['host'] conn_host = conn_settings['host']
# host can be a list or a string, so if string, force to a list # host can be a list or a string, so if string, force to a list
if isinstance(conn_host, str_types): if isinstance(conn_host, str_types):
@ -82,6 +89,8 @@ def register_connection(alias, name=None, host=None, port=None,
conn_settings['replicaSet'] = True conn_settings['replicaSet'] = True
if 'authsource' in uri_options: if 'authsource' in uri_options:
conn_settings['authentication_source'] = uri_options['authsource'] conn_settings['authentication_source'] = uri_options['authsource']
if 'authmechanism' in uri_options:
conn_settings['authentication_mechanism'] = uri_options['authmechanism']
else: else:
resolved_hosts.append(entity) resolved_hosts.append(entity)
conn_settings['host'] = resolved_hosts conn_settings['host'] = resolved_hosts
@ -123,6 +132,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
conn_settings.pop('username', None) conn_settings.pop('username', None)
conn_settings.pop('password', None) conn_settings.pop('password', None)
conn_settings.pop('authentication_source', None) conn_settings.pop('authentication_source', None)
conn_settings.pop('authentication_mechanism', None)
is_mock = conn_settings.pop('is_mock', None) is_mock = conn_settings.pop('is_mock', None)
if is_mock: if is_mock:
@ -157,6 +167,7 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
connection_settings.pop('username', None) connection_settings.pop('username', None)
connection_settings.pop('password', None) connection_settings.pop('password', None)
connection_settings.pop('authentication_source', None) connection_settings.pop('authentication_source', None)
connection_settings.pop('authentication_mechanism', None)
if conn_settings == connection_settings and _connections.get(db_alias, None): if conn_settings == connection_settings and _connections.get(db_alias, None):
connection = _connections[db_alias] connection = _connections[db_alias]
break break
@ -176,11 +187,13 @@ def get_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
conn = get_connection(alias) conn = get_connection(alias)
conn_settings = _connection_settings[alias] conn_settings = _connection_settings[alias]
db = conn[conn_settings['name']] db = conn[conn_settings['name']]
auth_kwargs = {'source': conn_settings['authentication_source']}
if conn_settings['authentication_mechanism'] is not None:
auth_kwargs['mechanism'] = conn_settings['authentication_mechanism']
# Authenticate if necessary # Authenticate if necessary
if conn_settings['username'] and conn_settings['password']: if conn_settings['username'] and (conn_settings['password'] or
db.authenticate(conn_settings['username'], conn_settings['authentication_mechanism'] == 'MONGODB-X509'):
conn_settings['password'], db.authenticate(conn_settings['username'], conn_settings['password'], **auth_kwargs)
source=conn_settings['authentication_source'])
_dbs[alias] = db _dbs[alias] = db
return _dbs[alias] return _dbs[alias]