From 18a518efa7016ac6fe811388f3ee464a1808ac38 Mon Sep 17 00:00:00 2001 From: Arun Babu Neelicattu Date: Sun, 13 Mar 2022 23:34:11 +0100 Subject: [PATCH] Expose timeout, deadline and metadata parameters from grpclib (#352) --- src/betterproto/grpc/grpclib_client.py | 30 +++++++------- src/betterproto/plugin/models.py | 10 +++++ src/betterproto/templates/template.py.j2 | 26 +++++++++++- tests/grpc/test_grpclib_client.py | 53 +++++++++++++++++++++++- 4 files changed, 101 insertions(+), 18 deletions(-) diff --git a/src/betterproto/grpc/grpclib_client.py b/src/betterproto/grpc/grpclib_client.py index a22b7e3..960bd3d 100644 --- a/src/betterproto/grpc/grpclib_client.py +++ b/src/betterproto/grpc/grpclib_client.py @@ -22,10 +22,10 @@ if TYPE_CHECKING: from grpclib.metadata import Deadline -_Value = Union[str, bytes] -_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] -_MessageLike = Union[T, ST] -_MessageSource = Union[Iterable[ST], AsyncIterable[ST]] +Value = Union[str, bytes] +MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]] +MessageLike = Union[T, ST] +MessageSource = Union[Iterable[ST], AsyncIterable[ST]] class ServiceStub(ABC): @@ -39,7 +39,7 @@ class ServiceStub(ABC): *, timeout: Optional[float] = None, deadline: Optional["Deadline"] = None, - metadata: Optional[_MetadataLike] = None, + metadata: Optional[MetadataLike] = None, ) -> None: self.channel = channel self.timeout = timeout @@ -50,7 +50,7 @@ class ServiceStub(ABC): self, timeout: Optional[float], deadline: Optional["Deadline"], - metadata: Optional[_MetadataLike], + metadata: Optional[MetadataLike], ): return { "timeout": self.timeout if timeout is None else timeout, @@ -61,12 +61,12 @@ class ServiceStub(ABC): async def _unary_unary( self, route: str, - request: _MessageLike, + request: MessageLike, response_type: Type[T], *, timeout: Optional[float] = None, deadline: Optional["Deadline"] = None, - metadata: Optional[_MetadataLike] = None, + metadata: Optional[MetadataLike] = None, ) -> T: """Make a unary request and return the response.""" async with self.channel.request( @@ -84,12 +84,12 @@ class ServiceStub(ABC): async def _unary_stream( self, route: str, - request: _MessageLike, + request: MessageLike, response_type: Type[T], *, timeout: Optional[float] = None, deadline: Optional["Deadline"] = None, - metadata: Optional[_MetadataLike] = None, + metadata: Optional[MetadataLike] = None, ) -> AsyncIterator[T]: """Make a unary request and return the stream response iterator.""" async with self.channel.request( @@ -106,13 +106,13 @@ class ServiceStub(ABC): async def _stream_unary( self, route: str, - request_iterator: _MessageSource, + request_iterator: MessageSource, request_type: Type[ST], response_type: Type[T], *, timeout: Optional[float] = None, deadline: Optional["Deadline"] = None, - metadata: Optional[_MetadataLike] = None, + metadata: Optional[MetadataLike] = None, ) -> T: """Make a stream request and return the response.""" async with self.channel.request( @@ -130,13 +130,13 @@ class ServiceStub(ABC): async def _stream_stream( self, route: str, - request_iterator: _MessageSource, + request_iterator: MessageSource, request_type: Type[ST], response_type: Type[T], *, timeout: Optional[float] = None, deadline: Optional["Deadline"] = None, - metadata: Optional[_MetadataLike] = None, + metadata: Optional[MetadataLike] = None, ) -> AsyncIterator[T]: """ Make a stream request and return an AsyncIterator to iterate over response @@ -161,7 +161,7 @@ class ServiceStub(ABC): raise @staticmethod - async def _send_messages(stream, messages: _MessageSource): + async def _send_messages(stream, messages: MessageSource): if isinstance(messages, AsyncIterable): async for message in messages: await stream.send_message(message) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index c2fccfc..63161b7 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -232,6 +232,7 @@ class OutputTemplate: messages: List["MessageCompiler"] = field(default_factory=list) enums: List["EnumDefinitionCompiler"] = field(default_factory=list) services: List["ServiceCompiler"] = field(default_factory=list) + imports_type_checking_only: Set[str] = field(default_factory=set) @property def package(self) -> str: @@ -679,6 +680,15 @@ class ServiceMethodCompiler(ProtoContentBase): if self.client_streaming or self.server_streaming: self.output_file.typing_imports.add("AsyncIterator") + # add imports required for request arguments timeout, deadline and metadata + self.output_file.typing_imports.add("Optional") + self.output_file.imports_type_checking_only.add( + "from betterproto.grpc.grpclib_client import MetadataLike" + ) + self.output_file.imports_type_checking_only.add( + "from grpclib.metadata import Deadline" + ) + super().__post_init__() # check for unset fields @property diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 718cda9..8f72b5e 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -20,6 +20,13 @@ from betterproto.grpc.grpclib_server import ServiceBase import grpclib {% endif %} +{% if output_file.imports_type_checking_only %} +from typing import TYPE_CHECKING + +if TYPE_CHECKING: +{% for i in output_file.imports_type_checking_only|sort %} {{ i }} +{% endfor %} +{% endif %} {% if output_file.enums %}{% for enum in output_file.enums %} class {{ enum.py_name }}(betterproto.Enum): @@ -86,6 +93,9 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {# Client streaming: need a request iterator instead #} , {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]] {%- endif -%} + , timeout: Optional[float] = None + , deadline: Optional["Deadline"] = None + , metadata: Optional["_MetadataLike"] = None ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: {% if method.comment %} {{ method.comment }} @@ -98,6 +108,9 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {{ method.py_input_message_param }}_iterator, {{ method.py_input_message_type }}, {{ method.py_output_message_type.strip('"') }}, + timeout=timeout, + deadline=deadline, + metadata=metadata, ): yield response {% else %}{# i.e. not client streaming #} @@ -105,6 +118,9 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): "{{ method.route }}", {{ method.py_input_message_param }}, {{ method.py_output_message_type.strip('"') }}, + timeout=timeout, + deadline=deadline, + metadata=metadata, ): yield response @@ -115,13 +131,19 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): "{{ method.route }}", {{ method.py_input_message_param }}_iterator, {{ method.py_input_message_type }}, - {{ method.py_output_message_type.strip('"') }} + {{ method.py_output_message_type.strip('"') }}, + timeout=timeout, + deadline=deadline, + metadata=metadata, ) {% else %}{# i.e. not client streaming #} return await self._unary_unary( "{{ method.route }}", {{ method.py_input_message_param }}, - {{ method.py_output_message_type.strip('"') }} + {{ method.py_output_message_type.strip('"') }}, + timeout=timeout, + deadline=deadline, + metadata=metadata, ) {% endif %}{# client streaming #} {% endif %} diff --git a/tests/grpc/test_grpclib_client.py b/tests/grpc/test_grpclib_client.py index 28ce56f..ba0b943 100644 --- a/tests/grpc/test_grpclib_client.py +++ b/tests/grpc/test_grpclib_client.py @@ -1,9 +1,11 @@ import asyncio import sys +import uuid import grpclib import grpclib.metadata import grpclib.server +import grpclib.client import pytest from betterproto.grpc.util.async_channel import AsyncChannel from grpclib.testing import ChannelFor @@ -18,7 +20,7 @@ from .thing_service import ThingService async def _test_client(client: ThingServiceClient, name="clean room", **kwargs): - response = await client.do_thing(DoThingRequest(name=name)) + response = await client.do_thing(DoThingRequest(name=name), **kwargs) assert response.names == [name] @@ -172,6 +174,55 @@ async def test_service_call_lower_level_with_overrides(): assert response.names == [THING_TO_DO] +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("overrides",), + [ + (dict(timeout=10),), + (dict(deadline=grpclib.metadata.Deadline.from_timeout(10)),), + (dict(metadata={"authorization": str(uuid.uuid4())}),), + (dict(timeout=20, metadata={"authorization": str(uuid.uuid4())}),), + ], +) +async def test_service_call_high_level_with_overrides(mocker, overrides): + request_spy = mocker.spy(grpclib.client.Channel, "request") + name = str(uuid.uuid4()) + defaults = dict( + timeout=99, + deadline=grpclib.metadata.Deadline.from_timeout(99), + metadata={"authorization": name}, + ) + + async with ChannelFor( + [ + ThingService( + test_hook=_assert_request_meta_received( + deadline=grpclib.metadata.Deadline.from_timeout( + overrides.get("timeout", 99) + ), + metadata=overrides.get("metadata", defaults.get("metadata")), + ) + ) + ] + ) as channel: + client = ThingServiceClient(channel, **defaults) + await _test_client(client, name=name, **overrides) + assert request_spy.call_count == 1 + + # for python <3.8 request_spy.call_args.kwargs do not work + _, request_spy_call_kwargs = request_spy.call_args_list[0] + + # ensure all overrides were successful + for key, value in overrides.items(): + assert key in request_spy_call_kwargs + assert request_spy_call_kwargs[key] == value + + # ensure default values were retained + for key in set(defaults.keys()) - set(overrides.keys()): + assert key in request_spy_call_kwargs + assert request_spy_call_kwargs[key] == defaults[key] + + @pytest.mark.asyncio async def test_async_gen_for_unary_stream_request(): thing_name = "my milkshakes"