Client and Service Stubs take 1 request parameter, not one for each field (#311)

This commit is contained in:
efokschaner
2022-01-17 10:58:57 -08:00
committed by GitHub
parent 6dd7baa26c
commit d260f071e0
13 changed files with 140 additions and 193 deletions

View File

@@ -1,6 +1,6 @@
from abc import ABC
from collections.abc import AsyncIterable
from typing import Callable, Any, Dict
from typing import Any, Callable, Dict
import grpclib
import grpclib.server
@@ -15,10 +15,10 @@ class ServiceBase(ABC):
self,
handler: Callable,
stream: grpclib.server.Stream,
request_kwargs: Dict[str, Any],
request: Any,
) -> None:
response_iter = handler(**request_kwargs)
response_iter = handler(request)
# check if response is actually an AsyncIterator
# this might be false if the method just returns without
# yielding at least once

View File

@@ -31,13 +31,15 @@ reference to `A` to `B`'s `fields` attribute.
import builtins
import re
import textwrap
from dataclasses import dataclass, field
from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union
import betterproto
from betterproto import which_one_of
from betterproto.casing import sanitize_name
from betterproto.compile.importing import (
get_type_reference,
parse_source_type_name,
)
from betterproto.compile.importing import get_type_reference, parse_source_type_name
from betterproto.compile.naming import (
pythonize_class_name,
pythonize_field_name,
@@ -46,21 +48,15 @@ from betterproto.compile.naming import (
from betterproto.lib.google.protobuf import (
DescriptorProto,
EnumDescriptorProto,
FileDescriptorProto,
MethodDescriptorProto,
Field,
FieldDescriptorProto,
FieldDescriptorProtoType,
FieldDescriptorProtoLabel,
FieldDescriptorProtoType,
FileDescriptorProto,
MethodDescriptorProto,
)
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
import re
import textwrap
from dataclasses import dataclass, field
from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union
from ..casing import sanitize_name
from ..compile.importing import get_type_reference, parse_source_type_name
from ..compile.naming import (
@@ -69,7 +65,6 @@ from ..compile.naming import (
pythonize_method_name,
)
# Create a unique placeholder to deal with
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
PLACEHOLDER = object()
@@ -675,12 +670,8 @@ class ServiceMethodCompiler(ProtoContentBase):
self.parent.methods.append(self)
# Check for imports
if self.py_input_message:
for f in self.py_input_message.fields:
f.add_imports_to(self.output_file)
if "Optional" in self.py_output_message_type:
self.output_file.typing_imports.add("Optional")
self.mutable_default_args # ensure this is called before rendering
# Check for Async imports
if self.client_streaming:
@@ -694,37 +685,6 @@ class ServiceMethodCompiler(ProtoContentBase):
super().__post_init__() # check for unset fields
@property
def mutable_default_args(self) -> Dict[str, str]:
"""Handle mutable default arguments.
Returns a list of tuples containing the name and default value
for arguments to this message who's default value is mutable.
The defaults are swapped out for None and replaced back inside
the method's body.
Reference:
https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
Returns
-------
Dict[str, str]
Name and actual default value (as a string)
for each argument with mutable default values.
"""
mutable_default_args = {}
if self.py_input_message:
for f in self.py_input_message.fields:
if (
not self.client_streaming
and f.default_value_string != "None"
and f.mutable
):
mutable_default_args[f.py_name] = f.default_value_string
self.output_file.typing_imports.add("Optional")
return mutable_default_args
@property
def py_name(self) -> str:
"""Pythonized method name."""
@@ -782,6 +742,17 @@ class ServiceMethodCompiler(ProtoContentBase):
source_type=self.proto_obj.input_type,
).strip('"')
@property
def py_input_message_param(self) -> str:
"""Param name corresponding to py_input_message_type.
Returns
-------
str
Param name corresponding to py_input_message_type.
"""
return pythonize_field_name(self.py_input_message_type)
@property
def py_output_message_type(self) -> str:
"""String representation of the Python type corresponding to the

View File

@@ -79,51 +79,21 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
{%- if method.py_input_message and method.py_input_message.fields -%}, *,
{%- for field in method.py_input_message.fields -%}
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
Optional[{{ field.annotation }}]
{%- else -%}
{{ field.annotation }}
{%- endif -%} =
{%- if field.py_name not in method.mutable_default_args -%}
{{ field.default_value_string }}
{%- else -%}
None
{% endif -%}
{%- if not loop.last %}, {% endif -%}
{%- endfor -%}
{%- endif -%}
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, request_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
{%- endif -%}
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
{{ method.comment }}
{% endif %}
{%- for py_name, zero in method.mutable_default_args.items() %}
{{ py_name }} = {{ py_name }} or {{ zero }}
{% endfor %}
{% if not method.client_streaming %}
request = {{ method.py_input_message_type }}()
{% for field in method.py_input_message.fields %}
{% if field.field_type == 'message' %}
if {{ field.py_name }} is not None:
request.{{ field.py_name }} = {{ field.py_name }}
{% else %}
request.{{ field.py_name }} = {{ field.py_name }}
{% endif %}
{% endfor %}
{% endif %}
{% if method.server_streaming %}
{% if method.client_streaming %}
async for response in self._stream_stream(
"{{ method.route }}",
request_iterator,
{{ method.py_input_message_param }}_iterator,
{{ method.py_input_message_type }},
{{ method.py_output_message_type.strip('"') }},
):
@@ -131,7 +101,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% else %}{# i.e. not client streaming #}
async for response in self._unary_stream(
"{{ method.route }}",
request,
{{ method.py_input_message_param }},
{{ method.py_output_message_type.strip('"') }},
):
yield response
@@ -141,14 +111,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% if method.client_streaming %}
return await self._stream_unary(
"{{ method.route }}",
request_iterator,
{{ method.py_input_message_param }}_iterator,
{{ method.py_input_message_type }},
{{ method.py_output_message_type.strip('"') }}
)
{% else %}{# i.e. not client streaming #}
return await self._unary_unary(
"{{ method.route }}",
request,
{{ method.py_input_message_param }},
{{ method.py_output_message_type.strip('"') }}
)
{% endif %}{# client streaming #}
@@ -167,19 +137,10 @@ class {{ service.py_name }}Base(ServiceBase):
{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
{%- if method.py_input_message and method.py_input_message.fields -%},
{%- for field in method.py_input_message.fields -%}
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
Optional[{{ field.annotation }}]
{%- else -%}
{{ field.annotation }}
{%- endif -%}
{%- if not loop.last %}, {% endif -%}
{%- endfor -%}
{%- endif -%}
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, request_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
, {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
{%- endif -%}
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
@@ -194,25 +155,17 @@ class {{ service.py_name }}Base(ServiceBase):
async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None:
{% if not method.client_streaming %}
request = await stream.recv_message()
request_kwargs = {
{% for field in method.py_input_message.fields %}
"{{ field.py_name }}": request.{{ field.py_name }},
{% endfor %}
}
{% else %}
request_kwargs = {"request_iterator": stream.__aiter__()}
request = stream.__aiter__()
{% endif %}
{% if not method.server_streaming %}
response = await self.{{ method.py_name }}(**request_kwargs)
response = await self.{{ method.py_name }}(request)
await stream.send_message(response)
{% else %}
await self._call_rpc_handler_server_stream(
self.{{ method.py_name }},
stream,
request_kwargs,
request,
)
{% endif %}