From a757da1b293969f148ba512bb7b5392dfb817cb6 Mon Sep 17 00:00:00 2001 From: Hans Lellelid Date: Mon, 11 May 2020 15:30:29 -0400 Subject: [PATCH] Adding basic support (untested) for client streaming --- betterproto/__init__.py | 38 ++++++++++++++++++++++++ betterproto/plugin.py | 5 ++-- betterproto/templates/template.py.j2 | 44 ++++++++++++++++++++++++++-- 3 files changed, 82 insertions(+), 5 deletions(-) diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 5d901be..a2e7a18 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -14,10 +14,12 @@ from typing import ( Collection, Dict, Generator, + Iterator, List, Mapping, Optional, Set, + SupportsBytes, Tuple, Type, TypeVar, @@ -431,6 +433,7 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: # Bound type variable to allow methods to return `self` of subclasses T = TypeVar("T", bound="Message") +ST = TypeVar("ST", bound="IProtoMessage") class ProtoClassMetadata: @@ -1104,3 +1107,38 @@ class ServiceStub(ABC): await stream.send_message(request, end=True) async for message in stream: yield message + + async def _stream_unary( + self, + route: str, + request_iterator: Iterator["IProtoMessage"], + request_type: Type[ST], + response_type: Type[T], + ) -> T: + """Make a stream request and return the response.""" + async with self.channel.request( + route, grpclib.const.Cardinality.STREAM_UNARY, request_type, response_type + ) as stream: + for message in request_iterator: + await stream.send_message(message) + await stream.send_request(end=True) + response = await stream.recv_message() + assert response is not None + return response + + async def _stream_stream( + self, + route: str, + request_iterator: Iterator["IProtoMessage"], + request_type: Type[ST], + response_type: Type[T], + ) -> AsyncGenerator[T, None]: + """Make a stream request and return the stream response iterator.""" + async with self.channel.request( + route, grpclib.const.Cardinality.STREAM_STREAM, request_type, response_type + ) as stream: + for message in request_iterator: + await stream.send_message(message) + await stream.send_request(end=True) + async for message in stream: + yield message diff --git a/betterproto/plugin.py b/betterproto/plugin.py index e300318..b877ce6 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -311,8 +311,6 @@ def generate_code(request, response): } for j, method in enumerate(service.method): - if method.client_streaming: - raise NotImplementedError("Client streaming not yet supported") input_message = None input_type = get_ref_type( @@ -350,6 +348,9 @@ def generate_code(request, response): if method.server_streaming: output["typing_imports"].add("AsyncGenerator") + if method.client_streaming: + output["typing_imports"].add("Iterator") + output["services"].append(data) output["imports"] = sorted(output["imports"]) diff --git a/betterproto/templates/template.py.j2 b/betterproto/templates/template.py.j2 index 3a19422..c4c3029 100644 --- a/betterproto/templates/template.py.j2 +++ b/betterproto/templates/template.py.j2 @@ -63,11 +63,28 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% endif %} {% for method in service.methods %} - async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}: + async def {{ method.py_name }}(self + {%- if not method.client_streaming -%} + {%- if method.input_message and method.input_message.properties -%}, *, + {%- for field in method.input_message.properties -%} + {{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") -%} + Optional[{{ field.type }}] + {%- else -%} + {{ field.type }} + {%- endif -%} = {{ field.zero }} + {%- if not loop.last %}, {% endif -%} + {%- endfor -%} + {%- endif -%} + {%- else -%} + {# Client streaming: need a request iterator instead #} + , request_iterator: Iterator["{{ method.input }}"] + {%- endif -%} + ) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}: {% if method.comment %} {{ method.comment }} {% endif %} + {% if not method.client_streaming %} request = {{ method.input }}() {% for field in method.input_message.properties %} {% if field.field_type == 'message' %} @@ -77,20 +94,41 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): request.{{ field.py_name }} = {{ field.py_name }} {% endif %} {% endfor %} + {% endif %} {% if method.server_streaming %} + {% if method.client_streaming %} + async for response in self._stream_stream( + "{{ method.route }}", + request_iterator, + {{ method.input }}, + {{ method.output }}, + ): + yield response + {% else %}{# i.e. not client streaming #} async for response in self._unary_stream( "{{ method.route }}", request, {{ method.output }}, ): yield response - {% else %} + + {% endif %}{# if client streaming #} + {% else %}{# i.e. not server streaming #} + {% if method.client_streaming %} + return await self._stream_unary( + "{{ method.route }}", + request_iterator, + {{ method.input }}, + {{ method.output }} + ) + {% else %}{# i.e. not client streaming #} return await self._unary_unary( "{{ method.route }}", request, - {{ method.output }}, + {{ method.output }} ) + {% endif %}{# client streaming #} {% endif %} {% endfor %}