Handle typing collisions and add validation to a files module for overlaping declarations (#582)

* Fix 'typing' import collisions.

* Fix formatting.

* Fix self-test issues.

* Validation for modules, different typing configurations

* add readme

* make warning

* fix format

---------

Co-authored-by: Scott Hendricks <scott.hendricks@confluent.io>
This commit is contained in:
Ian McDonald
2024-07-19 16:02:09 -07:00
committed by GitHub
parent 7c6c627938
commit 8b59234856
13 changed files with 899 additions and 169 deletions

View File

@@ -47,6 +47,7 @@ def get_type_reference(
package: str,
imports: set,
source_type: str,
typing_compiler: "TypingCompiler",
unwrap: bool = True,
pydantic: bool = False,
) -> str:
@@ -57,7 +58,7 @@ def get_type_reference(
if unwrap:
if source_type in WRAPPER_TYPES:
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
return f"Optional[{wrapped_type.__name__}]"
return typing_compiler.optional(wrapped_type.__name__)
if source_type == ".google.protobuf.Duration":
return "timedelta"

View File

@@ -1,4 +1,7 @@
import os.path
import sys
from .module_validation import ModuleValidator
try:
@@ -30,9 +33,12 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
lstrip_blocks=True,
loader=jinja2.FileSystemLoader(templates_folder),
)
template = env.get_template("template.py.j2")
# Load the body first so we have a compleate list of imports needed.
body_template = env.get_template("template.py.j2")
header_template = env.get_template("header.py.j2")
code = template.render(output_file=output_file)
code = body_template.render(output_file=output_file)
code = header_template.render(output_file=output_file) + code
code = isort.api.sort_code_string(
code=code,
show_diff=False,
@@ -44,7 +50,18 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
force_grid_wrap=2,
known_third_party=["grpclib", "betterproto"],
)
return black.format_str(
code = black.format_str(
src_contents=code,
mode=black.Mode(),
)
# Validate the generated code.
validator = ModuleValidator(iter(code.splitlines()))
if not validator.validate():
message_builder = ["[WARNING]: Generated code has collisions in the module:"]
for collision, lines in validator.collisions.items():
message_builder.append(f' "{collision}" on lines:')
for num, line in lines:
message_builder.append(f" {num}:{line}")
print("\n".join(message_builder), file=sys.stderr)
return code

View File

@@ -29,10 +29,8 @@ instantiating field `A` with parent message `B` should add a
reference to `A` to `B`'s `fields` attribute.
"""
import builtins
import re
import textwrap
from dataclasses import (
dataclass,
field,
@@ -49,12 +47,6 @@ from typing import (
)
import betterproto
from betterproto import which_one_of
from betterproto.casing import sanitize_name
from betterproto.compile.importing import (
get_type_reference,
parse_source_type_name,
)
from betterproto.compile.naming import (
pythonize_class_name,
pythonize_field_name,
@@ -72,6 +64,7 @@ from betterproto.lib.google.protobuf import (
)
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
from .. import which_one_of
from ..compile.importing import (
get_type_reference,
parse_source_type_name,
@@ -82,6 +75,10 @@ from ..compile.naming import (
pythonize_field_name,
pythonize_method_name,
)
from .typing_compiler import (
DirectImportTypingCompiler,
TypingCompiler,
)
# Create a unique placeholder to deal with
@@ -173,6 +170,7 @@ class ProtoContentBase:
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
source_file: FileDescriptorProto
typing_compiler: TypingCompiler
path: List[int]
comment_indent: int = 4
parent: Union["betterproto.Message", "OutputTemplate"]
@@ -242,7 +240,6 @@ class OutputTemplate:
input_files: List[str] = field(default_factory=list)
imports: Set[str] = field(default_factory=set)
datetime_imports: Set[str] = field(default_factory=set)
typing_imports: Set[str] = field(default_factory=set)
pydantic_imports: Set[str] = field(default_factory=set)
builtins_import: bool = False
messages: List["MessageCompiler"] = field(default_factory=list)
@@ -251,6 +248,7 @@ class OutputTemplate:
imports_type_checking_only: Set[str] = field(default_factory=set)
pydantic_dataclasses: bool = False
output: bool = True
typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)
@property
def package(self) -> str:
@@ -289,6 +287,7 @@ class MessageCompiler(ProtoContentBase):
"""Representation of a protobuf message."""
source_file: FileDescriptorProto
typing_compiler: TypingCompiler
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
proto_obj: DescriptorProto = PLACEHOLDER
path: List[int] = PLACEHOLDER
@@ -319,7 +318,7 @@ class MessageCompiler(ProtoContentBase):
@property
def annotation(self) -> str:
if self.repeated:
return f"List[{self.py_name}]"
return self.typing_compiler.list(self.py_name)
return self.py_name
@property
@@ -434,18 +433,6 @@ class FieldCompiler(MessageCompiler):
imports.add("datetime")
return imports
@property
def typing_imports(self) -> Set[str]:
imports = set()
annotation = self.annotation
if "Optional[" in annotation:
imports.add("Optional")
if "List[" in annotation:
imports.add("List")
if "Dict[" in annotation:
imports.add("Dict")
return imports
@property
def pydantic_imports(self) -> Set[str]:
return set()
@@ -458,7 +445,6 @@ class FieldCompiler(MessageCompiler):
def add_imports_to(self, output_file: OutputTemplate) -> None:
output_file.datetime_imports.update(self.datetime_imports)
output_file.typing_imports.update(self.typing_imports)
output_file.pydantic_imports.update(self.pydantic_imports)
output_file.builtins_import = output_file.builtins_import or self.use_builtins
@@ -488,7 +474,9 @@ class FieldCompiler(MessageCompiler):
@property
def mutable(self) -> bool:
"""True if the field is a mutable type, otherwise False."""
return self.annotation.startswith(("List[", "Dict["))
return self.annotation.startswith(
("typing.List[", "typing.Dict[", "dict[", "list[", "Dict[", "List[")
)
@property
def field_type(self) -> str:
@@ -562,6 +550,7 @@ class FieldCompiler(MessageCompiler):
package=self.output_file.package,
imports=self.output_file.imports,
source_type=self.proto_obj.type_name,
typing_compiler=self.typing_compiler,
pydantic=self.output_file.pydantic_dataclasses,
)
else:
@@ -573,9 +562,9 @@ class FieldCompiler(MessageCompiler):
if self.use_builtins:
py_type = f"builtins.{py_type}"
if self.repeated:
return f"List[{py_type}]"
return self.typing_compiler.list(py_type)
if self.optional:
return f"Optional[{py_type}]"
return self.typing_compiler.optional(py_type)
return py_type
@@ -623,11 +612,13 @@ class MapEntryCompiler(FieldCompiler):
source_file=self.source_file,
parent=self,
proto_obj=nested.field[0], # key
typing_compiler=self.typing_compiler,
).py_type
self.py_v_type = FieldCompiler(
source_file=self.source_file,
parent=self,
proto_obj=nested.field[1], # value
typing_compiler=self.typing_compiler,
).py_type
# Get proto types
@@ -645,7 +636,7 @@ class MapEntryCompiler(FieldCompiler):
@property
def annotation(self) -> str:
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
return self.typing_compiler.dict(self.py_k_type, self.py_v_type)
@property
def repeated(self) -> bool:
@@ -702,7 +693,6 @@ class ServiceCompiler(ProtoContentBase):
def __post_init__(self) -> None:
# Add service to output file
self.output_file.services.append(self)
self.output_file.typing_imports.add("Dict")
super().__post_init__() # check for unset fields
@property
@@ -725,22 +715,6 @@ class ServiceMethodCompiler(ProtoContentBase):
# Add method to service
self.parent.methods.append(self)
# Check for imports
if "Optional" in self.py_output_message_type:
self.output_file.typing_imports.add("Optional")
# Check for Async imports
if self.client_streaming:
self.output_file.typing_imports.add("AsyncIterable")
self.output_file.typing_imports.add("Iterable")
self.output_file.typing_imports.add("Union")
# Required by both client and server
if self.client_streaming or self.server_streaming:
self.output_file.typing_imports.add("AsyncIterator")
# add imports required for request arguments timeout, deadline and metadata
self.output_file.typing_imports.add("Optional")
self.output_file.imports_type_checking_only.add("import grpclib.server")
self.output_file.imports_type_checking_only.add(
"from betterproto.grpc.grpclib_client import MetadataLike"
@@ -806,6 +780,7 @@ class ServiceMethodCompiler(ProtoContentBase):
package=self.output_file.package,
imports=self.output_file.imports,
source_type=self.proto_obj.input_type,
typing_compiler=self.output_file.typing_compiler,
unwrap=False,
pydantic=self.output_file.pydantic_dataclasses,
).strip('"')
@@ -835,6 +810,7 @@ class ServiceMethodCompiler(ProtoContentBase):
package=self.output_file.package,
imports=self.output_file.imports,
source_type=self.proto_obj.output_type,
typing_compiler=self.output_file.typing_compiler,
unwrap=False,
pydantic=self.output_file.pydantic_dataclasses,
).strip('"')

View File

@@ -0,0 +1,163 @@
import re
from collections import defaultdict
from dataclasses import (
dataclass,
field,
)
from typing import (
Dict,
Iterator,
List,
Tuple,
)
@dataclass
class ModuleValidator:
line_iterator: Iterator[str]
line_number: int = field(init=False, default=0)
collisions: Dict[str, List[Tuple[int, str]]] = field(
init=False, default_factory=lambda: defaultdict(list)
)
def add_import(self, imp: str, number: int, full_line: str):
"""
Adds an import to be tracked.
"""
self.collisions[imp].append((number, full_line))
def process_import(self, imp: str):
"""
Filters out the import to its actual value.
"""
if " as " in imp:
imp = imp[imp.index(" as ") + 4 :]
imp = imp.strip()
assert " " not in imp, imp
return imp
def evaluate_multiline_import(self, line: str):
"""
Evaluates a multiline import from a starting line
"""
# Filter the first line and remove anything before the import statement.
full_line = line
line = line.split("import", 1)[1]
if "(" in line:
conditional = lambda line: ")" not in line
else:
conditional = lambda line: "\\" in line
# Remove open parenthesis if it exists.
if "(" in line:
line = line[line.index("(") + 1 :]
# Choose the conditional based on how multiline imports are formatted.
while conditional(line):
# Split the line by commas
imports = line.split(",")
for imp in imports:
# Add the import to the namespace
imp = self.process_import(imp)
if imp:
self.add_import(imp, self.line_number, full_line)
# Get the next line
full_line = line = next(self.line_iterator)
# Increment the line number
self.line_number += 1
# validate the last line
if ")" in line:
line = line[: line.index(")")]
imports = line.split(",")
for imp in imports:
imp = self.process_import(imp)
if imp:
self.add_import(imp, self.line_number, full_line)
def evaluate_import(self, line: str):
"""
Extracts an import from a line.
"""
whole_line = line
line = line[line.index("import") + 6 :]
values = line.split(",")
for v in values:
self.add_import(self.process_import(v), self.line_number, whole_line)
def next(self):
"""
Evaluate each line for names in the module.
"""
line = next(self.line_iterator)
# Skip lines with indentation or comments
if (
# Skip indents and whitespace.
line.startswith(" ")
or line == "\n"
or line.startswith("\t")
or
# Skip comments
line.startswith("#")
or
# Skip decorators
line.startswith("@")
):
self.line_number += 1
return
# Skip docstrings.
if line.startswith('"""') or line.startswith("'''"):
quote = line[0] * 3
line = line[3:]
while quote not in line:
line = next(self.line_iterator)
self.line_number += 1
return
# Evaluate Imports.
if line.startswith("from ") or line.startswith("import "):
if "(" in line or "\\" in line:
self.evaluate_multiline_import(line)
else:
self.evaluate_import(line)
# Evaluate Classes.
elif line.startswith("class "):
class_name = re.search(r"class (\w+)", line).group(1)
if class_name:
self.add_import(class_name, self.line_number, line)
# Evaluate Functions.
elif line.startswith("def "):
function_name = re.search(r"def (\w+)", line).group(1)
if function_name:
self.add_import(function_name, self.line_number, line)
# Evaluate direct assignments.
elif "=" in line:
assignment = re.search(r"(\w+)\s*=", line).group(1)
if assignment:
self.add_import(assignment, self.line_number, line)
self.line_number += 1
def validate(self) -> bool:
"""
Run Validation.
"""
try:
while True:
self.next()
except StopIteration:
pass
# Filter collisions for those with more than one value.
self.collisions = {k: v for k, v in self.collisions.items() if len(v) > 1}
# Return True if no collisions are found.
return not bool(self.collisions)

View File

@@ -37,6 +37,12 @@ from .models import (
is_map,
is_oneof,
)
from .typing_compiler import (
DirectImportTypingCompiler,
NoTyping310TypingCompiler,
TypingCompiler,
TypingImportTypingCompiler,
)
def traverse(
@@ -98,6 +104,28 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
output_package_name
].pydantic_dataclasses = True
# Gather any typing generation options.
typing_opts = [
opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")
]
if len(typing_opts) > 1:
raise ValueError("Multiple typing options provided")
# Set the compiler type.
typing_opt = typing_opts[0] if typing_opts else "direct"
if typing_opt == "direct":
request_data.output_packages[
output_package_name
].typing_compiler = DirectImportTypingCompiler()
elif typing_opt == "root":
request_data.output_packages[
output_package_name
].typing_compiler = TypingImportTypingCompiler()
elif typing_opt == "310":
request_data.output_packages[
output_package_name
].typing_compiler = NoTyping310TypingCompiler()
# Read Messages and Enums
# We need to read Messages before Services in so that we can
# get the references to input/output messages for each service
@@ -166,6 +194,7 @@ def _make_one_of_field_compiler(
parent=parent,
proto_obj=proto_obj,
path=path,
typing_compiler=output_package.typing_compiler,
)
@@ -181,7 +210,11 @@ def read_protobuf_type(
return
# Process Message
message_data = MessageCompiler(
source_file=source_file, parent=output_package, proto_obj=item, path=path
source_file=source_file,
parent=output_package,
proto_obj=item,
path=path,
typing_compiler=output_package.typing_compiler,
)
for index, field in enumerate(item.field):
if is_map(field, item):
@@ -190,6 +223,7 @@ def read_protobuf_type(
parent=message_data,
proto_obj=field,
path=path + [2, index],
typing_compiler=output_package.typing_compiler,
)
elif is_oneof(field):
_make_one_of_field_compiler(
@@ -201,11 +235,16 @@ def read_protobuf_type(
parent=message_data,
proto_obj=field,
path=path + [2, index],
typing_compiler=output_package.typing_compiler,
)
elif isinstance(item, EnumDescriptorProto):
# Enum
EnumDefinitionCompiler(
source_file=source_file, parent=output_package, proto_obj=item, path=path
source_file=source_file,
parent=output_package,
proto_obj=item,
path=path,
typing_compiler=output_package.typing_compiler,
)

View File

@@ -0,0 +1,167 @@
import abc
from collections import defaultdict
from dataclasses import (
dataclass,
field,
)
from typing import (
Dict,
Iterator,
Optional,
Set,
)
class TypingCompiler(metaclass=abc.ABCMeta):
@abc.abstractmethod
def optional(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def list(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def dict(self, key: str, value: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def union(self, *types: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def iterable(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def async_iterable(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def async_iterator(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def imports(self) -> Dict[str, Optional[Set[str]]]:
"""
Returns either the direct import as a key with none as value, or a set of
values to import from the key.
"""
raise NotImplementedError()
def import_lines(self) -> Iterator:
imports = self.imports()
for key, value in imports.items():
if value is None:
yield f"import {key}"
else:
yield f"from {key} import ("
for v in sorted(value):
yield f" {v},"
yield ")"
@dataclass
class DirectImportTypingCompiler(TypingCompiler):
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
def optional(self, type: str) -> str:
self._imports["typing"].add("Optional")
return f"Optional[{type}]"
def list(self, type: str) -> str:
self._imports["typing"].add("List")
return f"List[{type}]"
def dict(self, key: str, value: str) -> str:
self._imports["typing"].add("Dict")
return f"Dict[{key}, {value}]"
def union(self, *types: str) -> str:
self._imports["typing"].add("Union")
return f"Union[{', '.join(types)}]"
def iterable(self, type: str) -> str:
self._imports["typing"].add("Iterable")
return f"Iterable[{type}]"
def async_iterable(self, type: str) -> str:
self._imports["typing"].add("AsyncIterable")
return f"AsyncIterable[{type}]"
def async_iterator(self, type: str) -> str:
self._imports["typing"].add("AsyncIterator")
return f"AsyncIterator[{type}]"
def imports(self) -> Dict[str, Optional[Set[str]]]:
return {k: v if v else None for k, v in self._imports.items()}
@dataclass
class TypingImportTypingCompiler(TypingCompiler):
_imported: bool = False
def optional(self, type: str) -> str:
self._imported = True
return f"typing.Optional[{type}]"
def list(self, type: str) -> str:
self._imported = True
return f"typing.List[{type}]"
def dict(self, key: str, value: str) -> str:
self._imported = True
return f"typing.Dict[{key}, {value}]"
def union(self, *types: str) -> str:
self._imported = True
return f"typing.Union[{', '.join(types)}]"
def iterable(self, type: str) -> str:
self._imported = True
return f"typing.Iterable[{type}]"
def async_iterable(self, type: str) -> str:
self._imported = True
return f"typing.AsyncIterable[{type}]"
def async_iterator(self, type: str) -> str:
self._imported = True
return f"typing.AsyncIterator[{type}]"
def imports(self) -> Dict[str, Optional[Set[str]]]:
if self._imported:
return {"typing": None}
return {}
@dataclass
class NoTyping310TypingCompiler(TypingCompiler):
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
def optional(self, type: str) -> str:
return f"{type} | None"
def list(self, type: str) -> str:
return f"list[{type}]"
def dict(self, key: str, value: str) -> str:
return f"dict[{key}, {value}]"
def union(self, *types: str) -> str:
return " | ".join(types)
def iterable(self, type: str) -> str:
self._imports["typing"].add("Iterable")
return f"Iterable[{type}]"
def async_iterable(self, type: str) -> str:
self._imports["typing"].add("AsyncIterable")
return f"AsyncIterable[{type}]"
def async_iterator(self, type: str) -> str:
self._imports["typing"].add("AsyncIterator")
return f"AsyncIterator[{type}]"
def imports(self) -> Dict[str, Optional[Set[str]]]:
return {k: v if v else None for k, v in self._imports.items()}

View File

@@ -0,0 +1,54 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: {{ ', '.join(output_file.input_filenames) }}
# plugin: python-betterproto
# This file has been @generated
{% for i in output_file.python_module_imports|sort %}
import {{ i }}
{% endfor %}
{% set type_checking_imported = False %}
{% if output_file.pydantic_dataclasses %}
from typing import TYPE_CHECKING
{% set type_checking_imported = True %}
if TYPE_CHECKING:
from dataclasses import dataclass
else:
from pydantic.dataclasses import dataclass
{%- else -%}
from dataclasses import dataclass
{% endif %}
{% if output_file.datetime_imports %}
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif%}
{% set typing_imports = output_file.typing_compiler.imports() %}
{% if typing_imports %}
{% for line in output_file.typing_compiler.import_lines() %}
{{ line }}
{% endfor %}
{% endif %}
{% if output_file.pydantic_imports %}
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
import betterproto
{% if output_file.services %}
from betterproto.grpc.grpclib_server import ServiceBase
import grpclib
{% endif %}
{% for i in output_file.imports|sort %}
{{ i }}
{% endfor %}
{% if output_file.imports_type_checking_only and not type_checking_imported %}
from typing import TYPE_CHECKING
if TYPE_CHECKING:
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
{% endfor %}
{% endif %}

View File

@@ -1,53 +1,3 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: {{ ', '.join(output_file.input_filenames) }}
# plugin: python-betterproto
# This file has been @generated
{% for i in output_file.python_module_imports|sort %}
import {{ i }}
{% endfor %}
{% if output_file.pydantic_dataclasses %}
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dataclasses import dataclass
else:
from pydantic.dataclasses import dataclass
{%- else -%}
from dataclasses import dataclass
{% endif %}
{% if output_file.datetime_imports %}
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif%}
{% if output_file.typing_imports %}
from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
{% if output_file.pydantic_imports %}
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
import betterproto
{% if output_file.services %}
from betterproto.grpc.grpclib_server import ServiceBase
import grpclib
{% endif %}
{% for i in output_file.imports|sort %}
{{ i }}
{% endfor %}
{% if output_file.imports_type_checking_only %}
from typing import TYPE_CHECKING
if TYPE_CHECKING:
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
{% endfor %}
{% endif %}
{% if output_file.enums %}{% for enum in output_file.enums %}
class {{ enum.py_name }}(betterproto.Enum):
{% if enum.comment %}
@@ -116,14 +66,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }}
{%- endif -%}
,
*
, timeout: Optional[float] = None
, deadline: Optional["Deadline"] = None
, metadata: Optional["MetadataLike"] = None
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
, timeout: {{ output_file.typing_compiler.optional("float") }} = None
, deadline: {{ output_file.typing_compiler.optional('"Deadline"') }} = None
, metadata: {{ output_file.typing_compiler.optional('"MetadataLike"') }} = None
) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
{{ method.comment }}
@@ -191,9 +141,9 @@ class {{ service.py_name }}Base(ServiceBase):
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }}
{%- endif -%}
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
{{ method.comment }}
@@ -225,7 +175,7 @@ class {{ service.py_name }}Base(ServiceBase):
{% endfor %}
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
def __mapping__(self) -> {{ output_file.typing_compiler.dict("str", "grpclib.const.Handler") }}:
return {
{% for method in service.methods %}
"{{ method.route }}": grpclib.const.Handler(