diff --git a/src/betterproto/grpc/grpclib_client.py b/src/betterproto/grpc/grpclib_client.py index 54e5797..b19e806 100644 --- a/src/betterproto/grpc/grpclib_client.py +++ b/src/betterproto/grpc/grpclib_client.py @@ -127,6 +127,7 @@ class ServiceStub(ABC): response_type, **self.__resolve_request_kwargs(timeout, deadline, metadata), ) as stream: + await stream.send_request() await self._send_messages(stream, request_iterator) response = await stream.recv_message() assert response is not None diff --git a/tests/grpc/test_grpclib_client.py b/tests/grpc/test_grpclib_client.py index c44d4c0..10254f0 100644 --- a/tests/grpc/test_grpclib_client.py +++ b/tests/grpc/test_grpclib_client.py @@ -272,3 +272,27 @@ async def test_async_gen_for_stream_stream_request(): assert response_index == len( expected_things ), "Didn't receive all expected responses" + + +@pytest.mark.asyncio +async def test_stream_unary_with_empty_iterable(): + things = [] # empty + + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + requests = [DoThingRequest(name) for name in things] + response = await client.do_many_things(requests) + assert len(response.names) == 0 + + +@pytest.mark.asyncio +async def test_stream_stream_with_empty_iterable(): + things = [] # empty + + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + requests = [GetThingRequest(name) for name in things] + responses = [ + response async for response in client.get_different_things(requests) + ] + assert len(responses) == 0 diff --git a/tests/grpc/thing_service.py b/tests/grpc/thing_service.py index 1d7c27a..7723a29 100644 --- a/tests/grpc/thing_service.py +++ b/tests/grpc/thing_service.py @@ -27,7 +27,7 @@ class ThingService: async def do_many_things( self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" ): - thing_names = [request.name for request in stream] + thing_names = [request.name async for request in stream] if self.test_hook is not None: self.test_hook(stream) await stream.send_message(DoThingResponse(thing_names))