Compare commits
3 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
1f7272d139 | ||
|
f6ba1ad788 | ||
|
294d59c9bb |
@ -1,6 +1,10 @@
|
|||||||
import pymongo
|
import pymongo
|
||||||
from pymongo import MongoClient, MongoReplicaSetClient, uri_parser
|
from pymongo import MongoClient, MongoReplicaSetClient, uri_parser
|
||||||
|
|
||||||
|
try:
|
||||||
|
import motor
|
||||||
|
except ImportError:
|
||||||
|
motor = None
|
||||||
|
|
||||||
__all__ = ['ConnectionError', 'connect', 'register_connection',
|
__all__ = ['ConnectionError', 'connect', 'register_connection',
|
||||||
'DEFAULT_CONNECTION_NAME']
|
'DEFAULT_CONNECTION_NAME']
|
||||||
@ -21,6 +25,7 @@ _dbs = {}
|
|||||||
def register_connection(alias, name=None, host=None, port=None,
|
def register_connection(alias, name=None, host=None, port=None,
|
||||||
read_preference=False,
|
read_preference=False,
|
||||||
username=None, password=None, authentication_source=None,
|
username=None, password=None, authentication_source=None,
|
||||||
|
async=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Add a connection.
|
"""Add a connection.
|
||||||
|
|
||||||
@ -35,7 +40,6 @@ def register_connection(alias, name=None, host=None, port=None,
|
|||||||
:param password: password to authenticate with
|
:param password: password to authenticate with
|
||||||
:param authentication_source: database to authenticate against
|
:param authentication_source: database to authenticate against
|
||||||
:param kwargs: allow ad-hoc parameters to be passed into the pymongo driver
|
:param kwargs: allow ad-hoc parameters to be passed into the pymongo driver
|
||||||
|
|
||||||
"""
|
"""
|
||||||
global _connection_settings
|
global _connection_settings
|
||||||
|
|
||||||
@ -46,7 +50,8 @@ def register_connection(alias, name=None, host=None, port=None,
|
|||||||
'read_preference': read_preference,
|
'read_preference': read_preference,
|
||||||
'username': username,
|
'username': username,
|
||||||
'password': password,
|
'password': password,
|
||||||
'authentication_source': authentication_source
|
'authentication_source': authentication_source,
|
||||||
|
'async': async
|
||||||
}
|
}
|
||||||
|
|
||||||
# Handle uri style connections
|
# Handle uri style connections
|
||||||
@ -98,8 +103,17 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
|||||||
conn_settings.pop('username', None)
|
conn_settings.pop('username', None)
|
||||||
conn_settings.pop('password', None)
|
conn_settings.pop('password', None)
|
||||||
conn_settings.pop('authentication_source', None)
|
conn_settings.pop('authentication_source', None)
|
||||||
|
async = conn_settings.pop('async')
|
||||||
|
|
||||||
|
if async:
|
||||||
|
if not motor:
|
||||||
|
raise ImproperlyConfigured("Motor library was not found")
|
||||||
|
|
||||||
|
connection_class = motor.MotorClient
|
||||||
|
|
||||||
|
else:
|
||||||
|
connection_class = MongoClient
|
||||||
|
|
||||||
connection_class = MongoClient
|
|
||||||
if 'replicaSet' in conn_settings:
|
if 'replicaSet' in conn_settings:
|
||||||
conn_settings['hosts_or_uri'] = conn_settings.pop('host', None)
|
conn_settings['hosts_or_uri'] = conn_settings.pop('host', None)
|
||||||
# Discard port since it can't be used on MongoReplicaSetClient
|
# Discard port since it can't be used on MongoReplicaSetClient
|
||||||
@ -107,12 +121,17 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
|||||||
# Discard replicaSet if not base string
|
# Discard replicaSet if not base string
|
||||||
if not isinstance(conn_settings['replicaSet'], basestring):
|
if not isinstance(conn_settings['replicaSet'], basestring):
|
||||||
conn_settings.pop('replicaSet', None)
|
conn_settings.pop('replicaSet', None)
|
||||||
connection_class = MongoReplicaSetClient
|
|
||||||
|
if async:
|
||||||
|
connection_class = motor.MotorReplicaSetClient
|
||||||
|
else:
|
||||||
|
connection_class = MongoReplicaSetClient
|
||||||
|
|
||||||
try:
|
try:
|
||||||
connection = None
|
connection = None
|
||||||
# check for shared connections
|
# check for shared connections
|
||||||
connection_settings_iterator = ((db_alias, settings.copy()) for db_alias, settings in _connection_settings.iteritems())
|
connection_settings_iterator = (
|
||||||
|
(db_alias, settings.copy()) for db_alias, settings in _connection_settings.iteritems())
|
||||||
for db_alias, connection_settings in connection_settings_iterator:
|
for db_alias, connection_settings in connection_settings_iterator:
|
||||||
connection_settings.pop('name', None)
|
connection_settings.pop('name', None)
|
||||||
connection_settings.pop('username', None)
|
connection_settings.pop('username', None)
|
||||||
@ -121,9 +140,11 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
|||||||
connection = _connections[db_alias]
|
connection = _connections[db_alias]
|
||||||
break
|
break
|
||||||
|
|
||||||
_connections[alias] = connection if connection else connection_class(**conn_settings)
|
_connections[alias] = connection if connection else connection_class(
|
||||||
|
**conn_settings)
|
||||||
except Exception, e:
|
except Exception, e:
|
||||||
raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e))
|
raise ConnectionError(
|
||||||
|
"Cannot connect to database %s :\n%s" % (alias, e))
|
||||||
return _connections[alias]
|
return _connections[alias]
|
||||||
|
|
||||||
|
|
||||||
|
0
tests/async/__init__.py
Normal file
0
tests/async/__init__.py
Normal file
36
tests/async/test_connection.py
Normal file
36
tests/async/test_connection.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from mongoengine import *
|
||||||
|
import motor
|
||||||
|
import mongoengine.connection
|
||||||
|
from mongoengine.connection import get_db, get_connection, ConnectionError
|
||||||
|
|
||||||
|
try:
|
||||||
|
import unittest2 as unittest
|
||||||
|
except ImportError:
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionTest(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
mongoengine.connection._connection_settings = {}
|
||||||
|
mongoengine.connection._connections = {}
|
||||||
|
mongoengine.connection._dbs = {}
|
||||||
|
|
||||||
|
def test_register_connection(self):
|
||||||
|
"""
|
||||||
|
Ensure that the connect() method works properly.
|
||||||
|
"""
|
||||||
|
register_connection('asyncdb', 'mongoengineasynctest', async=True)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
mongoengine.connection._connection_settings['asyncdb']['name'],
|
||||||
|
'mongoengineasynctest')
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
mongoengine.connection._connection_settings['asyncdb']['async'])
|
||||||
|
conn = get_connection('asyncdb')
|
||||||
|
self.assertTrue(isinstance(conn, motor.MotorClient))
|
||||||
|
|
||||||
|
db = get_db('asyncdb')
|
||||||
|
self.assertTrue(isinstance(db, motor.MotorDatabase))
|
||||||
|
self.assertEqual(db.name, 'mongoengineasynctest')
|
Loading…
x
Reference in New Issue
Block a user