diff --git a/docs/changelog.rst b/docs/changelog.rst index 75230a28..e4d693bd 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -5,6 +5,7 @@ Changelog Changes in dev ============== +- Added uri support for connections - Added scalar for efficiently returning partial data values (aliased to values_list) - Fixed limit skip bug - Improved Inheritance / Mixin diff --git a/docs/guide/connecting.rst b/docs/guide/connecting.rst index 470dcb88..50eb2703 100644 --- a/docs/guide/connecting.rst +++ b/docs/guide/connecting.rst @@ -20,6 +20,13 @@ provide :attr:`host` and :attr:`port` arguments to connect('project1', host='192.168.1.35', port=12345) +Uri style connections are also supported as long as you include the database +name - just supply the uri as the :attr:`host` to +:func:`~mongoengine.connect`:: + + connect('project1', host='mongodb://localhost/database_name') + + Multiple Databases ================== diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 462c8951..822c604b 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,5 +1,5 @@ import pymongo -from pymongo import Connection +from pymongo import Connection, uri_parser __all__ = ['ConnectionError', 'connect', 'register_connection', @@ -38,6 +38,17 @@ def register_connection(alias, name, host='localhost', port=27017, """ global _connection_settings + + # Handle uri style connections + if "://" in host: + uri_dict = uri_parser.parse_uri(host) + if 'database' not in uri_dict: + raise ConnectionError("If using URI style connection include "\ + "database name in string") + uri_dict['name'] = uri_dict.get('database') + _connection_settings[alias] = uri_dict + return + _connection_settings[alias] = { 'name': name, 'host': host, @@ -48,8 +59,10 @@ def register_connection(alias, name, host='localhost', port=27017, 'password': password, 'read_preference': read_preference } + _connection_settings[alias].update(kwargs) + def disconnect(alias=DEFAULT_CONNECTION_NAME): global _connections global _dbs @@ -83,11 +96,12 @@ def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): conn_settings.pop('password') else: # Get all the slave connections - slaves = [] - for slave_alias in conn_settings['slaves']: - slaves.append(get_connection(slave_alias)) - conn_settings['slaves'] = slaves - conn_settings.pop('read_preference') + if 'slaves' in conn_settings: + slaves = [] + for slave_alias in conn_settings['slaves']: + slaves.append(get_connection(slave_alias)) + conn_settings['slaves'] = slaves + conn_settings.pop('read_preference') try: _connections[alias] = Connection(**conn_settings) diff --git a/tests/connection.py b/tests/connection.py index e017b388..7ff0998e 100644 --- a/tests/connection.py +++ b/tests/connection.py @@ -30,6 +30,19 @@ class ConnectionTest(unittest.TestCase): conn = get_connection('testdb') self.assertTrue(isinstance(conn, pymongo.connection.Connection)) + def test_connect_uri(self): + """Ensure that the connect() method works properly with uri's + """ + + connect("testdb_uri", host='mongodb://username:password@localhost/mongoenginetest') + + conn = get_connection() + self.assertTrue(isinstance(conn, pymongo.connection.Connection)) + + db = get_db() + self.assertTrue(isinstance(db, pymongo.database.Database)) + self.assertEqual(db.name, 'mongoenginetest') + def test_register_connection(self): """Ensure that connections with different aliases may be registered. """