diff --git a/mongoengine/connection.py b/mongoengine/connection.py index c7d8f893..07b730b8 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,7 +1,8 @@ from pymongo import Connection -__all__ = ['ConnectionError', 'connect', 'register_connection'] +__all__ = ['ConnectionError', 'connect', 'register_connection', + 'DEFAULT_CONNECTION_NAME'] DEFAULT_CONNECTION_NAME = 'default' diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 3e12296f..23d9e453 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -13,7 +13,7 @@ from base import (BaseField, ComplexBaseField, ObjectIdField, ValidationError, get_document) from queryset import DO_NOTHING from document import Document, EmbeddedDocument -from connection import get_db +from connection import get_db, DEFAULT_CONNECTION_NAME from operator import itemgetter @@ -779,8 +779,10 @@ class GridFSProxy(object): """ def __init__(self, grid_id=None, key=None, - instance=None, collection_name='fs'): - self.fs = gridfs.GridFS(get_db(), collection_name) # Filesystem instance + instance=None, + db_alias=DEFAULT_CONNECTION_NAME, + collection_name='fs'): + self.fs = gridfs.GridFS(get_db(db_alias), collection_name) # Filesystem instance self.newfile = None # Used for partial writes self.grid_id = grid_id # Store GridFS id for file self.gridout = None @@ -870,12 +872,16 @@ class FileField(BaseField): .. versionadded:: 0.4 .. versionchanged:: 0.5 added optional size param for read + .. versionchanged:: 0.6 added db_alias for multidb support """ proxy_class = GridFSProxy - def __init__(self, collection_name="fs", **kwargs): + def __init__(self, + db_alias=DEFAULT_CONNECTION_NAME, + collection_name="fs", **kwargs): super(FileField, self).__init__(**kwargs) self.collection_name = collection_name + self.db_alias = db_alias def __get__(self, instance, owner): if instance is None: @@ -890,6 +896,7 @@ class FileField(BaseField): self.grid_file.instance = instance return self.grid_file return self.proxy_class(key=self.name, instance=instance, + db_alias=self.db_alias, collection_name=self.collection_name) def __set__(self, instance, value): @@ -924,7 +931,8 @@ class FileField(BaseField): def to_python(self, value): if value is not None: return self.proxy_class(value, - collection_name=self.collection_name) + collection_name=self.collection_name, + db_alias=self.db_alias) def validate(self, value): if value.grid_id is not None: diff --git a/tests/fields.py b/tests/fields.py index 768f18d6..5f19e336 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -1503,6 +1503,34 @@ class FieldTest(unittest.TestCase): t.image.delete() + + def test_file_multidb(self): + register_connection('testfiles', 'testfiles') + class TestFile(Document): + name = StringField() + file = FileField(db_alias="testfiles", + collection_name="macumba") + + TestFile.drop_collection() + + # delete old filesystem + get_db("testfiles").macumba.files.drop() + get_db("testfiles").macumba.chunks.drop() + + # First instance + testfile = TestFile() + testfile.name = "Hello, World!" + testfile.file.put('Hello, World!', + name="hello.txt") + testfile.save() + + data = get_db("testfiles").macumba.files.find_one() + self.assertEquals(data.get('name'), 'hello.txt') + + testfile = TestFile.objects.first() + self.assertEquals(testfile.file.read(), + 'Hello, World!') + def test_geo_indexes(self): """Ensure that indexes are created automatically for GeoPointFields. """