Client and Service Stubs take 1 request parameter, not one for each field (#311)

This commit is contained in:
efokschaner
2022-01-17 10:58:57 -08:00
committed by GitHub
parent 6dd7baa26c
commit d260f071e0
13 changed files with 140 additions and 193 deletions

View File

@@ -1,23 +1,24 @@
import asyncio
import sys
import grpclib
import grpclib.metadata
import grpclib.server
import pytest
from betterproto.grpc.util.async_channel import AsyncChannel
from grpclib.testing import ChannelFor
from tests.output_betterproto.service.service import (
DoThingRequest,
DoThingResponse,
GetThingRequest,
TestStub as ThingServiceClient,
)
import grpclib
import grpclib.metadata
import grpclib.server
from grpclib.testing import ChannelFor
import pytest
from betterproto.grpc.util.async_channel import AsyncChannel
from tests.output_betterproto.service.service import TestStub as ThingServiceClient
from .thing_service import ThingService
async def _test_client(client, name="clean room", **kwargs):
response = await client.do_thing(name=name)
async def _test_client(client: ThingServiceClient, name="clean room", **kwargs):
response = await client.do_thing(DoThingRequest(name=name))
assert response.names == [name]
@@ -62,7 +63,7 @@ async def test_trailer_only_error_unary_unary(
)
async with ChannelFor([service]) as channel:
with pytest.raises(grpclib.exceptions.GRPCError) as e:
await ThingServiceClient(channel).do_thing(name="something")
await ThingServiceClient(channel).do_thing(DoThingRequest(name="something"))
assert e.value.status == grpclib.Status.UNAUTHENTICATED
@@ -80,7 +81,7 @@ async def test_trailer_only_error_stream_unary(
async with ChannelFor([service]) as channel:
with pytest.raises(grpclib.exceptions.GRPCError) as e:
await ThingServiceClient(channel).do_many_things(
request_iterator=[DoThingRequest(name="something")]
do_thing_request_iterator=[DoThingRequest(name="something")]
)
await _test_client(ThingServiceClient(channel))
assert e.value.status == grpclib.Status.UNAUTHENTICATED
@@ -178,7 +179,9 @@ async def test_async_gen_for_unary_stream_request():
async with ChannelFor([ThingService()]) as channel:
client = ThingServiceClient(channel)
expected_versions = [5, 4, 3, 2, 1]
async for response in client.get_thing_versions(name=thing_name):
async for response in client.get_thing_versions(
GetThingRequest(name=thing_name)
):
assert response.name == thing_name
assert response.version == expected_versions.pop()

View File

@@ -1,49 +1,48 @@
from typing import AsyncIterator, AsyncIterable
from typing import AsyncIterable, AsyncIterator
import pytest
from grpclib.testing import ChannelFor
from tests.output_betterproto.example_service.example_service import (
TestBase,
TestStub,
ExampleRequest,
ExampleResponse,
TestBase,
TestStub,
)
class ExampleService(TestBase):
async def example_unary_unary(
self, example_string: str, example_integer: int
self, example_request: ExampleRequest
) -> "ExampleResponse":
return ExampleResponse(
example_string=example_string,
example_integer=example_integer,
example_string=example_request.example_string,
example_integer=example_request.example_integer,
)
async def example_unary_stream(
self, example_string: str, example_integer: int
self, example_request: ExampleRequest
) -> AsyncIterator["ExampleResponse"]:
response = ExampleResponse(
example_string=example_string,
example_integer=example_integer,
example_string=example_request.example_string,
example_integer=example_request.example_integer,
)
yield response
yield response
yield response
async def example_stream_unary(
self, request_iterator: AsyncIterator["ExampleRequest"]
self, example_request_iterator: AsyncIterator["ExampleRequest"]
) -> "ExampleResponse":
async for example_request in request_iterator:
async for example_request in example_request_iterator:
return ExampleResponse(
example_string=example_request.example_string,
example_integer=example_request.example_integer,
)
async def example_stream_stream(
self, request_iterator: AsyncIterator["ExampleRequest"]
self, example_request_iterator: AsyncIterator["ExampleRequest"]
) -> AsyncIterator["ExampleResponse"]:
async for example_request in request_iterator:
async for example_request in example_request_iterator:
yield ExampleResponse(
example_string=example_request.example_string,
example_integer=example_request.example_integer,
@@ -52,44 +51,32 @@ class ExampleService(TestBase):
@pytest.mark.asyncio
async def test_calls_with_different_cardinalities():
test_string = "test string"
test_int = 42
example_request = ExampleRequest("test string", 42)
async with ChannelFor([ExampleService()]) as channel:
stub = TestStub(channel)
# unary unary
response = await stub.example_unary_unary(
example_string="test string",
example_integer=42,
)
assert response.example_string == test_string
assert response.example_integer == test_int
response = await stub.example_unary_unary(example_request)
assert response.example_string == example_request.example_string
assert response.example_integer == example_request.example_integer
# unary stream
async for response in stub.example_unary_stream(
example_string="test string",
example_integer=42,
):
assert response.example_string == test_string
assert response.example_integer == test_int
async for response in stub.example_unary_stream(example_request):
assert response.example_string == example_request.example_string
assert response.example_integer == example_request.example_integer
# stream unary
request = ExampleRequest(
example_string=test_string,
example_integer=42,
)
async def request_iterator():
yield request
yield request
yield request
yield example_request
yield example_request
yield example_request
response = await stub.example_stream_unary(request_iterator())
assert response.example_string == test_string
assert response.example_integer == test_int
assert response.example_string == example_request.example_string
assert response.example_integer == example_request.example_integer
# stream stream
async for response in stub.example_stream_stream(request_iterator()):
assert response.example_string == test_string
assert response.example_integer == test_int
assert response.example_string == example_request.example_string
assert response.example_integer == example_request.example_integer

View File

@@ -2,9 +2,8 @@ from typing import Any, Callable, Optional
import betterproto.lib.google.protobuf as protobuf
import pytest
from tests.mocks import MockChannel
from tests.output_betterproto.googletypes_response import TestStub
from tests.output_betterproto.googletypes_response import Input, TestStub
test_cases = [
(TestStub.get_double, protobuf.DoubleValue, 2.5),
@@ -22,14 +21,15 @@ test_cases = [
@pytest.mark.asyncio
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_channel_receives_wrapped_type(
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
):
wrapped_value = wrapper_class()
wrapped_value.value = value
channel = MockChannel(responses=[wrapped_value])
service = TestStub(channel)
method_param = Input()
await service_method(service)
await service_method(service, method_param)
assert channel.requests[0]["response_type"] != Optional[type(value)]
assert channel.requests[0]["response_type"] == type(wrapped_value)
@@ -39,7 +39,7 @@ async def test_channel_receives_wrapped_type(
@pytest.mark.xfail
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_service_unwraps_response(
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
):
"""
grpclib does not unwrap wrapper values returned by services
@@ -47,8 +47,9 @@ async def test_service_unwraps_response(
wrapped_value = wrapper_class()
wrapped_value.value = value
service = TestStub(MockChannel(responses=[wrapped_value]))
method_param = Input()
response_value = await service_method(service)
response_value = await service_method(service, method_param)
assert response_value == value
assert type(response_value) == type(value)

View File

@@ -1,7 +1,7 @@
import pytest
from tests.mocks import MockChannel
from tests.output_betterproto.googletypes_response_embedded import (
Input,
Output,
TestStub,
)
@@ -26,7 +26,7 @@ async def test_service_passes_through_unwrapped_values_embedded_in_response():
)
service = TestStub(MockChannel(responses=[output]))
response = await service.get_output()
response = await service.get_output(Input())
assert response.double_value == 10.0
assert response.float_value == 12.0

View File

@@ -1,17 +1,21 @@
import pytest
from tests.mocks import MockChannel
from tests.output_betterproto.import_service_input_message import (
NestedRequestMessage,
RequestMessage,
RequestResponse,
TestStub,
)
from tests.output_betterproto.import_service_input_message.child import (
ChildRequestMessage,
)
@pytest.mark.asyncio
async def test_service_correctly_imports_reference_message():
mock_response = RequestResponse(value=10)
service = TestStub(MockChannel([mock_response]))
response = await service.do_thing(argument=1)
response = await service.do_thing(RequestMessage(1))
assert mock_response == response
@@ -19,7 +23,7 @@ async def test_service_correctly_imports_reference_message():
async def test_service_correctly_imports_reference_message_from_child_package():
mock_response = RequestResponse(value=10)
service = TestStub(MockChannel([mock_response]))
response = await service.do_thing2(child_argument=1)
response = await service.do_thing2(ChildRequestMessage(1))
assert mock_response == response
@@ -27,5 +31,5 @@ async def test_service_correctly_imports_reference_message_from_child_package():
async def test_service_correctly_imports_nested_reference():
mock_response = RequestResponse(value=10)
service = TestStub(MockChannel([mock_response]))
response = await service.do_thing3(nested_argument=1)
response = await service.do_thing3(NestedRequestMessage(1))
assert mock_response == response

View File

@@ -1,8 +1,9 @@
import betterproto
from dataclasses import dataclass
from typing import Optional, List, Dict
from datetime import datetime
from inspect import signature
from inspect import Parameter, signature
from typing import Dict, List, Optional
import betterproto
def test_has_field():
@@ -349,10 +350,8 @@ def test_recursive_message():
def test_recursive_message_defaults():
from tests.output_betterproto.recursivemessage import (
Test as RecursiveMessage,
Intermediate,
)
from tests.output_betterproto.recursivemessage import Intermediate
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
@@ -479,8 +478,10 @@ def test_iso_datetime_list():
assert all([isinstance(item, datetime) for item in msg.timestamps])
def test_enum_service_argument__expected_default_value():
from tests.output_betterproto.service.service import ThingType, TestStub
def test_service_argument__expected_parameter():
from tests.output_betterproto.service.service import TestStub
sig = signature(TestStub.do_thing)
assert sig.parameters["type"].default == ThingType.UNKNOWN
do_thing_request_parameter = sig.parameters["do_thing_request"]
assert do_thing_request_parameter.default is Parameter.empty
assert do_thing_request_parameter.annotation == "DoThingRequest"