Added SwitchDB context manager (#106)

This commit is contained in:
Ross Lawley 2013-01-23 12:54:14 +00:00
parent 6d68ad735c
commit e5e88d792e
5 changed files with 85 additions and 1 deletions

View File

@ -7,6 +7,7 @@ Connecting
.. autofunction:: mongoengine.connect .. autofunction:: mongoengine.connect
.. autofunction:: mongoengine.register_connection .. autofunction:: mongoengine.register_connection
.. autoclass:: mongoengine.SwitchDB
Documents Documents
========= =========

View File

@ -32,6 +32,7 @@ Changes in 0.8.X
- Fixed inheritance and unique index creation (#140) - Fixed inheritance and unique index creation (#140)
- Fixed reverse delete rule with inheritance (#197) - Fixed reverse delete rule with inheritance (#197)
- Fixed validation for GenericReferences which havent been dereferenced - Fixed validation for GenericReferences which havent been dereferenced
- Added SwitchDB context manager (#106)
Changes in 0.7.9 Changes in 0.7.9
================ ================

View File

@ -69,3 +69,21 @@ to point across databases and collections. Below is an example schema, using
book = ReferenceField(Book) book = ReferenceField(Book)
meta = {"db_alias": "users-books-db"} meta = {"db_alias": "users-books-db"}
Switch Database Context Manager
===============================
Sometimes you might want to switch the database to query against for a class.
The SwitchDB context manager allows you to change the database alias for a
class eg ::
from mongoengine import SwitchDB
class User(Document):
name = StringField()
meta = {"db_alias": "user-db"}
with SwitchDB(User, 'archive-user-db') as User:
User(name="Ross").save() # Saves the 'archive-user-db'

View File

@ -3,7 +3,7 @@ from pymongo import Connection, ReplicaSetConnection, uri_parser
__all__ = ['ConnectionError', 'connect', 'register_connection', __all__ = ['ConnectionError', 'connect', 'register_connection',
'DEFAULT_CONNECTION_NAME'] 'DEFAULT_CONNECTION_NAME', 'SwitchDB']
DEFAULT_CONNECTION_NAME = 'default' DEFAULT_CONNECTION_NAME = 'default'
@ -163,6 +163,47 @@ def connect(db, alias=DEFAULT_CONNECTION_NAME, **kwargs):
return get_connection(alias) return get_connection(alias)
class SwitchDB(object):
""" SwitchDB alias contextmanager.
Example ::
# Register connections
register_connection('default', 'mongoenginetest')
register_connection('testdb-1', 'mongoenginetest2')
class Group(Document):
name = StringField()
Group(name="test").save() # Saves in the default db
with SwitchDB(Group, 'testdb-1') as Group:
Group(name="hello testdb!").save() # Saves in testdb-1
"""
def __init__(self, cls, db_alias):
""" Construct the query_counter.
:param cls: the class to change the registered db
:param db_alias: the name of the specific database to use
"""
self.cls = cls
self.collection = cls._get_collection()
self.db_alias = db_alias
self.ori_db_alias = cls._meta.get("db_alias", DEFAULT_CONNECTION_NAME)
def __enter__(self):
""" change the db_alias and clear the cached collection """
self.cls._meta["db_alias"] = self.db_alias
self.cls._collection = None
return self.cls
def __exit__(self, t, value, traceback):
""" Reset the db_alias and collection """
self.cls._meta["db_alias"] = self.ori_db_alias
self.cls._collection = self.collection
# Support old naming convention # Support old naming convention
_get_connection = get_connection _get_connection = get_connection
_get_db = get_db _get_db = get_db

View File

@ -93,6 +93,29 @@ class ConnectionTest(unittest.TestCase):
date_doc = DateDoc.objects.first() date_doc = DateDoc.objects.first()
self.assertEqual(d, date_doc.the_date) self.assertEqual(d, date_doc.the_date)
def test_switch_db_context_manager(self):
register_connection('testdb-1', 'mongoenginetest2')
class Group(Document):
name = StringField()
Group.drop_collection()
Group(name="hello - default").save()
self.assertEqual(1, Group.objects.count())
with SwitchDB(Group, 'testdb-1') as Group:
self.assertEqual(0, Group.objects.count())
Group(name="hello").save()
self.assertEqual(1, Group.objects.count())
Group.drop_collection()
self.assertEqual(0, Group.objects.count())
self.assertEqual(1, Group.objects.count())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()