From b5213097e887997a247218d9229c0ad3bec074a9 Mon Sep 17 00:00:00 2001 From: Yurii Andrieiev Date: Sun, 7 Apr 2019 02:02:26 +0300 Subject: [PATCH] Fail fast when db name is invalid Without this commit save operation on first document would fail instead of immediate failure upon connection attempt. Such later failure is much less obvious. --- AUTHORS | 1 + docs/changelog.rst | 1 + mongoengine/connection.py | 12 ++++++++++++ tests/test_connection.py | 32 +++++++++++++++++++++++++++++++- 4 files changed, 45 insertions(+), 1 deletion(-) diff --git a/AUTHORS b/AUTHORS index 21b0ec64..45a754cc 100644 --- a/AUTHORS +++ b/AUTHORS @@ -251,3 +251,4 @@ that much better: * Gleb Voropaev (https://github.com/buggyspace) * Paulo Amaral (https://github.com/pauloAmaral) * Gaurav Dadhania (https://github.com/GVRV) + * Yurii Andrieiev (https://github.com/yandrieiev) diff --git a/docs/changelog.rst b/docs/changelog.rst index 53373302..707182f1 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,7 @@ Development =========== - mongoengine now requires pymongo>=3.5 #2017 - Generate Unique Indices for SortedListField and EmbeddedDocumentListFields #2020 +- connect() fails immediately when db name contains invalid characters (e. g. when user mistakenly puts 'mongodb://127.0.0.1:27017' as db name, happened in #1718) or is if db name is of an invalid type - (Fill this out as you fix issues and develop your features). Changes in 0.17.0 diff --git a/mongoengine/connection.py b/mongoengine/connection.py index c0cfde31..dda9bbb7 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,4 +1,5 @@ from pymongo import MongoClient, ReadPreference, uri_parser +from pymongo.database import _check_name import six from mongoengine.pymongo_support import IS_PYMONGO_3 @@ -28,6 +29,16 @@ _connections = {} _dbs = {} +def check_db_name(name): + """Check if a database name is valid. + This functionality is copied from pymongo Database class constructor. + """ + if not isinstance(name, six.string_types): + raise TypeError('name must be an instance of %s' % six.string_types) + elif name != '$external': + _check_name(name) + + def register_connection(alias, db=None, name=None, host=None, port=None, read_preference=READ_PREFERENCE, username=None, password=None, @@ -69,6 +80,7 @@ def register_connection(alias, db=None, name=None, host=None, port=None, 'authentication_mechanism': authentication_mechanism } + check_db_name(conn_settings['name']) conn_host = conn_settings['host'] # Host can be a list or a string, so if string, force to a list. diff --git a/tests/test_connection.py b/tests/test_connection.py index 0a7271df..fb2a20d7 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,5 +1,5 @@ import datetime -from pymongo.errors import OperationFailure +from pymongo.errors import OperationFailure, InvalidName try: import unittest2 as unittest @@ -49,6 +49,36 @@ class ConnectionTest(unittest.TestCase): conn = get_connection('testdb') self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + def test_connect_with_invalid_db_name(self): + """Ensure that connect() method fails fast if db name is invalid + """ + with self.assertRaises(InvalidName): + connect('mongomock://localhost') + + def test_connect_with_db_name_external(self): + """Ensure that connect() works if db name is $external + """ + """Ensure that the connect() method works properly.""" + connect('$external') + + conn = get_connection() + self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + + db = get_db() + self.assertIsInstance(db, pymongo.database.Database) + self.assertEqual(db.name, '$external') + + connect('$external', alias='testdb') + conn = get_connection('testdb') + self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + + def test_connect_with_invalid_db_name_type(self): + """Ensure that connect() method fails fast if db name has invalid type + """ + with self.assertRaises(TypeError): + non_string_db_name = ['e. g. list instead of a string'] + connect(non_string_db_name) + def test_connect_in_mocking(self): """Ensure that the connect() method works properly in mocking. """