QOL fixes (#141)
- Add missing type annotations - Various style improvements - Use constants more consistently - enforce black on benchmark code
This commit is contained in:
@@ -5,10 +5,9 @@ try:
|
||||
import black
|
||||
import jinja2
|
||||
except ImportError as err:
|
||||
missing_import = err.args[0][17:-1]
|
||||
print(
|
||||
"\033[31m"
|
||||
f"Unable to import `{missing_import}` from betterproto plugin! "
|
||||
f"Unable to import `{err.name}` from betterproto plugin! "
|
||||
"Please ensure that you've installed betterproto as "
|
||||
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
||||
"are included."
|
||||
@@ -16,7 +15,7 @@ except ImportError as err:
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
from betterproto.plugin.models import OutputTemplate
|
||||
from .models import OutputTemplate
|
||||
|
||||
|
||||
def outputfile_compiler(output_file: OutputTemplate) -> str:
|
||||
@@ -32,9 +31,7 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
|
||||
)
|
||||
template = env.get_template("template.py.j2")
|
||||
|
||||
res = black.format_str(
|
||||
return black.format_str(
|
||||
template.render(output_file=output_file),
|
||||
mode=black.FileMode(target_versions={black.TargetVersion.PY37}),
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
#!/usr/bin/env python
|
||||
import sys
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
|
||||
from betterproto.plugin.parser import generate_code
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
"""The plugin's main entry point."""
|
||||
# Read request message from stdin
|
||||
data = sys.stdin.buffer.read()
|
||||
@@ -33,7 +34,7 @@ def main():
|
||||
sys.stdout.buffer.write(output)
|
||||
|
||||
|
||||
def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest):
|
||||
def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest) -> None:
|
||||
"""
|
||||
For developers: Supports running plugin.py standalone so its possible to debug it.
|
||||
Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file.
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
"""Plugin model dataclasses.
|
||||
|
||||
These classes are meant to be an intermediate representation
|
||||
of protbuf objects. They are used to organize the data collected during parsing.
|
||||
of protobuf objects. They are used to organize the data collected during parsing.
|
||||
|
||||
The general intention is to create a doubly-linked tree-like structure
|
||||
with the following types of references:
|
||||
- Downwards references: from message -> fields, from output package -> messages
|
||||
or from service -> service methods
|
||||
- Upwards references: from field -> message, message -> package.
|
||||
- Input/ouput message references: from a service method to it's corresponding
|
||||
- Input/output message references: from a service method to it's corresponding
|
||||
input/output messages, which may even be in another package.
|
||||
|
||||
There are convenience methods to allow climbing up and down this tree, for
|
||||
@@ -26,36 +26,24 @@ such as a pythonized name, that will be calculated from proto_obj.
|
||||
The instantiation should also attach a reference to the new object
|
||||
into the corresponding place within it's parent object. For example,
|
||||
instantiating field `A` with parent message `B` should add a
|
||||
reference to `A` to `B`'s `fields` attirbute.
|
||||
reference to `A` to `B`'s `fields` attribute.
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from typing import (
|
||||
Iterator,
|
||||
Union,
|
||||
Type,
|
||||
List,
|
||||
Dict,
|
||||
Set,
|
||||
Text,
|
||||
)
|
||||
import textwrap
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Iterator, List, Optional, Set, Text, Type, Union
|
||||
|
||||
import betterproto
|
||||
from betterproto.compile.importing import (
|
||||
get_type_reference,
|
||||
parse_source_type_name,
|
||||
)
|
||||
from betterproto.compile.naming import (
|
||||
|
||||
from ..casing import sanitize_name
|
||||
from ..compile.importing import get_type_reference, parse_source_type_name
|
||||
from ..compile.naming import (
|
||||
pythonize_class_name,
|
||||
pythonize_field_name,
|
||||
pythonize_method_name,
|
||||
)
|
||||
|
||||
from ..casing import sanitize_name
|
||||
|
||||
try:
|
||||
# betterproto[compiler] specific dependencies
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
@@ -67,10 +55,9 @@ try:
|
||||
MethodDescriptorProto,
|
||||
)
|
||||
except ImportError as err:
|
||||
missing_import = re.match(r".*(cannot import name .*$)", err.args[0]).group(1)
|
||||
print(
|
||||
"\033[31m"
|
||||
f"Unable to import `{missing_import}` from betterproto plugin! "
|
||||
f"Unable to import `{err.name}` from betterproto plugin! "
|
||||
"Please ensure that you've installed betterproto as "
|
||||
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
||||
"are included."
|
||||
@@ -124,10 +111,11 @@ PROTO_PACKED_TYPES = (
|
||||
)
|
||||
|
||||
|
||||
def get_comment(proto_file, path: List[int], indent: int = 4) -> str:
|
||||
def get_comment(
|
||||
proto_file: "FileDescriptorProto", path: List[int], indent: int = 4
|
||||
) -> str:
|
||||
pad = " " * indent
|
||||
for sci in proto_file.source_code_info.location:
|
||||
# print(list(sci.path), path, file=sys.stderr)
|
||||
if list(sci.path) == path and sci.leading_comments:
|
||||
lines = textwrap.wrap(
|
||||
sci.leading_comments.strip().replace("\n", ""), width=79 - indent
|
||||
@@ -153,9 +141,9 @@ class ProtoContentBase:
|
||||
|
||||
path: List[int]
|
||||
comment_indent: int = 4
|
||||
parent: Union["Messsage", "OutputTemplate"]
|
||||
parent: Union["betterproto.Message", "OutputTemplate"]
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
"""Checks that no fake default fields were left as placeholders."""
|
||||
for field_name, field_val in self.__dataclass_fields__.items():
|
||||
if field_val is PLACEHOLDER:
|
||||
@@ -273,7 +261,7 @@ class MessageCompiler(ProtoContentBase):
|
||||
)
|
||||
deprecated: bool = field(default=False, init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
# Add message to output file
|
||||
if isinstance(self.parent, OutputTemplate):
|
||||
if isinstance(self, EnumDefinitionCompiler):
|
||||
@@ -314,17 +302,17 @@ def is_map(
|
||||
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
|
||||
if message_type == map_entry:
|
||||
for nested in parent_message.nested_type: # parent message
|
||||
if nested.name.replace("_", "").lower() == map_entry:
|
||||
if nested.options.map_entry:
|
||||
return True
|
||||
if (
|
||||
nested.name.replace("_", "").lower() == map_entry
|
||||
and nested.options.map_entry
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
|
||||
"""True if proto_field_obj is a OneOf, otherwise False."""
|
||||
if proto_field_obj.HasField("oneof_index"):
|
||||
return True
|
||||
return False
|
||||
return proto_field_obj.HasField("oneof_index")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -332,7 +320,7 @@ class FieldCompiler(MessageCompiler):
|
||||
parent: MessageCompiler = PLACEHOLDER
|
||||
proto_obj: FieldDescriptorProto = PLACEHOLDER
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
# Add field to message
|
||||
self.parent.fields.append(self)
|
||||
# Check for new imports
|
||||
@@ -357,11 +345,9 @@ class FieldCompiler(MessageCompiler):
|
||||
([""] + self.betterproto_field_args) if self.betterproto_field_args else []
|
||||
)
|
||||
betterproto_field_type = (
|
||||
f"betterproto.{self.field_type}_field({self.proto_obj.number}"
|
||||
+ field_args
|
||||
+ ")"
|
||||
f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})"
|
||||
)
|
||||
return name + annotations + " = " + betterproto_field_type
|
||||
return f"{name}{annotations} = {betterproto_field_type}"
|
||||
|
||||
@property
|
||||
def betterproto_field_args(self) -> List[str]:
|
||||
@@ -371,7 +357,7 @@ class FieldCompiler(MessageCompiler):
|
||||
return args
|
||||
|
||||
@property
|
||||
def field_wraps(self) -> Union[str, None]:
|
||||
def field_wraps(self) -> Optional[str]:
|
||||
"""Returns betterproto wrapped field type or None."""
|
||||
match_wrapper = re.match(
|
||||
r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name
|
||||
@@ -384,17 +370,15 @@ class FieldCompiler(MessageCompiler):
|
||||
|
||||
@property
|
||||
def repeated(self) -> bool:
|
||||
if self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED and not is_map(
|
||||
self.proto_obj, self.parent
|
||||
):
|
||||
return True
|
||||
return False
|
||||
return (
|
||||
self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED
|
||||
and not is_map(self.proto_obj, self.parent)
|
||||
)
|
||||
|
||||
@property
|
||||
def mutable(self) -> bool:
|
||||
"""True if the field is a mutable type, otherwise False."""
|
||||
annotation = self.annotation
|
||||
return annotation.startswith("List[") or annotation.startswith("Dict[")
|
||||
return self.annotation.startswith(("List[", "Dict["))
|
||||
|
||||
@property
|
||||
def field_type(self) -> str:
|
||||
@@ -425,9 +409,7 @@ class FieldCompiler(MessageCompiler):
|
||||
@property
|
||||
def packed(self) -> bool:
|
||||
"""True if the wire representation is a packed format."""
|
||||
if self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES:
|
||||
return True
|
||||
return False
|
||||
return self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES
|
||||
|
||||
@property
|
||||
def py_name(self) -> str:
|
||||
@@ -486,22 +468,24 @@ class MapEntryCompiler(FieldCompiler):
|
||||
proto_k_type: str = PLACEHOLDER
|
||||
proto_v_type: str = PLACEHOLDER
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
"""Explore nested types and set k_type and v_type if unset."""
|
||||
map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry"
|
||||
for nested in self.parent.proto_obj.nested_type:
|
||||
if nested.name.replace("_", "").lower() == map_entry:
|
||||
if nested.options.map_entry:
|
||||
# Get Python types
|
||||
self.py_k_type = FieldCompiler(
|
||||
parent=self, proto_obj=nested.field[0] # key
|
||||
).py_type
|
||||
self.py_v_type = FieldCompiler(
|
||||
parent=self, proto_obj=nested.field[1] # value
|
||||
).py_type
|
||||
# Get proto types
|
||||
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
|
||||
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
|
||||
if (
|
||||
nested.name.replace("_", "").lower() == map_entry
|
||||
and nested.options.map_entry
|
||||
):
|
||||
# Get Python types
|
||||
self.py_k_type = FieldCompiler(
|
||||
parent=self, proto_obj=nested.field[0] # key
|
||||
).py_type
|
||||
self.py_v_type = FieldCompiler(
|
||||
parent=self, proto_obj=nested.field[1] # value
|
||||
).py_type
|
||||
# Get proto types
|
||||
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
|
||||
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
|
||||
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
|
||||
|
||||
@property
|
||||
@@ -513,11 +497,11 @@ class MapEntryCompiler(FieldCompiler):
|
||||
return "map"
|
||||
|
||||
@property
|
||||
def annotation(self):
|
||||
def annotation(self) -> str:
|
||||
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
|
||||
|
||||
@property
|
||||
def repeated(self):
|
||||
def repeated(self) -> bool:
|
||||
return False # maps cannot be repeated
|
||||
|
||||
|
||||
@@ -536,7 +520,7 @@ class EnumDefinitionCompiler(MessageCompiler):
|
||||
value: int
|
||||
comment: str
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
# Get entries/allowed values for this Enum
|
||||
self.entries = [
|
||||
self.EnumEntry(
|
||||
@@ -551,7 +535,7 @@ class EnumDefinitionCompiler(MessageCompiler):
|
||||
super().__post_init__() # call MessageCompiler __post_init__
|
||||
|
||||
@property
|
||||
def default_value_string(self) -> int:
|
||||
def default_value_string(self) -> str:
|
||||
"""Python representation of the default value for Enums.
|
||||
|
||||
As per the spec, this is the first value of the Enum.
|
||||
@@ -572,11 +556,11 @@ class ServiceCompiler(ProtoContentBase):
|
||||
super().__post_init__() # check for unset fields
|
||||
|
||||
@property
|
||||
def proto_name(self):
|
||||
def proto_name(self) -> str:
|
||||
return self.proto_obj.name
|
||||
|
||||
@property
|
||||
def py_name(self):
|
||||
def py_name(self) -> str:
|
||||
return pythonize_class_name(self.proto_name)
|
||||
|
||||
|
||||
@@ -628,7 +612,7 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
Name and actual default value (as a string)
|
||||
for each argument with mutable default values.
|
||||
"""
|
||||
mutable_default_args = dict()
|
||||
mutable_default_args = {}
|
||||
|
||||
if self.py_input_message:
|
||||
for f in self.py_input_message.fields:
|
||||
@@ -654,18 +638,15 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
|
||||
@property
|
||||
def route(self) -> str:
|
||||
return (
|
||||
f"/{self.output_file.package}."
|
||||
f"{self.parent.proto_name}/{self.proto_name}"
|
||||
)
|
||||
return f"/{self.output_file.package}.{self.parent.proto_name}/{self.proto_name}"
|
||||
|
||||
@property
|
||||
def py_input_message(self) -> Union[None, MessageCompiler]:
|
||||
def py_input_message(self) -> Optional[MessageCompiler]:
|
||||
"""Find the input message object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[None, MessageCompiler]
|
||||
Optional[MessageCompiler]
|
||||
Method instance representing the input message.
|
||||
If not input message could be found or there are no
|
||||
input messages, None is returned.
|
||||
@@ -685,14 +666,13 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
|
||||
@property
|
||||
def py_input_message_type(self) -> str:
|
||||
"""String representation of the Python type correspoding to the
|
||||
"""String representation of the Python type corresponding to the
|
||||
input message.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
String representation of the Python type correspoding to the
|
||||
input message.
|
||||
String representation of the Python type corresponding to the input message.
|
||||
"""
|
||||
return get_type_reference(
|
||||
package=self.output_file.package,
|
||||
@@ -702,14 +682,13 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
|
||||
@property
|
||||
def py_output_message_type(self) -> str:
|
||||
"""String representation of the Python type correspoding to the
|
||||
"""String representation of the Python type corresponding to the
|
||||
output message.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
String representation of the Python type correspoding to the
|
||||
output message.
|
||||
String representation of the Python type corresponding to the output message.
|
||||
"""
|
||||
return get_type_reference(
|
||||
package=self.output_file.package,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import itertools
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import List, Iterator
|
||||
from typing import TYPE_CHECKING, Iterator, List, Tuple, Union, Set
|
||||
|
||||
try:
|
||||
# betterproto[compiler] specific dependencies
|
||||
@@ -13,10 +13,9 @@ try:
|
||||
ServiceDescriptorProto,
|
||||
)
|
||||
except ImportError as err:
|
||||
missing_import = err.args[0][17:-1]
|
||||
print(
|
||||
"\033[31m"
|
||||
f"Unable to import `{missing_import}` from betterproto plugin! "
|
||||
f"Unable to import `{err.name}` from betterproto plugin! "
|
||||
"Please ensure that you've installed betterproto as "
|
||||
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
||||
"are included."
|
||||
@@ -24,26 +23,32 @@ except ImportError as err:
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
from betterproto.plugin.models import (
|
||||
PluginRequestCompiler,
|
||||
OutputTemplate,
|
||||
MessageCompiler,
|
||||
FieldCompiler,
|
||||
OneOfFieldCompiler,
|
||||
MapEntryCompiler,
|
||||
from .compiler import outputfile_compiler
|
||||
from .models import (
|
||||
EnumDefinitionCompiler,
|
||||
FieldCompiler,
|
||||
MapEntryCompiler,
|
||||
MessageCompiler,
|
||||
OneOfFieldCompiler,
|
||||
OutputTemplate,
|
||||
PluginRequestCompiler,
|
||||
ServiceCompiler,
|
||||
ServiceMethodCompiler,
|
||||
is_map,
|
||||
is_oneof,
|
||||
)
|
||||
|
||||
from betterproto.plugin.compiler import outputfile_compiler
|
||||
if TYPE_CHECKING:
|
||||
from google.protobuf.descriptor import Descriptor
|
||||
|
||||
|
||||
def traverse(proto_file: FieldDescriptorProto) -> Iterator:
|
||||
def traverse(
|
||||
proto_file: FieldDescriptorProto,
|
||||
) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]":
|
||||
# Todo: Keep information about nested hierarchy
|
||||
def _traverse(path, items, prefix=""):
|
||||
def _traverse(
|
||||
path: List[int], items: List["Descriptor"], prefix=""
|
||||
) -> Iterator[Tuple[Union[str, EnumDescriptorProto], List[int]]]:
|
||||
for i, item in enumerate(items):
|
||||
# Adjust the name since we flatten the hierarchy.
|
||||
# Todo: don't change the name, but include full name in returned tuple
|
||||
@@ -104,7 +109,7 @@ def generate_code(
|
||||
read_protobuf_service(service, index, output_package)
|
||||
|
||||
# Generate output files
|
||||
output_paths: pathlib.Path = set()
|
||||
output_paths: Set[pathlib.Path] = set()
|
||||
for output_package_name, output_package in request_data.output_packages.items():
|
||||
|
||||
# Add files to the response object
|
||||
@@ -112,20 +117,17 @@ def generate_code(
|
||||
output_paths.add(output_path)
|
||||
|
||||
f: response.File = response.file.add()
|
||||
f.name: str = str(output_path)
|
||||
f.name = str(output_path)
|
||||
|
||||
# Render and then format the output file
|
||||
f.content: str = outputfile_compiler(output_file=output_package)
|
||||
f.content = outputfile_compiler(output_file=output_package)
|
||||
|
||||
# Make each output directory a package with __init__ file
|
||||
init_files = (
|
||||
set(
|
||||
directory.joinpath("__init__.py")
|
||||
for path in output_paths
|
||||
for directory in path.parents
|
||||
)
|
||||
- output_paths
|
||||
)
|
||||
init_files = {
|
||||
directory.joinpath("__init__.py")
|
||||
for path in output_paths
|
||||
for directory in path.parents
|
||||
} - output_paths
|
||||
|
||||
for init_file in init_files:
|
||||
init = response.file.add()
|
||||
|
||||
Reference in New Issue
Block a user