Generate grpclib service stubs (#170)
This commit is contained in:
parent
73cea12e1f
commit
1d54ef8f99
30
src/betterproto/grpc/grpclib_server.py
Normal file
30
src/betterproto/grpc/grpclib_server.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from abc import ABC
|
||||||
|
from collections import AsyncIterable
|
||||||
|
from typing import Callable, Any, Dict
|
||||||
|
|
||||||
|
import grpclib
|
||||||
|
import grpclib.server
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceBase(ABC):
|
||||||
|
"""
|
||||||
|
Base class for async gRPC servers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _call_rpc_handler_server_stream(
|
||||||
|
self,
|
||||||
|
handler: Callable,
|
||||||
|
stream: grpclib.server.Stream,
|
||||||
|
request_kwargs: Dict[str, Any],
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
response_iter = handler(**request_kwargs)
|
||||||
|
# check if response is actually an AsyncIterator
|
||||||
|
# this might be false if the method just returns without
|
||||||
|
# yielding at least once
|
||||||
|
# in that case, we just interpret it as an empty iterator
|
||||||
|
if isinstance(response_iter, AsyncIterable):
|
||||||
|
async for response_message in response_iter:
|
||||||
|
await stream.send_message(response_message)
|
||||||
|
else:
|
||||||
|
response_iter.close()
|
@ -553,6 +553,7 @@ class ServiceCompiler(ProtoContentBase):
|
|||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
# Add service to output file
|
# Add service to output file
|
||||||
self.output_file.services.append(self)
|
self.output_file.services.append(self)
|
||||||
|
self.output_file.typing_imports.add("Dict")
|
||||||
super().__post_init__() # check for unset fields
|
super().__post_init__() # check for unset fields
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -15,6 +15,7 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no
|
|||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
import betterproto
|
import betterproto
|
||||||
|
from betterproto.grpc.grpclib_server import ServiceBase
|
||||||
{% if output_file.services %}
|
{% if output_file.services %}
|
||||||
import grpclib
|
import grpclib
|
||||||
{% endif %}
|
{% endif %}
|
||||||
@ -82,7 +83,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
|||||||
Optional[{{ field.annotation }}]
|
Optional[{{ field.annotation }}]
|
||||||
{%- else -%}
|
{%- else -%}
|
||||||
{{ field.annotation }}
|
{{ field.annotation }}
|
||||||
{%- endif -%} =
|
{%- endif -%} =
|
||||||
{%- if field.py_name not in method.mutable_default_args -%}
|
{%- if field.py_name not in method.mutable_default_args -%}
|
||||||
{{ field.default_value_string }}
|
{{ field.default_value_string }}
|
||||||
{%- else -%}
|
{%- else -%}
|
||||||
@ -154,6 +155,89 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
|||||||
{% endfor %}
|
{% endfor %}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
|
{% for service in output_file.services %}
|
||||||
|
class {{ service.py_name }}Base(ServiceBase):
|
||||||
|
{% if service.comment %}
|
||||||
|
{{ service.comment }}
|
||||||
|
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% for method in service.methods %}
|
||||||
|
async def {{ method.py_name }}(self
|
||||||
|
{%- if not method.client_streaming -%}
|
||||||
|
{%- if method.py_input_message and method.py_input_message.fields -%},
|
||||||
|
{%- for field in method.py_input_message.fields -%}
|
||||||
|
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
|
||||||
|
Optional[{{ field.annotation }}]
|
||||||
|
{%- else -%}
|
||||||
|
{{ field.annotation }}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- if not loop.last %}, {% endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- else -%}
|
||||||
|
{# Client streaming: need a request iterator instead #}
|
||||||
|
, request_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
|
||||||
|
{%- endif -%}
|
||||||
|
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||||
|
{% if method.comment %}
|
||||||
|
{{ method.comment }}
|
||||||
|
|
||||||
|
{% endif %}
|
||||||
|
raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED)
|
||||||
|
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
{% for method in service.methods %}
|
||||||
|
async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None:
|
||||||
|
{% if not method.client_streaming %}
|
||||||
|
request = await stream.recv_message()
|
||||||
|
|
||||||
|
request_kwargs = {
|
||||||
|
{% for field in method.py_input_message.fields %}
|
||||||
|
"{{ field.py_name }}": request.{{ field.py_name }},
|
||||||
|
{% endfor %}
|
||||||
|
}
|
||||||
|
|
||||||
|
{% else %}
|
||||||
|
request_kwargs = {"request_iterator": stream.__aiter__()}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% if not method.server_streaming %}
|
||||||
|
response = await self.{{ method.py_name }}(**request_kwargs)
|
||||||
|
await stream.send_message(response)
|
||||||
|
{% else %}
|
||||||
|
await self._call_rpc_handler_server_stream(
|
||||||
|
self.{{ method.py_name }},
|
||||||
|
stream,
|
||||||
|
request_kwargs,
|
||||||
|
)
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
|
||||||
|
return {
|
||||||
|
{% for method in service.methods %}
|
||||||
|
"{{ method.route }}": grpclib.const.Handler(
|
||||||
|
self.__rpc_{{ method.py_name }},
|
||||||
|
{% if not method.client_streaming and not method.server_streaming %}
|
||||||
|
grpclib.const.Cardinality.UNARY_UNARY,
|
||||||
|
{% elif not method.client_streaming and method.server_streaming %}
|
||||||
|
grpclib.const.Cardinality.UNARY_STREAM,
|
||||||
|
{% elif method.client_streaming and not method.server_streaming %}
|
||||||
|
grpclib.const.Cardinality.STREAM_UNARY,
|
||||||
|
{% else %}
|
||||||
|
grpclib.const.Cardinality.STREAM_STREAM,
|
||||||
|
{% endif %}
|
||||||
|
{{ method.py_input_message_type }},
|
||||||
|
{{ method.py_output_message_type }},
|
||||||
|
),
|
||||||
|
{% endfor %}
|
||||||
|
}
|
||||||
|
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
{% for i in output_file.imports|sort %}
|
{% for i in output_file.imports|sort %}
|
||||||
{{ i }}
|
{{ i }}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
@ -17,4 +17,5 @@ services = {
|
|||||||
"import_service_input_message",
|
"import_service_input_message",
|
||||||
"googletypes_service_returns_empty",
|
"googletypes_service_returns_empty",
|
||||||
"googletypes_service_returns_googletype",
|
"googletypes_service_returns_googletype",
|
||||||
|
"example_service",
|
||||||
}
|
}
|
||||||
|
20
tests/inputs/example_service/example_service.proto
Normal file
20
tests/inputs/example_service/example_service.proto
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package example_service;
|
||||||
|
|
||||||
|
service Test {
|
||||||
|
rpc ExampleUnaryUnary(ExampleRequest) returns (ExampleResponse);
|
||||||
|
rpc ExampleUnaryStream(ExampleRequest) returns (stream ExampleResponse);
|
||||||
|
rpc ExampleStreamUnary(stream ExampleRequest) returns (ExampleResponse);
|
||||||
|
rpc ExampleStreamStream(stream ExampleRequest) returns (stream ExampleResponse);
|
||||||
|
}
|
||||||
|
|
||||||
|
message ExampleRequest {
|
||||||
|
string example_string = 1;
|
||||||
|
int64 example_integer = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ExampleResponse {
|
||||||
|
string example_string = 1;
|
||||||
|
int64 example_integer = 2;
|
||||||
|
}
|
95
tests/inputs/example_service/test_example_service.py
Normal file
95
tests/inputs/example_service/test_example_service.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
from typing import AsyncIterator, AsyncIterable
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from grpclib.testing import ChannelFor
|
||||||
|
|
||||||
|
from tests.output_betterproto.example_service.example_service import (
|
||||||
|
TestBase,
|
||||||
|
TestStub,
|
||||||
|
ExampleRequest,
|
||||||
|
ExampleResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleService(TestBase):
|
||||||
|
async def example_unary_unary(
|
||||||
|
self, example_string: str, example_integer: int
|
||||||
|
) -> "ExampleResponse":
|
||||||
|
return ExampleResponse(
|
||||||
|
example_string=example_string,
|
||||||
|
example_integer=example_integer,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def example_unary_stream(
|
||||||
|
self, example_string: str, example_integer: int
|
||||||
|
) -> AsyncIterator["ExampleResponse"]:
|
||||||
|
response = ExampleResponse(
|
||||||
|
example_string=example_string,
|
||||||
|
example_integer=example_integer,
|
||||||
|
)
|
||||||
|
yield response
|
||||||
|
yield response
|
||||||
|
yield response
|
||||||
|
|
||||||
|
async def example_stream_unary(
|
||||||
|
self, request_iterator: AsyncIterator["ExampleRequest"]
|
||||||
|
) -> "ExampleResponse":
|
||||||
|
async for example_request in request_iterator:
|
||||||
|
return ExampleResponse(
|
||||||
|
example_string=example_request.example_string,
|
||||||
|
example_integer=example_request.example_integer,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def example_stream_stream(
|
||||||
|
self, request_iterator: AsyncIterator["ExampleRequest"]
|
||||||
|
) -> AsyncIterator["ExampleResponse"]:
|
||||||
|
async for example_request in request_iterator:
|
||||||
|
yield ExampleResponse(
|
||||||
|
example_string=example_request.example_string,
|
||||||
|
example_integer=example_request.example_integer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_calls_with_different_cardinalities():
|
||||||
|
test_string = "test string"
|
||||||
|
test_int = 42
|
||||||
|
|
||||||
|
async with ChannelFor([ExampleService()]) as channel:
|
||||||
|
stub = TestStub(channel)
|
||||||
|
|
||||||
|
# unary unary
|
||||||
|
response = await stub.example_unary_unary(
|
||||||
|
example_string="test string",
|
||||||
|
example_integer=42,
|
||||||
|
)
|
||||||
|
assert response.example_string == test_string
|
||||||
|
assert response.example_integer == test_int
|
||||||
|
|
||||||
|
# unary stream
|
||||||
|
async for response in stub.example_unary_stream(
|
||||||
|
example_string="test string",
|
||||||
|
example_integer=42,
|
||||||
|
):
|
||||||
|
assert response.example_string == test_string
|
||||||
|
assert response.example_integer == test_int
|
||||||
|
|
||||||
|
# stream unary
|
||||||
|
request = ExampleRequest(
|
||||||
|
example_string=test_string,
|
||||||
|
example_integer=42,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def request_iterator():
|
||||||
|
yield request
|
||||||
|
yield request
|
||||||
|
yield request
|
||||||
|
|
||||||
|
response = await stub.example_stream_unary(request_iterator())
|
||||||
|
assert response.example_string == test_string
|
||||||
|
assert response.example_integer == test_int
|
||||||
|
|
||||||
|
# stream stream
|
||||||
|
async for response in stub.example_stream_stream(request_iterator()):
|
||||||
|
assert response.example_string == test_string
|
||||||
|
assert response.example_integer == test_int
|
Loading…
x
Reference in New Issue
Block a user