diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 8aa95daa..4e0c60b0 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -103,7 +103,14 @@ def _get_connection_settings( if entity.startswith("mongomock://"): conn_settings["is_mock"] = True # `mongomock://` is not a valid url prefix and must be replaced by `mongodb://` - resolved_hosts.append(entity.replace("mongomock://", "mongodb://", 1)) + new_entity = entity.replace("mongomock://", "mongodb://", 1) + resolved_hosts.append(new_entity) + + uri_dict = uri_parser.parse_uri(new_entity) + + database = uri_dict.get("database") + if database: + conn_settings["name"] = database # Handle URI style connections, only updating connection params which # were explicitly specified in the URI. @@ -111,8 +118,9 @@ def _get_connection_settings( uri_dict = uri_parser.parse_uri(entity) resolved_hosts.append(entity) - if uri_dict.get("database"): - conn_settings["name"] = uri_dict.get("database") + database = uri_dict.get("database") + if database: + conn_settings["name"] = database for param in ("read_preference", "username", "password"): if uri_dict.get(param): diff --git a/tests/test_connection.py b/tests/test_connection.py index f9c9d098..b7dc9268 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -4,6 +4,8 @@ from pymongo import MongoClient from pymongo.errors import OperationFailure, InvalidName from pymongo import ReadPreference +from mongoengine import Document + try: import unittest2 as unittest except ImportError: @@ -269,6 +271,26 @@ class ConnectionTest(unittest.TestCase): conn = get_connection("testdb7") self.assertIsInstance(conn, mongomock.MongoClient) + def test_default_database_with_mocking(self): + """Ensure that the default database is correctly set when using mongomock. + """ + try: + import mongomock + except ImportError: + raise SkipTest("you need mongomock installed to run this testcase") + + disconnect_all() + + class SomeDocument(Document): + pass + + conn = connect(host="mongomock://localhost:27017/mongoenginetest") + some_document = SomeDocument() + # database won't exist until we save a document + some_document.save() + self.assertEqual(conn.get_default_database().name, "mongoenginetest") + self.assertEqual(conn.database_names()[0], "mongoenginetest") + def test_connect_with_host_list(self): """Ensure that the connect() method works when host is a list