Client and Service Stubs take 1 request parameter, not one for each field (#311)
This commit is contained in:
@@ -1,23 +1,24 @@
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import grpclib
|
||||
import grpclib.metadata
|
||||
import grpclib.server
|
||||
import pytest
|
||||
from betterproto.grpc.util.async_channel import AsyncChannel
|
||||
from grpclib.testing import ChannelFor
|
||||
from tests.output_betterproto.service.service import (
|
||||
DoThingRequest,
|
||||
DoThingResponse,
|
||||
GetThingRequest,
|
||||
TestStub as ThingServiceClient,
|
||||
)
|
||||
import grpclib
|
||||
import grpclib.metadata
|
||||
import grpclib.server
|
||||
from grpclib.testing import ChannelFor
|
||||
import pytest
|
||||
from betterproto.grpc.util.async_channel import AsyncChannel
|
||||
from tests.output_betterproto.service.service import TestStub as ThingServiceClient
|
||||
|
||||
from .thing_service import ThingService
|
||||
|
||||
|
||||
async def _test_client(client, name="clean room", **kwargs):
|
||||
response = await client.do_thing(name=name)
|
||||
async def _test_client(client: ThingServiceClient, name="clean room", **kwargs):
|
||||
response = await client.do_thing(DoThingRequest(name=name))
|
||||
assert response.names == [name]
|
||||
|
||||
|
||||
@@ -62,7 +63,7 @@ async def test_trailer_only_error_unary_unary(
|
||||
)
|
||||
async with ChannelFor([service]) as channel:
|
||||
with pytest.raises(grpclib.exceptions.GRPCError) as e:
|
||||
await ThingServiceClient(channel).do_thing(name="something")
|
||||
await ThingServiceClient(channel).do_thing(DoThingRequest(name="something"))
|
||||
assert e.value.status == grpclib.Status.UNAUTHENTICATED
|
||||
|
||||
|
||||
@@ -80,7 +81,7 @@ async def test_trailer_only_error_stream_unary(
|
||||
async with ChannelFor([service]) as channel:
|
||||
with pytest.raises(grpclib.exceptions.GRPCError) as e:
|
||||
await ThingServiceClient(channel).do_many_things(
|
||||
request_iterator=[DoThingRequest(name="something")]
|
||||
do_thing_request_iterator=[DoThingRequest(name="something")]
|
||||
)
|
||||
await _test_client(ThingServiceClient(channel))
|
||||
assert e.value.status == grpclib.Status.UNAUTHENTICATED
|
||||
@@ -178,7 +179,9 @@ async def test_async_gen_for_unary_stream_request():
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
client = ThingServiceClient(channel)
|
||||
expected_versions = [5, 4, 3, 2, 1]
|
||||
async for response in client.get_thing_versions(name=thing_name):
|
||||
async for response in client.get_thing_versions(
|
||||
GetThingRequest(name=thing_name)
|
||||
):
|
||||
assert response.name == thing_name
|
||||
assert response.version == expected_versions.pop()
|
||||
|
||||
|
@@ -1,49 +1,48 @@
|
||||
from typing import AsyncIterator, AsyncIterable
|
||||
from typing import AsyncIterable, AsyncIterator
|
||||
|
||||
import pytest
|
||||
from grpclib.testing import ChannelFor
|
||||
|
||||
from tests.output_betterproto.example_service.example_service import (
|
||||
TestBase,
|
||||
TestStub,
|
||||
ExampleRequest,
|
||||
ExampleResponse,
|
||||
TestBase,
|
||||
TestStub,
|
||||
)
|
||||
|
||||
|
||||
class ExampleService(TestBase):
|
||||
async def example_unary_unary(
|
||||
self, example_string: str, example_integer: int
|
||||
self, example_request: ExampleRequest
|
||||
) -> "ExampleResponse":
|
||||
return ExampleResponse(
|
||||
example_string=example_string,
|
||||
example_integer=example_integer,
|
||||
example_string=example_request.example_string,
|
||||
example_integer=example_request.example_integer,
|
||||
)
|
||||
|
||||
async def example_unary_stream(
|
||||
self, example_string: str, example_integer: int
|
||||
self, example_request: ExampleRequest
|
||||
) -> AsyncIterator["ExampleResponse"]:
|
||||
response = ExampleResponse(
|
||||
example_string=example_string,
|
||||
example_integer=example_integer,
|
||||
example_string=example_request.example_string,
|
||||
example_integer=example_request.example_integer,
|
||||
)
|
||||
yield response
|
||||
yield response
|
||||
yield response
|
||||
|
||||
async def example_stream_unary(
|
||||
self, request_iterator: AsyncIterator["ExampleRequest"]
|
||||
self, example_request_iterator: AsyncIterator["ExampleRequest"]
|
||||
) -> "ExampleResponse":
|
||||
async for example_request in request_iterator:
|
||||
async for example_request in example_request_iterator:
|
||||
return ExampleResponse(
|
||||
example_string=example_request.example_string,
|
||||
example_integer=example_request.example_integer,
|
||||
)
|
||||
|
||||
async def example_stream_stream(
|
||||
self, request_iterator: AsyncIterator["ExampleRequest"]
|
||||
self, example_request_iterator: AsyncIterator["ExampleRequest"]
|
||||
) -> AsyncIterator["ExampleResponse"]:
|
||||
async for example_request in request_iterator:
|
||||
async for example_request in example_request_iterator:
|
||||
yield ExampleResponse(
|
||||
example_string=example_request.example_string,
|
||||
example_integer=example_request.example_integer,
|
||||
@@ -52,44 +51,32 @@ class ExampleService(TestBase):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_with_different_cardinalities():
|
||||
test_string = "test string"
|
||||
test_int = 42
|
||||
example_request = ExampleRequest("test string", 42)
|
||||
|
||||
async with ChannelFor([ExampleService()]) as channel:
|
||||
stub = TestStub(channel)
|
||||
|
||||
# unary unary
|
||||
response = await stub.example_unary_unary(
|
||||
example_string="test string",
|
||||
example_integer=42,
|
||||
)
|
||||
assert response.example_string == test_string
|
||||
assert response.example_integer == test_int
|
||||
response = await stub.example_unary_unary(example_request)
|
||||
assert response.example_string == example_request.example_string
|
||||
assert response.example_integer == example_request.example_integer
|
||||
|
||||
# unary stream
|
||||
async for response in stub.example_unary_stream(
|
||||
example_string="test string",
|
||||
example_integer=42,
|
||||
):
|
||||
assert response.example_string == test_string
|
||||
assert response.example_integer == test_int
|
||||
async for response in stub.example_unary_stream(example_request):
|
||||
assert response.example_string == example_request.example_string
|
||||
assert response.example_integer == example_request.example_integer
|
||||
|
||||
# stream unary
|
||||
request = ExampleRequest(
|
||||
example_string=test_string,
|
||||
example_integer=42,
|
||||
)
|
||||
|
||||
async def request_iterator():
|
||||
yield request
|
||||
yield request
|
||||
yield request
|
||||
yield example_request
|
||||
yield example_request
|
||||
yield example_request
|
||||
|
||||
response = await stub.example_stream_unary(request_iterator())
|
||||
assert response.example_string == test_string
|
||||
assert response.example_integer == test_int
|
||||
assert response.example_string == example_request.example_string
|
||||
assert response.example_integer == example_request.example_integer
|
||||
|
||||
# stream stream
|
||||
async for response in stub.example_stream_stream(request_iterator()):
|
||||
assert response.example_string == test_string
|
||||
assert response.example_integer == test_int
|
||||
assert response.example_string == example_request.example_string
|
||||
assert response.example_integer == example_request.example_integer
|
||||
|
@@ -2,9 +2,8 @@ from typing import Any, Callable, Optional
|
||||
|
||||
import betterproto.lib.google.protobuf as protobuf
|
||||
import pytest
|
||||
|
||||
from tests.mocks import MockChannel
|
||||
from tests.output_betterproto.googletypes_response import TestStub
|
||||
from tests.output_betterproto.googletypes_response import Input, TestStub
|
||||
|
||||
test_cases = [
|
||||
(TestStub.get_double, protobuf.DoubleValue, 2.5),
|
||||
@@ -22,14 +21,15 @@ test_cases = [
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||
async def test_channel_receives_wrapped_type(
|
||||
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
||||
service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
|
||||
):
|
||||
wrapped_value = wrapper_class()
|
||||
wrapped_value.value = value
|
||||
channel = MockChannel(responses=[wrapped_value])
|
||||
service = TestStub(channel)
|
||||
method_param = Input()
|
||||
|
||||
await service_method(service)
|
||||
await service_method(service, method_param)
|
||||
|
||||
assert channel.requests[0]["response_type"] != Optional[type(value)]
|
||||
assert channel.requests[0]["response_type"] == type(wrapped_value)
|
||||
@@ -39,7 +39,7 @@ async def test_channel_receives_wrapped_type(
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||
async def test_service_unwraps_response(
|
||||
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
||||
service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
|
||||
):
|
||||
"""
|
||||
grpclib does not unwrap wrapper values returned by services
|
||||
@@ -47,8 +47,9 @@ async def test_service_unwraps_response(
|
||||
wrapped_value = wrapper_class()
|
||||
wrapped_value.value = value
|
||||
service = TestStub(MockChannel(responses=[wrapped_value]))
|
||||
method_param = Input()
|
||||
|
||||
response_value = await service_method(service)
|
||||
response_value = await service_method(service, method_param)
|
||||
|
||||
assert response_value == value
|
||||
assert type(response_value) == type(value)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from tests.mocks import MockChannel
|
||||
from tests.output_betterproto.googletypes_response_embedded import (
|
||||
Input,
|
||||
Output,
|
||||
TestStub,
|
||||
)
|
||||
@@ -26,7 +26,7 @@ async def test_service_passes_through_unwrapped_values_embedded_in_response():
|
||||
)
|
||||
|
||||
service = TestStub(MockChannel(responses=[output]))
|
||||
response = await service.get_output()
|
||||
response = await service.get_output(Input())
|
||||
|
||||
assert response.double_value == 10.0
|
||||
assert response.float_value == 12.0
|
||||
|
@@ -1,17 +1,21 @@
|
||||
import pytest
|
||||
|
||||
from tests.mocks import MockChannel
|
||||
from tests.output_betterproto.import_service_input_message import (
|
||||
NestedRequestMessage,
|
||||
RequestMessage,
|
||||
RequestResponse,
|
||||
TestStub,
|
||||
)
|
||||
from tests.output_betterproto.import_service_input_message.child import (
|
||||
ChildRequestMessage,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_correctly_imports_reference_message():
|
||||
mock_response = RequestResponse(value=10)
|
||||
service = TestStub(MockChannel([mock_response]))
|
||||
response = await service.do_thing(argument=1)
|
||||
response = await service.do_thing(RequestMessage(1))
|
||||
assert mock_response == response
|
||||
|
||||
|
||||
@@ -19,7 +23,7 @@ async def test_service_correctly_imports_reference_message():
|
||||
async def test_service_correctly_imports_reference_message_from_child_package():
|
||||
mock_response = RequestResponse(value=10)
|
||||
service = TestStub(MockChannel([mock_response]))
|
||||
response = await service.do_thing2(child_argument=1)
|
||||
response = await service.do_thing2(ChildRequestMessage(1))
|
||||
assert mock_response == response
|
||||
|
||||
|
||||
@@ -27,5 +31,5 @@ async def test_service_correctly_imports_reference_message_from_child_package():
|
||||
async def test_service_correctly_imports_nested_reference():
|
||||
mock_response = RequestResponse(value=10)
|
||||
service = TestStub(MockChannel([mock_response]))
|
||||
response = await service.do_thing3(nested_argument=1)
|
||||
response = await service.do_thing3(NestedRequestMessage(1))
|
||||
assert mock_response == response
|
||||
|
@@ -1,8 +1,9 @@
|
||||
import betterproto
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Dict
|
||||
from datetime import datetime
|
||||
from inspect import signature
|
||||
from inspect import Parameter, signature
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import betterproto
|
||||
|
||||
|
||||
def test_has_field():
|
||||
@@ -349,10 +350,8 @@ def test_recursive_message():
|
||||
|
||||
|
||||
def test_recursive_message_defaults():
|
||||
from tests.output_betterproto.recursivemessage import (
|
||||
Test as RecursiveMessage,
|
||||
Intermediate,
|
||||
)
|
||||
from tests.output_betterproto.recursivemessage import Intermediate
|
||||
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
|
||||
|
||||
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
|
||||
|
||||
@@ -479,8 +478,10 @@ def test_iso_datetime_list():
|
||||
assert all([isinstance(item, datetime) for item in msg.timestamps])
|
||||
|
||||
|
||||
def test_enum_service_argument__expected_default_value():
|
||||
from tests.output_betterproto.service.service import ThingType, TestStub
|
||||
def test_service_argument__expected_parameter():
|
||||
from tests.output_betterproto.service.service import TestStub
|
||||
|
||||
sig = signature(TestStub.do_thing)
|
||||
assert sig.parameters["type"].default == ThingType.UNKNOWN
|
||||
do_thing_request_parameter = sig.parameters["do_thing_request"]
|
||||
assert do_thing_request_parameter.default is Parameter.empty
|
||||
assert do_thing_request_parameter.annotation == "DoThingRequest"
|
||||
|
Reference in New Issue
Block a user