From 499489f1d342029f11f78fdd70ce34208e086299 Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Sun, 10 May 2020 16:34:20 +0200 Subject: [PATCH 1/5] Support using Google's wrapper types as RPC output values --- betterproto/plugin.py | 53 ++++++++++++++------- betterproto/tests/googletypes_service.proto | 18 +++++++ 2 files changed, 53 insertions(+), 18 deletions(-) create mode 100644 betterproto/tests/googletypes_service.proto diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 597bf1a..d70786b 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -30,38 +30,55 @@ from google.protobuf.descriptor_pb2 import ( from betterproto.casing import safe_snake_case +import google.protobuf.wrappers_pb2 + WRAPPER_TYPES = { - "google.protobuf.DoubleValue": "float", - "google.protobuf.FloatValue": "float", - "google.protobuf.Int64Value": "int", - "google.protobuf.UInt64Value": "int", - "google.protobuf.Int32Value": "int", - "google.protobuf.UInt32Value": "int", - "google.protobuf.BoolValue": "bool", - "google.protobuf.StringValue": "str", - "google.protobuf.BytesValue": "bytes", + google.protobuf.wrappers_pb2.DoubleValue: "float", + google.protobuf.wrappers_pb2.FloatValue: "float", + google.protobuf.wrappers_pb2.Int64Value: "int", + google.protobuf.wrappers_pb2.UInt64Value: "int", + google.protobuf.wrappers_pb2.Int32Value: "int", + google.protobuf.wrappers_pb2.UInt32Value: "int", + google.protobuf.wrappers_pb2.BoolValue: "bool", + google.protobuf.wrappers_pb2.StringValue: "str", + google.protobuf.wrappers_pb2.BytesValue: "bytes", } -def get_ref_type(package: str, imports: set, type_name: str) -> str: +def get_wrapper_type(type_name: str) -> (Any, str): + for wrapper, wrapped_type in WRAPPER_TYPES.items(): + if wrapper.DESCRIPTOR.full_name == type_name: + return wrapper, wrapped_type + return None, None + + +def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True) -> str: """ Return a Python type name for a proto type reference. Adds the import if - necessary. + necessary. Unwraps well known type if required. """ # If the package name is a blank string, then this should still work # because by convention packages are lowercase and message/enum types are # pascal-cased. May require refactoring in the future. type_name = type_name.lstrip(".") - if type_name in WRAPPER_TYPES: - return f"Optional[{WRAPPER_TYPES[type_name]}]" + # Check if type is wrapper. + wrapper, wrapped_type = get_wrapper_type(type_name) - if type_name == "google.protobuf.Duration": - return "timedelta" + if unwrap: + if wrapper: + return f"Optional[{wrapped_type}]" - if type_name == "google.protobuf.Timestamp": - return "datetime" + if type_name == "google.protobuf.Duration": + return "timedelta" + + if type_name == "google.protobuf.Timestamp": + return "datetime" + else: + if wrapper: + imports.add(f"from {wrapper.__module__} import {wrapper.__name__}") + return f"{wrapper.__name__}" if type_name.startswith(package): parts = type_name.lstrip(package).lstrip(".").split(".") @@ -379,7 +396,7 @@ def generate_code(request, response): ).strip('"'), "input_message": input_message, "output": get_ref_type( - package, output["imports"], method.output_type + package, output["imports"], method.output_type, unwrap=False ).strip('"'), "client_streaming": method.client_streaming, "server_streaming": method.server_streaming, diff --git a/betterproto/tests/googletypes_service.proto b/betterproto/tests/googletypes_service.proto new file mode 100644 index 0000000..4bdca68 --- /dev/null +++ b/betterproto/tests/googletypes_service.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +import "google/protobuf/wrappers.proto"; + +service Test { + rpc GetInt32 (Input) returns (google.protobuf.Int32Value); + rpc GetAnotherInt32 (Input) returns (google.protobuf.Int32Value); + rpc GetInt64 (Input) returns (google.protobuf.Int64Value); + rpc GetOutput (Input) returns (Output); +} + +message Input { + +} + +message Output { + google.protobuf.Int64Value int64 = 1; +} \ No newline at end of file From 7e9ba0866c8d636296f463345ec42609a3bae4ea Mon Sep 17 00:00:00 2001 From: Bouke Versteegh Date: Thu, 21 May 2020 22:55:26 +0200 Subject: [PATCH 2/5] cleanup --- betterproto/plugin.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/betterproto/plugin.py b/betterproto/plugin.py index d70786b..83b87d6 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -75,10 +75,9 @@ def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True if type_name == "google.protobuf.Timestamp": return "datetime" - else: - if wrapper: - imports.add(f"from {wrapper.__module__} import {wrapper.__name__}") - return f"{wrapper.__name__}" + elif wrapper: + imports.add(f"from {wrapper.__module__} import {wrapper.__name__}") + return f"{wrapper.__name__}" if type_name.startswith(package): parts = type_name.lstrip(package).lstrip(".").split(".") From 35548cb43e7d71ecc2c2ccc4a076f1bf7ea7c032 Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Sun, 24 May 2020 12:33:36 +0200 Subject: [PATCH 3/5] Test all supported wrapper types. Add xfail test for unwrapping the value --- betterproto/tests/googletypes_service.proto | 18 ------ .../googletypes_response.proto | 17 +++--- .../test_googletypes_response.py | 55 +++++++++++++++---- betterproto/tests/mocks.py | 39 +++++++++++++ 4 files changed, 93 insertions(+), 36 deletions(-) delete mode 100644 betterproto/tests/googletypes_service.proto create mode 100644 betterproto/tests/mocks.py diff --git a/betterproto/tests/googletypes_service.proto b/betterproto/tests/googletypes_service.proto deleted file mode 100644 index 4bdca68..0000000 --- a/betterproto/tests/googletypes_service.proto +++ /dev/null @@ -1,18 +0,0 @@ -syntax = "proto3"; - -import "google/protobuf/wrappers.proto"; - -service Test { - rpc GetInt32 (Input) returns (google.protobuf.Int32Value); - rpc GetAnotherInt32 (Input) returns (google.protobuf.Int32Value); - rpc GetInt64 (Input) returns (google.protobuf.Int64Value); - rpc GetOutput (Input) returns (Output); -} - -message Input { - -} - -message Output { - google.protobuf.Int64Value int64 = 1; -} \ No newline at end of file diff --git a/betterproto/tests/inputs/googletypes_response/googletypes_response.proto b/betterproto/tests/inputs/googletypes_response/googletypes_response.proto index 4bdca68..ee8dbbd 100644 --- a/betterproto/tests/inputs/googletypes_response/googletypes_response.proto +++ b/betterproto/tests/inputs/googletypes_response/googletypes_response.proto @@ -2,17 +2,20 @@ syntax = "proto3"; import "google/protobuf/wrappers.proto"; +// Tests that wrapped return values can be used + service Test { - rpc GetInt32 (Input) returns (google.protobuf.Int32Value); - rpc GetAnotherInt32 (Input) returns (google.protobuf.Int32Value); + rpc GetDouble (Input) returns (google.protobuf.DoubleValue); + rpc GetFloat (Input) returns (google.protobuf.FloatValue); rpc GetInt64 (Input) returns (google.protobuf.Int64Value); - rpc GetOutput (Input) returns (Output); + rpc GetUInt64 (Input) returns (google.protobuf.UInt64Value); + rpc GetInt32 (Input) returns (google.protobuf.Int32Value); + rpc GetUInt32 (Input) returns (google.protobuf.UInt32Value); + rpc GetBool (Input) returns (google.protobuf.BoolValue); + rpc GetString (Input) returns (google.protobuf.StringValue); + rpc GetBytes (Input) returns (google.protobuf.BytesValue); } message Input { } - -message Output { - google.protobuf.Int64Value int64 = 1; -} \ No newline at end of file diff --git a/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py b/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py index fba2070..76c012b 100644 --- a/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py +++ b/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py @@ -1,20 +1,53 @@ -from typing import Optional +from typing import Any, Callable, Optional +import google.protobuf.wrappers_pb2 as wrappers import pytest +from betterproto.tests.mocks import MockChannel from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import ( - TestStub + TestStub, ) - -class TestStubChild(TestStub): - async def _unary_unary(self, route, request, response_type, **kwargs): - self.response_type = response_type +test_cases = [ + (TestStub.get_double, wrappers.DoubleValue, 2.5), + (TestStub.get_float, wrappers.FloatValue, 2.5), + (TestStub.get_int64, wrappers.Int64Value, -64), + (TestStub.get_u_int64, wrappers.UInt64Value, 64), + (TestStub.get_int32, wrappers.Int32Value, -32), + (TestStub.get_u_int32, wrappers.UInt32Value, 32), + (TestStub.get_bool, wrappers.BoolValue, True), + (TestStub.get_string, wrappers.StringValue, "string"), + (TestStub.get_bytes, wrappers.BytesValue, bytes(0xFF)[0:4]), +] @pytest.mark.asyncio -async def test(): - pytest.skip("todo") - stub = TestStubChild(None) - await stub.get_int64() - assert stub.response_type != Optional[int] +@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) +async def test_channel_receives_wrapped_type( + service_method: Callable[[TestStub], Any], wrapper_class: Callable, value +): + wrapped_value = wrapper_class() + wrapped_value.value = value + channel = MockChannel(responses=[wrapped_value]) + service = TestStub(channel) + + await service_method(service) + + assert channel.requests[0]["response_type"] != Optional[type(value)] + assert channel.requests[0]["response_type"] == type(wrapped_value) + + +@pytest.mark.asyncio +@pytest.mark.xfail +@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) +async def test_service_unwraps_response( + service_method: Callable[[TestStub], Any], wrapper_class: Callable, value +): + wrapped_value = wrapper_class() + wrapped_value.value = value + service = TestStub(MockChannel(responses=[wrapped_value])) + + response_value = await service_method(service) + + assert type(response_value) == value + assert type(response_value) == type(value) diff --git a/betterproto/tests/mocks.py b/betterproto/tests/mocks.py new file mode 100644 index 0000000..287a58f --- /dev/null +++ b/betterproto/tests/mocks.py @@ -0,0 +1,39 @@ +from typing import List + +from grpclib.client import Channel + + +class MockChannel(Channel): + # noinspection PyMissingConstructor + def __init__(self, responses: List) -> None: + self.responses = responses + self.requests = [] + + def request(self, route, cardinality, request, response_type, **kwargs): + self.requests.append( + { + "route": route, + "cardinality": cardinality, + "request": request, + "response_type": response_type, + } + ) + return MockStream(self.responses) + + +class MockStream: + def __init__(self, responses: List) -> None: + super().__init__() + self.responses = responses + + async def recv_message(self): + return next(self.responses) + + async def send_message(self, *args, **kwargs): + pass + + async def __aexit__(self, exc_type, exc_val, exc_tb): + return True + + async def __aenter__(self): + return True From c50d9e2fdcb6c7aa3592c40fc9971ce08a67dc6a Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Sun, 24 May 2020 14:48:39 +0200 Subject: [PATCH 4/5] Add test for generating embedded wellknown types in outputs. --- .../googletypes_response.proto | 2 +- .../googletypes_response_embedded.proto | 24 ++++++++++++ .../test_googletypes_response_embedded.py | 39 +++++++++++++++++++ betterproto/tests/mocks.py | 4 +- betterproto/tests/test_inputs.py | 6 ++- 5 files changed, 71 insertions(+), 4 deletions(-) create mode 100644 betterproto/tests/inputs/googletypes_response_embedded/googletypes_response_embedded.proto create mode 100644 betterproto/tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py diff --git a/betterproto/tests/inputs/googletypes_response/googletypes_response.proto b/betterproto/tests/inputs/googletypes_response/googletypes_response.proto index ee8dbbd..9b0be5c 100644 --- a/betterproto/tests/inputs/googletypes_response/googletypes_response.proto +++ b/betterproto/tests/inputs/googletypes_response/googletypes_response.proto @@ -2,7 +2,7 @@ syntax = "proto3"; import "google/protobuf/wrappers.proto"; -// Tests that wrapped return values can be used +// Tests that wrapped values can be used directly as return values service Test { rpc GetDouble (Input) returns (google.protobuf.DoubleValue); diff --git a/betterproto/tests/inputs/googletypes_response_embedded/googletypes_response_embedded.proto b/betterproto/tests/inputs/googletypes_response_embedded/googletypes_response_embedded.proto new file mode 100644 index 0000000..89ae4cc --- /dev/null +++ b/betterproto/tests/inputs/googletypes_response_embedded/googletypes_response_embedded.proto @@ -0,0 +1,24 @@ +syntax = "proto3"; + +import "google/protobuf/wrappers.proto"; + +// Tests that wrapped values are supported as part of output message +service Test { + rpc getOutput (Input) returns (Output); +} + +message Input { + +} + +message Output { + google.protobuf.DoubleValue double_value = 1; + google.protobuf.FloatValue float_value = 2; + google.protobuf.Int64Value int64_value = 3; + google.protobuf.UInt64Value uint64_value = 4; + google.protobuf.Int32Value int32_value = 5; + google.protobuf.UInt32Value uint32_value = 6; + google.protobuf.BoolValue bool_value = 7; + google.protobuf.StringValue string_value = 8; + google.protobuf.BytesValue bytes_value = 9; +} diff --git a/betterproto/tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py b/betterproto/tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py new file mode 100644 index 0000000..00b980a --- /dev/null +++ b/betterproto/tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py @@ -0,0 +1,39 @@ +import pytest + +from betterproto.tests.mocks import MockChannel +from betterproto.tests.output_betterproto.googletypes_response_embedded.googletypes_response_embedded import ( + Output, + TestStub, +) + + +@pytest.mark.asyncio +async def test_service_passes_through_unwrapped_values_embedded_in_response(): + """ + We do not not need to implement value unwrapping for embedded well-known types, + as this is already handled by grpclib. This test merely shows that this is the case. + """ + output = Output( + double_value=10.0, + float_value=12.0, + int64_value=-13, + uint64_value=14, + int32_value=-15, + uint32_value=16, + bool_value=True, + string_value="string", + bytes_value=bytes(0xFF)[0:4], + ) + + service = TestStub(MockChannel(responses=[output])) + response = await service.get_output() + + assert response.double_value == 10.0 + assert response.float_value == 12.0 + assert response.int64_value == -13 + assert response.uint64_value == 14 + assert response.int32_value == -15 + assert response.uint32_value == 16 + assert response.bool_value + assert response.string_value == "string" + assert response.bytes_value == bytes(0xFF)[0:4] diff --git a/betterproto/tests/mocks.py b/betterproto/tests/mocks.py index 287a58f..326b892 100644 --- a/betterproto/tests/mocks.py +++ b/betterproto/tests/mocks.py @@ -27,7 +27,7 @@ class MockStream: self.responses = responses async def recv_message(self): - return next(self.responses) + return self.responses.pop(0) async def send_message(self, *args, **kwargs): pass @@ -36,4 +36,4 @@ class MockStream: return True async def __aenter__(self): - return True + return self diff --git a/betterproto/tests/test_inputs.py b/betterproto/tests/test_inputs.py index c8fb7d3..b1041e5 100644 --- a/betterproto/tests/test_inputs.py +++ b/betterproto/tests/test_inputs.py @@ -15,7 +15,11 @@ from google.protobuf.descriptor_pool import DescriptorPool from google.protobuf.json_format import Parse -excluded_test_cases = {"googletypes_response", "service"} +excluded_test_cases = { + "googletypes_response", + "googletypes_response_embedded", + "service", +} test_case_names = {*get_directories(inputs_path)} - excluded_test_cases plugin_output_package = "betterproto.tests.output_betterproto" From 8f0caf1db2c455f02a47258cd9a9ec32d3cc9dca Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Sun, 24 May 2020 14:50:56 +0200 Subject: [PATCH 5/5] Read desired wrapper type directly from wrapper definition --- betterproto/plugin.py | 52 +++++++++++++++++-------------------------- 1 file changed, 21 insertions(+), 31 deletions(-) diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 7476b38..2184a9c 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -1,12 +1,11 @@ #!/usr/bin/env python import itertools -import json import os.path -import re import sys import textwrap -from typing import Any, List, Tuple +from collections import defaultdict +from typing import Dict, List, Optional, Type try: import black @@ -24,33 +23,23 @@ from google.protobuf.descriptor_pb2 import ( DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, - FileDescriptorProto, - ServiceDescriptorProto, ) from betterproto.casing import safe_snake_case -import google.protobuf.wrappers_pb2 +import google.protobuf.wrappers_pb2 as google_wrappers - -WRAPPER_TYPES = { - google.protobuf.wrappers_pb2.DoubleValue: "float", - google.protobuf.wrappers_pb2.FloatValue: "float", - google.protobuf.wrappers_pb2.Int64Value: "int", - google.protobuf.wrappers_pb2.UInt64Value: "int", - google.protobuf.wrappers_pb2.Int32Value: "int", - google.protobuf.wrappers_pb2.UInt32Value: "int", - google.protobuf.wrappers_pb2.BoolValue: "bool", - google.protobuf.wrappers_pb2.StringValue: "str", - google.protobuf.wrappers_pb2.BytesValue: "bytes", -} - - -def get_wrapper_type(type_name: str) -> (Any, str): - for wrapper, wrapped_type in WRAPPER_TYPES.items(): - if wrapper.DESCRIPTOR.full_name == type_name: - return wrapper, wrapped_type - return None, None +WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict(lambda: None, { + 'google.protobuf.DoubleValue': google_wrappers.DoubleValue, + 'google.protobuf.FloatValue': google_wrappers.FloatValue, + 'google.protobuf.Int64Value': google_wrappers.Int64Value, + 'google.protobuf.UInt64Value': google_wrappers.UInt64Value, + 'google.protobuf.Int32Value': google_wrappers.Int32Value, + 'google.protobuf.UInt32Value': google_wrappers.UInt32Value, + 'google.protobuf.BoolValue': google_wrappers.BoolValue, + 'google.protobuf.StringValue': google_wrappers.StringValue, + 'google.protobuf.BytesValue': google_wrappers.BytesValue, +}) def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True) -> str: @@ -64,20 +53,21 @@ def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True type_name = type_name.lstrip(".") # Check if type is wrapper. - wrapper, wrapped_type = get_wrapper_type(type_name) + wrapper_class = WRAPPER_TYPES[type_name] if unwrap: - if wrapper: - return f"Optional[{wrapped_type}]" + if wrapper_class: + wrapped_type = type(wrapper_class().value) + return f"Optional[{wrapped_type.__name__}]" if type_name == "google.protobuf.Duration": return "timedelta" if type_name == "google.protobuf.Timestamp": return "datetime" - elif wrapper: - imports.add(f"from {wrapper.__module__} import {wrapper.__name__}") - return f"{wrapper.__name__}" + elif wrapper_class: + imports.add(f"from {wrapper_class.__module__} import {wrapper_class.__name__}") + return f"{wrapper_class.__name__}" if type_name.startswith(package): parts = type_name.lstrip(package).lstrip(".").split(".")