Support using Google's wrapper types as RPC output values
This commit is contained in:
parent
ce9f492f50
commit
499489f1d3
@ -30,38 +30,55 @@ from google.protobuf.descriptor_pb2 import (
|
|||||||
|
|
||||||
from betterproto.casing import safe_snake_case
|
from betterproto.casing import safe_snake_case
|
||||||
|
|
||||||
|
import google.protobuf.wrappers_pb2
|
||||||
|
|
||||||
|
|
||||||
WRAPPER_TYPES = {
|
WRAPPER_TYPES = {
|
||||||
"google.protobuf.DoubleValue": "float",
|
google.protobuf.wrappers_pb2.DoubleValue: "float",
|
||||||
"google.protobuf.FloatValue": "float",
|
google.protobuf.wrappers_pb2.FloatValue: "float",
|
||||||
"google.protobuf.Int64Value": "int",
|
google.protobuf.wrappers_pb2.Int64Value: "int",
|
||||||
"google.protobuf.UInt64Value": "int",
|
google.protobuf.wrappers_pb2.UInt64Value: "int",
|
||||||
"google.protobuf.Int32Value": "int",
|
google.protobuf.wrappers_pb2.Int32Value: "int",
|
||||||
"google.protobuf.UInt32Value": "int",
|
google.protobuf.wrappers_pb2.UInt32Value: "int",
|
||||||
"google.protobuf.BoolValue": "bool",
|
google.protobuf.wrappers_pb2.BoolValue: "bool",
|
||||||
"google.protobuf.StringValue": "str",
|
google.protobuf.wrappers_pb2.StringValue: "str",
|
||||||
"google.protobuf.BytesValue": "bytes",
|
google.protobuf.wrappers_pb2.BytesValue: "bytes",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
def get_wrapper_type(type_name: str) -> (Any, str):
|
||||||
|
for wrapper, wrapped_type in WRAPPER_TYPES.items():
|
||||||
|
if wrapper.DESCRIPTOR.full_name == type_name:
|
||||||
|
return wrapper, wrapped_type
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
|
def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True) -> str:
|
||||||
"""
|
"""
|
||||||
Return a Python type name for a proto type reference. Adds the import if
|
Return a Python type name for a proto type reference. Adds the import if
|
||||||
necessary.
|
necessary. Unwraps well known type if required.
|
||||||
"""
|
"""
|
||||||
# If the package name is a blank string, then this should still work
|
# If the package name is a blank string, then this should still work
|
||||||
# because by convention packages are lowercase and message/enum types are
|
# because by convention packages are lowercase and message/enum types are
|
||||||
# pascal-cased. May require refactoring in the future.
|
# pascal-cased. May require refactoring in the future.
|
||||||
type_name = type_name.lstrip(".")
|
type_name = type_name.lstrip(".")
|
||||||
|
|
||||||
if type_name in WRAPPER_TYPES:
|
# Check if type is wrapper.
|
||||||
return f"Optional[{WRAPPER_TYPES[type_name]}]"
|
wrapper, wrapped_type = get_wrapper_type(type_name)
|
||||||
|
|
||||||
if type_name == "google.protobuf.Duration":
|
if unwrap:
|
||||||
return "timedelta"
|
if wrapper:
|
||||||
|
return f"Optional[{wrapped_type}]"
|
||||||
|
|
||||||
if type_name == "google.protobuf.Timestamp":
|
if type_name == "google.protobuf.Duration":
|
||||||
return "datetime"
|
return "timedelta"
|
||||||
|
|
||||||
|
if type_name == "google.protobuf.Timestamp":
|
||||||
|
return "datetime"
|
||||||
|
else:
|
||||||
|
if wrapper:
|
||||||
|
imports.add(f"from {wrapper.__module__} import {wrapper.__name__}")
|
||||||
|
return f"{wrapper.__name__}"
|
||||||
|
|
||||||
if type_name.startswith(package):
|
if type_name.startswith(package):
|
||||||
parts = type_name.lstrip(package).lstrip(".").split(".")
|
parts = type_name.lstrip(package).lstrip(".").split(".")
|
||||||
@ -379,7 +396,7 @@ def generate_code(request, response):
|
|||||||
).strip('"'),
|
).strip('"'),
|
||||||
"input_message": input_message,
|
"input_message": input_message,
|
||||||
"output": get_ref_type(
|
"output": get_ref_type(
|
||||||
package, output["imports"], method.output_type
|
package, output["imports"], method.output_type, unwrap=False
|
||||||
).strip('"'),
|
).strip('"'),
|
||||||
"client_streaming": method.client_streaming,
|
"client_streaming": method.client_streaming,
|
||||||
"server_streaming": method.server_streaming,
|
"server_streaming": method.server_streaming,
|
||||||
|
18
betterproto/tests/googletypes_service.proto
Normal file
18
betterproto/tests/googletypes_service.proto
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
import "google/protobuf/wrappers.proto";
|
||||||
|
|
||||||
|
service Test {
|
||||||
|
rpc GetInt32 (Input) returns (google.protobuf.Int32Value);
|
||||||
|
rpc GetAnotherInt32 (Input) returns (google.protobuf.Int32Value);
|
||||||
|
rpc GetInt64 (Input) returns (google.protobuf.Int64Value);
|
||||||
|
rpc GetOutput (Input) returns (Output);
|
||||||
|
}
|
||||||
|
|
||||||
|
message Input {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
message Output {
|
||||||
|
google.protobuf.Int64Value int64 = 1;
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user