Merge branch 'master' of github.com:MongoEngine/mongoengine into fix_complex_datetime_field_invalid_string_set

This commit is contained in:
Bastien Gérard
2020-04-25 22:12:35 +02:00
62 changed files with 588 additions and 741 deletions

View File

@@ -5,14 +5,14 @@ import re
import socket
import time
import uuid
from io import BytesIO
from operator import itemgetter
from bson import Binary, DBRef, ObjectId, SON
from bson.int64 import Int64
import gridfs
import pymongo
from pymongo import ReturnDocument
import six
from six import iteritems
try:
import dateutil
@@ -21,11 +21,6 @@ except ImportError:
else:
import dateutil.parser
try:
from bson.int64 import Int64
except ImportError:
Int64 = long
from mongoengine.base import (
BaseDocument,
@@ -42,7 +37,6 @@ from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.document import Document, EmbeddedDocument
from mongoengine.errors import DoesNotExist, InvalidQueryError, ValidationError
from mongoengine.mongodb_support import MONGODB_36, get_mongodb_version
from mongoengine.python_support import StringIO
from mongoengine.queryset import DO_NOTHING
from mongoengine.queryset.base import BaseQuerySet
from mongoengine.queryset.transform import STRING_OPERATORS
@@ -53,11 +47,6 @@ except ImportError:
Image = None
ImageOps = None
if six.PY3:
# Useless as long as 2to3 gets executed
# as it turns `long` into `int` blindly
long = int
__all__ = (
"StringField",
@@ -114,10 +103,10 @@ class StringField(BaseField):
self.regex = re.compile(regex) if regex else None
self.max_length = max_length
self.min_length = min_length
super(StringField, self).__init__(**kwargs)
super().__init__(**kwargs)
def to_python(self, value):
if isinstance(value, six.text_type):
if isinstance(value, str):
return value
try:
value = value.decode("utf-8")
@@ -126,7 +115,7 @@ class StringField(BaseField):
return value
def validate(self, value):
if not isinstance(value, six.string_types):
if not isinstance(value, str):
self.error("StringField only accepts string values")
if self.max_length is not None and len(value) > self.max_length:
@@ -142,7 +131,7 @@ class StringField(BaseField):
return None
def prepare_query_value(self, op, value):
if not isinstance(op, six.string_types):
if not isinstance(op, str):
return value
if op in STRING_OPERATORS:
@@ -162,7 +151,7 @@ class StringField(BaseField):
# escape unsafe characters which could lead to a re.error
value = re.escape(value)
value = re.compile(regex % value, flags)
return super(StringField, self).prepare_query_value(op, value)
return super().prepare_query_value(op, value)
class URLField(StringField):
@@ -186,17 +175,17 @@ class URLField(StringField):
def __init__(self, url_regex=None, schemes=None, **kwargs):
self.url_regex = url_regex or self._URL_REGEX
self.schemes = schemes or self._URL_SCHEMES
super(URLField, self).__init__(**kwargs)
super().__init__(**kwargs)
def validate(self, value):
# Check first if the scheme is valid
scheme = value.split("://")[0].lower()
if scheme not in self.schemes:
self.error(u"Invalid scheme {} in URL: {}".format(scheme, value))
self.error("Invalid scheme {} in URL: {}".format(scheme, value))
# Then check full URL
if not self.url_regex.match(value):
self.error(u"Invalid URL: {}".format(value))
self.error("Invalid URL: {}".format(value))
class EmailField(StringField):
@@ -214,7 +203,7 @@ class EmailField(StringField):
)
UTF8_USER_REGEX = LazyRegexCompiler(
six.u(
(
# RFC 6531 Section 3.3 extends `atext` (used by dot-atom) to
# include `UTF8-non-ascii`.
r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z\u0080-\U0010FFFF]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z\u0080-\U0010FFFF]+)*\Z"
@@ -229,7 +218,7 @@ class EmailField(StringField):
re.IGNORECASE,
)
error_msg = u"Invalid email address: %s"
error_msg = "Invalid email address: %s"
def __init__(
self,
@@ -253,7 +242,7 @@ class EmailField(StringField):
self.domain_whitelist = domain_whitelist or []
self.allow_utf8_user = allow_utf8_user
self.allow_ip_domain = allow_ip_domain
super(EmailField, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
def validate_user_part(self, user_part):
"""Validate the user part of the email address. Return True if
@@ -280,13 +269,13 @@ class EmailField(StringField):
try:
socket.inet_pton(addr_family, domain_part[1:-1])
return True
except (socket.error, UnicodeEncodeError):
except (OSError, UnicodeEncodeError):
pass
return False
def validate(self, value):
super(EmailField, self).validate(value)
super().validate(value)
if "@" not in value:
self.error(self.error_msg % value)
@@ -303,12 +292,16 @@ class EmailField(StringField):
domain_part = domain_part.encode("idna").decode("ascii")
except UnicodeError:
self.error(
"%s %s" % (self.error_msg % value, "(domain failed IDN encoding)")
"{} {}".format(
self.error_msg % value, "(domain failed IDN encoding)"
)
)
else:
if not self.validate_domain_part(domain_part):
self.error(
"%s %s" % (self.error_msg % value, "(domain validation failed)")
"{} {}".format(
self.error_msg % value, "(domain validation failed)"
)
)
@@ -317,7 +310,7 @@ class IntField(BaseField):
def __init__(self, min_value=None, max_value=None, **kwargs):
self.min_value, self.max_value = min_value, max_value
super(IntField, self).__init__(**kwargs)
super().__init__(**kwargs)
def to_python(self, value):
try:
@@ -342,19 +335,19 @@ class IntField(BaseField):
if value is None:
return value
return super(IntField, self).prepare_query_value(op, int(value))
return super().prepare_query_value(op, int(value))
class LongField(BaseField):
"""64-bit integer field."""
"""64-bit integer field. (Equivalent to IntField since the support to Python2 was dropped)"""
def __init__(self, min_value=None, max_value=None, **kwargs):
self.min_value, self.max_value = min_value, max_value
super(LongField, self).__init__(**kwargs)
super().__init__(**kwargs)
def to_python(self, value):
try:
value = long(value)
value = int(value)
except (TypeError, ValueError):
pass
return value
@@ -364,7 +357,7 @@ class LongField(BaseField):
def validate(self, value):
try:
value = long(value)
value = int(value)
except (TypeError, ValueError):
self.error("%s could not be converted to long" % value)
@@ -378,7 +371,7 @@ class LongField(BaseField):
if value is None:
return value
return super(LongField, self).prepare_query_value(op, long(value))
return super().prepare_query_value(op, int(value))
class FloatField(BaseField):
@@ -386,7 +379,7 @@ class FloatField(BaseField):
def __init__(self, min_value=None, max_value=None, **kwargs):
self.min_value, self.max_value = min_value, max_value
super(FloatField, self).__init__(**kwargs)
super().__init__(**kwargs)
def to_python(self, value):
try:
@@ -396,7 +389,7 @@ class FloatField(BaseField):
return value
def validate(self, value):
if isinstance(value, six.integer_types):
if isinstance(value, int):
try:
value = float(value)
except OverflowError:
@@ -415,7 +408,7 @@ class FloatField(BaseField):
if value is None:
return value
return super(FloatField, self).prepare_query_value(op, float(value))
return super().prepare_query_value(op, float(value))
class DecimalField(BaseField):
@@ -462,7 +455,7 @@ class DecimalField(BaseField):
self.precision = precision
self.rounding = rounding
super(DecimalField, self).__init__(**kwargs)
super().__init__(**kwargs)
def to_python(self, value):
if value is None:
@@ -481,13 +474,13 @@ class DecimalField(BaseField):
if value is None:
return value
if self.force_string:
return six.text_type(self.to_python(value))
return str(self.to_python(value))
return float(self.to_python(value))
def validate(self, value):
if not isinstance(value, decimal.Decimal):
if not isinstance(value, six.string_types):
value = six.text_type(value)
if not isinstance(value, str):
value = str(value)
try:
value = decimal.Decimal(value)
except (TypeError, ValueError, decimal.InvalidOperation) as exc:
@@ -500,7 +493,7 @@ class DecimalField(BaseField):
self.error("Decimal value is too large")
def prepare_query_value(self, op, value):
return super(DecimalField, self).prepare_query_value(op, self.to_mongo(value))
return super().prepare_query_value(op, self.to_mongo(value))
class BooleanField(BaseField):
@@ -540,7 +533,7 @@ class DateTimeField(BaseField):
def validate(self, value):
new_value = self.to_mongo(value)
if not isinstance(new_value, (datetime.datetime, datetime.date)):
self.error(u'cannot parse date "%s"' % value)
self.error('cannot parse date "%s"' % value)
def to_mongo(self, value):
if value is None:
@@ -552,7 +545,7 @@ class DateTimeField(BaseField):
if callable(value):
return value()
if not isinstance(value, six.string_types):
if not isinstance(value, str):
return None
return self._parse_datetime(value)
@@ -597,19 +590,19 @@ class DateTimeField(BaseField):
return None
def prepare_query_value(self, op, value):
return super(DateTimeField, self).prepare_query_value(op, self.to_mongo(value))
return super().prepare_query_value(op, self.to_mongo(value))
class DateField(DateTimeField):
def to_mongo(self, value):
value = super(DateField, self).to_mongo(value)
value = super().to_mongo(value)
# drop hours, minutes, seconds
if isinstance(value, datetime.datetime):
value = datetime.datetime(value.year, value.month, value.day)
return value
def to_python(self, value):
value = super(DateField, self).to_python(value)
value = super().to_python(value)
# convert datetime to date
if isinstance(value, datetime.datetime):
value = datetime.date(value.year, value.month, value.day)
@@ -643,7 +636,7 @@ class ComplexDateTimeField(StringField):
"""
self.separator = separator
self.format = separator.join(["%Y", "%m", "%d", "%H", "%M", "%S", "%f"])
super(ComplexDateTimeField, self).__init__(**kwargs)
super().__init__(**kwargs)
def _convert_from_datetime(self, val):
"""
@@ -674,14 +667,14 @@ class ComplexDateTimeField(StringField):
if instance is None:
return self
data = super(ComplexDateTimeField, self).__get__(instance, owner)
data = super().__get__(instance, owner)
if isinstance(data, datetime.datetime) or data is None:
return data
return self._convert_from_string(data)
def __set__(self, instance, value):
super(ComplexDateTimeField, self).__set__(instance, value)
super().__set__(instance, value)
value = instance._data[self.name]
if value is not None:
if isinstance(value, datetime.datetime):
@@ -706,9 +699,7 @@ class ComplexDateTimeField(StringField):
return self._convert_from_datetime(value)
def prepare_query_value(self, op, value):
return super(ComplexDateTimeField, self).prepare_query_value(
op, self._convert_from_datetime(value)
)
return super().prepare_query_value(op, self._convert_from_datetime(value))
class EmbeddedDocumentField(BaseField):
@@ -719,7 +710,7 @@ class EmbeddedDocumentField(BaseField):
def __init__(self, document_type, **kwargs):
# XXX ValidationError raised outside of the "validate" method.
if not (
isinstance(document_type, six.string_types)
isinstance(document_type, str)
or issubclass(document_type, EmbeddedDocument)
):
self.error(
@@ -728,11 +719,11 @@ class EmbeddedDocumentField(BaseField):
)
self.document_type_obj = document_type
super(EmbeddedDocumentField, self).__init__(**kwargs)
super().__init__(**kwargs)
@property
def document_type(self):
if isinstance(self.document_type_obj, six.string_types):
if isinstance(self.document_type_obj, str):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
resolved_document_type = self.owner_document
else:
@@ -789,7 +780,7 @@ class EmbeddedDocumentField(BaseField):
"Querying the embedded document '%s' failed, due to an invalid query value"
% (self.document_type._class_name,)
)
super(EmbeddedDocumentField, self).prepare_query_value(op, value)
super().prepare_query_value(op, value)
return self.to_mongo(value)
@@ -805,9 +796,7 @@ class GenericEmbeddedDocumentField(BaseField):
"""
def prepare_query_value(self, op, value):
return super(GenericEmbeddedDocumentField, self).prepare_query_value(
op, self.to_mongo(value)
)
return super().prepare_query_value(op, self.to_mongo(value))
def to_python(self, value):
if isinstance(value, dict):
@@ -858,7 +847,7 @@ class DynamicField(BaseField):
"""Convert a Python type to a MongoDB compatible type.
"""
if isinstance(value, six.string_types):
if isinstance(value, str):
return value
if hasattr(value, "to_mongo"):
@@ -880,12 +869,12 @@ class DynamicField(BaseField):
value = {k: v for k, v in enumerate(value)}
data = {}
for k, v in iteritems(value):
for k, v in value.items():
data[k] = self.to_mongo(v, use_db_field, fields)
value = data
if is_list: # Convert back to a list
value = [v for k, v in sorted(iteritems(data), key=itemgetter(0))]
value = [v for k, v in sorted(data.items(), key=itemgetter(0))]
return value
def to_python(self, value):
@@ -895,15 +884,15 @@ class DynamicField(BaseField):
value = doc_cls._get_db().dereference(value["_ref"])
return doc_cls._from_son(value)
return super(DynamicField, self).to_python(value)
return super().to_python(value)
def lookup_member(self, member_name):
return member_name
def prepare_query_value(self, op, value):
if isinstance(value, six.string_types):
if isinstance(value, str):
return StringField().prepare_query_value(op, value)
return super(DynamicField, self).prepare_query_value(op, self.to_mongo(value))
return super().prepare_query_value(op, self.to_mongo(value))
def validate(self, value, clean=True):
if hasattr(value, "validate"):
@@ -924,7 +913,7 @@ class ListField(ComplexBaseField):
self.field = field
self.max_length = max_length
kwargs.setdefault("default", lambda: [])
super(ListField, self).__init__(**kwargs)
super().__init__(**kwargs)
def __get__(self, instance, owner):
if instance is None:
@@ -938,7 +927,7 @@ class ListField(ComplexBaseField):
and value
):
instance._data[self.name] = [self.field.build_lazyref(x) for x in value]
return super(ListField, self).__get__(instance, owner)
return super().__get__(instance, owner)
def validate(self, value):
"""Make sure that a list of valid fields is being used."""
@@ -952,7 +941,7 @@ class ListField(ComplexBaseField):
if self.max_length is not None and len(value) > self.max_length:
self.error("List is too long")
super(ListField, self).validate(value)
super().validate(value)
def prepare_query_value(self, op, value):
# Validate that the `set` operator doesn't contain more items than `max_length`.
@@ -966,14 +955,14 @@ class ListField(ComplexBaseField):
if (
op in ("set", "unset", None)
and hasattr(value, "__iter__")
and not isinstance(value, six.string_types)
and not isinstance(value, str)
and not isinstance(value, BaseDocument)
):
return [self.field.prepare_query_value(op, v) for v in value]
return self.field.prepare_query_value(op, value)
return super(ListField, self).prepare_query_value(op, value)
return super().prepare_query_value(op, value)
class EmbeddedDocumentListField(ListField):
@@ -994,9 +983,7 @@ class EmbeddedDocumentListField(ListField):
:param kwargs: Keyword arguments passed directly into the parent
:class:`~mongoengine.ListField`.
"""
super(EmbeddedDocumentListField, self).__init__(
field=EmbeddedDocumentField(document_type), **kwargs
)
super().__init__(field=EmbeddedDocumentField(document_type), **kwargs)
class SortedListField(ListField):
@@ -1022,10 +1009,10 @@ class SortedListField(ListField):
self._ordering = kwargs.pop("ordering")
if "reverse" in kwargs.keys():
self._order_reverse = kwargs.pop("reverse")
super(SortedListField, self).__init__(field, **kwargs)
super().__init__(field, **kwargs)
def to_mongo(self, value, use_db_field=True, fields=None):
value = super(SortedListField, self).to_mongo(value, use_db_field, fields)
value = super().to_mongo(value, use_db_field, fields)
if self._ordering is not None:
return sorted(
value, key=itemgetter(self._ordering), reverse=self._order_reverse
@@ -1038,9 +1025,7 @@ def key_not_string(d):
dictionary is not a string.
"""
for k, v in d.items():
if not isinstance(k, six.string_types) or (
isinstance(v, dict) and key_not_string(v)
):
if not isinstance(k, str) or (isinstance(v, dict) and key_not_string(v)):
return True
@@ -1080,7 +1065,7 @@ class DictField(ComplexBaseField):
self._auto_dereference = False
kwargs.setdefault("default", lambda: {})
super(DictField, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
def validate(self, value):
"""Make sure that a list of valid fields is being used."""
@@ -1100,7 +1085,7 @@ class DictField(ComplexBaseField):
self.error(
'Invalid dictionary key name - keys may not startswith "$" characters'
)
super(DictField, self).validate(value)
super().validate(value)
def lookup_member(self, member_name):
return DictField(db_field=member_name)
@@ -1117,7 +1102,7 @@ class DictField(ComplexBaseField):
"iexact",
]
if op in match_operators and isinstance(value, six.string_types):
if op in match_operators and isinstance(value, str):
return StringField().prepare_query_value(op, value)
if hasattr(
@@ -1129,7 +1114,7 @@ class DictField(ComplexBaseField):
}
return self.field.prepare_query_value(op, value)
return super(DictField, self).prepare_query_value(op, value)
return super().prepare_query_value(op, value)
class MapField(DictField):
@@ -1144,7 +1129,7 @@ class MapField(DictField):
# XXX ValidationError raised outside of the "validate" method.
if not isinstance(field, BaseField):
self.error("Argument to MapField constructor must be a valid field")
super(MapField, self).__init__(field=field, *args, **kwargs)
super().__init__(field=field, *args, **kwargs)
class ReferenceField(BaseField):
@@ -1204,7 +1189,7 @@ class ReferenceField(BaseField):
:class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`.
"""
# XXX ValidationError raised outside of the "validate" method.
if not isinstance(document_type, six.string_types) and not issubclass(
if not isinstance(document_type, str) and not issubclass(
document_type, Document
):
self.error(
@@ -1215,11 +1200,11 @@ class ReferenceField(BaseField):
self.dbref = dbref
self.document_type_obj = document_type
self.reverse_delete_rule = reverse_delete_rule
super(ReferenceField, self).__init__(**kwargs)
super().__init__(**kwargs)
@property
def document_type(self):
if isinstance(self.document_type_obj, six.string_types):
if isinstance(self.document_type_obj, str):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
@@ -1248,7 +1233,7 @@ class ReferenceField(BaseField):
else:
instance._data[self.name] = cls._from_son(dereferenced)
return super(ReferenceField, self).__get__(instance, owner)
return super().__get__(instance, owner)
def to_mongo(self, document):
if isinstance(document, DBRef):
@@ -1299,7 +1284,7 @@ class ReferenceField(BaseField):
def prepare_query_value(self, op, value):
if value is None:
return None
super(ReferenceField, self).prepare_query_value(op, value)
super().prepare_query_value(op, value)
return self.to_mongo(value)
def validate(self, value):
@@ -1335,7 +1320,7 @@ class CachedReferenceField(BaseField):
fields = []
# XXX ValidationError raised outside of the "validate" method.
if not isinstance(document_type, six.string_types) and not issubclass(
if not isinstance(document_type, str) and not issubclass(
document_type, Document
):
self.error(
@@ -1346,7 +1331,7 @@ class CachedReferenceField(BaseField):
self.auto_sync = auto_sync
self.document_type_obj = document_type
self.fields = fields
super(CachedReferenceField, self).__init__(**kwargs)
super().__init__(**kwargs)
def start_listener(self):
from mongoengine import signals
@@ -1358,7 +1343,7 @@ class CachedReferenceField(BaseField):
return None
update_kwargs = {
"set__%s__%s" % (self.name, key): val
"set__{}__{}".format(self.name, key): val
for key, val in document._delta()[0].items()
if key in self.fields
}
@@ -1380,7 +1365,7 @@ class CachedReferenceField(BaseField):
@property
def document_type(self):
if isinstance(self.document_type_obj, six.string_types):
if isinstance(self.document_type_obj, str):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
@@ -1404,7 +1389,7 @@ class CachedReferenceField(BaseField):
else:
instance._data[self.name] = self.document_type._from_son(dereferenced)
return super(CachedReferenceField, self).__get__(instance, owner)
return super().__get__(instance, owner)
def to_mongo(self, document, use_db_field=True, fields=None):
id_field_name = self.document_type._meta["id_field"]
@@ -1503,12 +1488,12 @@ class GenericReferenceField(BaseField):
def __init__(self, *args, **kwargs):
choices = kwargs.pop("choices", None)
super(GenericReferenceField, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self.choices = []
# Keep the choices as a list of allowed Document class names
if choices:
for choice in choices:
if isinstance(choice, six.string_types):
if isinstance(choice, str):
self.choices.append(choice)
elif isinstance(choice, type) and issubclass(choice, Document):
self.choices.append(choice._class_name)
@@ -1517,7 +1502,7 @@ class GenericReferenceField(BaseField):
# method.
self.error(
"Invalid choices provided: must be a list of"
"Document subclasses and/or six.string_typess"
"Document subclasses and/or str"
)
def _validate_choices(self, value):
@@ -1527,7 +1512,7 @@ class GenericReferenceField(BaseField):
value = value.get("_cls")
elif isinstance(value, Document):
value = value._class_name
super(GenericReferenceField, self)._validate_choices(value)
super()._validate_choices(value)
def __get__(self, instance, owner):
if instance is None:
@@ -1543,7 +1528,7 @@ class GenericReferenceField(BaseField):
else:
instance._data[self.name] = dereferenced
return super(GenericReferenceField, self).__get__(instance, owner)
return super().__get__(instance, owner)
def validate(self, value):
if not isinstance(value, (Document, DBRef, dict, SON)):
@@ -1607,22 +1592,22 @@ class BinaryField(BaseField):
def __init__(self, max_bytes=None, **kwargs):
self.max_bytes = max_bytes
super(BinaryField, self).__init__(**kwargs)
super().__init__(**kwargs)
def __set__(self, instance, value):
"""Handle bytearrays in python 3.1"""
if six.PY3 and isinstance(value, bytearray):
value = six.binary_type(value)
return super(BinaryField, self).__set__(instance, value)
if isinstance(value, bytearray):
value = bytes(value)
return super().__set__(instance, value)
def to_mongo(self, value):
return Binary(value)
def validate(self, value):
if not isinstance(value, (six.binary_type, Binary)):
if not isinstance(value, (bytes, Binary)):
self.error(
"BinaryField only accepts instances of "
"(%s, %s, Binary)" % (six.binary_type.__name__, Binary.__name__)
"(%s, %s, Binary)" % (bytes.__name__, Binary.__name__)
)
if self.max_bytes is not None and len(value) > self.max_bytes:
@@ -1631,14 +1616,14 @@ class BinaryField(BaseField):
def prepare_query_value(self, op, value):
if value is None:
return value
return super(BinaryField, self).prepare_query_value(op, self.to_mongo(value))
return super().prepare_query_value(op, self.to_mongo(value))
class GridFSError(Exception):
pass
class GridFSProxy(object):
class GridFSProxy:
"""Proxy object to handle writing and reading of files to and from GridFS
.. versionadded:: 0.4
@@ -1688,8 +1673,6 @@ class GridFSProxy(object):
def __bool__(self):
return bool(self.grid_id)
__nonzero__ = __bool__ # For Py2 support
def __getstate__(self):
self_dict = self.__dict__
self_dict["_fs"] = None
@@ -1704,12 +1687,12 @@ class GridFSProxy(object):
return self.__copy__()
def __repr__(self):
return "<%s: %s>" % (self.__class__.__name__, self.grid_id)
return "<{}: {}>".format(self.__class__.__name__, self.grid_id)
def __str__(self):
gridout = self.get()
filename = getattr(gridout, "filename") if gridout else "<no file>"
return "<%s: %s (%s)>" % (self.__class__.__name__, filename, self.grid_id)
return "<{}: {} ({})>".format(self.__class__.__name__, filename, self.grid_id)
def __eq__(self, other):
if isinstance(other, GridFSProxy):
@@ -1820,7 +1803,7 @@ class FileField(BaseField):
def __init__(
self, db_alias=DEFAULT_CONNECTION_NAME, collection_name="fs", **kwargs
):
super(FileField, self).__init__(**kwargs)
super().__init__(**kwargs)
self.collection_name = collection_name
self.db_alias = db_alias
@@ -1843,7 +1826,7 @@ class FileField(BaseField):
key = self.name
if (
hasattr(value, "read") and not isinstance(value, GridFSProxy)
) or isinstance(value, (six.binary_type, six.string_types)):
) or isinstance(value, (bytes, str)):
# using "FileField() = file/string" notation
grid_file = instance._data.get(self.name)
# If a file already exists, delete it
@@ -1961,11 +1944,11 @@ class ImageGridFsProxy(GridFSProxy):
w, h = img.size
io = StringIO()
io = BytesIO()
img.save(io, img_format, progressive=progressive)
io.seek(0)
return super(ImageGridFsProxy, self).put(
return super().put(
io, width=w, height=h, format=img_format, thumbnail_id=thumb_id, **kwargs
)
@@ -1975,12 +1958,12 @@ class ImageGridFsProxy(GridFSProxy):
if out and out.thumbnail_id:
self.fs.delete(out.thumbnail_id)
return super(ImageGridFsProxy, self).delete()
return super().delete()
def _put_thumbnail(self, thumbnail, format, progressive, **kwargs):
w, h = thumbnail.size
io = StringIO()
io = BytesIO()
thumbnail.save(io, format, progressive=progressive)
io.seek(0)
@@ -2050,16 +2033,11 @@ class ImageField(FileField):
for att_name, att in extra_args.items():
value = None
if isinstance(att, (tuple, list)):
if six.PY3:
value = dict(
itertools.zip_longest(params_size, att, fillvalue=None)
)
else:
value = dict(map(None, params_size, att))
value = dict(itertools.zip_longest(params_size, att, fillvalue=None))
setattr(self, att_name, value)
super(ImageField, self).__init__(collection_name=collection_name, **kwargs)
super().__init__(collection_name=collection_name, **kwargs)
class SequenceField(BaseField):
@@ -2111,14 +2089,14 @@ class SequenceField(BaseField):
self.value_decorator = (
value_decorator if callable(value_decorator) else self.VALUE_DECORATOR
)
super(SequenceField, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
def generate(self):
"""
Generate and Increment the counter
"""
sequence_name = self.get_sequence_name()
sequence_id = "%s.%s" % (sequence_name, self.name)
sequence_id = "{}.{}".format(sequence_name, self.name)
collection = get_db(alias=self.db_alias)[self.collection_name]
counter = collection.find_one_and_update(
@@ -2132,7 +2110,7 @@ class SequenceField(BaseField):
def set_next_value(self, value):
"""Helper method to set the next sequence value"""
sequence_name = self.get_sequence_name()
sequence_id = "%s.%s" % (sequence_name, self.name)
sequence_id = "{}.{}".format(sequence_name, self.name)
collection = get_db(alias=self.db_alias)[self.collection_name]
counter = collection.find_one_and_update(
filter={"_id": sequence_id},
@@ -2149,7 +2127,7 @@ class SequenceField(BaseField):
as it is only fixed on set.
"""
sequence_name = self.get_sequence_name()
sequence_id = "%s.%s" % (sequence_name, self.name)
sequence_id = "{}.{}".format(sequence_name, self.name)
collection = get_db(alias=self.db_alias)[self.collection_name]
data = collection.find_one({"_id": sequence_id})
@@ -2172,7 +2150,7 @@ class SequenceField(BaseField):
)
def __get__(self, instance, owner):
value = super(SequenceField, self).__get__(instance, owner)
value = super().__get__(instance, owner)
if value is None and instance._initialised:
value = self.generate()
instance._data[self.name] = value
@@ -2185,7 +2163,7 @@ class SequenceField(BaseField):
if value is None and instance._initialised:
value = self.generate()
return super(SequenceField, self).__set__(instance, value)
return super().__set__(instance, value)
def prepare_query_value(self, op, value):
"""
@@ -2219,14 +2197,14 @@ class UUIDField(BaseField):
.. versionchanged:: 0.6.19
"""
self._binary = binary
super(UUIDField, self).__init__(**kwargs)
super().__init__(**kwargs)
def to_python(self, value):
if not self._binary:
original_value = value
try:
if not isinstance(value, six.string_types):
value = six.text_type(value)
if not isinstance(value, str):
value = str(value)
return uuid.UUID(value)
except (ValueError, TypeError, AttributeError):
return original_value
@@ -2234,8 +2212,8 @@ class UUIDField(BaseField):
def to_mongo(self, value):
if not self._binary:
return six.text_type(value)
elif isinstance(value, six.string_types):
return str(value)
elif isinstance(value, str):
return uuid.UUID(value)
return value
@@ -2246,7 +2224,7 @@ class UUIDField(BaseField):
def validate(self, value):
if not isinstance(value, uuid.UUID):
if not isinstance(value, six.string_types):
if not isinstance(value, str):
value = str(value)
try:
uuid.UUID(value)
@@ -2445,7 +2423,7 @@ class LazyReferenceField(BaseField):
document. Note this only work getting field (not setting or deleting).
"""
# XXX ValidationError raised outside of the "validate" method.
if not isinstance(document_type, six.string_types) and not issubclass(
if not isinstance(document_type, str) and not issubclass(
document_type, Document
):
self.error(
@@ -2457,11 +2435,11 @@ class LazyReferenceField(BaseField):
self.passthrough = passthrough
self.document_type_obj = document_type
self.reverse_delete_rule = reverse_delete_rule
super(LazyReferenceField, self).__init__(**kwargs)
super().__init__(**kwargs)
@property
def document_type(self):
if isinstance(self.document_type_obj, six.string_types):
if isinstance(self.document_type_obj, str):
if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
self.document_type_obj = self.owner_document
else:
@@ -2500,7 +2478,7 @@ class LazyReferenceField(BaseField):
if value:
instance._data[self.name] = value
return super(LazyReferenceField, self).__get__(instance, owner)
return super().__get__(instance, owner)
def to_mongo(self, value):
if isinstance(value, LazyReference):
@@ -2564,7 +2542,7 @@ class LazyReferenceField(BaseField):
def prepare_query_value(self, op, value):
if value is None:
return None
super(LazyReferenceField, self).prepare_query_value(op, value)
super().prepare_query_value(op, value)
return self.to_mongo(value)
def lookup_member(self, member_name):
@@ -2591,12 +2569,12 @@ class GenericLazyReferenceField(GenericReferenceField):
def __init__(self, *args, **kwargs):
self.passthrough = kwargs.pop("passthrough", False)
super(GenericLazyReferenceField, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
def _validate_choices(self, value):
if isinstance(value, LazyReference):
value = value.document_type._class_name
super(GenericLazyReferenceField, self)._validate_choices(value)
super()._validate_choices(value)
def build_lazyref(self, value):
if isinstance(value, LazyReference):
@@ -2625,7 +2603,7 @@ class GenericLazyReferenceField(GenericReferenceField):
if value:
instance._data[self.name] = value
return super(GenericLazyReferenceField, self).__get__(instance, owner)
return super().__get__(instance, owner)
def validate(self, value):
if isinstance(value, LazyReference) and value.pk is None:
@@ -2633,7 +2611,7 @@ class GenericLazyReferenceField(GenericReferenceField):
"You can only reference documents once they have been"
" saved to the database"
)
return super(GenericLazyReferenceField, self).validate(value)
return super().validate(value)
def to_mongo(self, document):
if document is None:
@@ -2652,4 +2630,4 @@ class GenericLazyReferenceField(GenericReferenceField):
)
)
else:
return super(GenericLazyReferenceField, self).to_mongo(document)
return super().to_mongo(document)