From e80144e9f2399814760a343ee1c34a4e3e785c26 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Fri, 18 Nov 2011 07:22:37 -0800 Subject: [PATCH] Added multidb support No change required to upgrade to multiple databases. Aliases are used to describe the database and these can be manually registered or fall through to a default alias using connect. Made get_connection and get_db first class members of the connection class. Old style _get_connection and _get_db still supported. Refs: #84 #87 #93 #215 --- docs/apireference.rst | 1 + docs/guide/connecting.rst | 12 +++ mongoengine/connection.py | 152 +++++++++++++++++++++---------------- mongoengine/dereference.py | 6 +- mongoengine/document.py | 14 ++-- mongoengine/fields.py | 10 +-- mongoengine/queryset.py | 6 +- mongoengine/tests.py | 4 +- tests/connection.py | 48 ++++++++++++ tests/dereference.py | 4 +- tests/document.py | 10 +-- tests/dynamic_document.py | 5 +- tests/fields.py | 4 +- tests/fixtures.py | 3 - tests/queryset.py | 4 +- 15 files changed, 180 insertions(+), 103 deletions(-) create mode 100644 tests/connection.py diff --git a/docs/apireference.rst b/docs/apireference.rst index 932152fe..9e6ed474 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -6,6 +6,7 @@ Connecting ========== .. autofunction:: mongoengine.connect +.. autofunction:: mongoengine.register_connection Documents ========= diff --git a/docs/guide/connecting.rst b/docs/guide/connecting.rst index 64e7666a..a51d68e6 100644 --- a/docs/guide/connecting.rst +++ b/docs/guide/connecting.rst @@ -3,6 +3,7 @@ ===================== Connecting to MongoDB ===================== + To connect to a running instance of :program:`mongod`, use the :func:`~mongoengine.connect` function. The first argument is the name of the database to connect to. If the database does not exist, it will be created. If @@ -18,3 +19,14 @@ provide :attr:`host` and :attr:`port` arguments to :func:`~mongoengine.connect`:: connect('project1', host='192.168.1.35', port=12345) + + +Multiple Databases +================== + +Multiple database support was added in MongoEngine 0.6. To use multiple +databases you can use :func:`~mongoengine.connect` and provide an `alias` name +for the connection - if no `alias` is provided then "default" is used. + +In the background this uses :func:`~mongoengine.register_connection` to +store the data and you can register all aliases up front if required. diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 7b5cd210..c7d8f893 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,82 +1,106 @@ from pymongo import Connection -import multiprocessing -import threading - -__all__ = ['ConnectionError', 'connect'] -_connection_defaults = { - 'host': 'localhost', - 'port': 27017, -} -_connection = {} -_connection_settings = _connection_defaults.copy() +__all__ = ['ConnectionError', 'connect', 'register_connection'] -_db_name = None -_db_username = None -_db_password = None -_db = {} + +DEFAULT_CONNECTION_NAME = 'default' class ConnectionError(Exception): pass -def _get_connection(reconnect=False): - """Handles the connection to the database +_connection_settings = {} +_connections = {} +_dbs = {} + + +def register_connection(alias, name, host='localhost', port=27017, + is_slave=False, slaves=None, username=None, + password=None): + """Add a connection. + + :param alias: the name that will be used to refer to this connection + throughout MongoEngine + :param name: the name of the specific database to use + :param host: the host name of the :program:`mongod` instance to connect to + :param port: the port that the :program:`mongod` instance is running on + :param is_slave: whether the connection can act as a slave + :param slaves: a list of aliases of slave connections; each of these must + be a registered connection that has :attr:`is_slave` set to ``True`` + :param username: username to authenticate with + :param password: password to authenticate with """ - global _connection - identity = get_identity() + global _connection_settings + _connection_settings[alias] = { + 'name': name, + 'host': host, + 'port': port, + 'is_slave': is_slave, + 'slaves': slaves or [], + 'username': username, + 'password': password, + } + + +def get_connection(alias=DEFAULT_CONNECTION_NAME): + global _connections # Connect to the database if not already connected - if _connection.get(identity) is None or reconnect: + if alias not in _connections: + if alias not in _connection_settings: + msg = 'Connection with alias "%s" has not been defined' + if alias == DEFAULT_CONNECTION_NAME: + msg = 'You have not defined a default connection' + raise ConnectionError(msg) + conn_settings = _connection_settings[alias].copy() + + # Get all the slave connections + slaves = [] + for slave_alias in conn_settings['slaves']: + slaves.append(get_connection(slave_alias)) + conn_settings['slaves'] = slaves + try: - _connection[identity] = Connection(**_connection_settings) + _connections[alias] = Connection(**conn_settings) except Exception, e: - raise ConnectionError("Cannot connect to the database:\n%s" % e) - return _connection[identity] + raise e + raise ConnectionError('Cannot connect to database %s' % alias) + return _connections[alias] -def _get_db(reconnect=False): - """Handles database connections and authentication based on the current - identity + +def get_db(alias=DEFAULT_CONNECTION_NAME): + global _dbs + if alias not in _dbs: + conn = get_connection(alias) + conn_settings = _connection_settings[alias] + _dbs[alias] = conn[conn_settings['name']] + + # Authenticate if necessary + if conn_settings['username'] and conn_settings['password']: + _dbs[alias].authenticate(conn_settings['username'], + conn_settings['password']) + return _dbs[alias] + + +def connect(db, alias=DEFAULT_CONNECTION_NAME, **kwargs): + """Connect to the database specified by the 'db' argument. + + Connection settings may be provided here as well if the database is not + running on the default port on localhost. If authentication is needed, + provide username and password arguments as well. + + Multiple databases are supported by using aliases. Provide a separate + `alias` to connect to a different instance of :program:`mongod`. + + .. versionchanged:: 0.6 - added multiple database support. """ - global _db, _connection - identity = get_identity() - # Connect if not already connected - if _connection.get(identity) is None or reconnect: - _connection[identity] = _get_connection(reconnect=reconnect) + global _connections + if alias not in _connections: + register_connection(alias, db, **kwargs) - if _db.get(identity) is None or reconnect: - # _db_name will be None if the user hasn't called connect() - if _db_name is None: - raise ConnectionError('Not connected to the database') - - # Get DB from current connection and authenticate if necessary - _db[identity] = _connection[identity][_db_name] - if _db_username and _db_password: - _db[identity].authenticate(_db_username, _db_password) - - return _db[identity] - -def get_identity(): - """Creates an identity key based on the current process and thread - identity. - """ - identity = multiprocessing.current_process()._identity - identity = 0 if not identity else identity[0] - - identity = (identity, threading.current_thread().ident) - return identity - -def connect(db, username=None, password=None, **kwargs): - """Connect to the database specified by the 'db' argument. Connection - settings may be provided here as well if the database is not running on - the default port on localhost. If authentication is needed, provide - username and password arguments as well. - """ - global _connection_settings, _db_name, _db_username, _db_password, _db - _connection_settings = dict(_connection_defaults, **kwargs) - _db_name = db - _db_username = username - _db_password = password - return _get_db(reconnect=True) + return get_connection(alias) +# Support old naming convention +_get_connection = get_connection +_get_db = get_db diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 949bb2f9..d817a037 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -1,10 +1,8 @@ -import operator - import pymongo from base import BaseDict, BaseList, get_document, TopLevelDocumentMetaclass from fields import ReferenceField -from connection import _get_db +from connection import get_db from queryset import QuerySet from document import Document @@ -103,7 +101,7 @@ class DeReference(object): for key, doc in references.iteritems(): object_map[key] = doc else: # Generic reference: use the refs data to convert to document - references = _get_db()[col].find({'_id': {'$in': refs}}) + references = get_db()[col].find({'_id': {'$in': refs}}) for ref in references: if '_cls' in ref: doc = get_document(ref['_cls'])._from_son(ref) diff --git a/mongoengine/document.py b/mongoengine/document.py index 82b94a3d..f3893ddc 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -1,14 +1,14 @@ -import operator from mongoengine import signals from base import (DocumentMetaclass, TopLevelDocumentMetaclass, BaseDocument, ValidationError, BaseDict, BaseList, BaseDynamicField) from queryset import OperationError -from connection import _get_db +from connection import get_db import pymongo -__all__ = ['Document', 'EmbeddedDocument', 'DynamicDocument', 'DynamicEmbeddedDocument', - 'ValidationError', 'OperationError', 'InvalidCollectionError'] +__all__ = ['Document', 'EmbeddedDocument', 'DynamicDocument', + 'DynamicEmbeddedDocument', 'ValidationError', 'OperationError', + 'InvalidCollectionError'] class InvalidCollectionError(Exception): @@ -76,7 +76,7 @@ class Document(BaseDocument): by setting index_types to False on the meta dictionary for the document. """ __metaclass__ = TopLevelDocumentMetaclass - + @apply def pk(): """Primary key alias @@ -91,7 +91,7 @@ class Document(BaseDocument): def _get_collection(self): """Returns the collection for the document.""" if not hasattr(self, '_collection') or self._collection is None: - db = _get_db() + db = get_db() collection_name = self._get_collection_name() # Create collection as a capped collection if specified if self._meta['max_size'] or self._meta['max_documents']: @@ -300,7 +300,7 @@ class Document(BaseDocument): :class:`~mongoengine.Document` type from the database. """ from mongoengine.queryset import QuerySet - db = _get_db() + db = get_db() db.drop_collection(cls._get_collection_name()) QuerySet._reset_already_indexed(cls) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 0bfd54d2..3e12296f 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -13,7 +13,7 @@ from base import (BaseField, ComplexBaseField, ObjectIdField, ValidationError, get_document) from queryset import DO_NOTHING from document import Document, EmbeddedDocument -from connection import _get_db +from connection import get_db from operator import itemgetter @@ -637,7 +637,7 @@ class ReferenceField(BaseField): value = instance._data.get(self.name) # Dereference DBRefs if isinstance(value, (pymongo.dbref.DBRef)): - value = _get_db().dereference(value) + value = get_db().dereference(value) if value is not None: instance._data[self.name] = self.document_type._from_son(value) @@ -710,7 +710,7 @@ class GenericReferenceField(BaseField): def dereference(self, value): doc_cls = get_document(value['_cls']) reference = value['_ref'] - doc = _get_db().dereference(reference) + doc = get_db().dereference(reference) if doc is not None: doc = doc_cls._from_son(doc) return doc @@ -780,7 +780,7 @@ class GridFSProxy(object): def __init__(self, grid_id=None, key=None, instance=None, collection_name='fs'): - self.fs = gridfs.GridFS(_get_db(), collection_name) # Filesystem instance + self.fs = gridfs.GridFS(get_db(), collection_name) # Filesystem instance self.newfile = None # Used for partial writes self.grid_id = grid_id # Store GridFS id for file self.gridout = None @@ -1138,7 +1138,7 @@ class SequenceField(IntField): """ sequence_id = "{0}.{1}".format(self.owner_document._get_collection_name(), self.name) - collection = _get_db()[self.collection_name] + collection = get_db()[self.collection_name] counter = collection.find_and_modify(query={"_id": sequence_id}, update={"$inc": {"next": 1}}, new=True, diff --git a/mongoengine/queryset.py b/mongoengine/queryset.py index 6f01b765..4185e39d 100644 --- a/mongoengine/queryset.py +++ b/mongoengine/queryset.py @@ -1,4 +1,4 @@ -from connection import _get_db +from connection import get_db from mongoengine import signals import pprint @@ -481,7 +481,7 @@ class QuerySet(object): if self._document not in QuerySet.__already_indexed: # Ensure collection exists - db = _get_db() + db = get_db() if self._collection_obj.name not in db.collection_names(): self._document._collection = None self._collection_obj = self._document._get_collection() @@ -1436,7 +1436,7 @@ class QuerySet(object): scope['query'] = query code = pymongo.code.Code(code, scope=scope) - db = _get_db() + db = get_db() return db.eval(code, *fields) def where(self, where_clause): diff --git a/mongoengine/tests.py b/mongoengine/tests.py index 9584bc7c..68663772 100644 --- a/mongoengine/tests.py +++ b/mongoengine/tests.py @@ -1,4 +1,4 @@ -from mongoengine.connection import _get_db +from mongoengine.connection import get_db class query_counter(object): @@ -7,7 +7,7 @@ class query_counter(object): def __init__(self): """ Construct the query_counter. """ self.counter = 0 - self.db = _get_db() + self.db = get_db() def __enter__(self): """ On every with block we need to drop the profile collection. """ diff --git a/tests/connection.py b/tests/connection.py new file mode 100644 index 00000000..e017b388 --- /dev/null +++ b/tests/connection.py @@ -0,0 +1,48 @@ +import unittest +import pymongo + +import mongoengine.connection + +from mongoengine import * +from mongoengine.connection import get_db, get_connection + + +class ConnectionTest(unittest.TestCase): + + def tearDown(self): + mongoengine.connection._connection_settings = {} + mongoengine.connection._connections = {} + mongoengine.connection._dbs = {} + + def test_connect(self): + """Ensure that the connect() method works properly. + """ + connect('mongoenginetest') + + conn = get_connection() + self.assertTrue(isinstance(conn, pymongo.connection.Connection)) + + db = get_db() + self.assertTrue(isinstance(db, pymongo.database.Database)) + self.assertEqual(db.name, 'mongoenginetest') + + connect('mongoenginetest2', alias='testdb') + conn = get_connection('testdb') + self.assertTrue(isinstance(conn, pymongo.connection.Connection)) + + def test_register_connection(self): + """Ensure that connections with different aliases may be registered. + """ + register_connection('testdb', 'mongoenginetest2') + + self.assertRaises(ConnectionError, get_connection) + conn = get_connection('testdb') + self.assertTrue(isinstance(conn, pymongo.connection.Connection)) + + db = get_db('testdb') + self.assertTrue(isinstance(db, pymongo.database.Database)) + self.assertEqual(db.name, 'mongoenginetest2') + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/dereference.py b/tests/dereference.py index 088db98e..2e24a61e 100644 --- a/tests/dereference.py +++ b/tests/dereference.py @@ -1,7 +1,7 @@ import unittest from mongoengine import * -from mongoengine.connection import _get_db +from mongoengine.connection import get_db from mongoengine.tests import query_counter @@ -9,7 +9,7 @@ class FieldTest(unittest.TestCase): def setUp(self): connect(db='mongoenginetest') - self.db = _get_db() + self.db = get_db() def test_list_item_dereference(self): """Ensure that DBRef items in ListFields are dereferenced. diff --git a/tests/document.py b/tests/document.py index 9da886c1..dca5fed9 100644 --- a/tests/document.py +++ b/tests/document.py @@ -5,22 +5,18 @@ import warnings from datetime import datetime -import pymongo -import pickle -import weakref - from fixtures import Base, Mixin, PickleEmbedded, PickleTest from mongoengine import * -from mongoengine.base import _document_registry, NotRegistered, InvalidDocumentError -from mongoengine.connection import _get_db +from mongoengine.base import NotRegistered, InvalidDocumentError +from mongoengine.connection import get_db class DocumentTest(unittest.TestCase): def setUp(self): connect(db='mongoenginetest') - self.db = _get_db() + self.db = get_db() class Person(Document): name = StringField() diff --git a/tests/dynamic_document.py b/tests/dynamic_document.py index d76b196c..19cd4665 100644 --- a/tests/dynamic_document.py +++ b/tests/dynamic_document.py @@ -1,13 +1,14 @@ import unittest from mongoengine import * -from mongoengine.connection import _get_db +from mongoengine.connection import get_db + class DynamicDocTest(unittest.TestCase): def setUp(self): connect(db='mongoenginetest') - self.db = _get_db() + self.db = get_db() class Person(DynamicDocument): name = StringField() diff --git a/tests/fields.py b/tests/fields.py index 3b697917..768f18d6 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -6,7 +6,7 @@ import uuid from decimal import Decimal from mongoengine import * -from mongoengine.connection import _get_db +from mongoengine.connection import get_db from mongoengine.base import _document_registry, NotRegistered TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') @@ -16,7 +16,7 @@ class FieldTest(unittest.TestCase): def setUp(self): connect(db='mongoenginetest') - self.db = _get_db() + self.db = get_db() def test_default_values(self): """Ensure that default field values are used when creating a document. diff --git a/tests/fixtures.py b/tests/fixtures.py index 5aaba556..32081fe1 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,9 +1,6 @@ from datetime import datetime -import pymongo from mongoengine import * -from mongoengine.base import BaseField -from mongoengine.connection import _get_db class PickleEmbedded(EmbeddedDocument): diff --git a/tests/queryset.py b/tests/queryset.py index 60adae37..7c1c8dcc 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -7,7 +7,7 @@ from mongoengine.queryset import (QuerySet, QuerySetManager, MultipleObjectsReturned, DoesNotExist, QueryFieldList) from mongoengine import * -from mongoengine.connection import _get_connection +from mongoengine.connection import get_connection from mongoengine.tests import query_counter @@ -2276,7 +2276,7 @@ class QuerySetTest(unittest.TestCase): # check that polygon works for users who have a server >= 1.9 server_version = tuple( - _get_connection().server_info()['version'].split('.') + get_connection().server_info()['version'].split('.') ) required_version = tuple("1.9.0".split(".")) if server_version >= required_version: