diff --git a/docs/apireference.rst b/docs/apireference.rst index 267b22aa..4fff317a 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -64,3 +64,5 @@ Fields .. autoclass:: mongoengine.ReferenceField .. autoclass:: mongoengine.GenericReferenceField + +.. autoclass:: mongoengine.FileField diff --git a/docs/guide/defining-documents.rst b/docs/guide/defining-documents.rst index 3c276869..7b8dcd5b 100644 --- a/docs/guide/defining-documents.rst +++ b/docs/guide/defining-documents.rst @@ -46,6 +46,12 @@ are as follows: * :class:`~mongoengine.EmbeddedDocumentField` * :class:`~mongoengine.ReferenceField` * :class:`~mongoengine.GenericReferenceField` +* :class:`~mongoengine.BooleanField` +* :class:`~mongoengine.GeoLocationField` +* :class:`~mongoengine.FileField` +* :class:`~mongoengine.EmailField` +* :class:`~mongoengine.SortedListField` +* :class:`~mongoengine.BinaryField` Field arguments --------------- diff --git a/docs/guide/querying.rst b/docs/guide/querying.rst index 113ee431..1fd2ed57 100644 --- a/docs/guide/querying.rst +++ b/docs/guide/querying.rst @@ -174,7 +174,7 @@ custom manager methods as you like:: @queryset_manager def live_posts(doc_cls, queryset): - return queryset(published=True).filter(published=True) + return queryset.filter(published=True) BlogPost(title='test1', published=False).save() BlogPost(title='test2', published=True).save() diff --git a/mongoengine/base.py b/mongoengine/base.py index c1306ff5..086c7874 100644 --- a/mongoengine/base.py +++ b/mongoengine/base.py @@ -25,8 +25,8 @@ class BaseField(object): _geo_index = False def __init__(self, db_field=None, name=None, required=False, default=None, - unique=False, unique_with=None, primary_key=False, validation=None, - choices=None): + unique=False, unique_with=None, primary_key=False, + validation=None, choices=None): self.db_field = (db_field or name) if not primary_key else '_id' if name: import warnings @@ -87,13 +87,15 @@ class BaseField(object): # check choices if self.choices is not None: if value not in self.choices: - raise ValidationError("Value must be one of %s."%unicode(self.choices)) + raise ValidationError("Value must be one of %s." + % unicode(self.choices)) # check validation argument if self.validation is not None: if callable(self.validation): if not self.validation(value): - raise ValidationError('Value does not match custom validation method.') + raise ValidationError('Value does not match custom' \ + 'validation method.') else: raise ValueError('validation argument must be a callable.') @@ -337,8 +339,8 @@ class BaseDocument(object): try: field._validate(value) except (ValueError, AttributeError, AssertionError), e: - raise ValidationError('Invalid value for field of type "' + - field.__class__.__name__ + '"') + raise ValidationError('Invalid value for field of type "%s": %s' + % (field.__class__.__name__, value)) elif field.required: raise ValidationError('Field "%s" is required' % field.name) diff --git a/mongoengine/fields.py b/mongoengine/fields.py index f9fa4dee..f84f751b 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -7,17 +7,19 @@ import re import pymongo import datetime import decimal +import gridfs +import warnings +import types __all__ = ['StringField', 'IntField', 'FloatField', 'BooleanField', 'DateTimeField', 'EmbeddedDocumentField', 'ListField', 'DictField', 'ObjectIdField', 'ReferenceField', 'ValidationError', - 'DecimalField', 'URLField', 'GenericReferenceField', + 'DecimalField', 'URLField', 'GenericReferenceField', 'FileField', 'BinaryField', 'SortedListField', 'EmailField', 'GeoPointField'] RECURSIVE_REFERENCE_CONSTANT = 'self' - class StringField(BaseField): """A unicode string field. """ @@ -261,6 +263,7 @@ class ListField(BaseField): raise ValidationError('Argument to ListField constructor must be ' 'a valid field') self.field = field + kwargs.setdefault('default', []) super(ListField, self).__init__(**kwargs) def __get__(self, instance, owner): @@ -353,6 +356,7 @@ class DictField(BaseField): def __init__(self, basecls=None, *args, **kwargs): self.basecls = basecls or BaseField assert issubclass(self.basecls, BaseField) + kwargs.setdefault('default', {}) super(DictField, self).__init__(*args, **kwargs) def validate(self, value): @@ -369,7 +373,6 @@ class DictField(BaseField): def lookup_member(self, member_name): return self.basecls(db_field=member_name) - class ReferenceField(BaseField): """A reference to a document that will be automatically dereferenced on access (lazily). @@ -436,7 +439,6 @@ class ReferenceField(BaseField): def lookup_member(self, member_name): return self.document_type._fields.get(member_name) - class GenericReferenceField(BaseField): """A reference to *any* :class:`~mongoengine.document.Document` subclass that will be automatically dereferenced on access (lazily). @@ -505,6 +507,104 @@ class BinaryField(BaseField): if self.max_bytes is not None and len(value) > self.max_bytes: raise ValidationError('Binary value is too long') +class GridFSProxy(object): + """Proxy object to handle writing and reading of files to and from GridFS + """ + + def __init__(self): + self.fs = gridfs.GridFS(_get_db()) # Filesystem instance + self.newfile = None # Used for partial writes + self.grid_id = None # Store GridFS id for file + + def __getattr__(self, name): + obj = self.get() + if name in dir(obj): + return getattr(obj, name) + + def __get__(self, instance, value): + return self + + def get(self, id=None): + if id: self.grid_id = id + try: return self.fs.get(id or self.grid_id) + except: return None # File has been deleted + + def new_file(self, **kwargs): + self.newfile = self.fs.new_file(**kwargs) + self.grid_id = self.newfile._id + + def put(self, file, **kwargs): + self.grid_id = self.fs.put(file, **kwargs) + + def write(self, string): + if not self.newfile: + self.new_file() + self.grid_id = self.newfile._id + self.newfile.write(string) + + def writelines(self, lines): + if not self.newfile: + self.new_file() + self.grid_id = self.newfile._id + self.newfile.writelines(lines) + + def read(self): + try: return self.get().read() + except: return None + + def delete(self): + # Delete file from GridFS, FileField still remains + self.fs.delete(self.grid_id) + self.grid_id = None + + def replace(self, file, **kwargs): + self.delete() + self.put(file, **kwargs) + + def close(self): + if self.newfile: + self.newfile.close() + else: + msg = "The close() method is only necessary after calling write()" + warnings.warn(msg) + +class FileField(BaseField): + """A GridFS storage field. + """ + + def __init__(self, **kwargs): + self.gridfs = GridFSProxy() + super(FileField, self).__init__(**kwargs) + + def __get__(self, instance, owner): + if instance is None: + return self + + return self.gridfs + + def __set__(self, instance, value): + if isinstance(value, file) or isinstance(value, str): + # using "FileField() = file/string" notation + self.gridfs.put(value) + else: + instance._data[self.name] = value + + def to_mongo(self, value): + # Store the GridFS file id in MongoDB + if self.gridfs.grid_id is not None: + return self.gridfs.grid_id + return None + + def to_python(self, value): + # Use stored value (id) to lookup file in GridFS + if self.gridfs.grid_id is not None: + return self.gridfs.get(id=value) + return None + + def validate(self, value): + if value.grid_id is not None: + assert isinstance(value, GridFSProxy) + assert isinstance(value.grid_id, pymongo.objectid.ObjectId) class GeoPointField(BaseField): """A list storing a latitude and longitude. @@ -523,5 +623,4 @@ class GeoPointField(BaseField): raise ValidationError('Value must be a two-dimensional point.') if (not isinstance(value[0], (float, int)) and not isinstance(value[1], (float, int))): - raise ValidationError('Both values in point must be float or int.') - + raise ValidationError('Both values in point must be float or int.') \ No newline at end of file diff --git a/tests/document.py b/tests/document.py index 8bc907c5..1160b353 100644 --- a/tests/document.py +++ b/tests/document.py @@ -264,11 +264,12 @@ class DocumentTest(unittest.TestCase): # Indexes are lazy so use list() to perform query list(BlogPost.objects) info = BlogPost.objects._collection.index_information() + info = [value['key'] for key, value in info.iteritems()] self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] - in info.values()) - self.assertTrue([('_types', 1), ('addDate', -1)] in info.values()) + in info) + self.assertTrue([('_types', 1), ('addDate', -1)] in info) # tags is a list field so it shouldn't have _types in the index - self.assertTrue([('tags', 1)] in info.values()) + self.assertTrue([('tags', 1)] in info) class ExtendedBlogPost(BlogPost): title = StringField() @@ -278,10 +279,11 @@ class DocumentTest(unittest.TestCase): list(ExtendedBlogPost.objects) info = ExtendedBlogPost.objects._collection.index_information() + info = [value['key'] for key, value in info.iteritems()] self.assertTrue([('_types', 1), ('category', 1), ('addDate', -1)] - in info.values()) - self.assertTrue([('_types', 1), ('addDate', -1)] in info.values()) - self.assertTrue([('_types', 1), ('title', 1)] in info.values()) + in info) + self.assertTrue([('_types', 1), ('addDate', -1)] in info) + self.assertTrue([('_types', 1), ('title', 1)] in info) BlogPost.drop_collection() diff --git a/tests/fields.py b/tests/fields.py index 80ce3b67..136437b8 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -3,6 +3,7 @@ import datetime from decimal import Decimal import pymongo +import gridfs from mongoengine import * from mongoengine.connection import _get_db @@ -607,6 +608,73 @@ class FieldTest(unittest.TestCase): Shirt.drop_collection() + def test_file_fields(self): + """Ensure that file fields can be written to and their data retrieved + """ + class PutFile(Document): + file = FileField() + + class StreamFile(Document): + file = FileField() + + class SetFile(Document): + file = FileField() + + text = 'Hello, World!' + more_text = 'Foo Bar' + content_type = 'text/plain' + + PutFile.drop_collection() + StreamFile.drop_collection() + SetFile.drop_collection() + + putfile = PutFile() + putfile.file.put(text, content_type=content_type) + putfile.save() + putfile.validate() + result = PutFile.objects.first() + self.assertTrue(putfile == result) + self.assertEquals(result.file.read(), text) + self.assertEquals(result.file.content_type, content_type) + result.file.delete() # Remove file from GridFS + + streamfile = StreamFile() + streamfile.file.new_file(content_type=content_type) + streamfile.file.write(text) + streamfile.file.write(more_text) + streamfile.file.close() + streamfile.save() + streamfile.validate() + result = StreamFile.objects.first() + self.assertTrue(streamfile == result) + self.assertEquals(result.file.read(), text + more_text) + self.assertEquals(result.file.content_type, content_type) + result.file.delete() + + # Ensure deleted file returns None + self.assertTrue(result.file.read() == None) + + setfile = SetFile() + setfile.file = text + setfile.save() + setfile.validate() + result = SetFile.objects.first() + self.assertTrue(setfile == result) + self.assertEquals(result.file.read(), text) + + # Try replacing file with new one + result.file.replace(more_text) + result.save() + result.validate() + result = SetFile.objects.first() + self.assertTrue(setfile == result) + self.assertEquals(result.file.read(), more_text) + result.file.delete() + + PutFile.drop_collection() + StreamFile.drop_collection() + SetFile.drop_collection() + def test_geo_indexes(self): """Ensure that indexes are created automatically for GeoPointFields. """ @@ -621,11 +689,9 @@ class FieldTest(unittest.TestCase): info = Event.objects._collection.index_information() self.assertTrue(u'location_2d' in info) - self.assertTrue(info[u'location_2d'] == [(u'location', u'2d')]) + self.assertTrue(info[u'location_2d']['key'] == [(u'location', u'2d')]) Event.drop_collection() - - if __name__ == '__main__': unittest.main() diff --git a/tests/queryset.py b/tests/queryset.py index 4187d550..0424d323 100644 --- a/tests/queryset.py +++ b/tests/queryset.py @@ -1087,8 +1087,9 @@ class QuerySetTest(unittest.TestCase): # Indexes are lazy so use list() to perform query list(BlogPost.objects) info = BlogPost.objects._collection.index_information() - self.assertTrue([('_types', 1)] in info.values()) - self.assertTrue([('_types', 1), ('date', -1)] in info.values()) + info = [value['key'] for key, value in info.iteritems()] + self.assertTrue([('_types', 1)] in info) + self.assertTrue([('_types', 1), ('date', -1)] in info) BlogPost.drop_collection()