diff --git a/src/betterproto/grpc/grpclib_client.py b/src/betterproto/grpc/grpclib_client.py index 1068cce..99ec8d7 100644 --- a/src/betterproto/grpc/grpclib_client.py +++ b/src/betterproto/grpc/grpclib_client.py @@ -76,8 +76,8 @@ class ServiceStub(ABC): ) as stream: await stream.send_message(request, end=True) response = await stream.recv_message() - assert response is not None - return response + assert response is not None + return response async def _unary_stream( self, @@ -122,8 +122,8 @@ class ServiceStub(ABC): ) as stream: await self._send_messages(stream, request_iterator) response = await stream.recv_message() - assert response is not None - return response + assert response is not None + return response async def _stream_stream( self, diff --git a/tests/grpc/test_grpclib_client.py b/tests/grpc/test_grpclib_client.py index ed35c77..8eb5297 100644 --- a/tests/grpc/test_grpclib_client.py +++ b/tests/grpc/test_grpclib_client.py @@ -9,6 +9,7 @@ from tests.output_betterproto.service.service import ( ) import grpclib import grpclib.metadata +import grpclib.server from grpclib.testing import ChannelFor import pytest from betterproto.grpc.util.async_channel import AsyncChannel @@ -32,12 +33,59 @@ def _assert_request_meta_received(deadline, metadata): return server_side_test +@pytest.fixture +def handler_trailer_only_unauthenticated(): + async def handler(stream: grpclib.server.Stream): + await stream.recv_message() + await stream.send_initial_metadata() + await stream.send_trailing_metadata(status=grpclib.Status.UNAUTHENTICATED) + + return handler + + @pytest.mark.asyncio async def test_simple_service_call(): async with ChannelFor([ThingService()]) as channel: await _test_client(ThingServiceClient(channel)) +@pytest.mark.asyncio +async def test_trailer_only_error_unary_unary( + mocker, handler_trailer_only_unauthenticated +): + service = ThingService() + mocker.patch.object( + service, + "do_thing", + side_effect=handler_trailer_only_unauthenticated, + autospec=True, + ) + async with ChannelFor([service]) as channel: + with pytest.raises(grpclib.exceptions.GRPCError) as e: + await ThingServiceClient(channel).do_thing(name="something") + assert e.value.status == grpclib.Status.UNAUTHENTICATED + + +@pytest.mark.asyncio +async def test_trailer_only_error_stream_unary( + mocker, handler_trailer_only_unauthenticated +): + service = ThingService() + mocker.patch.object( + service, + "do_many_things", + side_effect=handler_trailer_only_unauthenticated, + autospec=True, + ) + async with ChannelFor([service]) as channel: + with pytest.raises(grpclib.exceptions.GRPCError) as e: + await ThingServiceClient(channel).do_many_things( + request_iterator=[DoThingRequest(name="something")] + ) + await _test_client(ThingServiceClient(channel)) + assert e.value.status == grpclib.Status.UNAUTHENTICATED + + @pytest.mark.asyncio @pytest.mark.skipif( sys.version_info < (3, 8), reason="async mock spy does works for python3.8+"