added support for db_alias in FileFields

This commit is contained in:
Wilson Júnior 2011-11-22 13:40:01 -02:00
parent e80144e9f2
commit fa4b820931
3 changed files with 43 additions and 6 deletions

View File

@ -1,7 +1,8 @@
from pymongo import Connection
__all__ = ['ConnectionError', 'connect', 'register_connection']
__all__ = ['ConnectionError', 'connect', 'register_connection',
'DEFAULT_CONNECTION_NAME']
DEFAULT_CONNECTION_NAME = 'default'

View File

@ -13,7 +13,7 @@ from base import (BaseField, ComplexBaseField, ObjectIdField,
ValidationError, get_document)
from queryset import DO_NOTHING
from document import Document, EmbeddedDocument
from connection import get_db
from connection import get_db, DEFAULT_CONNECTION_NAME
from operator import itemgetter
@ -779,8 +779,10 @@ class GridFSProxy(object):
"""
def __init__(self, grid_id=None, key=None,
instance=None, collection_name='fs'):
self.fs = gridfs.GridFS(get_db(), collection_name) # Filesystem instance
instance=None,
db_alias=DEFAULT_CONNECTION_NAME,
collection_name='fs'):
self.fs = gridfs.GridFS(get_db(db_alias), collection_name) # Filesystem instance
self.newfile = None # Used for partial writes
self.grid_id = grid_id # Store GridFS id for file
self.gridout = None
@ -870,12 +872,16 @@ class FileField(BaseField):
.. versionadded:: 0.4
.. versionchanged:: 0.5 added optional size param for read
.. versionchanged:: 0.6 added db_alias for multidb support
"""
proxy_class = GridFSProxy
def __init__(self, collection_name="fs", **kwargs):
def __init__(self,
db_alias=DEFAULT_CONNECTION_NAME,
collection_name="fs", **kwargs):
super(FileField, self).__init__(**kwargs)
self.collection_name = collection_name
self.db_alias = db_alias
def __get__(self, instance, owner):
if instance is None:
@ -890,6 +896,7 @@ class FileField(BaseField):
self.grid_file.instance = instance
return self.grid_file
return self.proxy_class(key=self.name, instance=instance,
db_alias=self.db_alias,
collection_name=self.collection_name)
def __set__(self, instance, value):
@ -924,7 +931,8 @@ class FileField(BaseField):
def to_python(self, value):
if value is not None:
return self.proxy_class(value,
collection_name=self.collection_name)
collection_name=self.collection_name,
db_alias=self.db_alias)
def validate(self, value):
if value.grid_id is not None:

View File

@ -1503,6 +1503,34 @@ class FieldTest(unittest.TestCase):
t.image.delete()
def test_file_multidb(self):
register_connection('testfiles', 'testfiles')
class TestFile(Document):
name = StringField()
file = FileField(db_alias="testfiles",
collection_name="macumba")
TestFile.drop_collection()
# delete old filesystem
get_db("testfiles").macumba.files.drop()
get_db("testfiles").macumba.chunks.drop()
# First instance
testfile = TestFile()
testfile.name = "Hello, World!"
testfile.file.put('Hello, World!',
name="hello.txt")
testfile.save()
data = get_db("testfiles").macumba.files.find_one()
self.assertEquals(data.get('name'), 'hello.txt')
testfile = TestFile.objects.first()
self.assertEquals(testfile.file.read(),
'Hello, World!')
def test_geo_indexes(self):
"""Ensure that indexes are created automatically for GeoPointFields.
"""