From 294d59c9bb1b53b6ee16d1b5c21e7aa62c4a394e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20J=C3=BAnior?= Date: Wed, 3 Sep 2014 00:45:02 -0300 Subject: [PATCH] register a possible async database --- mongoengine/connection.py | 35 ++++++++++++++++++++++++++------- tests/async/__init__.py | 0 tests/async/test_connection.py | 36 ++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 7 deletions(-) create mode 100644 tests/async/__init__.py create mode 100644 tests/async/test_connection.py diff --git a/mongoengine/connection.py b/mongoengine/connection.py index dcecdd9a..690ee602 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,6 +1,10 @@ import pymongo from pymongo import MongoClient, MongoReplicaSetClient, uri_parser +try: + import motor +except ImportError: + motor = None __all__ = ['ConnectionError', 'connect', 'register_connection', 'DEFAULT_CONNECTION_NAME'] @@ -21,6 +25,7 @@ _dbs = {} def register_connection(alias, name=None, host=None, port=None, read_preference=False, username=None, password=None, authentication_source=None, + async=False, **kwargs): """Add a connection. @@ -35,7 +40,6 @@ def register_connection(alias, name=None, host=None, port=None, :param password: password to authenticate with :param authentication_source: database to authenticate against :param kwargs: allow ad-hoc parameters to be passed into the pymongo driver - """ global _connection_settings @@ -46,7 +50,8 @@ def register_connection(alias, name=None, host=None, port=None, 'read_preference': read_preference, 'username': username, 'password': password, - 'authentication_source': authentication_source + 'authentication_source': authentication_source, + 'async': async } # Handle uri style connections @@ -98,8 +103,17 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): conn_settings.pop('username', None) conn_settings.pop('password', None) conn_settings.pop('authentication_source', None) + async = conn_settings.pop('async') + + if async: + if not motor: + raise ImproperlyConfigured("Motor library was not found") + + connection_class = motor.MotorClient + + else: + connection_class = MongoClient - connection_class = MongoClient if 'replicaSet' in conn_settings: conn_settings['hosts_or_uri'] = conn_settings.pop('host', None) # Discard port since it can't be used on MongoReplicaSetClient @@ -107,12 +121,17 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): # Discard replicaSet if not base string if not isinstance(conn_settings['replicaSet'], basestring): conn_settings.pop('replicaSet', None) - connection_class = MongoReplicaSetClient + + if async: + connection_class = MongoReplicaSetClient + else: + connection_class = motor.MotorReplicaSetClient try: connection = None # check for shared connections - connection_settings_iterator = ((db_alias, settings.copy()) for db_alias, settings in _connection_settings.iteritems()) + connection_settings_iterator = ( + (db_alias, settings.copy()) for db_alias, settings in _connection_settings.iteritems()) for db_alias, connection_settings in connection_settings_iterator: connection_settings.pop('name', None) connection_settings.pop('username', None) @@ -121,9 +140,11 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): connection = _connections[db_alias] break - _connections[alias] = connection if connection else connection_class(**conn_settings) + _connections[alias] = connection if connection else connection_class( + **conn_settings) except Exception, e: - raise ConnectionError("Cannot connect to database %s :\n%s" % (alias, e)) + raise ConnectionError( + "Cannot connect to database %s :\n%s" % (alias, e)) return _connections[alias] diff --git a/tests/async/__init__.py b/tests/async/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/async/test_connection.py b/tests/async/test_connection.py new file mode 100644 index 00000000..1c42c9d5 --- /dev/null +++ b/tests/async/test_connection.py @@ -0,0 +1,36 @@ +from mongoengine import * +import motor +import mongoengine.connection +from mongoengine.connection import get_db, get_connection, ConnectionError + +try: + import unittest2 as unittest +except ImportError: + import unittest + + +class ConnectionTest(unittest.TestCase): + + def setUp(self): + mongoengine.connection._connection_settings = {} + mongoengine.connection._connections = {} + mongoengine.connection._dbs = {} + + def test_register_connection(self): + """ + Ensure that the connect() method works properly. + """ + register_connection('asyncdb', 'mongoengineasynctest', async=True) + + self.assertEqual( + mongoengine.connection._connection_settings['asyncdb']['name'], + 'mongoengineasynctest') + + self.assertTrue( + mongoengine.connection._connection_settings['asyncdb']['async']) + conn = get_connection('asyncdb') + self.assertTrue(isinstance(conn, motor.MotorClient)) + + db = get_db('asyncdb') + self.assertTrue(isinstance(db, motor.MotorDatabase)) + self.assertEqual(db.name, 'mongoengineasynctest')