Fix connect/disconnect functions
- expose disconnect - disconnect cleans _connection_settings - disconnect cleans cached collection in Document._collection - re-connecting with the same alias raise an error (must call disconnect in between)
This commit is contained in:
		| @@ -12,12 +12,12 @@ from bson.tz_util import utc | ||||
|  | ||||
| from mongoengine import ( | ||||
|     connect, register_connection, | ||||
|     Document, DateTimeField | ||||
| ) | ||||
|     Document, DateTimeField, | ||||
|     disconnect_all, StringField) | ||||
| from mongoengine.pymongo_support import IS_PYMONGO_3 | ||||
| import mongoengine.connection | ||||
| from mongoengine.connection import (MongoEngineConnectionError, get_db, | ||||
|                                     get_connection) | ||||
|                                     get_connection, disconnect, DEFAULT_DATABASE_NAME) | ||||
|  | ||||
|  | ||||
| def get_tz_awareness(connection): | ||||
| @@ -29,6 +29,14 @@ def get_tz_awareness(connection): | ||||
|  | ||||
| class ConnectionTest(unittest.TestCase): | ||||
|  | ||||
|     @classmethod | ||||
|     def setUpClass(cls): | ||||
|         disconnect_all() | ||||
|  | ||||
|     @classmethod | ||||
|     def tearDownClass(cls): | ||||
|         disconnect_all() | ||||
|  | ||||
|     def tearDown(self): | ||||
|         mongoengine.connection._connection_settings = {} | ||||
|         mongoengine.connection._connections = {} | ||||
| @@ -49,6 +57,117 @@ class ConnectionTest(unittest.TestCase): | ||||
|         conn = get_connection('testdb') | ||||
|         self.assertIsInstance(conn, pymongo.mongo_client.MongoClient) | ||||
|  | ||||
|     def test_connect_disconnect_works_properly(self): | ||||
|         class History1(Document): | ||||
|             name = StringField() | ||||
|             meta = {'db_alias': 'db1'} | ||||
|  | ||||
|         class History2(Document): | ||||
|             name = StringField() | ||||
|             meta = {'db_alias': 'db2'} | ||||
|  | ||||
|         connect('db1', alias='db1') | ||||
|         connect('db2', alias='db2') | ||||
|  | ||||
|         History1.drop_collection() | ||||
|         History2.drop_collection() | ||||
|  | ||||
|         h = History1(name='default').save() | ||||
|         h1 = History2(name='db1').save() | ||||
|  | ||||
|         self.assertEqual(list(History1.objects().as_pymongo()), | ||||
|                          [{'_id': h.id, 'name': 'default'}]) | ||||
|         self.assertEqual(list(History2.objects().as_pymongo()), | ||||
|                          [{'_id': h1.id, 'name': 'db1'}]) | ||||
|  | ||||
|         disconnect('db1') | ||||
|         disconnect('db2') | ||||
|  | ||||
|         with self.assertRaises(MongoEngineConnectionError): | ||||
|             list(History1.objects().as_pymongo()) | ||||
|  | ||||
|         with self.assertRaises(MongoEngineConnectionError): | ||||
|             list(History2.objects().as_pymongo()) | ||||
|  | ||||
|         connect('db1', alias='db1') | ||||
|         connect('db2', alias='db2') | ||||
|  | ||||
|         self.assertEqual(list(History1.objects().as_pymongo()), | ||||
|                          [{'_id': h.id, 'name': 'default'}]) | ||||
|         self.assertEqual(list(History2.objects().as_pymongo()), | ||||
|                          [{'_id': h1.id, 'name': 'db1'}]) | ||||
|  | ||||
|     def test_connect_different_documents_to_different_database(self): | ||||
|         class History(Document): | ||||
|             name = StringField() | ||||
|  | ||||
|         class History1(Document): | ||||
|             name = StringField() | ||||
|             meta = {'db_alias': 'db1'} | ||||
|  | ||||
|         class History2(Document): | ||||
|             name = StringField() | ||||
|             meta = {'db_alias': 'db2'} | ||||
|  | ||||
|         connect() | ||||
|         connect('db1', alias='db1') | ||||
|         connect('db2', alias='db2') | ||||
|  | ||||
|         History.drop_collection() | ||||
|         History1.drop_collection() | ||||
|         History2.drop_collection() | ||||
|  | ||||
|         h = History(name='default').save() | ||||
|         h1 = History1(name='db1').save() | ||||
|         h2 = History2(name='db2').save() | ||||
|  | ||||
|         self.assertEqual(History._collection.database.name, DEFAULT_DATABASE_NAME) | ||||
|         self.assertEqual(History1._collection.database.name, 'db1') | ||||
|         self.assertEqual(History2._collection.database.name, 'db2') | ||||
|  | ||||
|         self.assertEqual(list(History.objects().as_pymongo()), | ||||
|                          [{'_id': h.id, 'name': 'default'}]) | ||||
|         self.assertEqual(list(History1.objects().as_pymongo()), | ||||
|                          [{'_id': h1.id, 'name': 'db1'}]) | ||||
|         self.assertEqual(list(History2.objects().as_pymongo()), | ||||
|                          [{'_id': h2.id, 'name': 'db2'}]) | ||||
|  | ||||
|     def test_connect_fails_if_connect_2_times_with_default_alias(self): | ||||
|         connect('mongoenginetest') | ||||
|  | ||||
|         with self.assertRaises(MongoEngineConnectionError) as ctx_err: | ||||
|             connect('mongoenginetest2') | ||||
|         self.assertEqual("A different connection with alias `default` was already registered. Use disconnect() first", str(ctx_err.exception)) | ||||
|  | ||||
|     def test_connect_fails_if_connect_2_times_with_custom_alias(self): | ||||
|         connect('mongoenginetest', alias='alias1') | ||||
|  | ||||
|         with self.assertRaises(MongoEngineConnectionError) as ctx_err: | ||||
|             connect('mongoenginetest2', alias='alias1') | ||||
|  | ||||
|         self.assertEqual("A different connection with alias `alias1` was already registered. Use disconnect() first", str(ctx_err.exception)) | ||||
|  | ||||
|     def test_connect_fails_if_similar_connection_settings_arent_defined_the_same_way(self): | ||||
|         """Intended to keep the detecton function simple but robust""" | ||||
|         db_name = 'mongoenginetest' | ||||
|         db_alias = 'alias1' | ||||
|         connect(db=db_name, alias=db_alias, host='localhost', port=27017) | ||||
|  | ||||
|         with self.assertRaises(MongoEngineConnectionError): | ||||
|             connect(host='mongodb://localhost:27017/%s' % db_name, alias=db_alias) | ||||
|  | ||||
|     def test_connect_passes_silently_connect_multiple_times_with_same_config(self): | ||||
|         # test default connection to `test` | ||||
|         connect() | ||||
|         connect() | ||||
|         self.assertEqual(len(mongoengine.connection._connections), 1) | ||||
|         connect('test01', alias='test01') | ||||
|         connect('test01', alias='test01') | ||||
|         self.assertEqual(len(mongoengine.connection._connections), 2) | ||||
|         connect(host='mongodb://localhost:27017/mongoenginetest02', alias='test02') | ||||
|         connect(host='mongodb://localhost:27017/mongoenginetest02', alias='test02') | ||||
|         self.assertEqual(len(mongoengine.connection._connections), 3) | ||||
|  | ||||
|     def test_connect_in_mocking(self): | ||||
|         """Ensure that the connect() method works properly in mocking. | ||||
|         """ | ||||
| @@ -120,13 +239,93 @@ class ConnectionTest(unittest.TestCase): | ||||
|         self.assertIsInstance(conn, mongomock.MongoClient) | ||||
|  | ||||
|     def test_disconnect(self): | ||||
|         """Ensure that the disconnect() method works properly | ||||
|         """ | ||||
|         """Ensure that the disconnect() method works properly""" | ||||
|         connections = mongoengine.connection._connections | ||||
|         dbs = mongoengine.connection._dbs | ||||
|         connection_settings = mongoengine.connection._connection_settings | ||||
|  | ||||
|         conn1 = connect('mongoenginetest') | ||||
|         mongoengine.connection.disconnect() | ||||
|  | ||||
|         class History(Document): | ||||
|             pass | ||||
|  | ||||
|         self.assertIsNone(History._collection) | ||||
|  | ||||
|         History.drop_collection() | ||||
|         History.objects.first()     # will trigger the caching of _collection attribute | ||||
|  | ||||
|         self.assertIsNotNone(History._collection) | ||||
|  | ||||
|         self.assertEqual(len(connections), 1) | ||||
|         self.assertEqual(len(dbs), 1) | ||||
|         self.assertEqual(len(connection_settings), 1) | ||||
|  | ||||
|         disconnect() | ||||
|  | ||||
|         self.assertIsNone(History._collection) | ||||
|  | ||||
|         self.assertEqual(len(connections), 0) | ||||
|         self.assertEqual(len(dbs), 0) | ||||
|         self.assertEqual(len(connection_settings), 0) | ||||
|  | ||||
|         with self.assertRaises(MongoEngineConnectionError) as ctx_err: | ||||
|             History.objects.first() | ||||
|         self.assertEqual("You have not defined a default connection", str(ctx_err.exception)) | ||||
|  | ||||
|         conn2 = connect('mongoenginetest') | ||||
|         History.objects.first()     # Make sure its back on track | ||||
|         self.assertTrue(conn1 is not conn2) | ||||
|  | ||||
|     def test_disconnect_silently_pass_if_alias_does_not_exist(self): | ||||
|         connections = mongoengine.connection._connections | ||||
|         self.assertEqual(len(connections), 0) | ||||
|         disconnect(alias='not_exist') | ||||
|  | ||||
|     def test_disconnect_all(self): | ||||
|         connections = mongoengine.connection._connections | ||||
|         dbs = mongoengine.connection._dbs | ||||
|         connection_settings = mongoengine.connection._connection_settings | ||||
|  | ||||
|         connect('mongoenginetest') | ||||
|         connect('mongoenginetest2', alias='db1') | ||||
|  | ||||
|         class History(Document): | ||||
|             pass | ||||
|  | ||||
|         class History1(Document): | ||||
|             name = StringField() | ||||
|             meta = {'db_alias': 'db1'} | ||||
|  | ||||
|         History.drop_collection()   # will trigger the caching of _collection attribute | ||||
|         History.objects.first() | ||||
|         History1.drop_collection() | ||||
|         History1.objects.first() | ||||
|  | ||||
|         self.assertIsNotNone(History._collection) | ||||
|         self.assertIsNotNone(History1._collection) | ||||
|  | ||||
|         self.assertEqual(len(connections), 2) | ||||
|         self.assertEqual(len(dbs), 2) | ||||
|         self.assertEqual(len(connection_settings), 2) | ||||
|  | ||||
|         disconnect_all() | ||||
|  | ||||
|         self.assertIsNone(History._collection) | ||||
|         self.assertIsNone(History1._collection) | ||||
|  | ||||
|         self.assertEqual(len(connections), 0) | ||||
|         self.assertEqual(len(dbs), 0) | ||||
|         self.assertEqual(len(connection_settings), 0) | ||||
|  | ||||
|         with self.assertRaises(MongoEngineConnectionError): | ||||
|             History.objects.first() | ||||
|  | ||||
|         with self.assertRaises(MongoEngineConnectionError): | ||||
|             History1.objects.first() | ||||
|  | ||||
|     def test_disconnect_all_silently_pass_if_no_connection_exist(self): | ||||
|         disconnect_all() | ||||
|  | ||||
|     def test_sharing_connections(self): | ||||
|         """Ensure that connections are shared when the connection settings are exactly the same | ||||
|         """ | ||||
| @@ -342,7 +541,7 @@ class ConnectionTest(unittest.TestCase): | ||||
|             with self.assertRaises(MongoEngineConnectionError): | ||||
|                 c = connect(replicaset='local-rs') | ||||
|  | ||||
|     def test_datetime(self): | ||||
|     def test_connect_tz_aware(self): | ||||
|         connect('mongoenginetest', tz_aware=True) | ||||
|         d = datetime.datetime(2010, 5, 5, tzinfo=utc) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user