fix: 3.10 style imports not resolving correctly (#594)

This commit is contained in:
James Hilton-Balfe 2024-08-14 08:01:31 +01:00 committed by GitHub
parent 5fdd0bb24f
commit f96f51650c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 54 additions and 37 deletions

View File

@ -62,6 +62,13 @@ if TYPE_CHECKING:
SupportsWrite, SupportsWrite,
) )
if sys.version_info >= (3, 10):
from types import UnionType as _types_UnionType
else:
class _types_UnionType:
...
# Proto 3 data types # Proto 3 data types
TYPE_ENUM = "enum" TYPE_ENUM = "enum"
@ -148,6 +155,7 @@ def datetime_default_gen() -> datetime:
DATETIME_ZERO = datetime_default_gen() DATETIME_ZERO = datetime_default_gen()
# Special protobuf json doubles # Special protobuf json doubles
INFINITY = "Infinity" INFINITY = "Infinity"
NEG_INFINITY = "-Infinity" NEG_INFINITY = "-Infinity"
@ -1166,30 +1174,29 @@ class Message(ABC):
def _get_field_default_gen(cls, field: dataclasses.Field) -> Any: def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
t = cls._type_hint(field.name) t = cls._type_hint(field.name)
if hasattr(t, "__origin__"): is_310_union = isinstance(t, _types_UnionType)
if t.__origin__ is dict: if hasattr(t, "__origin__") or is_310_union:
# This is some kind of map (dict in Python). if is_310_union or t.__origin__ is Union:
return dict
elif t.__origin__ is list:
# This is some kind of list (repeated) field.
return list
elif t.__origin__ is Union and t.__args__[1] is type(None):
# This is an optional field (either wrapped, or using proto3 # This is an optional field (either wrapped, or using proto3
# field presence). For setting the default we really don't care # field presence). For setting the default we really don't care
# what kind of field it is. # what kind of field it is.
return type(None) return type(None)
else: if t.__origin__ is list:
return t # This is some kind of list (repeated) field.
elif issubclass(t, Enum): return list
if t.__origin__ is dict:
# This is some kind of map (dict in Python).
return dict
return t
if issubclass(t, Enum):
# Enums always default to zero. # Enums always default to zero.
return t.try_value return t.try_value
elif t is datetime: if t is datetime:
# Offsets are relative to 1970-01-01T00:00:00Z # Offsets are relative to 1970-01-01T00:00:00Z
return datetime_default_gen return datetime_default_gen
else: # This is either a primitive scalar or another message type. Calling
# This is either a primitive scalar or another message type. Calling # it should result in its zero value.
# it should result in its zero value. return t
return t
def _postprocess_single( def _postprocess_single(
self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any

View File

@ -1,6 +1,9 @@
from __future__ import annotations
import os import os
import re import re
from typing import ( from typing import (
TYPE_CHECKING,
Dict, Dict,
List, List,
Set, Set,
@ -13,6 +16,9 @@ from ..lib.google import protobuf as google_protobuf
from .naming import pythonize_class_name from .naming import pythonize_class_name
if TYPE_CHECKING:
from ..plugin.typing_compiler import TypingCompiler
WRAPPER_TYPES: Dict[str, Type] = { WRAPPER_TYPES: Dict[str, Type] = {
".google.protobuf.DoubleValue": google_protobuf.DoubleValue, ".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
".google.protobuf.FloatValue": google_protobuf.FloatValue, ".google.protobuf.FloatValue": google_protobuf.FloatValue,
@ -47,7 +53,7 @@ def get_type_reference(
package: str, package: str,
imports: set, imports: set,
source_type: str, source_type: str,
typing_compiler: "TypingCompiler", typing_compiler: TypingCompiler,
unwrap: bool = True, unwrap: bool = True,
pydantic: bool = False, pydantic: bool = False,
) -> str: ) -> str:

View File

@ -139,29 +139,35 @@ class TypingImportTypingCompiler(TypingCompiler):
class NoTyping310TypingCompiler(TypingCompiler): class NoTyping310TypingCompiler(TypingCompiler):
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
@staticmethod
def _fmt(type: str) -> str: # for now this is necessary till 3.14
if type.startswith('"'):
return type[1:-1]
return type
def optional(self, type: str) -> str: def optional(self, type: str) -> str:
return f"{type} | None" return f'"{self._fmt(type)} | None"'
def list(self, type: str) -> str: def list(self, type: str) -> str:
return f"list[{type}]" return f'"list[{self._fmt(type)}]"'
def dict(self, key: str, value: str) -> str: def dict(self, key: str, value: str) -> str:
return f"dict[{key}, {value}]" return f'"dict[{key}, {self._fmt(value)}]"'
def union(self, *types: str) -> str: def union(self, *types: str) -> str:
return " | ".join(types) return f'"{" | ".join(map(self._fmt, types))}"'
def iterable(self, type: str) -> str: def iterable(self, type: str) -> str:
self._imports["typing"].add("Iterable") self._imports["collections.abc"].add("Iterable")
return f"Iterable[{type}]" return f'"Iterable[{type}]"'
def async_iterable(self, type: str) -> str: def async_iterable(self, type: str) -> str:
self._imports["typing"].add("AsyncIterable") self._imports["collections.abc"].add("AsyncIterable")
return f"AsyncIterable[{type}]" return f'"AsyncIterable[{type}]"'
def async_iterator(self, type: str) -> str: def async_iterator(self, type: str) -> str:
self._imports["typing"].add("AsyncIterator") self._imports["collections.abc"].add("AsyncIterator")
return f"AsyncIterator[{type}]" return f'"AsyncIterator[{type}]"'
def imports(self) -> Dict[str, Optional[Set[str]]]: def imports(self) -> Dict[str, Optional[Set[str]]]:
return {k: v if v else None for k, v in self._imports.items()} return {k: v if v else None for k, v in self._imports.items()}

View File

@ -62,19 +62,17 @@ def test_typing_import_typing_compiler():
def test_no_typing_311_typing_compiler(): def test_no_typing_311_typing_compiler():
compiler = NoTyping310TypingCompiler() compiler = NoTyping310TypingCompiler()
assert compiler.imports() == {} assert compiler.imports() == {}
assert compiler.optional("str") == "str | None" assert compiler.optional("str") == '"str | None"'
assert compiler.imports() == {} assert compiler.imports() == {}
assert compiler.list("str") == "list[str]" assert compiler.list("str") == '"list[str]"'
assert compiler.imports() == {} assert compiler.imports() == {}
assert compiler.dict("str", "int") == "dict[str, int]" assert compiler.dict("str", "int") == '"dict[str, int]"'
assert compiler.imports() == {} assert compiler.imports() == {}
assert compiler.union("str", "int") == "str | int" assert compiler.union("str", "int") == '"str | int"'
assert compiler.imports() == {} assert compiler.imports() == {}
assert compiler.iterable("str") == "Iterable[str]" assert compiler.iterable("str") == '"Iterable[str]"'
assert compiler.imports() == {"typing": {"Iterable"}} assert compiler.async_iterable("str") == '"AsyncIterable[str]"'
assert compiler.async_iterable("str") == "AsyncIterable[str]" assert compiler.async_iterator("str") == '"AsyncIterator[str]"'
assert compiler.imports() == {"typing": {"Iterable", "AsyncIterable"}}
assert compiler.async_iterator("str") == "AsyncIterator[str]"
assert compiler.imports() == { assert compiler.imports() == {
"typing": {"Iterable", "AsyncIterable", "AsyncIterator"} "collections.abc": {"Iterable", "AsyncIterable", "AsyncIterator"}
} }