diff --git a/docs/changelog.rst b/docs/changelog.rst index b9ab42c3..c6493038 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -25,6 +25,7 @@ Changes in 0.8.X - Fixed GridFSProxy __getattr__ behaviour (#196) - Fix Django timezone support (#151) - Simplified Q objects, removed QueryTreeTransformerVisitor (#98) (#171) +- FileFields now copyable (#198) Changes in 0.7.9 ================ diff --git a/mongoengine/fields.py b/mongoengine/fields.py index f6c03119..5f11ae3b 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -986,14 +986,22 @@ class GridFSProxy(object): self_dict['_fs'] = None return self_dict + def __copy__(self): + copied = GridFSProxy() + copied.__dict__.update(self.__getstate__()) + return copied + + def __deepcopy__(self, memo): + return self.__copy__() + def __repr__(self): return '<%s: %s>' % (self.__class__.__name__, self.grid_id) def __eq__(self, other): if isinstance(other, GridFSProxy): - return ((self.grid_id == other.grid_id) and - (self.collection_name == other.collection_name) and - (self.db_alias == other.db_alias)) + return ((self.grid_id == other.grid_id) and + (self.collection_name == other.collection_name) and + (self.db_alias == other.db_alias)) else: return False diff --git a/tests/fields/file.py b/tests/fields/file.py index 17d9ec37..a39dadbf 100644 --- a/tests/fields/file.py +++ b/tests/fields/file.py @@ -3,23 +3,17 @@ from __future__ import with_statement import sys sys.path[0:0] = [""] -import datetime +import copy import os import unittest -import uuid import tempfile -from decimal import Decimal - -from bson import Binary, DBRef, ObjectId import gridfs from nose.plugins.skip import SkipTest from mongoengine import * from mongoengine.connection import get_db -from mongoengine.base import _document_registry -from mongoengine.errors import NotRegistered -from mongoengine.python_support import PY3, b, StringIO, bin_type +from mongoengine.python_support import PY3, b, StringIO TEST_IMAGE_PATH = os.path.join(os.path.dirname(__file__), 'mongoengine.png') @@ -50,13 +44,12 @@ class FileTest(unittest.TestCase): PutFile.drop_collection() text = b('Hello, World!') - more_text = b('Foo Bar') content_type = 'text/plain' putfile = PutFile() putfile.the_file.put(text, content_type=content_type) putfile.save() - putfile.validate() + result = PutFile.objects.first() self.assertTrue(putfile == result) self.assertEqual(result.the_file.read(), text) @@ -73,7 +66,7 @@ class FileTest(unittest.TestCase): putstring.seek(0) putfile.the_file.put(putstring, content_type=content_type) putfile.save() - putfile.validate() + result = PutFile.objects.first() self.assertTrue(putfile == result) self.assertEqual(result.the_file.read(), text) @@ -98,7 +91,7 @@ class FileTest(unittest.TestCase): streamfile.the_file.write(more_text) streamfile.the_file.close() streamfile.save() - streamfile.validate() + result = StreamFile.objects.first() self.assertTrue(streamfile == result) self.assertEqual(result.the_file.read(), text + more_text) @@ -135,7 +128,7 @@ class FileTest(unittest.TestCase): # Try replacing file with new one result.the_file.replace(more_text) result.save() - result.validate() + result = SetFile.objects.first() self.assertTrue(setfile == result) self.assertEqual(result.the_file.read(), more_text) @@ -366,5 +359,25 @@ class FileTest(unittest.TestCase): self.assertEqual(test_file.the_file.read(), b('Hello, World!')) + def test_copyable(self): + class PutFile(Document): + the_file = FileField() + + PutFile.drop_collection() + + text = b('Hello, World!') + content_type = 'text/plain' + + putfile = PutFile() + putfile.the_file.put(text, content_type=content_type) + putfile.save() + + class TestFile(Document): + name = StringField() + + self.assertEqual(putfile, copy.copy(putfile)) + self.assertEqual(putfile, copy.deepcopy(putfile)) + + if __name__ == '__main__': unittest.main()