Fix changing databases

Conflicts:

	mongoengine/connection.py
	mongoengine/queryset.py
This commit is contained in:
Harry Marr 2010-10-05 00:46:13 +01:00
parent 3acfd90720
commit 92471445ec
2 changed files with 22 additions and 18 deletions

View File

@ -4,11 +4,12 @@ import multiprocessing
__all__ = ['ConnectionError', 'connect'] __all__ = ['ConnectionError', 'connect']
_connection_settings = { _connection_defaults = {
'host': 'localhost', 'host': 'localhost',
'port': 27017, 'port': 27017,
} }
_connection = {} _connection = {}
_connection_settings = _connection_defaults.copy()
_db_name = None _db_name = None
_db_username = None _db_username = None
@ -20,25 +21,25 @@ class ConnectionError(Exception):
pass pass
def _get_connection(): def _get_connection(reconnect=False):
global _connection global _connection
identity = get_identity() identity = get_identity()
# Connect to the database if not already connected # Connect to the database if not already connected
if _connection.get(identity) is None: if _connection.get(identity) is None or reconnect:
try: try:
_connection[identity] = Connection(**_connection_settings) _connection[identity] = Connection(**_connection_settings)
except: except:
raise ConnectionError('Cannot connect to the database') raise ConnectionError('Cannot connect to the database')
return _connection[identity] return _connection[identity]
def _get_db(): def _get_db(reconnect=False):
global _db, _connection global _db, _connection
identity = get_identity() identity = get_identity()
# Connect if not already connected # Connect if not already connected
if _connection.get(identity) is None: if _connection.get(identity) is None or reconnect:
_connection[identity] = _get_connection() _connection[identity] = _get_connection(reconnect=reconnect)
if _db.get(identity) is None: if _db.get(identity) is None or reconnect:
# _db_name will be None if the user hasn't called connect() # _db_name will be None if the user hasn't called connect()
if _db_name is None: if _db_name is None:
raise ConnectionError('Not connected to the database') raise ConnectionError('Not connected to the database')
@ -61,9 +62,10 @@ def connect(db, username=None, password=None, **kwargs):
the default port on localhost. If authentication is needed, provide the default port on localhost. If authentication is needed, provide
username and password arguments as well. username and password arguments as well.
""" """
global _connection_settings, _db_name, _db_username, _db_password global _connection_settings, _db_name, _db_username, _db_password, _db
_connection_settings.update(kwargs) _connection_settings = dict(_connection_defaults, **kwargs)
_db_name = db _db_name = db
_db_username = username _db_username = username
_db_password = password _db_password = password
return _get_db() return _get_db(reconnect=True)

View File

@ -977,7 +977,7 @@ class QuerySetManager(object):
def __init__(self, manager_func=None): def __init__(self, manager_func=None):
self._manager_func = manager_func self._manager_func = manager_func
self._collection = None self._collections = {}
def __get__(self, instance, owner): def __get__(self, instance, owner):
"""Descriptor for instantiating a new QuerySet object when """Descriptor for instantiating a new QuerySet object when
@ -987,8 +987,8 @@ class QuerySetManager(object):
# Document class being used rather than a document object # Document class being used rather than a document object
return self return self
if self._collection is None: db = _get_db()
db = _get_db() if db not in self._collections:
collection = owner._meta['collection'] collection = owner._meta['collection']
# Create collection as a capped collection if specified # Create collection as a capped collection if specified
@ -998,10 +998,10 @@ class QuerySetManager(object):
max_documents = owner._meta['max_documents'] max_documents = owner._meta['max_documents']
if collection in db.collection_names(): if collection in db.collection_names():
self._collection = db[collection] self._collections[db] = db[collection]
# The collection already exists, check if its capped # The collection already exists, check if its capped
# options match the specified capped options # options match the specified capped options
options = self._collection.options() options = self._collections[db].options()
if options.get('max') != max_documents or \ if options.get('max') != max_documents or \
options.get('size') != max_size: options.get('size') != max_size:
msg = ('Cannot create collection "%s" as a capped ' msg = ('Cannot create collection "%s" as a capped '
@ -1012,13 +1012,15 @@ class QuerySetManager(object):
opts = {'capped': True, 'size': max_size} opts = {'capped': True, 'size': max_size}
if max_documents: if max_documents:
opts['max'] = max_documents opts['max'] = max_documents
self._collection = db.create_collection(collection, **opts) self._collections[db] = db.create_collection(
collection, **opts
)
else: else:
self._collection = db[collection] self._collections[db] = db[collection]
# owner is the document that contains the QuerySetManager # owner is the document that contains the QuerySetManager
queryset_class = owner._meta['queryset_class'] or QuerySet queryset_class = owner._meta['queryset_class'] or QuerySet
queryset = queryset_class(owner, self._collection) queryset = queryset_class(owner, self._collections[db])
if self._manager_func: if self._manager_func:
if self._manager_func.func_code.co_argcount == 1: if self._manager_func.func_code.co_argcount == 1:
queryset = self._manager_func(queryset) queryset = self._manager_func(queryset)