Adding basic support (untested) for client streaming
This commit is contained in:
parent
a46979c8a6
commit
a757da1b29
@ -14,10 +14,12 @@ from typing import (
|
|||||||
Collection,
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
|
SupportsBytes,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
@ -431,6 +433,7 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
|||||||
|
|
||||||
# Bound type variable to allow methods to return `self` of subclasses
|
# Bound type variable to allow methods to return `self` of subclasses
|
||||||
T = TypeVar("T", bound="Message")
|
T = TypeVar("T", bound="Message")
|
||||||
|
ST = TypeVar("ST", bound="IProtoMessage")
|
||||||
|
|
||||||
|
|
||||||
class ProtoClassMetadata:
|
class ProtoClassMetadata:
|
||||||
@ -1104,3 +1107,38 @@ class ServiceStub(ABC):
|
|||||||
await stream.send_message(request, end=True)
|
await stream.send_message(request, end=True)
|
||||||
async for message in stream:
|
async for message in stream:
|
||||||
yield message
|
yield message
|
||||||
|
|
||||||
|
async def _stream_unary(
|
||||||
|
self,
|
||||||
|
route: str,
|
||||||
|
request_iterator: Iterator["IProtoMessage"],
|
||||||
|
request_type: Type[ST],
|
||||||
|
response_type: Type[T],
|
||||||
|
) -> T:
|
||||||
|
"""Make a stream request and return the response."""
|
||||||
|
async with self.channel.request(
|
||||||
|
route, grpclib.const.Cardinality.STREAM_UNARY, request_type, response_type
|
||||||
|
) as stream:
|
||||||
|
for message in request_iterator:
|
||||||
|
await stream.send_message(message)
|
||||||
|
await stream.send_request(end=True)
|
||||||
|
response = await stream.recv_message()
|
||||||
|
assert response is not None
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def _stream_stream(
|
||||||
|
self,
|
||||||
|
route: str,
|
||||||
|
request_iterator: Iterator["IProtoMessage"],
|
||||||
|
request_type: Type[ST],
|
||||||
|
response_type: Type[T],
|
||||||
|
) -> AsyncGenerator[T, None]:
|
||||||
|
"""Make a stream request and return the stream response iterator."""
|
||||||
|
async with self.channel.request(
|
||||||
|
route, grpclib.const.Cardinality.STREAM_STREAM, request_type, response_type
|
||||||
|
) as stream:
|
||||||
|
for message in request_iterator:
|
||||||
|
await stream.send_message(message)
|
||||||
|
await stream.send_request(end=True)
|
||||||
|
async for message in stream:
|
||||||
|
yield message
|
||||||
|
@ -311,8 +311,6 @@ def generate_code(request, response):
|
|||||||
}
|
}
|
||||||
|
|
||||||
for j, method in enumerate(service.method):
|
for j, method in enumerate(service.method):
|
||||||
if method.client_streaming:
|
|
||||||
raise NotImplementedError("Client streaming not yet supported")
|
|
||||||
|
|
||||||
input_message = None
|
input_message = None
|
||||||
input_type = get_ref_type(
|
input_type = get_ref_type(
|
||||||
@ -350,6 +348,9 @@ def generate_code(request, response):
|
|||||||
if method.server_streaming:
|
if method.server_streaming:
|
||||||
output["typing_imports"].add("AsyncGenerator")
|
output["typing_imports"].add("AsyncGenerator")
|
||||||
|
|
||||||
|
if method.client_streaming:
|
||||||
|
output["typing_imports"].add("Iterator")
|
||||||
|
|
||||||
output["services"].append(data)
|
output["services"].append(data)
|
||||||
|
|
||||||
output["imports"] = sorted(output["imports"])
|
output["imports"] = sorted(output["imports"])
|
||||||
|
@ -63,11 +63,28 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
|||||||
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% for method in service.methods %}
|
{% for method in service.methods %}
|
||||||
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
|
async def {{ method.py_name }}(self
|
||||||
|
{%- if not method.client_streaming -%}
|
||||||
|
{%- if method.input_message and method.input_message.properties -%}, *,
|
||||||
|
{%- for field in method.input_message.properties -%}
|
||||||
|
{{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") -%}
|
||||||
|
Optional[{{ field.type }}]
|
||||||
|
{%- else -%}
|
||||||
|
{{ field.type }}
|
||||||
|
{%- endif -%} = {{ field.zero }}
|
||||||
|
{%- if not loop.last %}, {% endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- else -%}
|
||||||
|
{# Client streaming: need a request iterator instead #}
|
||||||
|
, request_iterator: Iterator["{{ method.input }}"]
|
||||||
|
{%- endif -%}
|
||||||
|
) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
|
||||||
{% if method.comment %}
|
{% if method.comment %}
|
||||||
{{ method.comment }}
|
{{ method.comment }}
|
||||||
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
{% if not method.client_streaming %}
|
||||||
request = {{ method.input }}()
|
request = {{ method.input }}()
|
||||||
{% for field in method.input_message.properties %}
|
{% for field in method.input_message.properties %}
|
||||||
{% if field.field_type == 'message' %}
|
{% if field.field_type == 'message' %}
|
||||||
@ -77,20 +94,41 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
|||||||
request.{{ field.py_name }} = {{ field.py_name }}
|
request.{{ field.py_name }} = {{ field.py_name }}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
{% if method.server_streaming %}
|
{% if method.server_streaming %}
|
||||||
|
{% if method.client_streaming %}
|
||||||
|
async for response in self._stream_stream(
|
||||||
|
"{{ method.route }}",
|
||||||
|
request_iterator,
|
||||||
|
{{ method.input }},
|
||||||
|
{{ method.output }},
|
||||||
|
):
|
||||||
|
yield response
|
||||||
|
{% else %}{# i.e. not client streaming #}
|
||||||
async for response in self._unary_stream(
|
async for response in self._unary_stream(
|
||||||
"{{ method.route }}",
|
"{{ method.route }}",
|
||||||
request,
|
request,
|
||||||
{{ method.output }},
|
{{ method.output }},
|
||||||
):
|
):
|
||||||
yield response
|
yield response
|
||||||
{% else %}
|
|
||||||
|
{% endif %}{# if client streaming #}
|
||||||
|
{% else %}{# i.e. not server streaming #}
|
||||||
|
{% if method.client_streaming %}
|
||||||
|
return await self._stream_unary(
|
||||||
|
"{{ method.route }}",
|
||||||
|
request_iterator,
|
||||||
|
{{ method.input }},
|
||||||
|
{{ method.output }}
|
||||||
|
)
|
||||||
|
{% else %}{# i.e. not client streaming #}
|
||||||
return await self._unary_unary(
|
return await self._unary_unary(
|
||||||
"{{ method.route }}",
|
"{{ method.route }}",
|
||||||
request,
|
request,
|
||||||
{{ method.output }},
|
{{ method.output }}
|
||||||
)
|
)
|
||||||
|
{% endif %}{# client streaming #}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user