diff --git a/docs/changelog.rst b/docs/changelog.rst index d29e5eb4..356e2b65 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,10 +4,17 @@ Changelog Development =========== +- expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all` +- Fix disconnect function #566 #1599 #605 #607 #1213 #565 +- Improve connect/disconnect documentations +- POTENTIAL BREAKING CHANGES: (associated with connect/disconnect fixes) + - calling `connect` 2 times with the same alias and different parameter will raise an error (should call disconnect first) + - disconnect now clears `mongoengine.connection._connection_settings` + - disconnect now clears the cached attribute `Document._collection` - POTENTIAL BREAKING CHANGE: Aggregate gives wrong results when used with a queryset having limit and skip #2029 - mongoengine now requires pymongo>=3.5 #2017 - Generate Unique Indices for SortedListField and EmbeddedDocumentListFields #2020 -- connect() fails immediately when db name contains invalid characters (e. g. when user mistakenly puts 'mongodb://127.0.0.1:27017' as db name, happened in #1718) or is if db name is of an invalid type +- connect() fails immediately when db name contains invalid characters #2031 #1718 - (Fill this out as you fix issues and develop your features). Changes in 0.17.0 diff --git a/docs/guide/connecting.rst b/docs/guide/connecting.rst index 5dac6ae9..1107ee3a 100644 --- a/docs/guide/connecting.rst +++ b/docs/guide/connecting.rst @@ -4,9 +4,11 @@ Connecting to MongoDB ===================== -To connect to a running instance of :program:`mongod`, use the -:func:`~mongoengine.connect` function. The first argument is the name of the -database to connect to:: +Connections in MongoEngine are registered globally and are identified with aliases. +If no `alias` is provided during the connection, it will use "default" as alias. + +To connect to a running instance of :program:`mongod`, use the :func:`~mongoengine.connect` +function. The first argument is the name of the database to connect to:: from mongoengine import connect connect('project1') @@ -42,6 +44,9 @@ the :attr:`host` to will establish connection to ``production`` database using ``admin`` username and ``qwerty`` password. +.. note:: Calling :func:`~mongoengine.connect` without argument will establish + a connection to the "test" database by default + Replica Sets ============ @@ -71,6 +76,8 @@ is used. In the background this uses :func:`~mongoengine.register_connection` to store the data and you can register all aliases up front if required. +Documents defined in different database +--------------------------------------- Individual documents can also support multiple databases by providing a `db_alias` in their meta data. This allows :class:`~pymongo.dbref.DBRef` objects to point across databases and collections. Below is an example schema, @@ -93,6 +100,33 @@ using 3 different databases to store data:: meta = {'db_alias': 'users-books-db'} +Disconnecting an existing connection +------------------------------------ +The function :func:`~mongoengine.disconnect` can be used to +disconnect a particular connection. This can be used to change a +connection globally:: + + from mongoengine import connect, disconnect + connect('a_db', alias='db1') + + class User(Document): + name = StringField() + meta = {'db_alias': 'db1'} + + disconnect(alias='db1') + + connect('another_db', alias='db1') + +.. note:: Calling :func:`~mongoengine.disconnect` without argument + will disconnect the "default" connection + +.. note:: Since connections gets registered globally, it is important + to use the `disconnect` function from MongoEngine and not the + `disconnect()` method of an existing connection (pymongo.MongoClient) + +.. note:: :class:`~mongoengine.Document` are caching the pymongo collection. + using `disconnect` ensures that it gets cleaned as well + Context Managers ================ Sometimes you may want to switch the database or collection to query against. @@ -119,7 +153,7 @@ access to the same User document across databases:: Switch Collection ----------------- -The :class:`~mongoengine.context_managers.switch_collection` context manager +The :func:`~mongoengine.context_managers.switch_collection` context manager allows you to change the collection for a given class allowing quick and easy access to the same Group document across collection:: diff --git a/mongoengine/base/common.py b/mongoengine/base/common.py index d747c8cc..999fd23a 100644 --- a/mongoengine/base/common.py +++ b/mongoengine/base/common.py @@ -13,7 +13,7 @@ _document_registry = {} def get_document(name): - """Get a document class by name.""" + """Get a registered Document class by name.""" doc = _document_registry.get(name, None) if not doc: # Possible old style name @@ -30,3 +30,12 @@ def get_document(name): been imported? """.strip() % name) return doc + + +def _get_documents_by_db(connection_alias, default_connection_alias): + """Get all registered Documents class attached to a given database""" + def get_doc_alias(doc_cls): + return doc_cls._meta.get('db_alias', default_connection_alias) + + return [doc_cls for doc_cls in _document_registry.values() + if get_doc_alias(doc_cls) == connection_alias] diff --git a/mongoengine/connection.py b/mongoengine/connection.py index dda9bbb7..67374d01 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -4,11 +4,15 @@ import six from mongoengine.pymongo_support import IS_PYMONGO_3 -__all__ = ['MongoEngineConnectionError', 'connect', 'register_connection', - 'DEFAULT_CONNECTION_NAME', 'get_db'] +__all__ = ['MongoEngineConnectionError', 'connect', 'disconnect', 'disconnect_all', + 'register_connection', 'DEFAULT_CONNECTION_NAME', 'DEFAULT_DATABASE_NAME', + 'get_db', 'get_connection'] DEFAULT_CONNECTION_NAME = 'default' +DEFAULT_DATABASE_NAME = 'test' +DEFAULT_HOST = 'localhost' +DEFAULT_PORT = 27017 if IS_PYMONGO_3: READ_PREFERENCE = ReadPreference.PRIMARY @@ -39,40 +43,39 @@ def check_db_name(name): _check_name(name) -def register_connection(alias, db=None, name=None, host=None, port=None, - read_preference=READ_PREFERENCE, - username=None, password=None, - authentication_source=None, - authentication_mechanism=None, - **kwargs): - """Add a connection. +def _get_connection_settings( + db=None, name=None, host=None, port=None, + read_preference=READ_PREFERENCE, + username=None, password=None, + authentication_source=None, + authentication_mechanism=None, + **kwargs): + """Get the connection settings as a dict - :param alias: the name that will be used to refer to this connection - throughout MongoEngine - :param name: the name of the specific database to use - :param db: the name of the database to use, for compatibility with connect - :param host: the host name of the :program:`mongod` instance to connect to - :param port: the port that the :program:`mongod` instance is running on - :param read_preference: The read preference for the collection + : param db: the name of the database to use, for compatibility with connect + : param name: the name of the specific database to use + : param host: the host name of the: program: `mongod` instance to connect to + : param port: the port that the: program: `mongod` instance is running on + : param read_preference: The read preference for the collection ** Added pymongo 2.1 - :param username: username to authenticate with - :param password: password to authenticate with - :param authentication_source: database to authenticate against - :param authentication_mechanism: database authentication mechanisms. + : param username: username to authenticate with + : param password: password to authenticate with + : param authentication_source: database to authenticate against + : param authentication_mechanism: database authentication mechanisms. By default, use SCRAM-SHA-1 with MongoDB 3.0 and later, MONGODB-CR (MongoDB Challenge Response protocol) for older servers. - :param is_mock: explicitly use mongomock for this connection - (can also be done by using `mongomock://` as db host prefix) - :param kwargs: ad-hoc parameters to be passed into the pymongo driver, + : param is_mock: explicitly use mongomock for this connection + (can also be done by using `mongomock: // ` as db host prefix) + : param kwargs: ad-hoc parameters to be passed into the pymongo driver, for example maxpoolsize, tz_aware, etc. See the documentation for pymongo's `MongoClient` for a full list. .. versionchanged:: 0.10.6 - added mongomock support """ conn_settings = { - 'name': name or db or 'test', - 'host': host or 'localhost', - 'port': port or 27017, + 'name': name or db or DEFAULT_DATABASE_NAME, + 'host': host or DEFAULT_HOST, + 'port': port or DEFAULT_PORT, 'read_preference': read_preference, 'username': username, 'password': password, @@ -137,17 +140,75 @@ def register_connection(alias, db=None, name=None, host=None, port=None, kwargs.pop('is_slave', None) conn_settings.update(kwargs) + return conn_settings + + +def register_connection(alias, db=None, name=None, host=None, port=None, + read_preference=READ_PREFERENCE, + username=None, password=None, + authentication_source=None, + authentication_mechanism=None, + **kwargs): + """Register the connection settings. + + : param alias: the name that will be used to refer to this connection + throughout MongoEngine + : param name: the name of the specific database to use + : param db: the name of the database to use, for compatibility with connect + : param host: the host name of the: program: `mongod` instance to connect to + : param port: the port that the: program: `mongod` instance is running on + : param read_preference: The read preference for the collection + ** Added pymongo 2.1 + : param username: username to authenticate with + : param password: password to authenticate with + : param authentication_source: database to authenticate against + : param authentication_mechanism: database authentication mechanisms. + By default, use SCRAM-SHA-1 with MongoDB 3.0 and later, + MONGODB-CR (MongoDB Challenge Response protocol) for older servers. + : param is_mock: explicitly use mongomock for this connection + (can also be done by using `mongomock: // ` as db host prefix) + : param kwargs: ad-hoc parameters to be passed into the pymongo driver, + for example maxpoolsize, tz_aware, etc. See the documentation + for pymongo's `MongoClient` for a full list. + + .. versionchanged:: 0.10.6 - added mongomock support + """ + conn_settings = _get_connection_settings( + db=db, name=name, host=host, port=port, + read_preference=read_preference, + username=username, password=password, + authentication_source=authentication_source, + authentication_mechanism=authentication_mechanism, + **kwargs) _connection_settings[alias] = conn_settings def disconnect(alias=DEFAULT_CONNECTION_NAME): """Close the connection with a given alias.""" + from mongoengine.base.common import _get_documents_by_db + from mongoengine import Document + if alias in _connections: get_connection(alias=alias).close() del _connections[alias] + if alias in _dbs: + # Detach all cached collections in Documents + for doc_cls in _get_documents_by_db(alias, DEFAULT_CONNECTION_NAME): + if issubclass(doc_cls, Document): # Skip EmbeddedDocument + doc_cls._disconnect() + del _dbs[alias] + if alias in _connection_settings: + del _connection_settings[alias] + + +def disconnect_all(): + """Close all registered database.""" + for alias in list(_connections.keys()): + disconnect(alias) + def get_connection(alias=DEFAULT_CONNECTION_NAME, reconnect=False): """Return a connection with a given alias.""" @@ -270,14 +331,24 @@ def connect(db=None, alias=DEFAULT_CONNECTION_NAME, **kwargs): provide username and password arguments as well. Multiple databases are supported by using aliases. Provide a separate - `alias` to connect to a different instance of :program:`mongod`. + `alias` to connect to a different instance of: program: `mongod`. + + In order to replace a connection identified by a given alias, you'll + need to call ``disconnect`` first See the docstring for `register_connection` for more details about all supported kwargs. .. versionchanged:: 0.6 - added multiple database support. """ - if alias not in _connections: + if alias in _connections: + prev_conn_setting = _connection_settings[alias] + new_conn_settings = _get_connection_settings(db, **kwargs) + + if new_conn_settings != prev_conn_setting: + raise MongoEngineConnectionError( + 'A different connection with alias `%s` was already registered. Use disconnect() first' % alias) + else: register_connection(alias, db, **kwargs) return get_connection(alias) diff --git a/mongoengine/document.py b/mongoengine/document.py index 328ac299..753520c7 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -188,10 +188,16 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): return get_db(cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME)) @classmethod - def _get_collection(cls): - """Return a PyMongo collection for the document.""" - if not hasattr(cls, '_collection') or cls._collection is None: + def _disconnect(cls): + """Detach the Document class from the (cached) database collection""" + cls._collection = None + @classmethod + def _get_collection(cls): + """Return the corresponding PyMongo collection of this document. + Upon the first call, it will ensure that indexes gets created. The returned collection then gets cached + """ + if not hasattr(cls, '_collection') or cls._collection is None: # Get the collection, either capped or regular. if cls._meta.get('max_size') or cls._meta.get('max_documents'): cls._collection = cls._get_capped_collection() @@ -789,13 +795,13 @@ class Document(six.with_metaclass(TopLevelDocumentMetaclass, BaseDocument)): .. versionchanged:: 0.10.7 :class:`OperationError` exception raised if no collection available """ - col_name = cls._get_collection_name() - if not col_name: + coll_name = cls._get_collection_name() + if not coll_name: raise OperationError('Document %s has no collection defined ' '(is it abstract ?)' % cls) cls._collection = None db = cls._get_db() - db.drop_collection(col_name) + db.drop_collection(coll_name) @classmethod def create_index(cls, keys, background=False, **kwargs): diff --git a/tests/test_connection.py b/tests/test_connection.py index fb2a20d7..e5e10479 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,4 +1,6 @@ import datetime + +from pymongo import MongoClient from pymongo.errors import OperationFailure, InvalidName try: @@ -12,12 +14,12 @@ from bson.tz_util import utc from mongoengine import ( connect, register_connection, - Document, DateTimeField -) + Document, DateTimeField, + disconnect_all, StringField) from mongoengine.pymongo_support import IS_PYMONGO_3 import mongoengine.connection from mongoengine.connection import (MongoEngineConnectionError, get_db, - get_connection) + get_connection, disconnect, DEFAULT_DATABASE_NAME) def get_tz_awareness(connection): @@ -29,6 +31,14 @@ def get_tz_awareness(connection): class ConnectionTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + disconnect_all() + + @classmethod + def tearDownClass(cls): + disconnect_all() + def tearDown(self): mongoengine.connection._connection_settings = {} mongoengine.connection._connections = {} @@ -49,6 +59,117 @@ class ConnectionTest(unittest.TestCase): conn = get_connection('testdb') self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) + def test_connect_disconnect_works_properly(self): + class History1(Document): + name = StringField() + meta = {'db_alias': 'db1'} + + class History2(Document): + name = StringField() + meta = {'db_alias': 'db2'} + + connect('db1', alias='db1') + connect('db2', alias='db2') + + History1.drop_collection() + History2.drop_collection() + + h = History1(name='default').save() + h1 = History2(name='db1').save() + + self.assertEqual(list(History1.objects().as_pymongo()), + [{'_id': h.id, 'name': 'default'}]) + self.assertEqual(list(History2.objects().as_pymongo()), + [{'_id': h1.id, 'name': 'db1'}]) + + disconnect('db1') + disconnect('db2') + + with self.assertRaises(MongoEngineConnectionError): + list(History1.objects().as_pymongo()) + + with self.assertRaises(MongoEngineConnectionError): + list(History2.objects().as_pymongo()) + + connect('db1', alias='db1') + connect('db2', alias='db2') + + self.assertEqual(list(History1.objects().as_pymongo()), + [{'_id': h.id, 'name': 'default'}]) + self.assertEqual(list(History2.objects().as_pymongo()), + [{'_id': h1.id, 'name': 'db1'}]) + + def test_connect_different_documents_to_different_database(self): + class History(Document): + name = StringField() + + class History1(Document): + name = StringField() + meta = {'db_alias': 'db1'} + + class History2(Document): + name = StringField() + meta = {'db_alias': 'db2'} + + connect() + connect('db1', alias='db1') + connect('db2', alias='db2') + + History.drop_collection() + History1.drop_collection() + History2.drop_collection() + + h = History(name='default').save() + h1 = History1(name='db1').save() + h2 = History2(name='db2').save() + + self.assertEqual(History._collection.database.name, DEFAULT_DATABASE_NAME) + self.assertEqual(History1._collection.database.name, 'db1') + self.assertEqual(History2._collection.database.name, 'db2') + + self.assertEqual(list(History.objects().as_pymongo()), + [{'_id': h.id, 'name': 'default'}]) + self.assertEqual(list(History1.objects().as_pymongo()), + [{'_id': h1.id, 'name': 'db1'}]) + self.assertEqual(list(History2.objects().as_pymongo()), + [{'_id': h2.id, 'name': 'db2'}]) + + def test_connect_fails_if_connect_2_times_with_default_alias(self): + connect('mongoenginetest') + + with self.assertRaises(MongoEngineConnectionError) as ctx_err: + connect('mongoenginetest2') + self.assertEqual("A different connection with alias `default` was already registered. Use disconnect() first", str(ctx_err.exception)) + + def test_connect_fails_if_connect_2_times_with_custom_alias(self): + connect('mongoenginetest', alias='alias1') + + with self.assertRaises(MongoEngineConnectionError) as ctx_err: + connect('mongoenginetest2', alias='alias1') + + self.assertEqual("A different connection with alias `alias1` was already registered. Use disconnect() first", str(ctx_err.exception)) + + def test_connect_fails_if_similar_connection_settings_arent_defined_the_same_way(self): + """Intended to keep the detecton function simple but robust""" + db_name = 'mongoenginetest' + db_alias = 'alias1' + connect(db=db_name, alias=db_alias, host='localhost', port=27017) + + with self.assertRaises(MongoEngineConnectionError): + connect(host='mongodb://localhost:27017/%s' % db_name, alias=db_alias) + + def test_connect_passes_silently_connect_multiple_times_with_same_config(self): + # test default connection to `test` + connect() + connect() + self.assertEqual(len(mongoengine.connection._connections), 1) + connect('test01', alias='test01') + connect('test01', alias='test01') + self.assertEqual(len(mongoengine.connection._connections), 2) + connect(host='mongodb://localhost:27017/mongoenginetest02', alias='test02') + connect(host='mongodb://localhost:27017/mongoenginetest02', alias='test02') + self.assertEqual(len(mongoengine.connection._connections), 3) + def test_connect_with_invalid_db_name(self): """Ensure that connect() method fails fast if db name is invalid """ @@ -149,13 +270,133 @@ class ConnectionTest(unittest.TestCase): conn = get_connection('testdb6') self.assertIsInstance(conn, mongomock.MongoClient) - def test_disconnect(self): - """Ensure that the disconnect() method works properly - """ + def test_disconnect_cleans_globals(self): + """Ensure that the disconnect() method cleans the globals objects""" + connections = mongoengine.connection._connections + dbs = mongoengine.connection._dbs + connection_settings = mongoengine.connection._connection_settings + + connect('mongoenginetest') + + self.assertEqual(len(connections), 1) + self.assertEqual(len(dbs), 0) + self.assertEqual(len(connection_settings), 1) + + class TestDoc(Document): + pass + + TestDoc.drop_collection() # triggers the db + self.assertEqual(len(dbs), 1) + + disconnect() + self.assertEqual(len(connections), 0) + self.assertEqual(len(dbs), 0) + self.assertEqual(len(connection_settings), 0) + + def test_disconnect_cleans_cached_collection_attribute_in_document(self): + """Ensure that the disconnect() method works properly""" conn1 = connect('mongoenginetest') - mongoengine.connection.disconnect() - conn2 = connect('mongoenginetest') - self.assertTrue(conn1 is not conn2) + + class History(Document): + pass + + self.assertIsNone(History._collection) + + History.drop_collection() + + History.objects.first() # will trigger the caching of _collection attribute + self.assertIsNotNone(History._collection) + + disconnect() + + self.assertIsNone(History._collection) + + with self.assertRaises(MongoEngineConnectionError) as ctx_err: + History.objects.first() + self.assertEqual("You have not defined a default connection", str(ctx_err.exception)) + + def test_connect_disconnect_works_on_same_document(self): + """Ensure that the connect/disconnect works properly with a single Document""" + db1 = 'db1' + db2 = 'db2' + + # Ensure freshness of the 2 databases through pymongo + client = MongoClient('localhost', 27017) + client.drop_database(db1) + client.drop_database(db2) + + # Save in db1 + connect(db1) + + class User(Document): + name = StringField(required=True) + + user1 = User(name='John is in db1').save() + disconnect() + + # Make sure save doesnt work at this stage + with self.assertRaises(MongoEngineConnectionError): + User(name='Wont work').save() + + # Save in db2 + connect(db2) + user2 = User(name='Bob is in db2').save() + disconnect() + + db1_users = list(client[db1].user.find()) + self.assertEqual(db1_users, [{'_id': user1.id, 'name': 'John is in db1'}]) + db2_users = list(client[db2].user.find()) + self.assertEqual(db2_users, [{'_id': user2.id, 'name': 'Bob is in db2'}]) + + def test_disconnect_silently_pass_if_alias_does_not_exist(self): + connections = mongoengine.connection._connections + self.assertEqual(len(connections), 0) + disconnect(alias='not_exist') + + def test_disconnect_all(self): + connections = mongoengine.connection._connections + dbs = mongoengine.connection._dbs + connection_settings = mongoengine.connection._connection_settings + + connect('mongoenginetest') + connect('mongoenginetest2', alias='db1') + + class History(Document): + pass + + class History1(Document): + name = StringField() + meta = {'db_alias': 'db1'} + + History.drop_collection() # will trigger the caching of _collection attribute + History.objects.first() + History1.drop_collection() + History1.objects.first() + + self.assertIsNotNone(History._collection) + self.assertIsNotNone(History1._collection) + + self.assertEqual(len(connections), 2) + self.assertEqual(len(dbs), 2) + self.assertEqual(len(connection_settings), 2) + + disconnect_all() + + self.assertIsNone(History._collection) + self.assertIsNone(History1._collection) + + self.assertEqual(len(connections), 0) + self.assertEqual(len(dbs), 0) + self.assertEqual(len(connection_settings), 0) + + with self.assertRaises(MongoEngineConnectionError): + History.objects.first() + + with self.assertRaises(MongoEngineConnectionError): + History1.objects.first() + + def test_disconnect_all_silently_pass_if_no_connection_exist(self): + disconnect_all() def test_sharing_connections(self): """Ensure that connections are shared when the connection settings are exactly the same @@ -372,7 +613,7 @@ class ConnectionTest(unittest.TestCase): with self.assertRaises(MongoEngineConnectionError): c = connect(replicaset='local-rs') - def test_datetime(self): + def test_connect_tz_aware(self): connect('mongoenginetest', tz_aware=True) d = datetime.datetime(2010, 5, 5, tzinfo=utc) diff --git a/tests/test_context_managers.py b/tests/test_context_managers.py index 227031e0..22c33b01 100644 --- a/tests/test_context_managers.py +++ b/tests/test_context_managers.py @@ -37,14 +37,15 @@ class ContextManagersTest(unittest.TestCase): def test_switch_collection_context_manager(self): connect('mongoenginetest') - register_connection('testdb-1', 'mongoenginetest2') + register_connection(alias='testdb-1', db='mongoenginetest2') class Group(Document): name = StringField() - Group.drop_collection() + Group.drop_collection() # drops in default + with switch_collection(Group, 'group1') as Group: - Group.drop_collection() + Group.drop_collection() # drops in group1 Group(name="hello - group").save() self.assertEqual(1, Group.objects.count()) diff --git a/tests/utils.py b/tests/utils.py index 3c41f07d..910601b1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,7 +4,7 @@ import unittest from nose.plugins.skip import SkipTest from mongoengine import connect -from mongoengine.connection import get_db +from mongoengine.connection import get_db, disconnect_all from mongoengine.mongodb_support import get_mongodb_version, MONGODB_26, MONGODB_3, MONGODB_32, MONGODB_34 from mongoengine.pymongo_support import IS_PYMONGO_3 @@ -19,6 +19,7 @@ class MongoDBTestCase(unittest.TestCase): @classmethod def setUpClass(cls): + disconnect_all() cls._connection = connect(db=MONGO_TEST_DB) cls._connection.drop_database(MONGO_TEST_DB) cls.db = get_db() @@ -26,6 +27,7 @@ class MongoDBTestCase(unittest.TestCase): @classmethod def tearDownClass(cls): cls._connection.drop_database(MONGO_TEST_DB) + disconnect_all() def get_as_pymongo(doc):