Generate grpclib service stubs (#170)
This commit is contained in:
30
src/betterproto/grpc/grpclib_server.py
Normal file
30
src/betterproto/grpc/grpclib_server.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from abc import ABC
|
||||
from collections import AsyncIterable
|
||||
from typing import Callable, Any, Dict
|
||||
|
||||
import grpclib
|
||||
import grpclib.server
|
||||
|
||||
|
||||
class ServiceBase(ABC):
|
||||
"""
|
||||
Base class for async gRPC servers.
|
||||
"""
|
||||
|
||||
async def _call_rpc_handler_server_stream(
|
||||
self,
|
||||
handler: Callable,
|
||||
stream: grpclib.server.Stream,
|
||||
request_kwargs: Dict[str, Any],
|
||||
) -> None:
|
||||
|
||||
response_iter = handler(**request_kwargs)
|
||||
# check if response is actually an AsyncIterator
|
||||
# this might be false if the method just returns without
|
||||
# yielding at least once
|
||||
# in that case, we just interpret it as an empty iterator
|
||||
if isinstance(response_iter, AsyncIterable):
|
||||
async for response_message in response_iter:
|
||||
await stream.send_message(response_message)
|
||||
else:
|
||||
response_iter.close()
|
||||
@@ -553,6 +553,7 @@ class ServiceCompiler(ProtoContentBase):
|
||||
def __post_init__(self) -> None:
|
||||
# Add service to output file
|
||||
self.output_file.services.append(self)
|
||||
self.output_file.typing_imports.add("Dict")
|
||||
super().__post_init__() # check for unset fields
|
||||
|
||||
@property
|
||||
|
||||
@@ -15,6 +15,7 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no
|
||||
{% endif %}
|
||||
|
||||
import betterproto
|
||||
from betterproto.grpc.grpclib_server import ServiceBase
|
||||
{% if output_file.services %}
|
||||
import grpclib
|
||||
{% endif %}
|
||||
@@ -82,7 +83,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
Optional[{{ field.annotation }}]
|
||||
{%- else -%}
|
||||
{{ field.annotation }}
|
||||
{%- endif -%} =
|
||||
{%- endif -%} =
|
||||
{%- if field.py_name not in method.mutable_default_args -%}
|
||||
{{ field.default_value_string }}
|
||||
{%- else -%}
|
||||
@@ -154,6 +155,89 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
{% endfor %}
|
||||
{% endfor %}
|
||||
|
||||
{% for service in output_file.services %}
|
||||
class {{ service.py_name }}Base(ServiceBase):
|
||||
{% if service.comment %}
|
||||
{{ service.comment }}
|
||||
|
||||
{% endif %}
|
||||
|
||||
{% for method in service.methods %}
|
||||
async def {{ method.py_name }}(self
|
||||
{%- if not method.client_streaming -%}
|
||||
{%- if method.py_input_message and method.py_input_message.fields -%},
|
||||
{%- for field in method.py_input_message.fields -%}
|
||||
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
|
||||
Optional[{{ field.annotation }}]
|
||||
{%- else -%}
|
||||
{{ field.annotation }}
|
||||
{%- endif -%}
|
||||
{%- if not loop.last %}, {% endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{# Client streaming: need a request iterator instead #}
|
||||
, request_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
|
||||
{%- endif -%}
|
||||
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||
{% if method.comment %}
|
||||
{{ method.comment }}
|
||||
|
||||
{% endif %}
|
||||
raise grpclib.GRPCError(grpclib.const.Status.UNIMPLEMENTED)
|
||||
|
||||
{% endfor %}
|
||||
|
||||
{% for method in service.methods %}
|
||||
async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None:
|
||||
{% if not method.client_streaming %}
|
||||
request = await stream.recv_message()
|
||||
|
||||
request_kwargs = {
|
||||
{% for field in method.py_input_message.fields %}
|
||||
"{{ field.py_name }}": request.{{ field.py_name }},
|
||||
{% endfor %}
|
||||
}
|
||||
|
||||
{% else %}
|
||||
request_kwargs = {"request_iterator": stream.__aiter__()}
|
||||
{% endif %}
|
||||
|
||||
{% if not method.server_streaming %}
|
||||
response = await self.{{ method.py_name }}(**request_kwargs)
|
||||
await stream.send_message(response)
|
||||
{% else %}
|
||||
await self._call_rpc_handler_server_stream(
|
||||
self.{{ method.py_name }},
|
||||
stream,
|
||||
request_kwargs,
|
||||
)
|
||||
{% endif %}
|
||||
|
||||
{% endfor %}
|
||||
|
||||
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
|
||||
return {
|
||||
{% for method in service.methods %}
|
||||
"{{ method.route }}": grpclib.const.Handler(
|
||||
self.__rpc_{{ method.py_name }},
|
||||
{% if not method.client_streaming and not method.server_streaming %}
|
||||
grpclib.const.Cardinality.UNARY_UNARY,
|
||||
{% elif not method.client_streaming and method.server_streaming %}
|
||||
grpclib.const.Cardinality.UNARY_STREAM,
|
||||
{% elif method.client_streaming and not method.server_streaming %}
|
||||
grpclib.const.Cardinality.STREAM_UNARY,
|
||||
{% else %}
|
||||
grpclib.const.Cardinality.STREAM_STREAM,
|
||||
{% endif %}
|
||||
{{ method.py_input_message_type }},
|
||||
{{ method.py_output_message_type }},
|
||||
),
|
||||
{% endfor %}
|
||||
}
|
||||
|
||||
{% endfor %}
|
||||
|
||||
{% for i in output_file.imports|sort %}
|
||||
{{ i }}
|
||||
{% endfor %}
|
||||
|
||||
Reference in New Issue
Block a user