fix: 3.10 style imports not resolving correctly (#594)
This commit is contained in:
		
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							5fdd0bb24f
						
					
				
				
					commit
					f96f51650c
				
			| @@ -62,6 +62,13 @@ if TYPE_CHECKING: | ||||
|         SupportsWrite, | ||||
|     ) | ||||
|  | ||||
| if sys.version_info >= (3, 10): | ||||
|     from types import UnionType as _types_UnionType | ||||
| else: | ||||
|  | ||||
|     class _types_UnionType: | ||||
|         ... | ||||
|  | ||||
|  | ||||
| # Proto 3 data types | ||||
| TYPE_ENUM = "enum" | ||||
| @@ -148,6 +155,7 @@ def datetime_default_gen() -> datetime: | ||||
|  | ||||
| DATETIME_ZERO = datetime_default_gen() | ||||
|  | ||||
|  | ||||
| # Special protobuf json doubles | ||||
| INFINITY = "Infinity" | ||||
| NEG_INFINITY = "-Infinity" | ||||
| @@ -1166,30 +1174,29 @@ class Message(ABC): | ||||
|     def _get_field_default_gen(cls, field: dataclasses.Field) -> Any: | ||||
|         t = cls._type_hint(field.name) | ||||
|  | ||||
|         if hasattr(t, "__origin__"): | ||||
|             if t.__origin__ is dict: | ||||
|                 # This is some kind of map (dict in Python). | ||||
|                 return dict | ||||
|             elif t.__origin__ is list: | ||||
|                 # This is some kind of list (repeated) field. | ||||
|                 return list | ||||
|             elif t.__origin__ is Union and t.__args__[1] is type(None): | ||||
|         is_310_union = isinstance(t, _types_UnionType) | ||||
|         if hasattr(t, "__origin__") or is_310_union: | ||||
|             if is_310_union or t.__origin__ is Union: | ||||
|                 # This is an optional field (either wrapped, or using proto3 | ||||
|                 # field presence). For setting the default we really don't care | ||||
|                 # what kind of field it is. | ||||
|                 return type(None) | ||||
|             else: | ||||
|                 return t | ||||
|         elif issubclass(t, Enum): | ||||
|             if t.__origin__ is list: | ||||
|                 # This is some kind of list (repeated) field. | ||||
|                 return list | ||||
|             if t.__origin__ is dict: | ||||
|                 # This is some kind of map (dict in Python). | ||||
|                 return dict | ||||
|             return t | ||||
|         if issubclass(t, Enum): | ||||
|             # Enums always default to zero. | ||||
|             return t.try_value | ||||
|         elif t is datetime: | ||||
|         if t is datetime: | ||||
|             # Offsets are relative to 1970-01-01T00:00:00Z | ||||
|             return datetime_default_gen | ||||
|         else: | ||||
|             # This is either a primitive scalar or another message type. Calling | ||||
|             # it should result in its zero value. | ||||
|             return t | ||||
|         # This is either a primitive scalar or another message type. Calling | ||||
|         # it should result in its zero value. | ||||
|         return t | ||||
|  | ||||
|     def _postprocess_single( | ||||
|         self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any | ||||
|   | ||||
| @@ -1,6 +1,9 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import os | ||||
| import re | ||||
| from typing import ( | ||||
|     TYPE_CHECKING, | ||||
|     Dict, | ||||
|     List, | ||||
|     Set, | ||||
| @@ -13,6 +16,9 @@ from ..lib.google import protobuf as google_protobuf | ||||
| from .naming import pythonize_class_name | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from ..plugin.typing_compiler import TypingCompiler | ||||
|  | ||||
| WRAPPER_TYPES: Dict[str, Type] = { | ||||
|     ".google.protobuf.DoubleValue": google_protobuf.DoubleValue, | ||||
|     ".google.protobuf.FloatValue": google_protobuf.FloatValue, | ||||
| @@ -47,7 +53,7 @@ def get_type_reference( | ||||
|     package: str, | ||||
|     imports: set, | ||||
|     source_type: str, | ||||
|     typing_compiler: "TypingCompiler", | ||||
|     typing_compiler: TypingCompiler, | ||||
|     unwrap: bool = True, | ||||
|     pydantic: bool = False, | ||||
| ) -> str: | ||||
|   | ||||
| @@ -139,29 +139,35 @@ class TypingImportTypingCompiler(TypingCompiler): | ||||
| class NoTyping310TypingCompiler(TypingCompiler): | ||||
|     _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _fmt(type: str) -> str:  # for now this is necessary till 3.14 | ||||
|         if type.startswith('"'): | ||||
|             return type[1:-1] | ||||
|         return type | ||||
|  | ||||
|     def optional(self, type: str) -> str: | ||||
|         return f"{type} | None" | ||||
|         return f'"{self._fmt(type)} | None"' | ||||
|  | ||||
|     def list(self, type: str) -> str: | ||||
|         return f"list[{type}]" | ||||
|         return f'"list[{self._fmt(type)}]"' | ||||
|  | ||||
|     def dict(self, key: str, value: str) -> str: | ||||
|         return f"dict[{key}, {value}]" | ||||
|         return f'"dict[{key}, {self._fmt(value)}]"' | ||||
|  | ||||
|     def union(self, *types: str) -> str: | ||||
|         return " | ".join(types) | ||||
|         return f'"{" | ".join(map(self._fmt, types))}"' | ||||
|  | ||||
|     def iterable(self, type: str) -> str: | ||||
|         self._imports["typing"].add("Iterable") | ||||
|         return f"Iterable[{type}]" | ||||
|         self._imports["collections.abc"].add("Iterable") | ||||
|         return f'"Iterable[{type}]"' | ||||
|  | ||||
|     def async_iterable(self, type: str) -> str: | ||||
|         self._imports["typing"].add("AsyncIterable") | ||||
|         return f"AsyncIterable[{type}]" | ||||
|         self._imports["collections.abc"].add("AsyncIterable") | ||||
|         return f'"AsyncIterable[{type}]"' | ||||
|  | ||||
|     def async_iterator(self, type: str) -> str: | ||||
|         self._imports["typing"].add("AsyncIterator") | ||||
|         return f"AsyncIterator[{type}]" | ||||
|         self._imports["collections.abc"].add("AsyncIterator") | ||||
|         return f'"AsyncIterator[{type}]"' | ||||
|  | ||||
|     def imports(self) -> Dict[str, Optional[Set[str]]]: | ||||
|         return {k: v if v else None for k, v in self._imports.items()} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user