Handle typing collisions and add validation to a files module for overlaping declarations (#582)
* Fix 'typing' import collisions. * Fix formatting. * Fix self-test issues. * Validation for modules, different typing configurations * add readme * make warning * fix format --------- Co-authored-by: Scott Hendricks <scott.hendricks@confluent.io>
This commit is contained in:
@@ -47,6 +47,7 @@ def get_type_reference(
|
||||
package: str,
|
||||
imports: set,
|
||||
source_type: str,
|
||||
typing_compiler: "TypingCompiler",
|
||||
unwrap: bool = True,
|
||||
pydantic: bool = False,
|
||||
) -> str:
|
||||
@@ -57,7 +58,7 @@ def get_type_reference(
|
||||
if unwrap:
|
||||
if source_type in WRAPPER_TYPES:
|
||||
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
|
||||
return f"Optional[{wrapped_type.__name__}]"
|
||||
return typing_compiler.optional(wrapped_type.__name__)
|
||||
|
||||
if source_type == ".google.protobuf.Duration":
|
||||
return "timedelta"
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
from .module_validation import ModuleValidator
|
||||
|
||||
|
||||
try:
|
||||
@@ -30,9 +33,12 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
|
||||
lstrip_blocks=True,
|
||||
loader=jinja2.FileSystemLoader(templates_folder),
|
||||
)
|
||||
template = env.get_template("template.py.j2")
|
||||
# Load the body first so we have a compleate list of imports needed.
|
||||
body_template = env.get_template("template.py.j2")
|
||||
header_template = env.get_template("header.py.j2")
|
||||
|
||||
code = template.render(output_file=output_file)
|
||||
code = body_template.render(output_file=output_file)
|
||||
code = header_template.render(output_file=output_file) + code
|
||||
code = isort.api.sort_code_string(
|
||||
code=code,
|
||||
show_diff=False,
|
||||
@@ -44,7 +50,18 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
|
||||
force_grid_wrap=2,
|
||||
known_third_party=["grpclib", "betterproto"],
|
||||
)
|
||||
return black.format_str(
|
||||
code = black.format_str(
|
||||
src_contents=code,
|
||||
mode=black.Mode(),
|
||||
)
|
||||
|
||||
# Validate the generated code.
|
||||
validator = ModuleValidator(iter(code.splitlines()))
|
||||
if not validator.validate():
|
||||
message_builder = ["[WARNING]: Generated code has collisions in the module:"]
|
||||
for collision, lines in validator.collisions.items():
|
||||
message_builder.append(f' "{collision}" on lines:')
|
||||
for num, line in lines:
|
||||
message_builder.append(f" {num}:{line}")
|
||||
print("\n".join(message_builder), file=sys.stderr)
|
||||
return code
|
||||
|
||||
@@ -29,10 +29,8 @@ instantiating field `A` with parent message `B` should add a
|
||||
reference to `A` to `B`'s `fields` attribute.
|
||||
"""
|
||||
|
||||
|
||||
import builtins
|
||||
import re
|
||||
import textwrap
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
@@ -49,12 +47,6 @@ from typing import (
|
||||
)
|
||||
|
||||
import betterproto
|
||||
from betterproto import which_one_of
|
||||
from betterproto.casing import sanitize_name
|
||||
from betterproto.compile.importing import (
|
||||
get_type_reference,
|
||||
parse_source_type_name,
|
||||
)
|
||||
from betterproto.compile.naming import (
|
||||
pythonize_class_name,
|
||||
pythonize_field_name,
|
||||
@@ -72,6 +64,7 @@ from betterproto.lib.google.protobuf import (
|
||||
)
|
||||
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
|
||||
|
||||
from .. import which_one_of
|
||||
from ..compile.importing import (
|
||||
get_type_reference,
|
||||
parse_source_type_name,
|
||||
@@ -82,6 +75,10 @@ from ..compile.naming import (
|
||||
pythonize_field_name,
|
||||
pythonize_method_name,
|
||||
)
|
||||
from .typing_compiler import (
|
||||
DirectImportTypingCompiler,
|
||||
TypingCompiler,
|
||||
)
|
||||
|
||||
|
||||
# Create a unique placeholder to deal with
|
||||
@@ -173,6 +170,7 @@ class ProtoContentBase:
|
||||
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
|
||||
|
||||
source_file: FileDescriptorProto
|
||||
typing_compiler: TypingCompiler
|
||||
path: List[int]
|
||||
comment_indent: int = 4
|
||||
parent: Union["betterproto.Message", "OutputTemplate"]
|
||||
@@ -242,7 +240,6 @@ class OutputTemplate:
|
||||
input_files: List[str] = field(default_factory=list)
|
||||
imports: Set[str] = field(default_factory=set)
|
||||
datetime_imports: Set[str] = field(default_factory=set)
|
||||
typing_imports: Set[str] = field(default_factory=set)
|
||||
pydantic_imports: Set[str] = field(default_factory=set)
|
||||
builtins_import: bool = False
|
||||
messages: List["MessageCompiler"] = field(default_factory=list)
|
||||
@@ -251,6 +248,7 @@ class OutputTemplate:
|
||||
imports_type_checking_only: Set[str] = field(default_factory=set)
|
||||
pydantic_dataclasses: bool = False
|
||||
output: bool = True
|
||||
typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)
|
||||
|
||||
@property
|
||||
def package(self) -> str:
|
||||
@@ -289,6 +287,7 @@ class MessageCompiler(ProtoContentBase):
|
||||
"""Representation of a protobuf message."""
|
||||
|
||||
source_file: FileDescriptorProto
|
||||
typing_compiler: TypingCompiler
|
||||
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
|
||||
proto_obj: DescriptorProto = PLACEHOLDER
|
||||
path: List[int] = PLACEHOLDER
|
||||
@@ -319,7 +318,7 @@ class MessageCompiler(ProtoContentBase):
|
||||
@property
|
||||
def annotation(self) -> str:
|
||||
if self.repeated:
|
||||
return f"List[{self.py_name}]"
|
||||
return self.typing_compiler.list(self.py_name)
|
||||
return self.py_name
|
||||
|
||||
@property
|
||||
@@ -434,18 +433,6 @@ class FieldCompiler(MessageCompiler):
|
||||
imports.add("datetime")
|
||||
return imports
|
||||
|
||||
@property
|
||||
def typing_imports(self) -> Set[str]:
|
||||
imports = set()
|
||||
annotation = self.annotation
|
||||
if "Optional[" in annotation:
|
||||
imports.add("Optional")
|
||||
if "List[" in annotation:
|
||||
imports.add("List")
|
||||
if "Dict[" in annotation:
|
||||
imports.add("Dict")
|
||||
return imports
|
||||
|
||||
@property
|
||||
def pydantic_imports(self) -> Set[str]:
|
||||
return set()
|
||||
@@ -458,7 +445,6 @@ class FieldCompiler(MessageCompiler):
|
||||
|
||||
def add_imports_to(self, output_file: OutputTemplate) -> None:
|
||||
output_file.datetime_imports.update(self.datetime_imports)
|
||||
output_file.typing_imports.update(self.typing_imports)
|
||||
output_file.pydantic_imports.update(self.pydantic_imports)
|
||||
output_file.builtins_import = output_file.builtins_import or self.use_builtins
|
||||
|
||||
@@ -488,7 +474,9 @@ class FieldCompiler(MessageCompiler):
|
||||
@property
|
||||
def mutable(self) -> bool:
|
||||
"""True if the field is a mutable type, otherwise False."""
|
||||
return self.annotation.startswith(("List[", "Dict["))
|
||||
return self.annotation.startswith(
|
||||
("typing.List[", "typing.Dict[", "dict[", "list[", "Dict[", "List[")
|
||||
)
|
||||
|
||||
@property
|
||||
def field_type(self) -> str:
|
||||
@@ -562,6 +550,7 @@ class FieldCompiler(MessageCompiler):
|
||||
package=self.output_file.package,
|
||||
imports=self.output_file.imports,
|
||||
source_type=self.proto_obj.type_name,
|
||||
typing_compiler=self.typing_compiler,
|
||||
pydantic=self.output_file.pydantic_dataclasses,
|
||||
)
|
||||
else:
|
||||
@@ -573,9 +562,9 @@ class FieldCompiler(MessageCompiler):
|
||||
if self.use_builtins:
|
||||
py_type = f"builtins.{py_type}"
|
||||
if self.repeated:
|
||||
return f"List[{py_type}]"
|
||||
return self.typing_compiler.list(py_type)
|
||||
if self.optional:
|
||||
return f"Optional[{py_type}]"
|
||||
return self.typing_compiler.optional(py_type)
|
||||
return py_type
|
||||
|
||||
|
||||
@@ -623,11 +612,13 @@ class MapEntryCompiler(FieldCompiler):
|
||||
source_file=self.source_file,
|
||||
parent=self,
|
||||
proto_obj=nested.field[0], # key
|
||||
typing_compiler=self.typing_compiler,
|
||||
).py_type
|
||||
self.py_v_type = FieldCompiler(
|
||||
source_file=self.source_file,
|
||||
parent=self,
|
||||
proto_obj=nested.field[1], # value
|
||||
typing_compiler=self.typing_compiler,
|
||||
).py_type
|
||||
|
||||
# Get proto types
|
||||
@@ -645,7 +636,7 @@ class MapEntryCompiler(FieldCompiler):
|
||||
|
||||
@property
|
||||
def annotation(self) -> str:
|
||||
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
|
||||
return self.typing_compiler.dict(self.py_k_type, self.py_v_type)
|
||||
|
||||
@property
|
||||
def repeated(self) -> bool:
|
||||
@@ -702,7 +693,6 @@ class ServiceCompiler(ProtoContentBase):
|
||||
def __post_init__(self) -> None:
|
||||
# Add service to output file
|
||||
self.output_file.services.append(self)
|
||||
self.output_file.typing_imports.add("Dict")
|
||||
super().__post_init__() # check for unset fields
|
||||
|
||||
@property
|
||||
@@ -725,22 +715,6 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
# Add method to service
|
||||
self.parent.methods.append(self)
|
||||
|
||||
# Check for imports
|
||||
if "Optional" in self.py_output_message_type:
|
||||
self.output_file.typing_imports.add("Optional")
|
||||
|
||||
# Check for Async imports
|
||||
if self.client_streaming:
|
||||
self.output_file.typing_imports.add("AsyncIterable")
|
||||
self.output_file.typing_imports.add("Iterable")
|
||||
self.output_file.typing_imports.add("Union")
|
||||
|
||||
# Required by both client and server
|
||||
if self.client_streaming or self.server_streaming:
|
||||
self.output_file.typing_imports.add("AsyncIterator")
|
||||
|
||||
# add imports required for request arguments timeout, deadline and metadata
|
||||
self.output_file.typing_imports.add("Optional")
|
||||
self.output_file.imports_type_checking_only.add("import grpclib.server")
|
||||
self.output_file.imports_type_checking_only.add(
|
||||
"from betterproto.grpc.grpclib_client import MetadataLike"
|
||||
@@ -806,6 +780,7 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
package=self.output_file.package,
|
||||
imports=self.output_file.imports,
|
||||
source_type=self.proto_obj.input_type,
|
||||
typing_compiler=self.output_file.typing_compiler,
|
||||
unwrap=False,
|
||||
pydantic=self.output_file.pydantic_dataclasses,
|
||||
).strip('"')
|
||||
@@ -835,6 +810,7 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
package=self.output_file.package,
|
||||
imports=self.output_file.imports,
|
||||
source_type=self.proto_obj.output_type,
|
||||
typing_compiler=self.output_file.typing_compiler,
|
||||
unwrap=False,
|
||||
pydantic=self.output_file.pydantic_dataclasses,
|
||||
).strip('"')
|
||||
|
||||
163
src/betterproto/plugin/module_validation.py
Normal file
163
src/betterproto/plugin/module_validation.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
)
|
||||
from typing import (
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleValidator:
|
||||
line_iterator: Iterator[str]
|
||||
line_number: int = field(init=False, default=0)
|
||||
|
||||
collisions: Dict[str, List[Tuple[int, str]]] = field(
|
||||
init=False, default_factory=lambda: defaultdict(list)
|
||||
)
|
||||
|
||||
def add_import(self, imp: str, number: int, full_line: str):
|
||||
"""
|
||||
Adds an import to be tracked.
|
||||
"""
|
||||
self.collisions[imp].append((number, full_line))
|
||||
|
||||
def process_import(self, imp: str):
|
||||
"""
|
||||
Filters out the import to its actual value.
|
||||
"""
|
||||
if " as " in imp:
|
||||
imp = imp[imp.index(" as ") + 4 :]
|
||||
|
||||
imp = imp.strip()
|
||||
assert " " not in imp, imp
|
||||
return imp
|
||||
|
||||
def evaluate_multiline_import(self, line: str):
|
||||
"""
|
||||
Evaluates a multiline import from a starting line
|
||||
"""
|
||||
# Filter the first line and remove anything before the import statement.
|
||||
full_line = line
|
||||
line = line.split("import", 1)[1]
|
||||
if "(" in line:
|
||||
conditional = lambda line: ")" not in line
|
||||
else:
|
||||
conditional = lambda line: "\\" in line
|
||||
|
||||
# Remove open parenthesis if it exists.
|
||||
if "(" in line:
|
||||
line = line[line.index("(") + 1 :]
|
||||
|
||||
# Choose the conditional based on how multiline imports are formatted.
|
||||
while conditional(line):
|
||||
# Split the line by commas
|
||||
imports = line.split(",")
|
||||
|
||||
for imp in imports:
|
||||
# Add the import to the namespace
|
||||
imp = self.process_import(imp)
|
||||
if imp:
|
||||
self.add_import(imp, self.line_number, full_line)
|
||||
# Get the next line
|
||||
full_line = line = next(self.line_iterator)
|
||||
# Increment the line number
|
||||
self.line_number += 1
|
||||
|
||||
# validate the last line
|
||||
if ")" in line:
|
||||
line = line[: line.index(")")]
|
||||
imports = line.split(",")
|
||||
for imp in imports:
|
||||
imp = self.process_import(imp)
|
||||
if imp:
|
||||
self.add_import(imp, self.line_number, full_line)
|
||||
|
||||
def evaluate_import(self, line: str):
|
||||
"""
|
||||
Extracts an import from a line.
|
||||
"""
|
||||
whole_line = line
|
||||
line = line[line.index("import") + 6 :]
|
||||
values = line.split(",")
|
||||
for v in values:
|
||||
self.add_import(self.process_import(v), self.line_number, whole_line)
|
||||
|
||||
def next(self):
|
||||
"""
|
||||
Evaluate each line for names in the module.
|
||||
"""
|
||||
line = next(self.line_iterator)
|
||||
|
||||
# Skip lines with indentation or comments
|
||||
if (
|
||||
# Skip indents and whitespace.
|
||||
line.startswith(" ")
|
||||
or line == "\n"
|
||||
or line.startswith("\t")
|
||||
or
|
||||
# Skip comments
|
||||
line.startswith("#")
|
||||
or
|
||||
# Skip decorators
|
||||
line.startswith("@")
|
||||
):
|
||||
self.line_number += 1
|
||||
return
|
||||
|
||||
# Skip docstrings.
|
||||
if line.startswith('"""') or line.startswith("'''"):
|
||||
quote = line[0] * 3
|
||||
line = line[3:]
|
||||
while quote not in line:
|
||||
line = next(self.line_iterator)
|
||||
self.line_number += 1
|
||||
return
|
||||
|
||||
# Evaluate Imports.
|
||||
if line.startswith("from ") or line.startswith("import "):
|
||||
if "(" in line or "\\" in line:
|
||||
self.evaluate_multiline_import(line)
|
||||
else:
|
||||
self.evaluate_import(line)
|
||||
|
||||
# Evaluate Classes.
|
||||
elif line.startswith("class "):
|
||||
class_name = re.search(r"class (\w+)", line).group(1)
|
||||
if class_name:
|
||||
self.add_import(class_name, self.line_number, line)
|
||||
|
||||
# Evaluate Functions.
|
||||
elif line.startswith("def "):
|
||||
function_name = re.search(r"def (\w+)", line).group(1)
|
||||
if function_name:
|
||||
self.add_import(function_name, self.line_number, line)
|
||||
|
||||
# Evaluate direct assignments.
|
||||
elif "=" in line:
|
||||
assignment = re.search(r"(\w+)\s*=", line).group(1)
|
||||
if assignment:
|
||||
self.add_import(assignment, self.line_number, line)
|
||||
|
||||
self.line_number += 1
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""
|
||||
Run Validation.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
self.next()
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
# Filter collisions for those with more than one value.
|
||||
self.collisions = {k: v for k, v in self.collisions.items() if len(v) > 1}
|
||||
|
||||
# Return True if no collisions are found.
|
||||
return not bool(self.collisions)
|
||||
@@ -37,6 +37,12 @@ from .models import (
|
||||
is_map,
|
||||
is_oneof,
|
||||
)
|
||||
from .typing_compiler import (
|
||||
DirectImportTypingCompiler,
|
||||
NoTyping310TypingCompiler,
|
||||
TypingCompiler,
|
||||
TypingImportTypingCompiler,
|
||||
)
|
||||
|
||||
|
||||
def traverse(
|
||||
@@ -98,6 +104,28 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
||||
output_package_name
|
||||
].pydantic_dataclasses = True
|
||||
|
||||
# Gather any typing generation options.
|
||||
typing_opts = [
|
||||
opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")
|
||||
]
|
||||
|
||||
if len(typing_opts) > 1:
|
||||
raise ValueError("Multiple typing options provided")
|
||||
# Set the compiler type.
|
||||
typing_opt = typing_opts[0] if typing_opts else "direct"
|
||||
if typing_opt == "direct":
|
||||
request_data.output_packages[
|
||||
output_package_name
|
||||
].typing_compiler = DirectImportTypingCompiler()
|
||||
elif typing_opt == "root":
|
||||
request_data.output_packages[
|
||||
output_package_name
|
||||
].typing_compiler = TypingImportTypingCompiler()
|
||||
elif typing_opt == "310":
|
||||
request_data.output_packages[
|
||||
output_package_name
|
||||
].typing_compiler = NoTyping310TypingCompiler()
|
||||
|
||||
# Read Messages and Enums
|
||||
# We need to read Messages before Services in so that we can
|
||||
# get the references to input/output messages for each service
|
||||
@@ -166,6 +194,7 @@ def _make_one_of_field_compiler(
|
||||
parent=parent,
|
||||
proto_obj=proto_obj,
|
||||
path=path,
|
||||
typing_compiler=output_package.typing_compiler,
|
||||
)
|
||||
|
||||
|
||||
@@ -181,7 +210,11 @@ def read_protobuf_type(
|
||||
return
|
||||
# Process Message
|
||||
message_data = MessageCompiler(
|
||||
source_file=source_file, parent=output_package, proto_obj=item, path=path
|
||||
source_file=source_file,
|
||||
parent=output_package,
|
||||
proto_obj=item,
|
||||
path=path,
|
||||
typing_compiler=output_package.typing_compiler,
|
||||
)
|
||||
for index, field in enumerate(item.field):
|
||||
if is_map(field, item):
|
||||
@@ -190,6 +223,7 @@ def read_protobuf_type(
|
||||
parent=message_data,
|
||||
proto_obj=field,
|
||||
path=path + [2, index],
|
||||
typing_compiler=output_package.typing_compiler,
|
||||
)
|
||||
elif is_oneof(field):
|
||||
_make_one_of_field_compiler(
|
||||
@@ -201,11 +235,16 @@ def read_protobuf_type(
|
||||
parent=message_data,
|
||||
proto_obj=field,
|
||||
path=path + [2, index],
|
||||
typing_compiler=output_package.typing_compiler,
|
||||
)
|
||||
elif isinstance(item, EnumDescriptorProto):
|
||||
# Enum
|
||||
EnumDefinitionCompiler(
|
||||
source_file=source_file, parent=output_package, proto_obj=item, path=path
|
||||
source_file=source_file,
|
||||
parent=output_package,
|
||||
proto_obj=item,
|
||||
path=path,
|
||||
typing_compiler=output_package.typing_compiler,
|
||||
)
|
||||
|
||||
|
||||
|
||||
167
src/betterproto/plugin/typing_compiler.py
Normal file
167
src/betterproto/plugin/typing_compiler.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import abc
|
||||
from collections import defaultdict
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
)
|
||||
from typing import (
|
||||
Dict,
|
||||
Iterator,
|
||||
Optional,
|
||||
Set,
|
||||
)
|
||||
|
||||
|
||||
class TypingCompiler(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def optional(self, type: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def list(self, type: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def dict(self, key: str, value: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def union(self, *types: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def iterable(self, type: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def async_iterable(self, type: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def async_iterator(self, type: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def imports(self) -> Dict[str, Optional[Set[str]]]:
|
||||
"""
|
||||
Returns either the direct import as a key with none as value, or a set of
|
||||
values to import from the key.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def import_lines(self) -> Iterator:
|
||||
imports = self.imports()
|
||||
for key, value in imports.items():
|
||||
if value is None:
|
||||
yield f"import {key}"
|
||||
else:
|
||||
yield f"from {key} import ("
|
||||
for v in sorted(value):
|
||||
yield f" {v},"
|
||||
yield ")"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DirectImportTypingCompiler(TypingCompiler):
|
||||
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
|
||||
|
||||
def optional(self, type: str) -> str:
|
||||
self._imports["typing"].add("Optional")
|
||||
return f"Optional[{type}]"
|
||||
|
||||
def list(self, type: str) -> str:
|
||||
self._imports["typing"].add("List")
|
||||
return f"List[{type}]"
|
||||
|
||||
def dict(self, key: str, value: str) -> str:
|
||||
self._imports["typing"].add("Dict")
|
||||
return f"Dict[{key}, {value}]"
|
||||
|
||||
def union(self, *types: str) -> str:
|
||||
self._imports["typing"].add("Union")
|
||||
return f"Union[{', '.join(types)}]"
|
||||
|
||||
def iterable(self, type: str) -> str:
|
||||
self._imports["typing"].add("Iterable")
|
||||
return f"Iterable[{type}]"
|
||||
|
||||
def async_iterable(self, type: str) -> str:
|
||||
self._imports["typing"].add("AsyncIterable")
|
||||
return f"AsyncIterable[{type}]"
|
||||
|
||||
def async_iterator(self, type: str) -> str:
|
||||
self._imports["typing"].add("AsyncIterator")
|
||||
return f"AsyncIterator[{type}]"
|
||||
|
||||
def imports(self) -> Dict[str, Optional[Set[str]]]:
|
||||
return {k: v if v else None for k, v in self._imports.items()}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypingImportTypingCompiler(TypingCompiler):
|
||||
_imported: bool = False
|
||||
|
||||
def optional(self, type: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.Optional[{type}]"
|
||||
|
||||
def list(self, type: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.List[{type}]"
|
||||
|
||||
def dict(self, key: str, value: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.Dict[{key}, {value}]"
|
||||
|
||||
def union(self, *types: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.Union[{', '.join(types)}]"
|
||||
|
||||
def iterable(self, type: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.Iterable[{type}]"
|
||||
|
||||
def async_iterable(self, type: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.AsyncIterable[{type}]"
|
||||
|
||||
def async_iterator(self, type: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.AsyncIterator[{type}]"
|
||||
|
||||
def imports(self) -> Dict[str, Optional[Set[str]]]:
|
||||
if self._imported:
|
||||
return {"typing": None}
|
||||
return {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class NoTyping310TypingCompiler(TypingCompiler):
|
||||
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
|
||||
|
||||
def optional(self, type: str) -> str:
|
||||
return f"{type} | None"
|
||||
|
||||
def list(self, type: str) -> str:
|
||||
return f"list[{type}]"
|
||||
|
||||
def dict(self, key: str, value: str) -> str:
|
||||
return f"dict[{key}, {value}]"
|
||||
|
||||
def union(self, *types: str) -> str:
|
||||
return " | ".join(types)
|
||||
|
||||
def iterable(self, type: str) -> str:
|
||||
self._imports["typing"].add("Iterable")
|
||||
return f"Iterable[{type}]"
|
||||
|
||||
def async_iterable(self, type: str) -> str:
|
||||
self._imports["typing"].add("AsyncIterable")
|
||||
return f"AsyncIterable[{type}]"
|
||||
|
||||
def async_iterator(self, type: str) -> str:
|
||||
self._imports["typing"].add("AsyncIterator")
|
||||
return f"AsyncIterator[{type}]"
|
||||
|
||||
def imports(self) -> Dict[str, Optional[Set[str]]]:
|
||||
return {k: v if v else None for k, v in self._imports.items()}
|
||||
54
src/betterproto/templates/header.py.j2
Normal file
54
src/betterproto/templates/header.py.j2
Normal file
@@ -0,0 +1,54 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# sources: {{ ', '.join(output_file.input_filenames) }}
|
||||
# plugin: python-betterproto
|
||||
# This file has been @generated
|
||||
{% for i in output_file.python_module_imports|sort %}
|
||||
import {{ i }}
|
||||
{% endfor %}
|
||||
{% set type_checking_imported = False %}
|
||||
|
||||
{% if output_file.pydantic_dataclasses %}
|
||||
from typing import TYPE_CHECKING
|
||||
{% set type_checking_imported = True %}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dataclasses import dataclass
|
||||
else:
|
||||
from pydantic.dataclasses import dataclass
|
||||
{%- else -%}
|
||||
from dataclasses import dataclass
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.datetime_imports %}
|
||||
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif%}
|
||||
{% set typing_imports = output_file.typing_compiler.imports() %}
|
||||
{% if typing_imports %}
|
||||
{% for line in output_file.typing_compiler.import_lines() %}
|
||||
{{ line }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.pydantic_imports %}
|
||||
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif %}
|
||||
|
||||
import betterproto
|
||||
{% if output_file.services %}
|
||||
from betterproto.grpc.grpclib_server import ServiceBase
|
||||
import grpclib
|
||||
{% endif %}
|
||||
|
||||
{% for i in output_file.imports|sort %}
|
||||
{{ i }}
|
||||
{% endfor %}
|
||||
|
||||
{% if output_file.imports_type_checking_only and not type_checking_imported %}
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
@@ -1,53 +1,3 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# sources: {{ ', '.join(output_file.input_filenames) }}
|
||||
# plugin: python-betterproto
|
||||
# This file has been @generated
|
||||
{% for i in output_file.python_module_imports|sort %}
|
||||
import {{ i }}
|
||||
{% endfor %}
|
||||
|
||||
{% if output_file.pydantic_dataclasses %}
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from dataclasses import dataclass
|
||||
else:
|
||||
from pydantic.dataclasses import dataclass
|
||||
{%- else -%}
|
||||
from dataclasses import dataclass
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.datetime_imports %}
|
||||
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif%}
|
||||
{% if output_file.typing_imports %}
|
||||
from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.pydantic_imports %}
|
||||
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif %}
|
||||
|
||||
import betterproto
|
||||
{% if output_file.services %}
|
||||
from betterproto.grpc.grpclib_server import ServiceBase
|
||||
import grpclib
|
||||
{% endif %}
|
||||
|
||||
{% for i in output_file.imports|sort %}
|
||||
{{ i }}
|
||||
{% endfor %}
|
||||
|
||||
{% if output_file.imports_type_checking_only %}
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.enums %}{% for enum in output_file.enums %}
|
||||
class {{ enum.py_name }}(betterproto.Enum):
|
||||
{% if enum.comment %}
|
||||
@@ -116,14 +66,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
|
||||
{%- else -%}
|
||||
{# Client streaming: need a request iterator instead #}
|
||||
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
|
||||
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }}
|
||||
{%- endif -%}
|
||||
,
|
||||
*
|
||||
, timeout: Optional[float] = None
|
||||
, deadline: Optional["Deadline"] = None
|
||||
, metadata: Optional["MetadataLike"] = None
|
||||
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||
, timeout: {{ output_file.typing_compiler.optional("float") }} = None
|
||||
, deadline: {{ output_file.typing_compiler.optional('"Deadline"') }} = None
|
||||
, metadata: {{ output_file.typing_compiler.optional('"MetadataLike"') }} = None
|
||||
) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||
{% if method.comment %}
|
||||
{{ method.comment }}
|
||||
|
||||
@@ -191,9 +141,9 @@ class {{ service.py_name }}Base(ServiceBase):
|
||||
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
|
||||
{%- else -%}
|
||||
{# Client streaming: need a request iterator instead #}
|
||||
, {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
|
||||
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }}
|
||||
{%- endif -%}
|
||||
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||
) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||
{% if method.comment %}
|
||||
{{ method.comment }}
|
||||
|
||||
@@ -225,7 +175,7 @@ class {{ service.py_name }}Base(ServiceBase):
|
||||
|
||||
{% endfor %}
|
||||
|
||||
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
|
||||
def __mapping__(self) -> {{ output_file.typing_compiler.dict("str", "grpclib.const.Handler") }}:
|
||||
return {
|
||||
{% for method in service.methods %}
|
||||
"{{ method.route }}": grpclib.const.Handler(
|
||||
|
||||
Reference in New Issue
Block a user