From 548c7438b04d4b4481f874aedd55f3751b9a8715 Mon Sep 17 00:00:00 2001 From: Stefan Wojcik Date: Tue, 6 Dec 2016 16:14:53 -0500 Subject: [PATCH] dont re-implement six --- mongoengine/base/datastructures.py | 5 +++- mongoengine/base/document.py | 8 ++++-- mongoengine/base/fields.py | 12 ++++---- mongoengine/connection.py | 6 ++-- mongoengine/dereference.py | 4 +-- mongoengine/document.py | 10 ++++--- mongoengine/errors.py | 10 ++++--- mongoengine/fields.py | 15 +++++----- mongoengine/python_support.py | 31 ++++++++------------- mongoengine/queryset/base.py | 21 +++++++------- tests/fields/fields.py | 12 ++++---- tests/fields/file_tests.py | 44 ++++++++++++++++-------------- 12 files changed, 93 insertions(+), 85 deletions(-) diff --git a/mongoengine/base/datastructures.py b/mongoengine/base/datastructures.py index 2c6ebc2a..e94a2f24 100644 --- a/mongoengine/base/datastructures.py +++ b/mongoengine/base/datastructures.py @@ -1,6 +1,8 @@ import itertools import weakref +import six + from mongoengine.common import _import_class from mongoengine.errors import DoesNotExist, MultipleObjectsReturned @@ -212,7 +214,8 @@ class EmbeddedDocumentList(BaseList): def __match_all(cls, i, kwargs): items = kwargs.items() return all([ - getattr(i, k) == v or unicode(getattr(i, k)) == v for k, v in items + getattr(i, k) == v or six.text_type(getattr(i, k)) == v + for k, v in items ]) @classmethod diff --git a/mongoengine/base/document.py b/mongoengine/base/document.py index 9d69efd6..15c5c851 100644 --- a/mongoengine/base/document.py +++ b/mongoengine/base/document.py @@ -8,6 +8,7 @@ from bson import ObjectId, json_util from bson.dbref import DBRef from bson.son import SON import pymongo +import six from mongoengine import signals from mongoengine.base.common import ALLOW_INHERITANCE, get_document @@ -18,7 +19,7 @@ from mongoengine.base.fields import ComplexBaseField from mongoengine.common import _import_class from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError, LookUpError, ValidationError) -from mongoengine.python_support import PY3, txt_type +from mongoengine.python_support import PY3 __all__ = ('BaseDocument', 'NON_FIELD_ERRORS') @@ -250,12 +251,13 @@ class BaseDocument(object): return repr_type('<%s: %s>' % (self.__class__.__name__, u)) def __str__(self): + # TODO this could be simpler? if hasattr(self, '__unicode__'): if PY3: return self.__unicode__() else: - return unicode(self).encode('utf-8') - return txt_type('%s object' % self.__class__.__name__) + return six.text_type(self).encode('utf-8') + return six.text_type('%s object' % self.__class__.__name__) def __eq__(self, other): if isinstance(other, self.__class__) and hasattr(other, 'id') and other.id is not None: diff --git a/mongoengine/base/fields.py b/mongoengine/base/fields.py index 9b75fff2..b836458c 100644 --- a/mongoengine/base/fields.py +++ b/mongoengine/base/fields.py @@ -4,6 +4,7 @@ import weakref from bson import DBRef, ObjectId, SON import pymongo +import six from mongoengine.base.common import ALLOW_INHERITANCE from mongoengine.base.datastructures import ( @@ -12,6 +13,7 @@ from mongoengine.base.datastructures import ( from mongoengine.common import _import_class from mongoengine.errors import ValidationError + __all__ = ('BaseField', 'ComplexBaseField', 'ObjectIdField', 'GeoJsonBaseField') @@ -200,11 +202,11 @@ class BaseField(object): if isinstance(value, (Document, EmbeddedDocument)): if not any(isinstance(value, c) for c in choice_list): self.error( - 'Value must be instance of %s' % unicode(choice_list) + 'Value must be instance of %s' % six.text_type(choice_list) ) # Choices which are types other than Documents elif value not in choice_list: - self.error('Value must be one of %s' % unicode(choice_list)) + self.error('Value must be one of %s' % six.text_type(choice_list)) def _validate(self, value, **kwargs): # Check the Choices Constraint @@ -457,10 +459,10 @@ class ObjectIdField(BaseField): def to_mongo(self, value): if not isinstance(value, ObjectId): try: - return ObjectId(unicode(value)) + return ObjectId(six.text_type(value)) except Exception as e: # e.message attribute has been deprecated since Python 2.6 - self.error(unicode(e)) + self.error(six.text_type(e)) return value def prepare_query_value(self, op, value): @@ -468,7 +470,7 @@ class ObjectIdField(BaseField): def validate(self, value): try: - ObjectId(unicode(value)) + ObjectId(six.text_type(value)) except Exception: self.error('Invalid Object ID') diff --git a/mongoengine/connection.py b/mongoengine/connection.py index 28c6886f..826b617a 100644 --- a/mongoengine/connection.py +++ b/mongoengine/connection.py @@ -1,5 +1,7 @@ from pymongo import MongoClient, ReadPreference, uri_parser -from mongoengine.python_support import (IS_PYMONGO_3, str_types) +import six + +from mongoengine.python_support import IS_PYMONGO_3 __all__ = ['ConnectionError', 'connect', 'register_connection', 'DEFAULT_CONNECTION_NAME'] @@ -66,7 +68,7 @@ def register_connection(alias, name=None, host=None, port=None, # Handle uri style connections conn_host = conn_settings['host'] # host can be a list or a string, so if string, force to a list - if isinstance(conn_host, str_types): + if isinstance(conn_host, six.string_types): conn_host = [conn_host] resolved_hosts = [] diff --git a/mongoengine/dereference.py b/mongoengine/dereference.py index 11d3dbe6..c5157d50 100644 --- a/mongoengine/dereference.py +++ b/mongoengine/dereference.py @@ -1,4 +1,5 @@ from bson import DBRef, SON +import six from mongoengine.base.common import get_document from mongoengine.base.datastructures import (BaseDict, BaseList, @@ -7,7 +8,6 @@ from mongoengine.base.metaclasses import TopLevelDocumentMetaclass from mongoengine.connection import get_db from mongoengine.document import Document, EmbeddedDocument from mongoengine.fields import DictField, ListField, MapField, ReferenceField -from mongoengine.python_support import txt_type from mongoengine.queryset import QuerySet @@ -227,7 +227,7 @@ class DeReference(object): data[k]._data[field_name] = self.object_map.get( (v['_ref'].collection, v['_ref'].id), v) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: - item_name = txt_type("{0}.{1}.{2}").format(name, k, field_name) + item_name = six.text_type("{0}.{1}.{2}").format(name, k, field_name) data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name) elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth: item_name = '%s.%s' % (name, k) if name else name diff --git a/mongoengine/document.py b/mongoengine/document.py index 3092e003..38fcc63d 100644 --- a/mongoengine/document.py +++ b/mongoengine/document.py @@ -4,6 +4,7 @@ import warnings from bson.dbref import DBRef import pymongo from pymongo.read_preferences import ReadPreference +import six from mongoengine import signals from mongoengine.base.common import ALLOW_INHERITANCE, get_document @@ -391,15 +392,16 @@ class Document(BaseDocument): self.cascade_save(**kwargs) except pymongo.errors.DuplicateKeyError as err: message = u'Tried to save duplicate unique keys (%s)' - raise NotUniqueError(message % unicode(err)) + raise NotUniqueError(message % six.text_type(err)) except pymongo.errors.OperationFailure as err: message = 'Could not save document (%s)' - if re.match('^E1100[01] duplicate key', unicode(err)): + if re.match('^E1100[01] duplicate key', six.text_type(err)): # E11000 - duplicate key error index # E11001 - duplicate key on update message = u'Tried to save duplicate unique keys (%s)' - raise NotUniqueError(message % unicode(err)) - raise OperationError(message % unicode(err)) + raise NotUniqueError(message % six.text_type(err)) + raise OperationError(message % six.text_type(err)) + id_field = self._meta['id_field'] if created or id_field not in self._meta.get('shard_key', []): self[id_field] = self._fields[id_field].to_python(object_id) diff --git a/mongoengine/errors.py b/mongoengine/errors.py index 15830b5c..bcc31309 100644 --- a/mongoengine/errors.py +++ b/mongoengine/errors.py @@ -1,7 +1,6 @@ from collections import defaultdict -from mongoengine.python_support import txt_type - +import six __all__ = ('NotRegistered', 'InvalidDocumentError', 'LookUpError', 'DoesNotExist', 'MultipleObjectsReturned', 'InvalidQueryError', @@ -77,7 +76,7 @@ class ValidationError(AssertionError): self.message = message def __str__(self): - return txt_type(self.message) + return six.text_type(self.message) def __repr__(self): return '%s(%s,)' % (self.__class__.__name__, self.message) @@ -111,17 +110,20 @@ class ValidationError(AssertionError): errors_dict = {} if not source: return errors_dict + if isinstance(source, dict): for field_name, error in source.iteritems(): errors_dict[field_name] = build_dict(error) elif isinstance(source, ValidationError) and source.errors: return build_dict(source.errors) else: - return unicode(source) + return six.text_type(source) + return errors_dict if not self.errors: return {} + return build_dict(self.errors) def _format_errors(self): diff --git a/mongoengine/fields.py b/mongoengine/fields.py index 1a9bc497..ec2958db 100644 --- a/mongoengine/fields.py +++ b/mongoengine/fields.py @@ -32,8 +32,7 @@ from mongoengine.base.fields import (BaseField, ComplexBaseField, from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db from mongoengine.document import Document, EmbeddedDocument from mongoengine.errors import DoesNotExist, ValidationError -from mongoengine.python_support import (PY3, StringIO, bin_type, str_types, - txt_type) +from mongoengine.python_support import PY3, StringIO from mongoengine.queryset import DO_NOTHING, QuerySet try: @@ -1294,17 +1293,17 @@ class BinaryField(BaseField): def __set__(self, instance, value): """Handle bytearrays in python 3.1""" if PY3 and isinstance(value, bytearray): - value = bin_type(value) + value = six.binary_type(value) return super(BinaryField, self).__set__(instance, value) def to_mongo(self, value): return Binary(value) def validate(self, value): - if not isinstance(value, (bin_type, txt_type, Binary)): + if not isinstance(value, (six.binary_type, six.text_type, Binary)): self.error("BinaryField only accepts instances of " "(%s, %s, Binary)" % ( - bin_type.__name__, txt_type.__name__)) + six.binary_type.__name__, six.text_type.__name__)) if self.max_bytes is not None and len(value) > self.max_bytes: self.error('Binary value is too long') @@ -1492,8 +1491,10 @@ class FileField(BaseField): def __set__(self, instance, value): key = self.name - if ((hasattr(value, 'read') and not - isinstance(value, GridFSProxy)) or isinstance(value, str_types)): + if ( + (hasattr(value, 'read') and not isinstance(value, GridFSProxy)) or + isinstance(value, (six.binary_type, six.string_types)) + ): # using "FileField() = file/string" notation grid_file = instance._data.get(self.name) # If a file already exists, delete it diff --git a/mongoengine/python_support.py b/mongoengine/python_support.py index 5bb9038d..849f48d2 100644 --- a/mongoengine/python_support.py +++ b/mongoengine/python_support.py @@ -1,7 +1,10 @@ -"""Helper functions and types to aid with Python 2.5 - 3 support.""" - +""" +Helper functions, constants, and types to aid with Python v2.6 - v3.x and +PyMongo v2.7 - v3.x support. +""" import sys import pymongo +import six if pymongo.version_tuple[0] < 3: @@ -9,29 +12,17 @@ if pymongo.version_tuple[0] < 3: else: IS_PYMONGO_3 = True + PY3 = sys.version_info[0] == 3 -if PY3: - import codecs - from io import BytesIO as StringIO - # return s converted to binary. b('test') should be equivalent to b'test' - def b(s): - return codecs.latin_1_encode(s)[0] +# six.BytesIO resolves to StringIO.StringIO in Py2 and io.BytesIO in Py3. +StringIO = six.BytesIO - bin_type = bytes - txt_type = str -else: +# Additionally for Py2, try to use the faster cStringIO, if available +if not PY3: try: from cStringIO import StringIO except ImportError: - from StringIO import StringIO + pass - # Conversion to binary only necessary in Python 3 - def b(s): - return s - - bin_type = str - txt_type = unicode - -str_types = (bin_type, txt_type) diff --git a/mongoengine/queryset/base.py b/mongoengine/queryset/base.py index bd6fa739..d07b6d20 100644 --- a/mongoengine/queryset/base.py +++ b/mongoengine/queryset/base.py @@ -12,6 +12,7 @@ from bson.code import Code import pymongo import pymongo.errors from pymongo.common import validate_read_preference +import six from mongoengine import signals from mongoengine.base.common import get_document @@ -352,15 +353,15 @@ class BaseQuerySet(object): ids = self._collection.insert(raw, **write_concern) except pymongo.errors.DuplicateKeyError as err: message = 'Could not save document (%s)' - raise NotUniqueError(message % unicode(err)) + raise NotUniqueError(message % six.text_type(err)) except pymongo.errors.OperationFailure as err: message = 'Could not save document (%s)' - if re.match('^E1100[01] duplicate key', unicode(err)): + if re.match('^E1100[01] duplicate key', six.text_type(err)): # E11000 - duplicate key error index # E11001 - duplicate key on update message = u'Tried to save duplicate unique keys (%s)' - raise NotUniqueError(message % unicode(err)) - raise OperationError(message % unicode(err)) + raise NotUniqueError(message % six.text_type(err)) + raise OperationError(message % six.text_type(err)) if not load_bulk: signals.post_bulk_insert.send( @@ -506,12 +507,12 @@ class BaseQuerySet(object): elif result: return result['n'] except pymongo.errors.DuplicateKeyError as err: - raise NotUniqueError(u'Update failed (%s)' % unicode(err)) + raise NotUniqueError(u'Update failed (%s)' % six.text_type(err)) except pymongo.errors.OperationFailure as err: - if unicode(err) == u'multi not coded yet': + if six.text_type(err) == u'multi not coded yet': message = u'update() method requires MongoDB 1.1.3+' raise OperationError(message) - raise OperationError(u'Update failed (%s)' % unicode(err)) + raise OperationError(u'Update failed (%s)' % six.text_type(err)) def upsert_one(self, write_concern=None, **update): """Overwrite or add the first document matched by the query. @@ -1155,13 +1156,13 @@ class BaseQuerySet(object): map_f_scope = {} if isinstance(map_f, Code): map_f_scope = map_f.scope - map_f = unicode(map_f) + map_f = six.text_type(map_f) map_f = Code(queryset._sub_js_fields(map_f), map_f_scope) reduce_f_scope = {} if isinstance(reduce_f, Code): reduce_f_scope = reduce_f.scope - reduce_f = unicode(reduce_f) + reduce_f = six.text_type(reduce_f) reduce_f_code = queryset._sub_js_fields(reduce_f) reduce_f = Code(reduce_f_code, reduce_f_scope) @@ -1171,7 +1172,7 @@ class BaseQuerySet(object): finalize_f_scope = {} if isinstance(finalize_f, Code): finalize_f_scope = finalize_f.scope - finalize_f = unicode(finalize_f) + finalize_f = six.text_type(finalize_f) finalize_f_code = queryset._sub_js_fields(finalize_f) finalize_f = Code(finalize_f_code, finalize_f_scope) mr_args['finalize'] = finalize_f diff --git a/tests/fields/fields.py b/tests/fields/fields.py index 7ae0faae..1f22b1ed 100644 --- a/tests/fields/fields.py +++ b/tests/fields/fields.py @@ -29,7 +29,7 @@ from mongoengine.base.common import _document_registry from mongoengine.base.datastructures import BaseDict, EmbeddedDocumentList from mongoengine.base.fields import BaseField from mongoengine.errors import NotRegistered, DoesNotExist -from mongoengine.python_support import PY3, b, bin_type +from mongoengine.python_support import PY3 __all__ = ("FieldTest", "EmbeddedDocumentListFieldTestCase") @@ -2814,7 +2814,7 @@ class FieldTest(unittest.TestCase): content_type = StringField() blob = BinaryField() - BLOB = b('\xe6\x00\xc4\xff\x07') + BLOB = six.b('\xe6\x00\xc4\xff\x07') MIME_TYPE = 'application/octet-stream' Attachment.drop_collection() @@ -2824,7 +2824,7 @@ class FieldTest(unittest.TestCase): attachment_1 = Attachment.objects().first() self.assertEqual(MIME_TYPE, attachment_1.content_type) - self.assertEqual(BLOB, bin_type(attachment_1.blob)) + self.assertEqual(BLOB, six.binary_type(attachment_1.blob)) Attachment.drop_collection() @@ -2851,13 +2851,13 @@ class FieldTest(unittest.TestCase): attachment_required = AttachmentRequired() self.assertRaises(ValidationError, attachment_required.validate) - attachment_required.blob = Binary(b('\xe6\x00\xc4\xff\x07')) + attachment_required.blob = Binary(six.b('\xe6\x00\xc4\xff\x07')) attachment_required.validate() attachment_size_limit = AttachmentSizeLimit( - blob=b('\xe6\x00\xc4\xff\x07')) + blob=six.b('\xe6\x00\xc4\xff\x07')) self.assertRaises(ValidationError, attachment_size_limit.validate) - attachment_size_limit.blob = b('\xe6\x00\xc4\xff') + attachment_size_limit.blob = six.b('\xe6\x00\xc4\xff') attachment_size_limit.validate() Attachment.drop_collection() diff --git a/tests/fields/file_tests.py b/tests/fields/file_tests.py index ccd31537..88671238 100644 --- a/tests/fields/file_tests.py +++ b/tests/fields/file_tests.py @@ -5,11 +5,12 @@ import unittest import tempfile import gridfs +import six from nose.plugins.skip import SkipTest from mongoengine import * from mongoengine.connection import get_db -from mongoengine.python_support import b, StringIO +from mongoengine.python_support import StringIO try: from PIL import Image @@ -46,7 +47,7 @@ class FileTest(unittest.TestCase): PutFile.drop_collection() - text = b('Hello, World!') + text = six.b('Hello, World!') content_type = 'text/plain' putfile = PutFile() @@ -85,8 +86,8 @@ class FileTest(unittest.TestCase): StreamFile.drop_collection() - text = b('Hello, World!') - more_text = b('Foo Bar') + text = six.b('Hello, World!') + more_text = six.b('Foo Bar') content_type = 'text/plain' streamfile = StreamFile() @@ -120,8 +121,8 @@ class FileTest(unittest.TestCase): StreamFile.drop_collection() - text = b('Hello, World!') - more_text = b('Foo Bar') + text = six.b('Hello, World!') + more_text = six.b('Foo Bar') content_type = 'text/plain' streamfile = StreamFile() @@ -152,8 +153,8 @@ class FileTest(unittest.TestCase): class SetFile(Document): the_file = FileField() - text = b('Hello, World!') - more_text = b('Foo Bar') + text = six.b('Hello, World!') + more_text = six.b('Foo Bar') SetFile.drop_collection() @@ -182,7 +183,7 @@ class FileTest(unittest.TestCase): GridDocument.drop_collection() with tempfile.TemporaryFile() as f: - f.write(b("Hello World!")) + f.write(six.b("Hello World!")) f.flush() # Test without default @@ -199,7 +200,7 @@ class FileTest(unittest.TestCase): self.assertEqual(doc_b.the_file.grid_id, doc_c.the_file.grid_id) # Test with default - doc_d = GridDocument(the_file=b('')) + doc_d = GridDocument(the_file=six.b('')) doc_d.save() doc_e = GridDocument.objects.with_id(doc_d.id) @@ -225,7 +226,7 @@ class FileTest(unittest.TestCase): # First instance test_file = TestFile() test_file.name = "Hello, World!" - test_file.the_file.put(b('Hello, World!')) + test_file.the_file.put(six.b('Hello, World!')) test_file.save() # Second instance @@ -279,7 +280,7 @@ class FileTest(unittest.TestCase): test_file = TestFile() self.assertFalse(bool(test_file.the_file)) - test_file.the_file.put(b('Hello, World!'), content_type='text/plain') + test_file.the_file.put(six.b('Hello, World!'), content_type='text/plain') test_file.save() self.assertTrue(bool(test_file.the_file)) @@ -299,7 +300,7 @@ class FileTest(unittest.TestCase): class TestFile(Document): the_file = FileField() - text = b('Hello, World!') + text = six.b('Hello, World!') content_type = 'text/plain' testfile = TestFile() @@ -343,7 +344,7 @@ class FileTest(unittest.TestCase): testfile.the_file.put(text, content_type=content_type, filename="hello") testfile.save() - text = b('Bonjour, World!') + text = six.b('Bonjour, World!') testfile.the_file.replace(text, content_type=content_type, filename="hello") testfile.save() @@ -369,7 +370,7 @@ class FileTest(unittest.TestCase): TestImage.drop_collection() with tempfile.TemporaryFile() as f: - f.write(b("Hello World!")) + f.write(six.b("Hello World!")) f.flush() t = TestImage() @@ -493,7 +494,7 @@ class FileTest(unittest.TestCase): # First instance test_file = TestFile() test_file.name = "Hello, World!" - test_file.the_file.put(b('Hello, World!'), + test_file.the_file.put(six.b('Hello, World!'), name="hello.txt") test_file.save() @@ -501,16 +502,17 @@ class FileTest(unittest.TestCase): self.assertEqual(data.get('name'), 'hello.txt') test_file = TestFile.objects.first() - self.assertEqual(test_file.the_file.read(), - b('Hello, World!')) + self.assertEqual(test_file.the_file.read(), six.b('Hello, World!')) test_file = TestFile.objects.first() - test_file.the_file = b('HELLO, WORLD!') + test_file.the_file = six.b('HELLO, WORLD!') + print('HERE!!!') + print(test_file.the_file) test_file.save() test_file = TestFile.objects.first() self.assertEqual(test_file.the_file.read(), - b('HELLO, WORLD!')) + six.b('HELLO, WORLD!')) def test_copyable(self): class PutFile(Document): @@ -518,7 +520,7 @@ class FileTest(unittest.TestCase): PutFile.drop_collection() - text = b('Hello, World!') + text = six.b('Hello, World!') content_type = 'text/plain' putfile = PutFile()