Finalize python2/3 codebase compatibility and get rid of 2to3

This commit is contained in:
Bastien Gérard 2019-06-14 23:30:01 +02:00
parent 4d6ddb070e
commit 2ca905b6e5
11 changed files with 31 additions and 38 deletions

View File

@ -422,10 +422,10 @@ class StrictDict(object):
return len(list(iteritems(self))) return len(list(iteritems(self)))
def __eq__(self, other): def __eq__(self, other):
return self.items() == other.items() return list(self.items()) == list(other.items())
def __ne__(self, other): def __ne__(self, other):
return self.items() != other.items() return list(self.items()) != list(other.items())
@classmethod @classmethod
def create(cls, allowed_keys): def create(cls, allowed_keys):

View File

@ -92,7 +92,7 @@ class BaseDocument(object):
# if so raise an Exception. # if so raise an Exception.
if not self._dynamic and (self._meta.get("strict", True) or _created): if not self._dynamic and (self._meta.get("strict", True) or _created):
_undefined_fields = set(values.keys()) - set( _undefined_fields = set(values.keys()) - set(
self._fields.keys() + ["id", "pk", "_cls", "_text_score"] list(self._fields.keys()) + ["id", "pk", "_cls", "_text_score"]
) )
if _undefined_fields: if _undefined_fields:
msg = ('The fields "{0}" do not exist on the document "{1}"').format( msg = ('The fields "{0}" do not exist on the document "{1}"').format(
@ -670,7 +670,7 @@ class BaseDocument(object):
del set_data["_id"] del set_data["_id"]
# Determine if any changed items were actually unset. # Determine if any changed items were actually unset.
for path, value in set_data.items(): for path, value in list(set_data.items()):
if value or isinstance( if value or isinstance(
value, (numbers.Number, bool) value, (numbers.Number, bool)
): # Account for 0 and True that are truthy ): # Account for 0 and True that are truthy

View File

@ -8,6 +8,7 @@ import uuid
from operator import itemgetter from operator import itemgetter
from bson import Binary, DBRef, ObjectId, SON from bson import Binary, DBRef, ObjectId, SON
from bson.int64 import Int64
import gridfs import gridfs
import pymongo import pymongo
from pymongo import ReturnDocument from pymongo import ReturnDocument
@ -21,11 +22,6 @@ except ImportError:
else: else:
import dateutil.parser import dateutil.parser
try:
from bson.int64 import Int64
except ImportError:
Int64 = long
from mongoengine.base import ( from mongoengine.base import (
BaseDocument, BaseDocument,
@ -53,8 +49,6 @@ except ImportError:
ImageOps = None ImageOps = None
if six.PY3: if six.PY3:
# Useless as long as 2to3 gets executed
# as it turns `long` into `int` blindly
long = int long = int

View File

@ -989,7 +989,7 @@ class BaseQuerySet(object):
.. versionchanged:: 0.5 - Added subfield support .. versionchanged:: 0.5 - Added subfield support
""" """
fields = {f: QueryFieldList.ONLY for f in fields} fields = {f: QueryFieldList.ONLY for f in fields}
self.only_fields = fields.keys() self.only_fields = list(fields.keys())
return self.fields(True, **fields) return self.fields(True, **fields)
def exclude(self, *fields): def exclude(self, *fields):

View File

@ -118,8 +118,8 @@ extra_opts = {
"Pillow>=2.0.0", "Pillow>=2.0.0",
], ],
} }
if sys.version_info[0] == 3: if sys.version_info[0] == 3:
extra_opts["use_2to3"] = True
if "test" in sys.argv: if "test" in sys.argv:
extra_opts["packages"] = find_packages() extra_opts["packages"] = find_packages()
extra_opts["package_data"] = { extra_opts["package_data"] = {
@ -143,7 +143,7 @@ setup(
long_description=LONG_DESCRIPTION, long_description=LONG_DESCRIPTION,
platforms=["any"], platforms=["any"],
classifiers=CLASSIFIERS, classifiers=CLASSIFIERS,
install_requires=["pymongo>=3.4", "six>=1.10.0"], install_requires=['pymongo>=3.4', 'six', 'future'],
cmdclass={"test": PyTest}, cmdclass={"test": PyTest},
**extra_opts **extra_opts
) )

View File

@ -1,4 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from builtins import str
import pytest import pytest
from mongoengine import ( from mongoengine import (
@ -75,7 +77,7 @@ class TestEmbeddedDocumentField(MongoDBTestCase):
# Test non exiting attribute # Test non exiting attribute
with pytest.raises(InvalidQueryError) as exc_info: with pytest.raises(InvalidQueryError) as exc_info:
Person.objects(settings__notexist="bar").first() Person.objects(settings__notexist="bar").first()
assert unicode(exc_info.value) == u'Cannot resolve field "notexist"' assert str(exc_info.value) == u'Cannot resolve field "notexist"'
with pytest.raises(LookUpError): with pytest.raises(LookUpError):
Person.objects.only("settings.notexist") Person.objects.only("settings.notexist")
@ -111,7 +113,7 @@ class TestEmbeddedDocumentField(MongoDBTestCase):
# Test non exiting attribute # Test non exiting attribute
with pytest.raises(InvalidQueryError) as exc_info: with pytest.raises(InvalidQueryError) as exc_info:
assert Person.objects(settings__notexist="bar").first().id == p.id assert Person.objects(settings__notexist="bar").first().id == p.id
assert unicode(exc_info.value) == u'Cannot resolve field "notexist"' assert str(exc_info.value) == u'Cannot resolve field "notexist"'
# Test existing attribute # Test existing attribute
assert Person.objects(settings__base_foo="basefoo").first().id == p.id assert Person.objects(settings__base_foo="basefoo").first().id == p.id
@ -319,7 +321,7 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase):
# Test non exiting attribute # Test non exiting attribute
with pytest.raises(InvalidQueryError) as exc_info: with pytest.raises(InvalidQueryError) as exc_info:
Person.objects(settings__notexist="bar").first() Person.objects(settings__notexist="bar").first()
assert unicode(exc_info.value) == u'Cannot resolve field "notexist"' assert str(exc_info.value) == u'Cannot resolve field "notexist"'
with pytest.raises(LookUpError): with pytest.raises(LookUpError):
Person.objects.only("settings.notexist") Person.objects.only("settings.notexist")
@ -347,7 +349,7 @@ class TestGenericEmbeddedDocumentField(MongoDBTestCase):
# Test non exiting attribute # Test non exiting attribute
with pytest.raises(InvalidQueryError) as exc_info: with pytest.raises(InvalidQueryError) as exc_info:
assert Person.objects(settings__notexist="bar").first().id == p.id assert Person.objects(settings__notexist="bar").first().id == p.id
assert unicode(exc_info.value) == u'Cannot resolve field "notexist"' assert str(exc_info.value) == u'Cannot resolve field "notexist"'
# Test existing attribute # Test existing attribute
assert Person.objects(settings__base_foo="basefoo").first().id == p.id assert Person.objects(settings__base_foo="basefoo").first().id == p.id

View File

@ -1,11 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import pytest import pytest
import six
try:
from bson.int64 import Int64 from bson.int64 import Int64
except ImportError: import six
Int64 = long
from mongoengine import * from mongoengine import *
from mongoengine.connection import get_db from mongoengine.connection import get_db

View File

@ -21,7 +21,7 @@ class TestSequenceField(MongoDBTestCase):
assert c["next"] == 10 assert c["next"] == 10
ids = [i.id for i in Person.objects] ids = [i.id for i in Person.objects]
assert ids == range(1, 11) assert ids == list(range(1, 11))
c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) c = self.db["mongoengine.counters"].find_one({"_id": "person.id"})
assert c["next"] == 10 assert c["next"] == 10
@ -76,7 +76,7 @@ class TestSequenceField(MongoDBTestCase):
assert c["next"] == 10 assert c["next"] == 10
ids = [i.id for i in Person.objects] ids = [i.id for i in Person.objects]
assert ids == range(1, 11) assert ids == list(range(1, 11))
c = self.db["mongoengine.counters"].find_one({"_id": "jelly.id"}) c = self.db["mongoengine.counters"].find_one({"_id": "jelly.id"})
assert c["next"] == 10 assert c["next"] == 10
@ -101,10 +101,10 @@ class TestSequenceField(MongoDBTestCase):
assert c["next"] == 10 assert c["next"] == 10
ids = [i.id for i in Person.objects] ids = [i.id for i in Person.objects]
assert ids == range(1, 11) assert ids == list(range(1, 11))
counters = [i.counter for i in Person.objects] counters = [i.counter for i in Person.objects]
assert counters == range(1, 11) assert counters == list(range(1, 11))
c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) c = self.db["mongoengine.counters"].find_one({"_id": "person.id"})
assert c["next"] == 10 assert c["next"] == 10
@ -166,10 +166,10 @@ class TestSequenceField(MongoDBTestCase):
assert c["next"] == 10 assert c["next"] == 10
ids = [i.id for i in Person.objects] ids = [i.id for i in Person.objects]
assert ids == range(1, 11) assert ids == list(range(1, 11))
id = [i.id for i in Animal.objects] id = [i.id for i in Animal.objects]
assert id == range(1, 11) assert id == list(range(1, 11))
c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) c = self.db["mongoengine.counters"].find_one({"_id": "person.id"})
assert c["next"] == 10 assert c["next"] == 10
@ -193,7 +193,7 @@ class TestSequenceField(MongoDBTestCase):
assert c["next"] == 10 assert c["next"] == 10
ids = [i.id for i in Person.objects] ids = [i.id for i in Person.objects]
assert ids == map(str, range(1, 11)) assert ids == [str(i) for i in range(1, 11)]
c = self.db["mongoengine.counters"].find_one({"_id": "person.id"}) c = self.db["mongoengine.counters"].find_one({"_id": "person.id"})
assert c["next"] == 10 assert c["next"] == 10

View File

@ -1,6 +1,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import pytest import pytest
from builtins import str
from mongoengine import * from mongoengine import *
from tests.utils import MongoDBTestCase from tests.utils import MongoDBTestCase
@ -35,9 +37,8 @@ class TestURLField(MongoDBTestCase):
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
link.validate() link.validate()
assert ( assert (
unicode(exc_info.value) str(exc_info.exception)
== u"ValidationError (Link:None) (Invalid URL: http://\u043f\u0440\u0438\u0432\u0435\u0442.com: ['url'])" == u"ValidationError (Link:None) (Invalid URL: http://\u043f\u0440\u0438\u0432\u0435\u0442.com: ['url'])")
)
def test_url_scheme_validation(self): def test_url_scheme_validation(self):
"""Ensure that URLFields validate urls with specific schemes properly. """Ensure that URLFields validate urls with specific schemes properly.

View File

@ -110,7 +110,7 @@ class TestQueryset(unittest.TestCase):
# Filter people by age # Filter people by age
people = self.Person.objects(age=20) people = self.Person.objects(age=20)
assert people.count() == 1 assert people.count() == 1
person = people.next() person = next(people)
assert person == user_a assert person == user_a
assert person.name == "User A" assert person.name == "User A"
assert person.age == 20 assert person.age == 20
@ -2768,7 +2768,7 @@ class TestQueryset(unittest.TestCase):
) )
# start a map/reduce # start a map/reduce
cursor.next() next(cursor)
results = Person.objects.map_reduce( results = Person.objects.map_reduce(
map_f=map_person, map_f=map_person,
@ -4395,7 +4395,7 @@ class TestQueryset(unittest.TestCase):
# Use a query to filter the people found to just person1 # Use a query to filter the people found to just person1
people = self.Person.objects(age=20).scalar("name") people = self.Person.objects(age=20).scalar("name")
assert people.count() == 1 assert people.count() == 1
person = people.next() person = next(people)
assert person == "User A" assert person == "User A"
# Test limit # Test limit
@ -5309,7 +5309,7 @@ class TestQueryset(unittest.TestCase):
if not test: if not test:
raise AssertionError("Cursor has data and returned False") raise AssertionError("Cursor has data and returned False")
queryset.next() next(queryset)
if not queryset: if not queryset:
raise AssertionError( raise AssertionError(
"Cursor has data and it must returns True, even in the last item." "Cursor has data and it must returns True, even in the last item."

View File

@ -58,7 +58,7 @@ class TestSignal(unittest.TestCase):
@classmethod @classmethod
def post_save(cls, sender, document, **kwargs): def post_save(cls, sender, document, **kwargs):
dirty_keys = document._delta()[0].keys() + document._delta()[1].keys() dirty_keys = list(document._delta()[0].keys()) + list(document._delta()[1].keys())
signal_output.append("post_save signal, %s" % document) signal_output.append("post_save signal, %s" % document)
signal_output.append("post_save dirty keys, %s" % dirty_keys) signal_output.append("post_save dirty keys, %s" % dirty_keys)
if kwargs.pop("created", False): if kwargs.pop("created", False):