dont re-implement six

This commit is contained in:
Stefan Wojcik 2016-12-06 16:14:53 -05:00
parent 50df653768
commit 548c7438b0
12 changed files with 93 additions and 85 deletions

View File

@ -1,6 +1,8 @@
import itertools import itertools
import weakref import weakref
import six
from mongoengine.common import _import_class from mongoengine.common import _import_class
from mongoengine.errors import DoesNotExist, MultipleObjectsReturned from mongoengine.errors import DoesNotExist, MultipleObjectsReturned
@ -212,7 +214,8 @@ class EmbeddedDocumentList(BaseList):
def __match_all(cls, i, kwargs): def __match_all(cls, i, kwargs):
items = kwargs.items() items = kwargs.items()
return all([ return all([
getattr(i, k) == v or unicode(getattr(i, k)) == v for k, v in items getattr(i, k) == v or six.text_type(getattr(i, k)) == v
for k, v in items
]) ])
@classmethod @classmethod

View File

@ -8,6 +8,7 @@ from bson import ObjectId, json_util
from bson.dbref import DBRef from bson.dbref import DBRef
from bson.son import SON from bson.son import SON
import pymongo import pymongo
import six
from mongoengine import signals from mongoengine import signals
from mongoengine.base.common import ALLOW_INHERITANCE, get_document from mongoengine.base.common import ALLOW_INHERITANCE, get_document
@ -18,7 +19,7 @@ from mongoengine.base.fields import ComplexBaseField
from mongoengine.common import _import_class from mongoengine.common import _import_class
from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError, from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError,
LookUpError, ValidationError) LookUpError, ValidationError)
from mongoengine.python_support import PY3, txt_type from mongoengine.python_support import PY3
__all__ = ('BaseDocument', 'NON_FIELD_ERRORS') __all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
@ -250,12 +251,13 @@ class BaseDocument(object):
return repr_type('<%s: %s>' % (self.__class__.__name__, u)) return repr_type('<%s: %s>' % (self.__class__.__name__, u))
def __str__(self): def __str__(self):
# TODO this could be simpler?
if hasattr(self, '__unicode__'): if hasattr(self, '__unicode__'):
if PY3: if PY3:
return self.__unicode__() return self.__unicode__()
else: else:
return unicode(self).encode('utf-8') return six.text_type(self).encode('utf-8')
return txt_type('%s object' % self.__class__.__name__) return six.text_type('%s object' % self.__class__.__name__)
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, self.__class__) and hasattr(other, 'id') and other.id is not None: if isinstance(other, self.__class__) and hasattr(other, 'id') and other.id is not None:

View File

@ -4,6 +4,7 @@ import weakref
from bson import DBRef, ObjectId, SON from bson import DBRef, ObjectId, SON
import pymongo import pymongo
import six
from mongoengine.base.common import ALLOW_INHERITANCE from mongoengine.base.common import ALLOW_INHERITANCE
from mongoengine.base.datastructures import ( from mongoengine.base.datastructures import (
@ -12,6 +13,7 @@ from mongoengine.base.datastructures import (
from mongoengine.common import _import_class from mongoengine.common import _import_class
from mongoengine.errors import ValidationError from mongoengine.errors import ValidationError
__all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField', __all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField',
'GeoJsonBaseField') 'GeoJsonBaseField')
@ -200,11 +202,11 @@ class BaseField(object):
if isinstance(value, (Document, EmbeddedDocument)): if isinstance(value, (Document, EmbeddedDocument)):
if not any(isinstance(value, c) for c in choice_list): if not any(isinstance(value, c) for c in choice_list):
self.error( self.error(
'Value must be instance of %s' % unicode(choice_list) 'Value must be instance of %s' % six.text_type(choice_list)
) )
# Choices which are types other than Documents # Choices which are types other than Documents
elif value not in choice_list: elif value not in choice_list:
self.error('Value must be one of %s' % unicode(choice_list)) self.error('Value must be one of %s' % six.text_type(choice_list))
def _validate(self, value, **kwargs): def _validate(self, value, **kwargs):
# Check the Choices Constraint # Check the Choices Constraint
@ -457,10 +459,10 @@ class ObjectIdField(BaseField):
def to_mongo(self, value): def to_mongo(self, value):
if not isinstance(value, ObjectId): if not isinstance(value, ObjectId):
try: try:
return ObjectId(unicode(value)) return ObjectId(six.text_type(value))
except Exception as e: except Exception as e:
# e.message attribute has been deprecated since Python 2.6 # e.message attribute has been deprecated since Python 2.6
self.error(unicode(e)) self.error(six.text_type(e))
return value return value
def prepare_query_value(self, op, value): def prepare_query_value(self, op, value):
@ -468,7 +470,7 @@ class ObjectIdField(BaseField):
def validate(self, value): def validate(self, value):
try: try:
ObjectId(unicode(value)) ObjectId(six.text_type(value))
except Exception: except Exception:
self.error('Invalid Object ID') self.error('Invalid Object ID')

View File

@ -1,5 +1,7 @@
from pymongo import MongoClient, ReadPreference, uri_parser from pymongo import MongoClient, ReadPreference, uri_parser
from mongoengine.python_support import (IS_PYMONGO_3, str_types) import six
from mongoengine.python_support import IS_PYMONGO_3
__all__ = ['ConnectionError', 'connect', 'register_connection', __all__ = ['ConnectionError', 'connect', 'register_connection',
'DEFAULT_CONNECTION_NAME'] 'DEFAULT_CONNECTION_NAME']
@ -66,7 +68,7 @@ def register_connection(alias, name=None, host=None, port=None,
# Handle uri style connections # Handle uri style connections
conn_host = conn_settings['host'] conn_host = conn_settings['host']
# host can be a list or a string, so if string, force to a list # host can be a list or a string, so if string, force to a list
if isinstance(conn_host, str_types): if isinstance(conn_host, six.string_types):
conn_host = [conn_host] conn_host = [conn_host]
resolved_hosts = [] resolved_hosts = []

View File

@ -1,4 +1,5 @@
from bson import DBRef, SON from bson import DBRef, SON
import six
from mongoengine.base.common import get_document from mongoengine.base.common import get_document
from mongoengine.base.datastructures import (BaseDict, BaseList, from mongoengine.base.datastructures import (BaseDict, BaseList,
@ -7,7 +8,6 @@ from mongoengine.base.metaclasses import TopLevelDocumentMetaclass
from mongoengine.connection import get_db from mongoengine.connection import get_db
from mongoengine.document import Document, EmbeddedDocument from mongoengine.document import Document, EmbeddedDocument
from mongoengine.fields import DictField, ListField, MapField, ReferenceField from mongoengine.fields import DictField, ListField, MapField, ReferenceField
from mongoengine.python_support import txt_type
from mongoengine.queryset import QuerySet from mongoengine.queryset import QuerySet
@ -227,7 +227,7 @@ class DeReference(object):
data[k]._data[field_name] = self.object_map.get( data[k]._data[field_name] = self.object_map.get(
(v['_ref'].collection, v['_ref'].id), v) (v['_ref'].collection, v['_ref'].id), v)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
item_name = txt_type("{0}.{1}.{2}").format(name, k, field_name) item_name = six.text_type("{0}.{1}.{2}").format(name, k, field_name)
data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name) data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name)
elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
item_name = '%s.%s' % (name, k) if name else name item_name = '%s.%s' % (name, k) if name else name

View File

@ -4,6 +4,7 @@ import warnings
from bson.dbref import DBRef from bson.dbref import DBRef
import pymongo import pymongo
from pymongo.read_preferences import ReadPreference from pymongo.read_preferences import ReadPreference
import six
from mongoengine import signals from mongoengine import signals
from mongoengine.base.common import ALLOW_INHERITANCE, get_document from mongoengine.base.common import ALLOW_INHERITANCE, get_document
@ -391,15 +392,16 @@ class Document(BaseDocument):
self.cascade_save(**kwargs) self.cascade_save(**kwargs)
except pymongo.errors.DuplicateKeyError as err: except pymongo.errors.DuplicateKeyError as err:
message = u'Tried to save duplicate unique keys (%s)' message = u'Tried to save duplicate unique keys (%s)'
raise NotUniqueError(message % unicode(err)) raise NotUniqueError(message % six.text_type(err))
except pymongo.errors.OperationFailure as err: except pymongo.errors.OperationFailure as err:
message = 'Could not save document (%s)' message = 'Could not save document (%s)'
if re.match('^E1100[01] duplicate key', unicode(err)): if re.match('^E1100[01] duplicate key', six.text_type(err)):
# E11000 - duplicate key error index # E11000 - duplicate key error index
# E11001 - duplicate key on update # E11001 - duplicate key on update
message = u'Tried to save duplicate unique keys (%s)' message = u'Tried to save duplicate unique keys (%s)'
raise NotUniqueError(message % unicode(err)) raise NotUniqueError(message % six.text_type(err))
raise OperationError(message % unicode(err)) raise OperationError(message % six.text_type(err))
id_field = self._meta['id_field'] id_field = self._meta['id_field']
if created or id_field not in self._meta.get('shard_key', []): if created or id_field not in self._meta.get('shard_key', []):
self[id_field] = self._fields[id_field].to_python(object_id) self[id_field] = self._fields[id_field].to_python(object_id)

View File

@ -1,7 +1,6 @@
from collections import defaultdict from collections import defaultdict
from mongoengine.python_support import txt_type import six
__all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError', __all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError',
'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError', 'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError',
@ -77,7 +76,7 @@ class ValidationError(AssertionError):
self.message = message self.message = message
def __str__(self): def __str__(self):
return txt_type(self.message) return six.text_type(self.message)
def __repr__(self): def __repr__(self):
return '%s(%s,)' % (self.__class__.__name__, self.message) return '%s(%s,)' % (self.__class__.__name__, self.message)
@ -111,17 +110,20 @@ class ValidationError(AssertionError):
errors_dict = {} errors_dict = {}
if not source: if not source:
return errors_dict return errors_dict
if isinstance(source, dict): if isinstance(source, dict):
for field_name, error in source.iteritems(): for field_name, error in source.iteritems():
errors_dict[field_name] = build_dict(error) errors_dict[field_name] = build_dict(error)
elif isinstance(source, ValidationError) and source.errors: elif isinstance(source, ValidationError) and source.errors:
return build_dict(source.errors) return build_dict(source.errors)
else: else:
return unicode(source) return six.text_type(source)
return errors_dict return errors_dict
if not self.errors: if not self.errors:
return {} return {}
return build_dict(self.errors) return build_dict(self.errors)
def _format_errors(self): def _format_errors(self):

View File

@ -32,8 +32,7 @@ from mongoengine.base.fields import (BaseField, ComplexBaseField,
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.document import Document, EmbeddedDocument from mongoengine.document import Document, EmbeddedDocument
from mongoengine.errors import DoesNotExist, ValidationError from mongoengine.errors import DoesNotExist, ValidationError
from mongoengine.python_support import (PY3, StringIO, bin_type, str_types, from mongoengine.python_support import PY3, StringIO
txt_type)
from mongoengine.queryset import DO_NOTHING, QuerySet from mongoengine.queryset import DO_NOTHING, QuerySet
try: try:
@ -1294,17 +1293,17 @@ class BinaryField(BaseField):
def __set__(self, instance, value): def __set__(self, instance, value):
"""Handle bytearrays in python 3.1""" """Handle bytearrays in python 3.1"""
if PY3 and isinstance(value, bytearray): if PY3 and isinstance(value, bytearray):
value = bin_type(value) value = six.binary_type(value)
return super(BinaryField, self).__set__(instance, value) return super(BinaryField, self).__set__(instance, value)
def to_mongo(self, value): def to_mongo(self, value):
return Binary(value) return Binary(value)
def validate(self, value): def validate(self, value):
if not isinstance(value, (bin_type, txt_type, Binary)): if not isinstance(value, (six.binary_type, six.text_type, Binary)):
self.error("BinaryField only accepts instances of " self.error("BinaryField only accepts instances of "
"(%s, %s, Binary)" % ( "(%s, %s, Binary)" % (
bin_type.__name__, txt_type.__name__)) six.binary_type.__name__, six.text_type.__name__))
if self.max_bytes is not None and len(value) > self.max_bytes: if self.max_bytes is not None and len(value) > self.max_bytes:
self.error('Binary value is too long') self.error('Binary value is too long')
@ -1492,8 +1491,10 @@ class FileField(BaseField):
def __set__(self, instance, value): def __set__(self, instance, value):
key = self.name key = self.name
if ((hasattr(value, 'read') and not if (
isinstance(value, GridFSProxy)) or isinstance(value, str_types)): (hasattr(value, 'read') and not isinstance(value, GridFSProxy)) or
isinstance(value, (six.binary_type, six.string_types))
):
# using "FileField() = file/string" notation # using "FileField() = file/string" notation
grid_file = instance._data.get(self.name) grid_file = instance._data.get(self.name)
# If a file already exists, delete it # If a file already exists, delete it

View File

@ -1,7 +1,10 @@
"""Helper functions and types to aid with Python 2.5 - 3 support.""" """
Helper functions, constants, and types to aid with Python v2.6 - v3.x and
PyMongo v2.7 - v3.x support.
"""
import sys import sys
import pymongo import pymongo
import six
if pymongo.version_tuple[0] < 3: if pymongo.version_tuple[0] < 3:
@ -9,29 +12,17 @@ if pymongo.version_tuple[0] < 3:
else: else:
IS_PYMONGO_3 = True IS_PYMONGO_3 = True
PY3 = sys.version_info[0] == 3 PY3 = sys.version_info[0] == 3
if PY3:
import codecs
from io import BytesIO as StringIO
# return s converted to binary. b('test') should be equivalent to b'test' # six.BytesIO resolves to StringIO.StringIO in Py2 and io.BytesIO in Py3.
def b(s): StringIO = six.BytesIO
return codecs.latin_1_encode(s)[0]
bin_type = bytes # Additionally for Py2, try to use the faster cStringIO, if available
txt_type = str if not PY3:
else:
try: try:
from cStringIO import StringIO from cStringIO import StringIO
except ImportError: except ImportError:
from StringIO import StringIO pass
# Conversion to binary only necessary in Python 3
def b(s):
return s
bin_type = str
txt_type = unicode
str_types = (bin_type, txt_type)

View File

@ -12,6 +12,7 @@ from bson.code import Code
import pymongo import pymongo
import pymongo.errors import pymongo.errors
from pymongo.common import validate_read_preference from pymongo.common import validate_read_preference
import six
from mongoengine import signals from mongoengine import signals
from mongoengine.base.common import get_document from mongoengine.base.common import get_document
@ -352,15 +353,15 @@ class BaseQuerySet(object):
ids = self._collection.insert(raw, **write_concern) ids = self._collection.insert(raw, **write_concern)
except pymongo.errors.DuplicateKeyError as err: except pymongo.errors.DuplicateKeyError as err:
message = 'Could not save document (%s)' message = 'Could not save document (%s)'
raise NotUniqueError(message % unicode(err)) raise NotUniqueError(message % six.text_type(err))
except pymongo.errors.OperationFailure as err: except pymongo.errors.OperationFailure as err:
message = 'Could not save document (%s)' message = 'Could not save document (%s)'
if re.match('^E1100[01] duplicate key', unicode(err)): if re.match('^E1100[01] duplicate key', six.text_type(err)):
# E11000 - duplicate key error index # E11000 - duplicate key error index
# E11001 - duplicate key on update # E11001 - duplicate key on update
message = u'Tried to save duplicate unique keys (%s)' message = u'Tried to save duplicate unique keys (%s)'
raise NotUniqueError(message % unicode(err)) raise NotUniqueError(message % six.text_type(err))
raise OperationError(message % unicode(err)) raise OperationError(message % six.text_type(err))
if not load_bulk: if not load_bulk:
signals.post_bulk_insert.send( signals.post_bulk_insert.send(
@ -506,12 +507,12 @@ class BaseQuerySet(object):
elif result: elif result:
return result['n'] return result['n']
except pymongo.errors.DuplicateKeyError as err: except pymongo.errors.DuplicateKeyError as err:
raise NotUniqueError(u'Update failed (%s)' % unicode(err)) raise NotUniqueError(u'Update failed (%s)' % six.text_type(err))
except pymongo.errors.OperationFailure as err: except pymongo.errors.OperationFailure as err:
if unicode(err) == u'multi not coded yet': if six.text_type(err) == u'multi not coded yet':
message = u'update() method requires MongoDB 1.1.3+' message = u'update() method requires MongoDB 1.1.3+'
raise OperationError(message) raise OperationError(message)
raise OperationError(u'Update failed (%s)' % unicode(err)) raise OperationError(u'Update failed (%s)' % six.text_type(err))
def upsert_one(self, write_concern=None, **update): def upsert_one(self, write_concern=None, **update):
"""Overwrite or add the first document matched by the query. """Overwrite or add the first document matched by the query.
@ -1155,13 +1156,13 @@ class BaseQuerySet(object):
map_f_scope = {} map_f_scope = {}
if isinstance(map_f, Code): if isinstance(map_f, Code):
map_f_scope = map_f.scope map_f_scope = map_f.scope
map_f = unicode(map_f) map_f = six.text_type(map_f)
map_f = Code(queryset._sub_js_fields(map_f), map_f_scope) map_f = Code(queryset._sub_js_fields(map_f), map_f_scope)
reduce_f_scope = {} reduce_f_scope = {}
if isinstance(reduce_f, Code): if isinstance(reduce_f, Code):
reduce_f_scope = reduce_f.scope reduce_f_scope = reduce_f.scope
reduce_f = unicode(reduce_f) reduce_f = six.text_type(reduce_f)
reduce_f_code = queryset._sub_js_fields(reduce_f) reduce_f_code = queryset._sub_js_fields(reduce_f)
reduce_f = Code(reduce_f_code, reduce_f_scope) reduce_f = Code(reduce_f_code, reduce_f_scope)
@ -1171,7 +1172,7 @@ class BaseQuerySet(object):
finalize_f_scope = {} finalize_f_scope = {}
if isinstance(finalize_f, Code): if isinstance(finalize_f, Code):
finalize_f_scope = finalize_f.scope finalize_f_scope = finalize_f.scope
finalize_f = unicode(finalize_f) finalize_f = six.text_type(finalize_f)
finalize_f_code = queryset._sub_js_fields(finalize_f) finalize_f_code = queryset._sub_js_fields(finalize_f)
finalize_f = Code(finalize_f_code, finalize_f_scope) finalize_f = Code(finalize_f_code, finalize_f_scope)
mr_args['finalize'] = finalize_f mr_args['finalize'] = finalize_f

View File

@ -29,7 +29,7 @@ from mongoengine.base.common import _document_registry
from mongoengine.base.datastructures import BaseDict, EmbeddedDocumentList from mongoengine.base.datastructures import BaseDict, EmbeddedDocumentList
from mongoengine.base.fields import BaseField from mongoengine.base.fields import BaseField
from mongoengine.errors import NotRegistered, DoesNotExist from mongoengine.errors import NotRegistered, DoesNotExist
from mongoengine.python_support import PY3, b, bin_type from mongoengine.python_support import PY3
__all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase") __all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase")
@ -2814,7 +2814,7 @@ class FieldTest(unittest.TestCase):
content_type = StringField() content_type = StringField()
blob = BinaryField() blob = BinaryField()
BLOB = b('\xe6\x00\xc4\xff\x07') BLOB = six.b('\xe6\x00\xc4\xff\x07')
MIME_TYPE = 'application/octet-stream' MIME_TYPE = 'application/octet-stream'
Attachment.drop_collection() Attachment.drop_collection()
@ -2824,7 +2824,7 @@ class FieldTest(unittest.TestCase):
attachment_1 = Attachment.objects().first() attachment_1 = Attachment.objects().first()
self.assertEqual(MIME_TYPE, attachment_1.content_type) self.assertEqual(MIME_TYPE, attachment_1.content_type)
self.assertEqual(BLOB, bin_type(attachment_1.blob)) self.assertEqual(BLOB, six.binary_type(attachment_1.blob))
Attachment.drop_collection() Attachment.drop_collection()
@ -2851,13 +2851,13 @@ class FieldTest(unittest.TestCase):
attachment_required = AttachmentRequired() attachment_required = AttachmentRequired()
self.assertRaises(ValidationError, attachment_required.validate) self.assertRaises(ValidationError, attachment_required.validate)
attachment_required.blob = Binary(b('\xe6\x00\xc4\xff\x07')) attachment_required.blob = Binary(six.b('\xe6\x00\xc4\xff\x07'))
attachment_required.validate() attachment_required.validate()
attachment_size_limit = AttachmentSizeLimit( attachment_size_limit = AttachmentSizeLimit(
blob=b('\xe6\x00\xc4\xff\x07')) blob=six.b('\xe6\x00\xc4\xff\x07'))
self.assertRaises(ValidationError, attachment_size_limit.validate) self.assertRaises(ValidationError, attachment_size_limit.validate)
attachment_size_limit.blob = b('\xe6\x00\xc4\xff') attachment_size_limit.blob = six.b('\xe6\x00\xc4\xff')
attachment_size_limit.validate() attachment_size_limit.validate()
Attachment.drop_collection() Attachment.drop_collection()

View File

@ -5,11 +5,12 @@ import unittest
import tempfile import tempfile
import gridfs import gridfs
import six
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from mongoengine import * from mongoengine import *
from mongoengine.connection import get_db from mongoengine.connection import get_db
from mongoengine.python_support import b, StringIO from mongoengine.python_support import StringIO
try: try:
from PIL import Image from PIL import Image
@ -46,7 +47,7 @@ class FileTest(unittest.TestCase):
PutFile.drop_collection() PutFile.drop_collection()
text = b('Hello, World!') text = six.b('Hello, World!')
content_type = 'text/plain' content_type = 'text/plain'
putfile = PutFile() putfile = PutFile()
@ -85,8 +86,8 @@ class FileTest(unittest.TestCase):
StreamFile.drop_collection() StreamFile.drop_collection()
text = b('Hello, World!') text = six.b('Hello, World!')
more_text = b('Foo Bar') more_text = six.b('Foo Bar')
content_type = 'text/plain' content_type = 'text/plain'
streamfile = StreamFile() streamfile = StreamFile()
@ -120,8 +121,8 @@ class FileTest(unittest.TestCase):
StreamFile.drop_collection() StreamFile.drop_collection()
text = b('Hello, World!') text = six.b('Hello, World!')
more_text = b('Foo Bar') more_text = six.b('Foo Bar')
content_type = 'text/plain' content_type = 'text/plain'
streamfile = StreamFile() streamfile = StreamFile()
@ -152,8 +153,8 @@ class FileTest(unittest.TestCase):
class SetFile(Document): class SetFile(Document):
the_file = FileField() the_file = FileField()
text = b('Hello, World!') text = six.b('Hello, World!')
more_text = b('Foo Bar') more_text = six.b('Foo Bar')
SetFile.drop_collection() SetFile.drop_collection()
@ -182,7 +183,7 @@ class FileTest(unittest.TestCase):
GridDocument.drop_collection() GridDocument.drop_collection()
with tempfile.TemporaryFile() as f: with tempfile.TemporaryFile() as f:
f.write(b("Hello World!")) f.write(six.b("Hello World!"))
f.flush() f.flush()
# Test without default # Test without default
@ -199,7 +200,7 @@ class FileTest(unittest.TestCase):
self.assertEqual(doc_b.the_file.grid_id, doc_c.the_file.grid_id) self.assertEqual(doc_b.the_file.grid_id, doc_c.the_file.grid_id)
# Test with default # Test with default
doc_d = GridDocument(the_file=b('')) doc_d = GridDocument(the_file=six.b(''))
doc_d.save() doc_d.save()
doc_e = GridDocument.objects.with_id(doc_d.id) doc_e = GridDocument.objects.with_id(doc_d.id)
@ -225,7 +226,7 @@ class FileTest(unittest.TestCase):
# First instance # First instance
test_file = TestFile() test_file = TestFile()
test_file.name = "Hello, World!" test_file.name = "Hello, World!"
test_file.the_file.put(b('Hello, World!')) test_file.the_file.put(six.b('Hello, World!'))
test_file.save() test_file.save()
# Second instance # Second instance
@ -279,7 +280,7 @@ class FileTest(unittest.TestCase):
test_file = TestFile() test_file = TestFile()
self.assertFalse(bool(test_file.the_file)) self.assertFalse(bool(test_file.the_file))
test_file.the_file.put(b('Hello, World!'), content_type='text/plain') test_file.the_file.put(six.b('Hello, World!'), content_type='text/plain')
test_file.save() test_file.save()
self.assertTrue(bool(test_file.the_file)) self.assertTrue(bool(test_file.the_file))
@ -299,7 +300,7 @@ class FileTest(unittest.TestCase):
class TestFile(Document): class TestFile(Document):
the_file = FileField() the_file = FileField()
text = b('Hello, World!') text = six.b('Hello, World!')
content_type = 'text/plain' content_type = 'text/plain'
testfile = TestFile() testfile = TestFile()
@ -343,7 +344,7 @@ class FileTest(unittest.TestCase):
testfile.the_file.put(text, content_type=content_type, filename="hello") testfile.the_file.put(text, content_type=content_type, filename="hello")
testfile.save() testfile.save()
text = b('Bonjour, World!') text = six.b('Bonjour, World!')
testfile.the_file.replace(text, content_type=content_type, filename="hello") testfile.the_file.replace(text, content_type=content_type, filename="hello")
testfile.save() testfile.save()
@ -369,7 +370,7 @@ class FileTest(unittest.TestCase):
TestImage.drop_collection() TestImage.drop_collection()
with tempfile.TemporaryFile() as f: with tempfile.TemporaryFile() as f:
f.write(b("Hello World!")) f.write(six.b("Hello World!"))
f.flush() f.flush()
t = TestImage() t = TestImage()
@ -493,7 +494,7 @@ class FileTest(unittest.TestCase):
# First instance # First instance
test_file = TestFile() test_file = TestFile()
test_file.name = "Hello, World!" test_file.name = "Hello, World!"
test_file.the_file.put(b('Hello, World!'), test_file.the_file.put(six.b('Hello, World!'),
name="hello.txt") name="hello.txt")
test_file.save() test_file.save()
@ -501,16 +502,17 @@ class FileTest(unittest.TestCase):
self.assertEqual(data.get('name'), 'hello.txt') self.assertEqual(data.get('name'), 'hello.txt')
test_file = TestFile.objects.first() test_file = TestFile.objects.first()
self.assertEqual(test_file.the_file.read(), self.assertEqual(test_file.the_file.read(), six.b('Hello, World!'))
b('Hello, World!'))
test_file = TestFile.objects.first() test_file = TestFile.objects.first()
test_file.the_file = b('HELLO, WORLD!') test_file.the_file = six.b('HELLO, WORLD!')
print('HERE!!!')
print(test_file.the_file)
test_file.save() test_file.save()
test_file = TestFile.objects.first() test_file = TestFile.objects.first()
self.assertEqual(test_file.the_file.read(), self.assertEqual(test_file.the_file.read(),
b('HELLO, WORLD!')) six.b('HELLO, WORLD!'))
def test_copyable(self): def test_copyable(self):
class PutFile(Document): class PutFile(Document):
@ -518,7 +520,7 @@ class FileTest(unittest.TestCase):
PutFile.drop_collection() PutFile.drop_collection()
text = b('Hello, World!') text = six.b('Hello, World!')
content_type = 'text/plain' content_type = 'text/plain'
putfile = PutFile() putfile = PutFile()