Generate grpclib service stubs (#170)

This commit is contained in:
Tim Schmidt 2020-12-04 22:22:11 +01:00 committed by GitHub
parent 73cea12e1f
commit 1d54ef8f99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 232 additions and 1 deletions

View File

@ -0,0 +1,30 @@
from abc import ABC
from collections import AsyncIterable
from typing import Callable, Any, Dict
import grpclib
import grpclib.server
class ServiceBase(ABC):
"""
Base class for async gRPC servers.
"""
async def _call_rpc_handler_server_stream(
self,
handler: Callable,
stream: grpclib.server.Stream,
request_kwargs: Dict[str, Any],
) -> None:
response_iter = handler(**request_kwargs)
# check if response is actually an AsyncIterator
# this might be false if the method just returns without
# yielding at least once
# in that case, we just interpret it as an empty iterator
if isinstance(response_iter, AsyncIterable):
async for response_message in response_iter:
await stream.send_message(response_message)
else:
response_iter.close()

View File

@ -553,6 +553,7 @@ class ServiceCompiler(ProtoContentBase):
def __post_init__(self) -> None: def __post_init__(self) -> None:
# Add service to output file # Add service to output file
self.output_file.services.append(self) self.output_file.services.append(self)
self.output_file.typing_imports.add("Dict")
super().__post_init__() # check for unset fields super().__post_init__() # check for unset fields
@property @property

View File

@ -15,6 +15,7 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no
{% endif %} {% endif %}
import betterproto import betterproto
from betterproto.grpc.grpclib_server import ServiceBase
{% if output_file.services %} {% if output_file.services %}
import grpclib import grpclib
{% endif %} {% endif %}
@ -82,7 +83,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
Optional[{{ field.annotation }}] Optional[{{ field.annotation }}]
{%- else -%} {%- else -%}
{{ field.annotation }} {{ field.annotation }}
{%- endif -%} = {%- endif -%} =
{%- if field.py_name not in method.mutable_default_args -%} {%- if field.py_name not in method.mutable_default_args -%}
{{ field.default_value_string }} {{ field.default_value_string }}
{%- else -%} {%- else -%}
@ -154,6 +155,89 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% endfor %} {% endfor %}
{% endfor %} {% endfor %}
{% for service in output_file.services %}
class {{ service.py_name }}Base(ServiceBase):
{% if service.comment %}
{{ service.comment }}
{% endif %}
{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
{%- if method.py_input_message and method.py_input_message.fields -%},
{%- for field in method.py_input_message.fields -%}
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
Optional[{{ field.annotation }}]
{%- else -%}
{{ field.annotation }}
{%- endif -%}
{%- if not loop.last %}, {% endif -%}
{%- endfor -%}
{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, request_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
{%- endif -%}
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
{{ method.comment }}
{% endif %}
raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED)
{% endfor %}
{% for method in service.methods %}
async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None:
{% if not method.client_streaming %}
request = await stream.recv_message()
request_kwargs = {
{% for field in method.py_input_message.fields %}
"{{ field.py_name }}": request.{{ field.py_name }},
{% endfor %}
}
{% else %}
request_kwargs = {"request_iterator": stream.__aiter__()}
{% endif %}
{% if not method.server_streaming %}
response = await self.{{ method.py_name }}(**request_kwargs)
await stream.send_message(response)
{% else %}
await self._call_rpc_handler_server_stream(
self.{{ method.py_name }},
stream,
request_kwargs,
)
{% endif %}
{% endfor %}
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
return {
{% for method in service.methods %}
"{{ method.route }}": grpclib.const.Handler(
self.__rpc_{{ method.py_name }},
{% if not method.client_streaming and not method.server_streaming %}
grpclib.const.Cardinality.UNARY_UNARY,
{% elif not method.client_streaming and method.server_streaming %}
grpclib.const.Cardinality.UNARY_STREAM,
{% elif method.client_streaming and not method.server_streaming %}
grpclib.const.Cardinality.STREAM_UNARY,
{% else %}
grpclib.const.Cardinality.STREAM_STREAM,
{% endif %}
{{ method.py_input_message_type }},
{{ method.py_output_message_type }},
),
{% endfor %}
}
{% endfor %}
{% for i in output_file.imports|sort %} {% for i in output_file.imports|sort %}
{{ i }} {{ i }}
{% endfor %} {% endfor %}

View File

@ -17,4 +17,5 @@ services = {
"import_service_input_message", "import_service_input_message",
"googletypes_service_returns_empty", "googletypes_service_returns_empty",
"googletypes_service_returns_googletype", "googletypes_service_returns_googletype",
"example_service",
} }

View File

@ -0,0 +1,20 @@
syntax = "proto3";
package example_service;
service Test {
rpc ExampleUnaryUnary(ExampleRequest) returns (ExampleResponse);
rpc ExampleUnaryStream(ExampleRequest) returns (stream ExampleResponse);
rpc ExampleStreamUnary(stream ExampleRequest) returns (ExampleResponse);
rpc ExampleStreamStream(stream ExampleRequest) returns (stream ExampleResponse);
}
message ExampleRequest {
string example_string = 1;
int64 example_integer = 2;
}
message ExampleResponse {
string example_string = 1;
int64 example_integer = 2;
}

View File

@ -0,0 +1,95 @@
from typing import AsyncIterator, AsyncIterable
import pytest
from grpclib.testing import ChannelFor
from tests.output_betterproto.example_service.example_service import (
TestBase,
TestStub,
ExampleRequest,
ExampleResponse,
)
class ExampleService(TestBase):
async def example_unary_unary(
self, example_string: str, example_integer: int
) -> "ExampleResponse":
return ExampleResponse(
example_string=example_string,
example_integer=example_integer,
)
async def example_unary_stream(
self, example_string: str, example_integer: int
) -> AsyncIterator["ExampleResponse"]:
response = ExampleResponse(
example_string=example_string,
example_integer=example_integer,
)
yield response
yield response
yield response
async def example_stream_unary(
self, request_iterator: AsyncIterator["ExampleRequest"]
) -> "ExampleResponse":
async for example_request in request_iterator:
return ExampleResponse(
example_string=example_request.example_string,
example_integer=example_request.example_integer,
)
async def example_stream_stream(
self, request_iterator: AsyncIterator["ExampleRequest"]
) -> AsyncIterator["ExampleResponse"]:
async for example_request in request_iterator:
yield ExampleResponse(
example_string=example_request.example_string,
example_integer=example_request.example_integer,
)
@pytest.mark.asyncio
async def test_calls_with_different_cardinalities():
test_string = "test string"
test_int = 42
async with ChannelFor([ExampleService()]) as channel:
stub = TestStub(channel)
# unary unary
response = await stub.example_unary_unary(
example_string="test string",
example_integer=42,
)
assert response.example_string == test_string
assert response.example_integer == test_int
# unary stream
async for response in stub.example_unary_stream(
example_string="test string",
example_integer=42,
):
assert response.example_string == test_string
assert response.example_integer == test_int
# stream unary
request = ExampleRequest(
example_string=test_string,
example_integer=42,
)
async def request_iterator():
yield request
yield request
yield request
response = await stub.example_stream_unary(request_iterator())
assert response.example_string == test_string
assert response.example_integer == test_int
# stream stream
async for response in stub.example_stream_stream(request_iterator()):
assert response.example_string == test_string
assert response.example_integer == test_int