Client and Service Stubs take 1 request parameter, not one for each field (#311)

This commit is contained in:
efokschaner
2022-01-17 10:58:57 -08:00
committed by GitHub
parent 6dd7baa26c
commit d260f071e0
13 changed files with 140 additions and 193 deletions

View File

@@ -1,49 +1,48 @@
from typing import AsyncIterator, AsyncIterable
from typing import AsyncIterable, AsyncIterator
import pytest
from grpclib.testing import ChannelFor
from tests.output_betterproto.example_service.example_service import (
TestBase,
TestStub,
ExampleRequest,
ExampleResponse,
TestBase,
TestStub,
)
class ExampleService(TestBase):
async def example_unary_unary(
self, example_string: str, example_integer: int
self, example_request: ExampleRequest
) -> "ExampleResponse":
return ExampleResponse(
example_string=example_string,
example_integer=example_integer,
example_string=example_request.example_string,
example_integer=example_request.example_integer,
)
async def example_unary_stream(
self, example_string: str, example_integer: int
self, example_request: ExampleRequest
) -> AsyncIterator["ExampleResponse"]:
response = ExampleResponse(
example_string=example_string,
example_integer=example_integer,
example_string=example_request.example_string,
example_integer=example_request.example_integer,
)
yield response
yield response
yield response
async def example_stream_unary(
self, request_iterator: AsyncIterator["ExampleRequest"]
self, example_request_iterator: AsyncIterator["ExampleRequest"]
) -> "ExampleResponse":
async for example_request in request_iterator:
async for example_request in example_request_iterator:
return ExampleResponse(
example_string=example_request.example_string,
example_integer=example_request.example_integer,
)
async def example_stream_stream(
self, request_iterator: AsyncIterator["ExampleRequest"]
self, example_request_iterator: AsyncIterator["ExampleRequest"]
) -> AsyncIterator["ExampleResponse"]:
async for example_request in request_iterator:
async for example_request in example_request_iterator:
yield ExampleResponse(
example_string=example_request.example_string,
example_integer=example_request.example_integer,
@@ -52,44 +51,32 @@ class ExampleService(TestBase):
@pytest.mark.asyncio
async def test_calls_with_different_cardinalities():
test_string = "test string"
test_int = 42
example_request = ExampleRequest("test string", 42)
async with ChannelFor([ExampleService()]) as channel:
stub = TestStub(channel)
# unary unary
response = await stub.example_unary_unary(
example_string="test string",
example_integer=42,
)
assert response.example_string == test_string
assert response.example_integer == test_int
response = await stub.example_unary_unary(example_request)
assert response.example_string == example_request.example_string
assert response.example_integer == example_request.example_integer
# unary stream
async for response in stub.example_unary_stream(
example_string="test string",
example_integer=42,
):
assert response.example_string == test_string
assert response.example_integer == test_int
async for response in stub.example_unary_stream(example_request):
assert response.example_string == example_request.example_string
assert response.example_integer == example_request.example_integer
# stream unary
request = ExampleRequest(
example_string=test_string,
example_integer=42,
)
async def request_iterator():
yield request
yield request
yield request
yield example_request
yield example_request
yield example_request
response = await stub.example_stream_unary(request_iterator())
assert response.example_string == test_string
assert response.example_integer == test_int
assert response.example_string == example_request.example_string
assert response.example_integer == example_request.example_integer
# stream stream
async for response in stub.example_stream_stream(request_iterator()):
assert response.example_string == test_string
assert response.example_integer == test_int
assert response.example_string == example_request.example_string
assert response.example_integer == example_request.example_integer

View File

@@ -2,9 +2,8 @@ from typing import Any, Callable, Optional
import betterproto.lib.google.protobuf as protobuf
import pytest
from tests.mocks import MockChannel
from tests.output_betterproto.googletypes_response import TestStub
from tests.output_betterproto.googletypes_response import Input, TestStub
test_cases = [
(TestStub.get_double, protobuf.DoubleValue, 2.5),
@@ -22,14 +21,15 @@ test_cases = [
@pytest.mark.asyncio
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_channel_receives_wrapped_type(
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
):
wrapped_value = wrapper_class()
wrapped_value.value = value
channel = MockChannel(responses=[wrapped_value])
service = TestStub(channel)
method_param = Input()
await service_method(service)
await service_method(service, method_param)
assert channel.requests[0]["response_type"] != Optional[type(value)]
assert channel.requests[0]["response_type"] == type(wrapped_value)
@@ -39,7 +39,7 @@ async def test_channel_receives_wrapped_type(
@pytest.mark.xfail
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_service_unwraps_response(
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
):
"""
grpclib does not unwrap wrapper values returned by services
@@ -47,8 +47,9 @@ async def test_service_unwraps_response(
wrapped_value = wrapper_class()
wrapped_value.value = value
service = TestStub(MockChannel(responses=[wrapped_value]))
method_param = Input()
response_value = await service_method(service)
response_value = await service_method(service, method_param)
assert response_value == value
assert type(response_value) == type(value)

View File

@@ -1,7 +1,7 @@
import pytest
from tests.mocks import MockChannel
from tests.output_betterproto.googletypes_response_embedded import (
Input,
Output,
TestStub,
)
@@ -26,7 +26,7 @@ async def test_service_passes_through_unwrapped_values_embedded_in_response():
)
service = TestStub(MockChannel(responses=[output]))
response = await service.get_output()
response = await service.get_output(Input())
assert response.double_value == 10.0
assert response.float_value == 12.0

View File

@@ -1,17 +1,21 @@
import pytest
from tests.mocks import MockChannel
from tests.output_betterproto.import_service_input_message import (
NestedRequestMessage,
RequestMessage,
RequestResponse,
TestStub,
)
from tests.output_betterproto.import_service_input_message.child import (
ChildRequestMessage,
)
@pytest.mark.asyncio
async def test_service_correctly_imports_reference_message():
mock_response = RequestResponse(value=10)
service = TestStub(MockChannel([mock_response]))
response = await service.do_thing(argument=1)
response = await service.do_thing(RequestMessage(1))
assert mock_response == response
@@ -19,7 +23,7 @@ async def test_service_correctly_imports_reference_message():
async def test_service_correctly_imports_reference_message_from_child_package():
mock_response = RequestResponse(value=10)
service = TestStub(MockChannel([mock_response]))
response = await service.do_thing2(child_argument=1)
response = await service.do_thing2(ChildRequestMessage(1))
assert mock_response == response
@@ -27,5 +31,5 @@ async def test_service_correctly_imports_reference_message_from_child_package():
async def test_service_correctly_imports_nested_reference():
mock_response = RequestResponse(value=10)
service = TestStub(MockChannel([mock_response]))
response = await service.do_thing3(nested_argument=1)
response = await service.do_thing3(NestedRequestMessage(1))
assert mock_response == response