minor improvements

This commit is contained in:
Bastien Gérard 2019-04-25 22:11:43 +02:00
parent b1e28d02f7
commit 565e1dc0ed
4 changed files with 18 additions and 8 deletions

View File

@ -5,7 +5,7 @@ Changelog
Development Development
=========== ===========
- expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all` - expose `mongoengine.connection.disconnect` and `mongoengine.connection.disconnect_all`
- POTENTIAL BREAKING CHANGE: Fixes in connect/disconnect methods #565 #566 - POTENTIAL BREAKING CHANGES: Fixes in connect/disconnect methods #565 #566 #605 #607 #1213 #1599
- calling `connect` 2 times with the same alias and different parameter will raise an error (should call disconnect first) - calling `connect` 2 times with the same alias and different parameter will raise an error (should call disconnect first)
- disconnect now clears `mongoengine.connection._connection_settings` - disconnect now clears `mongoengine.connection._connection_settings`
- disconnect now clears the cached attribute `Document._collection` - disconnect now clears the cached attribute `Document._collection`

View File

@ -10,6 +10,8 @@ __all__ = ['MongoEngineConnectionError', 'connect', 'disconnect', 'disconnect_al
DEFAULT_CONNECTION_NAME = 'default' DEFAULT_CONNECTION_NAME = 'default'
DEFAULT_DATABASE_NAME = 'test' DEFAULT_DATABASE_NAME = 'test'
DEFAULT_HOST = 'localhost'
DEFAULT_PORT = 27017
if IS_PYMONGO_3: if IS_PYMONGO_3:
READ_PREFERENCE = ReadPreference.PRIMARY READ_PREFERENCE = ReadPreference.PRIMARY
@ -61,8 +63,8 @@ def _get_connection_settings(
""" """
conn_settings = { conn_settings = {
'name': name or db or DEFAULT_DATABASE_NAME, 'name': name or db or DEFAULT_DATABASE_NAME,
'host': host or 'localhost', 'host': host or DEFAULT_HOST,
'port': port or 27017, 'port': port or DEFAULT_PORT,
'read_preference': read_preference, 'read_preference': read_preference,
'username': username, 'username': username,
'password': password, 'password': password,
@ -172,6 +174,7 @@ def register_connection(alias, db=None, name=None, host=None, port=None,
def disconnect(alias=DEFAULT_CONNECTION_NAME): def disconnect(alias=DEFAULT_CONNECTION_NAME):
"""Close the connection with a given alias.""" """Close the connection with a given alias."""
from mongoengine.base.common import _get_documents_by_db from mongoengine.base.common import _get_documents_by_db
from mongoengine import Document
if alias in _connections: if alias in _connections:
get_connection(alias=alias).close() get_connection(alias=alias).close()
@ -180,7 +183,7 @@ def disconnect(alias=DEFAULT_CONNECTION_NAME):
if alias in _dbs: if alias in _dbs:
# Detach all cached collections in Documents # Detach all cached collections in Documents
for doc_cls in _get_documents_by_db(alias, DEFAULT_CONNECTION_NAME): for doc_cls in _get_documents_by_db(alias, DEFAULT_CONNECTION_NAME):
if hasattr(doc_cls, '_disconnect'): if issubclass(doc_cls, Document): # Skip EmbeddedDocument
doc_cls._disconnect() doc_cls._disconnect()
del _dbs[alias] del _dbs[alias]

View File

@ -249,9 +249,15 @@ class ConnectionTest(unittest.TestCase):
connect('mongoenginetest') connect('mongoenginetest')
self.assertEqual(len(connections), 1) self.assertEqual(len(connections), 1)
self.assertEqual(len(dbs), 1) self.assertEqual(len(dbs), 0)
self.assertEqual(len(connection_settings), 1) self.assertEqual(len(connection_settings), 1)
class TestDoc(Document):
pass
TestDoc.drop_collection() # triggers the db
self.assertEqual(len(dbs), 1)
disconnect() disconnect()
self.assertEqual(len(connections), 0) self.assertEqual(len(connections), 0)
self.assertEqual(len(dbs), 0) self.assertEqual(len(dbs), 0)

View File

@ -37,14 +37,15 @@ class ContextManagersTest(unittest.TestCase):
def test_switch_collection_context_manager(self): def test_switch_collection_context_manager(self):
connect('mongoenginetest') connect('mongoenginetest')
register_connection('testdb-1', 'mongoenginetest2') register_connection(alias='testdb-1', db='mongoenginetest2')
class Group(Document): class Group(Document):
name = StringField() name = StringField()
Group.drop_collection() Group.drop_collection() # drops in default
with switch_collection(Group, 'group1') as Group: with switch_collection(Group, 'group1') as Group:
Group.drop_collection() Group.drop_collection() # drops in group1
Group(name="hello - group").save() Group(name="hello - group").save()
self.assertEqual(1, Group.objects.count()) self.assertEqual(1, Group.objects.count())