QOL fixes (#141)

- Add missing type annotations
- Various style improvements
- Use constants more consistently
- enforce black on benchmark code
This commit is contained in:
James 2020-10-17 18:27:11 +01:00 committed by GitHub
parent bf9412e083
commit 8f7af272cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 177 additions and 220 deletions

View File

@ -17,7 +17,7 @@ jobs:
- name: Run Black
uses: lgeiger/black-action@master
with:
args: --check src/ tests/
args: --check src/ tests/ benchmarks/
- name: Install rST dependcies
run: python -m pip install doc8

1
.gitignore vendored
View File

@ -16,3 +16,4 @@ output
.tox
.venv
.asv
venv

View File

@ -8,9 +8,9 @@ class TestMessage(betterproto.Message):
bar: str = betterproto.string_field(1)
baz: float = betterproto.float_field(2)
class BenchMessage:
"""Test creation and usage a proto message.
"""
"""Test creation and usage a proto message."""
def setup(self):
self.cls = TestMessage
@ -18,8 +18,8 @@ class BenchMessage:
self.instance_filled = TestMessage(0, "test", 0.0)
def time_overhead(self):
"""Overhead in class definition.
"""
"""Overhead in class definition."""
@dataclass
class Message(betterproto.Message):
foo: int = betterproto.uint32_field(0)
@ -27,25 +27,21 @@ class BenchMessage:
baz: float = betterproto.float_field(2)
def time_instantiation(self):
"""Time instantiation
"""
"""Time instantiation"""
self.cls()
def time_attribute_access(self):
"""Time to access an attribute
"""
"""Time to access an attribute"""
self.instance.foo
self.instance.bar
self.instance.baz
def time_init_with_values(self):
"""Time to set an attribute
"""
"""Time to set an attribute"""
self.cls(0, "test", 0.0)
def time_attribute_setting(self):
"""Time to set attributes
"""
"""Time to set attributes"""
self.instance.foo = 0
self.instance.bar = "test"
self.instance.baz = 0.0

View File

@ -26,7 +26,7 @@ from ._types import T
from .casing import camel_case, safe_snake_case, snake_case
from .grpc.grpclib_client import ServiceStub
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
if sys.version_info[:2] < (3, 7):
# Apply backport of datetime.fromisoformat from 3.7
from backports.datetime_fromisoformat import MonkeyPatch
@ -110,7 +110,7 @@ WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
def datetime_default_gen():
def datetime_default_gen() -> datetime:
return datetime(1970, 1, 1, tzinfo=timezone.utc)
@ -256,8 +256,7 @@ class Enum(enum.IntEnum):
@classmethod
def from_string(cls, name: str) -> "Enum":
"""
Return the value which corresponds to the string name.
"""Return the value which corresponds to the string name.
Parameters
-----------
@ -316,11 +315,7 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
return encode_varint(value)
elif proto_type in [TYPE_SINT32, TYPE_SINT64]:
# Handle zig-zag encoding.
if value >= 0:
value = value << 1
else:
value = (value << 1) ^ (~0)
return encode_varint(value)
return encode_varint(value << 1 if value >= 0 else (value << 1) ^ (~0))
elif proto_type in FIXED_TYPES:
return struct.pack(_pack_fmt(proto_type), value)
elif proto_type == TYPE_STRING:
@ -413,15 +408,15 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
wire_type = num_wire & 0x7
decoded: Any = None
if wire_type == 0:
if wire_type == WIRE_VARINT:
decoded, i = decode_varint(value, i)
elif wire_type == 1:
elif wire_type == WIRE_FIXED_64:
decoded, i = value[i : i + 8], i + 8
elif wire_type == 2:
elif wire_type == WIRE_LEN_DELIM:
length, i = decode_varint(value, i)
decoded = value[i : i + length]
i += length
elif wire_type == 5:
elif wire_type == WIRE_FIXED_32:
decoded, i = value[i : i + 4], i + 4
yield ParsedField(
@ -430,12 +425,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
class ProtoClassMetadata:
oneof_group_by_field: Dict[str, str]
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
default_gen: Dict[str, Callable]
cls_by_field: Dict[str, Type]
field_name_by_number: Dict[int, str]
meta_by_field_name: Dict[str, FieldMetadata]
__slots__ = (
"oneof_group_by_field",
"oneof_field_by_group",
@ -446,6 +435,14 @@ class ProtoClassMetadata:
"sorted_field_names",
)
oneof_group_by_field: Dict[str, str]
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
field_name_by_number: Dict[int, str]
meta_by_field_name: Dict[str, FieldMetadata]
sorted_field_names: Tuple[str, ...]
default_gen: Dict[str, Callable[[], Any]]
cls_by_field: Dict[str, Type]
def __init__(self, cls: Type["Message"]):
by_field = {}
by_group: Dict[str, Set] = {}
@ -470,23 +467,21 @@ class ProtoClassMetadata:
self.field_name_by_number = by_field_number
self.meta_by_field_name = by_field_name
self.sorted_field_names = tuple(
by_field_number[number] for number in sorted(by_field_number.keys())
by_field_number[number] for number in sorted(by_field_number)
)
self.default_gen = self._get_default_gen(cls, fields)
self.cls_by_field = self._get_cls_by_field(cls, fields)
@staticmethod
def _get_default_gen(cls, fields):
default_gen = {}
for field in fields:
default_gen[field.name] = cls._get_field_default_gen(field)
return default_gen
def _get_default_gen(
cls: Type["Message"], fields: List[dataclasses.Field]
) -> Dict[str, Callable[[], Any]]:
return {field.name: cls._get_field_default_gen(field) for field in fields}
@staticmethod
def _get_cls_by_field(cls, fields):
def _get_cls_by_field(
cls: Type["Message"], fields: List[dataclasses.Field]
) -> Dict[str, Type]:
field_cls = {}
for field in fields:
@ -503,7 +498,7 @@ class ProtoClassMetadata:
],
bases=(Message,),
)
field_cls[field.name + ".value"] = vt
field_cls[f"{field.name}.value"] = vt
else:
field_cls[field.name] = cls._cls_for(field)
@ -612,7 +607,7 @@ class Message(ABC):
super().__setattr__(attr, value)
@property
def _betterproto(self):
def _betterproto(self) -> ProtoClassMetadata:
"""
Lazy initialize metadata for each protobuf class.
It may be initialized multiple times in a multi-threaded environment,
@ -726,9 +721,8 @@ class Message(ABC):
@classmethod
def _type_hints(cls) -> Dict[str, Type]:
module = inspect.getmodule(cls)
type_hints = get_type_hints(cls, vars(module))
return type_hints
module = sys.modules[cls.__module__]
return get_type_hints(cls, vars(module))
@classmethod
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
@ -739,7 +733,7 @@ class Message(ABC):
field_cls = field_cls.__args__[index]
return field_cls
def _get_field_default(self, field_name):
def _get_field_default(self, field_name: str) -> Any:
return self._betterproto.default_gen[field_name]()
@classmethod
@ -762,7 +756,7 @@ class Message(ABC):
elif issubclass(t, Enum):
# Enums always default to zero.
return int
elif t == datetime:
elif t is datetime:
# Offsets are relative to 1970-01-01T00:00:00Z
return datetime_default_gen
else:
@ -966,7 +960,7 @@ class Message(ABC):
)
):
output[cased_name] = value.to_dict(casing, include_default_values)
elif meta.proto_type == "map":
elif meta.proto_type == TYPE_MAP:
for k in value:
if hasattr(value[k], "to_dict"):
value[k] = value[k].to_dict(casing, include_default_values)
@ -1032,12 +1026,12 @@ class Message(ABC):
continue
if value[key] is not None:
if meta.proto_type == "message":
if meta.proto_type == TYPE_MESSAGE:
v = getattr(self, field_name)
if isinstance(v, list):
cls = self._betterproto.cls_by_field[field_name]
for i in range(len(value[key])):
v.append(cls().from_dict(value[key][i]))
for item in value[key]:
v.append(cls().from_dict(item))
elif isinstance(v, datetime):
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
setattr(self, field_name, v)
@ -1052,7 +1046,7 @@ class Message(ABC):
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field_name)
cls = self._betterproto.cls_by_field[field_name + ".value"]
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
for k in value[key]:
v[k] = cls().from_dict(value[key][k])
else:
@ -1134,7 +1128,7 @@ def serialized_on_wire(message: Message) -> bool:
return message._serialized_on_wire
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]:
"""
Return the name and value of a message's one-of field group.
@ -1145,21 +1139,21 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
"""
field_name = message._group_current.get(group_name)
if not field_name:
return ("", None)
return (field_name, getattr(message, field_name))
return "", None
return field_name, getattr(message, field_name)
# Circular import workaround: google.protobuf depends on base classes defined above.
from .lib.google.protobuf import ( # noqa
Duration,
Timestamp,
BoolValue,
BytesValue,
DoubleValue,
Duration,
FloatValue,
Int32Value,
Int64Value,
StringValue,
Timestamp,
UInt32Value,
UInt64Value,
)
@ -1174,8 +1168,8 @@ class _Duration(Duration):
parts = str(delta.total_seconds()).split(".")
if len(parts) > 1:
while len(parts[1]) not in [3, 6, 9]:
parts[1] = parts[1] + "0"
return ".".join(parts) + "s"
parts[1] = f"{parts[1]}0"
return f"{'.'.join(parts)}s"
class _Timestamp(Timestamp):
@ -1191,15 +1185,15 @@ class _Timestamp(Timestamp):
if (nanos % 1e9) == 0:
# If there are 0 fractional digits, the fractional
# point '.' should be omitted when serializing.
return result + "Z"
return f"{result}Z"
if (nanos % 1e6) == 0:
# Serialize 3 fractional digits.
return result + ".%03dZ" % (nanos / 1e6)
return f"{result}.{int(nanos // 1e6) :03d}Z"
if (nanos % 1e3) == 0:
# Serialize 6 fractional digits.
return result + ".%06dZ" % (nanos / 1e3)
return f"{result}.{int(nanos // 1e3) :06d}Z"
# Serialize 9 fractional digits.
return result + ".%09dZ" % nanos
return f"{result}.{nanos:09d}"
class _WrappedMessage(Message):

View File

@ -1,8 +1,8 @@
from typing import TYPE_CHECKING, TypeVar
if TYPE_CHECKING:
from . import Message
from grpclib._typing import IProtoMessage
from . import Message
# Bound type variable to allow methods to return `self` of subclasses
T = TypeVar("T", bound="Message")

View File

@ -1,10 +1,10 @@
import os
import re
from typing import Dict, List, Set, Type
from typing import Dict, List, Set, Tuple, Type
from betterproto import safe_snake_case
from betterproto.compile.naming import pythonize_class_name
from betterproto.lib.google import protobuf as google_protobuf
from ..casing import safe_snake_case
from ..lib.google import protobuf as google_protobuf
from .naming import pythonize_class_name
WRAPPER_TYPES: Dict[str, Type] = {
".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
@ -19,7 +19,7 @@ WRAPPER_TYPES: Dict[str, Type] = {
}
def parse_source_type_name(field_type_name):
def parse_source_type_name(field_type_name: str) -> Tuple[str, str]:
"""
Split full source type name into package and type name.
E.g. 'root.package.Message' -> ('root.package', 'Message')
@ -50,7 +50,7 @@ def get_type_reference(
if source_type == ".google.protobuf.Duration":
return "timedelta"
if source_type == ".google.protobuf.Timestamp":
elif source_type == ".google.protobuf.Timestamp":
return "datetime"
source_package, source_type = parse_source_type_name(source_type)
@ -79,7 +79,7 @@ def get_type_reference(
return reference_cousin(current_package, imports, py_package, py_type)
def reference_absolute(imports, py_package, py_type):
def reference_absolute(imports: Set[str], py_package: List[str], py_type: str) -> str:
"""
Returns a reference to a python type located in the root, i.e. sys.path.
"""

View File

@ -1,13 +1,13 @@
from betterproto import casing
def pythonize_class_name(name):
def pythonize_class_name(name: str) -> str:
return casing.pascal_case(name)
def pythonize_field_name(name: str):
def pythonize_field_name(name: str) -> str:
return casing.safe_snake_case(name)
def pythonize_method_name(name: str):
def pythonize_method_name(name: str) -> str:
return casing.safe_snake_case(name)

View File

@ -1,7 +1,7 @@
from abc import ABC
import asyncio
import grpclib.const
from abc import ABC
from typing import (
TYPE_CHECKING,
AsyncIterable,
AsyncIterator,
Collection,
@ -9,11 +9,13 @@ from typing import (
Mapping,
Optional,
Tuple,
TYPE_CHECKING,
Type,
Union,
)
from betterproto._types import ST, T
import grpclib.const
from .._types import ST, T
if TYPE_CHECKING:
from grpclib.client import Channel

View File

@ -1,12 +1,5 @@
import asyncio
from typing import (
AsyncIterable,
AsyncIterator,
Iterable,
Optional,
TypeVar,
Union,
)
from typing import AsyncIterable, AsyncIterator, Iterable, Optional, TypeVar, Union
T = TypeVar("T")
@ -16,8 +9,6 @@ class ChannelClosed(Exception):
An exception raised on an attempt to send through a closed channel
"""
pass
class ChannelDone(Exception):
"""
@ -25,8 +16,6 @@ class ChannelDone(Exception):
and empty.
"""
pass
class AsyncChannel(AsyncIterable[T]):
"""

View File

@ -5,10 +5,9 @@ try:
import black
import jinja2
except ImportError as err:
missing_import = err.args[0][17:-1]
print(
"\033[31m"
f"Unable to import `{missing_import}` from betterproto plugin! "
f"Unable to import `{err.name}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
@ -16,7 +15,7 @@ except ImportError as err:
)
raise SystemExit(1)
from betterproto.plugin.models import OutputTemplate
from .models import OutputTemplate
def outputfile_compiler(output_file: OutputTemplate) -> str:
@ -32,9 +31,7 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
)
template = env.get_template("template.py.j2")
res = black.format_str(
return black.format_str(
template.render(output_file=output_file),
mode=black.FileMode(target_versions={black.TargetVersion.PY37}),
)
return res

View File

@ -1,13 +1,14 @@
#!/usr/bin/env python
import sys
import os
import sys
from google.protobuf.compiler import plugin_pb2 as plugin
from betterproto.plugin.parser import generate_code
def main():
def main() -> None:
"""The plugin's main entry point."""
# Read request message from stdin
data = sys.stdin.buffer.read()
@ -33,7 +34,7 @@ def main():
sys.stdout.buffer.write(output)
def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest):
def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest) -> None:
"""
For developers: Supports running plugin.py standalone so its possible to debug it.
Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file.

View File

@ -1,14 +1,14 @@
"""Plugin model dataclasses.
These classes are meant to be an intermediate representation
of protbuf objects. They are used to organize the data collected during parsing.
of protobuf objects. They are used to organize the data collected during parsing.
The general intention is to create a doubly-linked tree-like structure
with the following types of references:
- Downwards references: from message -> fields, from output package -> messages
or from service -> service methods
- Upwards references: from field -> message, message -> package.
- Input/ouput message references: from a service method to it's corresponding
- Input/output message references: from a service method to it's corresponding
input/output messages, which may even be in another package.
There are convenience methods to allow climbing up and down this tree, for
@ -26,36 +26,24 @@ such as a pythonized name, that will be calculated from proto_obj.
The instantiation should also attach a reference to the new object
into the corresponding place within it's parent object. For example,
instantiating field `A` with parent message `B` should add a
reference to `A` to `B`'s `fields` attirbute.
reference to `A` to `B`'s `fields` attribute.
"""
import re
from dataclasses import dataclass
from dataclasses import field
from typing import (
Iterator,
Union,
Type,
List,
Dict,
Set,
Text,
)
import textwrap
from dataclasses import dataclass, field
from typing import Dict, Iterator, List, Optional, Set, Text, Type, Union
import betterproto
from betterproto.compile.importing import (
get_type_reference,
parse_source_type_name,
)
from betterproto.compile.naming import (
from ..casing import sanitize_name
from ..compile.importing import get_type_reference, parse_source_type_name
from ..compile.naming import (
pythonize_class_name,
pythonize_field_name,
pythonize_method_name,
)
from ..casing import sanitize_name
try:
# betterproto[compiler] specific dependencies
from google.protobuf.compiler import plugin_pb2 as plugin
@ -67,10 +55,9 @@ try:
MethodDescriptorProto,
)
except ImportError as err:
missing_import = re.match(r".*(cannot import name .*$)", err.args[0]).group(1)
print(
"\033[31m"
f"Unable to import `{missing_import}` from betterproto plugin! "
f"Unable to import `{err.name}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
@ -124,10 +111,11 @@ PROTO_PACKED_TYPES = (
)
def get_comment(proto_file, path: List[int], indent: int = 4) -> str:
def get_comment(
proto_file: "FileDescriptorProto", path: List[int], indent: int = 4
) -> str:
pad = " " * indent
for sci in proto_file.source_code_info.location:
# print(list(sci.path), path, file=sys.stderr)
if list(sci.path) == path and sci.leading_comments:
lines = textwrap.wrap(
sci.leading_comments.strip().replace("\n", ""), width=79 - indent
@ -153,9 +141,9 @@ class ProtoContentBase:
path: List[int]
comment_indent: int = 4
parent: Union["Messsage", "OutputTemplate"]
parent: Union["betterproto.Message", "OutputTemplate"]
def __post_init__(self):
def __post_init__(self) -> None:
"""Checks that no fake default fields were left as placeholders."""
for field_name, field_val in self.__dataclass_fields__.items():
if field_val is PLACEHOLDER:
@ -273,7 +261,7 @@ class MessageCompiler(ProtoContentBase):
)
deprecated: bool = field(default=False, init=False)
def __post_init__(self):
def __post_init__(self) -> None:
# Add message to output file
if isinstance(self.parent, OutputTemplate):
if isinstance(self, EnumDefinitionCompiler):
@ -314,17 +302,17 @@ def is_map(
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
if message_type == map_entry:
for nested in parent_message.nested_type: # parent message
if nested.name.replace("_", "").lower() == map_entry:
if nested.options.map_entry:
if (
nested.name.replace("_", "").lower() == map_entry
and nested.options.map_entry
):
return True
return False
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
"""True if proto_field_obj is a OneOf, otherwise False."""
if proto_field_obj.HasField("oneof_index"):
return True
return False
return proto_field_obj.HasField("oneof_index")
@dataclass
@ -332,7 +320,7 @@ class FieldCompiler(MessageCompiler):
parent: MessageCompiler = PLACEHOLDER
proto_obj: FieldDescriptorProto = PLACEHOLDER
def __post_init__(self):
def __post_init__(self) -> None:
# Add field to message
self.parent.fields.append(self)
# Check for new imports
@ -357,11 +345,9 @@ class FieldCompiler(MessageCompiler):
([""] + self.betterproto_field_args) if self.betterproto_field_args else []
)
betterproto_field_type = (
f"betterproto.{self.field_type}_field({self.proto_obj.number}"
+ field_args
+ ")"
f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})"
)
return name + annotations + " = " + betterproto_field_type
return f"{name}{annotations} = {betterproto_field_type}"
@property
def betterproto_field_args(self) -> List[str]:
@ -371,7 +357,7 @@ class FieldCompiler(MessageCompiler):
return args
@property
def field_wraps(self) -> Union[str, None]:
def field_wraps(self) -> Optional[str]:
"""Returns betterproto wrapped field type or None."""
match_wrapper = re.match(
r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name
@ -384,17 +370,15 @@ class FieldCompiler(MessageCompiler):
@property
def repeated(self) -> bool:
if self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED and not is_map(
self.proto_obj, self.parent
):
return True
return False
return (
self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED
and not is_map(self.proto_obj, self.parent)
)
@property
def mutable(self) -> bool:
"""True if the field is a mutable type, otherwise False."""
annotation = self.annotation
return annotation.startswith("List[") or annotation.startswith("Dict[")
return self.annotation.startswith(("List[", "Dict["))
@property
def field_type(self) -> str:
@ -425,9 +409,7 @@ class FieldCompiler(MessageCompiler):
@property
def packed(self) -> bool:
"""True if the wire representation is a packed format."""
if self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES:
return True
return False
return self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES
@property
def py_name(self) -> str:
@ -486,12 +468,14 @@ class MapEntryCompiler(FieldCompiler):
proto_k_type: str = PLACEHOLDER
proto_v_type: str = PLACEHOLDER
def __post_init__(self):
def __post_init__(self) -> None:
"""Explore nested types and set k_type and v_type if unset."""
map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry"
for nested in self.parent.proto_obj.nested_type:
if nested.name.replace("_", "").lower() == map_entry:
if nested.options.map_entry:
if (
nested.name.replace("_", "").lower() == map_entry
and nested.options.map_entry
):
# Get Python types
self.py_k_type = FieldCompiler(
parent=self, proto_obj=nested.field[0] # key
@ -513,11 +497,11 @@ class MapEntryCompiler(FieldCompiler):
return "map"
@property
def annotation(self):
def annotation(self) -> str:
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
@property
def repeated(self):
def repeated(self) -> bool:
return False # maps cannot be repeated
@ -536,7 +520,7 @@ class EnumDefinitionCompiler(MessageCompiler):
value: int
comment: str
def __post_init__(self):
def __post_init__(self) -> None:
# Get entries/allowed values for this Enum
self.entries = [
self.EnumEntry(
@ -551,7 +535,7 @@ class EnumDefinitionCompiler(MessageCompiler):
super().__post_init__() # call MessageCompiler __post_init__
@property
def default_value_string(self) -> int:
def default_value_string(self) -> str:
"""Python representation of the default value for Enums.
As per the spec, this is the first value of the Enum.
@ -572,11 +556,11 @@ class ServiceCompiler(ProtoContentBase):
super().__post_init__() # check for unset fields
@property
def proto_name(self):
def proto_name(self) -> str:
return self.proto_obj.name
@property
def py_name(self):
def py_name(self) -> str:
return pythonize_class_name(self.proto_name)
@ -628,7 +612,7 @@ class ServiceMethodCompiler(ProtoContentBase):
Name and actual default value (as a string)
for each argument with mutable default values.
"""
mutable_default_args = dict()
mutable_default_args = {}
if self.py_input_message:
for f in self.py_input_message.fields:
@ -654,18 +638,15 @@ class ServiceMethodCompiler(ProtoContentBase):
@property
def route(self) -> str:
return (
f"/{self.output_file.package}."
f"{self.parent.proto_name}/{self.proto_name}"
)
return f"/{self.output_file.package}.{self.parent.proto_name}/{self.proto_name}"
@property
def py_input_message(self) -> Union[None, MessageCompiler]:
def py_input_message(self) -> Optional[MessageCompiler]:
"""Find the input message object.
Returns
-------
Union[None, MessageCompiler]
Optional[MessageCompiler]
Method instance representing the input message.
If not input message could be found or there are no
input messages, None is returned.
@ -685,14 +666,13 @@ class ServiceMethodCompiler(ProtoContentBase):
@property
def py_input_message_type(self) -> str:
"""String representation of the Python type correspoding to the
"""String representation of the Python type corresponding to the
input message.
Returns
-------
str
String representation of the Python type correspoding to the
input message.
String representation of the Python type corresponding to the input message.
"""
return get_type_reference(
package=self.output_file.package,
@ -702,14 +682,13 @@ class ServiceMethodCompiler(ProtoContentBase):
@property
def py_output_message_type(self) -> str:
"""String representation of the Python type correspoding to the
"""String representation of the Python type corresponding to the
output message.
Returns
-------
str
String representation of the Python type correspoding to the
output message.
String representation of the Python type corresponding to the output message.
"""
return get_type_reference(
package=self.output_file.package,

View File

@ -1,7 +1,7 @@
import itertools
import pathlib
import sys
from typing import List, Iterator
from typing import TYPE_CHECKING, Iterator, List, Tuple, Union, Set
try:
# betterproto[compiler] specific dependencies
@ -13,10 +13,9 @@ try:
ServiceDescriptorProto,
)
except ImportError as err:
missing_import = err.args[0][17:-1]
print(
"\033[31m"
f"Unable to import `{missing_import}` from betterproto plugin! "
f"Unable to import `{err.name}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
@ -24,26 +23,32 @@ except ImportError as err:
)
raise SystemExit(1)
from betterproto.plugin.models import (
PluginRequestCompiler,
OutputTemplate,
MessageCompiler,
FieldCompiler,
OneOfFieldCompiler,
MapEntryCompiler,
from .compiler import outputfile_compiler
from .models import (
EnumDefinitionCompiler,
FieldCompiler,
MapEntryCompiler,
MessageCompiler,
OneOfFieldCompiler,
OutputTemplate,
PluginRequestCompiler,
ServiceCompiler,
ServiceMethodCompiler,
is_map,
is_oneof,
)
from betterproto.plugin.compiler import outputfile_compiler
if TYPE_CHECKING:
from google.protobuf.descriptor import Descriptor
def traverse(proto_file: FieldDescriptorProto) -> Iterator:
def traverse(
proto_file: FieldDescriptorProto,
) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]":
# Todo: Keep information about nested hierarchy
def _traverse(path, items, prefix=""):
def _traverse(
path: List[int], items: List["Descriptor"], prefix=""
) -> Iterator[Tuple[Union[str, EnumDescriptorProto], List[int]]]:
for i, item in enumerate(items):
# Adjust the name since we flatten the hierarchy.
# Todo: don't change the name, but include full name in returned tuple
@ -104,7 +109,7 @@ def generate_code(
read_protobuf_service(service, index, output_package)
# Generate output files
output_paths: pathlib.Path = set()
output_paths: Set[pathlib.Path] = set()
for output_package_name, output_package in request_data.output_packages.items():
# Add files to the response object
@ -112,20 +117,17 @@ def generate_code(
output_paths.add(output_path)
f: response.File = response.file.add()
f.name: str = str(output_path)
f.name = str(output_path)
# Render and then format the output file
f.content: str = outputfile_compiler(output_file=output_package)
f.content = outputfile_compiler(output_file=output_package)
# Make each output directory a package with __init__ file
init_files = (
set(
init_files = {
directory.joinpath("__init__.py")
for path in output_paths
for directory in path.parents
)
- output_paths
)
} - output_paths
for init_file in init_files:
init = response.file.add()

View File

@ -27,10 +27,7 @@ class ClientStub:
async def to_list(generator: AsyncIterator):
result = []
async for value in generator:
result.append(value)
return result
return [value async for value in generator]
@pytest.fixture

View File

@ -6,7 +6,7 @@ from grpclib.client import Channel
class MockChannel(Channel):
# noinspection PyMissingConstructor
def __init__(self, responses=None) -> None:
self.responses = responses if responses else []
self.responses = responses or []
self.requests = []
self._loop = None

View File

@ -23,8 +23,7 @@ def get_files(path, suffix: str) -> Generator[str, None, None]:
def get_directories(path):
for root, directories, files in os.walk(path):
for directory in directories:
yield directory
yield from directories
async def protoc(
@ -49,7 +48,7 @@ async def protoc(
def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = None):
test_data_file_name = json_file_name if json_file_name else f"{test_case_name}.json"
test_data_file_name = json_file_name or f"{test_case_name}.json"
test_data_file_path = inputs_path.joinpath(test_case_name, test_data_file_name)
if not test_data_file_path.exists():
@ -77,7 +76,7 @@ def find_module(
module_path = pathlib.Path(*module.__path__)
for sub in list(sub.parent for sub in module_path.glob("**/__init__.py")):
for sub in [sub.parent for sub in module_path.glob("**/__init__.py")]:
if sub == module_path:
continue
sub_module_path = sub.relative_to(module_path)