From 4b6f55dce58d82f8db14ada7e08063c14eab9b94 Mon Sep 17 00:00:00 2001 From: Nat Noordanus Date: Sun, 7 Jun 2020 17:51:26 +0200 Subject: [PATCH] Finish implementation and testing of client Including stream_unary and stream_stream call methods. Also - improve organisation of relevant tests - fix some generated type annotations - Add AsyncChannel utility cos it's useful --- betterproto/__init__.py | 2 +- betterproto/grpc/grpclib_client.py | 68 ++++-- betterproto/grpc/util/__init__.py | 0 betterproto/grpc/util/async_channel.py | 204 ++++++++++++++++++ betterproto/plugin.py | 9 +- betterproto/templates/template.py.j2 | 16 +- betterproto/tests/grpc/__init__.py | 0 betterproto/tests/grpc/test_grpclib_client.py | 150 +++++++++++++ betterproto/tests/grpc/thing_service.py | 83 +++++++ .../test_googletypes_response.py | 2 +- .../tests/inputs/service/test_service.py | 132 ------------ betterproto/tests/test_service_client.py | 176 --------------- 12 files changed, 503 insertions(+), 339 deletions(-) create mode 100644 betterproto/grpc/util/__init__.py create mode 100644 betterproto/grpc/util/async_channel.py create mode 100644 betterproto/tests/grpc/__init__.py create mode 100644 betterproto/tests/grpc/test_grpclib_client.py create mode 100644 betterproto/tests/grpc/thing_service.py delete mode 100644 betterproto/tests/inputs/service/test_service.py delete mode 100644 betterproto/tests/test_service_client.py diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 8288aaf..6a53d65 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -593,7 +593,7 @@ class Message(ABC): serialize_empty = False if isinstance(value, Message) and value._serialized_on_wire: # Empty messages can still be sent on the wire if they were - # set (or received empty). + # set (or recieved empty). serialize_empty = True if value == self._get_field_default(field_name) and not ( diff --git a/betterproto/grpc/grpclib_client.py b/betterproto/grpc/grpclib_client.py index 757982e..7218574 100644 --- a/betterproto/grpc/grpclib_client.py +++ b/betterproto/grpc/grpclib_client.py @@ -1,7 +1,8 @@ from abc import ABC +import asyncio import grpclib.const from typing import ( - AsyncGenerator, + Any, AsyncIterator, Collection, Iterator, @@ -16,17 +17,18 @@ from .._types import ST, T if TYPE_CHECKING: from grpclib._protocols import IProtoMessage - from grpclib.client import Channel + from grpclib.client import Channel, Stream from grpclib.metadata import Deadline _Value = Union[str, bytes] _MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] +_MessageSource = Union[Iterator["IProtoMessage"], AsyncIterator["IProtoMessage"]] class ServiceStub(ABC): """ - Base class for async gRPC service stubs. + Base class for async gRPC clients. """ def __init__( @@ -86,7 +88,7 @@ class ServiceStub(ABC): timeout: Optional[float] = None, deadline: Optional["Deadline"] = None, metadata: Optional[_MetadataLike] = None, - ) -> AsyncGenerator[T, None]: + ) -> AsyncIterator[T]: """Make a unary request and return the stream response iterator.""" async with self.channel.request( route, @@ -102,17 +104,23 @@ class ServiceStub(ABC): async def _stream_unary( self, route: str, - request_iterator: Iterator["IProtoMessage"], + request_iterator: _MessageSource, request_type: Type[ST], response_type: Type[T], + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[_MetadataLike] = None, ) -> T: """Make a stream request and return the response.""" async with self.channel.request( - route, grpclib.const.Cardinality.STREAM_UNARY, request_type, response_type + route, + grpclib.const.Cardinality.STREAM_UNARY, + request_type, + response_type, + **self.__resolve_request_kwargs(timeout, deadline, metadata), ) as stream: - for message in request_iterator: - await stream.send_message(message) - await stream.send_request(end=True) + await self._send_messages(stream, request_iterator) response = await stream.recv_message() assert response is not None return response @@ -120,16 +128,42 @@ class ServiceStub(ABC): async def _stream_stream( self, route: str, - request_iterator: Iterator["IProtoMessage"], + request_iterator: _MessageSource, request_type: Type[ST], response_type: Type[T], - ) -> AsyncGenerator[T, None]: - """Make a stream request and return the stream response iterator.""" + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[_MetadataLike] = None, + ) -> AsyncIterator[T]: + """ + Make a stream request and return an AsyncIterator to iterate over response + messages. + """ async with self.channel.request( - route, grpclib.const.Cardinality.STREAM_STREAM, request_type, response_type + route, + grpclib.const.Cardinality.STREAM_STREAM, + request_type, + response_type, + **self.__resolve_request_kwargs(timeout, deadline, metadata), ) as stream: - for message in request_iterator: + await stream.send_request() + sending_task = asyncio.ensure_future( + self._send_messages(stream, request_iterator) + ) + try: + async for response in stream: + yield response + except: + sending_task.cancel() + raise + + @staticmethod + async def _send_messages(stream, messages: _MessageSource): + if hasattr(messages, "__aiter__"): + async for message in messages: await stream.send_message(message) - await stream.send_request(end=True) - async for message in stream: - yield message + else: + for message in messages: + await stream.send_message(message) + await stream.end() diff --git a/betterproto/grpc/util/__init__.py b/betterproto/grpc/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/betterproto/grpc/util/async_channel.py b/betterproto/grpc/util/async_channel.py new file mode 100644 index 0000000..7e83c94 --- /dev/null +++ b/betterproto/grpc/util/async_channel.py @@ -0,0 +1,204 @@ +import asyncio +from typing import ( + AsyncIterable, + AsyncIterator, + Iterable, + Optional, + TypeVar, + Union, +) + +T = TypeVar("T") + + +class ChannelClosed(Exception): + """ + An exception raised on an attempt to send through a closed channel + """ + + pass + + +class ChannelDone(Exception): + """ + An exception raised on an attempt to send recieve from a channel that is both closed + and empty. + """ + + pass + + +class AsyncChannel(AsyncIterable[T]): + """ + A buffered async channel for sending items between coroutines with FIFO semantics. + + This makes decoupled bidirection steaming gRPC requests easy if used like: + + .. code-block:: python + client = GeneratedStub(grpclib_chan) + # The channel can be initialised with items to send immediately + request_chan = AsyncChannel([ReqestObject(...), ReqestObject(...)]) + async for response in client.rpc_call(request_chan): + # The response iterator will remain active until the connection is closed + ... + # More items can be sent at any time + await request_chan.send(ReqestObject(...)) + ... + # The channel must be closed to complete the gRPC connection + request_chan.close() + + Items can be sent through the channel by either: + - providing an iterable to the constructor + - providing an iterable to the send_from method + - passing them to the send method one at a time + + Items can be recieved from the channel by either: + - iterating over the channel with a for loop to get all items + - calling the recieve method to get one item at a time + + If the channel is empty then recievers will wait until either an item appears or the + channel is closed. + + Once the channel is closed then subsequent attempt to send through the channel will + fail with a ChannelClosed exception. + + When th channel is closed and empty then it is done, and further attempts to recieve + from it will fail with a ChannelDone exception + + If multiple coroutines recieve from the channel concurrently, each item sent will be + recieved by only one of the recievers. + + :param source: + An optional iterable will items that should be sent through the channel + immediately. + :param buffer_limit: + Limit the number of items that can be buffered in the channel, A value less than + 1 implies no limit. If the channel is full then attempts to send more items will + result in the sender waiting until an item is recieved from the channel. + :param close: + If set to True then the channel will automatically close after exhausting source + or immediately if no source is provided. + """ + + def __init__( + self, + source: Union[Iterable[T], AsyncIterable[T]] = tuple(), + *, + buffer_limit: int = 0, + close: bool = False, + ): + self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit) + self._closed = False + self._sending_task = ( + asyncio.ensure_future(self.send_from(source, close)) if source else None + ) + self._waiting_recievers: int = 0 + # Track whether flush has been invoked so it can only happen once + self._flushed = False + + def __aiter__(self) -> AsyncIterator[T]: + return self + + async def __anext__(self) -> T: + if self.done: + raise StopAsyncIteration + self._waiting_recievers += 1 + try: + result = await self._queue.get() + if result is self.__flush: + raise StopAsyncIteration + finally: + self._waiting_recievers -= 1 + self._queue.task_done() + + def closed(self) -> bool: + """ + Returns True if this channel is closed and no-longer accepting new items + """ + return self._closed + + def done(self) -> bool: + """ + Check if this channel is done. + + :return: True if this channel is closed and and has been drained of items in + which case any further attempts to recieve an item from this channel will raise + a ChannelDone exception. + """ + # After close the channel is not yet done until there is at least one waiting + # reciever per enqueued item. + return self._closed and self._queue.qsize() <= self._waiting_recievers + + async def send_from( + self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False + ): + """ + Iterates the given [Async]Iterable and sends all the resulting items. + If close is set to True then subsequent send calls will be rejected with a + ChannelClosed exception. + :param source: an iterable of items to send + :param close: + if True then the channel will be closed after the source has been exhausted + + """ + if self._closed: + raise ChannelClosed("Cannot send through a closed channel") + if isinstance(source, AsyncIterable): + async for item in source: + await self._queue.put(item) + else: + for item in source: + await self._queue.put(item) + if close: + # Complete the closing process + await self.close() + + async def send(self, item: T): + """ + Send a single item over this channel. + :param item: The item to send + """ + if self._closed: + raise ChannelClosed("Cannot send through a closed channel") + await self._queue.put(item) + + async def recieve(self) -> Optional[T]: + """ + Returns the next item from this channel when it becomes available, + or None if the channel is closed before another item is sent. + :return: An item from the channel + """ + if self.done: + raise ChannelDone("Cannot recieve from a closed channel") + self._waiting_recievers += 1 + try: + result = await self._queue.get() + if result is self.__flush: + return None + return result + finally: + self._waiting_recievers -= 1 + self._queue.task_done() + + def close(self): + """ + Close this channel to new items + """ + if self._sending_task is not None: + self._sending_task.cancel() + self._closed = True + asyncio.ensure_future(self._flush_queue()) + + async def _flush_queue(self): + """ + To be called after the channel is closed. Pushes a number of self.__flush + objects to the queue to ensure no waiting consumers get deadlocked. + """ + if not self._flushed: + self._flushed = True + deadlocked_recievers = max(0, self._waiting_recievers - self._queue.qsize()) + for _ in range(deadlocked_recievers): + await self._queue.put(self.__flush) + + # A special signal object for flushing the queue when the channel is closed + __flush = object() diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 44515d5..85fd905 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -344,11 +344,12 @@ def generate_code(request, response): } ) - if method.server_streaming: - output["typing_imports"].add("AsyncGenerator") - if method.client_streaming: - output["typing_imports"].add("Iterator") + output["typing_imports"].add("AsyncIterable") + output["typing_imports"].add("Iterable") + output["typing_imports"].add("Union") + if method.server_streaming: + output["typing_imports"].add("AsyncIterator") output["services"].append(data) diff --git a/betterproto/templates/template.py.j2 b/betterproto/templates/template.py.j2 index c4c3029..3894619 100644 --- a/betterproto/templates/template.py.j2 +++ b/betterproto/templates/template.py.j2 @@ -77,9 +77,9 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {%- endif -%} {%- else -%} {# Client streaming: need a request iterator instead #} - , request_iterator: Iterator["{{ method.input }}"] + , request_iterator: Union[AsyncIterable["{{ method.input }}"], Iterable["{{ method.input }}"]] {%- endif -%} - ) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}: + ) -> {% if method.server_streaming %}AsyncIterator[{{ method.output }}]{% else %}{{ method.output }}{% endif %}: {% if method.comment %} {{ method.comment }} @@ -97,7 +97,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% endif %} {% if method.server_streaming %} - {% if method.client_streaming %} + {% if method.client_streaming %} async for response in self._stream_stream( "{{ method.route }}", request_iterator, @@ -105,7 +105,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {{ method.output }}, ): yield response - {% else %}{# i.e. not client streaming #} + {% else %}{# i.e. not client streaming #} async for response in self._unary_stream( "{{ method.route }}", request, @@ -113,22 +113,22 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): ): yield response - {% endif %}{# if client streaming #} + {% endif %}{# if client streaming #} {% else %}{# i.e. not server streaming #} - {% if method.client_streaming %} + {% if method.client_streaming %} return await self._stream_unary( "{{ method.route }}", request_iterator, {{ method.input }}, {{ method.output }} ) - {% else %}{# i.e. not client streaming #} + {% else %}{# i.e. not client streaming #} return await self._unary_unary( "{{ method.route }}", request, {{ method.output }} ) - {% endif %}{# client streaming #} + {% endif %}{# client streaming #} {% endif %} {% endfor %} diff --git a/betterproto/tests/grpc/__init__.py b/betterproto/tests/grpc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/betterproto/tests/grpc/test_grpclib_client.py b/betterproto/tests/grpc/test_grpclib_client.py new file mode 100644 index 0000000..dc57fe4 --- /dev/null +++ b/betterproto/tests/grpc/test_grpclib_client.py @@ -0,0 +1,150 @@ +from betterproto.tests.output_betterproto.service.service import ( + DoThingResponse, + DoThingRequest, + GetThingRequest, + GetThingResponse, + TestStub as ThingServiceClient, +) +import grpclib +from grpclib.testing import ChannelFor +import pytest +from betterproto.grpc.util.async_channel import AsyncChannel +from .thing_service import ThingService + + +async def _test_client(client, name="clean room", **kwargs): + response = await client.do_thing(name=name) + assert response.names == [name] + + +def _assert_request_meta_recieved(deadline, metadata): + def server_side_test(stream): + assert stream.deadline._timestamp == pytest.approx( + deadline._timestamp, 1 + ), "The provided deadline should be recieved serverside" + assert ( + stream.metadata["authorization"] == metadata["authorization"] + ), "The provided authorization metadata should be recieved serverside" + + return server_side_test + + +@pytest.mark.asyncio +async def test_simple_service_call(): + async with ChannelFor([ThingService()]) as channel: + await _test_client(ThingServiceClient(channel)) + + +@pytest.mark.asyncio +async def test_service_call_with_upfront_request_params(): + # Setting deadline + deadline = grpclib.metadata.Deadline.from_timeout(22) + metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)] + ) as channel: + await _test_client( + ThingServiceClient(channel, deadline=deadline, metadata=metadata) + ) + + # Setting timeout + timeout = 99 + deadline = grpclib.metadata.Deadline.from_timeout(timeout) + metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)] + ) as channel: + await _test_client( + ThingServiceClient(channel, timeout=timeout, metadata=metadata) + ) + + +@pytest.mark.asyncio +async def test_service_call_lower_level_with_overrides(): + THING_TO_DO = "get milk" + + # Setting deadline + deadline = grpclib.metadata.Deadline.from_timeout(22) + metadata = {"authorization": "12345"} + kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28) + kwarg_metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)] + ) as channel: + client = ThingServiceClient(channel, deadline=deadline, metadata=metadata) + response = await client._unary_unary( + "/service.Test/DoThing", + DoThingRequest(THING_TO_DO), + DoThingResponse, + deadline=kwarg_deadline, + metadata=kwarg_metadata, + ) + assert response.names == [THING_TO_DO] + + # Setting timeout + timeout = 99 + deadline = grpclib.metadata.Deadline.from_timeout(timeout) + metadata = {"authorization": "12345"} + kwarg_timeout = 9000 + kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout) + kwarg_metadata = {"authorization": "09876"} + async with ChannelFor( + [ + ThingService( + test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata), + ) + ] + ) as channel: + client = ThingServiceClient(channel, deadline=deadline, metadata=metadata) + response = await client._unary_unary( + "/service.Test/DoThing", + DoThingRequest(THING_TO_DO), + DoThingResponse, + timeout=kwarg_timeout, + metadata=kwarg_metadata, + ) + assert response.names == [THING_TO_DO] + + +@pytest.mark.asyncio +async def test_async_gen_for_unary_stream_request(): + thing_name = "my milkshakes" + + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + expected_versions = [5, 4, 3, 2, 1] + async for response in client.get_thing_versions(name=thing_name): + assert response.name == thing_name + assert response.version == expected_versions.pop() + + +@pytest.mark.asyncio +async def test_async_gen_for_stream_stream_request(): + some_things = ["cake", "cricket", "coral reef"] + more_things = ["ball", "that", "56kmodem", "liberal humanism", "cheesesticks"] + expected_things = (*some_things, *more_things) + + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + # Use an AsyncChannel to decouple sending and recieving, it'll send some_things + # immediately and we'll use it to send more_things later, after recieving some + # results + request_chan = AsyncChannel(GetThingRequest(name) for name in some_things) + response_index = 0 + async for response in client.get_different_things(request_chan): + assert response.name == expected_things[response_index] + assert response.version == response_index + 1 + response_index += 1 + if more_things: + # Send some more requests as we recieve reponses to be sure coordination of + # send/recieve events doesn't matter + another_response = await request_chan.send( + GetThingRequest(more_things.pop(0)) + ) + if another_response is not None: + assert another_response.name == expected_things[response_index] + assert another_response.version == response_index + response_index += 1 + else: + # No more things to send make sure channel is closed + await request_chan.close() diff --git a/betterproto/tests/grpc/thing_service.py b/betterproto/tests/grpc/thing_service.py new file mode 100644 index 0000000..bc9fff8 --- /dev/null +++ b/betterproto/tests/grpc/thing_service.py @@ -0,0 +1,83 @@ +from betterproto.tests.output_betterproto.service.service import ( + DoThingResponse, + DoThingRequest, + GetThingRequest, + GetThingResponse, + TestStub as ThingServiceClient, +) +import grpclib +from typing import Any, Dict + + +class ThingService: + def __init__(self, test_hook=None): + # This lets us pass assertions to the servicer ;) + self.test_hook = test_hook + + async def do_thing( + self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" + ): + request = await stream.recv_message() + if self.test_hook is not None: + self.test_hook(stream) + await stream.send_message(DoThingResponse([request.name])) + + async def do_many_things( + self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" + ): + thing_names = [request.name for request in stream] + if self.test_hook is not None: + self.test_hook(stream) + await stream.send_message(DoThingResponse(thing_names)) + + async def get_thing_versions( + self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" + ): + request = await stream.recv_message() + if self.test_hook is not None: + self.test_hook(stream) + for version_num in range(1, 6): + await stream.send_message( + GetThingResponse(name=request.name, version=version_num) + ) + + async def get_different_things( + self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" + ): + if self.test_hook is not None: + self.test_hook(stream) + # Respond to each input item immediately + response_num = 0 + async for request in stream: + response_num += 1 + await stream.send_message( + GetThingResponse(name=request.name, version=response_num) + ) + + def __mapping__(self) -> Dict[str, "grpclib.const.Handler"]: + return { + "/service.Test/DoThing": grpclib.const.Handler( + self.do_thing, + grpclib.const.Cardinality.UNARY_UNARY, + DoThingRequest, + DoThingResponse, + ), + "/service.Test/DoManyThings": grpclib.const.Handler( + self.do_many_things, + grpclib.const.Cardinality.STREAM_UNARY, + DoThingRequest, + DoThingResponse, + ), + "/service.Test/GetThingVersions": grpclib.const.Handler( + self.get_thing_versions, + grpclib.const.Cardinality.UNARY_STREAM, + GetThingRequest, + GetThingResponse, + ), + "/service.Test/GetDifferentThings": grpclib.const.Handler( + self.get_different_things, + grpclib.const.Cardinality.STREAM_STREAM, + GetThingRequest, + GetThingResponse, + ), + } diff --git a/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py b/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py index 02fa193..bd5f602 100644 --- a/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py +++ b/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py @@ -23,7 +23,7 @@ test_cases = [ @pytest.mark.asyncio @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) -async def test_channel_receives_wrapped_type( +async def test_channel_recieves_wrapped_type( service_method: Callable[[TestStub], Any], wrapper_class: Callable, value ): wrapped_value = wrapper_class() diff --git a/betterproto/tests/inputs/service/test_service.py b/betterproto/tests/inputs/service/test_service.py deleted file mode 100644 index 2a6ca59..0000000 --- a/betterproto/tests/inputs/service/test_service.py +++ /dev/null @@ -1,132 +0,0 @@ -import betterproto -import grpclib -from grpclib.testing import ChannelFor -import pytest -from typing import Dict - -from betterproto.tests.output_betterproto.service.service import ( - DoThingResponse, - DoThingRequest, - TestStub as ExampleServiceStub, -) - - -class ExampleService: - def __init__(self, test_hook=None): - # This lets us pass assertions to the servicer ;) - self.test_hook = test_hook - - async def DoThing( - self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" - ): - request = await stream.recv_message() - print("self.test_hook", self.test_hook) - if self.test_hook is not None: - self.test_hook(stream) - for iteration in range(request.iterations): - pass - await stream.send_message(DoThingResponse(request.iterations)) - - def __mapping__(self) -> Dict[str, grpclib.const.Handler]: - return { - "/service.Test/DoThing": grpclib.const.Handler( - self.DoThing, - grpclib.const.Cardinality.UNARY_UNARY, - DoThingRequest, - DoThingResponse, - ) - } - - -async def _test_stub(stub, iterations=42, **kwargs): - response = await stub.do_thing(iterations=iterations) - assert response.successful_iterations == iterations - - -def _get_server_side_test(deadline, metadata): - def server_side_test(stream): - assert stream.deadline._timestamp == pytest.approx( - deadline._timestamp, 1 - ), "The provided deadline should be recieved serverside" - assert ( - stream.metadata["authorization"] == metadata["authorization"] - ), "The provided authorization metadata should be recieved serverside" - - return server_side_test - - -@pytest.mark.asyncio -async def test_simple_service_call(): - async with ChannelFor([ExampleService()]) as channel: - await _test_stub(ExampleServiceStub(channel)) - - -@pytest.mark.asyncio -async def test_service_call_with_upfront_request_params(): - # Setting deadline - deadline = grpclib.metadata.Deadline.from_timeout(22) - metadata = {"authorization": "12345"} - async with ChannelFor( - [ExampleService(test_hook=_get_server_side_test(deadline, metadata))] - ) as channel: - await _test_stub( - ExampleServiceStub(channel, deadline=deadline, metadata=metadata) - ) - - # Setting timeout - timeout = 99 - deadline = grpclib.metadata.Deadline.from_timeout(timeout) - metadata = {"authorization": "12345"} - async with ChannelFor( - [ExampleService(test_hook=_get_server_side_test(deadline, metadata))] - ) as channel: - await _test_stub( - ExampleServiceStub(channel, timeout=timeout, metadata=metadata) - ) - - -@pytest.mark.asyncio -async def test_service_call_lower_level_with_overrides(): - ITERATIONS = 99 - - # Setting deadline - deadline = grpclib.metadata.Deadline.from_timeout(22) - metadata = {"authorization": "12345"} - kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28) - kwarg_metadata = {"authorization": "12345"} - async with ChannelFor( - [ExampleService(test_hook=_get_server_side_test(deadline, metadata))] - ) as channel: - stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata) - response = await stub._unary_unary( - "/service.Test/DoThing", - DoThingRequest(ITERATIONS), - DoThingResponse, - deadline=kwarg_deadline, - metadata=kwarg_metadata, - ) - assert response.successful_iterations == ITERATIONS - - # Setting timeout - timeout = 99 - deadline = grpclib.metadata.Deadline.from_timeout(timeout) - metadata = {"authorization": "12345"} - kwarg_timeout = 9000 - kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout) - kwarg_metadata = {"authorization": "09876"} - async with ChannelFor( - [ - ExampleService( - test_hook=_get_server_side_test(kwarg_deadline, kwarg_metadata) - ) - ] - ) as channel: - stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata) - response = await stub._unary_unary( - "/service.Test/DoThing", - DoThingRequest(ITERATIONS), - DoThingResponse, - timeout=kwarg_timeout, - metadata=kwarg_metadata, - ) - assert response.successful_iterations == ITERATIONS diff --git a/betterproto/tests/test_service_client.py b/betterproto/tests/test_service_client.py deleted file mode 100644 index 586095d..0000000 --- a/betterproto/tests/test_service_client.py +++ /dev/null @@ -1,176 +0,0 @@ -import betterproto -import grpclib -from grpclib.testing import ChannelFor -import pytest -from typing import Dict -from betterproto.tests.output_betterproto.service.service import ( - DoThingResponse, - DoThingRequest, - GetThingRequest, - GetThingResponse, - TestStub as ThingServiceClient, -) - - -class ThingService: - def __init__(self, test_hook=None): - # This lets us pass assertions to the servicer ;) - self.test_hook = test_hook - - async def DoThing( - self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" - ): - request = await stream.recv_message() - if self.test_hook is not None: - self.test_hook(stream) - await stream.send_message(DoThingResponse([request.name])) - - async def DoManyThings( - self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" - ): - thing_names = [request.name for request in stream] - if self.test_hook is not None: - self.test_hook(stream) - await stream.send_message(DoThingResponse(thing_names)) - - async def GetThingVersions( - self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" - ): - request = await stream.recv_message() - if self.test_hook is not None: - self.test_hook(stream) - for version_num in range(1, 6): - await stream.send_message( - GetThingResponse(name=request, version=version_num) - ) - - async def GetDifferentThings( - self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" - ): - if self.test_hook is not None: - self.test_hook(stream) - # Response to each input item immediately - for request in stream: - await stream.send_message(GetThingResponse(name=request.name, version=1)) - - def __mapping__(self) -> Dict[str, grpclib.const.Handler]: - return { - "/service.Test/DoThing": grpclib.const.Handler( - self.DoThing, - grpclib.const.Cardinality.UNARY_UNARY, - DoThingRequest, - DoThingResponse, - ), - "/service.Test/DoManyThings": grpclib.const.Handler( - self.DoManyThings, - grpclib.const.Cardinality.STREAM_UNARY, - DoThingRequest, - DoThingResponse, - ), - "/service.Test/GetThingVersions": grpclib.const.Handler( - self.GetThingVersions, - grpclib.const.Cardinality.UNARY_STREAM, - GetThingRequest, - GetThingResponse, - ), - "/service.Test/GetDifferentThings": grpclib.const.Handler( - self.GetDifferentThings, - grpclib.const.Cardinality.STREAM_STREAM, - GetThingRequest, - GetThingResponse, - ), - } - - -async def _test_stub(stub, name="clean room", **kwargs): - response = await stub.do_thing(name=name) - assert response.names == [name] - - -def _assert_request_meta_recieved(deadline, metadata): - def server_side_test(stream): - assert stream.deadline._timestamp == pytest.approx( - deadline._timestamp, 1 - ), "The provided deadline should be recieved serverside" - assert ( - stream.metadata["authorization"] == metadata["authorization"] - ), "The provided authorization metadata should be recieved serverside" - - return server_side_test - - -@pytest.mark.asyncio -async def test_simple_service_call(): - async with ChannelFor([ThingService()]) as channel: - await _test_stub(ThingServiceClient(channel)) - - -@pytest.mark.asyncio -async def test_service_call_with_upfront_request_params(): - # Setting deadline - deadline = grpclib.metadata.Deadline.from_timeout(22) - metadata = {"authorization": "12345"} - async with ChannelFor( - [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata))] - ) as channel: - await _test_stub( - ThingServiceClient(channel, deadline=deadline, metadata=metadata) - ) - - # Setting timeout - timeout = 99 - deadline = grpclib.metadata.Deadline.from_timeout(timeout) - metadata = {"authorization": "12345"} - async with ChannelFor( - [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata))] - ) as channel: - await _test_stub( - ThingServiceClient(channel, timeout=timeout, metadata=metadata) - ) - - -@pytest.mark.asyncio -async def test_service_call_lower_level_with_overrides(): - THING_TO_DO = "get milk" - - # Setting deadline - deadline = grpclib.metadata.Deadline.from_timeout(22) - metadata = {"authorization": "12345"} - kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28) - kwarg_metadata = {"authorization": "12345"} - async with ChannelFor( - [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata))] - ) as channel: - stub = ThingServiceClient(channel, deadline=deadline, metadata=metadata) - response = await stub._unary_unary( - "/service.Test/DoThing", - DoThingRequest(THING_TO_DO), - DoThingResponse, - deadline=kwarg_deadline, - metadata=kwarg_metadata, - ) - assert response.names == [THING_TO_DO] - - # Setting timeout - timeout = 99 - deadline = grpclib.metadata.Deadline.from_timeout(timeout) - metadata = {"authorization": "12345"} - kwarg_timeout = 9000 - kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout) - kwarg_metadata = {"authorization": "09876"} - async with ChannelFor( - [ - ThingService( - test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata) - ) - ] - ) as channel: - stub = ThingServiceClient(channel, deadline=deadline, metadata=metadata) - response = await stub._unary_unary( - "/service.Test/DoThing", - DoThingRequest(THING_TO_DO), - DoThingResponse, - timeout=kwarg_timeout, - metadata=kwarg_metadata, - ) - assert response.names == [THING_TO_DO]