diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 94cc6ea1..814fde13 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -4,11 +4,12 @@ import multiprocessing __all__ = ['ConnectionError', 'connect'] -_connection_settings = { +_connection_defaults = { 'host': 'localhost', 'port': 27017, } _connection = {} +_connection_settings = _connection_defaults.copy() _db_name = None _db_username = None @@ -20,25 +21,25 @@ class ConnectionError(Exception): pass -def _get_connection(): +def _get_connection(reconnect=False): global _connection identity = get_identity() # Connect to the database if not already connected - if _connection.get(identity) is None: + if _connection.get(identity) is None or reconnect: try: _connection[identity] = Connection(**_connection_settings) except: raise ConnectionError('Cannot connect to the database') return _connection[identity] -def _get_db(): +def _get_db(reconnect=False): global _db, _connection identity = get_identity() # Connect if not already connected - if _connection.get(identity) is None: - _connection[identity] = _get_connection() + if _connection.get(identity) is None or reconnect: + _connection[identity] = _get_connection(reconnect=reconnect) - if _db.get(identity) is None: + if _db.get(identity) is None or reconnect: # _db_name will be None if the user hasn't called connect() if _db_name is None: raise ConnectionError('Not connected to the database') @@ -61,9 +62,10 @@ def connect(db, username=None, password=None, **kwargs): the default port on localhost. If authentication is needed, provide username and password arguments as well. """ - global _connection_settings, _db_name, _db_username, _db_password - _connection_settings.update(kwargs) + global _connection_settings, _db_name, _db_username, _db_password, _db + _connection_settings = dict(_connection_defaults, **kwargs) _db_name = db _db_username = username _db_password = password - return _get_db() \ No newline at end of file + return _get_db(reconnect=True) + diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 99417850..fae2aabf 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -977,7 +977,7 @@ class QuerySetManager(object): def __init__(self, manager_func=None): self._manager_func = manager_func - self._collection = None + self._collections = {} def __get__(self, instance, owner): """Descriptor for instantiating a new QuerySet object when @@ -987,8 +987,8 @@ class QuerySetManager(object): # Document class being used rather than a document object return self - if self._collection is None: - db = _get_db() + db = _get_db() + if db not in self._collections: collection = owner._meta['collection'] # Create collection as a capped collection if specified @@ -998,10 +998,10 @@ class QuerySetManager(object): max_documents = owner._meta['max_documents'] if collection in db.collection_names(): - self._collection = db[collection] + self._collections[db] = db[collection] # The collection already exists, check if its capped # options match the specified capped options - options = self._collection.options() + options = self._collections[db].options() if options.get('max') != max_documents or \ options.get('size') != max_size: msg = ('Cannot create collection "%s" as a capped ' @@ -1012,13 +1012,15 @@ class QuerySetManager(object): opts = {'capped': True, 'size': max_size} if max_documents: opts['max'] = max_documents - self._collection = db.create_collection(collection, **opts) + self._collections[db] = db.create_collection( + collection, **opts + ) else: - self._collection = db[collection] + self._collections[db] = db[collection] # owner is the document that contains the QuerySetManager queryset_class = owner._meta['queryset_class'] or QuerySet - queryset = queryset_class(owner, self._collection) + queryset = queryset_class(owner, self._collections[db]) if self._manager_func: if self._manager_func.func_code.co_argcount == 1: queryset = self._manager_func(queryset)