diff --git a/docs/changelog.rst b/docs/changelog.rst index ae680d5a..6fa9ea50 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -4,6 +4,7 @@ Changelog Changes in 0.6.X ================ +- Fixed FileField losing reference when no default set - Removed possible race condition from FileField (grid_file) - Added assignment to save, can now do: b = MyDoc(**kwargs).save() - Added support for pull operations on nested EmbeddedDocuments diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 96e11f5c..f88e5c19 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -946,14 +946,16 @@ class FileField(BaseField): # Check if a file already exists for this model grid_file = instance._data.get(self.name) - if isinstance(grid_file, self.proxy_class): - if not grid_file.key: - grid_file.key = self.name - grid_file.instance = instance - return grid_file - return self.proxy_class(key=self.name, instance=instance, - db_alias=self.db_alias, - collection_name=self.collection_name) + if not isinstance(grid_file, self.proxy_class): + grid_file = self.proxy_class(key=self.name, instance=instance, + db_alias=self.db_alias, + collection_name=self.collection_name) + instance._data[self.name] = grid_file + + if not grid_file.key: + grid_file.key = self.name + grid_file.instance = instance + return grid_file def __set__(self, instance, value): key = self.name diff --git a/tests/fields.py b/tests/fields.py index ea5262db..b75ef0a8 100644 --- a/tests/fields.py +++ b/tests/fields.py @@ -3,6 +3,8 @@ import os import unittest import uuid import StringIO +import tempfile +import gridfs from decimal import Decimal @@ -19,6 +21,10 @@ class FieldTest(unittest.TestCase): connect(db='mongoenginetest') self.db = get_db() + def tearDown(self): + self.db.drop_collection('fs.files') + self.db.drop_collection('fs.chunks') + def test_default_values(self): """Ensure that default field values are used when creating a document. """ @@ -1647,6 +1653,48 @@ class FieldTest(unittest.TestCase): testimage.delete() self.assertFalse(testimage_fs.exists(testimage_grid_id)) + def test_file_field_no_default(self): + + class GridDocument(Document): + the_file = FileField() + + GridDocument.drop_collection() + + with tempfile.TemporaryFile() as f: + f.write("Hello World!") + f.flush() + + # Test without default + doc_a = GridDocument() + doc_a.save() + + + doc_b = GridDocument.objects.with_id(doc_a.id) + doc_b.the_file.replace(f, filename='doc_b') + doc_b.save() + self.assertNotEquals(doc_b.the_file.grid_id, None) + + # Test it matches + doc_c = GridDocument.objects.with_id(doc_b.id) + self.assertEquals(doc_b.the_file.grid_id, doc_c.the_file.grid_id) + + # Test with default + doc_d = GridDocument(the_file='') + doc_d.save() + + doc_e = GridDocument.objects.with_id(doc_d.id) + self.assertEquals(doc_d.the_file.grid_id, doc_e.the_file.grid_id) + + doc_e.the_file.replace(f, filename='doc_e') + doc_e.save() + + doc_f = GridDocument.objects.with_id(doc_e.id) + self.assertEquals(doc_e.the_file.grid_id, doc_f.the_file.grid_id) + + db = GridDocument._get_db() + grid_fs = gridfs.GridFS(db) + self.assertEquals(['doc_b', 'doc_e'], grid_fs.list()) + def test_file_uniqueness(self): """Ensure that each instance of a FileField is unique """