diff --git a/docs/apireference.rst b/docs/apireference.rst index 267b22aa..4fff317a 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -64,3 +64,5 @@ Fields .. autoclass:: mongoengine.ReferenceField .. autoclass:: mongoengine.GenericReferenceField + +.. autoclass:: mongoengine.FileField diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 127f029f..7d9c47f4 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -7,12 +7,15 @@ import re import pymongo import datetime import decimal +import gridfs +import warnings +import types __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', 'ObjectIdField', 'ReferenceField', 'ValidationError', - 'DecimalField', 'URLField', 'GenericReferenceField', + 'DecimalField', 'URLField', 'GenericReferenceField', 'FileField', 'BinaryField', 'SortedListField', 'EmailField', 'GeoLocationField'] RECURSIVE_REFERENCE_CONSTANT = 'self' @@ -520,3 +523,87 @@ class BinaryField(BaseField): if self.max_bytes is not None and len(value) > self.max_bytes: raise ValidationError('Binary value is too long') + +class GridFSProxy(object): + """Proxy object to handle writing and reading of files to and from GridFS + """ + + def __init__(self): + self.fs = gridfs.GridFS(_get_db()) # Filesystem instance + self.newfile = None # Used for partial writes + self.grid_id = None # Store GridFS id for file + + def __getattr__(self, name): + obj = self.fs.get(self.grid_id) + if name in dir(obj): + return getattr(obj, name) + + def __get__(self, instance, value): + return self + + def new_file(self, **kwargs): + self.newfile = self.fs.new_file(**kwargs) + self.grid_id = self.newfile._id + + def put(self, file, **kwargs): + self.grid_id = self.fs.put(file, **kwargs) + + def write(self, string): + if not self.newfile: + self.new_file() + self.grid_id = self.newfile._id + self.newfile.write(string) + + def writelines(self, lines): + if not self.newfile: + self.new_file() + self.grid_id = self.newfile._id + self.newfile.writelines(lines) + + def read(self): + return self.fs.get(self.grid_id).read() + + def delete(self): + # Delete file from GridFS + self.fs.delete(self.grid_id) + + def close(self): + if self.newfile: + self.newfile.close() + else: + msg = "The close() method is only necessary after calling write()" + warnings.warn(msg) + +class FileField(BaseField): + """A GridFS storage field. + """ + + def __init__(self, **kwargs): + self.gridfs = GridFSProxy() + super(FileField, self).__init__(**kwargs) + + def __get__(self, instance, owner): + if instance is None: + return self + + return self.gridfs + + def __set__(self, instance, value): + if isinstance(value, file) or isinstance(value, str): + # using "FileField() = file/string" notation + self.gridfs.put(value) + else: + instance._data[self.name] = value + + def to_mongo(self, value): + # Store the GridFS file id in MongoDB + return self.gridfs.grid_id + + def to_python(self, value): + # Use stored value (id) to lookup file in GridFS + return self.gridfs.fs.get(value) + + def validate(self, value): + assert isinstance(value, GridFSProxy) + assert isinstance(value.grid_id, pymongo.objectid.ObjectId) + diff --git a/tests/fields.py b/tests/fields.py index 4050e264..e22e6ddd 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -3,6 +3,7 @@ import datetime from decimal import Decimal import pymongo +import gridfs from mongoengine import * from mongoengine.connection import _get_db @@ -607,6 +608,61 @@ class FieldTest(unittest.TestCase): Shirt.drop_collection() + def test_file_fields(self): + """Ensure that file fields can be written to and their data retrieved + """ + class PutFile(Document): + file = FileField() + + class StreamFile(Document): + file = FileField() + + class SetFile(Document): + file = FileField() + + text = 'Hello, World!' + more_text = 'Foo Bar' + content_type = 'text/plain' + + PutFile.drop_collection() + StreamFile.drop_collection() + SetFile.drop_collection() + + putfile = PutFile() + putfile.file.put(text, content_type=content_type) + putfile.save() + putfile.validate() + result = PutFile.objects.first() + self.assertTrue(putfile == result) + self.assertEquals(result.file.read(), text) + self.assertEquals(result.file.content_type, content_type) + result.file.delete() # Remove file from GridFS + + streamfile = StreamFile() + streamfile.file.new_file(content_type=content_type) + streamfile.file.write(text) + streamfile.file.write(more_text) + streamfile.file.close() + streamfile.save() + streamfile.validate() + result = StreamFile.objects.first() + self.assertTrue(streamfile == result) + self.assertEquals(result.file.read(), text + more_text) + self.assertEquals(result.file.content_type, content_type) + result.file.delete() # Remove file from GridFS + + setfile = SetFile() + setfile.file = text + setfile.save() + setfile.validate() + result = SetFile.objects.first() + self.assertTrue(setfile == result) + self.assertEquals(result.file.read(), text) + result.file.delete() # Remove file from GridFS + + PutFile.drop_collection() + StreamFile.drop_collection() + SetFile.drop_collection() if __name__ == '__main__':