Fix changing databases
Conflicts: mongoengine/connection.py mongoengine/queryset.py
This commit is contained in:
		| @@ -4,11 +4,12 @@ import multiprocessing | |||||||
| __all__ = ['ConnectionError', 'connect'] | __all__ = ['ConnectionError', 'connect'] | ||||||
|  |  | ||||||
|  |  | ||||||
| _connection_settings = { | _connection_defaults = { | ||||||
|     'host': 'localhost', |     'host': 'localhost', | ||||||
|     'port': 27017, |     'port': 27017, | ||||||
| } | } | ||||||
| _connection = {} | _connection = {} | ||||||
|  | _connection_settings = _connection_defaults.copy() | ||||||
|  |  | ||||||
| _db_name = None | _db_name = None | ||||||
| _db_username = None | _db_username = None | ||||||
| @@ -20,25 +21,25 @@ class ConnectionError(Exception): | |||||||
|     pass |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
| def _get_connection(): | def _get_connection(reconnect=False): | ||||||
|     global _connection |     global _connection | ||||||
|     identity = get_identity() |     identity = get_identity() | ||||||
|     # Connect to the database if not already connected |     # Connect to the database if not already connected | ||||||
|     if _connection.get(identity) is None: |     if _connection.get(identity) is None or reconnect: | ||||||
|         try: |         try: | ||||||
|             _connection[identity] = Connection(**_connection_settings) |             _connection[identity] = Connection(**_connection_settings) | ||||||
|         except: |         except: | ||||||
|             raise ConnectionError('Cannot connect to the database') |             raise ConnectionError('Cannot connect to the database') | ||||||
|     return _connection[identity] |     return _connection[identity] | ||||||
|  |  | ||||||
| def _get_db(): | def _get_db(reconnect=False): | ||||||
|     global _db, _connection |     global _db, _connection | ||||||
|     identity = get_identity() |     identity = get_identity() | ||||||
|     # Connect if not already connected |     # Connect if not already connected | ||||||
|     if _connection.get(identity) is None: |     if _connection.get(identity) is None or reconnect: | ||||||
|         _connection[identity] = _get_connection() |         _connection[identity] = _get_connection(reconnect=reconnect) | ||||||
|  |  | ||||||
|     if _db.get(identity) is None: |     if _db.get(identity) is None or reconnect: | ||||||
|         # _db_name will be None if the user hasn't called connect() |         # _db_name will be None if the user hasn't called connect() | ||||||
|         if _db_name is None: |         if _db_name is None: | ||||||
|             raise ConnectionError('Not connected to the database') |             raise ConnectionError('Not connected to the database') | ||||||
| @@ -61,9 +62,10 @@ def connect(db, username=None, password=None, **kwargs): | |||||||
|     the default port on localhost. If authentication is needed, provide |     the default port on localhost. If authentication is needed, provide | ||||||
|     username and password arguments as well. |     username and password arguments as well. | ||||||
|     """ |     """ | ||||||
|     global _connection_settings, _db_name, _db_username, _db_password |     global _connection_settings, _db_name, _db_username, _db_password, _db | ||||||
|     _connection_settings.update(kwargs) |     _connection_settings = dict(_connection_defaults, **kwargs) | ||||||
|     _db_name = db |     _db_name = db | ||||||
|     _db_username = username |     _db_username = username | ||||||
|     _db_password = password |     _db_password = password | ||||||
|     return _get_db() |     return _get_db(reconnect=True) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -977,7 +977,7 @@ class QuerySetManager(object): | |||||||
|  |  | ||||||
|     def __init__(self, manager_func=None): |     def __init__(self, manager_func=None): | ||||||
|         self._manager_func = manager_func |         self._manager_func = manager_func | ||||||
|         self._collection = None |         self._collections = {} | ||||||
|  |  | ||||||
|     def __get__(self, instance, owner): |     def __get__(self, instance, owner): | ||||||
|         """Descriptor for instantiating a new QuerySet object when |         """Descriptor for instantiating a new QuerySet object when | ||||||
| @@ -987,8 +987,8 @@ class QuerySetManager(object): | |||||||
|             # Document class being used rather than a document object |             # Document class being used rather than a document object | ||||||
|             return self |             return self | ||||||
|  |  | ||||||
|         if self._collection is None: |         db = _get_db() | ||||||
|             db = _get_db() |         if db not in self._collections: | ||||||
|             collection = owner._meta['collection'] |             collection = owner._meta['collection'] | ||||||
|  |  | ||||||
|             # Create collection as a capped collection if specified |             # Create collection as a capped collection if specified | ||||||
| @@ -998,10 +998,10 @@ class QuerySetManager(object): | |||||||
|                 max_documents = owner._meta['max_documents'] |                 max_documents = owner._meta['max_documents'] | ||||||
|  |  | ||||||
|                 if collection in db.collection_names(): |                 if collection in db.collection_names(): | ||||||
|                     self._collection = db[collection] |                     self._collections[db] = db[collection] | ||||||
|                     # The collection already exists, check if its capped |                     # The collection already exists, check if its capped | ||||||
|                     # options match the specified capped options |                     # options match the specified capped options | ||||||
|                     options = self._collection.options() |                     options = self._collections[db].options() | ||||||
|                     if options.get('max') != max_documents or \ |                     if options.get('max') != max_documents or \ | ||||||
|                        options.get('size') != max_size: |                        options.get('size') != max_size: | ||||||
|                         msg = ('Cannot create collection "%s" as a capped ' |                         msg = ('Cannot create collection "%s" as a capped ' | ||||||
| @@ -1012,13 +1012,15 @@ class QuerySetManager(object): | |||||||
|                     opts = {'capped': True, 'size': max_size} |                     opts = {'capped': True, 'size': max_size} | ||||||
|                     if max_documents: |                     if max_documents: | ||||||
|                         opts['max'] = max_documents |                         opts['max'] = max_documents | ||||||
|                     self._collection = db.create_collection(collection, **opts) |                     self._collections[db] = db.create_collection( | ||||||
|  |                         collection, **opts | ||||||
|  |                     ) | ||||||
|             else: |             else: | ||||||
|                 self._collection = db[collection] |                 self._collections[db] = db[collection] | ||||||
|  |  | ||||||
|         # owner is the document that contains the QuerySetManager |         # owner is the document that contains the QuerySetManager | ||||||
|         queryset_class = owner._meta['queryset_class'] or QuerySet |         queryset_class = owner._meta['queryset_class'] or QuerySet | ||||||
|         queryset = queryset_class(owner, self._collection) |         queryset = queryset_class(owner, self._collections[db]) | ||||||
|         if self._manager_func: |         if self._manager_func: | ||||||
|             if self._manager_func.func_code.co_argcount == 1: |             if self._manager_func.func_code.co_argcount == 1: | ||||||
|                 queryset = self._manager_func(queryset) |                 queryset = self._manager_func(queryset) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user