Adding basic support (untested) for client streaming
This commit is contained in:
		
				
					committed by
					
						 Nat Noordanus
						Nat Noordanus
					
				
			
			
				
	
			
			
			
						parent
						
							a46979c8a6
						
					
				
				
					commit
					a757da1b29
				
			| @@ -14,10 +14,12 @@ from typing import ( | ||||
|     Collection, | ||||
|     Dict, | ||||
|     Generator, | ||||
|     Iterator, | ||||
|     List, | ||||
|     Mapping, | ||||
|     Optional, | ||||
|     Set, | ||||
|     SupportsBytes, | ||||
|     Tuple, | ||||
|     Type, | ||||
|     TypeVar, | ||||
| @@ -431,6 +433,7 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: | ||||
|  | ||||
| # Bound type variable to allow methods to return `self` of subclasses | ||||
| T = TypeVar("T", bound="Message") | ||||
| ST = TypeVar("ST", bound="IProtoMessage") | ||||
|  | ||||
|  | ||||
| class ProtoClassMetadata: | ||||
| @@ -1104,3 +1107,38 @@ class ServiceStub(ABC): | ||||
|             await stream.send_message(request, end=True) | ||||
|             async for message in stream: | ||||
|                 yield message | ||||
|  | ||||
|     async def _stream_unary( | ||||
|         self, | ||||
|         route: str, | ||||
|         request_iterator: Iterator["IProtoMessage"], | ||||
|         request_type: Type[ST], | ||||
|         response_type: Type[T], | ||||
|     ) -> T: | ||||
|         """Make a stream request and return the response.""" | ||||
|         async with self.channel.request( | ||||
|             route, grpclib.const.Cardinality.STREAM_UNARY, request_type, response_type | ||||
|         ) as stream: | ||||
|             for message in request_iterator: | ||||
|                 await stream.send_message(message) | ||||
|             await stream.send_request(end=True) | ||||
|             response = await stream.recv_message() | ||||
|             assert response is not None | ||||
|             return response | ||||
|  | ||||
|     async def _stream_stream( | ||||
|         self, | ||||
|         route: str, | ||||
|         request_iterator: Iterator["IProtoMessage"], | ||||
|         request_type: Type[ST], | ||||
|         response_type: Type[T], | ||||
|     ) -> AsyncGenerator[T, None]: | ||||
|         """Make a stream request and return the stream response iterator.""" | ||||
|         async with self.channel.request( | ||||
|             route, grpclib.const.Cardinality.STREAM_STREAM, request_type, response_type | ||||
|         ) as stream: | ||||
|             for message in request_iterator: | ||||
|                 await stream.send_message(message) | ||||
|             await stream.send_request(end=True) | ||||
|             async for message in stream: | ||||
|                 yield message | ||||
|   | ||||
| @@ -311,8 +311,6 @@ def generate_code(request, response): | ||||
|                 } | ||||
|  | ||||
|                 for j, method in enumerate(service.method): | ||||
|                     if method.client_streaming: | ||||
|                         raise NotImplementedError("Client streaming not yet supported") | ||||
|  | ||||
|                     input_message = None | ||||
|                     input_type = get_ref_type( | ||||
| @@ -350,6 +348,9 @@ def generate_code(request, response): | ||||
|                     if method.server_streaming: | ||||
|                         output["typing_imports"].add("AsyncGenerator") | ||||
|  | ||||
|                     if method.client_streaming: | ||||
|                         output["typing_imports"].add("Iterator") | ||||
|  | ||||
|                 output["services"].append(data) | ||||
|  | ||||
|         output["imports"] = sorted(output["imports"]) | ||||
|   | ||||
| @@ -63,11 +63,28 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): | ||||
|  | ||||
|     {% endif %} | ||||
|     {% for method in service.methods %} | ||||
|     async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}: | ||||
|     async def {{ method.py_name }}(self | ||||
|         {%- if not method.client_streaming -%} | ||||
|             {%- if method.input_message and method.input_message.properties -%}, *, | ||||
|                 {%- for field in method.input_message.properties -%} | ||||
|                     {{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") -%} | ||||
|                                             Optional[{{ field.type }}] | ||||
|                                          {%- else -%} | ||||
|                                             {{ field.type }} | ||||
|                                          {%- endif -%} = {{ field.zero }} | ||||
|                     {%- if not loop.last %}, {% endif -%} | ||||
|                 {%- endfor -%} | ||||
|             {%- endif -%} | ||||
|         {%- else -%} | ||||
|             {# Client streaming: need a request iterator instead #} | ||||
|             , request_iterator: Iterator["{{ method.input }}"] | ||||
|         {%- endif -%} | ||||
|             ) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}: | ||||
|         {% if method.comment %} | ||||
| {{ method.comment }} | ||||
|  | ||||
|         {% endif %} | ||||
|         {% if not method.client_streaming %} | ||||
|         request = {{ method.input }}() | ||||
|         {% for field in method.input_message.properties %} | ||||
|             {% if field.field_type == 'message' %} | ||||
| @@ -77,20 +94,41 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): | ||||
|         request.{{ field.py_name }} = {{ field.py_name }} | ||||
|             {% endif %} | ||||
|         {% endfor %} | ||||
|         {% endif %} | ||||
|  | ||||
|         {% if method.server_streaming %} | ||||
|         {% if method.client_streaming %} | ||||
|         async for response in self._stream_stream( | ||||
|             "{{ method.route }}", | ||||
|             request_iterator, | ||||
|             {{ method.input }}, | ||||
|             {{ method.output }}, | ||||
|         ): | ||||
|             yield response | ||||
|         {% else %}{# i.e. not client streaming #} | ||||
|         async for response in self._unary_stream( | ||||
|             "{{ method.route }}", | ||||
|             request, | ||||
|             {{ method.output }}, | ||||
|         ): | ||||
|             yield response | ||||
|         {% else %} | ||||
|  | ||||
|         {% endif %}{# if client streaming #} | ||||
|         {% else %}{# i.e. not server streaming #} | ||||
|         {% if method.client_streaming %} | ||||
|         return await self._stream_unary( | ||||
|             "{{ method.route }}", | ||||
|             request_iterator, | ||||
|             {{ method.input }}, | ||||
|             {{ method.output }} | ||||
|         ) | ||||
|         {% else %}{# i.e. not client streaming #} | ||||
|         return await self._unary_unary( | ||||
|             "{{ method.route }}", | ||||
|             request, | ||||
|             {{ method.output }}, | ||||
|             {{ method.output }} | ||||
|         ) | ||||
|         {% endif %}{# client streaming #} | ||||
|         {% endif %} | ||||
|  | ||||
|     {% endfor %} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user