QOL fixes (#141)

- Add missing type annotations
- Various style improvements
- Use constants more consistently
- enforce black on benchmark code
This commit is contained in:
James
2020-10-17 18:27:11 +01:00
committed by GitHub
parent bf9412e083
commit 8f7af272cc
16 changed files with 177 additions and 220 deletions

View File

@@ -5,10 +5,9 @@ try:
import black
import jinja2
except ImportError as err:
missing_import = err.args[0][17:-1]
print(
"\033[31m"
f"Unable to import `{missing_import}` from betterproto plugin! "
f"Unable to import `{err.name}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
@@ -16,7 +15,7 @@ except ImportError as err:
)
raise SystemExit(1)
from betterproto.plugin.models import OutputTemplate
from .models import OutputTemplate
def outputfile_compiler(output_file: OutputTemplate) -> str:
@@ -32,9 +31,7 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
)
template = env.get_template("template.py.j2")
res = black.format_str(
return black.format_str(
template.render(output_file=output_file),
mode=black.FileMode(target_versions={black.TargetVersion.PY37}),
)
return res

View File

@@ -1,13 +1,14 @@
#!/usr/bin/env python
import sys
import os
import sys
from google.protobuf.compiler import plugin_pb2 as plugin
from betterproto.plugin.parser import generate_code
def main():
def main() -> None:
"""The plugin's main entry point."""
# Read request message from stdin
data = sys.stdin.buffer.read()
@@ -33,7 +34,7 @@ def main():
sys.stdout.buffer.write(output)
def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest):
def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest) -> None:
"""
For developers: Supports running plugin.py standalone so its possible to debug it.
Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file.

View File

@@ -1,14 +1,14 @@
"""Plugin model dataclasses.
These classes are meant to be an intermediate representation
of protbuf objects. They are used to organize the data collected during parsing.
of protobuf objects. They are used to organize the data collected during parsing.
The general intention is to create a doubly-linked tree-like structure
with the following types of references:
- Downwards references: from message -> fields, from output package -> messages
or from service -> service methods
- Upwards references: from field -> message, message -> package.
- Input/ouput message references: from a service method to it's corresponding
- Input/output message references: from a service method to it's corresponding
input/output messages, which may even be in another package.
There are convenience methods to allow climbing up and down this tree, for
@@ -26,36 +26,24 @@ such as a pythonized name, that will be calculated from proto_obj.
The instantiation should also attach a reference to the new object
into the corresponding place within it's parent object. For example,
instantiating field `A` with parent message `B` should add a
reference to `A` to `B`'s `fields` attirbute.
reference to `A` to `B`'s `fields` attribute.
"""
import re
from dataclasses import dataclass
from dataclasses import field
from typing import (
Iterator,
Union,
Type,
List,
Dict,
Set,
Text,
)
import textwrap
from dataclasses import dataclass, field
from typing import Dict, Iterator, List, Optional, Set, Text, Type, Union
import betterproto
from betterproto.compile.importing import (
get_type_reference,
parse_source_type_name,
)
from betterproto.compile.naming import (
from ..casing import sanitize_name
from ..compile.importing import get_type_reference, parse_source_type_name
from ..compile.naming import (
pythonize_class_name,
pythonize_field_name,
pythonize_method_name,
)
from ..casing import sanitize_name
try:
# betterproto[compiler] specific dependencies
from google.protobuf.compiler import plugin_pb2 as plugin
@@ -67,10 +55,9 @@ try:
MethodDescriptorProto,
)
except ImportError as err:
missing_import = re.match(r".*(cannot import name .*$)", err.args[0]).group(1)
print(
"\033[31m"
f"Unable to import `{missing_import}` from betterproto plugin! "
f"Unable to import `{err.name}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
@@ -124,10 +111,11 @@ PROTO_PACKED_TYPES = (
)
def get_comment(proto_file, path: List[int], indent: int = 4) -> str:
def get_comment(
proto_file: "FileDescriptorProto", path: List[int], indent: int = 4
) -> str:
pad = " " * indent
for sci in proto_file.source_code_info.location:
# print(list(sci.path), path, file=sys.stderr)
if list(sci.path) == path and sci.leading_comments:
lines = textwrap.wrap(
sci.leading_comments.strip().replace("\n", ""), width=79 - indent
@@ -153,9 +141,9 @@ class ProtoContentBase:
path: List[int]
comment_indent: int = 4
parent: Union["Messsage", "OutputTemplate"]
parent: Union["betterproto.Message", "OutputTemplate"]
def __post_init__(self):
def __post_init__(self) -> None:
"""Checks that no fake default fields were left as placeholders."""
for field_name, field_val in self.__dataclass_fields__.items():
if field_val is PLACEHOLDER:
@@ -273,7 +261,7 @@ class MessageCompiler(ProtoContentBase):
)
deprecated: bool = field(default=False, init=False)
def __post_init__(self):
def __post_init__(self) -> None:
# Add message to output file
if isinstance(self.parent, OutputTemplate):
if isinstance(self, EnumDefinitionCompiler):
@@ -314,17 +302,17 @@ def is_map(
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
if message_type == map_entry:
for nested in parent_message.nested_type: # parent message
if nested.name.replace("_", "").lower() == map_entry:
if nested.options.map_entry:
return True
if (
nested.name.replace("_", "").lower() == map_entry
and nested.options.map_entry
):
return True
return False
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
"""True if proto_field_obj is a OneOf, otherwise False."""
if proto_field_obj.HasField("oneof_index"):
return True
return False
return proto_field_obj.HasField("oneof_index")
@dataclass
@@ -332,7 +320,7 @@ class FieldCompiler(MessageCompiler):
parent: MessageCompiler = PLACEHOLDER
proto_obj: FieldDescriptorProto = PLACEHOLDER
def __post_init__(self):
def __post_init__(self) -> None:
# Add field to message
self.parent.fields.append(self)
# Check for new imports
@@ -357,11 +345,9 @@ class FieldCompiler(MessageCompiler):
([""] + self.betterproto_field_args) if self.betterproto_field_args else []
)
betterproto_field_type = (
f"betterproto.{self.field_type}_field({self.proto_obj.number}"
+ field_args
+ ")"
f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})"
)
return name + annotations + " = " + betterproto_field_type
return f"{name}{annotations} = {betterproto_field_type}"
@property
def betterproto_field_args(self) -> List[str]:
@@ -371,7 +357,7 @@ class FieldCompiler(MessageCompiler):
return args
@property
def field_wraps(self) -> Union[str, None]:
def field_wraps(self) -> Optional[str]:
"""Returns betterproto wrapped field type or None."""
match_wrapper = re.match(
r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name
@@ -384,17 +370,15 @@ class FieldCompiler(MessageCompiler):
@property
def repeated(self) -> bool:
if self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED and not is_map(
self.proto_obj, self.parent
):
return True
return False
return (
self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED
and not is_map(self.proto_obj, self.parent)
)
@property
def mutable(self) -> bool:
"""True if the field is a mutable type, otherwise False."""
annotation = self.annotation
return annotation.startswith("List[") or annotation.startswith("Dict[")
return self.annotation.startswith(("List[", "Dict["))
@property
def field_type(self) -> str:
@@ -425,9 +409,7 @@ class FieldCompiler(MessageCompiler):
@property
def packed(self) -> bool:
"""True if the wire representation is a packed format."""
if self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES:
return True
return False
return self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES
@property
def py_name(self) -> str:
@@ -486,22 +468,24 @@ class MapEntryCompiler(FieldCompiler):
proto_k_type: str = PLACEHOLDER
proto_v_type: str = PLACEHOLDER
def __post_init__(self):
def __post_init__(self) -> None:
"""Explore nested types and set k_type and v_type if unset."""
map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry"
for nested in self.parent.proto_obj.nested_type:
if nested.name.replace("_", "").lower() == map_entry:
if nested.options.map_entry:
# Get Python types
self.py_k_type = FieldCompiler(
parent=self, proto_obj=nested.field[0] # key
).py_type
self.py_v_type = FieldCompiler(
parent=self, proto_obj=nested.field[1] # value
).py_type
# Get proto types
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
if (
nested.name.replace("_", "").lower() == map_entry
and nested.options.map_entry
):
# Get Python types
self.py_k_type = FieldCompiler(
parent=self, proto_obj=nested.field[0] # key
).py_type
self.py_v_type = FieldCompiler(
parent=self, proto_obj=nested.field[1] # value
).py_type
# Get proto types
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
@property
@@ -513,11 +497,11 @@ class MapEntryCompiler(FieldCompiler):
return "map"
@property
def annotation(self):
def annotation(self) -> str:
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
@property
def repeated(self):
def repeated(self) -> bool:
return False # maps cannot be repeated
@@ -536,7 +520,7 @@ class EnumDefinitionCompiler(MessageCompiler):
value: int
comment: str
def __post_init__(self):
def __post_init__(self) -> None:
# Get entries/allowed values for this Enum
self.entries = [
self.EnumEntry(
@@ -551,7 +535,7 @@ class EnumDefinitionCompiler(MessageCompiler):
super().__post_init__() # call MessageCompiler __post_init__
@property
def default_value_string(self) -> int:
def default_value_string(self) -> str:
"""Python representation of the default value for Enums.
As per the spec, this is the first value of the Enum.
@@ -572,11 +556,11 @@ class ServiceCompiler(ProtoContentBase):
super().__post_init__() # check for unset fields
@property
def proto_name(self):
def proto_name(self) -> str:
return self.proto_obj.name
@property
def py_name(self):
def py_name(self) -> str:
return pythonize_class_name(self.proto_name)
@@ -628,7 +612,7 @@ class ServiceMethodCompiler(ProtoContentBase):
Name and actual default value (as a string)
for each argument with mutable default values.
"""
mutable_default_args = dict()
mutable_default_args = {}
if self.py_input_message:
for f in self.py_input_message.fields:
@@ -654,18 +638,15 @@ class ServiceMethodCompiler(ProtoContentBase):
@property
def route(self) -> str:
return (
f"/{self.output_file.package}."
f"{self.parent.proto_name}/{self.proto_name}"
)
return f"/{self.output_file.package}.{self.parent.proto_name}/{self.proto_name}"
@property
def py_input_message(self) -> Union[None, MessageCompiler]:
def py_input_message(self) -> Optional[MessageCompiler]:
"""Find the input message object.
Returns
-------
Union[None, MessageCompiler]
Optional[MessageCompiler]
Method instance representing the input message.
If not input message could be found or there are no
input messages, None is returned.
@@ -685,14 +666,13 @@ class ServiceMethodCompiler(ProtoContentBase):
@property
def py_input_message_type(self) -> str:
"""String representation of the Python type correspoding to the
"""String representation of the Python type corresponding to the
input message.
Returns
-------
str
String representation of the Python type correspoding to the
input message.
String representation of the Python type corresponding to the input message.
"""
return get_type_reference(
package=self.output_file.package,
@@ -702,14 +682,13 @@ class ServiceMethodCompiler(ProtoContentBase):
@property
def py_output_message_type(self) -> str:
"""String representation of the Python type correspoding to the
"""String representation of the Python type corresponding to the
output message.
Returns
-------
str
String representation of the Python type correspoding to the
output message.
String representation of the Python type corresponding to the output message.
"""
return get_type_reference(
package=self.output_file.package,

View File

@@ -1,7 +1,7 @@
import itertools
import pathlib
import sys
from typing import List, Iterator
from typing import TYPE_CHECKING, Iterator, List, Tuple, Union, Set
try:
# betterproto[compiler] specific dependencies
@@ -13,10 +13,9 @@ try:
ServiceDescriptorProto,
)
except ImportError as err:
missing_import = err.args[0][17:-1]
print(
"\033[31m"
f"Unable to import `{missing_import}` from betterproto plugin! "
f"Unable to import `{err.name}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
@@ -24,26 +23,32 @@ except ImportError as err:
)
raise SystemExit(1)
from betterproto.plugin.models import (
PluginRequestCompiler,
OutputTemplate,
MessageCompiler,
FieldCompiler,
OneOfFieldCompiler,
MapEntryCompiler,
from .compiler import outputfile_compiler
from .models import (
EnumDefinitionCompiler,
FieldCompiler,
MapEntryCompiler,
MessageCompiler,
OneOfFieldCompiler,
OutputTemplate,
PluginRequestCompiler,
ServiceCompiler,
ServiceMethodCompiler,
is_map,
is_oneof,
)
from betterproto.plugin.compiler import outputfile_compiler
if TYPE_CHECKING:
from google.protobuf.descriptor import Descriptor
def traverse(proto_file: FieldDescriptorProto) -> Iterator:
def traverse(
proto_file: FieldDescriptorProto,
) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]":
# Todo: Keep information about nested hierarchy
def _traverse(path, items, prefix=""):
def _traverse(
path: List[int], items: List["Descriptor"], prefix=""
) -> Iterator[Tuple[Union[str, EnumDescriptorProto], List[int]]]:
for i, item in enumerate(items):
# Adjust the name since we flatten the hierarchy.
# Todo: don't change the name, but include full name in returned tuple
@@ -104,7 +109,7 @@ def generate_code(
read_protobuf_service(service, index, output_package)
# Generate output files
output_paths: pathlib.Path = set()
output_paths: Set[pathlib.Path] = set()
for output_package_name, output_package in request_data.output_packages.items():
# Add files to the response object
@@ -112,20 +117,17 @@ def generate_code(
output_paths.add(output_path)
f: response.File = response.file.add()
f.name: str = str(output_path)
f.name = str(output_path)
# Render and then format the output file
f.content: str = outputfile_compiler(output_file=output_package)
f.content = outputfile_compiler(output_file=output_package)
# Make each output directory a package with __init__ file
init_files = (
set(
directory.joinpath("__init__.py")
for path in output_paths
for directory in path.parents
)
- output_paths
)
init_files = {
directory.joinpath("__init__.py")
for path in output_paths
for directory in path.parents
} - output_paths
for init_file in init_files:
init = response.file.add()