Client and Service Stubs take 1 request parameter, not one for each field (#311)

This commit is contained in:
efokschaner 2022-01-17 10:58:57 -08:00 committed by GitHub
parent 6dd7baa26c
commit d260f071e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 140 additions and 193 deletions

1
.gitignore vendored
View File

@ -17,3 +17,4 @@ output
.venv
.asv
venv
.devcontainer

View File

@ -7,6 +7,32 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Versions suffixed with `b*` are in `beta` and can be installed with `pip install --pre betterproto`.
## [Unreleased]
- **Breaking**: Client and Service Stubs no longer pack and unpack the input message fields as parameters.
Update your client calls and server handlers as follows:
Clients before:
```py
response = await service.echo(value="hello", extra_times=1)
```
Clients after:
```py
response = await service.echo(EchoRequest(value="hello", extra_times=1))
```
Servers before:
```py
async def echo(self, value: str, extra_times: int) -> EchoResponse:
```
Servers after:
```py
async def echo(self, echo_request: EchoRequest) -> EchoResponse:
# Use echo_request.value
# Use echo_request.extra_times
```
## [2.0.0b4] - 2022-01-03
- **Breaking**: the minimum Python version has been bumped to `3.6.2`

View File

@ -177,10 +177,10 @@ from grpclib.client import Channel
async def main():
channel = Channel(host="127.0.0.1", port=50051)
service = echo.EchoStub(channel)
response = await service.echo(value="hello", extra_times=1)
response = await service.echo(echo.EchoRequest(value="hello", extra_times=1))
print(response)
async for response in service.echo_stream(value="hello", extra_times=1):
async for response in service.echo_stream(echo.EchoRequest(value="hello", extra_times=1)):
print(response)
# don't forget to close the channel when done!
@ -206,18 +206,18 @@ service methods:
```python
import asyncio
from echo import EchoBase, EchoResponse, EchoStreamResponse
from echo import EchoBase, EchoRequest, EchoResponse, EchoStreamResponse
from grpclib.server import Server
from typing import AsyncIterator
class EchoService(EchoBase):
async def echo(self, value: str, extra_times: int) -> "EchoResponse":
return EchoResponse([value for _ in range(extra_times)])
async def echo(self, echo_request: "EchoRequest") -> "EchoResponse":
return EchoResponse([echo_request.value for _ in range(echo_request.extra_times)])
async def echo_stream(self, value: str, extra_times: int) -> AsyncIterator["EchoStreamResponse"]:
for _ in range(extra_times):
yield EchoStreamResponse(value)
async def echo_stream(self, echo_request: "EchoRequest") -> AsyncIterator["EchoStreamResponse"]:
for _ in range(echo_request.extra_times):
yield EchoStreamResponse(echo_request.value)
async def main():

View File

@ -111,7 +111,7 @@ omit = ["betterproto/tests/*"]
legacy_tox_ini = """
[tox]
isolated_build = true
envlist = py36, py37, py38
envlist = py36, py37, py38, py310
[testenv]
whitelist_externals = poetry

View File

@ -1,6 +1,6 @@
from abc import ABC
from collections.abc import AsyncIterable
from typing import Callable, Any, Dict
from typing import Any, Callable, Dict
import grpclib
import grpclib.server
@ -15,10 +15,10 @@ class ServiceBase(ABC):
self,
handler: Callable,
stream: grpclib.server.Stream,
request_kwargs: Dict[str, Any],
request: Any,
) -> None:
response_iter = handler(**request_kwargs)
response_iter = handler(request)
# check if response is actually an AsyncIterator
# this might be false if the method just returns without
# yielding at least once

View File

@ -31,13 +31,15 @@ reference to `A` to `B`'s `fields` attribute.
import builtins
import re
import textwrap
from dataclasses import dataclass, field
from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union
import betterproto
from betterproto import which_one_of
from betterproto.casing import sanitize_name
from betterproto.compile.importing import (
get_type_reference,
parse_source_type_name,
)
from betterproto.compile.importing import get_type_reference, parse_source_type_name
from betterproto.compile.naming import (
pythonize_class_name,
pythonize_field_name,
@ -46,21 +48,15 @@ from betterproto.compile.naming import (
from betterproto.lib.google.protobuf import (
DescriptorProto,
EnumDescriptorProto,
FileDescriptorProto,
MethodDescriptorProto,
Field,
FieldDescriptorProto,
FieldDescriptorProtoType,
FieldDescriptorProtoLabel,
FieldDescriptorProtoType,
FileDescriptorProto,
MethodDescriptorProto,
)
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
import re
import textwrap
from dataclasses import dataclass, field
from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union
from ..casing import sanitize_name
from ..compile.importing import get_type_reference, parse_source_type_name
from ..compile.naming import (
@ -69,7 +65,6 @@ from ..compile.naming import (
pythonize_method_name,
)
# Create a unique placeholder to deal with
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
PLACEHOLDER = object()
@ -675,12 +670,8 @@ class ServiceMethodCompiler(ProtoContentBase):
self.parent.methods.append(self)
# Check for imports
if self.py_input_message:
for f in self.py_input_message.fields:
f.add_imports_to(self.output_file)
if "Optional" in self.py_output_message_type:
self.output_file.typing_imports.add("Optional")
self.mutable_default_args # ensure this is called before rendering
# Check for Async imports
if self.client_streaming:
@ -694,37 +685,6 @@ class ServiceMethodCompiler(ProtoContentBase):
super().__post_init__() # check for unset fields
@property
def mutable_default_args(self) -> Dict[str, str]:
"""Handle mutable default arguments.
Returns a list of tuples containing the name and default value
for arguments to this message who's default value is mutable.
The defaults are swapped out for None and replaced back inside
the method's body.
Reference:
https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
Returns
-------
Dict[str, str]
Name and actual default value (as a string)
for each argument with mutable default values.
"""
mutable_default_args = {}
if self.py_input_message:
for f in self.py_input_message.fields:
if (
not self.client_streaming
and f.default_value_string != "None"
and f.mutable
):
mutable_default_args[f.py_name] = f.default_value_string
self.output_file.typing_imports.add("Optional")
return mutable_default_args
@property
def py_name(self) -> str:
"""Pythonized method name."""
@ -782,6 +742,17 @@ class ServiceMethodCompiler(ProtoContentBase):
source_type=self.proto_obj.input_type,
).strip('"')
@property
def py_input_message_param(self) -> str:
"""Param name corresponding to py_input_message_type.
Returns
-------
str
Param name corresponding to py_input_message_type.
"""
return pythonize_field_name(self.py_input_message_type)
@property
def py_output_message_type(self) -> str:
"""String representation of the Python type corresponding to the

View File

@ -79,51 +79,21 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
{%- if method.py_input_message and method.py_input_message.fields -%}, *,
{%- for field in method.py_input_message.fields -%}
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
Optional[{{ field.annotation }}]
{%- else -%}
{{ field.annotation }}
{%- endif -%} =
{%- if field.py_name not in method.mutable_default_args -%}
{{ field.default_value_string }}
{%- else -%}
None
{% endif -%}
{%- if not loop.last %}, {% endif -%}
{%- endfor -%}
{%- endif -%}
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, request_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
{%- endif -%}
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
{{ method.comment }}
{% endif %}
{%- for py_name, zero in method.mutable_default_args.items() %}
{{ py_name }} = {{ py_name }} or {{ zero }}
{% endfor %}
{% if not method.client_streaming %}
request = {{ method.py_input_message_type }}()
{% for field in method.py_input_message.fields %}
{% if field.field_type == 'message' %}
if {{ field.py_name }} is not None:
request.{{ field.py_name }} = {{ field.py_name }}
{% else %}
request.{{ field.py_name }} = {{ field.py_name }}
{% endif %}
{% endfor %}
{% endif %}
{% if method.server_streaming %}
{% if method.client_streaming %}
async for response in self._stream_stream(
"{{ method.route }}",
request_iterator,
{{ method.py_input_message_param }}_iterator,
{{ method.py_input_message_type }},
{{ method.py_output_message_type.strip('"') }},
):
@ -131,7 +101,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% else %}{# i.e. not client streaming #}
async for response in self._unary_stream(
"{{ method.route }}",
request,
{{ method.py_input_message_param }},
{{ method.py_output_message_type.strip('"') }},
):
yield response
@ -141,14 +111,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% if method.client_streaming %}
return await self._stream_unary(
"{{ method.route }}",
request_iterator,
{{ method.py_input_message_param }}_iterator,
{{ method.py_input_message_type }},
{{ method.py_output_message_type.strip('"') }}
)
{% else %}{# i.e. not client streaming #}
return await self._unary_unary(
"{{ method.route }}",
request,
{{ method.py_input_message_param }},
{{ method.py_output_message_type.strip('"') }}
)
{% endif %}{# client streaming #}
@ -167,19 +137,10 @@ class {{ service.py_name }}Base(ServiceBase):
{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
{%- if method.py_input_message and method.py_input_message.fields -%},
{%- for field in method.py_input_message.fields -%}
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
Optional[{{ field.annotation }}]
{%- else -%}
{{ field.annotation }}
{%- endif -%}
{%- if not loop.last %}, {% endif -%}
{%- endfor -%}
{%- endif -%}
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, request_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
, {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
{%- endif -%}
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
@ -194,25 +155,17 @@ class {{ service.py_name }}Base(ServiceBase):
async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None:
{% if not method.client_streaming %}
request = await stream.recv_message()
request_kwargs = {
{% for field in method.py_input_message.fields %}
"{{ field.py_name }}": request.{{ field.py_name }},
{% endfor %}
}
{% else %}
request_kwargs = {"request_iterator": stream.__aiter__()}
request = stream.__aiter__()
{% endif %}
{% if not method.server_streaming %}
response = await self.{{ method.py_name }}(**request_kwargs)
response = await self.{{ method.py_name }}(request)
await stream.send_message(response)
{% else %}
await self._call_rpc_handler_server_stream(
self.{{ method.py_name }},
stream,
request_kwargs,
request,
)
{% endif %}

View File

@ -1,23 +1,24 @@
import asyncio
import sys
import grpclib
import grpclib.metadata
import grpclib.server
import pytest
from betterproto.grpc.util.async_channel import AsyncChannel
from grpclib.testing import ChannelFor
from tests.output_betterproto.service.service import (
DoThingRequest,
DoThingResponse,
GetThingRequest,
TestStub as ThingServiceClient,
)
import grpclib
import grpclib.metadata
import grpclib.server
from grpclib.testing import ChannelFor
import pytest
from betterproto.grpc.util.async_channel import AsyncChannel
from tests.output_betterproto.service.service import TestStub as ThingServiceClient
from .thing_service import ThingService
async def _test_client(client, name="clean room", **kwargs):
response = await client.do_thing(name=name)
async def _test_client(client: ThingServiceClient, name="clean room", **kwargs):
response = await client.do_thing(DoThingRequest(name=name))
assert response.names == [name]
@ -62,7 +63,7 @@ async def test_trailer_only_error_unary_unary(
)
async with ChannelFor([service]) as channel:
with pytest.raises(grpclib.exceptions.GRPCError) as e:
await ThingServiceClient(channel).do_thing(name="something")
await ThingServiceClient(channel).do_thing(DoThingRequest(name="something"))
assert e.value.status == grpclib.Status.UNAUTHENTICATED
@ -80,7 +81,7 @@ async def test_trailer_only_error_stream_unary(
async with ChannelFor([service]) as channel:
with pytest.raises(grpclib.exceptions.GRPCError) as e:
await ThingServiceClient(channel).do_many_things(
request_iterator=[DoThingRequest(name="something")]
do_thing_request_iterator=[DoThingRequest(name="something")]
)
await _test_client(ThingServiceClient(channel))
assert e.value.status == grpclib.Status.UNAUTHENTICATED
@ -178,7 +179,9 @@ async def test_async_gen_for_unary_stream_request():
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
expected_versions = [5, 4, 3, 2, 1]
async for response in client.get_thing_versions(name=thing_name):
async for response in client.get_thing_versions(
GetThingRequest(name=thing_name)
):
assert response.name == thing_name
assert response.version == expected_versions.pop()

View File

@ -1,49 +1,48 @@
from typing import AsyncIterator, AsyncIterable
from typing import AsyncIterable, AsyncIterator
import pytest
from grpclib.testing import ChannelFor
from tests.output_betterproto.example_service.example_service import (
TestBase,
TestStub,
ExampleRequest,
ExampleResponse,
TestBase,
TestStub,
)
class ExampleService(TestBase):
async def example_unary_unary(
self, example_string: str, example_integer: int
self, example_request: ExampleRequest
) -> "ExampleResponse":
return ExampleResponse(
example_string=example_string,
example_integer=example_integer,
example_string=example_request.example_string,
example_integer=example_request.example_integer,
)
async def example_unary_stream(
self, example_string: str, example_integer: int
self, example_request: ExampleRequest
) -> AsyncIterator["ExampleResponse"]:
response = ExampleResponse(
example_string=example_string,
example_integer=example_integer,
example_string=example_request.example_string,
example_integer=example_request.example_integer,
)
yield response
yield response
yield response
async def example_stream_unary(
self, request_iterator: AsyncIterator["ExampleRequest"]
self, example_request_iterator: AsyncIterator["ExampleRequest"]
) -> "ExampleResponse":
async for example_request in request_iterator:
async for example_request in example_request_iterator:
return ExampleResponse(
example_string=example_request.example_string,
example_integer=example_request.example_integer,
)
async def example_stream_stream(
self, request_iterator: AsyncIterator["ExampleRequest"]
self, example_request_iterator: AsyncIterator["ExampleRequest"]
) -> AsyncIterator["ExampleResponse"]:
async for example_request in request_iterator:
async for example_request in example_request_iterator:
yield ExampleResponse(
example_string=example_request.example_string,
example_integer=example_request.example_integer,
@ -52,44 +51,32 @@ class ExampleService(TestBase):
@pytest.mark.asyncio
async def test_calls_with_different_cardinalities():
test_string = "test string"
test_int = 42
example_request = ExampleRequest("test string", 42)
async with ChannelFor([ExampleService()]) as channel:
stub = TestStub(channel)
# unary unary
response = await stub.example_unary_unary(
example_string="test string",
example_integer=42,
)
assert response.example_string == test_string
assert response.example_integer == test_int
response = await stub.example_unary_unary(example_request)
assert response.example_string == example_request.example_string
assert response.example_integer == example_request.example_integer
# unary stream
async for response in stub.example_unary_stream(
example_string="test string",
example_integer=42,
):
assert response.example_string == test_string
assert response.example_integer == test_int
async for response in stub.example_unary_stream(example_request):
assert response.example_string == example_request.example_string
assert response.example_integer == example_request.example_integer
# stream unary
request = ExampleRequest(
example_string=test_string,
example_integer=42,
)
async def request_iterator():
yield request
yield request
yield request
yield example_request
yield example_request
yield example_request
response = await stub.example_stream_unary(request_iterator())
assert response.example_string == test_string
assert response.example_integer == test_int
assert response.example_string == example_request.example_string
assert response.example_integer == example_request.example_integer
# stream stream
async for response in stub.example_stream_stream(request_iterator()):
assert response.example_string == test_string
assert response.example_integer == test_int
assert response.example_string == example_request.example_string
assert response.example_integer == example_request.example_integer

View File

@ -2,9 +2,8 @@ from typing import Any, Callable, Optional
import betterproto.lib.google.protobuf as protobuf
import pytest
from tests.mocks import MockChannel
from tests.output_betterproto.googletypes_response import TestStub
from tests.output_betterproto.googletypes_response import Input, TestStub
test_cases = [
(TestStub.get_double, protobuf.DoubleValue, 2.5),
@ -22,14 +21,15 @@ test_cases = [
@pytest.mark.asyncio
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_channel_receives_wrapped_type(
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
):
wrapped_value = wrapper_class()
wrapped_value.value = value
channel = MockChannel(responses=[wrapped_value])
service = TestStub(channel)
method_param = Input()
await service_method(service)
await service_method(service, method_param)
assert channel.requests[0]["response_type"] != Optional[type(value)]
assert channel.requests[0]["response_type"] == type(wrapped_value)
@ -39,7 +39,7 @@ async def test_channel_receives_wrapped_type(
@pytest.mark.xfail
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_service_unwraps_response(
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
):
"""
grpclib does not unwrap wrapper values returned by services
@ -47,8 +47,9 @@ async def test_service_unwraps_response(
wrapped_value = wrapper_class()
wrapped_value.value = value
service = TestStub(MockChannel(responses=[wrapped_value]))
method_param = Input()
response_value = await service_method(service)
response_value = await service_method(service, method_param)
assert response_value == value
assert type(response_value) == type(value)

View File

@ -1,7 +1,7 @@
import pytest
from tests.mocks import MockChannel
from tests.output_betterproto.googletypes_response_embedded import (
Input,
Output,
TestStub,
)
@ -26,7 +26,7 @@ async def test_service_passes_through_unwrapped_values_embedded_in_response():
)
service = TestStub(MockChannel(responses=[output]))
response = await service.get_output()
response = await service.get_output(Input())
assert response.double_value == 10.0
assert response.float_value == 12.0

View File

@ -1,17 +1,21 @@
import pytest
from tests.mocks import MockChannel
from tests.output_betterproto.import_service_input_message import (
NestedRequestMessage,
RequestMessage,
RequestResponse,
TestStub,
)
from tests.output_betterproto.import_service_input_message.child import (
ChildRequestMessage,
)
@pytest.mark.asyncio
async def test_service_correctly_imports_reference_message():
mock_response = RequestResponse(value=10)
service = TestStub(MockChannel([mock_response]))
response = await service.do_thing(argument=1)
response = await service.do_thing(RequestMessage(1))
assert mock_response == response
@ -19,7 +23,7 @@ async def test_service_correctly_imports_reference_message():
async def test_service_correctly_imports_reference_message_from_child_package():
mock_response = RequestResponse(value=10)
service = TestStub(MockChannel([mock_response]))
response = await service.do_thing2(child_argument=1)
response = await service.do_thing2(ChildRequestMessage(1))
assert mock_response == response
@ -27,5 +31,5 @@ async def test_service_correctly_imports_reference_message_from_child_package():
async def test_service_correctly_imports_nested_reference():
mock_response = RequestResponse(value=10)
service = TestStub(MockChannel([mock_response]))
response = await service.do_thing3(nested_argument=1)
response = await service.do_thing3(NestedRequestMessage(1))
assert mock_response == response

View File

@ -1,8 +1,9 @@
import betterproto
from dataclasses import dataclass
from typing import Optional, List, Dict
from datetime import datetime
from inspect import signature
from inspect import Parameter, signature
from typing import Dict, List, Optional
import betterproto
def test_has_field():
@ -349,10 +350,8 @@ def test_recursive_message():
def test_recursive_message_defaults():
from tests.output_betterproto.recursivemessage import (
Test as RecursiveMessage,
Intermediate,
)
from tests.output_betterproto.recursivemessage import Intermediate
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
@ -479,8 +478,10 @@ def test_iso_datetime_list():
assert all([isinstance(item, datetime) for item in msg.timestamps])
def test_enum_service_argument__expected_default_value():
from tests.output_betterproto.service.service import ThingType, TestStub
def test_service_argument__expected_parameter():
from tests.output_betterproto.service.service import TestStub
sig = signature(TestStub.do_thing)
assert sig.parameters["type"].default == ThingType.UNKNOWN
do_thing_request_parameter = sig.parameters["do_thing_request"]
assert do_thing_request_parameter.default is Parameter.empty
assert do_thing_request_parameter.annotation == "DoThingRequest"