Adding basic support (untested) for client streaming

This commit is contained in:
Hans Lellelid 2020-05-11 15:30:29 -04:00 committed by Nat Noordanus
parent a46979c8a6
commit a757da1b29
3 changed files with 82 additions and 5 deletions

View File

@ -14,10 +14,12 @@ from typing import (
Collection, Collection,
Dict, Dict,
Generator, Generator,
Iterator,
List, List,
Mapping, Mapping,
Optional, Optional,
Set, Set,
SupportsBytes,
Tuple, Tuple,
Type, Type,
TypeVar, TypeVar,
@ -431,6 +433,7 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
# Bound type variable to allow methods to return `self` of subclasses # Bound type variable to allow methods to return `self` of subclasses
T = TypeVar("T", bound="Message") T = TypeVar("T", bound="Message")
ST = TypeVar("ST", bound="IProtoMessage")
class ProtoClassMetadata: class ProtoClassMetadata:
@ -1104,3 +1107,38 @@ class ServiceStub(ABC):
await stream.send_message(request, end=True) await stream.send_message(request, end=True)
async for message in stream: async for message in stream:
yield message yield message
async def _stream_unary(
self,
route: str,
request_iterator: Iterator["IProtoMessage"],
request_type: Type[ST],
response_type: Type[T],
) -> T:
"""Make a stream request and return the response."""
async with self.channel.request(
route, grpclib.const.Cardinality.STREAM_UNARY, request_type, response_type
) as stream:
for message in request_iterator:
await stream.send_message(message)
await stream.send_request(end=True)
response = await stream.recv_message()
assert response is not None
return response
async def _stream_stream(
self,
route: str,
request_iterator: Iterator["IProtoMessage"],
request_type: Type[ST],
response_type: Type[T],
) -> AsyncGenerator[T, None]:
"""Make a stream request and return the stream response iterator."""
async with self.channel.request(
route, grpclib.const.Cardinality.STREAM_STREAM, request_type, response_type
) as stream:
for message in request_iterator:
await stream.send_message(message)
await stream.send_request(end=True)
async for message in stream:
yield message

View File

@ -311,8 +311,6 @@ def generate_code(request, response):
} }
for j, method in enumerate(service.method): for j, method in enumerate(service.method):
if method.client_streaming:
raise NotImplementedError("Client streaming not yet supported")
input_message = None input_message = None
input_type = get_ref_type( input_type = get_ref_type(
@ -350,6 +348,9 @@ def generate_code(request, response):
if method.server_streaming: if method.server_streaming:
output["typing_imports"].add("AsyncGenerator") output["typing_imports"].add("AsyncGenerator")
if method.client_streaming:
output["typing_imports"].add("Iterator")
output["services"].append(data) output["services"].append(data)
output["imports"] = sorted(output["imports"]) output["imports"] = sorted(output["imports"])

View File

@ -63,11 +63,28 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% endif %} {% endif %}
{% for method in service.methods %} {% for method in service.methods %}
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}: async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
{%- if method.input_message and method.input_message.properties -%}, *,
{%- for field in method.input_message.properties -%}
{{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") -%}
Optional[{{ field.type }}]
{%- else -%}
{{ field.type }}
{%- endif -%} = {{ field.zero }}
{%- if not loop.last %}, {% endif -%}
{%- endfor -%}
{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, request_iterator: Iterator["{{ method.input }}"]
{%- endif -%}
) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
{% if method.comment %} {% if method.comment %}
{{ method.comment }} {{ method.comment }}
{% endif %} {% endif %}
{% if not method.client_streaming %}
request = {{ method.input }}() request = {{ method.input }}()
{% for field in method.input_message.properties %} {% for field in method.input_message.properties %}
{% if field.field_type == 'message' %} {% if field.field_type == 'message' %}
@ -77,20 +94,41 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
request.{{ field.py_name }} = {{ field.py_name }} request.{{ field.py_name }} = {{ field.py_name }}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{% endif %}
{% if method.server_streaming %} {% if method.server_streaming %}
{% if method.client_streaming %}
async for response in self._stream_stream(
"{{ method.route }}",
request_iterator,
{{ method.input }},
{{ method.output }},
):
yield response
{% else %}{# i.e. not client streaming #}
async for response in self._unary_stream( async for response in self._unary_stream(
"{{ method.route }}", "{{ method.route }}",
request, request,
{{ method.output }}, {{ method.output }},
): ):
yield response yield response
{% else %}
{% endif %}{# if client streaming #}
{% else %}{# i.e. not server streaming #}
{% if method.client_streaming %}
return await self._stream_unary(
"{{ method.route }}",
request_iterator,
{{ method.input }},
{{ method.output }}
)
{% else %}{# i.e. not client streaming #}
return await self._unary_unary( return await self._unary_unary(
"{{ method.route }}", "{{ method.route }}",
request, request,
{{ method.output }}, {{ method.output }}
) )
{% endif %}{# client streaming #}
{% endif %} {% endif %}
{% endfor %} {% endfor %}