diff --git a/src/betterproto/grpc/grpclib_server.py b/src/betterproto/grpc/grpclib_server.py new file mode 100644 index 0000000..59bc7d4 --- /dev/null +++ b/src/betterproto/grpc/grpclib_server.py @@ -0,0 +1,30 @@ +from abc import ABC +from collections import AsyncIterable +from typing import Callable, Any, Dict + +import grpclib +import grpclib.server + + +class ServiceBase(ABC): + """ + Base class for async gRPC servers. + """ + + async def _call_rpc_handler_server_stream( + self, + handler: Callable, + stream: grpclib.server.Stream, + request_kwargs: Dict[str, Any], + ) -> None: + + response_iter = handler(**request_kwargs) + # check if response is actually an AsyncIterator + # this might be false if the method just returns without + # yielding at least once + # in that case, we just interpret it as an empty iterator + if isinstance(response_iter, AsyncIterable): + async for response_message in response_iter: + await stream.send_message(response_message) + else: + response_iter.close() diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 8d487a4..09217b9 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -553,6 +553,7 @@ class ServiceCompiler(ProtoContentBase): def __post_init__(self) -> None: # Add service to output file self.output_file.services.append(self) + self.output_file.typing_imports.add("Dict") super().__post_init__() # check for unset fields @property diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 753d340..de53963 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -15,6 +15,7 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no {% endif %} import betterproto +from betterproto.grpc.grpclib_server import ServiceBase {% if output_file.services %} import grpclib {% endif %} @@ -82,7 +83,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): Optional[{{ field.annotation }}] {%- else -%} {{ field.annotation }} - {%- endif -%} = + {%- endif -%} = {%- if field.py_name not in method.mutable_default_args -%} {{ field.default_value_string }} {%- else -%} @@ -154,6 +155,89 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% endfor %} {% endfor %} +{% for service in output_file.services %} +class {{ service.py_name }}Base(ServiceBase): + {% if service.comment %} +{{ service.comment }} + + {% endif %} + + {% for method in service.methods %} + async def {{ method.py_name }}(self + {%- if not method.client_streaming -%} + {%- if method.py_input_message and method.py_input_message.fields -%}, + {%- for field in method.py_input_message.fields -%} + {{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%} + Optional[{{ field.annotation }}] + {%- else -%} + {{ field.annotation }} + {%- endif -%} + {%- if not loop.last %}, {% endif -%} + {%- endfor -%} + {%- endif -%} + {%- else -%} + {# Client streaming: need a request iterator instead #} + , request_iterator: AsyncIterator["{{ method.py_input_message_type }}"] + {%- endif -%} + ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: + {% if method.comment %} +{{ method.comment }} + + {% endif %} + raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED) + + {% endfor %} + + {% for method in service.methods %} + async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None: + {% if not method.client_streaming %} + request = await stream.recv_message() + + request_kwargs = { + {% for field in method.py_input_message.fields %} + "{{ field.py_name }}": request.{{ field.py_name }}, + {% endfor %} + } + + {% else %} + request_kwargs = {"request_iterator": stream.__aiter__()} + {% endif %} + + {% if not method.server_streaming %} + response = await self.{{ method.py_name }}(**request_kwargs) + await stream.send_message(response) + {% else %} + await self._call_rpc_handler_server_stream( + self.{{ method.py_name }}, + stream, + request_kwargs, + ) + {% endif %} + + {% endfor %} + + def __mapping__(self) -> Dict[str, grpclib.const.Handler]: + return { + {% for method in service.methods %} + "{{ method.route }}": grpclib.const.Handler( + self.__rpc_{{ method.py_name }}, + {% if not method.client_streaming and not method.server_streaming %} + grpclib.const.Cardinality.UNARY_UNARY, + {% elif not method.client_streaming and method.server_streaming %} + grpclib.const.Cardinality.UNARY_STREAM, + {% elif method.client_streaming and not method.server_streaming %} + grpclib.const.Cardinality.STREAM_UNARY, + {% else %} + grpclib.const.Cardinality.STREAM_STREAM, + {% endif %} + {{ method.py_input_message_type }}, + {{ method.py_output_message_type }}, + ), + {% endfor %} + } + +{% endfor %} + {% for i in output_file.imports|sort %} {{ i }} {% endfor %} diff --git a/tests/inputs/config.py b/tests/inputs/config.py index 7d14667..9b7b288 100644 --- a/tests/inputs/config.py +++ b/tests/inputs/config.py @@ -17,4 +17,5 @@ services = { "import_service_input_message", "googletypes_service_returns_empty", "googletypes_service_returns_googletype", + "example_service", } diff --git a/tests/inputs/example_service/example_service.proto b/tests/inputs/example_service/example_service.proto new file mode 100644 index 0000000..96455cc --- /dev/null +++ b/tests/inputs/example_service/example_service.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package example_service; + +service Test { + rpc ExampleUnaryUnary(ExampleRequest) returns (ExampleResponse); + rpc ExampleUnaryStream(ExampleRequest) returns (stream ExampleResponse); + rpc ExampleStreamUnary(stream ExampleRequest) returns (ExampleResponse); + rpc ExampleStreamStream(stream ExampleRequest) returns (stream ExampleResponse); +} + +message ExampleRequest { + string example_string = 1; + int64 example_integer = 2; +} + +message ExampleResponse { + string example_string = 1; + int64 example_integer = 2; +} diff --git a/tests/inputs/example_service/test_example_service.py b/tests/inputs/example_service/test_example_service.py new file mode 100644 index 0000000..12d646b --- /dev/null +++ b/tests/inputs/example_service/test_example_service.py @@ -0,0 +1,95 @@ +from typing import AsyncIterator, AsyncIterable + +import pytest +from grpclib.testing import ChannelFor + +from tests.output_betterproto.example_service.example_service import ( + TestBase, + TestStub, + ExampleRequest, + ExampleResponse, +) + + +class ExampleService(TestBase): + async def example_unary_unary( + self, example_string: str, example_integer: int + ) -> "ExampleResponse": + return ExampleResponse( + example_string=example_string, + example_integer=example_integer, + ) + + async def example_unary_stream( + self, example_string: str, example_integer: int + ) -> AsyncIterator["ExampleResponse"]: + response = ExampleResponse( + example_string=example_string, + example_integer=example_integer, + ) + yield response + yield response + yield response + + async def example_stream_unary( + self, request_iterator: AsyncIterator["ExampleRequest"] + ) -> "ExampleResponse": + async for example_request in request_iterator: + return ExampleResponse( + example_string=example_request.example_string, + example_integer=example_request.example_integer, + ) + + async def example_stream_stream( + self, request_iterator: AsyncIterator["ExampleRequest"] + ) -> AsyncIterator["ExampleResponse"]: + async for example_request in request_iterator: + yield ExampleResponse( + example_string=example_request.example_string, + example_integer=example_request.example_integer, + ) + + +@pytest.mark.asyncio +async def test_calls_with_different_cardinalities(): + test_string = "test string" + test_int = 42 + + async with ChannelFor([ExampleService()]) as channel: + stub = TestStub(channel) + + # unary unary + response = await stub.example_unary_unary( + example_string="test string", + example_integer=42, + ) + assert response.example_string == test_string + assert response.example_integer == test_int + + # unary stream + async for response in stub.example_unary_stream( + example_string="test string", + example_integer=42, + ): + assert response.example_string == test_string + assert response.example_integer == test_int + + # stream unary + request = ExampleRequest( + example_string=test_string, + example_integer=42, + ) + + async def request_iterator(): + yield request + yield request + yield request + + response = await stub.example_stream_unary(request_iterator()) + assert response.example_string == test_string + assert response.example_integer == test_int + + # stream stream + async for response in stub.example_stream_stream(request_iterator()): + assert response.example_string == test_string + assert response.example_integer == test_int