Fail fast when db name is invalid

Without this commit save operation on first document would fail instead of immediate failure upon connection attempt. Such later failure is much less obvious.
This commit is contained in:
Yurii Andrieiev 2019-04-07 02:02:26 +03:00
parent 827de76345
commit b5213097e8
4 changed files with 45 additions and 1 deletions

View File

@ -251,3 +251,4 @@ that much better:
* Gleb Voropaev (https://github.com/buggyspace) * Gleb Voropaev (https://github.com/buggyspace)
* Paulo Amaral (https://github.com/pauloAmaral) * Paulo Amaral (https://github.com/pauloAmaral)
* Gaurav Dadhania (https://github.com/GVRV) * Gaurav Dadhania (https://github.com/GVRV)
* Yurii Andrieiev (https://github.com/yandrieiev)

View File

@ -6,6 +6,7 @@ Development
=========== ===========
- mongoengine now requires pymongo>=3.5 #2017 - mongoengine now requires pymongo>=3.5 #2017
- Generate Unique Indices for SortedListField and EmbeddedDocumentListFields #2020 - Generate Unique Indices for SortedListField and EmbeddedDocumentListFields #2020
- connect() fails immediately when db name contains invalid characters (e. g. when user mistakenly puts 'mongodb://127.0.0.1:27017' as db name, happened in #1718) or is if db name is of an invalid type
- (Fill this out as you fix issues and develop your features). - (Fill this out as you fix issues and develop your features).
Changes in 0.17.0 Changes in 0.17.0

View File

@ -1,4 +1,5 @@
from pymongo import MongoClient, ReadPreference, uri_parser from pymongo import MongoClient, ReadPreference, uri_parser
from pymongo.database import _check_name
import six import six
from mongoengine.pymongo_support import IS_PYMONGO_3 from mongoengine.pymongo_support import IS_PYMONGO_3
@ -28,6 +29,16 @@ _connections = {}
_dbs = {} _dbs = {}
def check_db_name(name):
"""Check if a database name is valid.
This functionality is copied from pymongo Database class constructor.
"""
if not isinstance(name, six.string_types):
raise TypeError('name must be an instance of %s' % six.string_types)
elif name != '$external':
_check_name(name)
def register_connection(alias, db=None, name=None, host=None, port=None, def register_connection(alias, db=None, name=None, host=None, port=None,
read_preference=READ_PREFERENCE, read_preference=READ_PREFERENCE,
username=None, password=None, username=None, password=None,
@ -69,6 +80,7 @@ def register_connection(alias, db=None, name=None, host=None, port=None,
'authentication_mechanism': authentication_mechanism 'authentication_mechanism': authentication_mechanism
} }
check_db_name(conn_settings['name'])
conn_host = conn_settings['host'] conn_host = conn_settings['host']
# Host can be a list or a string, so if string, force to a list. # Host can be a list or a string, so if string, force to a list.

View File

@ -1,5 +1,5 @@
import datetime import datetime
from pymongo.errors import OperationFailure from pymongo.errors import OperationFailure, InvalidName
try: try:
import unittest2 as unittest import unittest2 as unittest
@ -49,6 +49,36 @@ class ConnectionTest(unittest.TestCase):
conn = get_connection('testdb') conn = get_connection('testdb')
self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) self.assertIsInstance(conn, pymongo.mongo_client.MongoClient)
def test_connect_with_invalid_db_name(self):
"""Ensure that connect() method fails fast if db name is invalid
"""
with self.assertRaises(InvalidName):
connect('mongomock://localhost')
def test_connect_with_db_name_external(self):
"""Ensure that connect() works if db name is $external
"""
"""Ensure that the connect() method works properly."""
connect('$external')
conn = get_connection()
self.assertIsInstance(conn, pymongo.mongo_client.MongoClient)
db = get_db()
self.assertIsInstance(db, pymongo.database.Database)
self.assertEqual(db.name, '$external')
connect('$external', alias='testdb')
conn = get_connection('testdb')
self.assertIsInstance(conn, pymongo.mongo_client.MongoClient)
def test_connect_with_invalid_db_name_type(self):
"""Ensure that connect() method fails fast if db name has invalid type
"""
with self.assertRaises(TypeError):
non_string_db_name = ['e. g. list instead of a string']
connect(non_string_db_name)
def test_connect_in_mocking(self): def test_connect_in_mocking(self):
"""Ensure that the connect() method works properly in mocking. """Ensure that the connect() method works properly in mocking.
""" """