From 3fd5a0d6626b54a87494b97b09f413fea0c28a2b Mon Sep 17 00:00:00 2001 From: James Hilton-Balfe Date: Wed, 6 Jul 2022 19:05:40 +0100 Subject: [PATCH] Fix parameters missing from services (#381) --- src/betterproto/__init__.py | 24 +++++++--- src/betterproto/compile/importing.py | 2 +- src/betterproto/grpc/grpclib_client.py | 39 +++++++-------- src/betterproto/plugin/models.py | 3 ++ src/betterproto/plugin/parser.py | 18 +++---- src/betterproto/templates/template.py.j2 | 13 +++-- tests/inputs/config.py | 1 + .../googletypes_request.proto | 29 ++++++++++++ .../test_googletypes_request.py | 47 +++++++++++++++++++ 9 files changed, 136 insertions(+), 40 deletions(-) create mode 100644 tests/inputs/googletypes_request/googletypes_request.proto create mode 100644 tests/inputs/googletypes_request/test_googletypes_request.py diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 384c260..62056e3 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -379,15 +379,10 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes: elif proto_type == TYPE_MESSAGE: if isinstance(value, datetime): # Convert the `datetime` to a timestamp message. - seconds = int(value.timestamp()) - nanos = int(value.microsecond * 1e3) - value = _Timestamp(seconds=seconds, nanos=nanos) + value = _Timestamp.from_datetime(value) elif isinstance(value, timedelta): # Convert the `timedelta` to a duration message. - total_ms = value // timedelta(microseconds=1) - seconds = int(total_ms / 1e6) - nanos = int((total_ms % 1e6) * 1e3) - value = _Duration(seconds=seconds, nanos=nanos) + value = _Duration.from_timedelta(value) elif wraps: if value is None: return b"" @@ -1505,6 +1500,15 @@ from .lib.google.protobuf import ( # noqa class _Duration(Duration): + @classmethod + def from_timedelta( + cls, delta: timedelta, *, _1_microsecond: timedelta = timedelta(microseconds=1) + ) -> "_Duration": + total_ms = delta // _1_microsecond + seconds = int(total_ms / 1e6) + nanos = int((total_ms % 1e6) * 1e3) + return cls(seconds, nanos) + def to_timedelta(self) -> timedelta: return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3) @@ -1518,6 +1522,12 @@ class _Duration(Duration): class _Timestamp(Timestamp): + @classmethod + def from_datetime(cls, dt: datetime) -> "_Timestamp": + seconds = int(dt.timestamp()) + nanos = int(dt.microsecond * 1e3) + return cls(seconds, nanos) + def to_datetime(self) -> datetime: ts = self.seconds + (self.nanos / 1e9) return datetime.fromtimestamp(ts, tz=timezone.utc) diff --git a/src/betterproto/compile/importing.py b/src/betterproto/compile/importing.py index a28f555..96ede7b 100644 --- a/src/betterproto/compile/importing.py +++ b/src/betterproto/compile/importing.py @@ -43,7 +43,7 @@ def parse_source_type_name(field_type_name: str) -> Tuple[str, str]: def get_type_reference( - package: str, imports: set, source_type: str, unwrap: bool = True + *, package: str, imports: set, source_type: str, unwrap: bool = True ) -> str: """ Return a Python type name for a proto type reference. Adds the import if diff --git a/src/betterproto/grpc/grpclib_client.py b/src/betterproto/grpc/grpclib_client.py index 28d47bd..54e5797 100644 --- a/src/betterproto/grpc/grpclib_client.py +++ b/src/betterproto/grpc/grpclib_client.py @@ -15,21 +15,22 @@ from typing import ( import grpclib.const -from .._types import ( - ST, - T, -) - if TYPE_CHECKING: from grpclib.client import Channel from grpclib.metadata import Deadline + from .._types import ( + ST, + IProtoMessage, + Message, + T, + ) + Value = Union[str, bytes] MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]] -MessageLike = Union[T, ST] -MessageSource = Union[Iterable[ST], AsyncIterable[ST]] +MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]] class ServiceStub(ABC): @@ -65,13 +66,13 @@ class ServiceStub(ABC): async def _unary_unary( self, route: str, - request: MessageLike, - response_type: Type[T], + request: "IProtoMessage", + response_type: Type["T"], *, timeout: Optional[float] = None, deadline: Optional["Deadline"] = None, metadata: Optional[MetadataLike] = None, - ) -> T: + ) -> "T": """Make a unary request and return the response.""" async with self.channel.request( route, @@ -88,13 +89,13 @@ class ServiceStub(ABC): async def _unary_stream( self, route: str, - request: MessageLike, - response_type: Type[T], + request: "IProtoMessage", + response_type: Type["T"], *, timeout: Optional[float] = None, deadline: Optional["Deadline"] = None, metadata: Optional[MetadataLike] = None, - ) -> AsyncIterator[T]: + ) -> AsyncIterator["T"]: """Make a unary request and return the stream response iterator.""" async with self.channel.request( route, @@ -111,13 +112,13 @@ class ServiceStub(ABC): self, route: str, request_iterator: MessageSource, - request_type: Type[ST], - response_type: Type[T], + request_type: Type["IProtoMessage"], + response_type: Type["T"], *, timeout: Optional[float] = None, deadline: Optional["Deadline"] = None, metadata: Optional[MetadataLike] = None, - ) -> T: + ) -> "T": """Make a stream request and return the response.""" async with self.channel.request( route, @@ -135,13 +136,13 @@ class ServiceStub(ABC): self, route: str, request_iterator: MessageSource, - request_type: Type[ST], - response_type: Type[T], + request_type: Type["IProtoMessage"], + response_type: Type["T"], *, timeout: Optional[float] = None, deadline: Optional["Deadline"] = None, metadata: Optional[MetadataLike] = None, - ) -> AsyncIterator[T]: + ) -> AsyncIterator["T"]: """ Make a stream request and return an AsyncIterator to iterate over response messages. diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 71c5471..9c835fa 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -252,6 +252,7 @@ class OutputTemplate: enums: List["EnumDefinitionCompiler"] = field(default_factory=list) services: List["ServiceCompiler"] = field(default_factory=list) imports_type_checking_only: Set[str] = field(default_factory=set) + output: bool = True @property def package(self) -> str: @@ -704,6 +705,7 @@ class ServiceMethodCompiler(ProtoContentBase): # add imports required for request arguments timeout, deadline and metadata self.output_file.typing_imports.add("Optional") + self.output_file.imports_type_checking_only.add("import grpclib.server") self.output_file.imports_type_checking_only.add( "from betterproto.grpc.grpclib_client import MetadataLike" ) @@ -768,6 +770,7 @@ class ServiceMethodCompiler(ProtoContentBase): package=self.output_file.package, imports=self.output_file.imports, source_type=self.proto_obj.input_type, + unwrap=False, ).strip('"') @property diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py index d503ee0..c0d32f6 100644 --- a/src/betterproto/plugin/parser.py +++ b/src/betterproto/plugin/parser.py @@ -74,14 +74,6 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: request_data = PluginRequestCompiler(plugin_request_obj=request) # Gather output packages for proto_file in request.proto_file: - if ( - proto_file.package == "google.protobuf" - and "INCLUDE_GOOGLE" not in plugin_options - ): - # If not INCLUDE_GOOGLE, - # skip re-compiling Google's well-known types - continue - output_package_name = proto_file.package if output_package_name not in request_data.output_packages: # Create a new output if there is no output for this package @@ -91,6 +83,14 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: # Add this input file to the output corresponding to this package request_data.output_packages[output_package_name].input_files.append(proto_file) + if ( + proto_file.package == "google.protobuf" + and "INCLUDE_GOOGLE" not in plugin_options + ): + # If not INCLUDE_GOOGLE, + # skip outputting Google's well-known types + request_data.output_packages[output_package_name].output = False + # Read Messages and Enums # We need to read Messages before Services in so that we can # get the references to input/output messages for each service @@ -113,6 +113,8 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: # Generate output files output_paths: Set[pathlib.Path] = set() for output_package_name, output_package in request_data.output_packages.items(): + if not output_package.output: + continue # Add files to the response object output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index e615350..ce6fe90 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -15,13 +15,14 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no {% endif %} import betterproto +{% if output_file.services %} from betterproto.grpc.grpclib_server import ServiceBase +import grpclib +{% endif %} + {% for i in output_file.imports|sort %} {{ i }} {% endfor %} -{% if output_file.services %} -import grpclib -{% endif %} {% if output_file.imports_type_checking_only %} from typing import TYPE_CHECKING @@ -96,9 +97,11 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {# Client streaming: need a request iterator instead #} , {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]] {%- endif -%} + , + * , timeout: Optional[float] = None , deadline: Optional["Deadline"] = None - , metadata: Optional["_MetadataLike"] = None + , metadata: Optional["MetadataLike"] = None ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: {% if method.comment %} {{ method.comment }} @@ -179,7 +182,7 @@ class {{ service.py_name }}Base(ServiceBase): {% endfor %} {% for method in service.methods %} - async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None: + async def __rpc_{{ method.py_name }}(self, stream: "grpclib.server.Stream[{{ method.py_input_message_type }}, {{ method.py_output_message_type }}]") -> None: {% if not method.client_streaming %} request = await stream.recv_message() {% else %} diff --git a/tests/inputs/config.py b/tests/inputs/config.py index 49882b0..b0e6486 100644 --- a/tests/inputs/config.py +++ b/tests/inputs/config.py @@ -9,6 +9,7 @@ xfail = { } services = { + "googletypes_request", "googletypes_response", "googletypes_response_embedded", "service", diff --git a/tests/inputs/googletypes_request/googletypes_request.proto b/tests/inputs/googletypes_request/googletypes_request.proto new file mode 100644 index 0000000..1cedcaa --- /dev/null +++ b/tests/inputs/googletypes_request/googletypes_request.proto @@ -0,0 +1,29 @@ +syntax = "proto3"; + +package googletypes_request; + +import "google/protobuf/duration.proto"; +import "google/protobuf/empty.proto"; +import "google/protobuf/timestamp.proto"; +import "google/protobuf/wrappers.proto"; + +// Tests that google types can be used as params + +service Test { + rpc SendDouble (google.protobuf.DoubleValue) returns (Input); + rpc SendFloat (google.protobuf.FloatValue) returns (Input); + rpc SendInt64 (google.protobuf.Int64Value) returns (Input); + rpc SendUInt64 (google.protobuf.UInt64Value) returns (Input); + rpc SendInt32 (google.protobuf.Int32Value) returns (Input); + rpc SendUInt32 (google.protobuf.UInt32Value) returns (Input); + rpc SendBool (google.protobuf.BoolValue) returns (Input); + rpc SendString (google.protobuf.StringValue) returns (Input); + rpc SendBytes (google.protobuf.BytesValue) returns (Input); + rpc SendDatetime (google.protobuf.Timestamp) returns (Input); + rpc SendTimedelta (google.protobuf.Duration) returns (Input); + rpc SendEmpty (google.protobuf.Empty) returns (Input); +} + +message Input { + +} diff --git a/tests/inputs/googletypes_request/test_googletypes_request.py b/tests/inputs/googletypes_request/test_googletypes_request.py new file mode 100644 index 0000000..ffb2608 --- /dev/null +++ b/tests/inputs/googletypes_request/test_googletypes_request.py @@ -0,0 +1,47 @@ +from datetime import ( + datetime, + timedelta, +) +from typing import ( + Any, + Callable, +) + +import pytest + +import betterproto.lib.google.protobuf as protobuf +from tests.mocks import MockChannel +from tests.output_betterproto.googletypes_request import ( + Input, + TestStub, +) + + +test_cases = [ + (TestStub.send_double, protobuf.DoubleValue, 2.5), + (TestStub.send_float, protobuf.FloatValue, 2.5), + (TestStub.send_int64, protobuf.Int64Value, -64), + (TestStub.send_u_int64, protobuf.UInt64Value, 64), + (TestStub.send_int32, protobuf.Int32Value, -32), + (TestStub.send_u_int32, protobuf.UInt32Value, 32), + (TestStub.send_bool, protobuf.BoolValue, True), + (TestStub.send_string, protobuf.StringValue, "string"), + (TestStub.send_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]), + (TestStub.send_datetime, protobuf.Timestamp, datetime(2038, 1, 19, 3, 14, 8)), + (TestStub.send_timedelta, protobuf.Duration, timedelta(seconds=123456)), +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) +async def test_channel_receives_wrapped_type( + service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value +): + wrapped_value = wrapper_class() + wrapped_value.value = value + channel = MockChannel(responses=[Input()]) + service = TestStub(channel) + + await service_method(service, wrapped_value) + + assert channel.requests[0]["request"] == type(wrapped_value)