Support using Google's wrapper types as RPC output values

This commit is contained in:
boukeversteegh 2020-05-10 16:34:20 +02:00
parent ce9f492f50
commit 499489f1d3
2 changed files with 53 additions and 18 deletions

View File

@ -30,38 +30,55 @@ from google.protobuf.descriptor_pb2 import (
from betterproto.casing import safe_snake_case
import google.protobuf.wrappers_pb2
WRAPPER_TYPES = {
"google.protobuf.DoubleValue": "float",
"google.protobuf.FloatValue": "float",
"google.protobuf.Int64Value": "int",
"google.protobuf.UInt64Value": "int",
"google.protobuf.Int32Value": "int",
"google.protobuf.UInt32Value": "int",
"google.protobuf.BoolValue": "bool",
"google.protobuf.StringValue": "str",
"google.protobuf.BytesValue": "bytes",
google.protobuf.wrappers_pb2.DoubleValue: "float",
google.protobuf.wrappers_pb2.FloatValue: "float",
google.protobuf.wrappers_pb2.Int64Value: "int",
google.protobuf.wrappers_pb2.UInt64Value: "int",
google.protobuf.wrappers_pb2.Int32Value: "int",
google.protobuf.wrappers_pb2.UInt32Value: "int",
google.protobuf.wrappers_pb2.BoolValue: "bool",
google.protobuf.wrappers_pb2.StringValue: "str",
google.protobuf.wrappers_pb2.BytesValue: "bytes",
}
def get_ref_type(package: str, imports: set, type_name: str) -> str:
def get_wrapper_type(type_name: str) -> (Any, str):
for wrapper, wrapped_type in WRAPPER_TYPES.items():
if wrapper.DESCRIPTOR.full_name == type_name:
return wrapper, wrapped_type
return None, None
def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True) -> str:
"""
Return a Python type name for a proto type reference. Adds the import if
necessary.
necessary. Unwraps well known type if required.
"""
# If the package name is a blank string, then this should still work
# because by convention packages are lowercase and message/enum types are
# pascal-cased. May require refactoring in the future.
type_name = type_name.lstrip(".")
if type_name in WRAPPER_TYPES:
return f"Optional[{WRAPPER_TYPES[type_name]}]"
# Check if type is wrapper.
wrapper, wrapped_type = get_wrapper_type(type_name)
if unwrap:
if wrapper:
return f"Optional[{wrapped_type}]"
if type_name == "google.protobuf.Duration":
return "timedelta"
if type_name == "google.protobuf.Timestamp":
return "datetime"
else:
if wrapper:
imports.add(f"from {wrapper.__module__} import {wrapper.__name__}")
return f"{wrapper.__name__}"
if type_name.startswith(package):
parts = type_name.lstrip(package).lstrip(".").split(".")
@ -379,7 +396,7 @@ def generate_code(request, response):
).strip('"'),
"input_message": input_message,
"output": get_ref_type(
package, output["imports"], method.output_type
package, output["imports"], method.output_type, unwrap=False
).strip('"'),
"client_streaming": method.client_streaming,
"server_streaming": method.server_streaming,

View File

@ -0,0 +1,18 @@
syntax = "proto3";
import "google/protobuf/wrappers.proto";
service Test {
rpc GetInt32 (Input) returns (google.protobuf.Int32Value);
rpc GetAnotherInt32 (Input) returns (google.protobuf.Int32Value);
rpc GetInt64 (Input) returns (google.protobuf.Int64Value);
rpc GetOutput (Input) returns (Output);
}
message Input {
}
message Output {
google.protobuf.Int64Value int64 = 1;
}