diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 597bf1a..d70786b 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -30,38 +30,55 @@ from google.protobuf.descriptor_pb2 import ( from betterproto.casing import safe_snake_case +import google.protobuf.wrappers_pb2 + WRAPPER_TYPES = { - "google.protobuf.DoubleValue": "float", - "google.protobuf.FloatValue": "float", - "google.protobuf.Int64Value": "int", - "google.protobuf.UInt64Value": "int", - "google.protobuf.Int32Value": "int", - "google.protobuf.UInt32Value": "int", - "google.protobuf.BoolValue": "bool", - "google.protobuf.StringValue": "str", - "google.protobuf.BytesValue": "bytes", + google.protobuf.wrappers_pb2.DoubleValue: "float", + google.protobuf.wrappers_pb2.FloatValue: "float", + google.protobuf.wrappers_pb2.Int64Value: "int", + google.protobuf.wrappers_pb2.UInt64Value: "int", + google.protobuf.wrappers_pb2.Int32Value: "int", + google.protobuf.wrappers_pb2.UInt32Value: "int", + google.protobuf.wrappers_pb2.BoolValue: "bool", + google.protobuf.wrappers_pb2.StringValue: "str", + google.protobuf.wrappers_pb2.BytesValue: "bytes", } -def get_ref_type(package: str, imports: set, type_name: str) -> str: +def get_wrapper_type(type_name: str) -> (Any, str): + for wrapper, wrapped_type in WRAPPER_TYPES.items(): + if wrapper.DESCRIPTOR.full_name == type_name: + return wrapper, wrapped_type + return None, None + + +def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True) -> str: """ Return a Python type name for a proto type reference. Adds the import if - necessary. + necessary. Unwraps well known type if required. """ # If the package name is a blank string, then this should still work # because by convention packages are lowercase and message/enum types are # pascal-cased. May require refactoring in the future. type_name = type_name.lstrip(".") - if type_name in WRAPPER_TYPES: - return f"Optional[{WRAPPER_TYPES[type_name]}]" + # Check if type is wrapper. + wrapper, wrapped_type = get_wrapper_type(type_name) - if type_name == "google.protobuf.Duration": - return "timedelta" + if unwrap: + if wrapper: + return f"Optional[{wrapped_type}]" - if type_name == "google.protobuf.Timestamp": - return "datetime" + if type_name == "google.protobuf.Duration": + return "timedelta" + + if type_name == "google.protobuf.Timestamp": + return "datetime" + else: + if wrapper: + imports.add(f"from {wrapper.__module__} import {wrapper.__name__}") + return f"{wrapper.__name__}" if type_name.startswith(package): parts = type_name.lstrip(package).lstrip(".").split(".") @@ -379,7 +396,7 @@ def generate_code(request, response): ).strip('"'), "input_message": input_message, "output": get_ref_type( - package, output["imports"], method.output_type + package, output["imports"], method.output_type, unwrap=False ).strip('"'), "client_streaming": method.client_streaming, "server_streaming": method.server_streaming, diff --git a/betterproto/tests/googletypes_service.proto b/betterproto/tests/googletypes_service.proto new file mode 100644 index 0000000..4bdca68 --- /dev/null +++ b/betterproto/tests/googletypes_service.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +import "google/protobuf/wrappers.proto"; + +service Test { + rpc GetInt32 (Input) returns (google.protobuf.Int32Value); + rpc GetAnotherInt32 (Input) returns (google.protobuf.Int32Value); + rpc GetInt64 (Input) returns (google.protobuf.Int64Value); + rpc GetOutput (Input) returns (Output); +} + +message Input { + +} + +message Output { + google.protobuf.Int64Value int64 = 1; +} \ No newline at end of file