QOL fixes (#141)
- Add missing type annotations - Various style improvements - Use constants more consistently - enforce black on benchmark code
This commit is contained in:
		| @@ -5,10 +5,9 @@ try: | ||||
|     import black | ||||
|     import jinja2 | ||||
| except ImportError as err: | ||||
|     missing_import = err.args[0][17:-1] | ||||
|     print( | ||||
|         "\033[31m" | ||||
|         f"Unable to import `{missing_import}` from betterproto plugin! " | ||||
|         f"Unable to import `{err.name}` from betterproto plugin! " | ||||
|         "Please ensure that you've installed betterproto as " | ||||
|         '`pip install "betterproto[compiler]"` so that compiler dependencies ' | ||||
|         "are included." | ||||
| @@ -16,7 +15,7 @@ except ImportError as err: | ||||
|     ) | ||||
|     raise SystemExit(1) | ||||
|  | ||||
| from betterproto.plugin.models import OutputTemplate | ||||
| from .models import OutputTemplate | ||||
|  | ||||
|  | ||||
| def outputfile_compiler(output_file: OutputTemplate) -> str: | ||||
| @@ -32,9 +31,7 @@ def outputfile_compiler(output_file: OutputTemplate) -> str: | ||||
|     ) | ||||
|     template = env.get_template("template.py.j2") | ||||
|  | ||||
|     res = black.format_str( | ||||
|     return black.format_str( | ||||
|         template.render(output_file=output_file), | ||||
|         mode=black.FileMode(target_versions={black.TargetVersion.PY37}), | ||||
|     ) | ||||
|  | ||||
|     return res | ||||
|   | ||||
| @@ -1,13 +1,14 @@ | ||||
| #!/usr/bin/env python | ||||
| import sys | ||||
|  | ||||
| import os | ||||
| import sys | ||||
|  | ||||
| from google.protobuf.compiler import plugin_pb2 as plugin | ||||
|  | ||||
| from betterproto.plugin.parser import generate_code | ||||
|  | ||||
|  | ||||
| def main(): | ||||
| def main() -> None: | ||||
|     """The plugin's main entry point.""" | ||||
|     # Read request message from stdin | ||||
|     data = sys.stdin.buffer.read() | ||||
| @@ -33,7 +34,7 @@ def main(): | ||||
|     sys.stdout.buffer.write(output) | ||||
|  | ||||
|  | ||||
| def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest): | ||||
| def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest) -> None: | ||||
|     """ | ||||
|     For developers: Supports running plugin.py standalone so its possible to debug it. | ||||
|     Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file. | ||||
|   | ||||
| @@ -1,14 +1,14 @@ | ||||
| """Plugin model dataclasses. | ||||
|  | ||||
| These classes are meant to be an intermediate representation | ||||
| of protbuf objects. They are used to organize the data collected during parsing. | ||||
| of protobuf objects. They are used to organize the data collected during parsing. | ||||
|  | ||||
| The general intention is to create a doubly-linked tree-like structure | ||||
| with the following types of references: | ||||
| - Downwards references: from message -> fields, from output package -> messages | ||||
| or from service -> service methods | ||||
| - Upwards references: from field -> message, message -> package. | ||||
| - Input/ouput message references: from a service method to it's corresponding | ||||
| - Input/output message references: from a service method to it's corresponding | ||||
| input/output messages, which may even be in another package. | ||||
|  | ||||
| There are convenience methods to allow climbing up and down this tree, for | ||||
| @@ -26,36 +26,24 @@ such as a pythonized name, that will be calculated from proto_obj. | ||||
| The instantiation should also attach a reference to the new object | ||||
| into the corresponding place within it's parent object. For example, | ||||
| instantiating field `A` with parent message `B` should add a | ||||
| reference to `A` to `B`'s `fields` attirbute. | ||||
| reference to `A` to `B`'s `fields` attribute. | ||||
| """ | ||||
|  | ||||
| import re | ||||
| from dataclasses import dataclass | ||||
| from dataclasses import field | ||||
| from typing import ( | ||||
|     Iterator, | ||||
|     Union, | ||||
|     Type, | ||||
|     List, | ||||
|     Dict, | ||||
|     Set, | ||||
|     Text, | ||||
| ) | ||||
| import textwrap | ||||
| from dataclasses import dataclass, field | ||||
| from typing import Dict, Iterator, List, Optional, Set, Text, Type, Union | ||||
|  | ||||
| import betterproto | ||||
| from betterproto.compile.importing import ( | ||||
|     get_type_reference, | ||||
|     parse_source_type_name, | ||||
| ) | ||||
| from betterproto.compile.naming import ( | ||||
|  | ||||
| from ..casing import sanitize_name | ||||
| from ..compile.importing import get_type_reference, parse_source_type_name | ||||
| from ..compile.naming import ( | ||||
|     pythonize_class_name, | ||||
|     pythonize_field_name, | ||||
|     pythonize_method_name, | ||||
| ) | ||||
|  | ||||
| from ..casing import sanitize_name | ||||
|  | ||||
| try: | ||||
|     # betterproto[compiler] specific dependencies | ||||
|     from google.protobuf.compiler import plugin_pb2 as plugin | ||||
| @@ -67,10 +55,9 @@ try: | ||||
|         MethodDescriptorProto, | ||||
|     ) | ||||
| except ImportError as err: | ||||
|     missing_import = re.match(r".*(cannot import name .*$)", err.args[0]).group(1) | ||||
|     print( | ||||
|         "\033[31m" | ||||
|         f"Unable to import `{missing_import}` from betterproto plugin! " | ||||
|         f"Unable to import `{err.name}` from betterproto plugin! " | ||||
|         "Please ensure that you've installed betterproto as " | ||||
|         '`pip install "betterproto[compiler]"` so that compiler dependencies ' | ||||
|         "are included." | ||||
| @@ -124,10 +111,11 @@ PROTO_PACKED_TYPES = ( | ||||
| ) | ||||
|  | ||||
|  | ||||
| def get_comment(proto_file, path: List[int], indent: int = 4) -> str: | ||||
| def get_comment( | ||||
|     proto_file: "FileDescriptorProto", path: List[int], indent: int = 4 | ||||
| ) -> str: | ||||
|     pad = " " * indent | ||||
|     for sci in proto_file.source_code_info.location: | ||||
|         # print(list(sci.path), path, file=sys.stderr) | ||||
|         if list(sci.path) == path and sci.leading_comments: | ||||
|             lines = textwrap.wrap( | ||||
|                 sci.leading_comments.strip().replace("\n", ""), width=79 - indent | ||||
| @@ -153,9 +141,9 @@ class ProtoContentBase: | ||||
|  | ||||
|     path: List[int] | ||||
|     comment_indent: int = 4 | ||||
|     parent: Union["Messsage", "OutputTemplate"] | ||||
|     parent: Union["betterproto.Message", "OutputTemplate"] | ||||
|  | ||||
|     def __post_init__(self): | ||||
|     def __post_init__(self) -> None: | ||||
|         """Checks that no fake default fields were left as placeholders.""" | ||||
|         for field_name, field_val in self.__dataclass_fields__.items(): | ||||
|             if field_val is PLACEHOLDER: | ||||
| @@ -273,7 +261,7 @@ class MessageCompiler(ProtoContentBase): | ||||
|     ) | ||||
|     deprecated: bool = field(default=False, init=False) | ||||
|  | ||||
|     def __post_init__(self): | ||||
|     def __post_init__(self) -> None: | ||||
|         # Add message to output file | ||||
|         if isinstance(self.parent, OutputTemplate): | ||||
|             if isinstance(self, EnumDefinitionCompiler): | ||||
| @@ -314,17 +302,17 @@ def is_map( | ||||
|         map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry" | ||||
|         if message_type == map_entry: | ||||
|             for nested in parent_message.nested_type:  # parent message | ||||
|                 if nested.name.replace("_", "").lower() == map_entry: | ||||
|                     if nested.options.map_entry: | ||||
|                         return True | ||||
|                 if ( | ||||
|                     nested.name.replace("_", "").lower() == map_entry | ||||
|                     and nested.options.map_entry | ||||
|                 ): | ||||
|                     return True | ||||
|     return False | ||||
|  | ||||
|  | ||||
| def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool: | ||||
|     """True if proto_field_obj is a OneOf, otherwise False.""" | ||||
|     if proto_field_obj.HasField("oneof_index"): | ||||
|         return True | ||||
|     return False | ||||
|     return proto_field_obj.HasField("oneof_index") | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| @@ -332,7 +320,7 @@ class FieldCompiler(MessageCompiler): | ||||
|     parent: MessageCompiler = PLACEHOLDER | ||||
|     proto_obj: FieldDescriptorProto = PLACEHOLDER | ||||
|  | ||||
|     def __post_init__(self): | ||||
|     def __post_init__(self) -> None: | ||||
|         # Add field to message | ||||
|         self.parent.fields.append(self) | ||||
|         # Check for new imports | ||||
| @@ -357,11 +345,9 @@ class FieldCompiler(MessageCompiler): | ||||
|             ([""] + self.betterproto_field_args) if self.betterproto_field_args else [] | ||||
|         ) | ||||
|         betterproto_field_type = ( | ||||
|             f"betterproto.{self.field_type}_field({self.proto_obj.number}" | ||||
|             + field_args | ||||
|             + ")" | ||||
|             f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})" | ||||
|         ) | ||||
|         return name + annotations + " = " + betterproto_field_type | ||||
|         return f"{name}{annotations} = {betterproto_field_type}" | ||||
|  | ||||
|     @property | ||||
|     def betterproto_field_args(self) -> List[str]: | ||||
| @@ -371,7 +357,7 @@ class FieldCompiler(MessageCompiler): | ||||
|         return args | ||||
|  | ||||
|     @property | ||||
|     def field_wraps(self) -> Union[str, None]: | ||||
|     def field_wraps(self) -> Optional[str]: | ||||
|         """Returns betterproto wrapped field type or None.""" | ||||
|         match_wrapper = re.match( | ||||
|             r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name | ||||
| @@ -384,17 +370,15 @@ class FieldCompiler(MessageCompiler): | ||||
|  | ||||
|     @property | ||||
|     def repeated(self) -> bool: | ||||
|         if self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED and not is_map( | ||||
|             self.proto_obj, self.parent | ||||
|         ): | ||||
|             return True | ||||
|         return False | ||||
|         return ( | ||||
|             self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED | ||||
|             and not is_map(self.proto_obj, self.parent) | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def mutable(self) -> bool: | ||||
|         """True if the field is a mutable type, otherwise False.""" | ||||
|         annotation = self.annotation | ||||
|         return annotation.startswith("List[") or annotation.startswith("Dict[") | ||||
|         return self.annotation.startswith(("List[", "Dict[")) | ||||
|  | ||||
|     @property | ||||
|     def field_type(self) -> str: | ||||
| @@ -425,9 +409,7 @@ class FieldCompiler(MessageCompiler): | ||||
|     @property | ||||
|     def packed(self) -> bool: | ||||
|         """True if the wire representation is a packed format.""" | ||||
|         if self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES: | ||||
|             return True | ||||
|         return False | ||||
|         return self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES | ||||
|  | ||||
|     @property | ||||
|     def py_name(self) -> str: | ||||
| @@ -486,22 +468,24 @@ class MapEntryCompiler(FieldCompiler): | ||||
|     proto_k_type: str = PLACEHOLDER | ||||
|     proto_v_type: str = PLACEHOLDER | ||||
|  | ||||
|     def __post_init__(self): | ||||
|     def __post_init__(self) -> None: | ||||
|         """Explore nested types and set k_type and v_type if unset.""" | ||||
|         map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry" | ||||
|         for nested in self.parent.proto_obj.nested_type: | ||||
|             if nested.name.replace("_", "").lower() == map_entry: | ||||
|                 if nested.options.map_entry: | ||||
|                     # Get Python types | ||||
|                     self.py_k_type = FieldCompiler( | ||||
|                         parent=self, proto_obj=nested.field[0]  # key | ||||
|                     ).py_type | ||||
|                     self.py_v_type = FieldCompiler( | ||||
|                         parent=self, proto_obj=nested.field[1]  # value | ||||
|                     ).py_type | ||||
|                     # Get proto types | ||||
|                     self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type) | ||||
|                     self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type) | ||||
|             if ( | ||||
|                 nested.name.replace("_", "").lower() == map_entry | ||||
|                 and nested.options.map_entry | ||||
|             ): | ||||
|                 # Get Python types | ||||
|                 self.py_k_type = FieldCompiler( | ||||
|                     parent=self, proto_obj=nested.field[0]  # key | ||||
|                 ).py_type | ||||
|                 self.py_v_type = FieldCompiler( | ||||
|                     parent=self, proto_obj=nested.field[1]  # value | ||||
|                 ).py_type | ||||
|                 # Get proto types | ||||
|                 self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type) | ||||
|                 self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type) | ||||
|         super().__post_init__()  # call FieldCompiler-> MessageCompiler __post_init__ | ||||
|  | ||||
|     @property | ||||
| @@ -513,11 +497,11 @@ class MapEntryCompiler(FieldCompiler): | ||||
|         return "map" | ||||
|  | ||||
|     @property | ||||
|     def annotation(self): | ||||
|     def annotation(self) -> str: | ||||
|         return f"Dict[{self.py_k_type}, {self.py_v_type}]" | ||||
|  | ||||
|     @property | ||||
|     def repeated(self): | ||||
|     def repeated(self) -> bool: | ||||
|         return False  # maps cannot be repeated | ||||
|  | ||||
|  | ||||
| @@ -536,7 +520,7 @@ class EnumDefinitionCompiler(MessageCompiler): | ||||
|         value: int | ||||
|         comment: str | ||||
|  | ||||
|     def __post_init__(self): | ||||
|     def __post_init__(self) -> None: | ||||
|         # Get entries/allowed values for this Enum | ||||
|         self.entries = [ | ||||
|             self.EnumEntry( | ||||
| @@ -551,7 +535,7 @@ class EnumDefinitionCompiler(MessageCompiler): | ||||
|         super().__post_init__()  # call MessageCompiler __post_init__ | ||||
|  | ||||
|     @property | ||||
|     def default_value_string(self) -> int: | ||||
|     def default_value_string(self) -> str: | ||||
|         """Python representation of the default value for Enums. | ||||
|  | ||||
|         As per the spec, this is the first value of the Enum. | ||||
| @@ -572,11 +556,11 @@ class ServiceCompiler(ProtoContentBase): | ||||
|         super().__post_init__()  # check for unset fields | ||||
|  | ||||
|     @property | ||||
|     def proto_name(self): | ||||
|     def proto_name(self) -> str: | ||||
|         return self.proto_obj.name | ||||
|  | ||||
|     @property | ||||
|     def py_name(self): | ||||
|     def py_name(self) -> str: | ||||
|         return pythonize_class_name(self.proto_name) | ||||
|  | ||||
|  | ||||
| @@ -628,7 +612,7 @@ class ServiceMethodCompiler(ProtoContentBase): | ||||
|             Name and actual default value (as a string) | ||||
|             for each argument with mutable default values. | ||||
|         """ | ||||
|         mutable_default_args = dict() | ||||
|         mutable_default_args = {} | ||||
|  | ||||
|         if self.py_input_message: | ||||
|             for f in self.py_input_message.fields: | ||||
| @@ -654,18 +638,15 @@ class ServiceMethodCompiler(ProtoContentBase): | ||||
|  | ||||
|     @property | ||||
|     def route(self) -> str: | ||||
|         return ( | ||||
|             f"/{self.output_file.package}." | ||||
|             f"{self.parent.proto_name}/{self.proto_name}" | ||||
|         ) | ||||
|         return f"/{self.output_file.package}.{self.parent.proto_name}/{self.proto_name}" | ||||
|  | ||||
|     @property | ||||
|     def py_input_message(self) -> Union[None, MessageCompiler]: | ||||
|     def py_input_message(self) -> Optional[MessageCompiler]: | ||||
|         """Find the input message object. | ||||
|  | ||||
|         Returns | ||||
|         ------- | ||||
|         Union[None, MessageCompiler] | ||||
|         Optional[MessageCompiler] | ||||
|             Method instance representing the input message. | ||||
|             If not input message could be found or there are no | ||||
|             input messages, None is returned. | ||||
| @@ -685,14 +666,13 @@ class ServiceMethodCompiler(ProtoContentBase): | ||||
|  | ||||
|     @property | ||||
|     def py_input_message_type(self) -> str: | ||||
|         """String representation of the Python type correspoding to the | ||||
|         """String representation of the Python type corresponding to the | ||||
|         input message. | ||||
|  | ||||
|         Returns | ||||
|         ------- | ||||
|         str | ||||
|             String representation of the Python type correspoding to the | ||||
|         input message. | ||||
|             String representation of the Python type corresponding to the input message. | ||||
|         """ | ||||
|         return get_type_reference( | ||||
|             package=self.output_file.package, | ||||
| @@ -702,14 +682,13 @@ class ServiceMethodCompiler(ProtoContentBase): | ||||
|  | ||||
|     @property | ||||
|     def py_output_message_type(self) -> str: | ||||
|         """String representation of the Python type correspoding to the | ||||
|         """String representation of the Python type corresponding to the | ||||
|         output message. | ||||
|  | ||||
|         Returns | ||||
|         ------- | ||||
|         str | ||||
|             String representation of the Python type correspoding to the | ||||
|         output message. | ||||
|             String representation of the Python type corresponding to the output message. | ||||
|         """ | ||||
|         return get_type_reference( | ||||
|             package=self.output_file.package, | ||||
|   | ||||
| @@ -1,7 +1,7 @@ | ||||
| import itertools | ||||
| import pathlib | ||||
| import sys | ||||
| from typing import List, Iterator | ||||
| from typing import TYPE_CHECKING, Iterator, List, Tuple, Union, Set | ||||
|  | ||||
| try: | ||||
|     # betterproto[compiler] specific dependencies | ||||
| @@ -13,10 +13,9 @@ try: | ||||
|         ServiceDescriptorProto, | ||||
|     ) | ||||
| except ImportError as err: | ||||
|     missing_import = err.args[0][17:-1] | ||||
|     print( | ||||
|         "\033[31m" | ||||
|         f"Unable to import `{missing_import}` from betterproto plugin! " | ||||
|         f"Unable to import `{err.name}` from betterproto plugin! " | ||||
|         "Please ensure that you've installed betterproto as " | ||||
|         '`pip install "betterproto[compiler]"` so that compiler dependencies ' | ||||
|         "are included." | ||||
| @@ -24,26 +23,32 @@ except ImportError as err: | ||||
|     ) | ||||
|     raise SystemExit(1) | ||||
|  | ||||
| from betterproto.plugin.models import ( | ||||
|     PluginRequestCompiler, | ||||
|     OutputTemplate, | ||||
|     MessageCompiler, | ||||
|     FieldCompiler, | ||||
|     OneOfFieldCompiler, | ||||
|     MapEntryCompiler, | ||||
| from .compiler import outputfile_compiler | ||||
| from .models import ( | ||||
|     EnumDefinitionCompiler, | ||||
|     FieldCompiler, | ||||
|     MapEntryCompiler, | ||||
|     MessageCompiler, | ||||
|     OneOfFieldCompiler, | ||||
|     OutputTemplate, | ||||
|     PluginRequestCompiler, | ||||
|     ServiceCompiler, | ||||
|     ServiceMethodCompiler, | ||||
|     is_map, | ||||
|     is_oneof, | ||||
| ) | ||||
|  | ||||
| from betterproto.plugin.compiler import outputfile_compiler | ||||
| if TYPE_CHECKING: | ||||
|     from google.protobuf.descriptor import Descriptor | ||||
|  | ||||
|  | ||||
| def traverse(proto_file: FieldDescriptorProto) -> Iterator: | ||||
| def traverse( | ||||
|     proto_file: FieldDescriptorProto, | ||||
| ) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]": | ||||
|     # Todo: Keep information about nested hierarchy | ||||
|     def _traverse(path, items, prefix=""): | ||||
|     def _traverse( | ||||
|         path: List[int], items: List["Descriptor"], prefix="" | ||||
|     ) -> Iterator[Tuple[Union[str, EnumDescriptorProto], List[int]]]: | ||||
|         for i, item in enumerate(items): | ||||
|             # Adjust the name since we flatten the hierarchy. | ||||
|             # Todo: don't change the name, but include full name in returned tuple | ||||
| @@ -104,7 +109,7 @@ def generate_code( | ||||
|                 read_protobuf_service(service, index, output_package) | ||||
|  | ||||
|     # Generate output files | ||||
|     output_paths: pathlib.Path = set() | ||||
|     output_paths: Set[pathlib.Path] = set() | ||||
|     for output_package_name, output_package in request_data.output_packages.items(): | ||||
|  | ||||
|         # Add files to the response object | ||||
| @@ -112,20 +117,17 @@ def generate_code( | ||||
|         output_paths.add(output_path) | ||||
|  | ||||
|         f: response.File = response.file.add() | ||||
|         f.name: str = str(output_path) | ||||
|         f.name = str(output_path) | ||||
|  | ||||
|         # Render and then format the output file | ||||
|         f.content: str = outputfile_compiler(output_file=output_package) | ||||
|         f.content = outputfile_compiler(output_file=output_package) | ||||
|  | ||||
|     # Make each output directory a package with __init__ file | ||||
|     init_files = ( | ||||
|         set( | ||||
|             directory.joinpath("__init__.py") | ||||
|             for path in output_paths | ||||
|             for directory in path.parents | ||||
|         ) | ||||
|         - output_paths | ||||
|     ) | ||||
|     init_files = { | ||||
|         directory.joinpath("__init__.py") | ||||
|         for path in output_paths | ||||
|         for directory in path.parents | ||||
|     } - output_paths | ||||
|  | ||||
|     for init_file in init_files: | ||||
|         init = response.file.add() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user