diff --git a/AUTHORS b/AUTHORS index 21b0ec64..45a754cc 100644 --- a/AUTHORS +++ b/AUTHORS @@ -251,3 +251,4 @@ that much better: * Gleb Voropaev (https://github.com/buggyspace) * Paulo Amaral (https://github.com/pauloAmaral) * Gaurav Dadhania (https://github.com/GVRV) + * Yurii Andrieiev (https://github.com/yandrieiev) diff --git a/docs/changelog.rst b/docs/changelog.rst index 5a472eb5..d29e5eb4 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -7,6 +7,7 @@ Development - POTENTIAL BREAKING CHANGE: Aggregate gives wrong results when used with a queryset having limit and skip #2029 - mongoengine now requires pymongo>=3.5 #2017 - Generate Unique Indices for SortedListField and EmbeddedDocumentListFields #2020 +- connect() fails immediately when db name contains invalid characters (e. g. when user mistakenly puts 'mongodb://127.0.0.1:27017' as db name, happened in #1718) or is if db name is of an invalid type - (Fill this out as you fix issues and develop your features). Changes in 0.17.0 diff --git a/mongoengine/connection.py b/mongoengine/connection.py index c0cfde31..dda9bbb7 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,4 +1,5 @@ from pymongo import MongoClient, ReadPreference, uri_parser +from pymongo.database import _check_name import six from mongoengine.pymongo_support import IS_PYMONGO_3 @@ -28,6 +29,16 @@ _connections = {} _dbs = {} +def check_db_name(name): + """Check if a database name is valid. + This functionality is copied from pymongo Database class constructor. + """ + if not isinstance(name, six.string_types): + raise TypeError('name must be an instance of %s' % six.string_types) + elif name != '$external': + _check_name(name) + + def register_connection(alias, db=None, name=None, host=None, port=None, read_preference=READ_PREFERENCE, username=None, password=None, @@ -69,6 +80,7 @@ def register_connection(alias, db=None, name=None, host=None, port=None, 'authentication_mechanism': authentication_mechanism } + check_db_name(conn_settings['name']) conn_host = conn_settings['host'] # Host can be a list or a string, so if string, force to a list. diff --git a/tests/test_connection.py b/tests/test_connection.py index 0a7271df..fb2a20d7 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,5 +1,5 @@ import datetime -from pymongo.errors import OperationFailure +from pymongo.errors import OperationFailure, InvalidName try: import unittest2 as unittest @@ -49,6 +49,36 @@ class ConnectionTest(unittest.TestCase): conn = get_connection('testdb') self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + def test_connect_with_invalid_db_name(self): + """Ensure that connect() method fails fast if db name is invalid + """ + with self.assertRaises(InvalidName): + connect('mongomock://localhost') + + def test_connect_with_db_name_external(self): + """Ensure that connect() works if db name is $external + """ + """Ensure that the connect() method works properly.""" + connect('$external') + + conn = get_connection() + self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + + db = get_db() + self.assertIsInstance(db, pymongo.database.Database) + self.assertEqual(db.name, '$external') + + connect('$external', alias='testdb') + conn = get_connection('testdb') + self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + + def test_connect_with_invalid_db_name_type(self): + """Ensure that connect() method fails fast if db name has invalid type + """ + with self.assertRaises(TypeError): + non_string_db_name = ['e. g. list instead of a string'] + connect(non_string_db_name) + def test_connect_in_mocking(self): """Ensure that the connect() method works properly in mocking. """