Merge pull request #40 from boukeversteegh/pr/wrapper-as-output

Support using Google's wrapper types as RPC output values
This commit is contained in:
nat 2020-05-24 19:06:30 +02:00 committed by GitHub
commit 1a87ea43a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 192 additions and 44 deletions

View File

@ -1,12 +1,11 @@
#!/usr/bin/env python #!/usr/bin/env python
import itertools import itertools
import json
import os.path import os.path
import re
import sys import sys
import textwrap import textwrap
from typing import Any, List, Tuple from collections import defaultdict
from typing import Dict, List, Optional, Type
try: try:
import black import black
@ -24,44 +23,51 @@ from google.protobuf.descriptor_pb2 import (
DescriptorProto, DescriptorProto,
EnumDescriptorProto, EnumDescriptorProto,
FieldDescriptorProto, FieldDescriptorProto,
FileDescriptorProto,
ServiceDescriptorProto,
) )
from betterproto.casing import safe_snake_case from betterproto.casing import safe_snake_case
import google.protobuf.wrappers_pb2 as google_wrappers
WRAPPER_TYPES = { WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict(lambda: None, {
"google.protobuf.DoubleValue": "float", 'google.protobuf.DoubleValue': google_wrappers.DoubleValue,
"google.protobuf.FloatValue": "float", 'google.protobuf.FloatValue': google_wrappers.FloatValue,
"google.protobuf.Int64Value": "int", 'google.protobuf.Int64Value': google_wrappers.Int64Value,
"google.protobuf.UInt64Value": "int", 'google.protobuf.UInt64Value': google_wrappers.UInt64Value,
"google.protobuf.Int32Value": "int", 'google.protobuf.Int32Value': google_wrappers.Int32Value,
"google.protobuf.UInt32Value": "int", 'google.protobuf.UInt32Value': google_wrappers.UInt32Value,
"google.protobuf.BoolValue": "bool", 'google.protobuf.BoolValue': google_wrappers.BoolValue,
"google.protobuf.StringValue": "str", 'google.protobuf.StringValue': google_wrappers.StringValue,
"google.protobuf.BytesValue": "bytes", 'google.protobuf.BytesValue': google_wrappers.BytesValue,
} })
def get_ref_type(package: str, imports: set, type_name: str) -> str: def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True) -> str:
""" """
Return a Python type name for a proto type reference. Adds the import if Return a Python type name for a proto type reference. Adds the import if
necessary. necessary. Unwraps well known type if required.
""" """
# If the package name is a blank string, then this should still work # If the package name is a blank string, then this should still work
# because by convention packages are lowercase and message/enum types are # because by convention packages are lowercase and message/enum types are
# pascal-cased. May require refactoring in the future. # pascal-cased. May require refactoring in the future.
type_name = type_name.lstrip(".") type_name = type_name.lstrip(".")
if type_name in WRAPPER_TYPES: # Check if type is wrapper.
return f"Optional[{WRAPPER_TYPES[type_name]}]" wrapper_class = WRAPPER_TYPES[type_name]
if type_name == "google.protobuf.Duration": if unwrap:
return "timedelta" if wrapper_class:
wrapped_type = type(wrapper_class().value)
return f"Optional[{wrapped_type.__name__}]"
if type_name == "google.protobuf.Timestamp": if type_name == "google.protobuf.Duration":
return "datetime" return "timedelta"
if type_name == "google.protobuf.Timestamp":
return "datetime"
elif wrapper_class:
imports.add(f"from {wrapper_class.__module__} import {wrapper_class.__name__}")
return f"{wrapper_class.__name__}"
if type_name.startswith(package): if type_name.startswith(package):
parts = type_name.lstrip(package).lstrip(".").split(".") parts = type_name.lstrip(package).lstrip(".").split(".")
@ -379,7 +385,7 @@ def generate_code(request, response):
).strip('"'), ).strip('"'),
"input_message": input_message, "input_message": input_message,
"output": get_ref_type( "output": get_ref_type(
package, output["imports"], method.output_type package, output["imports"], method.output_type, unwrap=False
).strip('"'), ).strip('"'),
"client_streaming": method.client_streaming, "client_streaming": method.client_streaming,
"server_streaming": method.server_streaming, "server_streaming": method.server_streaming,

View File

@ -2,17 +2,20 @@ syntax = "proto3";
import "google/protobuf/wrappers.proto"; import "google/protobuf/wrappers.proto";
// Tests that wrapped values can be used directly as return values
service Test { service Test {
rpc GetInt32 (Input) returns (google.protobuf.Int32Value); rpc GetDouble (Input) returns (google.protobuf.DoubleValue);
rpc GetAnotherInt32 (Input) returns (google.protobuf.Int32Value); rpc GetFloat (Input) returns (google.protobuf.FloatValue);
rpc GetInt64 (Input) returns (google.protobuf.Int64Value); rpc GetInt64 (Input) returns (google.protobuf.Int64Value);
rpc GetOutput (Input) returns (Output); rpc GetUInt64 (Input) returns (google.protobuf.UInt64Value);
rpc GetInt32 (Input) returns (google.protobuf.Int32Value);
rpc GetUInt32 (Input) returns (google.protobuf.UInt32Value);
rpc GetBool (Input) returns (google.protobuf.BoolValue);
rpc GetString (Input) returns (google.protobuf.StringValue);
rpc GetBytes (Input) returns (google.protobuf.BytesValue);
} }
message Input { message Input {
} }
message Output {
google.protobuf.Int64Value int64 = 1;
}

View File

@ -1,20 +1,53 @@
from typing import Optional from typing import Any, Callable, Optional
import google.protobuf.wrappers_pb2 as wrappers
import pytest import pytest
from betterproto.tests.mocks import MockChannel
from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import ( from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import (
TestStub TestStub,
) )
test_cases = [
class TestStubChild(TestStub): (TestStub.get_double, wrappers.DoubleValue, 2.5),
async def _unary_unary(self, route, request, response_type, **kwargs): (TestStub.get_float, wrappers.FloatValue, 2.5),
self.response_type = response_type (TestStub.get_int64, wrappers.Int64Value, -64),
(TestStub.get_u_int64, wrappers.UInt64Value, 64),
(TestStub.get_int32, wrappers.Int32Value, -32),
(TestStub.get_u_int32, wrappers.UInt32Value, 32),
(TestStub.get_bool, wrappers.BoolValue, True),
(TestStub.get_string, wrappers.StringValue, "string"),
(TestStub.get_bytes, wrappers.BytesValue, bytes(0xFF)[0:4]),
]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test(): @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
pytest.skip("todo") async def test_channel_receives_wrapped_type(
stub = TestStubChild(None) service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
await stub.get_int64() ):
assert stub.response_type != Optional[int] wrapped_value = wrapper_class()
wrapped_value.value = value
channel = MockChannel(responses=[wrapped_value])
service = TestStub(channel)
await service_method(service)
assert channel.requests[0]["response_type"] != Optional[type(value)]
assert channel.requests[0]["response_type"] == type(wrapped_value)
@pytest.mark.asyncio
@pytest.mark.xfail
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_service_unwraps_response(
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
):
wrapped_value = wrapper_class()
wrapped_value.value = value
service = TestStub(MockChannel(responses=[wrapped_value]))
response_value = await service_method(service)
assert type(response_value) == value
assert type(response_value) == type(value)

View File

@ -0,0 +1,24 @@
syntax = "proto3";
import "google/protobuf/wrappers.proto";
// Tests that wrapped values are supported as part of output message
service Test {
rpc getOutput (Input) returns (Output);
}
message Input {
}
message Output {
google.protobuf.DoubleValue double_value = 1;
google.protobuf.FloatValue float_value = 2;
google.protobuf.Int64Value int64_value = 3;
google.protobuf.UInt64Value uint64_value = 4;
google.protobuf.Int32Value int32_value = 5;
google.protobuf.UInt32Value uint32_value = 6;
google.protobuf.BoolValue bool_value = 7;
google.protobuf.StringValue string_value = 8;
google.protobuf.BytesValue bytes_value = 9;
}

View File

@ -0,0 +1,39 @@
import pytest
from betterproto.tests.mocks import MockChannel
from betterproto.tests.output_betterproto.googletypes_response_embedded.googletypes_response_embedded import (
Output,
TestStub,
)
@pytest.mark.asyncio
async def test_service_passes_through_unwrapped_values_embedded_in_response():
"""
We do not not need to implement value unwrapping for embedded well-known types,
as this is already handled by grpclib. This test merely shows that this is the case.
"""
output = Output(
double_value=10.0,
float_value=12.0,
int64_value=-13,
uint64_value=14,
int32_value=-15,
uint32_value=16,
bool_value=True,
string_value="string",
bytes_value=bytes(0xFF)[0:4],
)
service = TestStub(MockChannel(responses=[output]))
response = await service.get_output()
assert response.double_value == 10.0
assert response.float_value == 12.0
assert response.int64_value == -13
assert response.uint64_value == 14
assert response.int32_value == -15
assert response.uint32_value == 16
assert response.bool_value
assert response.string_value == "string"
assert response.bytes_value == bytes(0xFF)[0:4]

View File

@ -0,0 +1,39 @@
from typing import List
from grpclib.client import Channel
class MockChannel(Channel):
# noinspection PyMissingConstructor
def __init__(self, responses: List) -> None:
self.responses = responses
self.requests = []
def request(self, route, cardinality, request, response_type, **kwargs):
self.requests.append(
{
"route": route,
"cardinality": cardinality,
"request": request,
"response_type": response_type,
}
)
return MockStream(self.responses)
class MockStream:
def __init__(self, responses: List) -> None:
super().__init__()
self.responses = responses
async def recv_message(self):
return self.responses.pop(0)
async def send_message(self, *args, **kwargs):
pass
async def __aexit__(self, exc_type, exc_val, exc_tb):
return True
async def __aenter__(self):
return self

View File

@ -16,7 +16,11 @@ from google.protobuf.descriptor_pool import DescriptorPool
from google.protobuf.json_format import Parse from google.protobuf.json_format import Parse
excluded_test_cases = {"googletypes_response", "service"} excluded_test_cases = {
"googletypes_response",
"googletypes_response_embedded",
"service",
}
test_case_names = {*get_directories(inputs_path)} - excluded_test_cases test_case_names = {*get_directories(inputs_path)} - excluded_test_cases
plugin_output_package = "betterproto.tests.output_betterproto" plugin_output_package = "betterproto.tests.output_betterproto"