cleaner connection code
This commit is contained in:
parent
fa6949eca2
commit
c86155e571
@ -17,6 +17,9 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
class ConnectionError(Exception):
|
class ConnectionError(Exception):
|
||||||
|
"""Error raised when the database connection can't be established or
|
||||||
|
when a connection with a requested alias can't be retrieved.
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@ -125,65 +128,88 @@ def disconnect(alias=DEFAULT_CONNECTION_NAME):
|
|||||||
|
|
||||||
def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||||
global _connections
|
global _connections
|
||||||
|
|
||||||
# Connect to the database if not already connected
|
# Connect to the database if not already connected
|
||||||
if reconnect:
|
if reconnect:
|
||||||
disconnect(alias)
|
disconnect(alias)
|
||||||
|
|
||||||
if alias not in _connections:
|
# If the requested alias already exists in the _connections list, return
|
||||||
if alias not in _connection_settings:
|
# it immediately.
|
||||||
msg = 'Connection with alias "%s" has not been defined' % alias
|
if alias in _connections:
|
||||||
if alias == DEFAULT_CONNECTION_NAME:
|
return _connections[alias]
|
||||||
msg = 'You have not defined a default connection'
|
|
||||||
raise ConnectionError(msg)
|
|
||||||
conn_settings = _connection_settings[alias].copy()
|
|
||||||
|
|
||||||
conn_settings.pop('name', None)
|
# Validate that the requested alias exists in the _connection_settings.
|
||||||
conn_settings.pop('username', None)
|
# Raise ConnectionError if it doesn't.
|
||||||
conn_settings.pop('password', None)
|
if alias not in _connection_settings:
|
||||||
conn_settings.pop('authentication_source', None)
|
msg = 'Connection with alias "%s" has not been defined' % alias
|
||||||
conn_settings.pop('authentication_mechanism', None)
|
if alias == DEFAULT_CONNECTION_NAME:
|
||||||
|
msg = 'You have not defined a default connection'
|
||||||
|
raise ConnectionError(msg)
|
||||||
|
|
||||||
is_mock = conn_settings.pop('is_mock', None)
|
def _clean_settings(settings_dict):
|
||||||
if is_mock:
|
irrelevant_fields = (
|
||||||
# Use MongoClient from mongomock
|
'name', 'username', 'password', 'authentication_source',
|
||||||
try:
|
'authentication_mechanism'
|
||||||
import mongomock
|
)
|
||||||
except ImportError:
|
return dict(
|
||||||
raise RuntimeError('You need mongomock installed '
|
(k, v) for k, v in settings_dict.items()
|
||||||
'to mock MongoEngine.')
|
if k not in irrelevant_fields
|
||||||
connection_class = mongomock.MongoClient
|
)
|
||||||
else:
|
|
||||||
# Use MongoClient from pymongo
|
|
||||||
connection_class = MongoClient
|
|
||||||
|
|
||||||
if 'replicaSet' in conn_settings:
|
# Retrieve a copy of the connection settings associated with the requested
|
||||||
# Discard port since it can't be used on MongoReplicaSetClient
|
# alias and remove the database name and authentication info (we don't
|
||||||
conn_settings.pop('port', None)
|
# care about them at this point).
|
||||||
# Discard replicaSet if not base string
|
conn_settings = _clean_settings(_connection_settings[alias].copy())
|
||||||
if not isinstance(conn_settings['replicaSet'], six.string_types):
|
|
||||||
conn_settings.pop('replicaSet', None)
|
|
||||||
if not IS_PYMONGO_3:
|
|
||||||
connection_class = MongoReplicaSetClient
|
|
||||||
conn_settings['hosts_or_uri'] = conn_settings.pop('host', None)
|
|
||||||
|
|
||||||
|
# Determine if we should use PyMongo's or mongomock's MongoClient.
|
||||||
|
is_mock = conn_settings.pop('is_mock', None)
|
||||||
|
if is_mock:
|
||||||
try:
|
try:
|
||||||
connection = None
|
import mongomock
|
||||||
# check for shared connections
|
except ImportError:
|
||||||
connection_settings_iterator = (
|
raise RuntimeError('You need mongomock installed to mock '
|
||||||
(db_alias, settings.copy()) for db_alias, settings in _connection_settings.iteritems())
|
'MongoEngine.')
|
||||||
for db_alias, connection_settings in connection_settings_iterator:
|
connection_class = mongomock.MongoClient
|
||||||
connection_settings.pop('name', None)
|
else:
|
||||||
connection_settings.pop('username', None)
|
connection_class = MongoClient
|
||||||
connection_settings.pop('password', None)
|
|
||||||
connection_settings.pop('authentication_source', None)
|
|
||||||
connection_settings.pop('authentication_mechanism', None)
|
|
||||||
if conn_settings == connection_settings and _connections.get(db_alias, None):
|
|
||||||
connection = _connections[db_alias]
|
|
||||||
break
|
|
||||||
|
|
||||||
_connections[alias] = connection if connection else connection_class(**conn_settings)
|
# For replica set connections with PyMongo 2.x, use MongoReplicaSetClient
|
||||||
|
# TODO remove this block once we stop supporting PyMongo 2.x.
|
||||||
|
if 'replicaSet' in conn_settings:
|
||||||
|
# Discard port since it can't be used on MongoReplicaSetClient
|
||||||
|
conn_settings.pop('port', None)
|
||||||
|
# Discard replicaSet if it's not a string
|
||||||
|
if not isinstance(conn_settings['replicaSet'], six.string_types):
|
||||||
|
conn_settings.pop('replicaSet', None)
|
||||||
|
if not IS_PYMONGO_3:
|
||||||
|
connection_class = MongoReplicaSetClient
|
||||||
|
conn_settings['hosts_or_uri'] = conn_settings.pop('host', None)
|
||||||
|
|
||||||
|
# Iterate over all of the connection settings and if a connection with
|
||||||
|
# the same parameters is already established, use it instead of creating
|
||||||
|
# a new one.
|
||||||
|
existing_connection = None
|
||||||
|
connection_settings_iterator = (
|
||||||
|
(db_alias, settings.copy())
|
||||||
|
for db_alias, settings in _connection_settings.items()
|
||||||
|
)
|
||||||
|
for db_alias, connection_settings in connection_settings_iterator:
|
||||||
|
connection_settings = _clean_settings(connection_settings)
|
||||||
|
if conn_settings == connection_settings and _connections.get(db_alias):
|
||||||
|
existing_connection = _connections[db_alias]
|
||||||
|
break
|
||||||
|
|
||||||
|
# If an existing connection was found, assign it to the new alias
|
||||||
|
if existing_connection:
|
||||||
|
_connections[alias] = existing_connection
|
||||||
|
else:
|
||||||
|
# Otherwise, create the new connection for this alias. Raise
|
||||||
|
# ConnectionError if it can't be established.
|
||||||
|
try:
|
||||||
|
_connections[alias] = connection_class(**conn_settings)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e))
|
raise ConnectionError('Cannot connect to database %s :\n%s' % (alias, e))
|
||||||
|
|
||||||
return _connections[alias]
|
return _connections[alias]
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user