Handle typing collisions and add validation to a files module for overlaping declarations (#582)
* Fix 'typing' import collisions. * Fix formatting. * Fix self-test issues. * Validation for modules, different typing configurations * add readme * make warning * fix format --------- Co-authored-by: Scott Hendricks <scott.hendricks@confluent.io>
This commit is contained in:
		| @@ -47,6 +47,7 @@ def get_type_reference( | ||||
|     package: str, | ||||
|     imports: set, | ||||
|     source_type: str, | ||||
|     typing_compiler: "TypingCompiler", | ||||
|     unwrap: bool = True, | ||||
|     pydantic: bool = False, | ||||
| ) -> str: | ||||
| @@ -57,7 +58,7 @@ def get_type_reference( | ||||
|     if unwrap: | ||||
|         if source_type in WRAPPER_TYPES: | ||||
|             wrapped_type = type(WRAPPER_TYPES[source_type]().value) | ||||
|             return f"Optional[{wrapped_type.__name__}]" | ||||
|             return typing_compiler.optional(wrapped_type.__name__) | ||||
|  | ||||
|         if source_type == ".google.protobuf.Duration": | ||||
|             return "timedelta" | ||||
|   | ||||
| @@ -1,4 +1,7 @@ | ||||
| import os.path | ||||
| import sys | ||||
|  | ||||
| from .module_validation import ModuleValidator | ||||
|  | ||||
|  | ||||
| try: | ||||
| @@ -30,9 +33,12 @@ def outputfile_compiler(output_file: OutputTemplate) -> str: | ||||
|         lstrip_blocks=True, | ||||
|         loader=jinja2.FileSystemLoader(templates_folder), | ||||
|     ) | ||||
|     template = env.get_template("template.py.j2") | ||||
|     # Load the body first so we have a compleate list of imports needed. | ||||
|     body_template = env.get_template("template.py.j2") | ||||
|     header_template = env.get_template("header.py.j2") | ||||
|  | ||||
|     code = template.render(output_file=output_file) | ||||
|     code = body_template.render(output_file=output_file) | ||||
|     code = header_template.render(output_file=output_file) + code | ||||
|     code = isort.api.sort_code_string( | ||||
|         code=code, | ||||
|         show_diff=False, | ||||
| @@ -44,7 +50,18 @@ def outputfile_compiler(output_file: OutputTemplate) -> str: | ||||
|         force_grid_wrap=2, | ||||
|         known_third_party=["grpclib", "betterproto"], | ||||
|     ) | ||||
|     return black.format_str( | ||||
|     code = black.format_str( | ||||
|         src_contents=code, | ||||
|         mode=black.Mode(), | ||||
|     ) | ||||
|  | ||||
|     # Validate the generated code. | ||||
|     validator = ModuleValidator(iter(code.splitlines())) | ||||
|     if not validator.validate(): | ||||
|         message_builder = ["[WARNING]: Generated code has collisions in the module:"] | ||||
|         for collision, lines in validator.collisions.items(): | ||||
|             message_builder.append(f'  "{collision}" on lines:') | ||||
|             for num, line in lines: | ||||
|                 message_builder.append(f"    {num}:{line}") | ||||
|         print("\n".join(message_builder), file=sys.stderr) | ||||
|     return code | ||||
|   | ||||
| @@ -29,10 +29,8 @@ instantiating field `A` with parent message `B` should add a | ||||
| reference to `A` to `B`'s `fields` attribute. | ||||
| """ | ||||
|  | ||||
|  | ||||
| import builtins | ||||
| import re | ||||
| import textwrap | ||||
| from dataclasses import ( | ||||
|     dataclass, | ||||
|     field, | ||||
| @@ -49,12 +47,6 @@ from typing import ( | ||||
| ) | ||||
|  | ||||
| import betterproto | ||||
| from betterproto import which_one_of | ||||
| from betterproto.casing import sanitize_name | ||||
| from betterproto.compile.importing import ( | ||||
|     get_type_reference, | ||||
|     parse_source_type_name, | ||||
| ) | ||||
| from betterproto.compile.naming import ( | ||||
|     pythonize_class_name, | ||||
|     pythonize_field_name, | ||||
| @@ -72,6 +64,7 @@ from betterproto.lib.google.protobuf import ( | ||||
| ) | ||||
| from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest | ||||
|  | ||||
| from .. import which_one_of | ||||
| from ..compile.importing import ( | ||||
|     get_type_reference, | ||||
|     parse_source_type_name, | ||||
| @@ -82,6 +75,10 @@ from ..compile.naming import ( | ||||
|     pythonize_field_name, | ||||
|     pythonize_method_name, | ||||
| ) | ||||
| from .typing_compiler import ( | ||||
|     DirectImportTypingCompiler, | ||||
|     TypingCompiler, | ||||
| ) | ||||
|  | ||||
|  | ||||
| # Create a unique placeholder to deal with | ||||
| @@ -173,6 +170,7 @@ class ProtoContentBase: | ||||
|     """Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler.""" | ||||
|  | ||||
|     source_file: FileDescriptorProto | ||||
|     typing_compiler: TypingCompiler | ||||
|     path: List[int] | ||||
|     comment_indent: int = 4 | ||||
|     parent: Union["betterproto.Message", "OutputTemplate"] | ||||
| @@ -242,7 +240,6 @@ class OutputTemplate: | ||||
|     input_files: List[str] = field(default_factory=list) | ||||
|     imports: Set[str] = field(default_factory=set) | ||||
|     datetime_imports: Set[str] = field(default_factory=set) | ||||
|     typing_imports: Set[str] = field(default_factory=set) | ||||
|     pydantic_imports: Set[str] = field(default_factory=set) | ||||
|     builtins_import: bool = False | ||||
|     messages: List["MessageCompiler"] = field(default_factory=list) | ||||
| @@ -251,6 +248,7 @@ class OutputTemplate: | ||||
|     imports_type_checking_only: Set[str] = field(default_factory=set) | ||||
|     pydantic_dataclasses: bool = False | ||||
|     output: bool = True | ||||
|     typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler) | ||||
|  | ||||
|     @property | ||||
|     def package(self) -> str: | ||||
| @@ -289,6 +287,7 @@ class MessageCompiler(ProtoContentBase): | ||||
|     """Representation of a protobuf message.""" | ||||
|  | ||||
|     source_file: FileDescriptorProto | ||||
|     typing_compiler: TypingCompiler | ||||
|     parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER | ||||
|     proto_obj: DescriptorProto = PLACEHOLDER | ||||
|     path: List[int] = PLACEHOLDER | ||||
| @@ -319,7 +318,7 @@ class MessageCompiler(ProtoContentBase): | ||||
|     @property | ||||
|     def annotation(self) -> str: | ||||
|         if self.repeated: | ||||
|             return f"List[{self.py_name}]" | ||||
|             return self.typing_compiler.list(self.py_name) | ||||
|         return self.py_name | ||||
|  | ||||
|     @property | ||||
| @@ -434,18 +433,6 @@ class FieldCompiler(MessageCompiler): | ||||
|             imports.add("datetime") | ||||
|         return imports | ||||
|  | ||||
|     @property | ||||
|     def typing_imports(self) -> Set[str]: | ||||
|         imports = set() | ||||
|         annotation = self.annotation | ||||
|         if "Optional[" in annotation: | ||||
|             imports.add("Optional") | ||||
|         if "List[" in annotation: | ||||
|             imports.add("List") | ||||
|         if "Dict[" in annotation: | ||||
|             imports.add("Dict") | ||||
|         return imports | ||||
|  | ||||
|     @property | ||||
|     def pydantic_imports(self) -> Set[str]: | ||||
|         return set() | ||||
| @@ -458,7 +445,6 @@ class FieldCompiler(MessageCompiler): | ||||
|  | ||||
|     def add_imports_to(self, output_file: OutputTemplate) -> None: | ||||
|         output_file.datetime_imports.update(self.datetime_imports) | ||||
|         output_file.typing_imports.update(self.typing_imports) | ||||
|         output_file.pydantic_imports.update(self.pydantic_imports) | ||||
|         output_file.builtins_import = output_file.builtins_import or self.use_builtins | ||||
|  | ||||
| @@ -488,7 +474,9 @@ class FieldCompiler(MessageCompiler): | ||||
|     @property | ||||
|     def mutable(self) -> bool: | ||||
|         """True if the field is a mutable type, otherwise False.""" | ||||
|         return self.annotation.startswith(("List[", "Dict[")) | ||||
|         return self.annotation.startswith( | ||||
|             ("typing.List[", "typing.Dict[", "dict[", "list[", "Dict[", "List[") | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def field_type(self) -> str: | ||||
| @@ -562,6 +550,7 @@ class FieldCompiler(MessageCompiler): | ||||
|                 package=self.output_file.package, | ||||
|                 imports=self.output_file.imports, | ||||
|                 source_type=self.proto_obj.type_name, | ||||
|                 typing_compiler=self.typing_compiler, | ||||
|                 pydantic=self.output_file.pydantic_dataclasses, | ||||
|             ) | ||||
|         else: | ||||
| @@ -573,9 +562,9 @@ class FieldCompiler(MessageCompiler): | ||||
|         if self.use_builtins: | ||||
|             py_type = f"builtins.{py_type}" | ||||
|         if self.repeated: | ||||
|             return f"List[{py_type}]" | ||||
|             return self.typing_compiler.list(py_type) | ||||
|         if self.optional: | ||||
|             return f"Optional[{py_type}]" | ||||
|             return self.typing_compiler.optional(py_type) | ||||
|         return py_type | ||||
|  | ||||
|  | ||||
| @@ -623,11 +612,13 @@ class MapEntryCompiler(FieldCompiler): | ||||
|                     source_file=self.source_file, | ||||
|                     parent=self, | ||||
|                     proto_obj=nested.field[0],  # key | ||||
|                     typing_compiler=self.typing_compiler, | ||||
|                 ).py_type | ||||
|                 self.py_v_type = FieldCompiler( | ||||
|                     source_file=self.source_file, | ||||
|                     parent=self, | ||||
|                     proto_obj=nested.field[1],  # value | ||||
|                     typing_compiler=self.typing_compiler, | ||||
|                 ).py_type | ||||
|  | ||||
|                 # Get proto types | ||||
| @@ -645,7 +636,7 @@ class MapEntryCompiler(FieldCompiler): | ||||
|  | ||||
|     @property | ||||
|     def annotation(self) -> str: | ||||
|         return f"Dict[{self.py_k_type}, {self.py_v_type}]" | ||||
|         return self.typing_compiler.dict(self.py_k_type, self.py_v_type) | ||||
|  | ||||
|     @property | ||||
|     def repeated(self) -> bool: | ||||
| @@ -702,7 +693,6 @@ class ServiceCompiler(ProtoContentBase): | ||||
|     def __post_init__(self) -> None: | ||||
|         # Add service to output file | ||||
|         self.output_file.services.append(self) | ||||
|         self.output_file.typing_imports.add("Dict") | ||||
|         super().__post_init__()  # check for unset fields | ||||
|  | ||||
|     @property | ||||
| @@ -725,22 +715,6 @@ class ServiceMethodCompiler(ProtoContentBase): | ||||
|         # Add method to service | ||||
|         self.parent.methods.append(self) | ||||
|  | ||||
|         # Check for imports | ||||
|         if "Optional" in self.py_output_message_type: | ||||
|             self.output_file.typing_imports.add("Optional") | ||||
|  | ||||
|         # Check for Async imports | ||||
|         if self.client_streaming: | ||||
|             self.output_file.typing_imports.add("AsyncIterable") | ||||
|             self.output_file.typing_imports.add("Iterable") | ||||
|             self.output_file.typing_imports.add("Union") | ||||
|  | ||||
|         # Required by both client and server | ||||
|         if self.client_streaming or self.server_streaming: | ||||
|             self.output_file.typing_imports.add("AsyncIterator") | ||||
|  | ||||
|         # add imports required for request arguments timeout, deadline and metadata | ||||
|         self.output_file.typing_imports.add("Optional") | ||||
|         self.output_file.imports_type_checking_only.add("import grpclib.server") | ||||
|         self.output_file.imports_type_checking_only.add( | ||||
|             "from betterproto.grpc.grpclib_client import MetadataLike" | ||||
| @@ -806,6 +780,7 @@ class ServiceMethodCompiler(ProtoContentBase): | ||||
|             package=self.output_file.package, | ||||
|             imports=self.output_file.imports, | ||||
|             source_type=self.proto_obj.input_type, | ||||
|             typing_compiler=self.output_file.typing_compiler, | ||||
|             unwrap=False, | ||||
|             pydantic=self.output_file.pydantic_dataclasses, | ||||
|         ).strip('"') | ||||
| @@ -835,6 +810,7 @@ class ServiceMethodCompiler(ProtoContentBase): | ||||
|             package=self.output_file.package, | ||||
|             imports=self.output_file.imports, | ||||
|             source_type=self.proto_obj.output_type, | ||||
|             typing_compiler=self.output_file.typing_compiler, | ||||
|             unwrap=False, | ||||
|             pydantic=self.output_file.pydantic_dataclasses, | ||||
|         ).strip('"') | ||||
|   | ||||
							
								
								
									
										163
									
								
								src/betterproto/plugin/module_validation.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										163
									
								
								src/betterproto/plugin/module_validation.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,163 @@ | ||||
| import re | ||||
| from collections import defaultdict | ||||
| from dataclasses import ( | ||||
|     dataclass, | ||||
|     field, | ||||
| ) | ||||
| from typing import ( | ||||
|     Dict, | ||||
|     Iterator, | ||||
|     List, | ||||
|     Tuple, | ||||
| ) | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class ModuleValidator: | ||||
|     line_iterator: Iterator[str] | ||||
|     line_number: int = field(init=False, default=0) | ||||
|  | ||||
|     collisions: Dict[str, List[Tuple[int, str]]] = field( | ||||
|         init=False, default_factory=lambda: defaultdict(list) | ||||
|     ) | ||||
|  | ||||
|     def add_import(self, imp: str, number: int, full_line: str): | ||||
|         """ | ||||
|         Adds an import to be tracked. | ||||
|         """ | ||||
|         self.collisions[imp].append((number, full_line)) | ||||
|  | ||||
|     def process_import(self, imp: str): | ||||
|         """ | ||||
|         Filters out the import to its actual value. | ||||
|         """ | ||||
|         if " as " in imp: | ||||
|             imp = imp[imp.index(" as ") + 4 :] | ||||
|  | ||||
|         imp = imp.strip() | ||||
|         assert " " not in imp, imp | ||||
|         return imp | ||||
|  | ||||
|     def evaluate_multiline_import(self, line: str): | ||||
|         """ | ||||
|         Evaluates a multiline import from a starting line | ||||
|         """ | ||||
|         # Filter the first line and remove anything before the import statement. | ||||
|         full_line = line | ||||
|         line = line.split("import", 1)[1] | ||||
|         if "(" in line: | ||||
|             conditional = lambda line: ")" not in line | ||||
|         else: | ||||
|             conditional = lambda line: "\\" in line | ||||
|  | ||||
|         # Remove open parenthesis if it exists. | ||||
|         if "(" in line: | ||||
|             line = line[line.index("(") + 1 :] | ||||
|  | ||||
|         # Choose the conditional based on how multiline imports are formatted. | ||||
|         while conditional(line): | ||||
|             # Split the line by commas | ||||
|             imports = line.split(",") | ||||
|  | ||||
|             for imp in imports: | ||||
|                 # Add the import to the namespace | ||||
|                 imp = self.process_import(imp) | ||||
|                 if imp: | ||||
|                     self.add_import(imp, self.line_number, full_line) | ||||
|             # Get the next line | ||||
|             full_line = line = next(self.line_iterator) | ||||
|             # Increment the line number | ||||
|             self.line_number += 1 | ||||
|  | ||||
|         # validate the last line | ||||
|         if ")" in line: | ||||
|             line = line[: line.index(")")] | ||||
|         imports = line.split(",") | ||||
|         for imp in imports: | ||||
|             imp = self.process_import(imp) | ||||
|             if imp: | ||||
|                 self.add_import(imp, self.line_number, full_line) | ||||
|  | ||||
|     def evaluate_import(self, line: str): | ||||
|         """ | ||||
|         Extracts an import from a line. | ||||
|         """ | ||||
|         whole_line = line | ||||
|         line = line[line.index("import") + 6 :] | ||||
|         values = line.split(",") | ||||
|         for v in values: | ||||
|             self.add_import(self.process_import(v), self.line_number, whole_line) | ||||
|  | ||||
|     def next(self): | ||||
|         """ | ||||
|         Evaluate each line for names in the module. | ||||
|         """ | ||||
|         line = next(self.line_iterator) | ||||
|  | ||||
|         # Skip lines with indentation or comments | ||||
|         if ( | ||||
|             # Skip indents and whitespace. | ||||
|             line.startswith(" ") | ||||
|             or line == "\n" | ||||
|             or line.startswith("\t") | ||||
|             or | ||||
|             # Skip comments | ||||
|             line.startswith("#") | ||||
|             or | ||||
|             # Skip  decorators | ||||
|             line.startswith("@") | ||||
|         ): | ||||
|             self.line_number += 1 | ||||
|             return | ||||
|  | ||||
|         # Skip docstrings. | ||||
|         if line.startswith('"""') or line.startswith("'''"): | ||||
|             quote = line[0] * 3 | ||||
|             line = line[3:] | ||||
|             while quote not in line: | ||||
|                 line = next(self.line_iterator) | ||||
|             self.line_number += 1 | ||||
|             return | ||||
|  | ||||
|         # Evaluate Imports. | ||||
|         if line.startswith("from ") or line.startswith("import "): | ||||
|             if "(" in line or "\\" in line: | ||||
|                 self.evaluate_multiline_import(line) | ||||
|             else: | ||||
|                 self.evaluate_import(line) | ||||
|  | ||||
|         # Evaluate Classes. | ||||
|         elif line.startswith("class "): | ||||
|             class_name = re.search(r"class (\w+)", line).group(1) | ||||
|             if class_name: | ||||
|                 self.add_import(class_name, self.line_number, line) | ||||
|  | ||||
|         # Evaluate Functions. | ||||
|         elif line.startswith("def "): | ||||
|             function_name = re.search(r"def (\w+)", line).group(1) | ||||
|             if function_name: | ||||
|                 self.add_import(function_name, self.line_number, line) | ||||
|  | ||||
|         # Evaluate direct assignments. | ||||
|         elif "=" in line: | ||||
|             assignment = re.search(r"(\w+)\s*=", line).group(1) | ||||
|             if assignment: | ||||
|                 self.add_import(assignment, self.line_number, line) | ||||
|  | ||||
|         self.line_number += 1 | ||||
|  | ||||
|     def validate(self) -> bool: | ||||
|         """ | ||||
|         Run Validation. | ||||
|         """ | ||||
|         try: | ||||
|             while True: | ||||
|                 self.next() | ||||
|         except StopIteration: | ||||
|             pass | ||||
|  | ||||
|         # Filter collisions for those with more than one value. | ||||
|         self.collisions = {k: v for k, v in self.collisions.items() if len(v) > 1} | ||||
|  | ||||
|         # Return True if no collisions are found. | ||||
|         return not bool(self.collisions) | ||||
| @@ -37,6 +37,12 @@ from .models import ( | ||||
|     is_map, | ||||
|     is_oneof, | ||||
| ) | ||||
| from .typing_compiler import ( | ||||
|     DirectImportTypingCompiler, | ||||
|     NoTyping310TypingCompiler, | ||||
|     TypingCompiler, | ||||
|     TypingImportTypingCompiler, | ||||
| ) | ||||
|  | ||||
|  | ||||
| def traverse( | ||||
| @@ -98,6 +104,28 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: | ||||
|                 output_package_name | ||||
|             ].pydantic_dataclasses = True | ||||
|  | ||||
|         # Gather any typing generation options. | ||||
|         typing_opts = [ | ||||
|             opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.") | ||||
|         ] | ||||
|  | ||||
|         if len(typing_opts) > 1: | ||||
|             raise ValueError("Multiple typing options provided") | ||||
|         # Set the compiler type. | ||||
|         typing_opt = typing_opts[0] if typing_opts else "direct" | ||||
|         if typing_opt == "direct": | ||||
|             request_data.output_packages[ | ||||
|                 output_package_name | ||||
|             ].typing_compiler = DirectImportTypingCompiler() | ||||
|         elif typing_opt == "root": | ||||
|             request_data.output_packages[ | ||||
|                 output_package_name | ||||
|             ].typing_compiler = TypingImportTypingCompiler() | ||||
|         elif typing_opt == "310": | ||||
|             request_data.output_packages[ | ||||
|                 output_package_name | ||||
|             ].typing_compiler = NoTyping310TypingCompiler() | ||||
|  | ||||
|     # Read Messages and Enums | ||||
|     # We need to read Messages before Services in so that we can | ||||
|     # get the references to input/output messages for each service | ||||
| @@ -166,6 +194,7 @@ def _make_one_of_field_compiler( | ||||
|         parent=parent, | ||||
|         proto_obj=proto_obj, | ||||
|         path=path, | ||||
|         typing_compiler=output_package.typing_compiler, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| @@ -181,7 +210,11 @@ def read_protobuf_type( | ||||
|             return | ||||
|         # Process Message | ||||
|         message_data = MessageCompiler( | ||||
|             source_file=source_file, parent=output_package, proto_obj=item, path=path | ||||
|             source_file=source_file, | ||||
|             parent=output_package, | ||||
|             proto_obj=item, | ||||
|             path=path, | ||||
|             typing_compiler=output_package.typing_compiler, | ||||
|         ) | ||||
|         for index, field in enumerate(item.field): | ||||
|             if is_map(field, item): | ||||
| @@ -190,6 +223,7 @@ def read_protobuf_type( | ||||
|                     parent=message_data, | ||||
|                     proto_obj=field, | ||||
|                     path=path + [2, index], | ||||
|                     typing_compiler=output_package.typing_compiler, | ||||
|                 ) | ||||
|             elif is_oneof(field): | ||||
|                 _make_one_of_field_compiler( | ||||
| @@ -201,11 +235,16 @@ def read_protobuf_type( | ||||
|                     parent=message_data, | ||||
|                     proto_obj=field, | ||||
|                     path=path + [2, index], | ||||
|                     typing_compiler=output_package.typing_compiler, | ||||
|                 ) | ||||
|     elif isinstance(item, EnumDescriptorProto): | ||||
|         # Enum | ||||
|         EnumDefinitionCompiler( | ||||
|             source_file=source_file, parent=output_package, proto_obj=item, path=path | ||||
|             source_file=source_file, | ||||
|             parent=output_package, | ||||
|             proto_obj=item, | ||||
|             path=path, | ||||
|             typing_compiler=output_package.typing_compiler, | ||||
|         ) | ||||
|  | ||||
|  | ||||
|   | ||||
							
								
								
									
										167
									
								
								src/betterproto/plugin/typing_compiler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										167
									
								
								src/betterproto/plugin/typing_compiler.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,167 @@ | ||||
| import abc | ||||
| from collections import defaultdict | ||||
| from dataclasses import ( | ||||
|     dataclass, | ||||
|     field, | ||||
| ) | ||||
| from typing import ( | ||||
|     Dict, | ||||
|     Iterator, | ||||
|     Optional, | ||||
|     Set, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class TypingCompiler(metaclass=abc.ABCMeta): | ||||
|     @abc.abstractmethod | ||||
|     def optional(self, type: str) -> str: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def list(self, type: str) -> str: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def dict(self, key: str, value: str) -> str: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def union(self, *types: str) -> str: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def iterable(self, type: str) -> str: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def async_iterable(self, type: str) -> str: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def async_iterator(self, type: str) -> str: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def imports(self) -> Dict[str, Optional[Set[str]]]: | ||||
|         """ | ||||
|         Returns either the direct import as a key with none as value, or a set of | ||||
|         values to import from the key. | ||||
|         """ | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     def import_lines(self) -> Iterator: | ||||
|         imports = self.imports() | ||||
|         for key, value in imports.items(): | ||||
|             if value is None: | ||||
|                 yield f"import {key}" | ||||
|             else: | ||||
|                 yield f"from {key} import (" | ||||
|                 for v in sorted(value): | ||||
|                     yield f"    {v}," | ||||
|                 yield ")" | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class DirectImportTypingCompiler(TypingCompiler): | ||||
|     _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) | ||||
|  | ||||
|     def optional(self, type: str) -> str: | ||||
|         self._imports["typing"].add("Optional") | ||||
|         return f"Optional[{type}]" | ||||
|  | ||||
|     def list(self, type: str) -> str: | ||||
|         self._imports["typing"].add("List") | ||||
|         return f"List[{type}]" | ||||
|  | ||||
|     def dict(self, key: str, value: str) -> str: | ||||
|         self._imports["typing"].add("Dict") | ||||
|         return f"Dict[{key}, {value}]" | ||||
|  | ||||
|     def union(self, *types: str) -> str: | ||||
|         self._imports["typing"].add("Union") | ||||
|         return f"Union[{', '.join(types)}]" | ||||
|  | ||||
|     def iterable(self, type: str) -> str: | ||||
|         self._imports["typing"].add("Iterable") | ||||
|         return f"Iterable[{type}]" | ||||
|  | ||||
|     def async_iterable(self, type: str) -> str: | ||||
|         self._imports["typing"].add("AsyncIterable") | ||||
|         return f"AsyncIterable[{type}]" | ||||
|  | ||||
|     def async_iterator(self, type: str) -> str: | ||||
|         self._imports["typing"].add("AsyncIterator") | ||||
|         return f"AsyncIterator[{type}]" | ||||
|  | ||||
|     def imports(self) -> Dict[str, Optional[Set[str]]]: | ||||
|         return {k: v if v else None for k, v in self._imports.items()} | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class TypingImportTypingCompiler(TypingCompiler): | ||||
|     _imported: bool = False | ||||
|  | ||||
|     def optional(self, type: str) -> str: | ||||
|         self._imported = True | ||||
|         return f"typing.Optional[{type}]" | ||||
|  | ||||
|     def list(self, type: str) -> str: | ||||
|         self._imported = True | ||||
|         return f"typing.List[{type}]" | ||||
|  | ||||
|     def dict(self, key: str, value: str) -> str: | ||||
|         self._imported = True | ||||
|         return f"typing.Dict[{key}, {value}]" | ||||
|  | ||||
|     def union(self, *types: str) -> str: | ||||
|         self._imported = True | ||||
|         return f"typing.Union[{', '.join(types)}]" | ||||
|  | ||||
|     def iterable(self, type: str) -> str: | ||||
|         self._imported = True | ||||
|         return f"typing.Iterable[{type}]" | ||||
|  | ||||
|     def async_iterable(self, type: str) -> str: | ||||
|         self._imported = True | ||||
|         return f"typing.AsyncIterable[{type}]" | ||||
|  | ||||
|     def async_iterator(self, type: str) -> str: | ||||
|         self._imported = True | ||||
|         return f"typing.AsyncIterator[{type}]" | ||||
|  | ||||
|     def imports(self) -> Dict[str, Optional[Set[str]]]: | ||||
|         if self._imported: | ||||
|             return {"typing": None} | ||||
|         return {} | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class NoTyping310TypingCompiler(TypingCompiler): | ||||
|     _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) | ||||
|  | ||||
|     def optional(self, type: str) -> str: | ||||
|         return f"{type} | None" | ||||
|  | ||||
|     def list(self, type: str) -> str: | ||||
|         return f"list[{type}]" | ||||
|  | ||||
|     def dict(self, key: str, value: str) -> str: | ||||
|         return f"dict[{key}, {value}]" | ||||
|  | ||||
|     def union(self, *types: str) -> str: | ||||
|         return " | ".join(types) | ||||
|  | ||||
|     def iterable(self, type: str) -> str: | ||||
|         self._imports["typing"].add("Iterable") | ||||
|         return f"Iterable[{type}]" | ||||
|  | ||||
|     def async_iterable(self, type: str) -> str: | ||||
|         self._imports["typing"].add("AsyncIterable") | ||||
|         return f"AsyncIterable[{type}]" | ||||
|  | ||||
|     def async_iterator(self, type: str) -> str: | ||||
|         self._imports["typing"].add("AsyncIterator") | ||||
|         return f"AsyncIterator[{type}]" | ||||
|  | ||||
|     def imports(self) -> Dict[str, Optional[Set[str]]]: | ||||
|         return {k: v if v else None for k, v in self._imports.items()} | ||||
							
								
								
									
										54
									
								
								src/betterproto/templates/header.py.j2
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								src/betterproto/templates/header.py.j2
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,54 @@ | ||||
| # Generated by the protocol buffer compiler.  DO NOT EDIT! | ||||
| # sources: {{ ', '.join(output_file.input_filenames) }} | ||||
| # plugin: python-betterproto | ||||
| # This file has been @generated | ||||
| {% for i in output_file.python_module_imports|sort %} | ||||
| import {{ i }} | ||||
| {% endfor %} | ||||
| {% set type_checking_imported = False %} | ||||
|  | ||||
| {% if output_file.pydantic_dataclasses %} | ||||
| from typing import TYPE_CHECKING | ||||
| {% set type_checking_imported = True %} | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from dataclasses import dataclass | ||||
| else: | ||||
|     from pydantic.dataclasses import dataclass | ||||
| {%- else -%} | ||||
| from dataclasses import dataclass | ||||
| {% endif %} | ||||
|  | ||||
| {% if output_file.datetime_imports %} | ||||
| from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} | ||||
|  | ||||
| {% endif%} | ||||
| {% set typing_imports = output_file.typing_compiler.imports() %} | ||||
| {% if typing_imports %} | ||||
| {% for line in output_file.typing_compiler.import_lines() %} | ||||
| {{ line }} | ||||
| {% endfor %} | ||||
| {% endif %} | ||||
|  | ||||
| {% if output_file.pydantic_imports %} | ||||
| from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} | ||||
|  | ||||
| {% endif %} | ||||
|  | ||||
| import betterproto | ||||
| {% if output_file.services %} | ||||
| from betterproto.grpc.grpclib_server import ServiceBase | ||||
| import grpclib | ||||
| {% endif %} | ||||
|  | ||||
| {% for i in output_file.imports|sort %} | ||||
| {{ i }} | ||||
| {% endfor %} | ||||
|  | ||||
| {% if output_file.imports_type_checking_only and not type_checking_imported %} | ||||
| from typing import TYPE_CHECKING | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
| {% for i in output_file.imports_type_checking_only|sort %}    {{ i }} | ||||
| {% endfor %} | ||||
| {% endif %} | ||||
| @@ -1,53 +1,3 @@ | ||||
| # Generated by the protocol buffer compiler.  DO NOT EDIT! | ||||
| # sources: {{ ', '.join(output_file.input_filenames) }} | ||||
| # plugin: python-betterproto | ||||
| # This file has been @generated | ||||
| {% for i in output_file.python_module_imports|sort %} | ||||
| import {{ i }} | ||||
| {% endfor %} | ||||
|  | ||||
| {% if output_file.pydantic_dataclasses %} | ||||
| from typing import TYPE_CHECKING | ||||
| if TYPE_CHECKING: | ||||
|     from dataclasses import dataclass | ||||
| else: | ||||
|     from pydantic.dataclasses import dataclass | ||||
| {%- else -%} | ||||
| from dataclasses import dataclass | ||||
| {% endif %} | ||||
|  | ||||
| {% if output_file.datetime_imports %} | ||||
| from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} | ||||
|  | ||||
| {% endif%} | ||||
| {% if output_file.typing_imports %} | ||||
| from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} | ||||
|  | ||||
| {% endif %} | ||||
|  | ||||
| {% if output_file.pydantic_imports %} | ||||
| from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} | ||||
|  | ||||
| {% endif %} | ||||
|  | ||||
| import betterproto | ||||
| {% if output_file.services %} | ||||
| from betterproto.grpc.grpclib_server import ServiceBase | ||||
| import grpclib | ||||
| {% endif %} | ||||
|  | ||||
| {% for i in output_file.imports|sort %} | ||||
| {{ i }} | ||||
| {% endfor %} | ||||
|  | ||||
| {% if output_file.imports_type_checking_only %} | ||||
| from typing import TYPE_CHECKING | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
| {% for i in output_file.imports_type_checking_only|sort %}    {{ i }} | ||||
| {% endfor %} | ||||
| {% endif %} | ||||
|  | ||||
| {% if output_file.enums %}{% for enum in output_file.enums %} | ||||
| class {{ enum.py_name }}(betterproto.Enum): | ||||
|     {% if enum.comment %} | ||||
| @@ -116,14 +66,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): | ||||
|             {%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%} | ||||
|         {%- else -%} | ||||
|             {# Client streaming: need a request iterator instead #} | ||||
|             , {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]] | ||||
|             , {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }} | ||||
|         {%- endif -%} | ||||
|             , | ||||
|             * | ||||
|             , timeout: Optional[float] = None | ||||
|             , deadline: Optional["Deadline"] = None | ||||
|             , metadata: Optional["MetadataLike"] = None | ||||
|             ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: | ||||
|             , timeout: {{ output_file.typing_compiler.optional("float") }} = None | ||||
|             , deadline: {{ output_file.typing_compiler.optional('"Deadline"') }} = None | ||||
|             , metadata: {{ output_file.typing_compiler.optional('"MetadataLike"') }} = None | ||||
|             ) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}: | ||||
|         {% if method.comment %} | ||||
| {{ method.comment }} | ||||
|  | ||||
| @@ -191,9 +141,9 @@ class {{ service.py_name }}Base(ServiceBase): | ||||
|             {%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%} | ||||
|         {%- else -%} | ||||
|             {# Client streaming: need a request iterator instead #} | ||||
|             , {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"] | ||||
|             , {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }} | ||||
|         {%- endif -%} | ||||
|             ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: | ||||
|             ) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}: | ||||
|         {% if method.comment %} | ||||
| {{ method.comment }} | ||||
|  | ||||
| @@ -225,7 +175,7 @@ class {{ service.py_name }}Base(ServiceBase): | ||||
|  | ||||
|     {% endfor %} | ||||
|  | ||||
|     def __mapping__(self) -> Dict[str, grpclib.const.Handler]: | ||||
|     def __mapping__(self) -> {{ output_file.typing_compiler.dict("str", "grpclib.const.Handler") }}: | ||||
|         return { | ||||
|         {% for method in service.methods %} | ||||
|         "{{ method.route }}": grpclib.const.Handler( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user