Added multidb support
No change required to upgrade to multiple databases. Aliases are used to describe the database and these can be manually registered or fall through to a default alias using connect. Made get_connection and get_db first class members of the connection class. Old style _get_connection and _get_db still supported. Refs: #84 #87 #93 #215
This commit is contained in:
		| @@ -1,82 +1,106 @@ | ||||
| from pymongo import Connection | ||||
| import multiprocessing | ||||
| import threading | ||||
|  | ||||
| __all__ = ['ConnectionError', 'connect'] | ||||
|  | ||||
|  | ||||
| _connection_defaults = { | ||||
|     'host': 'localhost', | ||||
|     'port': 27017, | ||||
| } | ||||
| _connection = {} | ||||
| _connection_settings = _connection_defaults.copy() | ||||
| __all__ = ['ConnectionError', 'connect', 'register_connection'] | ||||
|  | ||||
| _db_name = None | ||||
| _db_username = None | ||||
| _db_password = None | ||||
| _db = {} | ||||
|  | ||||
| DEFAULT_CONNECTION_NAME = 'default' | ||||
|  | ||||
|  | ||||
| class ConnectionError(Exception): | ||||
|     pass | ||||
|  | ||||
|  | ||||
| def _get_connection(reconnect=False): | ||||
|     """Handles the connection to the database | ||||
| _connection_settings = {} | ||||
| _connections = {} | ||||
| _dbs = {} | ||||
|  | ||||
|  | ||||
| def register_connection(alias, name, host='localhost', port=27017, | ||||
|                         is_slave=False, slaves=None, username=None, | ||||
|                         password=None): | ||||
|     """Add a connection. | ||||
|  | ||||
|     :param alias: the name that will be used to refer to this connection | ||||
|         throughout MongoEngine | ||||
|     :param name: the name of the specific database to use | ||||
|     :param host: the host name of the :program:`mongod` instance to connect to | ||||
|     :param port: the port that the :program:`mongod` instance is running on | ||||
|     :param is_slave: whether the connection can act as a slave | ||||
|     :param slaves: a list of aliases of slave connections; each of these must | ||||
|         be a registered connection that has :attr:`is_slave` set to ``True`` | ||||
|     :param username: username to authenticate with | ||||
|     :param password: password to authenticate with | ||||
|     """ | ||||
|     global _connection | ||||
|     identity = get_identity() | ||||
|     global _connection_settings | ||||
|     _connection_settings[alias] = { | ||||
|         'name': name, | ||||
|         'host': host, | ||||
|         'port': port, | ||||
|         'is_slave': is_slave, | ||||
|         'slaves': slaves or [], | ||||
|         'username': username, | ||||
|         'password': password, | ||||
|     } | ||||
|  | ||||
|  | ||||
| def get_connection(alias=DEFAULT_CONNECTION_NAME): | ||||
|     global _connections | ||||
|     # Connect to the database if not already connected | ||||
|     if _connection.get(identity) is None or reconnect: | ||||
|     if alias not in _connections: | ||||
|         if alias not in _connection_settings: | ||||
|             msg = 'Connection with alias "%s" has not been defined' | ||||
|             if alias == DEFAULT_CONNECTION_NAME: | ||||
|                 msg = 'You have not defined a default connection' | ||||
|             raise ConnectionError(msg) | ||||
|         conn_settings = _connection_settings[alias].copy() | ||||
|  | ||||
|         # Get all the slave connections | ||||
|         slaves = [] | ||||
|         for slave_alias in conn_settings['slaves']: | ||||
|             slaves.append(get_connection(slave_alias)) | ||||
|         conn_settings['slaves'] = slaves | ||||
|  | ||||
|         try: | ||||
|             _connection[identity] = Connection(**_connection_settings) | ||||
|             _connections[alias] = Connection(**conn_settings) | ||||
|         except Exception, e: | ||||
|             raise ConnectionError("Cannot connect to the database:\n%s" % e) | ||||
|     return _connection[identity] | ||||
|             raise e | ||||
|             raise ConnectionError('Cannot connect to database %s' % alias) | ||||
|     return _connections[alias] | ||||
|  | ||||
| def _get_db(reconnect=False): | ||||
|     """Handles database connections and authentication based on the current | ||||
|     identity | ||||
|  | ||||
| def get_db(alias=DEFAULT_CONNECTION_NAME): | ||||
|     global _dbs | ||||
|     if alias not in _dbs: | ||||
|         conn = get_connection(alias) | ||||
|         conn_settings = _connection_settings[alias] | ||||
|         _dbs[alias] = conn[conn_settings['name']] | ||||
|  | ||||
|         # Authenticate if necessary | ||||
|         if conn_settings['username'] and conn_settings['password']: | ||||
|             _dbs[alias].authenticate(conn_settings['username'], | ||||
|                                      conn_settings['password']) | ||||
|     return _dbs[alias] | ||||
|  | ||||
|  | ||||
| def connect(db, alias=DEFAULT_CONNECTION_NAME, **kwargs): | ||||
|     """Connect to the database specified by the 'db' argument. | ||||
|  | ||||
|     Connection settings may be provided here as well if the database is not | ||||
|     running on the default port on localhost. If authentication is needed, | ||||
|     provide username and password arguments as well. | ||||
|  | ||||
|     Multiple databases are supported by using aliases.  Provide a separate | ||||
|     `alias` to connect to a different instance of :program:`mongod`. | ||||
|  | ||||
|     .. versionchanged:: 0.6 - added multiple database support. | ||||
|     """ | ||||
|     global _db, _connection | ||||
|     identity = get_identity() | ||||
|     # Connect if not already connected | ||||
|     if _connection.get(identity) is None or reconnect: | ||||
|         _connection[identity] = _get_connection(reconnect=reconnect) | ||||
|     global _connections | ||||
|     if alias not in _connections: | ||||
|         register_connection(alias, db, **kwargs) | ||||
|  | ||||
|     if _db.get(identity) is None or reconnect: | ||||
|         # _db_name will be None if the user hasn't called connect() | ||||
|         if _db_name is None: | ||||
|             raise ConnectionError('Not connected to the database') | ||||
|  | ||||
|         # Get DB from current connection and authenticate if necessary | ||||
|         _db[identity] = _connection[identity][_db_name] | ||||
|         if _db_username and _db_password: | ||||
|             _db[identity].authenticate(_db_username, _db_password) | ||||
|  | ||||
|     return _db[identity] | ||||
|  | ||||
| def get_identity(): | ||||
|     """Creates an identity key based on the current process and thread | ||||
|     identity. | ||||
|     """ | ||||
|     identity = multiprocessing.current_process()._identity | ||||
|     identity = 0 if not identity else identity[0] | ||||
|  | ||||
|     identity = (identity, threading.current_thread().ident) | ||||
|     return identity | ||||
|  | ||||
| def connect(db, username=None, password=None, **kwargs): | ||||
|     """Connect to the database specified by the 'db' argument. Connection | ||||
|     settings may be provided here as well if the database is not running on | ||||
|     the default port on localhost. If authentication is needed, provide | ||||
|     username and password arguments as well. | ||||
|     """ | ||||
|     global _connection_settings, _db_name, _db_username, _db_password, _db | ||||
|     _connection_settings = dict(_connection_defaults, **kwargs) | ||||
|     _db_name = db | ||||
|     _db_username = username | ||||
|     _db_password = password | ||||
|     return _get_db(reconnect=True) | ||||
|     return get_connection(alias) | ||||
|  | ||||
| # Support old naming convention | ||||
| _get_connection = get_connection | ||||
| _get_db = get_db | ||||
|   | ||||
| @@ -1,10 +1,8 @@ | ||||
| import operator | ||||
|  | ||||
| import pymongo | ||||
|  | ||||
| from base import BaseDict, BaseList, get_document, TopLevelDocumentMetaclass | ||||
| from fields import ReferenceField | ||||
| from connection import _get_db | ||||
| from connection import get_db | ||||
| from queryset import QuerySet | ||||
| from document import Document | ||||
|  | ||||
| @@ -103,7 +101,7 @@ class DeReference(object): | ||||
|                 for key, doc in references.iteritems(): | ||||
|                     object_map[key] = doc | ||||
|             else:  # Generic reference: use the refs data to convert to document | ||||
|                 references = _get_db()[col].find({'_id': {'$in': refs}}) | ||||
|                 references = get_db()[col].find({'_id': {'$in': refs}}) | ||||
|                 for ref in references: | ||||
|                     if '_cls' in ref: | ||||
|                         doc = get_document(ref['_cls'])._from_son(ref) | ||||
|   | ||||
| @@ -1,14 +1,14 @@ | ||||
| import operator | ||||
| from mongoengine import signals | ||||
| from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, | ||||
|                   ValidationError, BaseDict, BaseList, BaseDynamicField) | ||||
| from queryset import OperationError | ||||
| from connection import _get_db | ||||
| from connection import get_db | ||||
|  | ||||
| import pymongo | ||||
|  | ||||
| __all__ = ['Document', 'EmbeddedDocument', 'DynamicDocument', 'DynamicEmbeddedDocument', | ||||
|            'ValidationError', 'OperationError', 'InvalidCollectionError'] | ||||
| __all__ = ['Document', 'EmbeddedDocument', 'DynamicDocument', | ||||
|            'DynamicEmbeddedDocument', 'ValidationError', 'OperationError', | ||||
|            'InvalidCollectionError'] | ||||
|  | ||||
|  | ||||
| class InvalidCollectionError(Exception): | ||||
| @@ -76,7 +76,7 @@ class Document(BaseDocument): | ||||
|     by setting index_types to False on the meta dictionary for the document. | ||||
|     """ | ||||
|     __metaclass__ = TopLevelDocumentMetaclass | ||||
|      | ||||
|  | ||||
|     @apply | ||||
|     def pk(): | ||||
|         """Primary key alias | ||||
| @@ -91,7 +91,7 @@ class Document(BaseDocument): | ||||
|     def _get_collection(self): | ||||
|         """Returns the collection for the document.""" | ||||
|         if not hasattr(self, '_collection') or self._collection is None: | ||||
|             db = _get_db() | ||||
|             db = get_db() | ||||
|             collection_name = self._get_collection_name() | ||||
|             # Create collection as a capped collection if specified | ||||
|             if self._meta['max_size'] or self._meta['max_documents']: | ||||
| @@ -300,7 +300,7 @@ class Document(BaseDocument): | ||||
|         :class:`~mongoengine.Document` type from the database. | ||||
|         """ | ||||
|         from mongoengine.queryset import QuerySet | ||||
|         db = _get_db() | ||||
|         db = get_db() | ||||
|         db.drop_collection(cls._get_collection_name()) | ||||
|         QuerySet._reset_already_indexed(cls) | ||||
|  | ||||
|   | ||||
| @@ -13,7 +13,7 @@ from base import (BaseField, ComplexBaseField, ObjectIdField, | ||||
|                   ValidationError, get_document) | ||||
| from queryset import DO_NOTHING | ||||
| from document import Document, EmbeddedDocument | ||||
| from connection import _get_db | ||||
| from connection import get_db | ||||
| from operator import itemgetter | ||||
|  | ||||
|  | ||||
| @@ -637,7 +637,7 @@ class ReferenceField(BaseField): | ||||
|         value = instance._data.get(self.name) | ||||
|         # Dereference DBRefs | ||||
|         if isinstance(value, (pymongo.dbref.DBRef)): | ||||
|             value = _get_db().dereference(value) | ||||
|             value = get_db().dereference(value) | ||||
|             if value is not None: | ||||
|                 instance._data[self.name] = self.document_type._from_son(value) | ||||
|  | ||||
| @@ -710,7 +710,7 @@ class GenericReferenceField(BaseField): | ||||
|     def dereference(self, value): | ||||
|         doc_cls = get_document(value['_cls']) | ||||
|         reference = value['_ref'] | ||||
|         doc = _get_db().dereference(reference) | ||||
|         doc = get_db().dereference(reference) | ||||
|         if doc is not None: | ||||
|             doc = doc_cls._from_son(doc) | ||||
|         return doc | ||||
| @@ -780,7 +780,7 @@ class GridFSProxy(object): | ||||
|  | ||||
|     def __init__(self, grid_id=None, key=None, | ||||
|                  instance=None, collection_name='fs'): | ||||
|         self.fs = gridfs.GridFS(_get_db(), collection_name)  # Filesystem instance | ||||
|         self.fs = gridfs.GridFS(get_db(), collection_name)  # Filesystem instance | ||||
|         self.newfile = None                 # Used for partial writes | ||||
|         self.grid_id = grid_id              # Store GridFS id for file | ||||
|         self.gridout = None | ||||
| @@ -1138,7 +1138,7 @@ class SequenceField(IntField): | ||||
|         """ | ||||
|         sequence_id = "{0}.{1}".format(self.owner_document._get_collection_name(), | ||||
|                                        self.name) | ||||
|         collection = _get_db()[self.collection_name] | ||||
|         collection = get_db()[self.collection_name] | ||||
|         counter = collection.find_and_modify(query={"_id": sequence_id}, | ||||
|                                              update={"$inc": {"next": 1}}, | ||||
|                                              new=True, | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| from connection import _get_db | ||||
| from connection import get_db | ||||
| from mongoengine import signals | ||||
|  | ||||
| import pprint | ||||
| @@ -481,7 +481,7 @@ class QuerySet(object): | ||||
|         if self._document not in QuerySet.__already_indexed: | ||||
|  | ||||
|             # Ensure collection exists | ||||
|             db = _get_db() | ||||
|             db = get_db() | ||||
|             if self._collection_obj.name not in db.collection_names(): | ||||
|                 self._document._collection = None | ||||
|                 self._collection_obj = self._document._get_collection() | ||||
| @@ -1436,7 +1436,7 @@ class QuerySet(object): | ||||
|         scope['query'] = query | ||||
|         code = pymongo.code.Code(code, scope=scope) | ||||
|  | ||||
|         db = _get_db() | ||||
|         db = get_db() | ||||
|         return db.eval(code, *fields) | ||||
|  | ||||
|     def where(self, where_clause): | ||||
|   | ||||
| @@ -1,4 +1,4 @@ | ||||
| from mongoengine.connection import _get_db | ||||
| from mongoengine.connection import get_db | ||||
|  | ||||
|  | ||||
| class query_counter(object): | ||||
| @@ -7,7 +7,7 @@ class query_counter(object): | ||||
|     def __init__(self): | ||||
|         """ Construct the query_counter. """ | ||||
|         self.counter = 0 | ||||
|         self.db = _get_db() | ||||
|         self.db = get_db() | ||||
|  | ||||
|     def __enter__(self): | ||||
|         """ On every with block we need to drop the profile collection. """ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user