Test all supported wrapper types. Add xfail test for unwrapping the value

This commit is contained in:
boukeversteegh 2020-05-24 12:33:36 +02:00
parent b711d1e11f
commit 35548cb43e
4 changed files with 93 additions and 36 deletions

View File

@ -1,18 +0,0 @@
syntax = "proto3";
import "google/protobuf/wrappers.proto";
service Test {
rpc GetInt32 (Input) returns (google.protobuf.Int32Value);
rpc GetAnotherInt32 (Input) returns (google.protobuf.Int32Value);
rpc GetInt64 (Input) returns (google.protobuf.Int64Value);
rpc GetOutput (Input) returns (Output);
}
message Input {
}
message Output {
google.protobuf.Int64Value int64 = 1;
}

View File

@ -2,17 +2,20 @@ syntax = "proto3";
import "google/protobuf/wrappers.proto"; import "google/protobuf/wrappers.proto";
// Tests that wrapped return values can be used
service Test { service Test {
rpc GetInt32 (Input) returns (google.protobuf.Int32Value); rpc GetDouble (Input) returns (google.protobuf.DoubleValue);
rpc GetAnotherInt32 (Input) returns (google.protobuf.Int32Value); rpc GetFloat (Input) returns (google.protobuf.FloatValue);
rpc GetInt64 (Input) returns (google.protobuf.Int64Value); rpc GetInt64 (Input) returns (google.protobuf.Int64Value);
rpc GetOutput (Input) returns (Output); rpc GetUInt64 (Input) returns (google.protobuf.UInt64Value);
rpc GetInt32 (Input) returns (google.protobuf.Int32Value);
rpc GetUInt32 (Input) returns (google.protobuf.UInt32Value);
rpc GetBool (Input) returns (google.protobuf.BoolValue);
rpc GetString (Input) returns (google.protobuf.StringValue);
rpc GetBytes (Input) returns (google.protobuf.BytesValue);
} }
message Input { message Input {
} }
message Output {
google.protobuf.Int64Value int64 = 1;
}

View File

@ -1,20 +1,53 @@
from typing import Optional from typing import Any, Callable, Optional
import google.protobuf.wrappers_pb2 as wrappers
import pytest import pytest
from betterproto.tests.mocks import MockChannel
from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import ( from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import (
TestStub TestStub,
) )
test_cases = [
class TestStubChild(TestStub): (TestStub.get_double, wrappers.DoubleValue, 2.5),
async def _unary_unary(self, route, request, response_type, **kwargs): (TestStub.get_float, wrappers.FloatValue, 2.5),
self.response_type = response_type (TestStub.get_int64, wrappers.Int64Value, -64),
(TestStub.get_u_int64, wrappers.UInt64Value, 64),
(TestStub.get_int32, wrappers.Int32Value, -32),
(TestStub.get_u_int32, wrappers.UInt32Value, 32),
(TestStub.get_bool, wrappers.BoolValue, True),
(TestStub.get_string, wrappers.StringValue, "string"),
(TestStub.get_bytes, wrappers.BytesValue, bytes(0xFF)[0:4]),
]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test(): @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
pytest.skip("todo") async def test_channel_receives_wrapped_type(
stub = TestStubChild(None) service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
await stub.get_int64() ):
assert stub.response_type != Optional[int] wrapped_value = wrapper_class()
wrapped_value.value = value
channel = MockChannel(responses=[wrapped_value])
service = TestStub(channel)
await service_method(service)
assert channel.requests[0]["response_type"] != Optional[type(value)]
assert channel.requests[0]["response_type"] == type(wrapped_value)
@pytest.mark.asyncio
@pytest.mark.xfail
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_service_unwraps_response(
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
):
wrapped_value = wrapper_class()
wrapped_value.value = value
service = TestStub(MockChannel(responses=[wrapped_value]))
response_value = await service_method(service)
assert type(response_value) == value
assert type(response_value) == type(value)

View File

@ -0,0 +1,39 @@
from typing import List
from grpclib.client import Channel
class MockChannel(Channel):
# noinspection PyMissingConstructor
def __init__(self, responses: List) -> None:
self.responses = responses
self.requests = []
def request(self, route, cardinality, request, response_type, **kwargs):
self.requests.append(
{
"route": route,
"cardinality": cardinality,
"request": request,
"response_type": response_type,
}
)
return MockStream(self.responses)
class MockStream:
def __init__(self, responses: List) -> None:
super().__init__()
self.responses = responses
async def recv_message(self):
return next(self.responses)
async def send_message(self, *args, **kwargs):
pass
async def __aexit__(self, exc_type, exc_val, exc_tb):
return True
async def __aenter__(self):
return True