Merge pull request #40 from boukeversteegh/pr/wrapper-as-output
Support using Google's wrapper types as RPC output values
This commit is contained in:
commit
1a87ea43a1
@ -1,12 +1,11 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import os.path
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import Any, List, Tuple
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
try:
|
||||
import black
|
||||
@ -24,44 +23,51 @@ from google.protobuf.descriptor_pb2 import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FieldDescriptorProto,
|
||||
FileDescriptorProto,
|
||||
ServiceDescriptorProto,
|
||||
)
|
||||
|
||||
from betterproto.casing import safe_snake_case
|
||||
|
||||
import google.protobuf.wrappers_pb2 as google_wrappers
|
||||
|
||||
WRAPPER_TYPES = {
|
||||
"google.protobuf.DoubleValue": "float",
|
||||
"google.protobuf.FloatValue": "float",
|
||||
"google.protobuf.Int64Value": "int",
|
||||
"google.protobuf.UInt64Value": "int",
|
||||
"google.protobuf.Int32Value": "int",
|
||||
"google.protobuf.UInt32Value": "int",
|
||||
"google.protobuf.BoolValue": "bool",
|
||||
"google.protobuf.StringValue": "str",
|
||||
"google.protobuf.BytesValue": "bytes",
|
||||
}
|
||||
WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict(lambda: None, {
|
||||
'google.protobuf.DoubleValue': google_wrappers.DoubleValue,
|
||||
'google.protobuf.FloatValue': google_wrappers.FloatValue,
|
||||
'google.protobuf.Int64Value': google_wrappers.Int64Value,
|
||||
'google.protobuf.UInt64Value': google_wrappers.UInt64Value,
|
||||
'google.protobuf.Int32Value': google_wrappers.Int32Value,
|
||||
'google.protobuf.UInt32Value': google_wrappers.UInt32Value,
|
||||
'google.protobuf.BoolValue': google_wrappers.BoolValue,
|
||||
'google.protobuf.StringValue': google_wrappers.StringValue,
|
||||
'google.protobuf.BytesValue': google_wrappers.BytesValue,
|
||||
})
|
||||
|
||||
|
||||
def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
||||
def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True) -> str:
|
||||
"""
|
||||
Return a Python type name for a proto type reference. Adds the import if
|
||||
necessary.
|
||||
necessary. Unwraps well known type if required.
|
||||
"""
|
||||
# If the package name is a blank string, then this should still work
|
||||
# because by convention packages are lowercase and message/enum types are
|
||||
# pascal-cased. May require refactoring in the future.
|
||||
type_name = type_name.lstrip(".")
|
||||
|
||||
if type_name in WRAPPER_TYPES:
|
||||
return f"Optional[{WRAPPER_TYPES[type_name]}]"
|
||||
# Check if type is wrapper.
|
||||
wrapper_class = WRAPPER_TYPES[type_name]
|
||||
|
||||
if type_name == "google.protobuf.Duration":
|
||||
return "timedelta"
|
||||
if unwrap:
|
||||
if wrapper_class:
|
||||
wrapped_type = type(wrapper_class().value)
|
||||
return f"Optional[{wrapped_type.__name__}]"
|
||||
|
||||
if type_name == "google.protobuf.Timestamp":
|
||||
return "datetime"
|
||||
if type_name == "google.protobuf.Duration":
|
||||
return "timedelta"
|
||||
|
||||
if type_name == "google.protobuf.Timestamp":
|
||||
return "datetime"
|
||||
elif wrapper_class:
|
||||
imports.add(f"from {wrapper_class.__module__} import {wrapper_class.__name__}")
|
||||
return f"{wrapper_class.__name__}"
|
||||
|
||||
if type_name.startswith(package):
|
||||
parts = type_name.lstrip(package).lstrip(".").split(".")
|
||||
@ -379,7 +385,7 @@ def generate_code(request, response):
|
||||
).strip('"'),
|
||||
"input_message": input_message,
|
||||
"output": get_ref_type(
|
||||
package, output["imports"], method.output_type
|
||||
package, output["imports"], method.output_type, unwrap=False
|
||||
).strip('"'),
|
||||
"client_streaming": method.client_streaming,
|
||||
"server_streaming": method.server_streaming,
|
||||
|
@ -2,17 +2,20 @@ syntax = "proto3";
|
||||
|
||||
import "google/protobuf/wrappers.proto";
|
||||
|
||||
// Tests that wrapped values can be used directly as return values
|
||||
|
||||
service Test {
|
||||
rpc GetInt32 (Input) returns (google.protobuf.Int32Value);
|
||||
rpc GetAnotherInt32 (Input) returns (google.protobuf.Int32Value);
|
||||
rpc GetDouble (Input) returns (google.protobuf.DoubleValue);
|
||||
rpc GetFloat (Input) returns (google.protobuf.FloatValue);
|
||||
rpc GetInt64 (Input) returns (google.protobuf.Int64Value);
|
||||
rpc GetOutput (Input) returns (Output);
|
||||
rpc GetUInt64 (Input) returns (google.protobuf.UInt64Value);
|
||||
rpc GetInt32 (Input) returns (google.protobuf.Int32Value);
|
||||
rpc GetUInt32 (Input) returns (google.protobuf.UInt32Value);
|
||||
rpc GetBool (Input) returns (google.protobuf.BoolValue);
|
||||
rpc GetString (Input) returns (google.protobuf.StringValue);
|
||||
rpc GetBytes (Input) returns (google.protobuf.BytesValue);
|
||||
}
|
||||
|
||||
message Input {
|
||||
|
||||
}
|
||||
|
||||
message Output {
|
||||
google.protobuf.Int64Value int64 = 1;
|
||||
}
|
@ -1,20 +1,53 @@
|
||||
from typing import Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import google.protobuf.wrappers_pb2 as wrappers
|
||||
import pytest
|
||||
|
||||
from betterproto.tests.mocks import MockChannel
|
||||
from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import (
|
||||
TestStub
|
||||
TestStub,
|
||||
)
|
||||
|
||||
|
||||
class TestStubChild(TestStub):
|
||||
async def _unary_unary(self, route, request, response_type, **kwargs):
|
||||
self.response_type = response_type
|
||||
test_cases = [
|
||||
(TestStub.get_double, wrappers.DoubleValue, 2.5),
|
||||
(TestStub.get_float, wrappers.FloatValue, 2.5),
|
||||
(TestStub.get_int64, wrappers.Int64Value, -64),
|
||||
(TestStub.get_u_int64, wrappers.UInt64Value, 64),
|
||||
(TestStub.get_int32, wrappers.Int32Value, -32),
|
||||
(TestStub.get_u_int32, wrappers.UInt32Value, 32),
|
||||
(TestStub.get_bool, wrappers.BoolValue, True),
|
||||
(TestStub.get_string, wrappers.StringValue, "string"),
|
||||
(TestStub.get_bytes, wrappers.BytesValue, bytes(0xFF)[0:4]),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test():
|
||||
pytest.skip("todo")
|
||||
stub = TestStubChild(None)
|
||||
await stub.get_int64()
|
||||
assert stub.response_type != Optional[int]
|
||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||
async def test_channel_receives_wrapped_type(
|
||||
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
||||
):
|
||||
wrapped_value = wrapper_class()
|
||||
wrapped_value.value = value
|
||||
channel = MockChannel(responses=[wrapped_value])
|
||||
service = TestStub(channel)
|
||||
|
||||
await service_method(service)
|
||||
|
||||
assert channel.requests[0]["response_type"] != Optional[type(value)]
|
||||
assert channel.requests[0]["response_type"] == type(wrapped_value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||
async def test_service_unwraps_response(
|
||||
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
||||
):
|
||||
wrapped_value = wrapper_class()
|
||||
wrapped_value.value = value
|
||||
service = TestStub(MockChannel(responses=[wrapped_value]))
|
||||
|
||||
response_value = await service_method(service)
|
||||
|
||||
assert type(response_value) == value
|
||||
assert type(response_value) == type(value)
|
||||
|
@ -0,0 +1,24 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/wrappers.proto";
|
||||
|
||||
// Tests that wrapped values are supported as part of output message
|
||||
service Test {
|
||||
rpc getOutput (Input) returns (Output);
|
||||
}
|
||||
|
||||
message Input {
|
||||
|
||||
}
|
||||
|
||||
message Output {
|
||||
google.protobuf.DoubleValue double_value = 1;
|
||||
google.protobuf.FloatValue float_value = 2;
|
||||
google.protobuf.Int64Value int64_value = 3;
|
||||
google.protobuf.UInt64Value uint64_value = 4;
|
||||
google.protobuf.Int32Value int32_value = 5;
|
||||
google.protobuf.UInt32Value uint32_value = 6;
|
||||
google.protobuf.BoolValue bool_value = 7;
|
||||
google.protobuf.StringValue string_value = 8;
|
||||
google.protobuf.BytesValue bytes_value = 9;
|
||||
}
|
@ -0,0 +1,39 @@
|
||||
import pytest
|
||||
|
||||
from betterproto.tests.mocks import MockChannel
|
||||
from betterproto.tests.output_betterproto.googletypes_response_embedded.googletypes_response_embedded import (
|
||||
Output,
|
||||
TestStub,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_passes_through_unwrapped_values_embedded_in_response():
|
||||
"""
|
||||
We do not not need to implement value unwrapping for embedded well-known types,
|
||||
as this is already handled by grpclib. This test merely shows that this is the case.
|
||||
"""
|
||||
output = Output(
|
||||
double_value=10.0,
|
||||
float_value=12.0,
|
||||
int64_value=-13,
|
||||
uint64_value=14,
|
||||
int32_value=-15,
|
||||
uint32_value=16,
|
||||
bool_value=True,
|
||||
string_value="string",
|
||||
bytes_value=bytes(0xFF)[0:4],
|
||||
)
|
||||
|
||||
service = TestStub(MockChannel(responses=[output]))
|
||||
response = await service.get_output()
|
||||
|
||||
assert response.double_value == 10.0
|
||||
assert response.float_value == 12.0
|
||||
assert response.int64_value == -13
|
||||
assert response.uint64_value == 14
|
||||
assert response.int32_value == -15
|
||||
assert response.uint32_value == 16
|
||||
assert response.bool_value
|
||||
assert response.string_value == "string"
|
||||
assert response.bytes_value == bytes(0xFF)[0:4]
|
39
betterproto/tests/mocks.py
Normal file
39
betterproto/tests/mocks.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import List
|
||||
|
||||
from grpclib.client import Channel
|
||||
|
||||
|
||||
class MockChannel(Channel):
|
||||
# noinspection PyMissingConstructor
|
||||
def __init__(self, responses: List) -> None:
|
||||
self.responses = responses
|
||||
self.requests = []
|
||||
|
||||
def request(self, route, cardinality, request, response_type, **kwargs):
|
||||
self.requests.append(
|
||||
{
|
||||
"route": route,
|
||||
"cardinality": cardinality,
|
||||
"request": request,
|
||||
"response_type": response_type,
|
||||
}
|
||||
)
|
||||
return MockStream(self.responses)
|
||||
|
||||
|
||||
class MockStream:
|
||||
def __init__(self, responses: List) -> None:
|
||||
super().__init__()
|
||||
self.responses = responses
|
||||
|
||||
async def recv_message(self):
|
||||
return self.responses.pop(0)
|
||||
|
||||
async def send_message(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return True
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
@ -16,7 +16,11 @@ from google.protobuf.descriptor_pool import DescriptorPool
|
||||
from google.protobuf.json_format import Parse
|
||||
|
||||
|
||||
excluded_test_cases = {"googletypes_response", "service"}
|
||||
excluded_test_cases = {
|
||||
"googletypes_response",
|
||||
"googletypes_response_embedded",
|
||||
"service",
|
||||
}
|
||||
test_case_names = {*get_directories(inputs_path)} - excluded_test_cases
|
||||
|
||||
plugin_output_package = "betterproto.tests.output_betterproto"
|
||||
|
Loading…
x
Reference in New Issue
Block a user