Fix changing databases
Conflicts: mongoengine/connection.py mongoengine/queryset.py
This commit is contained in:
		| @@ -4,11 +4,12 @@ import multiprocessing | ||||
| __all__ = ['ConnectionError', 'connect'] | ||||
|  | ||||
|  | ||||
| _connection_settings = { | ||||
| _connection_defaults = { | ||||
|     'host': 'localhost', | ||||
|     'port': 27017, | ||||
| } | ||||
| _connection = {} | ||||
| _connection_settings = _connection_defaults.copy() | ||||
|  | ||||
| _db_name = None | ||||
| _db_username = None | ||||
| @@ -20,25 +21,25 @@ class ConnectionError(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| def _get_connection(): | ||||
| def _get_connection(reconnect=False): | ||||
|     global _connection | ||||
|     identity = get_identity() | ||||
|     # Connect to the database if not already connected | ||||
|     if _connection.get(identity) is None: | ||||
|     if _connection.get(identity) is None or reconnect: | ||||
|         try: | ||||
|             _connection[identity] = Connection(**_connection_settings) | ||||
|         except: | ||||
|             raise ConnectionError('Cannot connect to the database') | ||||
|     return _connection[identity] | ||||
|  | ||||
| def _get_db(): | ||||
| def _get_db(reconnect=False): | ||||
|     global _db, _connection | ||||
|     identity = get_identity() | ||||
|     # Connect if not already connected | ||||
|     if _connection.get(identity) is None: | ||||
|         _connection[identity] = _get_connection() | ||||
|     if _connection.get(identity) is None or reconnect: | ||||
|         _connection[identity] = _get_connection(reconnect=reconnect) | ||||
|  | ||||
|     if _db.get(identity) is None: | ||||
|     if _db.get(identity) is None or reconnect: | ||||
|         # _db_name will be None if the user hasn't called connect() | ||||
|         if _db_name is None: | ||||
|             raise ConnectionError('Not connected to the database') | ||||
| @@ -61,9 +62,10 @@ def connect(db, username=None, password=None, **kwargs): | ||||
|     the default port on localhost. If authentication is needed, provide | ||||
|     username and password arguments as well. | ||||
|     """ | ||||
|     global _connection_settings, _db_name, _db_username, _db_password | ||||
|     _connection_settings.update(kwargs) | ||||
|     global _connection_settings, _db_name, _db_username, _db_password, _db | ||||
|     _connection_settings = dict(_connection_defaults, **kwargs) | ||||
|     _db_name = db | ||||
|     _db_username = username | ||||
|     _db_password = password | ||||
|     return _get_db() | ||||
|     return _get_db(reconnect=True) | ||||
|  | ||||
|   | ||||
| @@ -977,7 +977,7 @@ class QuerySetManager(object): | ||||
|  | ||||
|     def __init__(self, manager_func=None): | ||||
|         self._manager_func = manager_func | ||||
|         self._collection = None | ||||
|         self._collections = {} | ||||
|  | ||||
|     def __get__(self, instance, owner): | ||||
|         """Descriptor for instantiating a new QuerySet object when | ||||
| @@ -987,8 +987,8 @@ class QuerySetManager(object): | ||||
|             # Document class being used rather than a document object | ||||
|             return self | ||||
|  | ||||
|         if self._collection is None: | ||||
|         db = _get_db() | ||||
|         if db not in self._collections: | ||||
|             collection = owner._meta['collection'] | ||||
|  | ||||
|             # Create collection as a capped collection if specified | ||||
| @@ -998,10 +998,10 @@ class QuerySetManager(object): | ||||
|                 max_documents = owner._meta['max_documents'] | ||||
|  | ||||
|                 if collection in db.collection_names(): | ||||
|                     self._collection = db[collection] | ||||
|                     self._collections[db] = db[collection] | ||||
|                     # The collection already exists, check if its capped | ||||
|                     # options match the specified capped options | ||||
|                     options = self._collection.options() | ||||
|                     options = self._collections[db].options() | ||||
|                     if options.get('max') != max_documents or \ | ||||
|                        options.get('size') != max_size: | ||||
|                         msg = ('Cannot create collection "%s" as a capped ' | ||||
| @@ -1012,13 +1012,15 @@ class QuerySetManager(object): | ||||
|                     opts = {'capped': True, 'size': max_size} | ||||
|                     if max_documents: | ||||
|                         opts['max'] = max_documents | ||||
|                     self._collection = db.create_collection(collection, **opts) | ||||
|                     self._collections[db] = db.create_collection( | ||||
|                         collection, **opts | ||||
|                     ) | ||||
|             else: | ||||
|                 self._collection = db[collection] | ||||
|                 self._collections[db] = db[collection] | ||||
|  | ||||
|         # owner is the document that contains the QuerySetManager | ||||
|         queryset_class = owner._meta['queryset_class'] or QuerySet | ||||
|         queryset = queryset_class(owner, self._collection) | ||||
|         queryset = queryset_class(owner, self._collections[db]) | ||||
|         if self._manager_func: | ||||
|             if self._manager_func.func_code.co_argcount == 1: | ||||
|                 queryset = self._manager_func(queryset) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user