QOL fixes (#141)
- Add missing type annotations - Various style improvements - Use constants more consistently - enforce black on benchmark code
This commit is contained in:
parent
bf9412e083
commit
8f7af272cc
2
.github/workflows/code-quality.yml
vendored
2
.github/workflows/code-quality.yml
vendored
@ -17,7 +17,7 @@ jobs:
|
||||
- name: Run Black
|
||||
uses: lgeiger/black-action@master
|
||||
with:
|
||||
args: --check src/ tests/
|
||||
args: --check src/ tests/ benchmarks/
|
||||
|
||||
- name: Install rST dependcies
|
||||
run: python -m pip install doc8
|
||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -15,4 +15,5 @@ output
|
||||
.DS_Store
|
||||
.tox
|
||||
.venv
|
||||
.asv
|
||||
.asv
|
||||
venv
|
||||
|
@ -8,9 +8,9 @@ class TestMessage(betterproto.Message):
|
||||
bar: str = betterproto.string_field(1)
|
||||
baz: float = betterproto.float_field(2)
|
||||
|
||||
|
||||
class BenchMessage:
|
||||
"""Test creation and usage a proto message.
|
||||
"""
|
||||
"""Test creation and usage a proto message."""
|
||||
|
||||
def setup(self):
|
||||
self.cls = TestMessage
|
||||
@ -18,8 +18,8 @@ class BenchMessage:
|
||||
self.instance_filled = TestMessage(0, "test", 0.0)
|
||||
|
||||
def time_overhead(self):
|
||||
"""Overhead in class definition.
|
||||
"""
|
||||
"""Overhead in class definition."""
|
||||
|
||||
@dataclass
|
||||
class Message(betterproto.Message):
|
||||
foo: int = betterproto.uint32_field(0)
|
||||
@ -27,29 +27,25 @@ class BenchMessage:
|
||||
baz: float = betterproto.float_field(2)
|
||||
|
||||
def time_instantiation(self):
|
||||
"""Time instantiation
|
||||
"""
|
||||
"""Time instantiation"""
|
||||
self.cls()
|
||||
|
||||
def time_attribute_access(self):
|
||||
"""Time to access an attribute
|
||||
"""
|
||||
"""Time to access an attribute"""
|
||||
self.instance.foo
|
||||
self.instance.bar
|
||||
self.instance.baz
|
||||
|
||||
|
||||
def time_init_with_values(self):
|
||||
"""Time to set an attribute
|
||||
"""
|
||||
"""Time to set an attribute"""
|
||||
self.cls(0, "test", 0.0)
|
||||
|
||||
def time_attribute_setting(self):
|
||||
"""Time to set attributes
|
||||
"""
|
||||
"""Time to set attributes"""
|
||||
self.instance.foo = 0
|
||||
self.instance.bar = "test"
|
||||
self.instance.baz = 0.0
|
||||
|
||||
|
||||
def time_serialize(self):
|
||||
"""Time serializing a message to wire."""
|
||||
bytes(self.instance_filled)
|
||||
@ -58,6 +54,6 @@ class BenchMessage:
|
||||
class MemSuite:
|
||||
def setup(self):
|
||||
self.cls = TestMessage
|
||||
|
||||
|
||||
def mem_instance(self):
|
||||
return self.cls()
|
||||
|
@ -26,7 +26,7 @@ from ._types import T
|
||||
from .casing import camel_case, safe_snake_case, snake_case
|
||||
from .grpc.grpclib_client import ServiceStub
|
||||
|
||||
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
|
||||
if sys.version_info[:2] < (3, 7):
|
||||
# Apply backport of datetime.fromisoformat from 3.7
|
||||
from backports.datetime_fromisoformat import MonkeyPatch
|
||||
|
||||
@ -110,7 +110,7 @@ WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
||||
|
||||
|
||||
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
|
||||
def datetime_default_gen():
|
||||
def datetime_default_gen() -> datetime:
|
||||
return datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
@ -256,8 +256,7 @@ class Enum(enum.IntEnum):
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, name: str) -> "Enum":
|
||||
"""
|
||||
Return the value which corresponds to the string name.
|
||||
"""Return the value which corresponds to the string name.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
@ -316,11 +315,7 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
|
||||
return encode_varint(value)
|
||||
elif proto_type in [TYPE_SINT32, TYPE_SINT64]:
|
||||
# Handle zig-zag encoding.
|
||||
if value >= 0:
|
||||
value = value << 1
|
||||
else:
|
||||
value = (value << 1) ^ (~0)
|
||||
return encode_varint(value)
|
||||
return encode_varint(value << 1 if value >= 0 else (value << 1) ^ (~0))
|
||||
elif proto_type in FIXED_TYPES:
|
||||
return struct.pack(_pack_fmt(proto_type), value)
|
||||
elif proto_type == TYPE_STRING:
|
||||
@ -413,15 +408,15 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
||||
wire_type = num_wire & 0x7
|
||||
|
||||
decoded: Any = None
|
||||
if wire_type == 0:
|
||||
if wire_type == WIRE_VARINT:
|
||||
decoded, i = decode_varint(value, i)
|
||||
elif wire_type == 1:
|
||||
elif wire_type == WIRE_FIXED_64:
|
||||
decoded, i = value[i : i + 8], i + 8
|
||||
elif wire_type == 2:
|
||||
elif wire_type == WIRE_LEN_DELIM:
|
||||
length, i = decode_varint(value, i)
|
||||
decoded = value[i : i + length]
|
||||
i += length
|
||||
elif wire_type == 5:
|
||||
elif wire_type == WIRE_FIXED_32:
|
||||
decoded, i = value[i : i + 4], i + 4
|
||||
|
||||
yield ParsedField(
|
||||
@ -430,12 +425,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
||||
|
||||
|
||||
class ProtoClassMetadata:
|
||||
oneof_group_by_field: Dict[str, str]
|
||||
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
|
||||
default_gen: Dict[str, Callable]
|
||||
cls_by_field: Dict[str, Type]
|
||||
field_name_by_number: Dict[int, str]
|
||||
meta_by_field_name: Dict[str, FieldMetadata]
|
||||
__slots__ = (
|
||||
"oneof_group_by_field",
|
||||
"oneof_field_by_group",
|
||||
@ -446,6 +435,14 @@ class ProtoClassMetadata:
|
||||
"sorted_field_names",
|
||||
)
|
||||
|
||||
oneof_group_by_field: Dict[str, str]
|
||||
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
|
||||
field_name_by_number: Dict[int, str]
|
||||
meta_by_field_name: Dict[str, FieldMetadata]
|
||||
sorted_field_names: Tuple[str, ...]
|
||||
default_gen: Dict[str, Callable[[], Any]]
|
||||
cls_by_field: Dict[str, Type]
|
||||
|
||||
def __init__(self, cls: Type["Message"]):
|
||||
by_field = {}
|
||||
by_group: Dict[str, Set] = {}
|
||||
@ -470,23 +467,21 @@ class ProtoClassMetadata:
|
||||
self.field_name_by_number = by_field_number
|
||||
self.meta_by_field_name = by_field_name
|
||||
self.sorted_field_names = tuple(
|
||||
by_field_number[number] for number in sorted(by_field_number.keys())
|
||||
by_field_number[number] for number in sorted(by_field_number)
|
||||
)
|
||||
|
||||
self.default_gen = self._get_default_gen(cls, fields)
|
||||
self.cls_by_field = self._get_cls_by_field(cls, fields)
|
||||
|
||||
@staticmethod
|
||||
def _get_default_gen(cls, fields):
|
||||
default_gen = {}
|
||||
|
||||
for field in fields:
|
||||
default_gen[field.name] = cls._get_field_default_gen(field)
|
||||
|
||||
return default_gen
|
||||
def _get_default_gen(
|
||||
cls: Type["Message"], fields: List[dataclasses.Field]
|
||||
) -> Dict[str, Callable[[], Any]]:
|
||||
return {field.name: cls._get_field_default_gen(field) for field in fields}
|
||||
|
||||
@staticmethod
|
||||
def _get_cls_by_field(cls, fields):
|
||||
def _get_cls_by_field(
|
||||
cls: Type["Message"], fields: List[dataclasses.Field]
|
||||
) -> Dict[str, Type]:
|
||||
field_cls = {}
|
||||
|
||||
for field in fields:
|
||||
@ -503,7 +498,7 @@ class ProtoClassMetadata:
|
||||
],
|
||||
bases=(Message,),
|
||||
)
|
||||
field_cls[field.name + ".value"] = vt
|
||||
field_cls[f"{field.name}.value"] = vt
|
||||
else:
|
||||
field_cls[field.name] = cls._cls_for(field)
|
||||
|
||||
@ -612,7 +607,7 @@ class Message(ABC):
|
||||
super().__setattr__(attr, value)
|
||||
|
||||
@property
|
||||
def _betterproto(self):
|
||||
def _betterproto(self) -> ProtoClassMetadata:
|
||||
"""
|
||||
Lazy initialize metadata for each protobuf class.
|
||||
It may be initialized multiple times in a multi-threaded environment,
|
||||
@ -726,9 +721,8 @@ class Message(ABC):
|
||||
|
||||
@classmethod
|
||||
def _type_hints(cls) -> Dict[str, Type]:
|
||||
module = inspect.getmodule(cls)
|
||||
type_hints = get_type_hints(cls, vars(module))
|
||||
return type_hints
|
||||
module = sys.modules[cls.__module__]
|
||||
return get_type_hints(cls, vars(module))
|
||||
|
||||
@classmethod
|
||||
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
|
||||
@ -739,7 +733,7 @@ class Message(ABC):
|
||||
field_cls = field_cls.__args__[index]
|
||||
return field_cls
|
||||
|
||||
def _get_field_default(self, field_name):
|
||||
def _get_field_default(self, field_name: str) -> Any:
|
||||
return self._betterproto.default_gen[field_name]()
|
||||
|
||||
@classmethod
|
||||
@ -762,7 +756,7 @@ class Message(ABC):
|
||||
elif issubclass(t, Enum):
|
||||
# Enums always default to zero.
|
||||
return int
|
||||
elif t == datetime:
|
||||
elif t is datetime:
|
||||
# Offsets are relative to 1970-01-01T00:00:00Z
|
||||
return datetime_default_gen
|
||||
else:
|
||||
@ -966,7 +960,7 @@ class Message(ABC):
|
||||
)
|
||||
):
|
||||
output[cased_name] = value.to_dict(casing, include_default_values)
|
||||
elif meta.proto_type == "map":
|
||||
elif meta.proto_type == TYPE_MAP:
|
||||
for k in value:
|
||||
if hasattr(value[k], "to_dict"):
|
||||
value[k] = value[k].to_dict(casing, include_default_values)
|
||||
@ -1032,12 +1026,12 @@ class Message(ABC):
|
||||
continue
|
||||
|
||||
if value[key] is not None:
|
||||
if meta.proto_type == "message":
|
||||
if meta.proto_type == TYPE_MESSAGE:
|
||||
v = getattr(self, field_name)
|
||||
if isinstance(v, list):
|
||||
cls = self._betterproto.cls_by_field[field_name]
|
||||
for i in range(len(value[key])):
|
||||
v.append(cls().from_dict(value[key][i]))
|
||||
for item in value[key]:
|
||||
v.append(cls().from_dict(item))
|
||||
elif isinstance(v, datetime):
|
||||
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
|
||||
setattr(self, field_name, v)
|
||||
@ -1052,7 +1046,7 @@ class Message(ABC):
|
||||
v.from_dict(value[key])
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
v = getattr(self, field_name)
|
||||
cls = self._betterproto.cls_by_field[field_name + ".value"]
|
||||
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
|
||||
for k in value[key]:
|
||||
v[k] = cls().from_dict(value[key][k])
|
||||
else:
|
||||
@ -1134,7 +1128,7 @@ def serialized_on_wire(message: Message) -> bool:
|
||||
return message._serialized_on_wire
|
||||
|
||||
|
||||
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
||||
def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]:
|
||||
"""
|
||||
Return the name and value of a message's one-of field group.
|
||||
|
||||
@ -1145,21 +1139,21 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
||||
"""
|
||||
field_name = message._group_current.get(group_name)
|
||||
if not field_name:
|
||||
return ("", None)
|
||||
return (field_name, getattr(message, field_name))
|
||||
return "", None
|
||||
return field_name, getattr(message, field_name)
|
||||
|
||||
|
||||
# Circular import workaround: google.protobuf depends on base classes defined above.
|
||||
from .lib.google.protobuf import ( # noqa
|
||||
Duration,
|
||||
Timestamp,
|
||||
BoolValue,
|
||||
BytesValue,
|
||||
DoubleValue,
|
||||
Duration,
|
||||
FloatValue,
|
||||
Int32Value,
|
||||
Int64Value,
|
||||
StringValue,
|
||||
Timestamp,
|
||||
UInt32Value,
|
||||
UInt64Value,
|
||||
)
|
||||
@ -1174,8 +1168,8 @@ class _Duration(Duration):
|
||||
parts = str(delta.total_seconds()).split(".")
|
||||
if len(parts) > 1:
|
||||
while len(parts[1]) not in [3, 6, 9]:
|
||||
parts[1] = parts[1] + "0"
|
||||
return ".".join(parts) + "s"
|
||||
parts[1] = f"{parts[1]}0"
|
||||
return f"{'.'.join(parts)}s"
|
||||
|
||||
|
||||
class _Timestamp(Timestamp):
|
||||
@ -1191,15 +1185,15 @@ class _Timestamp(Timestamp):
|
||||
if (nanos % 1e9) == 0:
|
||||
# If there are 0 fractional digits, the fractional
|
||||
# point '.' should be omitted when serializing.
|
||||
return result + "Z"
|
||||
return f"{result}Z"
|
||||
if (nanos % 1e6) == 0:
|
||||
# Serialize 3 fractional digits.
|
||||
return result + ".%03dZ" % (nanos / 1e6)
|
||||
return f"{result}.{int(nanos // 1e6) :03d}Z"
|
||||
if (nanos % 1e3) == 0:
|
||||
# Serialize 6 fractional digits.
|
||||
return result + ".%06dZ" % (nanos / 1e3)
|
||||
return f"{result}.{int(nanos // 1e3) :06d}Z"
|
||||
# Serialize 9 fractional digits.
|
||||
return result + ".%09dZ" % nanos
|
||||
return f"{result}.{nanos:09d}"
|
||||
|
||||
|
||||
class _WrappedMessage(Message):
|
||||
|
@ -1,8 +1,8 @@
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import Message
|
||||
from grpclib._typing import IProtoMessage
|
||||
from . import Message
|
||||
|
||||
# Bound type variable to allow methods to return `self` of subclasses
|
||||
T = TypeVar("T", bound="Message")
|
||||
|
@ -1,10 +1,10 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Set, Type
|
||||
from typing import Dict, List, Set, Tuple, Type
|
||||
|
||||
from betterproto import safe_snake_case
|
||||
from betterproto.compile.naming import pythonize_class_name
|
||||
from betterproto.lib.google import protobuf as google_protobuf
|
||||
from ..casing import safe_snake_case
|
||||
from ..lib.google import protobuf as google_protobuf
|
||||
from .naming import pythonize_class_name
|
||||
|
||||
WRAPPER_TYPES: Dict[str, Type] = {
|
||||
".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
|
||||
@ -19,7 +19,7 @@ WRAPPER_TYPES: Dict[str, Type] = {
|
||||
}
|
||||
|
||||
|
||||
def parse_source_type_name(field_type_name):
|
||||
def parse_source_type_name(field_type_name: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Split full source type name into package and type name.
|
||||
E.g. 'root.package.Message' -> ('root.package', 'Message')
|
||||
@ -50,7 +50,7 @@ def get_type_reference(
|
||||
if source_type == ".google.protobuf.Duration":
|
||||
return "timedelta"
|
||||
|
||||
if source_type == ".google.protobuf.Timestamp":
|
||||
elif source_type == ".google.protobuf.Timestamp":
|
||||
return "datetime"
|
||||
|
||||
source_package, source_type = parse_source_type_name(source_type)
|
||||
@ -79,7 +79,7 @@ def get_type_reference(
|
||||
return reference_cousin(current_package, imports, py_package, py_type)
|
||||
|
||||
|
||||
def reference_absolute(imports, py_package, py_type):
|
||||
def reference_absolute(imports: Set[str], py_package: List[str], py_type: str) -> str:
|
||||
"""
|
||||
Returns a reference to a python type located in the root, i.e. sys.path.
|
||||
"""
|
||||
|
@ -1,13 +1,13 @@
|
||||
from betterproto import casing
|
||||
|
||||
|
||||
def pythonize_class_name(name):
|
||||
def pythonize_class_name(name: str) -> str:
|
||||
return casing.pascal_case(name)
|
||||
|
||||
|
||||
def pythonize_field_name(name: str):
|
||||
def pythonize_field_name(name: str) -> str:
|
||||
return casing.safe_snake_case(name)
|
||||
|
||||
|
||||
def pythonize_method_name(name: str):
|
||||
def pythonize_method_name(name: str) -> str:
|
||||
return casing.safe_snake_case(name)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from abc import ABC
|
||||
import asyncio
|
||||
import grpclib.const
|
||||
from abc import ABC
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Collection,
|
||||
@ -9,11 +9,13 @@ from typing import (
|
||||
Mapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
from betterproto._types import ST, T
|
||||
|
||||
import grpclib.const
|
||||
|
||||
from .._types import ST, T
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from grpclib.client import Channel
|
||||
|
@ -1,12 +1,5 @@
|
||||
import asyncio
|
||||
from typing import (
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Iterable,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import AsyncIterable, AsyncIterator, Iterable, Optional, TypeVar, Union
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@ -16,8 +9,6 @@ class ChannelClosed(Exception):
|
||||
An exception raised on an attempt to send through a closed channel
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChannelDone(Exception):
|
||||
"""
|
||||
@ -25,8 +16,6 @@ class ChannelDone(Exception):
|
||||
and empty.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AsyncChannel(AsyncIterable[T]):
|
||||
"""
|
||||
|
@ -5,10 +5,9 @@ try:
|
||||
import black
|
||||
import jinja2
|
||||
except ImportError as err:
|
||||
missing_import = err.args[0][17:-1]
|
||||
print(
|
||||
"\033[31m"
|
||||
f"Unable to import `{missing_import}` from betterproto plugin! "
|
||||
f"Unable to import `{err.name}` from betterproto plugin! "
|
||||
"Please ensure that you've installed betterproto as "
|
||||
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
||||
"are included."
|
||||
@ -16,7 +15,7 @@ except ImportError as err:
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
from betterproto.plugin.models import OutputTemplate
|
||||
from .models import OutputTemplate
|
||||
|
||||
|
||||
def outputfile_compiler(output_file: OutputTemplate) -> str:
|
||||
@ -32,9 +31,7 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
|
||||
)
|
||||
template = env.get_template("template.py.j2")
|
||||
|
||||
res = black.format_str(
|
||||
return black.format_str(
|
||||
template.render(output_file=output_file),
|
||||
mode=black.FileMode(target_versions={black.TargetVersion.PY37}),
|
||||
)
|
||||
|
||||
return res
|
||||
|
@ -1,13 +1,14 @@
|
||||
#!/usr/bin/env python
|
||||
import sys
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
|
||||
from betterproto.plugin.parser import generate_code
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
"""The plugin's main entry point."""
|
||||
# Read request message from stdin
|
||||
data = sys.stdin.buffer.read()
|
||||
@ -33,7 +34,7 @@ def main():
|
||||
sys.stdout.buffer.write(output)
|
||||
|
||||
|
||||
def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest):
|
||||
def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest) -> None:
|
||||
"""
|
||||
For developers: Supports running plugin.py standalone so its possible to debug it.
|
||||
Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file.
|
||||
|
@ -1,14 +1,14 @@
|
||||
"""Plugin model dataclasses.
|
||||
|
||||
These classes are meant to be an intermediate representation
|
||||
of protbuf objects. They are used to organize the data collected during parsing.
|
||||
of protobuf objects. They are used to organize the data collected during parsing.
|
||||
|
||||
The general intention is to create a doubly-linked tree-like structure
|
||||
with the following types of references:
|
||||
- Downwards references: from message -> fields, from output package -> messages
|
||||
or from service -> service methods
|
||||
- Upwards references: from field -> message, message -> package.
|
||||
- Input/ouput message references: from a service method to it's corresponding
|
||||
- Input/output message references: from a service method to it's corresponding
|
||||
input/output messages, which may even be in another package.
|
||||
|
||||
There are convenience methods to allow climbing up and down this tree, for
|
||||
@ -26,36 +26,24 @@ such as a pythonized name, that will be calculated from proto_obj.
|
||||
The instantiation should also attach a reference to the new object
|
||||
into the corresponding place within it's parent object. For example,
|
||||
instantiating field `A` with parent message `B` should add a
|
||||
reference to `A` to `B`'s `fields` attirbute.
|
||||
reference to `A` to `B`'s `fields` attribute.
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from typing import (
|
||||
Iterator,
|
||||
Union,
|
||||
Type,
|
||||
List,
|
||||
Dict,
|
||||
Set,
|
||||
Text,
|
||||
)
|
||||
import textwrap
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Iterator, List, Optional, Set, Text, Type, Union
|
||||
|
||||
import betterproto
|
||||
from betterproto.compile.importing import (
|
||||
get_type_reference,
|
||||
parse_source_type_name,
|
||||
)
|
||||
from betterproto.compile.naming import (
|
||||
|
||||
from ..casing import sanitize_name
|
||||
from ..compile.importing import get_type_reference, parse_source_type_name
|
||||
from ..compile.naming import (
|
||||
pythonize_class_name,
|
||||
pythonize_field_name,
|
||||
pythonize_method_name,
|
||||
)
|
||||
|
||||
from ..casing import sanitize_name
|
||||
|
||||
try:
|
||||
# betterproto[compiler] specific dependencies
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
@ -67,10 +55,9 @@ try:
|
||||
MethodDescriptorProto,
|
||||
)
|
||||
except ImportError as err:
|
||||
missing_import = re.match(r".*(cannot import name .*$)", err.args[0]).group(1)
|
||||
print(
|
||||
"\033[31m"
|
||||
f"Unable to import `{missing_import}` from betterproto plugin! "
|
||||
f"Unable to import `{err.name}` from betterproto plugin! "
|
||||
"Please ensure that you've installed betterproto as "
|
||||
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
||||
"are included."
|
||||
@ -124,10 +111,11 @@ PROTO_PACKED_TYPES = (
|
||||
)
|
||||
|
||||
|
||||
def get_comment(proto_file, path: List[int], indent: int = 4) -> str:
|
||||
def get_comment(
|
||||
proto_file: "FileDescriptorProto", path: List[int], indent: int = 4
|
||||
) -> str:
|
||||
pad = " " * indent
|
||||
for sci in proto_file.source_code_info.location:
|
||||
# print(list(sci.path), path, file=sys.stderr)
|
||||
if list(sci.path) == path and sci.leading_comments:
|
||||
lines = textwrap.wrap(
|
||||
sci.leading_comments.strip().replace("\n", ""), width=79 - indent
|
||||
@ -153,9 +141,9 @@ class ProtoContentBase:
|
||||
|
||||
path: List[int]
|
||||
comment_indent: int = 4
|
||||
parent: Union["Messsage", "OutputTemplate"]
|
||||
parent: Union["betterproto.Message", "OutputTemplate"]
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
"""Checks that no fake default fields were left as placeholders."""
|
||||
for field_name, field_val in self.__dataclass_fields__.items():
|
||||
if field_val is PLACEHOLDER:
|
||||
@ -273,7 +261,7 @@ class MessageCompiler(ProtoContentBase):
|
||||
)
|
||||
deprecated: bool = field(default=False, init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
# Add message to output file
|
||||
if isinstance(self.parent, OutputTemplate):
|
||||
if isinstance(self, EnumDefinitionCompiler):
|
||||
@ -314,17 +302,17 @@ def is_map(
|
||||
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
|
||||
if message_type == map_entry:
|
||||
for nested in parent_message.nested_type: # parent message
|
||||
if nested.name.replace("_", "").lower() == map_entry:
|
||||
if nested.options.map_entry:
|
||||
return True
|
||||
if (
|
||||
nested.name.replace("_", "").lower() == map_entry
|
||||
and nested.options.map_entry
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
|
||||
"""True if proto_field_obj is a OneOf, otherwise False."""
|
||||
if proto_field_obj.HasField("oneof_index"):
|
||||
return True
|
||||
return False
|
||||
return proto_field_obj.HasField("oneof_index")
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -332,7 +320,7 @@ class FieldCompiler(MessageCompiler):
|
||||
parent: MessageCompiler = PLACEHOLDER
|
||||
proto_obj: FieldDescriptorProto = PLACEHOLDER
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
# Add field to message
|
||||
self.parent.fields.append(self)
|
||||
# Check for new imports
|
||||
@ -357,11 +345,9 @@ class FieldCompiler(MessageCompiler):
|
||||
([""] + self.betterproto_field_args) if self.betterproto_field_args else []
|
||||
)
|
||||
betterproto_field_type = (
|
||||
f"betterproto.{self.field_type}_field({self.proto_obj.number}"
|
||||
+ field_args
|
||||
+ ")"
|
||||
f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})"
|
||||
)
|
||||
return name + annotations + " = " + betterproto_field_type
|
||||
return f"{name}{annotations} = {betterproto_field_type}"
|
||||
|
||||
@property
|
||||
def betterproto_field_args(self) -> List[str]:
|
||||
@ -371,7 +357,7 @@ class FieldCompiler(MessageCompiler):
|
||||
return args
|
||||
|
||||
@property
|
||||
def field_wraps(self) -> Union[str, None]:
|
||||
def field_wraps(self) -> Optional[str]:
|
||||
"""Returns betterproto wrapped field type or None."""
|
||||
match_wrapper = re.match(
|
||||
r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name
|
||||
@ -384,17 +370,15 @@ class FieldCompiler(MessageCompiler):
|
||||
|
||||
@property
|
||||
def repeated(self) -> bool:
|
||||
if self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED and not is_map(
|
||||
self.proto_obj, self.parent
|
||||
):
|
||||
return True
|
||||
return False
|
||||
return (
|
||||
self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED
|
||||
and not is_map(self.proto_obj, self.parent)
|
||||
)
|
||||
|
||||
@property
|
||||
def mutable(self) -> bool:
|
||||
"""True if the field is a mutable type, otherwise False."""
|
||||
annotation = self.annotation
|
||||
return annotation.startswith("List[") or annotation.startswith("Dict[")
|
||||
return self.annotation.startswith(("List[", "Dict["))
|
||||
|
||||
@property
|
||||
def field_type(self) -> str:
|
||||
@ -425,9 +409,7 @@ class FieldCompiler(MessageCompiler):
|
||||
@property
|
||||
def packed(self) -> bool:
|
||||
"""True if the wire representation is a packed format."""
|
||||
if self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES:
|
||||
return True
|
||||
return False
|
||||
return self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES
|
||||
|
||||
@property
|
||||
def py_name(self) -> str:
|
||||
@ -486,22 +468,24 @@ class MapEntryCompiler(FieldCompiler):
|
||||
proto_k_type: str = PLACEHOLDER
|
||||
proto_v_type: str = PLACEHOLDER
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
"""Explore nested types and set k_type and v_type if unset."""
|
||||
map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry"
|
||||
for nested in self.parent.proto_obj.nested_type:
|
||||
if nested.name.replace("_", "").lower() == map_entry:
|
||||
if nested.options.map_entry:
|
||||
# Get Python types
|
||||
self.py_k_type = FieldCompiler(
|
||||
parent=self, proto_obj=nested.field[0] # key
|
||||
).py_type
|
||||
self.py_v_type = FieldCompiler(
|
||||
parent=self, proto_obj=nested.field[1] # value
|
||||
).py_type
|
||||
# Get proto types
|
||||
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
|
||||
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
|
||||
if (
|
||||
nested.name.replace("_", "").lower() == map_entry
|
||||
and nested.options.map_entry
|
||||
):
|
||||
# Get Python types
|
||||
self.py_k_type = FieldCompiler(
|
||||
parent=self, proto_obj=nested.field[0] # key
|
||||
).py_type
|
||||
self.py_v_type = FieldCompiler(
|
||||
parent=self, proto_obj=nested.field[1] # value
|
||||
).py_type
|
||||
# Get proto types
|
||||
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
|
||||
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
|
||||
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
|
||||
|
||||
@property
|
||||
@ -513,11 +497,11 @@ class MapEntryCompiler(FieldCompiler):
|
||||
return "map"
|
||||
|
||||
@property
|
||||
def annotation(self):
|
||||
def annotation(self) -> str:
|
||||
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
|
||||
|
||||
@property
|
||||
def repeated(self):
|
||||
def repeated(self) -> bool:
|
||||
return False # maps cannot be repeated
|
||||
|
||||
|
||||
@ -536,7 +520,7 @@ class EnumDefinitionCompiler(MessageCompiler):
|
||||
value: int
|
||||
comment: str
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
# Get entries/allowed values for this Enum
|
||||
self.entries = [
|
||||
self.EnumEntry(
|
||||
@ -551,7 +535,7 @@ class EnumDefinitionCompiler(MessageCompiler):
|
||||
super().__post_init__() # call MessageCompiler __post_init__
|
||||
|
||||
@property
|
||||
def default_value_string(self) -> int:
|
||||
def default_value_string(self) -> str:
|
||||
"""Python representation of the default value for Enums.
|
||||
|
||||
As per the spec, this is the first value of the Enum.
|
||||
@ -572,11 +556,11 @@ class ServiceCompiler(ProtoContentBase):
|
||||
super().__post_init__() # check for unset fields
|
||||
|
||||
@property
|
||||
def proto_name(self):
|
||||
def proto_name(self) -> str:
|
||||
return self.proto_obj.name
|
||||
|
||||
@property
|
||||
def py_name(self):
|
||||
def py_name(self) -> str:
|
||||
return pythonize_class_name(self.proto_name)
|
||||
|
||||
|
||||
@ -628,7 +612,7 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
Name and actual default value (as a string)
|
||||
for each argument with mutable default values.
|
||||
"""
|
||||
mutable_default_args = dict()
|
||||
mutable_default_args = {}
|
||||
|
||||
if self.py_input_message:
|
||||
for f in self.py_input_message.fields:
|
||||
@ -654,18 +638,15 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
|
||||
@property
|
||||
def route(self) -> str:
|
||||
return (
|
||||
f"/{self.output_file.package}."
|
||||
f"{self.parent.proto_name}/{self.proto_name}"
|
||||
)
|
||||
return f"/{self.output_file.package}.{self.parent.proto_name}/{self.proto_name}"
|
||||
|
||||
@property
|
||||
def py_input_message(self) -> Union[None, MessageCompiler]:
|
||||
def py_input_message(self) -> Optional[MessageCompiler]:
|
||||
"""Find the input message object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[None, MessageCompiler]
|
||||
Optional[MessageCompiler]
|
||||
Method instance representing the input message.
|
||||
If not input message could be found or there are no
|
||||
input messages, None is returned.
|
||||
@ -685,14 +666,13 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
|
||||
@property
|
||||
def py_input_message_type(self) -> str:
|
||||
"""String representation of the Python type correspoding to the
|
||||
"""String representation of the Python type corresponding to the
|
||||
input message.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
String representation of the Python type correspoding to the
|
||||
input message.
|
||||
String representation of the Python type corresponding to the input message.
|
||||
"""
|
||||
return get_type_reference(
|
||||
package=self.output_file.package,
|
||||
@ -702,14 +682,13 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
|
||||
@property
|
||||
def py_output_message_type(self) -> str:
|
||||
"""String representation of the Python type correspoding to the
|
||||
"""String representation of the Python type corresponding to the
|
||||
output message.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
String representation of the Python type correspoding to the
|
||||
output message.
|
||||
String representation of the Python type corresponding to the output message.
|
||||
"""
|
||||
return get_type_reference(
|
||||
package=self.output_file.package,
|
||||
|
@ -1,7 +1,7 @@
|
||||
import itertools
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import List, Iterator
|
||||
from typing import TYPE_CHECKING, Iterator, List, Tuple, Union, Set
|
||||
|
||||
try:
|
||||
# betterproto[compiler] specific dependencies
|
||||
@ -13,10 +13,9 @@ try:
|
||||
ServiceDescriptorProto,
|
||||
)
|
||||
except ImportError as err:
|
||||
missing_import = err.args[0][17:-1]
|
||||
print(
|
||||
"\033[31m"
|
||||
f"Unable to import `{missing_import}` from betterproto plugin! "
|
||||
f"Unable to import `{err.name}` from betterproto plugin! "
|
||||
"Please ensure that you've installed betterproto as "
|
||||
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
||||
"are included."
|
||||
@ -24,26 +23,32 @@ except ImportError as err:
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
from betterproto.plugin.models import (
|
||||
PluginRequestCompiler,
|
||||
OutputTemplate,
|
||||
MessageCompiler,
|
||||
FieldCompiler,
|
||||
OneOfFieldCompiler,
|
||||
MapEntryCompiler,
|
||||
from .compiler import outputfile_compiler
|
||||
from .models import (
|
||||
EnumDefinitionCompiler,
|
||||
FieldCompiler,
|
||||
MapEntryCompiler,
|
||||
MessageCompiler,
|
||||
OneOfFieldCompiler,
|
||||
OutputTemplate,
|
||||
PluginRequestCompiler,
|
||||
ServiceCompiler,
|
||||
ServiceMethodCompiler,
|
||||
is_map,
|
||||
is_oneof,
|
||||
)
|
||||
|
||||
from betterproto.plugin.compiler import outputfile_compiler
|
||||
if TYPE_CHECKING:
|
||||
from google.protobuf.descriptor import Descriptor
|
||||
|
||||
|
||||
def traverse(proto_file: FieldDescriptorProto) -> Iterator:
|
||||
def traverse(
|
||||
proto_file: FieldDescriptorProto,
|
||||
) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]":
|
||||
# Todo: Keep information about nested hierarchy
|
||||
def _traverse(path, items, prefix=""):
|
||||
def _traverse(
|
||||
path: List[int], items: List["Descriptor"], prefix=""
|
||||
) -> Iterator[Tuple[Union[str, EnumDescriptorProto], List[int]]]:
|
||||
for i, item in enumerate(items):
|
||||
# Adjust the name since we flatten the hierarchy.
|
||||
# Todo: don't change the name, but include full name in returned tuple
|
||||
@ -104,7 +109,7 @@ def generate_code(
|
||||
read_protobuf_service(service, index, output_package)
|
||||
|
||||
# Generate output files
|
||||
output_paths: pathlib.Path = set()
|
||||
output_paths: Set[pathlib.Path] = set()
|
||||
for output_package_name, output_package in request_data.output_packages.items():
|
||||
|
||||
# Add files to the response object
|
||||
@ -112,20 +117,17 @@ def generate_code(
|
||||
output_paths.add(output_path)
|
||||
|
||||
f: response.File = response.file.add()
|
||||
f.name: str = str(output_path)
|
||||
f.name = str(output_path)
|
||||
|
||||
# Render and then format the output file
|
||||
f.content: str = outputfile_compiler(output_file=output_package)
|
||||
f.content = outputfile_compiler(output_file=output_package)
|
||||
|
||||
# Make each output directory a package with __init__ file
|
||||
init_files = (
|
||||
set(
|
||||
directory.joinpath("__init__.py")
|
||||
for path in output_paths
|
||||
for directory in path.parents
|
||||
)
|
||||
- output_paths
|
||||
)
|
||||
init_files = {
|
||||
directory.joinpath("__init__.py")
|
||||
for path in output_paths
|
||||
for directory in path.parents
|
||||
} - output_paths
|
||||
|
||||
for init_file in init_files:
|
||||
init = response.file.add()
|
||||
|
@ -27,10 +27,7 @@ class ClientStub:
|
||||
|
||||
|
||||
async def to_list(generator: AsyncIterator):
|
||||
result = []
|
||||
async for value in generator:
|
||||
result.append(value)
|
||||
return result
|
||||
return [value async for value in generator]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -6,7 +6,7 @@ from grpclib.client import Channel
|
||||
class MockChannel(Channel):
|
||||
# noinspection PyMissingConstructor
|
||||
def __init__(self, responses=None) -> None:
|
||||
self.responses = responses if responses else []
|
||||
self.responses = responses or []
|
||||
self.requests = []
|
||||
self._loop = None
|
||||
|
||||
|
@ -23,8 +23,7 @@ def get_files(path, suffix: str) -> Generator[str, None, None]:
|
||||
|
||||
def get_directories(path):
|
||||
for root, directories, files in os.walk(path):
|
||||
for directory in directories:
|
||||
yield directory
|
||||
yield from directories
|
||||
|
||||
|
||||
async def protoc(
|
||||
@ -49,7 +48,7 @@ async def protoc(
|
||||
|
||||
|
||||
def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = None):
|
||||
test_data_file_name = json_file_name if json_file_name else f"{test_case_name}.json"
|
||||
test_data_file_name = json_file_name or f"{test_case_name}.json"
|
||||
test_data_file_path = inputs_path.joinpath(test_case_name, test_data_file_name)
|
||||
|
||||
if not test_data_file_path.exists():
|
||||
@ -77,7 +76,7 @@ def find_module(
|
||||
|
||||
module_path = pathlib.Path(*module.__path__)
|
||||
|
||||
for sub in list(sub.parent for sub in module_path.glob("**/__init__.py")):
|
||||
for sub in [sub.parent for sub in module_path.glob("**/__init__.py")]:
|
||||
if sub == module_path:
|
||||
continue
|
||||
sub_module_path = sub.relative_to(module_path)
|
||||
|
Loading…
x
Reference in New Issue
Block a user