fix: 3.10 style imports not resolving correctly (#594)
This commit is contained in:
parent
5fdd0bb24f
commit
f96f51650c
@ -62,6 +62,13 @@ if TYPE_CHECKING:
|
|||||||
SupportsWrite,
|
SupportsWrite,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 10):
|
||||||
|
from types import UnionType as _types_UnionType
|
||||||
|
else:
|
||||||
|
|
||||||
|
class _types_UnionType:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
# Proto 3 data types
|
# Proto 3 data types
|
||||||
TYPE_ENUM = "enum"
|
TYPE_ENUM = "enum"
|
||||||
@ -148,6 +155,7 @@ def datetime_default_gen() -> datetime:
|
|||||||
|
|
||||||
DATETIME_ZERO = datetime_default_gen()
|
DATETIME_ZERO = datetime_default_gen()
|
||||||
|
|
||||||
|
|
||||||
# Special protobuf json doubles
|
# Special protobuf json doubles
|
||||||
INFINITY = "Infinity"
|
INFINITY = "Infinity"
|
||||||
NEG_INFINITY = "-Infinity"
|
NEG_INFINITY = "-Infinity"
|
||||||
@ -1166,30 +1174,29 @@ class Message(ABC):
|
|||||||
def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
|
def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
|
||||||
t = cls._type_hint(field.name)
|
t = cls._type_hint(field.name)
|
||||||
|
|
||||||
if hasattr(t, "__origin__"):
|
is_310_union = isinstance(t, _types_UnionType)
|
||||||
if t.__origin__ is dict:
|
if hasattr(t, "__origin__") or is_310_union:
|
||||||
# This is some kind of map (dict in Python).
|
if is_310_union or t.__origin__ is Union:
|
||||||
return dict
|
|
||||||
elif t.__origin__ is list:
|
|
||||||
# This is some kind of list (repeated) field.
|
|
||||||
return list
|
|
||||||
elif t.__origin__ is Union and t.__args__[1] is type(None):
|
|
||||||
# This is an optional field (either wrapped, or using proto3
|
# This is an optional field (either wrapped, or using proto3
|
||||||
# field presence). For setting the default we really don't care
|
# field presence). For setting the default we really don't care
|
||||||
# what kind of field it is.
|
# what kind of field it is.
|
||||||
return type(None)
|
return type(None)
|
||||||
else:
|
if t.__origin__ is list:
|
||||||
return t
|
# This is some kind of list (repeated) field.
|
||||||
elif issubclass(t, Enum):
|
return list
|
||||||
|
if t.__origin__ is dict:
|
||||||
|
# This is some kind of map (dict in Python).
|
||||||
|
return dict
|
||||||
|
return t
|
||||||
|
if issubclass(t, Enum):
|
||||||
# Enums always default to zero.
|
# Enums always default to zero.
|
||||||
return t.try_value
|
return t.try_value
|
||||||
elif t is datetime:
|
if t is datetime:
|
||||||
# Offsets are relative to 1970-01-01T00:00:00Z
|
# Offsets are relative to 1970-01-01T00:00:00Z
|
||||||
return datetime_default_gen
|
return datetime_default_gen
|
||||||
else:
|
# This is either a primitive scalar or another message type. Calling
|
||||||
# This is either a primitive scalar or another message type. Calling
|
# it should result in its zero value.
|
||||||
# it should result in its zero value.
|
return t
|
||||||
return t
|
|
||||||
|
|
||||||
def _postprocess_single(
|
def _postprocess_single(
|
||||||
self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any
|
self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Set,
|
Set,
|
||||||
@ -13,6 +16,9 @@ from ..lib.google import protobuf as google_protobuf
|
|||||||
from .naming import pythonize_class_name
|
from .naming import pythonize_class_name
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..plugin.typing_compiler import TypingCompiler
|
||||||
|
|
||||||
WRAPPER_TYPES: Dict[str, Type] = {
|
WRAPPER_TYPES: Dict[str, Type] = {
|
||||||
".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
|
".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
|
||||||
".google.protobuf.FloatValue": google_protobuf.FloatValue,
|
".google.protobuf.FloatValue": google_protobuf.FloatValue,
|
||||||
@ -47,7 +53,7 @@ def get_type_reference(
|
|||||||
package: str,
|
package: str,
|
||||||
imports: set,
|
imports: set,
|
||||||
source_type: str,
|
source_type: str,
|
||||||
typing_compiler: "TypingCompiler",
|
typing_compiler: TypingCompiler,
|
||||||
unwrap: bool = True,
|
unwrap: bool = True,
|
||||||
pydantic: bool = False,
|
pydantic: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
@ -139,29 +139,35 @@ class TypingImportTypingCompiler(TypingCompiler):
|
|||||||
class NoTyping310TypingCompiler(TypingCompiler):
|
class NoTyping310TypingCompiler(TypingCompiler):
|
||||||
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
|
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fmt(type: str) -> str: # for now this is necessary till 3.14
|
||||||
|
if type.startswith('"'):
|
||||||
|
return type[1:-1]
|
||||||
|
return type
|
||||||
|
|
||||||
def optional(self, type: str) -> str:
|
def optional(self, type: str) -> str:
|
||||||
return f"{type} | None"
|
return f'"{self._fmt(type)} | None"'
|
||||||
|
|
||||||
def list(self, type: str) -> str:
|
def list(self, type: str) -> str:
|
||||||
return f"list[{type}]"
|
return f'"list[{self._fmt(type)}]"'
|
||||||
|
|
||||||
def dict(self, key: str, value: str) -> str:
|
def dict(self, key: str, value: str) -> str:
|
||||||
return f"dict[{key}, {value}]"
|
return f'"dict[{key}, {self._fmt(value)}]"'
|
||||||
|
|
||||||
def union(self, *types: str) -> str:
|
def union(self, *types: str) -> str:
|
||||||
return " | ".join(types)
|
return f'"{" | ".join(map(self._fmt, types))}"'
|
||||||
|
|
||||||
def iterable(self, type: str) -> str:
|
def iterable(self, type: str) -> str:
|
||||||
self._imports["typing"].add("Iterable")
|
self._imports["collections.abc"].add("Iterable")
|
||||||
return f"Iterable[{type}]"
|
return f'"Iterable[{type}]"'
|
||||||
|
|
||||||
def async_iterable(self, type: str) -> str:
|
def async_iterable(self, type: str) -> str:
|
||||||
self._imports["typing"].add("AsyncIterable")
|
self._imports["collections.abc"].add("AsyncIterable")
|
||||||
return f"AsyncIterable[{type}]"
|
return f'"AsyncIterable[{type}]"'
|
||||||
|
|
||||||
def async_iterator(self, type: str) -> str:
|
def async_iterator(self, type: str) -> str:
|
||||||
self._imports["typing"].add("AsyncIterator")
|
self._imports["collections.abc"].add("AsyncIterator")
|
||||||
return f"AsyncIterator[{type}]"
|
return f'"AsyncIterator[{type}]"'
|
||||||
|
|
||||||
def imports(self) -> Dict[str, Optional[Set[str]]]:
|
def imports(self) -> Dict[str, Optional[Set[str]]]:
|
||||||
return {k: v if v else None for k, v in self._imports.items()}
|
return {k: v if v else None for k, v in self._imports.items()}
|
||||||
|
@ -62,19 +62,17 @@ def test_typing_import_typing_compiler():
|
|||||||
def test_no_typing_311_typing_compiler():
|
def test_no_typing_311_typing_compiler():
|
||||||
compiler = NoTyping310TypingCompiler()
|
compiler = NoTyping310TypingCompiler()
|
||||||
assert compiler.imports() == {}
|
assert compiler.imports() == {}
|
||||||
assert compiler.optional("str") == "str | None"
|
assert compiler.optional("str") == '"str | None"'
|
||||||
assert compiler.imports() == {}
|
assert compiler.imports() == {}
|
||||||
assert compiler.list("str") == "list[str]"
|
assert compiler.list("str") == '"list[str]"'
|
||||||
assert compiler.imports() == {}
|
assert compiler.imports() == {}
|
||||||
assert compiler.dict("str", "int") == "dict[str, int]"
|
assert compiler.dict("str", "int") == '"dict[str, int]"'
|
||||||
assert compiler.imports() == {}
|
assert compiler.imports() == {}
|
||||||
assert compiler.union("str", "int") == "str | int"
|
assert compiler.union("str", "int") == '"str | int"'
|
||||||
assert compiler.imports() == {}
|
assert compiler.imports() == {}
|
||||||
assert compiler.iterable("str") == "Iterable[str]"
|
assert compiler.iterable("str") == '"Iterable[str]"'
|
||||||
assert compiler.imports() == {"typing": {"Iterable"}}
|
assert compiler.async_iterable("str") == '"AsyncIterable[str]"'
|
||||||
assert compiler.async_iterable("str") == "AsyncIterable[str]"
|
assert compiler.async_iterator("str") == '"AsyncIterator[str]"'
|
||||||
assert compiler.imports() == {"typing": {"Iterable", "AsyncIterable"}}
|
|
||||||
assert compiler.async_iterator("str") == "AsyncIterator[str]"
|
|
||||||
assert compiler.imports() == {
|
assert compiler.imports() == {
|
||||||
"typing": {"Iterable", "AsyncIterable", "AsyncIterator"}
|
"collections.abc": {"Iterable", "AsyncIterable", "AsyncIterator"}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user