Added multidb support
No change required to upgrade to multiple databases. Aliases are used to describe the database and these can be manually registered or fall through to a default alias using connect. Made get_connection and get_db first class members of the connection class. Old style _get_connection and _get_db still supported. Refs: #84 #87 #93 #215
This commit is contained in:
parent
63c5a4dd65
commit
e80144e9f2
@ -6,6 +6,7 @@ Connecting
|
||||
==========
|
||||
|
||||
.. autofunction:: mongoengine.connect
|
||||
.. autofunction:: mongoengine.register_connection
|
||||
|
||||
Documents
|
||||
=========
|
||||
|
@ -3,6 +3,7 @@
|
||||
=====================
|
||||
Connecting to MongoDB
|
||||
=====================
|
||||
|
||||
To connect to a running instance of :program:`mongod`, use the
|
||||
:func:`~mongoengine.connect` function. The first argument is the name of the
|
||||
database to connect to. If the database does not exist, it will be created. If
|
||||
@ -18,3 +19,14 @@ provide :attr:`host` and :attr:`port` arguments to
|
||||
:func:`~mongoengine.connect`::
|
||||
|
||||
connect('project1', host='192.168.1.35', port=12345)
|
||||
|
||||
|
||||
Multiple Databases
|
||||
==================
|
||||
|
||||
Multiple database support was added in MongoEngine 0.6. To use multiple
|
||||
databases you can use :func:`~mongoengine.connect` and provide an `alias` name
|
||||
for the connection - if no `alias` is provided then "default" is used.
|
||||
|
||||
In the background this uses :func:`~mongoengine.register_connection` to
|
||||
store the data and you can register all aliases up front if required.
|
||||
|
@ -1,82 +1,106 @@
|
||||
from pymongo import Connection
|
||||
import multiprocessing
|
||||
import threading
|
||||
|
||||
__all__ = ['ConnectionError', 'connect']
|
||||
|
||||
|
||||
_connection_defaults = {
|
||||
'host': 'localhost',
|
||||
'port': 27017,
|
||||
}
|
||||
_connection = {}
|
||||
_connection_settings = _connection_defaults.copy()
|
||||
__all__ = ['ConnectionError', 'connect', 'register_connection']
|
||||
|
||||
_db_name = None
|
||||
_db_username = None
|
||||
_db_password = None
|
||||
_db = {}
|
||||
|
||||
DEFAULT_CONNECTION_NAME = 'default'
|
||||
|
||||
|
||||
class ConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _get_connection(reconnect=False):
|
||||
"""Handles the connection to the database
|
||||
_connection_settings = {}
|
||||
_connections = {}
|
||||
_dbs = {}
|
||||
|
||||
|
||||
def register_connection(alias, name, host='localhost', port=27017,
|
||||
is_slave=False, slaves=None, username=None,
|
||||
password=None):
|
||||
"""Add a connection.
|
||||
|
||||
:param alias: the name that will be used to refer to this connection
|
||||
throughout MongoEngine
|
||||
:param name: the name of the specific database to use
|
||||
:param host: the host name of the :program:`mongod` instance to connect to
|
||||
:param port: the port that the :program:`mongod` instance is running on
|
||||
:param is_slave: whether the connection can act as a slave
|
||||
:param slaves: a list of aliases of slave connections; each of these must
|
||||
be a registered connection that has :attr:`is_slave` set to ``True``
|
||||
:param username: username to authenticate with
|
||||
:param password: password to authenticate with
|
||||
"""
|
||||
global _connection
|
||||
identity = get_identity()
|
||||
global _connection_settings
|
||||
_connection_settings[alias] = {
|
||||
'name': name,
|
||||
'host': host,
|
||||
'port': port,
|
||||
'is_slave': is_slave,
|
||||
'slaves': slaves or [],
|
||||
'username': username,
|
||||
'password': password,
|
||||
}
|
||||
|
||||
|
||||
def get_connection(alias=DEFAULT_CONNECTION_NAME):
|
||||
global _connections
|
||||
# Connect to the database if not already connected
|
||||
if _connection.get(identity) is None or reconnect:
|
||||
if alias not in _connections:
|
||||
if alias not in _connection_settings:
|
||||
msg = 'Connection with alias "%s" has not been defined'
|
||||
if alias == DEFAULT_CONNECTION_NAME:
|
||||
msg = 'You have not defined a default connection'
|
||||
raise ConnectionError(msg)
|
||||
conn_settings = _connection_settings[alias].copy()
|
||||
|
||||
# Get all the slave connections
|
||||
slaves = []
|
||||
for slave_alias in conn_settings['slaves']:
|
||||
slaves.append(get_connection(slave_alias))
|
||||
conn_settings['slaves'] = slaves
|
||||
|
||||
try:
|
||||
_connection[identity] = Connection(**_connection_settings)
|
||||
_connections[alias] = Connection(**conn_settings)
|
||||
except Exception, e:
|
||||
raise ConnectionError("Cannot connect to the database:\n%s" % e)
|
||||
return _connection[identity]
|
||||
raise e
|
||||
raise ConnectionError('Cannot connect to database %s' % alias)
|
||||
return _connections[alias]
|
||||
|
||||
def _get_db(reconnect=False):
|
||||
"""Handles database connections and authentication based on the current
|
||||
identity
|
||||
|
||||
def get_db(alias=DEFAULT_CONNECTION_NAME):
|
||||
global _dbs
|
||||
if alias not in _dbs:
|
||||
conn = get_connection(alias)
|
||||
conn_settings = _connection_settings[alias]
|
||||
_dbs[alias] = conn[conn_settings['name']]
|
||||
|
||||
# Authenticate if necessary
|
||||
if conn_settings['username'] and conn_settings['password']:
|
||||
_dbs[alias].authenticate(conn_settings['username'],
|
||||
conn_settings['password'])
|
||||
return _dbs[alias]
|
||||
|
||||
|
||||
def connect(db, alias=DEFAULT_CONNECTION_NAME, **kwargs):
|
||||
"""Connect to the database specified by the 'db' argument.
|
||||
|
||||
Connection settings may be provided here as well if the database is not
|
||||
running on the default port on localhost. If authentication is needed,
|
||||
provide username and password arguments as well.
|
||||
|
||||
Multiple databases are supported by using aliases. Provide a separate
|
||||
`alias` to connect to a different instance of :program:`mongod`.
|
||||
|
||||
.. versionchanged:: 0.6 - added multiple database support.
|
||||
"""
|
||||
global _db, _connection
|
||||
identity = get_identity()
|
||||
# Connect if not already connected
|
||||
if _connection.get(identity) is None or reconnect:
|
||||
_connection[identity] = _get_connection(reconnect=reconnect)
|
||||
global _connections
|
||||
if alias not in _connections:
|
||||
register_connection(alias, db, **kwargs)
|
||||
|
||||
if _db.get(identity) is None or reconnect:
|
||||
# _db_name will be None if the user hasn't called connect()
|
||||
if _db_name is None:
|
||||
raise ConnectionError('Not connected to the database')
|
||||
|
||||
# Get DB from current connection and authenticate if necessary
|
||||
_db[identity] = _connection[identity][_db_name]
|
||||
if _db_username and _db_password:
|
||||
_db[identity].authenticate(_db_username, _db_password)
|
||||
|
||||
return _db[identity]
|
||||
|
||||
def get_identity():
|
||||
"""Creates an identity key based on the current process and thread
|
||||
identity.
|
||||
"""
|
||||
identity = multiprocessing.current_process()._identity
|
||||
identity = 0 if not identity else identity[0]
|
||||
|
||||
identity = (identity, threading.current_thread().ident)
|
||||
return identity
|
||||
|
||||
def connect(db, username=None, password=None, **kwargs):
|
||||
"""Connect to the database specified by the 'db' argument. Connection
|
||||
settings may be provided here as well if the database is not running on
|
||||
the default port on localhost. If authentication is needed, provide
|
||||
username and password arguments as well.
|
||||
"""
|
||||
global _connection_settings, _db_name, _db_username, _db_password, _db
|
||||
_connection_settings = dict(_connection_defaults, **kwargs)
|
||||
_db_name = db
|
||||
_db_username = username
|
||||
_db_password = password
|
||||
return _get_db(reconnect=True)
|
||||
return get_connection(alias)
|
||||
|
||||
# Support old naming convention
|
||||
_get_connection = get_connection
|
||||
_get_db = get_db
|
||||
|
@ -1,10 +1,8 @@
|
||||
import operator
|
||||
|
||||
import pymongo
|
||||
|
||||
from base import BaseDict, BaseList, get_document, TopLevelDocumentMetaclass
|
||||
from fields import ReferenceField
|
||||
from connection import _get_db
|
||||
from connection import get_db
|
||||
from queryset import QuerySet
|
||||
from document import Document
|
||||
|
||||
@ -103,7 +101,7 @@ class DeReference(object):
|
||||
for key, doc in references.iteritems():
|
||||
object_map[key] = doc
|
||||
else: # Generic reference: use the refs data to convert to document
|
||||
references = _get_db()[col].find({'_id': {'$in': refs}})
|
||||
references = get_db()[col].find({'_id': {'$in': refs}})
|
||||
for ref in references:
|
||||
if '_cls' in ref:
|
||||
doc = get_document(ref['_cls'])._from_son(ref)
|
||||
|
@ -1,14 +1,14 @@
|
||||
import operator
|
||||
from mongoengine import signals
|
||||
from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument,
|
||||
ValidationError, BaseDict, BaseList, BaseDynamicField)
|
||||
from queryset import OperationError
|
||||
from connection import _get_db
|
||||
from connection import get_db
|
||||
|
||||
import pymongo
|
||||
|
||||
__all__ = ['Document', 'EmbeddedDocument', 'DynamicDocument', 'DynamicEmbeddedDocument',
|
||||
'ValidationError', 'OperationError', 'InvalidCollectionError']
|
||||
__all__ = ['Document', 'EmbeddedDocument', 'DynamicDocument',
|
||||
'DynamicEmbeddedDocument', 'ValidationError', 'OperationError',
|
||||
'InvalidCollectionError']
|
||||
|
||||
|
||||
class InvalidCollectionError(Exception):
|
||||
@ -91,7 +91,7 @@ class Document(BaseDocument):
|
||||
def _get_collection(self):
|
||||
"""Returns the collection for the document."""
|
||||
if not hasattr(self, '_collection') or self._collection is None:
|
||||
db = _get_db()
|
||||
db = get_db()
|
||||
collection_name = self._get_collection_name()
|
||||
# Create collection as a capped collection if specified
|
||||
if self._meta['max_size'] or self._meta['max_documents']:
|
||||
@ -300,7 +300,7 @@ class Document(BaseDocument):
|
||||
:class:`~mongoengine.Document` type from the database.
|
||||
"""
|
||||
from mongoengine.queryset import QuerySet
|
||||
db = _get_db()
|
||||
db = get_db()
|
||||
db.drop_collection(cls._get_collection_name())
|
||||
QuerySet._reset_already_indexed(cls)
|
||||
|
||||
|
@ -13,7 +13,7 @@ from base import (BaseField, ComplexBaseField, ObjectIdField,
|
||||
ValidationError, get_document)
|
||||
from queryset import DO_NOTHING
|
||||
from document import Document, EmbeddedDocument
|
||||
from connection import _get_db
|
||||
from connection import get_db
|
||||
from operator import itemgetter
|
||||
|
||||
|
||||
@ -637,7 +637,7 @@ class ReferenceField(BaseField):
|
||||
value = instance._data.get(self.name)
|
||||
# Dereference DBRefs
|
||||
if isinstance(value, (pymongo.dbref.DBRef)):
|
||||
value = _get_db().dereference(value)
|
||||
value = get_db().dereference(value)
|
||||
if value is not None:
|
||||
instance._data[self.name] = self.document_type._from_son(value)
|
||||
|
||||
@ -710,7 +710,7 @@ class GenericReferenceField(BaseField):
|
||||
def dereference(self, value):
|
||||
doc_cls = get_document(value['_cls'])
|
||||
reference = value['_ref']
|
||||
doc = _get_db().dereference(reference)
|
||||
doc = get_db().dereference(reference)
|
||||
if doc is not None:
|
||||
doc = doc_cls._from_son(doc)
|
||||
return doc
|
||||
@ -780,7 +780,7 @@ class GridFSProxy(object):
|
||||
|
||||
def __init__(self, grid_id=None, key=None,
|
||||
instance=None, collection_name='fs'):
|
||||
self.fs = gridfs.GridFS(_get_db(), collection_name) # Filesystem instance
|
||||
self.fs = gridfs.GridFS(get_db(), collection_name) # Filesystem instance
|
||||
self.newfile = None # Used for partial writes
|
||||
self.grid_id = grid_id # Store GridFS id for file
|
||||
self.gridout = None
|
||||
@ -1138,7 +1138,7 @@ class SequenceField(IntField):
|
||||
"""
|
||||
sequence_id = "{0}.{1}".format(self.owner_document._get_collection_name(),
|
||||
self.name)
|
||||
collection = _get_db()[self.collection_name]
|
||||
collection = get_db()[self.collection_name]
|
||||
counter = collection.find_and_modify(query={"_id": sequence_id},
|
||||
update={"$inc": {"next": 1}},
|
||||
new=True,
|
||||
|
@ -1,4 +1,4 @@
|
||||
from connection import _get_db
|
||||
from connection import get_db
|
||||
from mongoengine import signals
|
||||
|
||||
import pprint
|
||||
@ -481,7 +481,7 @@ class QuerySet(object):
|
||||
if self._document not in QuerySet.__already_indexed:
|
||||
|
||||
# Ensure collection exists
|
||||
db = _get_db()
|
||||
db = get_db()
|
||||
if self._collection_obj.name not in db.collection_names():
|
||||
self._document._collection = None
|
||||
self._collection_obj = self._document._get_collection()
|
||||
@ -1436,7 +1436,7 @@ class QuerySet(object):
|
||||
scope['query'] = query
|
||||
code = pymongo.code.Code(code, scope=scope)
|
||||
|
||||
db = _get_db()
|
||||
db = get_db()
|
||||
return db.eval(code, *fields)
|
||||
|
||||
def where(self, where_clause):
|
||||
|
@ -1,4 +1,4 @@
|
||||
from mongoengine.connection import _get_db
|
||||
from mongoengine.connection import get_db
|
||||
|
||||
|
||||
class query_counter(object):
|
||||
@ -7,7 +7,7 @@ class query_counter(object):
|
||||
def __init__(self):
|
||||
""" Construct the query_counter. """
|
||||
self.counter = 0
|
||||
self.db = _get_db()
|
||||
self.db = get_db()
|
||||
|
||||
def __enter__(self):
|
||||
""" On every with block we need to drop the profile collection. """
|
||||
|
48
tests/connection.py
Normal file
48
tests/connection.py
Normal file
@ -0,0 +1,48 @@
|
||||
import unittest
|
||||
import pymongo
|
||||
|
||||
import mongoengine.connection
|
||||
|
||||
from mongoengine import *
|
||||
from mongoengine.connection import get_db, get_connection
|
||||
|
||||
|
||||
class ConnectionTest(unittest.TestCase):
|
||||
|
||||
def tearDown(self):
|
||||
mongoengine.connection._connection_settings = {}
|
||||
mongoengine.connection._connections = {}
|
||||
mongoengine.connection._dbs = {}
|
||||
|
||||
def test_connect(self):
|
||||
"""Ensure that the connect() method works properly.
|
||||
"""
|
||||
connect('mongoenginetest')
|
||||
|
||||
conn = get_connection()
|
||||
self.assertTrue(isinstance(conn, pymongo.connection.Connection))
|
||||
|
||||
db = get_db()
|
||||
self.assertTrue(isinstance(db, pymongo.database.Database))
|
||||
self.assertEqual(db.name, 'mongoenginetest')
|
||||
|
||||
connect('mongoenginetest2', alias='testdb')
|
||||
conn = get_connection('testdb')
|
||||
self.assertTrue(isinstance(conn, pymongo.connection.Connection))
|
||||
|
||||
def test_register_connection(self):
|
||||
"""Ensure that connections with different aliases may be registered.
|
||||
"""
|
||||
register_connection('testdb', 'mongoenginetest2')
|
||||
|
||||
self.assertRaises(ConnectionError, get_connection)
|
||||
conn = get_connection('testdb')
|
||||
self.assertTrue(isinstance(conn, pymongo.connection.Connection))
|
||||
|
||||
db = get_db('testdb')
|
||||
self.assertTrue(isinstance(db, pymongo.database.Database))
|
||||
self.assertEqual(db.name, 'mongoenginetest2')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
|
||||
from mongoengine import *
|
||||
from mongoengine.connection import _get_db
|
||||
from mongoengine.connection import get_db
|
||||
from mongoengine.tests import query_counter
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ class FieldTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
connect(db='mongoenginetest')
|
||||
self.db = _get_db()
|
||||
self.db = get_db()
|
||||
|
||||
def test_list_item_dereference(self):
|
||||
"""Ensure that DBRef items in ListFields are dereferenced.
|
||||
|
@ -5,22 +5,18 @@ import warnings
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import pymongo
|
||||
import pickle
|
||||
import weakref
|
||||
|
||||
from fixtures import Base, Mixin, PickleEmbedded, PickleTest
|
||||
|
||||
from mongoengine import *
|
||||
from mongoengine.base import _document_registry, NotRegistered, InvalidDocumentError
|
||||
from mongoengine.connection import _get_db
|
||||
from mongoengine.base import NotRegistered, InvalidDocumentError
|
||||
from mongoengine.connection import get_db
|
||||
|
||||
|
||||
class DocumentTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
connect(db='mongoenginetest')
|
||||
self.db = _get_db()
|
||||
self.db = get_db()
|
||||
|
||||
class Person(Document):
|
||||
name = StringField()
|
||||
|
@ -1,13 +1,14 @@
|
||||
import unittest
|
||||
|
||||
from mongoengine import *
|
||||
from mongoengine.connection import _get_db
|
||||
from mongoengine.connection import get_db
|
||||
|
||||
|
||||
class DynamicDocTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
connect(db='mongoenginetest')
|
||||
self.db = _get_db()
|
||||
self.db = get_db()
|
||||
|
||||
class Person(DynamicDocument):
|
||||
name = StringField()
|
||||
|
@ -6,7 +6,7 @@ import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
from mongoengine import *
|
||||
from mongoengine.connection import _get_db
|
||||
from mongoengine.connection import get_db
|
||||
from mongoengine.base import _document_registry, NotRegistered
|
||||
|
||||
TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png')
|
||||
@ -16,7 +16,7 @@ class FieldTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
connect(db='mongoenginetest')
|
||||
self.db = _get_db()
|
||||
self.db = get_db()
|
||||
|
||||
def test_default_values(self):
|
||||
"""Ensure that default field values are used when creating a document.
|
||||
|
@ -1,9 +1,6 @@
|
||||
from datetime import datetime
|
||||
import pymongo
|
||||
|
||||
from mongoengine import *
|
||||
from mongoengine.base import BaseField
|
||||
from mongoengine.connection import _get_db
|
||||
|
||||
|
||||
class PickleEmbedded(EmbeddedDocument):
|
||||
|
@ -7,7 +7,7 @@ from mongoengine.queryset import (QuerySet, QuerySetManager,
|
||||
MultipleObjectsReturned, DoesNotExist,
|
||||
QueryFieldList)
|
||||
from mongoengine import *
|
||||
from mongoengine.connection import _get_connection
|
||||
from mongoengine.connection import get_connection
|
||||
from mongoengine.tests import query_counter
|
||||
|
||||
|
||||
@ -2276,7 +2276,7 @@ class QuerySetTest(unittest.TestCase):
|
||||
|
||||
# check that polygon works for users who have a server >= 1.9
|
||||
server_version = tuple(
|
||||
_get_connection().server_info()['version'].split('.')
|
||||
get_connection().server_info()['version'].split('.')
|
||||
)
|
||||
required_version = tuple("1.9.0".split("."))
|
||||
if server_version >= required_version:
|
||||
|
Loading…
x
Reference in New Issue
Block a user