Client and Service Stubs take 1 request parameter, not one for each field (#311)

This commit is contained in:
efokschaner 2022-01-17 10:58:57 -08:00 committed by GitHub
parent 6dd7baa26c
commit d260f071e0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 140 additions and 193 deletions

1
.gitignore vendored
View File

@ -17,3 +17,4 @@ output
.venv .venv
.asv .asv
venv venv
.devcontainer

View File

@ -7,6 +7,32 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Versions suffixed with `b*` are in `beta` and can be installed with `pip install --pre betterproto`. - Versions suffixed with `b*` are in `beta` and can be installed with `pip install --pre betterproto`.
## [Unreleased]
- **Breaking**: Client and Service Stubs no longer pack and unpack the input message fields as parameters.
Update your client calls and server handlers as follows:
Clients before:
```py
response = await service.echo(value="hello", extra_times=1)
```
Clients after:
```py
response = await service.echo(EchoRequest(value="hello", extra_times=1))
```
Servers before:
```py
async def echo(self, value: str, extra_times: int) -> EchoResponse:
```
Servers after:
```py
async def echo(self, echo_request: EchoRequest) -> EchoResponse:
# Use echo_request.value
# Use echo_request.extra_times
```
## [2.0.0b4] - 2022-01-03 ## [2.0.0b4] - 2022-01-03
- **Breaking**: the minimum Python version has been bumped to `3.6.2` - **Breaking**: the minimum Python version has been bumped to `3.6.2`

View File

@ -177,10 +177,10 @@ from grpclib.client import Channel
async def main(): async def main():
channel = Channel(host="127.0.0.1", port=50051) channel = Channel(host="127.0.0.1", port=50051)
service = echo.EchoStub(channel) service = echo.EchoStub(channel)
response = await service.echo(value="hello", extra_times=1) response = await service.echo(echo.EchoRequest(value="hello", extra_times=1))
print(response) print(response)
async for response in service.echo_stream(value="hello", extra_times=1): async for response in service.echo_stream(echo.EchoRequest(value="hello", extra_times=1)):
print(response) print(response)
# don't forget to close the channel when done! # don't forget to close the channel when done!
@ -206,18 +206,18 @@ service methods:
```python ```python
import asyncio import asyncio
from echo import EchoBase, EchoResponse, EchoStreamResponse from echo import EchoBase, EchoRequest, EchoResponse, EchoStreamResponse
from grpclib.server import Server from grpclib.server import Server
from typing import AsyncIterator from typing import AsyncIterator
class EchoService(EchoBase): class EchoService(EchoBase):
async def echo(self, value: str, extra_times: int) -> "EchoResponse": async def echo(self, echo_request: "EchoRequest") -> "EchoResponse":
return EchoResponse([value for _ in range(extra_times)]) return EchoResponse([echo_request.value for _ in range(echo_request.extra_times)])
async def echo_stream(self, value: str, extra_times: int) -> AsyncIterator["EchoStreamResponse"]: async def echo_stream(self, echo_request: "EchoRequest") -> AsyncIterator["EchoStreamResponse"]:
for _ in range(extra_times): for _ in range(echo_request.extra_times):
yield EchoStreamResponse(value) yield EchoStreamResponse(echo_request.value)
async def main(): async def main():

View File

@ -111,7 +111,7 @@ omit = ["betterproto/tests/*"]
legacy_tox_ini = """ legacy_tox_ini = """
[tox] [tox]
isolated_build = true isolated_build = true
envlist = py36, py37, py38 envlist = py36, py37, py38, py310
[testenv] [testenv]
whitelist_externals = poetry whitelist_externals = poetry

View File

@ -1,6 +1,6 @@
from abc import ABC from abc import ABC
from collections.abc import AsyncIterable from collections.abc import AsyncIterable
from typing import Callable, Any, Dict from typing import Any, Callable, Dict
import grpclib import grpclib
import grpclib.server import grpclib.server
@ -15,10 +15,10 @@ class ServiceBase(ABC):
self, self,
handler: Callable, handler: Callable,
stream: grpclib.server.Stream, stream: grpclib.server.Stream,
request_kwargs: Dict[str, Any], request: Any,
) -> None: ) -> None:
response_iter = handler(**request_kwargs) response_iter = handler(request)
# check if response is actually an AsyncIterator # check if response is actually an AsyncIterator
# this might be false if the method just returns without # this might be false if the method just returns without
# yielding at least once # yielding at least once

View File

@ -31,13 +31,15 @@ reference to `A` to `B`'s `fields` attribute.
import builtins import builtins
import re
import textwrap
from dataclasses import dataclass, field
from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union
import betterproto import betterproto
from betterproto import which_one_of from betterproto import which_one_of
from betterproto.casing import sanitize_name from betterproto.casing import sanitize_name
from betterproto.compile.importing import ( from betterproto.compile.importing import get_type_reference, parse_source_type_name
get_type_reference,
parse_source_type_name,
)
from betterproto.compile.naming import ( from betterproto.compile.naming import (
pythonize_class_name, pythonize_class_name,
pythonize_field_name, pythonize_field_name,
@ -46,21 +48,15 @@ from betterproto.compile.naming import (
from betterproto.lib.google.protobuf import ( from betterproto.lib.google.protobuf import (
DescriptorProto, DescriptorProto,
EnumDescriptorProto, EnumDescriptorProto,
FileDescriptorProto,
MethodDescriptorProto,
Field, Field,
FieldDescriptorProto, FieldDescriptorProto,
FieldDescriptorProtoType,
FieldDescriptorProtoLabel, FieldDescriptorProtoLabel,
FieldDescriptorProtoType,
FileDescriptorProto,
MethodDescriptorProto,
) )
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
import re
import textwrap
from dataclasses import dataclass, field
from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union
from ..casing import sanitize_name from ..casing import sanitize_name
from ..compile.importing import get_type_reference, parse_source_type_name from ..compile.importing import get_type_reference, parse_source_type_name
from ..compile.naming import ( from ..compile.naming import (
@ -69,7 +65,6 @@ from ..compile.naming import (
pythonize_method_name, pythonize_method_name,
) )
# Create a unique placeholder to deal with # Create a unique placeholder to deal with
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses # https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
PLACEHOLDER = object() PLACEHOLDER = object()
@ -675,12 +670,8 @@ class ServiceMethodCompiler(ProtoContentBase):
self.parent.methods.append(self) self.parent.methods.append(self)
# Check for imports # Check for imports
if self.py_input_message:
for f in self.py_input_message.fields:
f.add_imports_to(self.output_file)
if "Optional" in self.py_output_message_type: if "Optional" in self.py_output_message_type:
self.output_file.typing_imports.add("Optional") self.output_file.typing_imports.add("Optional")
self.mutable_default_args # ensure this is called before rendering
# Check for Async imports # Check for Async imports
if self.client_streaming: if self.client_streaming:
@ -694,37 +685,6 @@ class ServiceMethodCompiler(ProtoContentBase):
super().__post_init__() # check for unset fields super().__post_init__() # check for unset fields
@property
def mutable_default_args(self) -> Dict[str, str]:
"""Handle mutable default arguments.
Returns a list of tuples containing the name and default value
for arguments to this message who's default value is mutable.
The defaults are swapped out for None and replaced back inside
the method's body.
Reference:
https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
Returns
-------
Dict[str, str]
Name and actual default value (as a string)
for each argument with mutable default values.
"""
mutable_default_args = {}
if self.py_input_message:
for f in self.py_input_message.fields:
if (
not self.client_streaming
and f.default_value_string != "None"
and f.mutable
):
mutable_default_args[f.py_name] = f.default_value_string
self.output_file.typing_imports.add("Optional")
return mutable_default_args
@property @property
def py_name(self) -> str: def py_name(self) -> str:
"""Pythonized method name.""" """Pythonized method name."""
@ -782,6 +742,17 @@ class ServiceMethodCompiler(ProtoContentBase):
source_type=self.proto_obj.input_type, source_type=self.proto_obj.input_type,
).strip('"') ).strip('"')
@property
def py_input_message_param(self) -> str:
"""Param name corresponding to py_input_message_type.
Returns
-------
str
Param name corresponding to py_input_message_type.
"""
return pythonize_field_name(self.py_input_message_type)
@property @property
def py_output_message_type(self) -> str: def py_output_message_type(self) -> str:
"""String representation of the Python type corresponding to the """String representation of the Python type corresponding to the

View File

@ -79,51 +79,21 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% for method in service.methods %} {% for method in service.methods %}
async def {{ method.py_name }}(self async def {{ method.py_name }}(self
{%- if not method.client_streaming -%} {%- if not method.client_streaming -%}
{%- if method.py_input_message and method.py_input_message.fields -%}, *, {%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
{%- for field in method.py_input_message.fields -%}
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
Optional[{{ field.annotation }}]
{%- else -%}
{{ field.annotation }}
{%- endif -%} =
{%- if field.py_name not in method.mutable_default_args -%}
{{ field.default_value_string }}
{%- else -%}
None
{% endif -%}
{%- if not loop.last %}, {% endif -%}
{%- endfor -%}
{%- endif -%}
{%- else -%} {%- else -%}
{# Client streaming: need a request iterator instead #} {# Client streaming: need a request iterator instead #}
, request_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]] , {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
{%- endif -%} {%- endif -%}
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %} {% if method.comment %}
{{ method.comment }} {{ method.comment }}
{% endif %} {% endif %}
{%- for py_name, zero in method.mutable_default_args.items() %}
{{ py_name }} = {{ py_name }} or {{ zero }}
{% endfor %}
{% if not method.client_streaming %}
request = {{ method.py_input_message_type }}()
{% for field in method.py_input_message.fields %}
{% if field.field_type == 'message' %}
if {{ field.py_name }} is not None:
request.{{ field.py_name }} = {{ field.py_name }}
{% else %}
request.{{ field.py_name }} = {{ field.py_name }}
{% endif %}
{% endfor %}
{% endif %}
{% if method.server_streaming %} {% if method.server_streaming %}
{% if method.client_streaming %} {% if method.client_streaming %}
async for response in self._stream_stream( async for response in self._stream_stream(
"{{ method.route }}", "{{ method.route }}",
request_iterator, {{ method.py_input_message_param }}_iterator,
{{ method.py_input_message_type }}, {{ method.py_input_message_type }},
{{ method.py_output_message_type.strip('"') }}, {{ method.py_output_message_type.strip('"') }},
): ):
@ -131,7 +101,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% else %}{# i.e. not client streaming #} {% else %}{# i.e. not client streaming #}
async for response in self._unary_stream( async for response in self._unary_stream(
"{{ method.route }}", "{{ method.route }}",
request, {{ method.py_input_message_param }},
{{ method.py_output_message_type.strip('"') }}, {{ method.py_output_message_type.strip('"') }},
): ):
yield response yield response
@ -141,14 +111,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% if method.client_streaming %} {% if method.client_streaming %}
return await self._stream_unary( return await self._stream_unary(
"{{ method.route }}", "{{ method.route }}",
request_iterator, {{ method.py_input_message_param }}_iterator,
{{ method.py_input_message_type }}, {{ method.py_input_message_type }},
{{ method.py_output_message_type.strip('"') }} {{ method.py_output_message_type.strip('"') }}
) )
{% else %}{# i.e. not client streaming #} {% else %}{# i.e. not client streaming #}
return await self._unary_unary( return await self._unary_unary(
"{{ method.route }}", "{{ method.route }}",
request, {{ method.py_input_message_param }},
{{ method.py_output_message_type.strip('"') }} {{ method.py_output_message_type.strip('"') }}
) )
{% endif %}{# client streaming #} {% endif %}{# client streaming #}
@ -167,19 +137,10 @@ class {{ service.py_name }}Base(ServiceBase):
{% for method in service.methods %} {% for method in service.methods %}
async def {{ method.py_name }}(self async def {{ method.py_name }}(self
{%- if not method.client_streaming -%} {%- if not method.client_streaming -%}
{%- if method.py_input_message and method.py_input_message.fields -%}, {%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
{%- for field in method.py_input_message.fields -%}
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
Optional[{{ field.annotation }}]
{%- else -%}
{{ field.annotation }}
{%- endif -%}
{%- if not loop.last %}, {% endif -%}
{%- endfor -%}
{%- endif -%}
{%- else -%} {%- else -%}
{# Client streaming: need a request iterator instead #} {# Client streaming: need a request iterator instead #}
, request_iterator: AsyncIterator["{{ method.py_input_message_type }}"] , {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
{%- endif -%} {%- endif -%}
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %} {% if method.comment %}
@ -194,25 +155,17 @@ class {{ service.py_name }}Base(ServiceBase):
async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None: async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None:
{% if not method.client_streaming %} {% if not method.client_streaming %}
request = await stream.recv_message() request = await stream.recv_message()
request_kwargs = {
{% for field in method.py_input_message.fields %}
"{{ field.py_name }}": request.{{ field.py_name }},
{% endfor %}
}
{% else %} {% else %}
request_kwargs = {"request_iterator": stream.__aiter__()} request = stream.__aiter__()
{% endif %} {% endif %}
{% if not method.server_streaming %} {% if not method.server_streaming %}
response = await self.{{ method.py_name }}(**request_kwargs) response = await self.{{ method.py_name }}(request)
await stream.send_message(response) await stream.send_message(response)
{% else %} {% else %}
await self._call_rpc_handler_server_stream( await self._call_rpc_handler_server_stream(
self.{{ method.py_name }}, self.{{ method.py_name }},
stream, stream,
request_kwargs, request,
) )
{% endif %} {% endif %}

View File

@ -1,23 +1,24 @@
import asyncio import asyncio
import sys import sys
import grpclib
import grpclib.metadata
import grpclib.server
import pytest
from betterproto.grpc.util.async_channel import AsyncChannel
from grpclib.testing import ChannelFor
from tests.output_betterproto.service.service import ( from tests.output_betterproto.service.service import (
DoThingRequest, DoThingRequest,
DoThingResponse, DoThingResponse,
GetThingRequest, GetThingRequest,
TestStub as ThingServiceClient,
) )
import grpclib from tests.output_betterproto.service.service import TestStub as ThingServiceClient
import grpclib.metadata
import grpclib.server
from grpclib.testing import ChannelFor
import pytest
from betterproto.grpc.util.async_channel import AsyncChannel
from .thing_service import ThingService from .thing_service import ThingService
async def _test_client(client, name="clean room", **kwargs): async def _test_client(client: ThingServiceClient, name="clean room", **kwargs):
response = await client.do_thing(name=name) response = await client.do_thing(DoThingRequest(name=name))
assert response.names == [name] assert response.names == [name]
@ -62,7 +63,7 @@ async def test_trailer_only_error_unary_unary(
) )
async with ChannelFor([service]) as channel: async with ChannelFor([service]) as channel:
with pytest.raises(grpclib.exceptions.GRPCError) as e: with pytest.raises(grpclib.exceptions.GRPCError) as e:
await ThingServiceClient(channel).do_thing(name="something") await ThingServiceClient(channel).do_thing(DoThingRequest(name="something"))
assert e.value.status == grpclib.Status.UNAUTHENTICATED assert e.value.status == grpclib.Status.UNAUTHENTICATED
@ -80,7 +81,7 @@ async def test_trailer_only_error_stream_unary(
async with ChannelFor([service]) as channel: async with ChannelFor([service]) as channel:
with pytest.raises(grpclib.exceptions.GRPCError) as e: with pytest.raises(grpclib.exceptions.GRPCError) as e:
await ThingServiceClient(channel).do_many_things( await ThingServiceClient(channel).do_many_things(
request_iterator=[DoThingRequest(name="something")] do_thing_request_iterator=[DoThingRequest(name="something")]
) )
await _test_client(ThingServiceClient(channel)) await _test_client(ThingServiceClient(channel))
assert e.value.status == grpclib.Status.UNAUTHENTICATED assert e.value.status == grpclib.Status.UNAUTHENTICATED
@ -178,7 +179,9 @@ async def test_async_gen_for_unary_stream_request():
async with ChannelFor([ThingService()]) as channel: async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel) client = ThingServiceClient(channel)
expected_versions = [5, 4, 3, 2, 1] expected_versions = [5, 4, 3, 2, 1]
async for response in client.get_thing_versions(name=thing_name): async for response in client.get_thing_versions(
GetThingRequest(name=thing_name)
):
assert response.name == thing_name assert response.name == thing_name
assert response.version == expected_versions.pop() assert response.version == expected_versions.pop()

View File

@ -1,49 +1,48 @@
from typing import AsyncIterator, AsyncIterable from typing import AsyncIterable, AsyncIterator
import pytest import pytest
from grpclib.testing import ChannelFor from grpclib.testing import ChannelFor
from tests.output_betterproto.example_service.example_service import ( from tests.output_betterproto.example_service.example_service import (
TestBase,
TestStub,
ExampleRequest, ExampleRequest,
ExampleResponse, ExampleResponse,
TestBase,
TestStub,
) )
class ExampleService(TestBase): class ExampleService(TestBase):
async def example_unary_unary( async def example_unary_unary(
self, example_string: str, example_integer: int self, example_request: ExampleRequest
) -> "ExampleResponse": ) -> "ExampleResponse":
return ExampleResponse( return ExampleResponse(
example_string=example_string, example_string=example_request.example_string,
example_integer=example_integer, example_integer=example_request.example_integer,
) )
async def example_unary_stream( async def example_unary_stream(
self, example_string: str, example_integer: int self, example_request: ExampleRequest
) -> AsyncIterator["ExampleResponse"]: ) -> AsyncIterator["ExampleResponse"]:
response = ExampleResponse( response = ExampleResponse(
example_string=example_string, example_string=example_request.example_string,
example_integer=example_integer, example_integer=example_request.example_integer,
) )
yield response yield response
yield response yield response
yield response yield response
async def example_stream_unary( async def example_stream_unary(
self, request_iterator: AsyncIterator["ExampleRequest"] self, example_request_iterator: AsyncIterator["ExampleRequest"]
) -> "ExampleResponse": ) -> "ExampleResponse":
async for example_request in request_iterator: async for example_request in example_request_iterator:
return ExampleResponse( return ExampleResponse(
example_string=example_request.example_string, example_string=example_request.example_string,
example_integer=example_request.example_integer, example_integer=example_request.example_integer,
) )
async def example_stream_stream( async def example_stream_stream(
self, request_iterator: AsyncIterator["ExampleRequest"] self, example_request_iterator: AsyncIterator["ExampleRequest"]
) -> AsyncIterator["ExampleResponse"]: ) -> AsyncIterator["ExampleResponse"]:
async for example_request in request_iterator: async for example_request in example_request_iterator:
yield ExampleResponse( yield ExampleResponse(
example_string=example_request.example_string, example_string=example_request.example_string,
example_integer=example_request.example_integer, example_integer=example_request.example_integer,
@ -52,44 +51,32 @@ class ExampleService(TestBase):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_calls_with_different_cardinalities(): async def test_calls_with_different_cardinalities():
test_string = "test string" example_request = ExampleRequest("test string", 42)
test_int = 42
async with ChannelFor([ExampleService()]) as channel: async with ChannelFor([ExampleService()]) as channel:
stub = TestStub(channel) stub = TestStub(channel)
# unary unary # unary unary
response = await stub.example_unary_unary( response = await stub.example_unary_unary(example_request)
example_string="test string", assert response.example_string == example_request.example_string
example_integer=42, assert response.example_integer == example_request.example_integer
)
assert response.example_string == test_string
assert response.example_integer == test_int
# unary stream # unary stream
async for response in stub.example_unary_stream( async for response in stub.example_unary_stream(example_request):
example_string="test string", assert response.example_string == example_request.example_string
example_integer=42, assert response.example_integer == example_request.example_integer
):
assert response.example_string == test_string
assert response.example_integer == test_int
# stream unary # stream unary
request = ExampleRequest(
example_string=test_string,
example_integer=42,
)
async def request_iterator(): async def request_iterator():
yield request yield example_request
yield request yield example_request
yield request yield example_request
response = await stub.example_stream_unary(request_iterator()) response = await stub.example_stream_unary(request_iterator())
assert response.example_string == test_string assert response.example_string == example_request.example_string
assert response.example_integer == test_int assert response.example_integer == example_request.example_integer
# stream stream # stream stream
async for response in stub.example_stream_stream(request_iterator()): async for response in stub.example_stream_stream(request_iterator()):
assert response.example_string == test_string assert response.example_string == example_request.example_string
assert response.example_integer == test_int assert response.example_integer == example_request.example_integer

View File

@ -2,9 +2,8 @@ from typing import Any, Callable, Optional
import betterproto.lib.google.protobuf as protobuf import betterproto.lib.google.protobuf as protobuf
import pytest import pytest
from tests.mocks import MockChannel from tests.mocks import MockChannel
from tests.output_betterproto.googletypes_response import TestStub from tests.output_betterproto.googletypes_response import Input, TestStub
test_cases = [ test_cases = [
(TestStub.get_double, protobuf.DoubleValue, 2.5), (TestStub.get_double, protobuf.DoubleValue, 2.5),
@ -22,14 +21,15 @@ test_cases = [
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_channel_receives_wrapped_type( async def test_channel_receives_wrapped_type(
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
): ):
wrapped_value = wrapper_class() wrapped_value = wrapper_class()
wrapped_value.value = value wrapped_value.value = value
channel = MockChannel(responses=[wrapped_value]) channel = MockChannel(responses=[wrapped_value])
service = TestStub(channel) service = TestStub(channel)
method_param = Input()
await service_method(service) await service_method(service, method_param)
assert channel.requests[0]["response_type"] != Optional[type(value)] assert channel.requests[0]["response_type"] != Optional[type(value)]
assert channel.requests[0]["response_type"] == type(wrapped_value) assert channel.requests[0]["response_type"] == type(wrapped_value)
@ -39,7 +39,7 @@ async def test_channel_receives_wrapped_type(
@pytest.mark.xfail @pytest.mark.xfail
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_service_unwraps_response( async def test_service_unwraps_response(
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
): ):
""" """
grpclib does not unwrap wrapper values returned by services grpclib does not unwrap wrapper values returned by services
@ -47,8 +47,9 @@ async def test_service_unwraps_response(
wrapped_value = wrapper_class() wrapped_value = wrapper_class()
wrapped_value.value = value wrapped_value.value = value
service = TestStub(MockChannel(responses=[wrapped_value])) service = TestStub(MockChannel(responses=[wrapped_value]))
method_param = Input()
response_value = await service_method(service) response_value = await service_method(service, method_param)
assert response_value == value assert response_value == value
assert type(response_value) == type(value) assert type(response_value) == type(value)

View File

@ -1,7 +1,7 @@
import pytest import pytest
from tests.mocks import MockChannel from tests.mocks import MockChannel
from tests.output_betterproto.googletypes_response_embedded import ( from tests.output_betterproto.googletypes_response_embedded import (
Input,
Output, Output,
TestStub, TestStub,
) )
@ -26,7 +26,7 @@ async def test_service_passes_through_unwrapped_values_embedded_in_response():
) )
service = TestStub(MockChannel(responses=[output])) service = TestStub(MockChannel(responses=[output]))
response = await service.get_output() response = await service.get_output(Input())
assert response.double_value == 10.0 assert response.double_value == 10.0
assert response.float_value == 12.0 assert response.float_value == 12.0

View File

@ -1,17 +1,21 @@
import pytest import pytest
from tests.mocks import MockChannel from tests.mocks import MockChannel
from tests.output_betterproto.import_service_input_message import ( from tests.output_betterproto.import_service_input_message import (
NestedRequestMessage,
RequestMessage,
RequestResponse, RequestResponse,
TestStub, TestStub,
) )
from tests.output_betterproto.import_service_input_message.child import (
ChildRequestMessage,
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_service_correctly_imports_reference_message(): async def test_service_correctly_imports_reference_message():
mock_response = RequestResponse(value=10) mock_response = RequestResponse(value=10)
service = TestStub(MockChannel([mock_response])) service = TestStub(MockChannel([mock_response]))
response = await service.do_thing(argument=1) response = await service.do_thing(RequestMessage(1))
assert mock_response == response assert mock_response == response
@ -19,7 +23,7 @@ async def test_service_correctly_imports_reference_message():
async def test_service_correctly_imports_reference_message_from_child_package(): async def test_service_correctly_imports_reference_message_from_child_package():
mock_response = RequestResponse(value=10) mock_response = RequestResponse(value=10)
service = TestStub(MockChannel([mock_response])) service = TestStub(MockChannel([mock_response]))
response = await service.do_thing2(child_argument=1) response = await service.do_thing2(ChildRequestMessage(1))
assert mock_response == response assert mock_response == response
@ -27,5 +31,5 @@ async def test_service_correctly_imports_reference_message_from_child_package():
async def test_service_correctly_imports_nested_reference(): async def test_service_correctly_imports_nested_reference():
mock_response = RequestResponse(value=10) mock_response = RequestResponse(value=10)
service = TestStub(MockChannel([mock_response])) service = TestStub(MockChannel([mock_response]))
response = await service.do_thing3(nested_argument=1) response = await service.do_thing3(NestedRequestMessage(1))
assert mock_response == response assert mock_response == response

View File

@ -1,8 +1,9 @@
import betterproto
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, List, Dict
from datetime import datetime from datetime import datetime
from inspect import signature from inspect import Parameter, signature
from typing import Dict, List, Optional
import betterproto
def test_has_field(): def test_has_field():
@ -349,10 +350,8 @@ def test_recursive_message():
def test_recursive_message_defaults(): def test_recursive_message_defaults():
from tests.output_betterproto.recursivemessage import ( from tests.output_betterproto.recursivemessage import Intermediate
Test as RecursiveMessage, from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
Intermediate,
)
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42)) msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
@ -479,8 +478,10 @@ def test_iso_datetime_list():
assert all([isinstance(item, datetime) for item in msg.timestamps]) assert all([isinstance(item, datetime) for item in msg.timestamps])
def test_enum_service_argument__expected_default_value(): def test_service_argument__expected_parameter():
from tests.output_betterproto.service.service import ThingType, TestStub from tests.output_betterproto.service.service import TestStub
sig = signature(TestStub.do_thing) sig = signature(TestStub.do_thing)
assert sig.parameters["type"].default == ThingType.UNKNOWN do_thing_request_parameter = sig.parameters["do_thing_request"]
assert do_thing_request_parameter.default is Parameter.empty
assert do_thing_request_parameter.annotation == "DoThingRequest"