Merge pull request #40 from boukeversteegh/pr/wrapper-as-output
Support using Google's wrapper types as RPC output values
This commit is contained in:
		| @@ -1,12 +1,11 @@ | |||||||
| #!/usr/bin/env python | #!/usr/bin/env python | ||||||
|  |  | ||||||
| import itertools | import itertools | ||||||
| import json |  | ||||||
| import os.path | import os.path | ||||||
| import re |  | ||||||
| import sys | import sys | ||||||
| import textwrap | import textwrap | ||||||
| from typing import Any, List, Tuple | from collections import defaultdict | ||||||
|  | from typing import Dict, List, Optional, Type | ||||||
|  |  | ||||||
| try: | try: | ||||||
|     import black |     import black | ||||||
| @@ -24,44 +23,51 @@ from google.protobuf.descriptor_pb2 import ( | |||||||
|     DescriptorProto, |     DescriptorProto, | ||||||
|     EnumDescriptorProto, |     EnumDescriptorProto, | ||||||
|     FieldDescriptorProto, |     FieldDescriptorProto, | ||||||
|     FileDescriptorProto, |  | ||||||
|     ServiceDescriptorProto, |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| from betterproto.casing import safe_snake_case | from betterproto.casing import safe_snake_case | ||||||
|  |  | ||||||
|  | import google.protobuf.wrappers_pb2 as google_wrappers | ||||||
|  |  | ||||||
| WRAPPER_TYPES = { | WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict(lambda: None, { | ||||||
|     "google.protobuf.DoubleValue": "float", |     'google.protobuf.DoubleValue': google_wrappers.DoubleValue, | ||||||
|     "google.protobuf.FloatValue": "float", |     'google.protobuf.FloatValue': google_wrappers.FloatValue, | ||||||
|     "google.protobuf.Int64Value": "int", |     'google.protobuf.Int64Value': google_wrappers.Int64Value, | ||||||
|     "google.protobuf.UInt64Value": "int", |     'google.protobuf.UInt64Value': google_wrappers.UInt64Value, | ||||||
|     "google.protobuf.Int32Value": "int", |     'google.protobuf.Int32Value': google_wrappers.Int32Value, | ||||||
|     "google.protobuf.UInt32Value": "int", |     'google.protobuf.UInt32Value': google_wrappers.UInt32Value, | ||||||
|     "google.protobuf.BoolValue": "bool", |     'google.protobuf.BoolValue': google_wrappers.BoolValue, | ||||||
|     "google.protobuf.StringValue": "str", |     'google.protobuf.StringValue': google_wrappers.StringValue, | ||||||
|     "google.protobuf.BytesValue": "bytes", |     'google.protobuf.BytesValue': google_wrappers.BytesValue, | ||||||
| } | }) | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_ref_type(package: str, imports: set, type_name: str) -> str: | def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True) -> str: | ||||||
|     """ |     """ | ||||||
|     Return a Python type name for a proto type reference. Adds the import if |     Return a Python type name for a proto type reference. Adds the import if | ||||||
|     necessary. |     necessary. Unwraps well known type if required. | ||||||
|     """ |     """ | ||||||
|     # If the package name is a blank string, then this should still work |     # If the package name is a blank string, then this should still work | ||||||
|     # because by convention packages are lowercase and message/enum types are |     # because by convention packages are lowercase and message/enum types are | ||||||
|     # pascal-cased. May require refactoring in the future. |     # pascal-cased. May require refactoring in the future. | ||||||
|     type_name = type_name.lstrip(".") |     type_name = type_name.lstrip(".") | ||||||
|  |  | ||||||
|     if type_name in WRAPPER_TYPES: |     # Check if type is wrapper. | ||||||
|         return f"Optional[{WRAPPER_TYPES[type_name]}]" |     wrapper_class = WRAPPER_TYPES[type_name] | ||||||
|  |  | ||||||
|  |     if unwrap: | ||||||
|  |         if wrapper_class: | ||||||
|  |             wrapped_type = type(wrapper_class().value) | ||||||
|  |             return f"Optional[{wrapped_type.__name__}]" | ||||||
|  |  | ||||||
|         if type_name == "google.protobuf.Duration": |         if type_name == "google.protobuf.Duration": | ||||||
|             return "timedelta" |             return "timedelta" | ||||||
|  |  | ||||||
|         if type_name == "google.protobuf.Timestamp": |         if type_name == "google.protobuf.Timestamp": | ||||||
|             return "datetime" |             return "datetime" | ||||||
|  |     elif wrapper_class: | ||||||
|  |         imports.add(f"from {wrapper_class.__module__} import {wrapper_class.__name__}") | ||||||
|  |         return f"{wrapper_class.__name__}" | ||||||
|  |  | ||||||
|     if type_name.startswith(package): |     if type_name.startswith(package): | ||||||
|         parts = type_name.lstrip(package).lstrip(".").split(".") |         parts = type_name.lstrip(package).lstrip(".").split(".") | ||||||
| @@ -379,7 +385,7 @@ def generate_code(request, response): | |||||||
|                             ).strip('"'), |                             ).strip('"'), | ||||||
|                             "input_message": input_message, |                             "input_message": input_message, | ||||||
|                             "output": get_ref_type( |                             "output": get_ref_type( | ||||||
|                                 package, output["imports"], method.output_type |                                 package, output["imports"], method.output_type, unwrap=False | ||||||
|                             ).strip('"'), |                             ).strip('"'), | ||||||
|                             "client_streaming": method.client_streaming, |                             "client_streaming": method.client_streaming, | ||||||
|                             "server_streaming": method.server_streaming, |                             "server_streaming": method.server_streaming, | ||||||
|   | |||||||
| @@ -2,17 +2,20 @@ syntax = "proto3"; | |||||||
|  |  | ||||||
| import "google/protobuf/wrappers.proto"; | import "google/protobuf/wrappers.proto"; | ||||||
|  |  | ||||||
|  | // Tests that wrapped values can be used directly as return values | ||||||
|  |  | ||||||
| service Test { | service Test { | ||||||
|     rpc GetInt32 (Input) returns (google.protobuf.Int32Value); |     rpc GetDouble (Input) returns (google.protobuf.DoubleValue); | ||||||
|     rpc GetAnotherInt32 (Input) returns (google.protobuf.Int32Value); |     rpc GetFloat (Input) returns (google.protobuf.FloatValue); | ||||||
|     rpc GetInt64 (Input) returns (google.protobuf.Int64Value); |     rpc GetInt64 (Input) returns (google.protobuf.Int64Value); | ||||||
|     rpc GetOutput (Input) returns (Output); |     rpc GetUInt64 (Input) returns (google.protobuf.UInt64Value); | ||||||
|  |     rpc GetInt32 (Input) returns (google.protobuf.Int32Value); | ||||||
|  |     rpc GetUInt32 (Input) returns (google.protobuf.UInt32Value); | ||||||
|  |     rpc GetBool (Input) returns (google.protobuf.BoolValue); | ||||||
|  |     rpc GetString (Input) returns (google.protobuf.StringValue); | ||||||
|  |     rpc GetBytes (Input) returns (google.protobuf.BytesValue); | ||||||
| } | } | ||||||
|  |  | ||||||
| message Input { | message Input { | ||||||
|  |  | ||||||
| } | } | ||||||
|  |  | ||||||
| message Output { |  | ||||||
|     google.protobuf.Int64Value int64 = 1; |  | ||||||
| } |  | ||||||
| @@ -1,20 +1,53 @@ | |||||||
| from typing import Optional | from typing import Any, Callable, Optional | ||||||
|  |  | ||||||
|  | import google.protobuf.wrappers_pb2 as wrappers | ||||||
| import pytest | import pytest | ||||||
|  |  | ||||||
|  | from betterproto.tests.mocks import MockChannel | ||||||
| from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import ( | from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import ( | ||||||
|     TestStub |     TestStub, | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | test_cases = [ | ||||||
| class TestStubChild(TestStub): |     (TestStub.get_double, wrappers.DoubleValue, 2.5), | ||||||
|     async def _unary_unary(self, route, request, response_type, **kwargs): |     (TestStub.get_float, wrappers.FloatValue, 2.5), | ||||||
|         self.response_type = response_type |     (TestStub.get_int64, wrappers.Int64Value, -64), | ||||||
|  |     (TestStub.get_u_int64, wrappers.UInt64Value, 64), | ||||||
|  |     (TestStub.get_int32, wrappers.Int32Value, -32), | ||||||
|  |     (TestStub.get_u_int32, wrappers.UInt32Value, 32), | ||||||
|  |     (TestStub.get_bool, wrappers.BoolValue, True), | ||||||
|  |     (TestStub.get_string, wrappers.StringValue, "string"), | ||||||
|  |     (TestStub.get_bytes, wrappers.BytesValue, bytes(0xFF)[0:4]), | ||||||
|  | ] | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.asyncio | @pytest.mark.asyncio | ||||||
| async def test(): | @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) | ||||||
|     pytest.skip("todo") | async def test_channel_receives_wrapped_type( | ||||||
|     stub = TestStubChild(None) |     service_method: Callable[[TestStub], Any], wrapper_class: Callable, value | ||||||
|     await stub.get_int64() | ): | ||||||
|     assert stub.response_type != Optional[int] |     wrapped_value = wrapper_class() | ||||||
|  |     wrapped_value.value = value | ||||||
|  |     channel = MockChannel(responses=[wrapped_value]) | ||||||
|  |     service = TestStub(channel) | ||||||
|  |  | ||||||
|  |     await service_method(service) | ||||||
|  |  | ||||||
|  |     assert channel.requests[0]["response_type"] != Optional[type(value)] | ||||||
|  |     assert channel.requests[0]["response_type"] == type(wrapped_value) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | @pytest.mark.xfail | ||||||
|  | @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) | ||||||
|  | async def test_service_unwraps_response( | ||||||
|  |     service_method: Callable[[TestStub], Any], wrapper_class: Callable, value | ||||||
|  | ): | ||||||
|  |     wrapped_value = wrapper_class() | ||||||
|  |     wrapped_value.value = value | ||||||
|  |     service = TestStub(MockChannel(responses=[wrapped_value])) | ||||||
|  |  | ||||||
|  |     response_value = await service_method(service) | ||||||
|  |  | ||||||
|  |     assert type(response_value) == value | ||||||
|  |     assert type(response_value) == type(value) | ||||||
|   | |||||||
| @@ -0,0 +1,24 @@ | |||||||
|  | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | import "google/protobuf/wrappers.proto"; | ||||||
|  |  | ||||||
|  | // Tests that wrapped values are supported as part of output message | ||||||
|  | service Test { | ||||||
|  |     rpc getOutput (Input) returns (Output); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | message Input { | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
|  | message Output { | ||||||
|  |     google.protobuf.DoubleValue double_value = 1; | ||||||
|  |     google.protobuf.FloatValue float_value = 2; | ||||||
|  |     google.protobuf.Int64Value int64_value = 3; | ||||||
|  |     google.protobuf.UInt64Value uint64_value = 4; | ||||||
|  |     google.protobuf.Int32Value int32_value = 5; | ||||||
|  |     google.protobuf.UInt32Value uint32_value = 6; | ||||||
|  |     google.protobuf.BoolValue bool_value = 7; | ||||||
|  |     google.protobuf.StringValue string_value = 8; | ||||||
|  |     google.protobuf.BytesValue bytes_value = 9; | ||||||
|  | } | ||||||
| @@ -0,0 +1,39 @@ | |||||||
|  | import pytest | ||||||
|  |  | ||||||
|  | from betterproto.tests.mocks import MockChannel | ||||||
|  | from betterproto.tests.output_betterproto.googletypes_response_embedded.googletypes_response_embedded import ( | ||||||
|  |     Output, | ||||||
|  |     TestStub, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | async def test_service_passes_through_unwrapped_values_embedded_in_response(): | ||||||
|  |     """ | ||||||
|  |     We do not not need to implement value unwrapping for embedded well-known types, | ||||||
|  |     as this is already handled by grpclib. This test merely shows that this is the case. | ||||||
|  |     """ | ||||||
|  |     output = Output( | ||||||
|  |         double_value=10.0, | ||||||
|  |         float_value=12.0, | ||||||
|  |         int64_value=-13, | ||||||
|  |         uint64_value=14, | ||||||
|  |         int32_value=-15, | ||||||
|  |         uint32_value=16, | ||||||
|  |         bool_value=True, | ||||||
|  |         string_value="string", | ||||||
|  |         bytes_value=bytes(0xFF)[0:4], | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     service = TestStub(MockChannel(responses=[output])) | ||||||
|  |     response = await service.get_output() | ||||||
|  |  | ||||||
|  |     assert response.double_value == 10.0 | ||||||
|  |     assert response.float_value == 12.0 | ||||||
|  |     assert response.int64_value == -13 | ||||||
|  |     assert response.uint64_value == 14 | ||||||
|  |     assert response.int32_value == -15 | ||||||
|  |     assert response.uint32_value == 16 | ||||||
|  |     assert response.bool_value | ||||||
|  |     assert response.string_value == "string" | ||||||
|  |     assert response.bytes_value == bytes(0xFF)[0:4] | ||||||
							
								
								
									
										39
									
								
								betterproto/tests/mocks.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										39
									
								
								betterproto/tests/mocks.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,39 @@ | |||||||
|  | from typing import List | ||||||
|  |  | ||||||
|  | from grpclib.client import Channel | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MockChannel(Channel): | ||||||
|  |     # noinspection PyMissingConstructor | ||||||
|  |     def __init__(self, responses: List) -> None: | ||||||
|  |         self.responses = responses | ||||||
|  |         self.requests = [] | ||||||
|  |  | ||||||
|  |     def request(self, route, cardinality, request, response_type, **kwargs): | ||||||
|  |         self.requests.append( | ||||||
|  |             { | ||||||
|  |                 "route": route, | ||||||
|  |                 "cardinality": cardinality, | ||||||
|  |                 "request": request, | ||||||
|  |                 "response_type": response_type, | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |         return MockStream(self.responses) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MockStream: | ||||||
|  |     def __init__(self, responses: List) -> None: | ||||||
|  |         super().__init__() | ||||||
|  |         self.responses = responses | ||||||
|  |  | ||||||
|  |     async def recv_message(self): | ||||||
|  |         return self.responses.pop(0) | ||||||
|  |  | ||||||
|  |     async def send_message(self, *args, **kwargs): | ||||||
|  |         pass | ||||||
|  |  | ||||||
|  |     async def __aexit__(self, exc_type, exc_val, exc_tb): | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     async def __aenter__(self): | ||||||
|  |         return self | ||||||
| @@ -16,7 +16,11 @@ from google.protobuf.descriptor_pool import DescriptorPool | |||||||
| from google.protobuf.json_format import Parse | from google.protobuf.json_format import Parse | ||||||
|  |  | ||||||
|  |  | ||||||
| excluded_test_cases = {"googletypes_response", "service"} | excluded_test_cases = { | ||||||
|  |     "googletypes_response", | ||||||
|  |     "googletypes_response_embedded", | ||||||
|  |     "service", | ||||||
|  | } | ||||||
| test_case_names = {*get_directories(inputs_path)} - excluded_test_cases | test_case_names = {*get_directories(inputs_path)} - excluded_test_cases | ||||||
|  |  | ||||||
| plugin_output_package = "betterproto.tests.output_betterproto" | plugin_output_package = "betterproto.tests.output_betterproto" | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user