QOL fixes (#141)
- Add missing type annotations - Various style improvements - Use constants more consistently - enforce black on benchmark code
This commit is contained in:
@@ -1,14 +1,14 @@
|
||||
"""Plugin model dataclasses.
|
||||
|
||||
These classes are meant to be an intermediate representation
|
||||
of protbuf objects. They are used to organize the data collected during parsing.
|
||||
of protobuf objects. They are used to organize the data collected during parsing.
|
||||
|
||||
The general intention is to create a doubly-linked tree-like structure
|
||||
with the following types of references:
|
||||
- Downwards references: from message -> fields, from output package -> messages
|
||||
or from service -> service methods
|
||||
- Upwards references: from field -> message, message -> package.
|
||||
- Input/ouput message references: from a service method to it's corresponding
|
||||
- Input/output message references: from a service method to it's corresponding
|
||||
input/output messages, which may even be in another package.
|
||||
|
||||
There are convenience methods to allow climbing up and down this tree, for
|
||||
@@ -26,36 +26,24 @@ such as a pythonized name, that will be calculated from proto_obj.
|
||||
The instantiation should also attach a reference to the new object
|
||||
into the corresponding place within it's parent object. For example,
|
||||
instantiating field `A` with parent message `B` should add a
|
||||
reference to `A` to `B`'s `fields` attirbute.
|
||||
reference to `A` to `B`'s `fields` attribute.
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from typing import (
|
||||
Iterator,
|
||||
Union,
|
||||
Type,
|
||||
List,
|
||||
Dict,
|
||||
Set,
|
||||
Text,
|
||||
)
|
||||
import textwrap
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Iterator, List, Optional, Set, Text, Type, Union
|
||||
|
||||
import betterproto
|
||||
from betterproto.compile.importing import (
|
||||
get_type_reference,
|
||||
parse_source_type_name,
|
||||
)
|
||||
from betterproto.compile.naming import (
|
||||
|
||||
from ..casing import sanitize_name
|
||||
from ..compile.importing import get_type_reference, parse_source_type_name
|
||||
from ..compile.naming import (
|
||||
pythonize_class_name,
|
||||
pythonize_field_name,
|
||||
pythonize_method_name,
|
||||
)
|
||||
|
||||
from ..casing import sanitize_name
|
||||
|
||||
try:
|
||||
# betterproto[compiler] specific dependencies
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
@@ -67,10 +55,9 @@ try:
|
||||
MethodDescriptorProto,
|
||||
)
|
||||
except ImportError as err:
|
||||
missing_import = re.match(r".*(cannot import name .*$)", err.args[0]).group(1)
|
||||
print(
|
||||
"\033[31m"
|
||||
f"Unable to import `{missing_import}` from betterproto plugin! "
|
||||
f"Unable to import `{err.name}` from betterproto plugin! "
|
||||
"Please ensure that you've installed betterproto as "
|
||||
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
||||
"are included."
|
||||
@@ -124,10 +111,11 @@ PROTO_PACKED_TYPES = (
|
||||
)
|
||||
|
||||
|
||||
def get_comment(proto_file, path: List[int], indent: int = 4) -> str:
|
||||
def get_comment(
|
||||
proto_file: "FileDescriptorProto", path: List[int], indent: int = 4
|
||||
) -> str:
|
||||
pad = " " * indent
|
||||
for sci in proto_file.source_code_info.location:
|
||||
# print(list(sci.path), path, file=sys.stderr)
|
||||
if list(sci.path) == path and sci.leading_comments:
|
||||
lines = textwrap.wrap(
|
||||
sci.leading_comments.strip().replace("\n", ""), width=79 - indent
|
||||
@@ -153,9 +141,9 @@ class ProtoContentBase:
|
||||
|
||||
path: List[int]
|
||||
comment_indent: int = 4
|
||||
parent: Union["Messsage", "OutputTemplate"]
|
||||
parent: Union["betterproto.Message", "OutputTemplate"]
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
"""Checks that no fake default fields were left as placeholders."""
|
||||
for field_name, field_val in self.__dataclass_fields__.items():
|
||||
if field_val is PLACEHOLDER:
|
||||
@@ -273,7 +261,7 @@ class MessageCompiler(ProtoContentBase):
|
||||
)
|
||||
deprecated: bool = field(default=False, init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
# Add message to output file
|
||||
if isinstance(self.parent, OutputTemplate):
|
||||
if isinstance(self, EnumDefinitionCompiler):
|
||||
@@ -314,17 +302,17 @@ def is_map(
|
||||
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
|
||||
if message_type == map_entry:
|
||||
for nested in parent_message.nested_type: # parent message
|
||||
if nested.name.replace("_", "").lower() == map_entry:
|
||||
if nested.options.map_entry:
|
||||
return True
|
||||
if (
|
||||
nested.name.replace("_", "").lower() == map_entry
|
||||
and nested.options.map_entry
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
|
||||
"""True if proto_field_obj is a OneOf, otherwise False."""
|
||||
if proto_field_obj.HasField("oneof_index"):
|
||||
return True
|
||||
return False
|
||||
return proto_field_obj.HasField("oneof_index")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -332,7 +320,7 @@ class FieldCompiler(MessageCompiler):
|
||||
parent: MessageCompiler = PLACEHOLDER
|
||||
proto_obj: FieldDescriptorProto = PLACEHOLDER
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
# Add field to message
|
||||
self.parent.fields.append(self)
|
||||
# Check for new imports
|
||||
@@ -357,11 +345,9 @@ class FieldCompiler(MessageCompiler):
|
||||
([""] + self.betterproto_field_args) if self.betterproto_field_args else []
|
||||
)
|
||||
betterproto_field_type = (
|
||||
f"betterproto.{self.field_type}_field({self.proto_obj.number}"
|
||||
+ field_args
|
||||
+ ")"
|
||||
f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})"
|
||||
)
|
||||
return name + annotations + " = " + betterproto_field_type
|
||||
return f"{name}{annotations} = {betterproto_field_type}"
|
||||
|
||||
@property
|
||||
def betterproto_field_args(self) -> List[str]:
|
||||
@@ -371,7 +357,7 @@ class FieldCompiler(MessageCompiler):
|
||||
return args
|
||||
|
||||
@property
|
||||
def field_wraps(self) -> Union[str, None]:
|
||||
def field_wraps(self) -> Optional[str]:
|
||||
"""Returns betterproto wrapped field type or None."""
|
||||
match_wrapper = re.match(
|
||||
r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name
|
||||
@@ -384,17 +370,15 @@ class FieldCompiler(MessageCompiler):
|
||||
|
||||
@property
|
||||
def repeated(self) -> bool:
|
||||
if self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED and not is_map(
|
||||
self.proto_obj, self.parent
|
||||
):
|
||||
return True
|
||||
return False
|
||||
return (
|
||||
self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED
|
||||
and not is_map(self.proto_obj, self.parent)
|
||||
)
|
||||
|
||||
@property
|
||||
def mutable(self) -> bool:
|
||||
"""True if the field is a mutable type, otherwise False."""
|
||||
annotation = self.annotation
|
||||
return annotation.startswith("List[") or annotation.startswith("Dict[")
|
||||
return self.annotation.startswith(("List[", "Dict["))
|
||||
|
||||
@property
|
||||
def field_type(self) -> str:
|
||||
@@ -425,9 +409,7 @@ class FieldCompiler(MessageCompiler):
|
||||
@property
|
||||
def packed(self) -> bool:
|
||||
"""True if the wire representation is a packed format."""
|
||||
if self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES:
|
||||
return True
|
||||
return False
|
||||
return self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES
|
||||
|
||||
@property
|
||||
def py_name(self) -> str:
|
||||
@@ -486,22 +468,24 @@ class MapEntryCompiler(FieldCompiler):
|
||||
proto_k_type: str = PLACEHOLDER
|
||||
proto_v_type: str = PLACEHOLDER
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
"""Explore nested types and set k_type and v_type if unset."""
|
||||
map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry"
|
||||
for nested in self.parent.proto_obj.nested_type:
|
||||
if nested.name.replace("_", "").lower() == map_entry:
|
||||
if nested.options.map_entry:
|
||||
# Get Python types
|
||||
self.py_k_type = FieldCompiler(
|
||||
parent=self, proto_obj=nested.field[0] # key
|
||||
).py_type
|
||||
self.py_v_type = FieldCompiler(
|
||||
parent=self, proto_obj=nested.field[1] # value
|
||||
).py_type
|
||||
# Get proto types
|
||||
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
|
||||
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
|
||||
if (
|
||||
nested.name.replace("_", "").lower() == map_entry
|
||||
and nested.options.map_entry
|
||||
):
|
||||
# Get Python types
|
||||
self.py_k_type = FieldCompiler(
|
||||
parent=self, proto_obj=nested.field[0] # key
|
||||
).py_type
|
||||
self.py_v_type = FieldCompiler(
|
||||
parent=self, proto_obj=nested.field[1] # value
|
||||
).py_type
|
||||
# Get proto types
|
||||
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
|
||||
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
|
||||
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
|
||||
|
||||
@property
|
||||
@@ -513,11 +497,11 @@ class MapEntryCompiler(FieldCompiler):
|
||||
return "map"
|
||||
|
||||
@property
|
||||
def annotation(self):
|
||||
def annotation(self) -> str:
|
||||
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
|
||||
|
||||
@property
|
||||
def repeated(self):
|
||||
def repeated(self) -> bool:
|
||||
return False # maps cannot be repeated
|
||||
|
||||
|
||||
@@ -536,7 +520,7 @@ class EnumDefinitionCompiler(MessageCompiler):
|
||||
value: int
|
||||
comment: str
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
# Get entries/allowed values for this Enum
|
||||
self.entries = [
|
||||
self.EnumEntry(
|
||||
@@ -551,7 +535,7 @@ class EnumDefinitionCompiler(MessageCompiler):
|
||||
super().__post_init__() # call MessageCompiler __post_init__
|
||||
|
||||
@property
|
||||
def default_value_string(self) -> int:
|
||||
def default_value_string(self) -> str:
|
||||
"""Python representation of the default value for Enums.
|
||||
|
||||
As per the spec, this is the first value of the Enum.
|
||||
@@ -572,11 +556,11 @@ class ServiceCompiler(ProtoContentBase):
|
||||
super().__post_init__() # check for unset fields
|
||||
|
||||
@property
|
||||
def proto_name(self):
|
||||
def proto_name(self) -> str:
|
||||
return self.proto_obj.name
|
||||
|
||||
@property
|
||||
def py_name(self):
|
||||
def py_name(self) -> str:
|
||||
return pythonize_class_name(self.proto_name)
|
||||
|
||||
|
||||
@@ -628,7 +612,7 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
Name and actual default value (as a string)
|
||||
for each argument with mutable default values.
|
||||
"""
|
||||
mutable_default_args = dict()
|
||||
mutable_default_args = {}
|
||||
|
||||
if self.py_input_message:
|
||||
for f in self.py_input_message.fields:
|
||||
@@ -654,18 +638,15 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
|
||||
@property
|
||||
def route(self) -> str:
|
||||
return (
|
||||
f"/{self.output_file.package}."
|
||||
f"{self.parent.proto_name}/{self.proto_name}"
|
||||
)
|
||||
return f"/{self.output_file.package}.{self.parent.proto_name}/{self.proto_name}"
|
||||
|
||||
@property
|
||||
def py_input_message(self) -> Union[None, MessageCompiler]:
|
||||
def py_input_message(self) -> Optional[MessageCompiler]:
|
||||
"""Find the input message object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[None, MessageCompiler]
|
||||
Optional[MessageCompiler]
|
||||
Method instance representing the input message.
|
||||
If not input message could be found or there are no
|
||||
input messages, None is returned.
|
||||
@@ -685,14 +666,13 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
|
||||
@property
|
||||
def py_input_message_type(self) -> str:
|
||||
"""String representation of the Python type correspoding to the
|
||||
"""String representation of the Python type corresponding to the
|
||||
input message.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
String representation of the Python type correspoding to the
|
||||
input message.
|
||||
String representation of the Python type corresponding to the input message.
|
||||
"""
|
||||
return get_type_reference(
|
||||
package=self.output_file.package,
|
||||
@@ -702,14 +682,13 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
|
||||
@property
|
||||
def py_output_message_type(self) -> str:
|
||||
"""String representation of the Python type correspoding to the
|
||||
"""String representation of the Python type corresponding to the
|
||||
output message.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
String representation of the Python type correspoding to the
|
||||
output message.
|
||||
String representation of the Python type corresponding to the output message.
|
||||
"""
|
||||
return get_type_reference(
|
||||
package=self.output_file.package,
|
||||
|
||||
Reference in New Issue
Block a user