Compare commits
3 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
1f7272d139 | ||
|
f6ba1ad788 | ||
|
294d59c9bb |
@ -1,6 +1,10 @@
|
||||
import pymongo
|
||||
from pymongo import MongoClient, MongoReplicaSetClient, uri_parser
|
||||
|
||||
try:
|
||||
import motor
|
||||
except ImportError:
|
||||
motor = None
|
||||
|
||||
__all__ = ['ConnectionError', 'connect', 'register_connection',
|
||||
'DEFAULT_CONNECTION_NAME']
|
||||
@ -21,6 +25,7 @@ _dbs = {}
|
||||
def register_connection(alias, name=None, host=None, port=None,
|
||||
read_preference=False,
|
||||
username=None, password=None, authentication_source=None,
|
||||
async=False,
|
||||
**kwargs):
|
||||
"""Add a connection.
|
||||
|
||||
@ -35,7 +40,6 @@ def register_connection(alias, name=None, host=None, port=None,
|
||||
:param password: password to authenticate with
|
||||
:param authentication_source: database to authenticate against
|
||||
:param kwargs: allow ad-hoc parameters to be passed into the pymongo driver
|
||||
|
||||
"""
|
||||
global _connection_settings
|
||||
|
||||
@ -46,7 +50,8 @@ def register_connection(alias, name=None, host=None, port=None,
|
||||
'read_preference': read_preference,
|
||||
'username': username,
|
||||
'password': password,
|
||||
'authentication_source': authentication_source
|
||||
'authentication_source': authentication_source,
|
||||
'async': async
|
||||
}
|
||||
|
||||
# Handle uri style connections
|
||||
@ -98,8 +103,17 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
conn_settings.pop('username', None)
|
||||
conn_settings.pop('password', None)
|
||||
conn_settings.pop('authentication_source', None)
|
||||
async = conn_settings.pop('async')
|
||||
|
||||
if async:
|
||||
if not motor:
|
||||
raise ImproperlyConfigured("Motor library was not found")
|
||||
|
||||
connection_class = motor.MotorClient
|
||||
|
||||
else:
|
||||
connection_class = MongoClient
|
||||
|
||||
connection_class = MongoClient
|
||||
if 'replicaSet' in conn_settings:
|
||||
conn_settings['hosts_or_uri'] = conn_settings.pop('host', None)
|
||||
# Discard port since it can't be used on MongoReplicaSetClient
|
||||
@ -107,12 +121,17 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
# Discard replicaSet if not base string
|
||||
if not isinstance(conn_settings['replicaSet'], basestring):
|
||||
conn_settings.pop('replicaSet', None)
|
||||
connection_class = MongoReplicaSetClient
|
||||
|
||||
if async:
|
||||
connection_class = motor.MotorReplicaSetClient
|
||||
else:
|
||||
connection_class = MongoReplicaSetClient
|
||||
|
||||
try:
|
||||
connection = None
|
||||
# check for shared connections
|
||||
connection_settings_iterator = ((db_alias, settings.copy()) for db_alias, settings in _connection_settings.iteritems())
|
||||
connection_settings_iterator = (
|
||||
(db_alias, settings.copy()) for db_alias, settings in _connection_settings.iteritems())
|
||||
for db_alias, connection_settings in connection_settings_iterator:
|
||||
connection_settings.pop('name', None)
|
||||
connection_settings.pop('username', None)
|
||||
@ -121,9 +140,11 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
|
||||
connection = _connections[db_alias]
|
||||
break
|
||||
|
||||
_connections[alias] = connection if connection else connection_class(**conn_settings)
|
||||
_connections[alias] = connection if connection else connection_class(
|
||||
**conn_settings)
|
||||
except Exception, e:
|
||||
raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e))
|
||||
raise ConnectionError(
|
||||
"Cannot connect to database %s :\n%s" % (alias, e))
|
||||
return _connections[alias]
|
||||
|
||||
|
||||
|
0
tests/async/__init__.py
Normal file
0
tests/async/__init__.py
Normal file
36
tests/async/test_connection.py
Normal file
36
tests/async/test_connection.py
Normal file
@ -0,0 +1,36 @@
|
||||
from mongoengine import *
|
||||
import motor
|
||||
import mongoengine.connection
|
||||
from mongoengine.connection import get_db, get_connection, ConnectionError
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest
|
||||
|
||||
|
||||
class ConnectionTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
mongoengine.connection._connection_settings = {}
|
||||
mongoengine.connection._connections = {}
|
||||
mongoengine.connection._dbs = {}
|
||||
|
||||
def test_register_connection(self):
|
||||
"""
|
||||
Ensure that the connect() method works properly.
|
||||
"""
|
||||
register_connection('asyncdb', 'mongoengineasynctest', async=True)
|
||||
|
||||
self.assertEqual(
|
||||
mongoengine.connection._connection_settings['asyncdb']['name'],
|
||||
'mongoengineasynctest')
|
||||
|
||||
self.assertTrue(
|
||||
mongoengine.connection._connection_settings['asyncdb']['async'])
|
||||
conn = get_connection('asyncdb')
|
||||
self.assertTrue(isinstance(conn, motor.MotorClient))
|
||||
|
||||
db = get_db('asyncdb')
|
||||
self.assertTrue(isinstance(db, motor.MotorDatabase))
|
||||
self.assertEqual(db.name, 'mongoengineasynctest')
|
Loading…
x
Reference in New Issue
Block a user