diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 7476b38..2184a9c 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -1,12 +1,11 @@ #!/usr/bin/env python import itertools -import json import os.path -import re import sys import textwrap -from typing import Any, List, Tuple +from collections import defaultdict +from typing import Dict, List, Optional, Type try: import black @@ -24,33 +23,23 @@ from google.protobuf.descriptor_pb2 import ( DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, - FileDescriptorProto, - ServiceDescriptorProto, ) from betterproto.casing import safe_snake_case -import google.protobuf.wrappers_pb2 +import google.protobuf.wrappers_pb2 as google_wrappers - -WRAPPER_TYPES = { - google.protobuf.wrappers_pb2.DoubleValue: "float", - google.protobuf.wrappers_pb2.FloatValue: "float", - google.protobuf.wrappers_pb2.Int64Value: "int", - google.protobuf.wrappers_pb2.UInt64Value: "int", - google.protobuf.wrappers_pb2.Int32Value: "int", - google.protobuf.wrappers_pb2.UInt32Value: "int", - google.protobuf.wrappers_pb2.BoolValue: "bool", - google.protobuf.wrappers_pb2.StringValue: "str", - google.protobuf.wrappers_pb2.BytesValue: "bytes", -} - - -def get_wrapper_type(type_name: str) -> (Any, str): - for wrapper, wrapped_type in WRAPPER_TYPES.items(): - if wrapper.DESCRIPTOR.full_name == type_name: - return wrapper, wrapped_type - return None, None +WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict(lambda: None, { + 'google.protobuf.DoubleValue': google_wrappers.DoubleValue, + 'google.protobuf.FloatValue': google_wrappers.FloatValue, + 'google.protobuf.Int64Value': google_wrappers.Int64Value, + 'google.protobuf.UInt64Value': google_wrappers.UInt64Value, + 'google.protobuf.Int32Value': google_wrappers.Int32Value, + 'google.protobuf.UInt32Value': google_wrappers.UInt32Value, + 'google.protobuf.BoolValue': google_wrappers.BoolValue, + 'google.protobuf.StringValue': google_wrappers.StringValue, + 'google.protobuf.BytesValue': google_wrappers.BytesValue, +}) def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True) -> str: @@ -64,20 +53,21 @@ def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True type_name = type_name.lstrip(".") # Check if type is wrapper. - wrapper, wrapped_type = get_wrapper_type(type_name) + wrapper_class = WRAPPER_TYPES[type_name] if unwrap: - if wrapper: - return f"Optional[{wrapped_type}]" + if wrapper_class: + wrapped_type = type(wrapper_class().value) + return f"Optional[{wrapped_type.__name__}]" if type_name == "google.protobuf.Duration": return "timedelta" if type_name == "google.protobuf.Timestamp": return "datetime" - elif wrapper: - imports.add(f"from {wrapper.__module__} import {wrapper.__name__}") - return f"{wrapper.__name__}" + elif wrapper_class: + imports.add(f"from {wrapper_class.__module__} import {wrapper_class.__name__}") + return f"{wrapper_class.__name__}" if type_name.startswith(package): parts = type_name.lstrip(package).lstrip(".").split(".")