Added multidb support

No change required to upgrade to multiple databases. Aliases are used
to describe the database and these can be manually registered or fall
through to a default alias using connect.

Made get_connection and get_db first class members of the connection class.
Old style _get_connection and _get_db still supported.

Refs: #84 #87 #93 #215
This commit is contained in:
Ross Lawley
2011-11-18 07:22:37 -08:00
parent 63c5a4dd65
commit e80144e9f2
15 changed files with 180 additions and 103 deletions

48
tests/connection.py Normal file
View File

@@ -0,0 +1,48 @@
import unittest
import pymongo
import mongoengine.connection
from mongoengine import *
from mongoengine.connection import get_db, get_connection
class ConnectionTest(unittest.TestCase):
def tearDown(self):
mongoengine.connection._connection_settings = {}
mongoengine.connection._connections = {}
mongoengine.connection._dbs = {}
def test_connect(self):
"""Ensure that the connect() method works properly.
"""
connect('mongoenginetest')
conn = get_connection()
self.assertTrue(isinstance(conn, pymongo.connection.Connection))
db = get_db()
self.assertTrue(isinstance(db, pymongo.database.Database))
self.assertEqual(db.name, 'mongoenginetest')
connect('mongoenginetest2', alias='testdb')
conn = get_connection('testdb')
self.assertTrue(isinstance(conn, pymongo.connection.Connection))
def test_register_connection(self):
"""Ensure that connections with different aliases may be registered.
"""
register_connection('testdb', 'mongoenginetest2')
self.assertRaises(ConnectionError, get_connection)
conn = get_connection('testdb')
self.assertTrue(isinstance(conn, pymongo.connection.Connection))
db = get_db('testdb')
self.assertTrue(isinstance(db, pymongo.database.Database))
self.assertEqual(db.name, 'mongoenginetest2')
if __name__ == '__main__':
unittest.main()

View File

@@ -1,7 +1,7 @@
import unittest
from mongoengine import *
from mongoengine.connection import _get_db
from mongoengine.connection import get_db
from mongoengine.tests import query_counter
@@ -9,7 +9,7 @@ class FieldTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
self.db = _get_db()
self.db = get_db()
def test_list_item_dereference(self):
"""Ensure that DBRef items in ListFields are dereferenced.

View File

@@ -5,22 +5,18 @@ import warnings
from datetime import datetime
import pymongo
import pickle
import weakref
from fixtures import Base, Mixin, PickleEmbedded, PickleTest
from mongoengine import *
from mongoengine.base import _document_registry, NotRegistered, InvalidDocumentError
from mongoengine.connection import _get_db
from mongoengine.base import NotRegistered, InvalidDocumentError
from mongoengine.connection import get_db
class DocumentTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
self.db = _get_db()
self.db = get_db()
class Person(Document):
name = StringField()

View File

@@ -1,13 +1,14 @@
import unittest
from mongoengine import *
from mongoengine.connection import _get_db
from mongoengine.connection import get_db
class DynamicDocTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
self.db = _get_db()
self.db = get_db()
class Person(DynamicDocument):
name = StringField()

View File

@@ -6,7 +6,7 @@ import uuid
from decimal import Decimal
from mongoengine import *
from mongoengine.connection import _get_db
from mongoengine.connection import get_db
from mongoengine.base import _document_registry, NotRegistered
TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png')
@@ -16,7 +16,7 @@ class FieldTest(unittest.TestCase):
def setUp(self):
connect(db='mongoenginetest')
self.db = _get_db()
self.db = get_db()
def test_default_values(self):
"""Ensure that default field values are used when creating a document.

View File

@@ -1,9 +1,6 @@
from datetime import datetime
import pymongo
from mongoengine import *
from mongoengine.base import BaseField
from mongoengine.connection import _get_db
class PickleEmbedded(EmbeddedDocument):

View File

@@ -7,7 +7,7 @@ from mongoengine.queryset import (QuerySet, QuerySetManager,
MultipleObjectsReturned, DoesNotExist,
QueryFieldList)
from mongoengine import *
from mongoengine.connection import _get_connection
from mongoengine.connection import get_connection
from mongoengine.tests import query_counter
@@ -2276,7 +2276,7 @@ class QuerySetTest(unittest.TestCase):
# check that polygon works for users who have a server >= 1.9
server_version = tuple(
_get_connection().server_info()['version'].split('.')
get_connection().server_info()['version'].split('.')
)
required_version = tuple("1.9.0".split("."))
if server_version >= required_version: