diff --git a/.github/workflows/code-quality.yml b/.github/workflows/code-quality.yml index 03221ab..a3d483f 100644 --- a/.github/workflows/code-quality.yml +++ b/.github/workflows/code-quality.yml @@ -17,7 +17,7 @@ jobs: - name: Run Black uses: lgeiger/black-action@master with: - args: --check src/ tests/ + args: --check src/ tests/ benchmarks/ - name: Install rST dependcies run: python -m pip install doc8 diff --git a/.gitignore b/.gitignore index b35fe09..67d0768 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,5 @@ output .DS_Store .tox .venv -.asv \ No newline at end of file +.asv +venv diff --git a/benchmarks/benchmarks.py b/benchmarks/benchmarks.py index 3cbde53..76fb906 100644 --- a/benchmarks/benchmarks.py +++ b/benchmarks/benchmarks.py @@ -8,9 +8,9 @@ class TestMessage(betterproto.Message): bar: str = betterproto.string_field(1) baz: float = betterproto.float_field(2) + class BenchMessage: - """Test creation and usage a proto message. - """ + """Test creation and usage a proto message.""" def setup(self): self.cls = TestMessage @@ -18,8 +18,8 @@ class BenchMessage: self.instance_filled = TestMessage(0, "test", 0.0) def time_overhead(self): - """Overhead in class definition. - """ + """Overhead in class definition.""" + @dataclass class Message(betterproto.Message): foo: int = betterproto.uint32_field(0) @@ -27,29 +27,25 @@ class BenchMessage: baz: float = betterproto.float_field(2) def time_instantiation(self): - """Time instantiation - """ + """Time instantiation""" self.cls() def time_attribute_access(self): - """Time to access an attribute - """ + """Time to access an attribute""" self.instance.foo self.instance.bar self.instance.baz - + def time_init_with_values(self): - """Time to set an attribute - """ + """Time to set an attribute""" self.cls(0, "test", 0.0) def time_attribute_setting(self): - """Time to set attributes - """ + """Time to set attributes""" self.instance.foo = 0 self.instance.bar = "test" self.instance.baz = 0.0 - + def time_serialize(self): """Time serializing a message to wire.""" bytes(self.instance_filled) @@ -58,6 +54,6 @@ class BenchMessage: class MemSuite: def setup(self): self.cls = TestMessage - + def mem_instance(self): return self.cls() diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 9a46fe1..b90dd05 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -26,7 +26,7 @@ from ._types import T from .casing import camel_case, safe_snake_case, snake_case from .grpc.grpclib_client import ServiceStub -if not (sys.version_info.major == 3 and sys.version_info.minor >= 7): +if sys.version_info[:2] < (3, 7): # Apply backport of datetime.fromisoformat from 3.7 from backports.datetime_fromisoformat import MonkeyPatch @@ -110,7 +110,7 @@ WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] # Protobuf datetimes start at the Unix Epoch in 1970 in UTC. -def datetime_default_gen(): +def datetime_default_gen() -> datetime: return datetime(1970, 1, 1, tzinfo=timezone.utc) @@ -256,8 +256,7 @@ class Enum(enum.IntEnum): @classmethod def from_string(cls, name: str) -> "Enum": - """ - Return the value which corresponds to the string name. + """Return the value which corresponds to the string name. Parameters ----------- @@ -316,11 +315,7 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes: return encode_varint(value) elif proto_type in [TYPE_SINT32, TYPE_SINT64]: # Handle zig-zag encoding. - if value >= 0: - value = value << 1 - else: - value = (value << 1) ^ (~0) - return encode_varint(value) + return encode_varint(value << 1 if value >= 0 else (value << 1) ^ (~0)) elif proto_type in FIXED_TYPES: return struct.pack(_pack_fmt(proto_type), value) elif proto_type == TYPE_STRING: @@ -413,15 +408,15 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: wire_type = num_wire & 0x7 decoded: Any = None - if wire_type == 0: + if wire_type == WIRE_VARINT: decoded, i = decode_varint(value, i) - elif wire_type == 1: + elif wire_type == WIRE_FIXED_64: decoded, i = value[i : i + 8], i + 8 - elif wire_type == 2: + elif wire_type == WIRE_LEN_DELIM: length, i = decode_varint(value, i) decoded = value[i : i + length] i += length - elif wire_type == 5: + elif wire_type == WIRE_FIXED_32: decoded, i = value[i : i + 4], i + 4 yield ParsedField( @@ -430,12 +425,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: class ProtoClassMetadata: - oneof_group_by_field: Dict[str, str] - oneof_field_by_group: Dict[str, Set[dataclasses.Field]] - default_gen: Dict[str, Callable] - cls_by_field: Dict[str, Type] - field_name_by_number: Dict[int, str] - meta_by_field_name: Dict[str, FieldMetadata] __slots__ = ( "oneof_group_by_field", "oneof_field_by_group", @@ -446,6 +435,14 @@ class ProtoClassMetadata: "sorted_field_names", ) + oneof_group_by_field: Dict[str, str] + oneof_field_by_group: Dict[str, Set[dataclasses.Field]] + field_name_by_number: Dict[int, str] + meta_by_field_name: Dict[str, FieldMetadata] + sorted_field_names: Tuple[str, ...] + default_gen: Dict[str, Callable[[], Any]] + cls_by_field: Dict[str, Type] + def __init__(self, cls: Type["Message"]): by_field = {} by_group: Dict[str, Set] = {} @@ -470,23 +467,21 @@ class ProtoClassMetadata: self.field_name_by_number = by_field_number self.meta_by_field_name = by_field_name self.sorted_field_names = tuple( - by_field_number[number] for number in sorted(by_field_number.keys()) + by_field_number[number] for number in sorted(by_field_number) ) - self.default_gen = self._get_default_gen(cls, fields) self.cls_by_field = self._get_cls_by_field(cls, fields) @staticmethod - def _get_default_gen(cls, fields): - default_gen = {} - - for field in fields: - default_gen[field.name] = cls._get_field_default_gen(field) - - return default_gen + def _get_default_gen( + cls: Type["Message"], fields: List[dataclasses.Field] + ) -> Dict[str, Callable[[], Any]]: + return {field.name: cls._get_field_default_gen(field) for field in fields} @staticmethod - def _get_cls_by_field(cls, fields): + def _get_cls_by_field( + cls: Type["Message"], fields: List[dataclasses.Field] + ) -> Dict[str, Type]: field_cls = {} for field in fields: @@ -503,7 +498,7 @@ class ProtoClassMetadata: ], bases=(Message,), ) - field_cls[field.name + ".value"] = vt + field_cls[f"{field.name}.value"] = vt else: field_cls[field.name] = cls._cls_for(field) @@ -612,7 +607,7 @@ class Message(ABC): super().__setattr__(attr, value) @property - def _betterproto(self): + def _betterproto(self) -> ProtoClassMetadata: """ Lazy initialize metadata for each protobuf class. It may be initialized multiple times in a multi-threaded environment, @@ -726,9 +721,8 @@ class Message(ABC): @classmethod def _type_hints(cls) -> Dict[str, Type]: - module = inspect.getmodule(cls) - type_hints = get_type_hints(cls, vars(module)) - return type_hints + module = sys.modules[cls.__module__] + return get_type_hints(cls, vars(module)) @classmethod def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type: @@ -739,7 +733,7 @@ class Message(ABC): field_cls = field_cls.__args__[index] return field_cls - def _get_field_default(self, field_name): + def _get_field_default(self, field_name: str) -> Any: return self._betterproto.default_gen[field_name]() @classmethod @@ -762,7 +756,7 @@ class Message(ABC): elif issubclass(t, Enum): # Enums always default to zero. return int - elif t == datetime: + elif t is datetime: # Offsets are relative to 1970-01-01T00:00:00Z return datetime_default_gen else: @@ -966,7 +960,7 @@ class Message(ABC): ) ): output[cased_name] = value.to_dict(casing, include_default_values) - elif meta.proto_type == "map": + elif meta.proto_type == TYPE_MAP: for k in value: if hasattr(value[k], "to_dict"): value[k] = value[k].to_dict(casing, include_default_values) @@ -1032,12 +1026,12 @@ class Message(ABC): continue if value[key] is not None: - if meta.proto_type == "message": + if meta.proto_type == TYPE_MESSAGE: v = getattr(self, field_name) if isinstance(v, list): cls = self._betterproto.cls_by_field[field_name] - for i in range(len(value[key])): - v.append(cls().from_dict(value[key][i])) + for item in value[key]: + v.append(cls().from_dict(item)) elif isinstance(v, datetime): v = datetime.fromisoformat(value[key].replace("Z", "+00:00")) setattr(self, field_name, v) @@ -1052,7 +1046,7 @@ class Message(ABC): v.from_dict(value[key]) elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: v = getattr(self, field_name) - cls = self._betterproto.cls_by_field[field_name + ".value"] + cls = self._betterproto.cls_by_field[f"{field_name}.value"] for k in value[key]: v[k] = cls().from_dict(value[key][k]) else: @@ -1134,7 +1128,7 @@ def serialized_on_wire(message: Message) -> bool: return message._serialized_on_wire -def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]: +def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]: """ Return the name and value of a message's one-of field group. @@ -1145,21 +1139,21 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]: """ field_name = message._group_current.get(group_name) if not field_name: - return ("", None) - return (field_name, getattr(message, field_name)) + return "", None + return field_name, getattr(message, field_name) # Circular import workaround: google.protobuf depends on base classes defined above. from .lib.google.protobuf import ( # noqa - Duration, - Timestamp, BoolValue, BytesValue, DoubleValue, + Duration, FloatValue, Int32Value, Int64Value, StringValue, + Timestamp, UInt32Value, UInt64Value, ) @@ -1174,8 +1168,8 @@ class _Duration(Duration): parts = str(delta.total_seconds()).split(".") if len(parts) > 1: while len(parts[1]) not in [3, 6, 9]: - parts[1] = parts[1] + "0" - return ".".join(parts) + "s" + parts[1] = f"{parts[1]}0" + return f"{'.'.join(parts)}s" class _Timestamp(Timestamp): @@ -1191,15 +1185,15 @@ class _Timestamp(Timestamp): if (nanos % 1e9) == 0: # If there are 0 fractional digits, the fractional # point '.' should be omitted when serializing. - return result + "Z" + return f"{result}Z" if (nanos % 1e6) == 0: # Serialize 3 fractional digits. - return result + ".%03dZ" % (nanos / 1e6) + return f"{result}.{int(nanos // 1e6) :03d}Z" if (nanos % 1e3) == 0: # Serialize 6 fractional digits. - return result + ".%06dZ" % (nanos / 1e3) + return f"{result}.{int(nanos // 1e3) :06d}Z" # Serialize 9 fractional digits. - return result + ".%09dZ" % nanos + return f"{result}.{nanos:09d}" class _WrappedMessage(Message): diff --git a/src/betterproto/_types.py b/src/betterproto/_types.py index bc3748f..26b7344 100644 --- a/src/betterproto/_types.py +++ b/src/betterproto/_types.py @@ -1,8 +1,8 @@ from typing import TYPE_CHECKING, TypeVar if TYPE_CHECKING: - from . import Message from grpclib._typing import IProtoMessage + from . import Message # Bound type variable to allow methods to return `self` of subclasses T = TypeVar("T", bound="Message") diff --git a/src/betterproto/compile/importing.py b/src/betterproto/compile/importing.py index 1e245e4..8d471d1 100644 --- a/src/betterproto/compile/importing.py +++ b/src/betterproto/compile/importing.py @@ -1,10 +1,10 @@ import os import re -from typing import Dict, List, Set, Type +from typing import Dict, List, Set, Tuple, Type -from betterproto import safe_snake_case -from betterproto.compile.naming import pythonize_class_name -from betterproto.lib.google import protobuf as google_protobuf +from ..casing import safe_snake_case +from ..lib.google import protobuf as google_protobuf +from .naming import pythonize_class_name WRAPPER_TYPES: Dict[str, Type] = { ".google.protobuf.DoubleValue": google_protobuf.DoubleValue, @@ -19,7 +19,7 @@ WRAPPER_TYPES: Dict[str, Type] = { } -def parse_source_type_name(field_type_name): +def parse_source_type_name(field_type_name: str) -> Tuple[str, str]: """ Split full source type name into package and type name. E.g. 'root.package.Message' -> ('root.package', 'Message') @@ -50,7 +50,7 @@ def get_type_reference( if source_type == ".google.protobuf.Duration": return "timedelta" - if source_type == ".google.protobuf.Timestamp": + elif source_type == ".google.protobuf.Timestamp": return "datetime" source_package, source_type = parse_source_type_name(source_type) @@ -79,7 +79,7 @@ def get_type_reference( return reference_cousin(current_package, imports, py_package, py_type) -def reference_absolute(imports, py_package, py_type): +def reference_absolute(imports: Set[str], py_package: List[str], py_type: str) -> str: """ Returns a reference to a python type located in the root, i.e. sys.path. """ diff --git a/src/betterproto/compile/naming.py b/src/betterproto/compile/naming.py index 3d56852..1c2dbab 100644 --- a/src/betterproto/compile/naming.py +++ b/src/betterproto/compile/naming.py @@ -1,13 +1,13 @@ from betterproto import casing -def pythonize_class_name(name): +def pythonize_class_name(name: str) -> str: return casing.pascal_case(name) -def pythonize_field_name(name: str): +def pythonize_field_name(name: str) -> str: return casing.safe_snake_case(name) -def pythonize_method_name(name: str): +def pythonize_method_name(name: str) -> str: return casing.safe_snake_case(name) diff --git a/src/betterproto/grpc/grpclib_client.py b/src/betterproto/grpc/grpclib_client.py index 6fa35b4..a22b7e3 100644 --- a/src/betterproto/grpc/grpclib_client.py +++ b/src/betterproto/grpc/grpclib_client.py @@ -1,7 +1,7 @@ -from abc import ABC import asyncio -import grpclib.const +from abc import ABC from typing import ( + TYPE_CHECKING, AsyncIterable, AsyncIterator, Collection, @@ -9,11 +9,13 @@ from typing import ( Mapping, Optional, Tuple, - TYPE_CHECKING, Type, Union, ) -from betterproto._types import ST, T + +import grpclib.const + +from .._types import ST, T if TYPE_CHECKING: from grpclib.client import Channel diff --git a/src/betterproto/grpc/util/async_channel.py b/src/betterproto/grpc/util/async_channel.py index 0cda4b2..9b822fe 100644 --- a/src/betterproto/grpc/util/async_channel.py +++ b/src/betterproto/grpc/util/async_channel.py @@ -1,12 +1,5 @@ import asyncio -from typing import ( - AsyncIterable, - AsyncIterator, - Iterable, - Optional, - TypeVar, - Union, -) +from typing import AsyncIterable, AsyncIterator, Iterable, Optional, TypeVar, Union T = TypeVar("T") @@ -16,8 +9,6 @@ class ChannelClosed(Exception): An exception raised on an attempt to send through a closed channel """ - pass - class ChannelDone(Exception): """ @@ -25,8 +16,6 @@ class ChannelDone(Exception): and empty. """ - pass - class AsyncChannel(AsyncIterable[T]): """ diff --git a/src/betterproto/plugin/compiler.py b/src/betterproto/plugin/compiler.py index 4fd3b8f..617a650 100644 --- a/src/betterproto/plugin/compiler.py +++ b/src/betterproto/plugin/compiler.py @@ -5,10 +5,9 @@ try: import black import jinja2 except ImportError as err: - missing_import = err.args[0][17:-1] print( "\033[31m" - f"Unable to import `{missing_import}` from betterproto plugin! " + f"Unable to import `{err.name}` from betterproto plugin! " "Please ensure that you've installed betterproto as " '`pip install "betterproto[compiler]"` so that compiler dependencies ' "are included." @@ -16,7 +15,7 @@ except ImportError as err: ) raise SystemExit(1) -from betterproto.plugin.models import OutputTemplate +from .models import OutputTemplate def outputfile_compiler(output_file: OutputTemplate) -> str: @@ -32,9 +31,7 @@ def outputfile_compiler(output_file: OutputTemplate) -> str: ) template = env.get_template("template.py.j2") - res = black.format_str( + return black.format_str( template.render(output_file=output_file), mode=black.FileMode(target_versions={black.TargetVersion.PY37}), ) - - return res diff --git a/src/betterproto/plugin/main.py b/src/betterproto/plugin/main.py index 2604af2..dc9d04c 100644 --- a/src/betterproto/plugin/main.py +++ b/src/betterproto/plugin/main.py @@ -1,13 +1,14 @@ #!/usr/bin/env python -import sys + import os +import sys from google.protobuf.compiler import plugin_pb2 as plugin from betterproto.plugin.parser import generate_code -def main(): +def main() -> None: """The plugin's main entry point.""" # Read request message from stdin data = sys.stdin.buffer.read() @@ -33,7 +34,7 @@ def main(): sys.stdout.buffer.write(output) -def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest): +def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest) -> None: """ For developers: Supports running plugin.py standalone so its possible to debug it. Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file. diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index bbd21e6..bf31405 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -1,14 +1,14 @@ """Plugin model dataclasses. These classes are meant to be an intermediate representation -of protbuf objects. They are used to organize the data collected during parsing. +of protobuf objects. They are used to organize the data collected during parsing. The general intention is to create a doubly-linked tree-like structure with the following types of references: - Downwards references: from message -> fields, from output package -> messages or from service -> service methods - Upwards references: from field -> message, message -> package. -- Input/ouput message references: from a service method to it's corresponding +- Input/output message references: from a service method to it's corresponding input/output messages, which may even be in another package. There are convenience methods to allow climbing up and down this tree, for @@ -26,36 +26,24 @@ such as a pythonized name, that will be calculated from proto_obj. The instantiation should also attach a reference to the new object into the corresponding place within it's parent object. For example, instantiating field `A` with parent message `B` should add a -reference to `A` to `B`'s `fields` attirbute. +reference to `A` to `B`'s `fields` attribute. """ import re -from dataclasses import dataclass -from dataclasses import field -from typing import ( - Iterator, - Union, - Type, - List, - Dict, - Set, - Text, -) import textwrap +from dataclasses import dataclass, field +from typing import Dict, Iterator, List, Optional, Set, Text, Type, Union import betterproto -from betterproto.compile.importing import ( - get_type_reference, - parse_source_type_name, -) -from betterproto.compile.naming import ( + +from ..casing import sanitize_name +from ..compile.importing import get_type_reference, parse_source_type_name +from ..compile.naming import ( pythonize_class_name, pythonize_field_name, pythonize_method_name, ) -from ..casing import sanitize_name - try: # betterproto[compiler] specific dependencies from google.protobuf.compiler import plugin_pb2 as plugin @@ -67,10 +55,9 @@ try: MethodDescriptorProto, ) except ImportError as err: - missing_import = re.match(r".*(cannot import name .*$)", err.args[0]).group(1) print( "\033[31m" - f"Unable to import `{missing_import}` from betterproto plugin! " + f"Unable to import `{err.name}` from betterproto plugin! " "Please ensure that you've installed betterproto as " '`pip install "betterproto[compiler]"` so that compiler dependencies ' "are included." @@ -124,10 +111,11 @@ PROTO_PACKED_TYPES = ( ) -def get_comment(proto_file, path: List[int], indent: int = 4) -> str: +def get_comment( + proto_file: "FileDescriptorProto", path: List[int], indent: int = 4 +) -> str: pad = " " * indent for sci in proto_file.source_code_info.location: - # print(list(sci.path), path, file=sys.stderr) if list(sci.path) == path and sci.leading_comments: lines = textwrap.wrap( sci.leading_comments.strip().replace("\n", ""), width=79 - indent @@ -153,9 +141,9 @@ class ProtoContentBase: path: List[int] comment_indent: int = 4 - parent: Union["Messsage", "OutputTemplate"] + parent: Union["betterproto.Message", "OutputTemplate"] - def __post_init__(self): + def __post_init__(self) -> None: """Checks that no fake default fields were left as placeholders.""" for field_name, field_val in self.__dataclass_fields__.items(): if field_val is PLACEHOLDER: @@ -273,7 +261,7 @@ class MessageCompiler(ProtoContentBase): ) deprecated: bool = field(default=False, init=False) - def __post_init__(self): + def __post_init__(self) -> None: # Add message to output file if isinstance(self.parent, OutputTemplate): if isinstance(self, EnumDefinitionCompiler): @@ -314,17 +302,17 @@ def is_map( map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry" if message_type == map_entry: for nested in parent_message.nested_type: # parent message - if nested.name.replace("_", "").lower() == map_entry: - if nested.options.map_entry: - return True + if ( + nested.name.replace("_", "").lower() == map_entry + and nested.options.map_entry + ): + return True return False def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool: """True if proto_field_obj is a OneOf, otherwise False.""" - if proto_field_obj.HasField("oneof_index"): - return True - return False + return proto_field_obj.HasField("oneof_index") @dataclass @@ -332,7 +320,7 @@ class FieldCompiler(MessageCompiler): parent: MessageCompiler = PLACEHOLDER proto_obj: FieldDescriptorProto = PLACEHOLDER - def __post_init__(self): + def __post_init__(self) -> None: # Add field to message self.parent.fields.append(self) # Check for new imports @@ -357,11 +345,9 @@ class FieldCompiler(MessageCompiler): ([""] + self.betterproto_field_args) if self.betterproto_field_args else [] ) betterproto_field_type = ( - f"betterproto.{self.field_type}_field({self.proto_obj.number}" - + field_args - + ")" + f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})" ) - return name + annotations + " = " + betterproto_field_type + return f"{name}{annotations} = {betterproto_field_type}" @property def betterproto_field_args(self) -> List[str]: @@ -371,7 +357,7 @@ class FieldCompiler(MessageCompiler): return args @property - def field_wraps(self) -> Union[str, None]: + def field_wraps(self) -> Optional[str]: """Returns betterproto wrapped field type or None.""" match_wrapper = re.match( r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name @@ -384,17 +370,15 @@ class FieldCompiler(MessageCompiler): @property def repeated(self) -> bool: - if self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED and not is_map( - self.proto_obj, self.parent - ): - return True - return False + return ( + self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED + and not is_map(self.proto_obj, self.parent) + ) @property def mutable(self) -> bool: """True if the field is a mutable type, otherwise False.""" - annotation = self.annotation - return annotation.startswith("List[") or annotation.startswith("Dict[") + return self.annotation.startswith(("List[", "Dict[")) @property def field_type(self) -> str: @@ -425,9 +409,7 @@ class FieldCompiler(MessageCompiler): @property def packed(self) -> bool: """True if the wire representation is a packed format.""" - if self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES: - return True - return False + return self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES @property def py_name(self) -> str: @@ -486,22 +468,24 @@ class MapEntryCompiler(FieldCompiler): proto_k_type: str = PLACEHOLDER proto_v_type: str = PLACEHOLDER - def __post_init__(self): + def __post_init__(self) -> None: """Explore nested types and set k_type and v_type if unset.""" map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry" for nested in self.parent.proto_obj.nested_type: - if nested.name.replace("_", "").lower() == map_entry: - if nested.options.map_entry: - # Get Python types - self.py_k_type = FieldCompiler( - parent=self, proto_obj=nested.field[0] # key - ).py_type - self.py_v_type = FieldCompiler( - parent=self, proto_obj=nested.field[1] # value - ).py_type - # Get proto types - self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type) - self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type) + if ( + nested.name.replace("_", "").lower() == map_entry + and nested.options.map_entry + ): + # Get Python types + self.py_k_type = FieldCompiler( + parent=self, proto_obj=nested.field[0] # key + ).py_type + self.py_v_type = FieldCompiler( + parent=self, proto_obj=nested.field[1] # value + ).py_type + # Get proto types + self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type) + self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type) super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__ @property @@ -513,11 +497,11 @@ class MapEntryCompiler(FieldCompiler): return "map" @property - def annotation(self): + def annotation(self) -> str: return f"Dict[{self.py_k_type}, {self.py_v_type}]" @property - def repeated(self): + def repeated(self) -> bool: return False # maps cannot be repeated @@ -536,7 +520,7 @@ class EnumDefinitionCompiler(MessageCompiler): value: int comment: str - def __post_init__(self): + def __post_init__(self) -> None: # Get entries/allowed values for this Enum self.entries = [ self.EnumEntry( @@ -551,7 +535,7 @@ class EnumDefinitionCompiler(MessageCompiler): super().__post_init__() # call MessageCompiler __post_init__ @property - def default_value_string(self) -> int: + def default_value_string(self) -> str: """Python representation of the default value for Enums. As per the spec, this is the first value of the Enum. @@ -572,11 +556,11 @@ class ServiceCompiler(ProtoContentBase): super().__post_init__() # check for unset fields @property - def proto_name(self): + def proto_name(self) -> str: return self.proto_obj.name @property - def py_name(self): + def py_name(self) -> str: return pythonize_class_name(self.proto_name) @@ -628,7 +612,7 @@ class ServiceMethodCompiler(ProtoContentBase): Name and actual default value (as a string) for each argument with mutable default values. """ - mutable_default_args = dict() + mutable_default_args = {} if self.py_input_message: for f in self.py_input_message.fields: @@ -654,18 +638,15 @@ class ServiceMethodCompiler(ProtoContentBase): @property def route(self) -> str: - return ( - f"/{self.output_file.package}." - f"{self.parent.proto_name}/{self.proto_name}" - ) + return f"/{self.output_file.package}.{self.parent.proto_name}/{self.proto_name}" @property - def py_input_message(self) -> Union[None, MessageCompiler]: + def py_input_message(self) -> Optional[MessageCompiler]: """Find the input message object. Returns ------- - Union[None, MessageCompiler] + Optional[MessageCompiler] Method instance representing the input message. If not input message could be found or there are no input messages, None is returned. @@ -685,14 +666,13 @@ class ServiceMethodCompiler(ProtoContentBase): @property def py_input_message_type(self) -> str: - """String representation of the Python type correspoding to the + """String representation of the Python type corresponding to the input message. Returns ------- str - String representation of the Python type correspoding to the - input message. + String representation of the Python type corresponding to the input message. """ return get_type_reference( package=self.output_file.package, @@ -702,14 +682,13 @@ class ServiceMethodCompiler(ProtoContentBase): @property def py_output_message_type(self) -> str: - """String representation of the Python type correspoding to the + """String representation of the Python type corresponding to the output message. Returns ------- str - String representation of the Python type correspoding to the - output message. + String representation of the Python type corresponding to the output message. """ return get_type_reference( package=self.output_file.package, diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py index cb5b654..a1be268 100644 --- a/src/betterproto/plugin/parser.py +++ b/src/betterproto/plugin/parser.py @@ -1,7 +1,7 @@ import itertools import pathlib import sys -from typing import List, Iterator +from typing import TYPE_CHECKING, Iterator, List, Tuple, Union, Set try: # betterproto[compiler] specific dependencies @@ -13,10 +13,9 @@ try: ServiceDescriptorProto, ) except ImportError as err: - missing_import = err.args[0][17:-1] print( "\033[31m" - f"Unable to import `{missing_import}` from betterproto plugin! " + f"Unable to import `{err.name}` from betterproto plugin! " "Please ensure that you've installed betterproto as " '`pip install "betterproto[compiler]"` so that compiler dependencies ' "are included." @@ -24,26 +23,32 @@ except ImportError as err: ) raise SystemExit(1) -from betterproto.plugin.models import ( - PluginRequestCompiler, - OutputTemplate, - MessageCompiler, - FieldCompiler, - OneOfFieldCompiler, - MapEntryCompiler, +from .compiler import outputfile_compiler +from .models import ( EnumDefinitionCompiler, + FieldCompiler, + MapEntryCompiler, + MessageCompiler, + OneOfFieldCompiler, + OutputTemplate, + PluginRequestCompiler, ServiceCompiler, ServiceMethodCompiler, is_map, is_oneof, ) -from betterproto.plugin.compiler import outputfile_compiler +if TYPE_CHECKING: + from google.protobuf.descriptor import Descriptor -def traverse(proto_file: FieldDescriptorProto) -> Iterator: +def traverse( + proto_file: FieldDescriptorProto, +) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]": # Todo: Keep information about nested hierarchy - def _traverse(path, items, prefix=""): + def _traverse( + path: List[int], items: List["Descriptor"], prefix="" + ) -> Iterator[Tuple[Union[str, EnumDescriptorProto], List[int]]]: for i, item in enumerate(items): # Adjust the name since we flatten the hierarchy. # Todo: don't change the name, but include full name in returned tuple @@ -104,7 +109,7 @@ def generate_code( read_protobuf_service(service, index, output_package) # Generate output files - output_paths: pathlib.Path = set() + output_paths: Set[pathlib.Path] = set() for output_package_name, output_package in request_data.output_packages.items(): # Add files to the response object @@ -112,20 +117,17 @@ def generate_code( output_paths.add(output_path) f: response.File = response.file.add() - f.name: str = str(output_path) + f.name = str(output_path) # Render and then format the output file - f.content: str = outputfile_compiler(output_file=output_package) + f.content = outputfile_compiler(output_file=output_package) # Make each output directory a package with __init__ file - init_files = ( - set( - directory.joinpath("__init__.py") - for path in output_paths - for directory in path.parents - ) - - output_paths - ) + init_files = { + directory.joinpath("__init__.py") + for path in output_paths + for directory in path.parents + } - output_paths for init_file in init_files: init = response.file.add() diff --git a/tests/grpc/test_stream_stream.py b/tests/grpc/test_stream_stream.py index 2fc9237..020262d 100644 --- a/tests/grpc/test_stream_stream.py +++ b/tests/grpc/test_stream_stream.py @@ -27,10 +27,7 @@ class ClientStub: async def to_list(generator: AsyncIterator): - result = [] - async for value in generator: - result.append(value) - return result + return [value async for value in generator] @pytest.fixture diff --git a/tests/mocks.py b/tests/mocks.py index 9042f78..dc6e117 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -6,7 +6,7 @@ from grpclib.client import Channel class MockChannel(Channel): # noinspection PyMissingConstructor def __init__(self, responses=None) -> None: - self.responses = responses if responses else [] + self.responses = responses or [] self.requests = [] self._loop = None diff --git a/tests/util.py b/tests/util.py index d085eb6..6c63141 100644 --- a/tests/util.py +++ b/tests/util.py @@ -23,8 +23,7 @@ def get_files(path, suffix: str) -> Generator[str, None, None]: def get_directories(path): for root, directories, files in os.walk(path): - for directory in directories: - yield directory + yield from directories async def protoc( @@ -49,7 +48,7 @@ async def protoc( def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = None): - test_data_file_name = json_file_name if json_file_name else f"{test_case_name}.json" + test_data_file_name = json_file_name or f"{test_case_name}.json" test_data_file_path = inputs_path.joinpath(test_case_name, test_data_file_name) if not test_data_file_path.exists(): @@ -77,7 +76,7 @@ def find_module( module_path = pathlib.Path(*module.__path__) - for sub in list(sub.parent for sub in module_path.glob("**/__init__.py")): + for sub in [sub.parent for sub in module_path.glob("**/__init__.py")]: if sub == module_path: continue sub_module_path = sub.relative_to(module_path)