Added multidb support
No change required to upgrade to multiple databases. Aliases are used to describe the database and these can be manually registered or fall through to a default alias using connect. Made get_connection and get_db first class members of the connection class. Old style _get_connection and _get_db still supported. Refs: #84 #87 #93 #215
This commit is contained in:
@@ -1,82 +1,106 @@
|
||||
from pymongo import Connection
|
||||
import multiprocessing
|
||||
import threading
|
||||
|
||||
__all__ = ['ConnectionError', 'connect']
|
||||
|
||||
|
||||
_connection_defaults = {
|
||||
'host': 'localhost',
|
||||
'port': 27017,
|
||||
}
|
||||
_connection = {}
|
||||
_connection_settings = _connection_defaults.copy()
|
||||
__all__ = ['ConnectionError', 'connect', 'register_connection']
|
||||
|
||||
_db_name = None
|
||||
_db_username = None
|
||||
_db_password = None
|
||||
_db = {}
|
||||
|
||||
DEFAULT_CONNECTION_NAME = 'default'
|
||||
|
||||
|
||||
class ConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _get_connection(reconnect=False):
|
||||
"""Handles the connection to the database
|
||||
_connection_settings = {}
|
||||
_connections = {}
|
||||
_dbs = {}
|
||||
|
||||
|
||||
def register_connection(alias, name, host='localhost', port=27017,
|
||||
is_slave=False, slaves=None, username=None,
|
||||
password=None):
|
||||
"""Add a connection.
|
||||
|
||||
:param alias: the name that will be used to refer to this connection
|
||||
throughout MongoEngine
|
||||
:param name: the name of the specific database to use
|
||||
:param host: the host name of the :program:`mongod` instance to connect to
|
||||
:param port: the port that the :program:`mongod` instance is running on
|
||||
:param is_slave: whether the connection can act as a slave
|
||||
:param slaves: a list of aliases of slave connections; each of these must
|
||||
be a registered connection that has :attr:`is_slave` set to ``True``
|
||||
:param username: username to authenticate with
|
||||
:param password: password to authenticate with
|
||||
"""
|
||||
global _connection
|
||||
identity = get_identity()
|
||||
global _connection_settings
|
||||
_connection_settings[alias] = {
|
||||
'name': name,
|
||||
'host': host,
|
||||
'port': port,
|
||||
'is_slave': is_slave,
|
||||
'slaves': slaves or [],
|
||||
'username': username,
|
||||
'password': password,
|
||||
}
|
||||
|
||||
|
||||
def get_connection(alias=DEFAULT_CONNECTION_NAME):
|
||||
global _connections
|
||||
# Connect to the database if not already connected
|
||||
if _connection.get(identity) is None or reconnect:
|
||||
if alias not in _connections:
|
||||
if alias not in _connection_settings:
|
||||
msg = 'Connection with alias "%s" has not been defined'
|
||||
if alias == DEFAULT_CONNECTION_NAME:
|
||||
msg = 'You have not defined a default connection'
|
||||
raise ConnectionError(msg)
|
||||
conn_settings = _connection_settings[alias].copy()
|
||||
|
||||
# Get all the slave connections
|
||||
slaves = []
|
||||
for slave_alias in conn_settings['slaves']:
|
||||
slaves.append(get_connection(slave_alias))
|
||||
conn_settings['slaves'] = slaves
|
||||
|
||||
try:
|
||||
_connection[identity] = Connection(**_connection_settings)
|
||||
_connections[alias] = Connection(**conn_settings)
|
||||
except Exception, e:
|
||||
raise ConnectionError("Cannot connect to the database:\n%s" % e)
|
||||
return _connection[identity]
|
||||
raise e
|
||||
raise ConnectionError('Cannot connect to database %s' % alias)
|
||||
return _connections[alias]
|
||||
|
||||
def _get_db(reconnect=False):
|
||||
"""Handles database connections and authentication based on the current
|
||||
identity
|
||||
|
||||
def get_db(alias=DEFAULT_CONNECTION_NAME):
|
||||
global _dbs
|
||||
if alias not in _dbs:
|
||||
conn = get_connection(alias)
|
||||
conn_settings = _connection_settings[alias]
|
||||
_dbs[alias] = conn[conn_settings['name']]
|
||||
|
||||
# Authenticate if necessary
|
||||
if conn_settings['username'] and conn_settings['password']:
|
||||
_dbs[alias].authenticate(conn_settings['username'],
|
||||
conn_settings['password'])
|
||||
return _dbs[alias]
|
||||
|
||||
|
||||
def connect(db, alias=DEFAULT_CONNECTION_NAME, **kwargs):
|
||||
"""Connect to the database specified by the 'db' argument.
|
||||
|
||||
Connection settings may be provided here as well if the database is not
|
||||
running on the default port on localhost. If authentication is needed,
|
||||
provide username and password arguments as well.
|
||||
|
||||
Multiple databases are supported by using aliases. Provide a separate
|
||||
`alias` to connect to a different instance of :program:`mongod`.
|
||||
|
||||
.. versionchanged:: 0.6 - added multiple database support.
|
||||
"""
|
||||
global _db, _connection
|
||||
identity = get_identity()
|
||||
# Connect if not already connected
|
||||
if _connection.get(identity) is None or reconnect:
|
||||
_connection[identity] = _get_connection(reconnect=reconnect)
|
||||
global _connections
|
||||
if alias not in _connections:
|
||||
register_connection(alias, db, **kwargs)
|
||||
|
||||
if _db.get(identity) is None or reconnect:
|
||||
# _db_name will be None if the user hasn't called connect()
|
||||
if _db_name is None:
|
||||
raise ConnectionError('Not connected to the database')
|
||||
|
||||
# Get DB from current connection and authenticate if necessary
|
||||
_db[identity] = _connection[identity][_db_name]
|
||||
if _db_username and _db_password:
|
||||
_db[identity].authenticate(_db_username, _db_password)
|
||||
|
||||
return _db[identity]
|
||||
|
||||
def get_identity():
|
||||
"""Creates an identity key based on the current process and thread
|
||||
identity.
|
||||
"""
|
||||
identity = multiprocessing.current_process()._identity
|
||||
identity = 0 if not identity else identity[0]
|
||||
|
||||
identity = (identity, threading.current_thread().ident)
|
||||
return identity
|
||||
|
||||
def connect(db, username=None, password=None, **kwargs):
|
||||
"""Connect to the database specified by the 'db' argument. Connection
|
||||
settings may be provided here as well if the database is not running on
|
||||
the default port on localhost. If authentication is needed, provide
|
||||
username and password arguments as well.
|
||||
"""
|
||||
global _connection_settings, _db_name, _db_username, _db_password, _db
|
||||
_connection_settings = dict(_connection_defaults, **kwargs)
|
||||
_db_name = db
|
||||
_db_username = username
|
||||
_db_password = password
|
||||
return _get_db(reconnect=True)
|
||||
return get_connection(alias)
|
||||
|
||||
# Support old naming convention
|
||||
_get_connection = get_connection
|
||||
_get_db = get_db
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import operator
|
||||
|
||||
import pymongo
|
||||
|
||||
from base import BaseDict, BaseList, get_document, TopLevelDocumentMetaclass
|
||||
from fields import ReferenceField
|
||||
from connection import _get_db
|
||||
from connection import get_db
|
||||
from queryset import QuerySet
|
||||
from document import Document
|
||||
|
||||
@@ -103,7 +101,7 @@ class DeReference(object):
|
||||
for key, doc in references.iteritems():
|
||||
object_map[key] = doc
|
||||
else: # Generic reference: use the refs data to convert to document
|
||||
references = _get_db()[col].find({'_id': {'$in': refs}})
|
||||
references = get_db()[col].find({'_id': {'$in': refs}})
|
||||
for ref in references:
|
||||
if '_cls' in ref:
|
||||
doc = get_document(ref['_cls'])._from_son(ref)
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import operator
|
||||
from mongoengine import signals
|
||||
from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument,
|
||||
ValidationError, BaseDict, BaseList, BaseDynamicField)
|
||||
from queryset import OperationError
|
||||
from connection import _get_db
|
||||
from connection import get_db
|
||||
|
||||
import pymongo
|
||||
|
||||
__all__ = ['Document', 'EmbeddedDocument', 'DynamicDocument', 'DynamicEmbeddedDocument',
|
||||
'ValidationError', 'OperationError', 'InvalidCollectionError']
|
||||
__all__ = ['Document', 'EmbeddedDocument', 'DynamicDocument',
|
||||
'DynamicEmbeddedDocument', 'ValidationError', 'OperationError',
|
||||
'InvalidCollectionError']
|
||||
|
||||
|
||||
class InvalidCollectionError(Exception):
|
||||
@@ -76,7 +76,7 @@ class Document(BaseDocument):
|
||||
by setting index_types to False on the meta dictionary for the document.
|
||||
"""
|
||||
__metaclass__ = TopLevelDocumentMetaclass
|
||||
|
||||
|
||||
@apply
|
||||
def pk():
|
||||
"""Primary key alias
|
||||
@@ -91,7 +91,7 @@ class Document(BaseDocument):
|
||||
def _get_collection(self):
|
||||
"""Returns the collection for the document."""
|
||||
if not hasattr(self, '_collection') or self._collection is None:
|
||||
db = _get_db()
|
||||
db = get_db()
|
||||
collection_name = self._get_collection_name()
|
||||
# Create collection as a capped collection if specified
|
||||
if self._meta['max_size'] or self._meta['max_documents']:
|
||||
@@ -300,7 +300,7 @@ class Document(BaseDocument):
|
||||
:class:`~mongoengine.Document` type from the database.
|
||||
"""
|
||||
from mongoengine.queryset import QuerySet
|
||||
db = _get_db()
|
||||
db = get_db()
|
||||
db.drop_collection(cls._get_collection_name())
|
||||
QuerySet._reset_already_indexed(cls)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from base import (BaseField, ComplexBaseField, ObjectIdField,
|
||||
ValidationError, get_document)
|
||||
from queryset import DO_NOTHING
|
||||
from document import Document, EmbeddedDocument
|
||||
from connection import _get_db
|
||||
from connection import get_db
|
||||
from operator import itemgetter
|
||||
|
||||
|
||||
@@ -637,7 +637,7 @@ class ReferenceField(BaseField):
|
||||
value = instance._data.get(self.name)
|
||||
# Dereference DBRefs
|
||||
if isinstance(value, (pymongo.dbref.DBRef)):
|
||||
value = _get_db().dereference(value)
|
||||
value = get_db().dereference(value)
|
||||
if value is not None:
|
||||
instance._data[self.name] = self.document_type._from_son(value)
|
||||
|
||||
@@ -710,7 +710,7 @@ class GenericReferenceField(BaseField):
|
||||
def dereference(self, value):
|
||||
doc_cls = get_document(value['_cls'])
|
||||
reference = value['_ref']
|
||||
doc = _get_db().dereference(reference)
|
||||
doc = get_db().dereference(reference)
|
||||
if doc is not None:
|
||||
doc = doc_cls._from_son(doc)
|
||||
return doc
|
||||
@@ -780,7 +780,7 @@ class GridFSProxy(object):
|
||||
|
||||
def __init__(self, grid_id=None, key=None,
|
||||
instance=None, collection_name='fs'):
|
||||
self.fs = gridfs.GridFS(_get_db(), collection_name) # Filesystem instance
|
||||
self.fs = gridfs.GridFS(get_db(), collection_name) # Filesystem instance
|
||||
self.newfile = None # Used for partial writes
|
||||
self.grid_id = grid_id # Store GridFS id for file
|
||||
self.gridout = None
|
||||
@@ -1138,7 +1138,7 @@ class SequenceField(IntField):
|
||||
"""
|
||||
sequence_id = "{0}.{1}".format(self.owner_document._get_collection_name(),
|
||||
self.name)
|
||||
collection = _get_db()[self.collection_name]
|
||||
collection = get_db()[self.collection_name]
|
||||
counter = collection.find_and_modify(query={"_id": sequence_id},
|
||||
update={"$inc": {"next": 1}},
|
||||
new=True,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from connection import _get_db
|
||||
from connection import get_db
|
||||
from mongoengine import signals
|
||||
|
||||
import pprint
|
||||
@@ -481,7 +481,7 @@ class QuerySet(object):
|
||||
if self._document not in QuerySet.__already_indexed:
|
||||
|
||||
# Ensure collection exists
|
||||
db = _get_db()
|
||||
db = get_db()
|
||||
if self._collection_obj.name not in db.collection_names():
|
||||
self._document._collection = None
|
||||
self._collection_obj = self._document._get_collection()
|
||||
@@ -1436,7 +1436,7 @@ class QuerySet(object):
|
||||
scope['query'] = query
|
||||
code = pymongo.code.Code(code, scope=scope)
|
||||
|
||||
db = _get_db()
|
||||
db = get_db()
|
||||
return db.eval(code, *fields)
|
||||
|
||||
def where(self, where_clause):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from mongoengine.connection import _get_db
|
||||
from mongoengine.connection import get_db
|
||||
|
||||
|
||||
class query_counter(object):
|
||||
@@ -7,7 +7,7 @@ class query_counter(object):
|
||||
def __init__(self):
|
||||
""" Construct the query_counter. """
|
||||
self.counter = 0
|
||||
self.db = _get_db()
|
||||
self.db = get_db()
|
||||
|
||||
def __enter__(self):
|
||||
""" On every with block we need to drop the profile collection. """
|
||||
|
||||
Reference in New Issue
Block a user