diff --git a/betterproto/__init__.py b/betterproto/__init__.py index f082fa6..dc2566c 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -11,10 +11,12 @@ from typing import ( Any, AsyncGenerator, Callable, + Collection, Dict, Generator, Iterable, List, + Mapping, Optional, SupportsBytes, Tuple, @@ -1000,20 +1002,57 @@ def _get_wrapper(proto_type: str) -> Type: }[proto_type] +_Value = Union[str, bytes] +_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] + + class ServiceStub(ABC): """ Base class for async gRPC service stubs. """ - def __init__(self, channel: grpclib.client.Channel) -> None: + def __init__( + self, + channel: grpclib.client.Channel, + *, + timeout: Optional[float] = None, + deadline: Optional[grpclib.metadata.Deadline] = None, + metadata: Optional[_MetadataLike] = None, + ) -> None: self.channel = channel + self.timeout = timeout + self.deadline = deadline + self.metadata = metadata + + def __resolve_request_kwargs( + self, + timeout: Optional[float], + deadline: Optional[grpclib.metadata.Deadline], + metadata: Optional[_MetadataLike], + ): + return { + "timeout": self.timeout if timeout is None else timeout, + "deadline": self.deadline if deadline is None else deadline, + "metadata": self.metadata if metadata is None else metadata, + } async def _unary_unary( - self, route: str, request: "IProtoMessage", response_type: Type[T] + self, + route: str, + request: "IProtoMessage", + response_type: Type[T], + *, + timeout: Optional[float] = None, + deadline: Optional[grpclib.metadata.Deadline] = None, + metadata: Optional[_MetadataLike] = None, ) -> T: """Make a unary request and return the response.""" async with self.channel.request( - route, grpclib.const.Cardinality.UNARY_UNARY, type(request), response_type + route, + grpclib.const.Cardinality.UNARY_UNARY, + type(request), + response_type, + **self.__resolve_request_kwargs(timeout, deadline, metadata), ) as stream: await stream.send_message(request, end=True) response = await stream.recv_message() @@ -1021,11 +1060,22 @@ class ServiceStub(ABC): return response async def _unary_stream( - self, route: str, request: "IProtoMessage", response_type: Type[T] + self, + route: str, + request: "IProtoMessage", + response_type: Type[T], + *, + timeout: Optional[float] = None, + deadline: Optional[grpclib.metadata.Deadline] = None, + metadata: Optional[_MetadataLike] = None, ) -> AsyncGenerator[T, None]: """Make a unary request and return the stream response iterator.""" async with self.channel.request( - route, grpclib.const.Cardinality.UNARY_STREAM, type(request), response_type + route, + grpclib.const.Cardinality.UNARY_STREAM, + type(request), + response_type, + **self.__resolve_request_kwargs(timeout, deadline, metadata), ) as stream: await stream.send_message(request, end=True) async for message in stream: diff --git a/betterproto/tests/test_service_stub.py b/betterproto/tests/test_service_stub.py index 84de7b6..a5ba200 100644 --- a/betterproto/tests/test_service_stub.py +++ b/betterproto/tests/test_service_stub.py @@ -7,17 +7,24 @@ from .service import DoThingResponse, DoThingRequest, ExampleServiceStub class ExampleService: + def __init__(self, test_hook=None): + # This lets us pass assertions to the servicer ;) + self.test_hook = test_hook - async def DoThing(self, stream: 'grpclib.server.Stream[DoThingRequest, DoThingResponse]'): + async def DoThing( + self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" + ): request = await stream.recv_message() + print("self.test_hook", self.test_hook) + if self.test_hook is not None: + self.test_hook(stream) for iteration in range(request.iterations): pass await stream.send_message(DoThingResponse(request.iterations)) - def __mapping__(self) -> Dict[str, grpclib.const.Handler]: return { - '/service.ExampleService/DoThing': grpclib.const.Handler( + "/service.ExampleService/DoThing": grpclib.const.Handler( self.DoThing, grpclib.const.Cardinality.UNARY_UNARY, DoThingRequest, @@ -26,10 +33,91 @@ class ExampleService: } +async def _test_stub(stub, iterations=42, **kwargs): + response = await stub.do_thing(iterations=iterations) + assert response.successful_iterations == iterations + + +def _get_server_side_test(deadline, metadata): + def server_side_test(stream): + assert stream.deadline._timestamp == pytest.approx( + deadline._timestamp, 1 + ), "The provided deadline should be recieved serverside" + assert ( + stream.metadata["authorization"] == metadata["authorization"] + ), "The provided authorization metadata should be recieved serverside" + + return server_side_test + + @pytest.mark.asyncio async def test_simple_service_call(): - ITERATIONS = 42 async with ChannelFor([ExampleService()]) as channel: - stub = ExampleServiceStub(channel) - response = await stub.do_thing(iterations=ITERATIONS) + await _test_stub(ExampleServiceStub(channel)) + + +@pytest.mark.asyncio +async def test_service_call_with_upfront_request_params(): + # Setting deadline + deadline = grpclib.metadata.Deadline.from_timeout(22) + metadata = {"authorization": "12345"} + async with ChannelFor( + [ExampleService(test_hook=_get_server_side_test(deadline, metadata))] + ) as channel: + await _test_stub( + ExampleServiceStub(channel, deadline=deadline, metadata=metadata) + ) + + # Setting timeout + timeout = 99 + deadline = grpclib.metadata.Deadline.from_timeout(timeout) + metadata = {"authorization": "12345"} + async with ChannelFor( + [ExampleService(test_hook=_get_server_side_test(deadline, metadata))] + ) as channel: + await _test_stub( + ExampleServiceStub(channel, timeout=timeout, metadata=metadata) + ) + + +@pytest.mark.asyncio +async def test_service_call_lower_level_with_overrides(): + ITERATIONS = 99 + + # Setting deadline + deadline = grpclib.metadata.Deadline.from_timeout(22) + metadata = {"authorization": "12345"} + kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28) + kwarg_metadata = {"authorization": "12345"} + async with ChannelFor( + [ExampleService(test_hook=_get_server_side_test(deadline, metadata))] + ) as channel: + stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata) + response = await stub._unary_unary( + "/service.ExampleService/DoThing", DoThingRequest(ITERATIONS), DoThingResponse, + deadline=kwarg_deadline, + metadata=kwarg_metadata, + ) + assert response.successful_iterations == ITERATIONS + + # Setting timeout + timeout = 99 + deadline = grpclib.metadata.Deadline.from_timeout(timeout) + metadata = {"authorization": "12345"} + kwarg_timeout = 9000 + kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout) + kwarg_metadata = {"authorization": "09876"} + async with ChannelFor( + [ + ExampleService( + test_hook=_get_server_side_test(kwarg_deadline, kwarg_metadata) + ) + ] + ) as channel: + stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata) + response = await stub._unary_unary( + "/service.ExampleService/DoThing", DoThingRequest(ITERATIONS), DoThingResponse, + timeout=kwarg_timeout, + metadata=kwarg_metadata, + ) assert response.successful_iterations == ITERATIONS