Test all supported wrapper types. Add xfail test for unwrapping the value
This commit is contained in:
parent
b711d1e11f
commit
35548cb43e
@ -1,18 +0,0 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/wrappers.proto";
|
||||
|
||||
service Test {
|
||||
rpc GetInt32 (Input) returns (google.protobuf.Int32Value);
|
||||
rpc GetAnotherInt32 (Input) returns (google.protobuf.Int32Value);
|
||||
rpc GetInt64 (Input) returns (google.protobuf.Int64Value);
|
||||
rpc GetOutput (Input) returns (Output);
|
||||
}
|
||||
|
||||
message Input {
|
||||
|
||||
}
|
||||
|
||||
message Output {
|
||||
google.protobuf.Int64Value int64 = 1;
|
||||
}
|
@ -2,17 +2,20 @@ syntax = "proto3";
|
||||
|
||||
import "google/protobuf/wrappers.proto";
|
||||
|
||||
// Tests that wrapped return values can be used
|
||||
|
||||
service Test {
|
||||
rpc GetInt32 (Input) returns (google.protobuf.Int32Value);
|
||||
rpc GetAnotherInt32 (Input) returns (google.protobuf.Int32Value);
|
||||
rpc GetDouble (Input) returns (google.protobuf.DoubleValue);
|
||||
rpc GetFloat (Input) returns (google.protobuf.FloatValue);
|
||||
rpc GetInt64 (Input) returns (google.protobuf.Int64Value);
|
||||
rpc GetOutput (Input) returns (Output);
|
||||
rpc GetUInt64 (Input) returns (google.protobuf.UInt64Value);
|
||||
rpc GetInt32 (Input) returns (google.protobuf.Int32Value);
|
||||
rpc GetUInt32 (Input) returns (google.protobuf.UInt32Value);
|
||||
rpc GetBool (Input) returns (google.protobuf.BoolValue);
|
||||
rpc GetString (Input) returns (google.protobuf.StringValue);
|
||||
rpc GetBytes (Input) returns (google.protobuf.BytesValue);
|
||||
}
|
||||
|
||||
message Input {
|
||||
|
||||
}
|
||||
|
||||
message Output {
|
||||
google.protobuf.Int64Value int64 = 1;
|
||||
}
|
@ -1,20 +1,53 @@
|
||||
from typing import Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import google.protobuf.wrappers_pb2 as wrappers
|
||||
import pytest
|
||||
|
||||
from betterproto.tests.mocks import MockChannel
|
||||
from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import (
|
||||
TestStub
|
||||
TestStub,
|
||||
)
|
||||
|
||||
|
||||
class TestStubChild(TestStub):
|
||||
async def _unary_unary(self, route, request, response_type, **kwargs):
|
||||
self.response_type = response_type
|
||||
test_cases = [
|
||||
(TestStub.get_double, wrappers.DoubleValue, 2.5),
|
||||
(TestStub.get_float, wrappers.FloatValue, 2.5),
|
||||
(TestStub.get_int64, wrappers.Int64Value, -64),
|
||||
(TestStub.get_u_int64, wrappers.UInt64Value, 64),
|
||||
(TestStub.get_int32, wrappers.Int32Value, -32),
|
||||
(TestStub.get_u_int32, wrappers.UInt32Value, 32),
|
||||
(TestStub.get_bool, wrappers.BoolValue, True),
|
||||
(TestStub.get_string, wrappers.StringValue, "string"),
|
||||
(TestStub.get_bytes, wrappers.BytesValue, bytes(0xFF)[0:4]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test():
|
||||
pytest.skip("todo")
|
||||
stub = TestStubChild(None)
|
||||
await stub.get_int64()
|
||||
assert stub.response_type != Optional[int]
|
||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||
async def test_channel_receives_wrapped_type(
|
||||
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
||||
):
|
||||
wrapped_value = wrapper_class()
|
||||
wrapped_value.value = value
|
||||
channel = MockChannel(responses=[wrapped_value])
|
||||
service = TestStub(channel)
|
||||
|
||||
await service_method(service)
|
||||
|
||||
assert channel.requests[0]["response_type"] != Optional[type(value)]
|
||||
assert channel.requests[0]["response_type"] == type(wrapped_value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||
async def test_service_unwraps_response(
|
||||
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
||||
):
|
||||
wrapped_value = wrapper_class()
|
||||
wrapped_value.value = value
|
||||
service = TestStub(MockChannel(responses=[wrapped_value]))
|
||||
|
||||
response_value = await service_method(service)
|
||||
|
||||
assert type(response_value) == value
|
||||
assert type(response_value) == type(value)
|
||||
|
39
betterproto/tests/mocks.py
Normal file
39
betterproto/tests/mocks.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import List
|
||||
|
||||
from grpclib.client import Channel
|
||||
|
||||
|
||||
class MockChannel(Channel):
|
||||
# noinspection PyMissingConstructor
|
||||
def __init__(self, responses: List) -> None:
|
||||
self.responses = responses
|
||||
self.requests = []
|
||||
|
||||
def request(self, route, cardinality, request, response_type, **kwargs):
|
||||
self.requests.append(
|
||||
{
|
||||
"route": route,
|
||||
"cardinality": cardinality,
|
||||
"request": request,
|
||||
"response_type": response_type,
|
||||
}
|
||||
)
|
||||
return MockStream(self.responses)
|
||||
|
||||
|
||||
class MockStream:
|
||||
def __init__(self, responses: List) -> None:
|
||||
super().__init__()
|
||||
self.responses = responses
|
||||
|
||||
async def recv_message(self):
|
||||
return next(self.responses)
|
||||
|
||||
async def send_message(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return True
|
||||
|
||||
async def __aenter__(self):
|
||||
return True
|
Loading…
x
Reference in New Issue
Block a user