QOL fixes (#141)

- Add missing type annotations
- Various style improvements
- Use constants more consistently
- enforce black on benchmark code
This commit is contained in:
James 2020-10-17 18:27:11 +01:00 committed by GitHub
parent bf9412e083
commit 8f7af272cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 177 additions and 220 deletions

View File

@ -17,7 +17,7 @@ jobs:
- name: Run Black - name: Run Black
uses: lgeiger/black-action@master uses: lgeiger/black-action@master
with: with:
args: --check src/ tests/ args: --check src/ tests/ benchmarks/
- name: Install rST dependcies - name: Install rST dependcies
run: python -m pip install doc8 run: python -m pip install doc8

3
.gitignore vendored
View File

@ -15,4 +15,5 @@ output
.DS_Store .DS_Store
.tox .tox
.venv .venv
.asv .asv
venv

View File

@ -8,9 +8,9 @@ class TestMessage(betterproto.Message):
bar: str = betterproto.string_field(1) bar: str = betterproto.string_field(1)
baz: float = betterproto.float_field(2) baz: float = betterproto.float_field(2)
class BenchMessage: class BenchMessage:
"""Test creation and usage a proto message. """Test creation and usage a proto message."""
"""
def setup(self): def setup(self):
self.cls = TestMessage self.cls = TestMessage
@ -18,8 +18,8 @@ class BenchMessage:
self.instance_filled = TestMessage(0, "test", 0.0) self.instance_filled = TestMessage(0, "test", 0.0)
def time_overhead(self): def time_overhead(self):
"""Overhead in class definition. """Overhead in class definition."""
"""
@dataclass @dataclass
class Message(betterproto.Message): class Message(betterproto.Message):
foo: int = betterproto.uint32_field(0) foo: int = betterproto.uint32_field(0)
@ -27,29 +27,25 @@ class BenchMessage:
baz: float = betterproto.float_field(2) baz: float = betterproto.float_field(2)
def time_instantiation(self): def time_instantiation(self):
"""Time instantiation """Time instantiation"""
"""
self.cls() self.cls()
def time_attribute_access(self): def time_attribute_access(self):
"""Time to access an attribute """Time to access an attribute"""
"""
self.instance.foo self.instance.foo
self.instance.bar self.instance.bar
self.instance.baz self.instance.baz
def time_init_with_values(self): def time_init_with_values(self):
"""Time to set an attribute """Time to set an attribute"""
"""
self.cls(0, "test", 0.0) self.cls(0, "test", 0.0)
def time_attribute_setting(self): def time_attribute_setting(self):
"""Time to set attributes """Time to set attributes"""
"""
self.instance.foo = 0 self.instance.foo = 0
self.instance.bar = "test" self.instance.bar = "test"
self.instance.baz = 0.0 self.instance.baz = 0.0
def time_serialize(self): def time_serialize(self):
"""Time serializing a message to wire.""" """Time serializing a message to wire."""
bytes(self.instance_filled) bytes(self.instance_filled)
@ -58,6 +54,6 @@ class BenchMessage:
class MemSuite: class MemSuite:
def setup(self): def setup(self):
self.cls = TestMessage self.cls = TestMessage
def mem_instance(self): def mem_instance(self):
return self.cls() return self.cls()

View File

@ -26,7 +26,7 @@ from ._types import T
from .casing import camel_case, safe_snake_case, snake_case from .casing import camel_case, safe_snake_case, snake_case
from .grpc.grpclib_client import ServiceStub from .grpc.grpclib_client import ServiceStub
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7): if sys.version_info[:2] < (3, 7):
# Apply backport of datetime.fromisoformat from 3.7 # Apply backport of datetime.fromisoformat from 3.7
from backports.datetime_fromisoformat import MonkeyPatch from backports.datetime_fromisoformat import MonkeyPatch
@ -110,7 +110,7 @@ WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC. # Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
def datetime_default_gen(): def datetime_default_gen() -> datetime:
return datetime(1970, 1, 1, tzinfo=timezone.utc) return datetime(1970, 1, 1, tzinfo=timezone.utc)
@ -256,8 +256,7 @@ class Enum(enum.IntEnum):
@classmethod @classmethod
def from_string(cls, name: str) -> "Enum": def from_string(cls, name: str) -> "Enum":
""" """Return the value which corresponds to the string name.
Return the value which corresponds to the string name.
Parameters Parameters
----------- -----------
@ -316,11 +315,7 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
return encode_varint(value) return encode_varint(value)
elif proto_type in [TYPE_SINT32, TYPE_SINT64]: elif proto_type in [TYPE_SINT32, TYPE_SINT64]:
# Handle zig-zag encoding. # Handle zig-zag encoding.
if value >= 0: return encode_varint(value << 1 if value >= 0 else (value << 1) ^ (~0))
value = value << 1
else:
value = (value << 1) ^ (~0)
return encode_varint(value)
elif proto_type in FIXED_TYPES: elif proto_type in FIXED_TYPES:
return struct.pack(_pack_fmt(proto_type), value) return struct.pack(_pack_fmt(proto_type), value)
elif proto_type == TYPE_STRING: elif proto_type == TYPE_STRING:
@ -413,15 +408,15 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
wire_type = num_wire & 0x7 wire_type = num_wire & 0x7
decoded: Any = None decoded: Any = None
if wire_type == 0: if wire_type == WIRE_VARINT:
decoded, i = decode_varint(value, i) decoded, i = decode_varint(value, i)
elif wire_type == 1: elif wire_type == WIRE_FIXED_64:
decoded, i = value[i : i + 8], i + 8 decoded, i = value[i : i + 8], i + 8
elif wire_type == 2: elif wire_type == WIRE_LEN_DELIM:
length, i = decode_varint(value, i) length, i = decode_varint(value, i)
decoded = value[i : i + length] decoded = value[i : i + length]
i += length i += length
elif wire_type == 5: elif wire_type == WIRE_FIXED_32:
decoded, i = value[i : i + 4], i + 4 decoded, i = value[i : i + 4], i + 4
yield ParsedField( yield ParsedField(
@ -430,12 +425,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
class ProtoClassMetadata: class ProtoClassMetadata:
oneof_group_by_field: Dict[str, str]
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
default_gen: Dict[str, Callable]
cls_by_field: Dict[str, Type]
field_name_by_number: Dict[int, str]
meta_by_field_name: Dict[str, FieldMetadata]
__slots__ = ( __slots__ = (
"oneof_group_by_field", "oneof_group_by_field",
"oneof_field_by_group", "oneof_field_by_group",
@ -446,6 +435,14 @@ class ProtoClassMetadata:
"sorted_field_names", "sorted_field_names",
) )
oneof_group_by_field: Dict[str, str]
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
field_name_by_number: Dict[int, str]
meta_by_field_name: Dict[str, FieldMetadata]
sorted_field_names: Tuple[str, ...]
default_gen: Dict[str, Callable[[], Any]]
cls_by_field: Dict[str, Type]
def __init__(self, cls: Type["Message"]): def __init__(self, cls: Type["Message"]):
by_field = {} by_field = {}
by_group: Dict[str, Set] = {} by_group: Dict[str, Set] = {}
@ -470,23 +467,21 @@ class ProtoClassMetadata:
self.field_name_by_number = by_field_number self.field_name_by_number = by_field_number
self.meta_by_field_name = by_field_name self.meta_by_field_name = by_field_name
self.sorted_field_names = tuple( self.sorted_field_names = tuple(
by_field_number[number] for number in sorted(by_field_number.keys()) by_field_number[number] for number in sorted(by_field_number)
) )
self.default_gen = self._get_default_gen(cls, fields) self.default_gen = self._get_default_gen(cls, fields)
self.cls_by_field = self._get_cls_by_field(cls, fields) self.cls_by_field = self._get_cls_by_field(cls, fields)
@staticmethod @staticmethod
def _get_default_gen(cls, fields): def _get_default_gen(
default_gen = {} cls: Type["Message"], fields: List[dataclasses.Field]
) -> Dict[str, Callable[[], Any]]:
for field in fields: return {field.name: cls._get_field_default_gen(field) for field in fields}
default_gen[field.name] = cls._get_field_default_gen(field)
return default_gen
@staticmethod @staticmethod
def _get_cls_by_field(cls, fields): def _get_cls_by_field(
cls: Type["Message"], fields: List[dataclasses.Field]
) -> Dict[str, Type]:
field_cls = {} field_cls = {}
for field in fields: for field in fields:
@ -503,7 +498,7 @@ class ProtoClassMetadata:
], ],
bases=(Message,), bases=(Message,),
) )
field_cls[field.name + ".value"] = vt field_cls[f"{field.name}.value"] = vt
else: else:
field_cls[field.name] = cls._cls_for(field) field_cls[field.name] = cls._cls_for(field)
@ -612,7 +607,7 @@ class Message(ABC):
super().__setattr__(attr, value) super().__setattr__(attr, value)
@property @property
def _betterproto(self): def _betterproto(self) -> ProtoClassMetadata:
""" """
Lazy initialize metadata for each protobuf class. Lazy initialize metadata for each protobuf class.
It may be initialized multiple times in a multi-threaded environment, It may be initialized multiple times in a multi-threaded environment,
@ -726,9 +721,8 @@ class Message(ABC):
@classmethod @classmethod
def _type_hints(cls) -> Dict[str, Type]: def _type_hints(cls) -> Dict[str, Type]:
module = inspect.getmodule(cls) module = sys.modules[cls.__module__]
type_hints = get_type_hints(cls, vars(module)) return get_type_hints(cls, vars(module))
return type_hints
@classmethod @classmethod
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type: def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
@ -739,7 +733,7 @@ class Message(ABC):
field_cls = field_cls.__args__[index] field_cls = field_cls.__args__[index]
return field_cls return field_cls
def _get_field_default(self, field_name): def _get_field_default(self, field_name: str) -> Any:
return self._betterproto.default_gen[field_name]() return self._betterproto.default_gen[field_name]()
@classmethod @classmethod
@ -762,7 +756,7 @@ class Message(ABC):
elif issubclass(t, Enum): elif issubclass(t, Enum):
# Enums always default to zero. # Enums always default to zero.
return int return int
elif t == datetime: elif t is datetime:
# Offsets are relative to 1970-01-01T00:00:00Z # Offsets are relative to 1970-01-01T00:00:00Z
return datetime_default_gen return datetime_default_gen
else: else:
@ -966,7 +960,7 @@ class Message(ABC):
) )
): ):
output[cased_name] = value.to_dict(casing, include_default_values) output[cased_name] = value.to_dict(casing, include_default_values)
elif meta.proto_type == "map": elif meta.proto_type == TYPE_MAP:
for k in value: for k in value:
if hasattr(value[k], "to_dict"): if hasattr(value[k], "to_dict"):
value[k] = value[k].to_dict(casing, include_default_values) value[k] = value[k].to_dict(casing, include_default_values)
@ -1032,12 +1026,12 @@ class Message(ABC):
continue continue
if value[key] is not None: if value[key] is not None:
if meta.proto_type == "message": if meta.proto_type == TYPE_MESSAGE:
v = getattr(self, field_name) v = getattr(self, field_name)
if isinstance(v, list): if isinstance(v, list):
cls = self._betterproto.cls_by_field[field_name] cls = self._betterproto.cls_by_field[field_name]
for i in range(len(value[key])): for item in value[key]:
v.append(cls().from_dict(value[key][i])) v.append(cls().from_dict(item))
elif isinstance(v, datetime): elif isinstance(v, datetime):
v = datetime.fromisoformat(value[key].replace("Z", "+00:00")) v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
setattr(self, field_name, v) setattr(self, field_name, v)
@ -1052,7 +1046,7 @@ class Message(ABC):
v.from_dict(value[key]) v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field_name) v = getattr(self, field_name)
cls = self._betterproto.cls_by_field[field_name + ".value"] cls = self._betterproto.cls_by_field[f"{field_name}.value"]
for k in value[key]: for k in value[key]:
v[k] = cls().from_dict(value[key][k]) v[k] = cls().from_dict(value[key][k])
else: else:
@ -1134,7 +1128,7 @@ def serialized_on_wire(message: Message) -> bool:
return message._serialized_on_wire return message._serialized_on_wire
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]: def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]:
""" """
Return the name and value of a message's one-of field group. Return the name and value of a message's one-of field group.
@ -1145,21 +1139,21 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
""" """
field_name = message._group_current.get(group_name) field_name = message._group_current.get(group_name)
if not field_name: if not field_name:
return ("", None) return "", None
return (field_name, getattr(message, field_name)) return field_name, getattr(message, field_name)
# Circular import workaround: google.protobuf depends on base classes defined above. # Circular import workaround: google.protobuf depends on base classes defined above.
from .lib.google.protobuf import ( # noqa from .lib.google.protobuf import ( # noqa
Duration,
Timestamp,
BoolValue, BoolValue,
BytesValue, BytesValue,
DoubleValue, DoubleValue,
Duration,
FloatValue, FloatValue,
Int32Value, Int32Value,
Int64Value, Int64Value,
StringValue, StringValue,
Timestamp,
UInt32Value, UInt32Value,
UInt64Value, UInt64Value,
) )
@ -1174,8 +1168,8 @@ class _Duration(Duration):
parts = str(delta.total_seconds()).split(".") parts = str(delta.total_seconds()).split(".")
if len(parts) > 1: if len(parts) > 1:
while len(parts[1]) not in [3, 6, 9]: while len(parts[1]) not in [3, 6, 9]:
parts[1] = parts[1] + "0" parts[1] = f"{parts[1]}0"
return ".".join(parts) + "s" return f"{'.'.join(parts)}s"
class _Timestamp(Timestamp): class _Timestamp(Timestamp):
@ -1191,15 +1185,15 @@ class _Timestamp(Timestamp):
if (nanos % 1e9) == 0: if (nanos % 1e9) == 0:
# If there are 0 fractional digits, the fractional # If there are 0 fractional digits, the fractional
# point '.' should be omitted when serializing. # point '.' should be omitted when serializing.
return result + "Z" return f"{result}Z"
if (nanos % 1e6) == 0: if (nanos % 1e6) == 0:
# Serialize 3 fractional digits. # Serialize 3 fractional digits.
return result + ".%03dZ" % (nanos / 1e6) return f"{result}.{int(nanos // 1e6) :03d}Z"
if (nanos % 1e3) == 0: if (nanos % 1e3) == 0:
# Serialize 6 fractional digits. # Serialize 6 fractional digits.
return result + ".%06dZ" % (nanos / 1e3) return f"{result}.{int(nanos // 1e3) :06d}Z"
# Serialize 9 fractional digits. # Serialize 9 fractional digits.
return result + ".%09dZ" % nanos return f"{result}.{nanos:09d}"
class _WrappedMessage(Message): class _WrappedMessage(Message):

View File

@ -1,8 +1,8 @@
from typing import TYPE_CHECKING, TypeVar from typing import TYPE_CHECKING, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from . import Message
from grpclib._typing import IProtoMessage from grpclib._typing import IProtoMessage
from . import Message
# Bound type variable to allow methods to return `self` of subclasses # Bound type variable to allow methods to return `self` of subclasses
T = TypeVar("T", bound="Message") T = TypeVar("T", bound="Message")

View File

@ -1,10 +1,10 @@
import os import os
import re import re
from typing import Dict, List, Set, Type from typing import Dict, List, Set, Tuple, Type
from betterproto import safe_snake_case from ..casing import safe_snake_case
from betterproto.compile.naming import pythonize_class_name from ..lib.google import protobuf as google_protobuf
from betterproto.lib.google import protobuf as google_protobuf from .naming import pythonize_class_name
WRAPPER_TYPES: Dict[str, Type] = { WRAPPER_TYPES: Dict[str, Type] = {
".google.protobuf.DoubleValue": google_protobuf.DoubleValue, ".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
@ -19,7 +19,7 @@ WRAPPER_TYPES: Dict[str, Type] = {
} }
def parse_source_type_name(field_type_name): def parse_source_type_name(field_type_name: str) -> Tuple[str, str]:
""" """
Split full source type name into package and type name. Split full source type name into package and type name.
E.g. 'root.package.Message' -> ('root.package', 'Message') E.g. 'root.package.Message' -> ('root.package', 'Message')
@ -50,7 +50,7 @@ def get_type_reference(
if source_type == ".google.protobuf.Duration": if source_type == ".google.protobuf.Duration":
return "timedelta" return "timedelta"
if source_type == ".google.protobuf.Timestamp": elif source_type == ".google.protobuf.Timestamp":
return "datetime" return "datetime"
source_package, source_type = parse_source_type_name(source_type) source_package, source_type = parse_source_type_name(source_type)
@ -79,7 +79,7 @@ def get_type_reference(
return reference_cousin(current_package, imports, py_package, py_type) return reference_cousin(current_package, imports, py_package, py_type)
def reference_absolute(imports, py_package, py_type): def reference_absolute(imports: Set[str], py_package: List[str], py_type: str) -> str:
""" """
Returns a reference to a python type located in the root, i.e. sys.path. Returns a reference to a python type located in the root, i.e. sys.path.
""" """

View File

@ -1,13 +1,13 @@
from betterproto import casing from betterproto import casing
def pythonize_class_name(name): def pythonize_class_name(name: str) -> str:
return casing.pascal_case(name) return casing.pascal_case(name)
def pythonize_field_name(name: str): def pythonize_field_name(name: str) -> str:
return casing.safe_snake_case(name) return casing.safe_snake_case(name)
def pythonize_method_name(name: str): def pythonize_method_name(name: str) -> str:
return casing.safe_snake_case(name) return casing.safe_snake_case(name)

View File

@ -1,7 +1,7 @@
from abc import ABC
import asyncio import asyncio
import grpclib.const from abc import ABC
from typing import ( from typing import (
TYPE_CHECKING,
AsyncIterable, AsyncIterable,
AsyncIterator, AsyncIterator,
Collection, Collection,
@ -9,11 +9,13 @@ from typing import (
Mapping, Mapping,
Optional, Optional,
Tuple, Tuple,
TYPE_CHECKING,
Type, Type,
Union, Union,
) )
from betterproto._types import ST, T
import grpclib.const
from .._types import ST, T
if TYPE_CHECKING: if TYPE_CHECKING:
from grpclib.client import Channel from grpclib.client import Channel

View File

@ -1,12 +1,5 @@
import asyncio import asyncio
from typing import ( from typing import AsyncIterable, AsyncIterator, Iterable, Optional, TypeVar, Union
AsyncIterable,
AsyncIterator,
Iterable,
Optional,
TypeVar,
Union,
)
T = TypeVar("T") T = TypeVar("T")
@ -16,8 +9,6 @@ class ChannelClosed(Exception):
An exception raised on an attempt to send through a closed channel An exception raised on an attempt to send through a closed channel
""" """
pass
class ChannelDone(Exception): class ChannelDone(Exception):
""" """
@ -25,8 +16,6 @@ class ChannelDone(Exception):
and empty. and empty.
""" """
pass
class AsyncChannel(AsyncIterable[T]): class AsyncChannel(AsyncIterable[T]):
""" """

View File

@ -5,10 +5,9 @@ try:
import black import black
import jinja2 import jinja2
except ImportError as err: except ImportError as err:
missing_import = err.args[0][17:-1]
print( print(
"\033[31m" "\033[31m"
f"Unable to import `{missing_import}` from betterproto plugin! " f"Unable to import `{err.name}` from betterproto plugin! "
"Please ensure that you've installed betterproto as " "Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies ' '`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included." "are included."
@ -16,7 +15,7 @@ except ImportError as err:
) )
raise SystemExit(1) raise SystemExit(1)
from betterproto.plugin.models import OutputTemplate from .models import OutputTemplate
def outputfile_compiler(output_file: OutputTemplate) -> str: def outputfile_compiler(output_file: OutputTemplate) -> str:
@ -32,9 +31,7 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
) )
template = env.get_template("template.py.j2") template = env.get_template("template.py.j2")
res = black.format_str( return black.format_str(
template.render(output_file=output_file), template.render(output_file=output_file),
mode=black.FileMode(target_versions={black.TargetVersion.PY37}), mode=black.FileMode(target_versions={black.TargetVersion.PY37}),
) )
return res

View File

@ -1,13 +1,14 @@
#!/usr/bin/env python #!/usr/bin/env python
import sys
import os import os
import sys
from google.protobuf.compiler import plugin_pb2 as plugin from google.protobuf.compiler import plugin_pb2 as plugin
from betterproto.plugin.parser import generate_code from betterproto.plugin.parser import generate_code
def main(): def main() -> None:
"""The plugin's main entry point.""" """The plugin's main entry point."""
# Read request message from stdin # Read request message from stdin
data = sys.stdin.buffer.read() data = sys.stdin.buffer.read()
@ -33,7 +34,7 @@ def main():
sys.stdout.buffer.write(output) sys.stdout.buffer.write(output)
def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest): def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest) -> None:
""" """
For developers: Supports running plugin.py standalone so its possible to debug it. For developers: Supports running plugin.py standalone so its possible to debug it.
Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file. Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file.

View File

@ -1,14 +1,14 @@
"""Plugin model dataclasses. """Plugin model dataclasses.
These classes are meant to be an intermediate representation These classes are meant to be an intermediate representation
of protbuf objects. They are used to organize the data collected during parsing. of protobuf objects. They are used to organize the data collected during parsing.
The general intention is to create a doubly-linked tree-like structure The general intention is to create a doubly-linked tree-like structure
with the following types of references: with the following types of references:
- Downwards references: from message -> fields, from output package -> messages - Downwards references: from message -> fields, from output package -> messages
or from service -> service methods or from service -> service methods
- Upwards references: from field -> message, message -> package. - Upwards references: from field -> message, message -> package.
- Input/ouput message references: from a service method to it's corresponding - Input/output message references: from a service method to it's corresponding
input/output messages, which may even be in another package. input/output messages, which may even be in another package.
There are convenience methods to allow climbing up and down this tree, for There are convenience methods to allow climbing up and down this tree, for
@ -26,36 +26,24 @@ such as a pythonized name, that will be calculated from proto_obj.
The instantiation should also attach a reference to the new object The instantiation should also attach a reference to the new object
into the corresponding place within it's parent object. For example, into the corresponding place within it's parent object. For example,
instantiating field `A` with parent message `B` should add a instantiating field `A` with parent message `B` should add a
reference to `A` to `B`'s `fields` attirbute. reference to `A` to `B`'s `fields` attribute.
""" """
import re import re
from dataclasses import dataclass
from dataclasses import field
from typing import (
Iterator,
Union,
Type,
List,
Dict,
Set,
Text,
)
import textwrap import textwrap
from dataclasses import dataclass, field
from typing import Dict, Iterator, List, Optional, Set, Text, Type, Union
import betterproto import betterproto
from betterproto.compile.importing import (
get_type_reference, from ..casing import sanitize_name
parse_source_type_name, from ..compile.importing import get_type_reference, parse_source_type_name
) from ..compile.naming import (
from betterproto.compile.naming import (
pythonize_class_name, pythonize_class_name,
pythonize_field_name, pythonize_field_name,
pythonize_method_name, pythonize_method_name,
) )
from ..casing import sanitize_name
try: try:
# betterproto[compiler] specific dependencies # betterproto[compiler] specific dependencies
from google.protobuf.compiler import plugin_pb2 as plugin from google.protobuf.compiler import plugin_pb2 as plugin
@ -67,10 +55,9 @@ try:
MethodDescriptorProto, MethodDescriptorProto,
) )
except ImportError as err: except ImportError as err:
missing_import = re.match(r".*(cannot import name .*$)", err.args[0]).group(1)
print( print(
"\033[31m" "\033[31m"
f"Unable to import `{missing_import}` from betterproto plugin! " f"Unable to import `{err.name}` from betterproto plugin! "
"Please ensure that you've installed betterproto as " "Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies ' '`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included." "are included."
@ -124,10 +111,11 @@ PROTO_PACKED_TYPES = (
) )
def get_comment(proto_file, path: List[int], indent: int = 4) -> str: def get_comment(
proto_file: "FileDescriptorProto", path: List[int], indent: int = 4
) -> str:
pad = " " * indent pad = " " * indent
for sci in proto_file.source_code_info.location: for sci in proto_file.source_code_info.location:
# print(list(sci.path), path, file=sys.stderr)
if list(sci.path) == path and sci.leading_comments: if list(sci.path) == path and sci.leading_comments:
lines = textwrap.wrap( lines = textwrap.wrap(
sci.leading_comments.strip().replace("\n", ""), width=79 - indent sci.leading_comments.strip().replace("\n", ""), width=79 - indent
@ -153,9 +141,9 @@ class ProtoContentBase:
path: List[int] path: List[int]
comment_indent: int = 4 comment_indent: int = 4
parent: Union["Messsage", "OutputTemplate"] parent: Union["betterproto.Message", "OutputTemplate"]
def __post_init__(self): def __post_init__(self) -> None:
"""Checks that no fake default fields were left as placeholders.""" """Checks that no fake default fields were left as placeholders."""
for field_name, field_val in self.__dataclass_fields__.items(): for field_name, field_val in self.__dataclass_fields__.items():
if field_val is PLACEHOLDER: if field_val is PLACEHOLDER:
@ -273,7 +261,7 @@ class MessageCompiler(ProtoContentBase):
) )
deprecated: bool = field(default=False, init=False) deprecated: bool = field(default=False, init=False)
def __post_init__(self): def __post_init__(self) -> None:
# Add message to output file # Add message to output file
if isinstance(self.parent, OutputTemplate): if isinstance(self.parent, OutputTemplate):
if isinstance(self, EnumDefinitionCompiler): if isinstance(self, EnumDefinitionCompiler):
@ -314,17 +302,17 @@ def is_map(
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry" map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
if message_type == map_entry: if message_type == map_entry:
for nested in parent_message.nested_type: # parent message for nested in parent_message.nested_type: # parent message
if nested.name.replace("_", "").lower() == map_entry: if (
if nested.options.map_entry: nested.name.replace("_", "").lower() == map_entry
return True and nested.options.map_entry
):
return True
return False return False
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool: def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
"""True if proto_field_obj is a OneOf, otherwise False.""" """True if proto_field_obj is a OneOf, otherwise False."""
if proto_field_obj.HasField("oneof_index"): return proto_field_obj.HasField("oneof_index")
return True
return False
@dataclass @dataclass
@ -332,7 +320,7 @@ class FieldCompiler(MessageCompiler):
parent: MessageCompiler = PLACEHOLDER parent: MessageCompiler = PLACEHOLDER
proto_obj: FieldDescriptorProto = PLACEHOLDER proto_obj: FieldDescriptorProto = PLACEHOLDER
def __post_init__(self): def __post_init__(self) -> None:
# Add field to message # Add field to message
self.parent.fields.append(self) self.parent.fields.append(self)
# Check for new imports # Check for new imports
@ -357,11 +345,9 @@ class FieldCompiler(MessageCompiler):
([""] + self.betterproto_field_args) if self.betterproto_field_args else [] ([""] + self.betterproto_field_args) if self.betterproto_field_args else []
) )
betterproto_field_type = ( betterproto_field_type = (
f"betterproto.{self.field_type}_field({self.proto_obj.number}" f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})"
+ field_args
+ ")"
) )
return name + annotations + " = " + betterproto_field_type return f"{name}{annotations} = {betterproto_field_type}"
@property @property
def betterproto_field_args(self) -> List[str]: def betterproto_field_args(self) -> List[str]:
@ -371,7 +357,7 @@ class FieldCompiler(MessageCompiler):
return args return args
@property @property
def field_wraps(self) -> Union[str, None]: def field_wraps(self) -> Optional[str]:
"""Returns betterproto wrapped field type or None.""" """Returns betterproto wrapped field type or None."""
match_wrapper = re.match( match_wrapper = re.match(
r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name
@ -384,17 +370,15 @@ class FieldCompiler(MessageCompiler):
@property @property
def repeated(self) -> bool: def repeated(self) -> bool:
if self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED and not is_map( return (
self.proto_obj, self.parent self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED
): and not is_map(self.proto_obj, self.parent)
return True )
return False
@property @property
def mutable(self) -> bool: def mutable(self) -> bool:
"""True if the field is a mutable type, otherwise False.""" """True if the field is a mutable type, otherwise False."""
annotation = self.annotation return self.annotation.startswith(("List[", "Dict["))
return annotation.startswith("List[") or annotation.startswith("Dict[")
@property @property
def field_type(self) -> str: def field_type(self) -> str:
@ -425,9 +409,7 @@ class FieldCompiler(MessageCompiler):
@property @property
def packed(self) -> bool: def packed(self) -> bool:
"""True if the wire representation is a packed format.""" """True if the wire representation is a packed format."""
if self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES: return self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES
return True
return False
@property @property
def py_name(self) -> str: def py_name(self) -> str:
@ -486,22 +468,24 @@ class MapEntryCompiler(FieldCompiler):
proto_k_type: str = PLACEHOLDER proto_k_type: str = PLACEHOLDER
proto_v_type: str = PLACEHOLDER proto_v_type: str = PLACEHOLDER
def __post_init__(self): def __post_init__(self) -> None:
"""Explore nested types and set k_type and v_type if unset.""" """Explore nested types and set k_type and v_type if unset."""
map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry" map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry"
for nested in self.parent.proto_obj.nested_type: for nested in self.parent.proto_obj.nested_type:
if nested.name.replace("_", "").lower() == map_entry: if (
if nested.options.map_entry: nested.name.replace("_", "").lower() == map_entry
# Get Python types and nested.options.map_entry
self.py_k_type = FieldCompiler( ):
parent=self, proto_obj=nested.field[0] # key # Get Python types
).py_type self.py_k_type = FieldCompiler(
self.py_v_type = FieldCompiler( parent=self, proto_obj=nested.field[0] # key
parent=self, proto_obj=nested.field[1] # value ).py_type
).py_type self.py_v_type = FieldCompiler(
# Get proto types parent=self, proto_obj=nested.field[1] # value
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type) ).py_type
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type) # Get proto types
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__ super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
@property @property
@ -513,11 +497,11 @@ class MapEntryCompiler(FieldCompiler):
return "map" return "map"
@property @property
def annotation(self): def annotation(self) -> str:
return f"Dict[{self.py_k_type}, {self.py_v_type}]" return f"Dict[{self.py_k_type}, {self.py_v_type}]"
@property @property
def repeated(self): def repeated(self) -> bool:
return False # maps cannot be repeated return False # maps cannot be repeated
@ -536,7 +520,7 @@ class EnumDefinitionCompiler(MessageCompiler):
value: int value: int
comment: str comment: str
def __post_init__(self): def __post_init__(self) -> None:
# Get entries/allowed values for this Enum # Get entries/allowed values for this Enum
self.entries = [ self.entries = [
self.EnumEntry( self.EnumEntry(
@ -551,7 +535,7 @@ class EnumDefinitionCompiler(MessageCompiler):
super().__post_init__() # call MessageCompiler __post_init__ super().__post_init__() # call MessageCompiler __post_init__
@property @property
def default_value_string(self) -> int: def default_value_string(self) -> str:
"""Python representation of the default value for Enums. """Python representation of the default value for Enums.
As per the spec, this is the first value of the Enum. As per the spec, this is the first value of the Enum.
@ -572,11 +556,11 @@ class ServiceCompiler(ProtoContentBase):
super().__post_init__() # check for unset fields super().__post_init__() # check for unset fields
@property @property
def proto_name(self): def proto_name(self) -> str:
return self.proto_obj.name return self.proto_obj.name
@property @property
def py_name(self): def py_name(self) -> str:
return pythonize_class_name(self.proto_name) return pythonize_class_name(self.proto_name)
@ -628,7 +612,7 @@ class ServiceMethodCompiler(ProtoContentBase):
Name and actual default value (as a string) Name and actual default value (as a string)
for each argument with mutable default values. for each argument with mutable default values.
""" """
mutable_default_args = dict() mutable_default_args = {}
if self.py_input_message: if self.py_input_message:
for f in self.py_input_message.fields: for f in self.py_input_message.fields:
@ -654,18 +638,15 @@ class ServiceMethodCompiler(ProtoContentBase):
@property @property
def route(self) -> str: def route(self) -> str:
return ( return f"/{self.output_file.package}.{self.parent.proto_name}/{self.proto_name}"
f"/{self.output_file.package}."
f"{self.parent.proto_name}/{self.proto_name}"
)
@property @property
def py_input_message(self) -> Union[None, MessageCompiler]: def py_input_message(self) -> Optional[MessageCompiler]:
"""Find the input message object. """Find the input message object.
Returns Returns
------- -------
Union[None, MessageCompiler] Optional[MessageCompiler]
Method instance representing the input message. Method instance representing the input message.
If not input message could be found or there are no If not input message could be found or there are no
input messages, None is returned. input messages, None is returned.
@ -685,14 +666,13 @@ class ServiceMethodCompiler(ProtoContentBase):
@property @property
def py_input_message_type(self) -> str: def py_input_message_type(self) -> str:
"""String representation of the Python type correspoding to the """String representation of the Python type corresponding to the
input message. input message.
Returns Returns
------- -------
str str
String representation of the Python type correspoding to the String representation of the Python type corresponding to the input message.
input message.
""" """
return get_type_reference( return get_type_reference(
package=self.output_file.package, package=self.output_file.package,
@ -702,14 +682,13 @@ class ServiceMethodCompiler(ProtoContentBase):
@property @property
def py_output_message_type(self) -> str: def py_output_message_type(self) -> str:
"""String representation of the Python type correspoding to the """String representation of the Python type corresponding to the
output message. output message.
Returns Returns
------- -------
str str
String representation of the Python type correspoding to the String representation of the Python type corresponding to the output message.
output message.
""" """
return get_type_reference( return get_type_reference(
package=self.output_file.package, package=self.output_file.package,

View File

@ -1,7 +1,7 @@
import itertools import itertools
import pathlib import pathlib
import sys import sys
from typing import List, Iterator from typing import TYPE_CHECKING, Iterator, List, Tuple, Union, Set
try: try:
# betterproto[compiler] specific dependencies # betterproto[compiler] specific dependencies
@ -13,10 +13,9 @@ try:
ServiceDescriptorProto, ServiceDescriptorProto,
) )
except ImportError as err: except ImportError as err:
missing_import = err.args[0][17:-1]
print( print(
"\033[31m" "\033[31m"
f"Unable to import `{missing_import}` from betterproto plugin! " f"Unable to import `{err.name}` from betterproto plugin! "
"Please ensure that you've installed betterproto as " "Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies ' '`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included." "are included."
@ -24,26 +23,32 @@ except ImportError as err:
) )
raise SystemExit(1) raise SystemExit(1)
from betterproto.plugin.models import ( from .compiler import outputfile_compiler
PluginRequestCompiler, from .models import (
OutputTemplate,
MessageCompiler,
FieldCompiler,
OneOfFieldCompiler,
MapEntryCompiler,
EnumDefinitionCompiler, EnumDefinitionCompiler,
FieldCompiler,
MapEntryCompiler,
MessageCompiler,
OneOfFieldCompiler,
OutputTemplate,
PluginRequestCompiler,
ServiceCompiler, ServiceCompiler,
ServiceMethodCompiler, ServiceMethodCompiler,
is_map, is_map,
is_oneof, is_oneof,
) )
from betterproto.plugin.compiler import outputfile_compiler if TYPE_CHECKING:
from google.protobuf.descriptor import Descriptor
def traverse(proto_file: FieldDescriptorProto) -> Iterator: def traverse(
proto_file: FieldDescriptorProto,
) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]":
# Todo: Keep information about nested hierarchy # Todo: Keep information about nested hierarchy
def _traverse(path, items, prefix=""): def _traverse(
path: List[int], items: List["Descriptor"], prefix=""
) -> Iterator[Tuple[Union[str, EnumDescriptorProto], List[int]]]:
for i, item in enumerate(items): for i, item in enumerate(items):
# Adjust the name since we flatten the hierarchy. # Adjust the name since we flatten the hierarchy.
# Todo: don't change the name, but include full name in returned tuple # Todo: don't change the name, but include full name in returned tuple
@ -104,7 +109,7 @@ def generate_code(
read_protobuf_service(service, index, output_package) read_protobuf_service(service, index, output_package)
# Generate output files # Generate output files
output_paths: pathlib.Path = set() output_paths: Set[pathlib.Path] = set()
for output_package_name, output_package in request_data.output_packages.items(): for output_package_name, output_package in request_data.output_packages.items():
# Add files to the response object # Add files to the response object
@ -112,20 +117,17 @@ def generate_code(
output_paths.add(output_path) output_paths.add(output_path)
f: response.File = response.file.add() f: response.File = response.file.add()
f.name: str = str(output_path) f.name = str(output_path)
# Render and then format the output file # Render and then format the output file
f.content: str = outputfile_compiler(output_file=output_package) f.content = outputfile_compiler(output_file=output_package)
# Make each output directory a package with __init__ file # Make each output directory a package with __init__ file
init_files = ( init_files = {
set( directory.joinpath("__init__.py")
directory.joinpath("__init__.py") for path in output_paths
for path in output_paths for directory in path.parents
for directory in path.parents } - output_paths
)
- output_paths
)
for init_file in init_files: for init_file in init_files:
init = response.file.add() init = response.file.add()

View File

@ -27,10 +27,7 @@ class ClientStub:
async def to_list(generator: AsyncIterator): async def to_list(generator: AsyncIterator):
result = [] return [value async for value in generator]
async for value in generator:
result.append(value)
return result
@pytest.fixture @pytest.fixture

View File

@ -6,7 +6,7 @@ from grpclib.client import Channel
class MockChannel(Channel): class MockChannel(Channel):
# noinspection PyMissingConstructor # noinspection PyMissingConstructor
def __init__(self, responses=None) -> None: def __init__(self, responses=None) -> None:
self.responses = responses if responses else [] self.responses = responses or []
self.requests = [] self.requests = []
self._loop = None self._loop = None

View File

@ -23,8 +23,7 @@ def get_files(path, suffix: str) -> Generator[str, None, None]:
def get_directories(path): def get_directories(path):
for root, directories, files in os.walk(path): for root, directories, files in os.walk(path):
for directory in directories: yield from directories
yield directory
async def protoc( async def protoc(
@ -49,7 +48,7 @@ async def protoc(
def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = None): def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = None):
test_data_file_name = json_file_name if json_file_name else f"{test_case_name}.json" test_data_file_name = json_file_name or f"{test_case_name}.json"
test_data_file_path = inputs_path.joinpath(test_case_name, test_data_file_name) test_data_file_path = inputs_path.joinpath(test_case_name, test_data_file_name)
if not test_data_file_path.exists(): if not test_data_file_path.exists():
@ -77,7 +76,7 @@ def find_module(
module_path = pathlib.Path(*module.__path__) module_path = pathlib.Path(*module.__path__)
for sub in list(sub.parent for sub in module_path.glob("**/__init__.py")): for sub in [sub.parent for sub in module_path.glob("**/__init__.py")]:
if sub == module_path: if sub == module_path:
continue continue
sub_module_path = sub.relative_to(module_path) sub_module_path = sub.relative_to(module_path)