Read desired wrapper type directly from wrapper definition

This commit is contained in:
boukeversteegh 2020-05-24 14:50:56 +02:00
parent c50d9e2fdc
commit 8f0caf1db2

View File

@ -1,12 +1,11 @@
#!/usr/bin/env python
import itertools
import json
import os.path
import re
import sys
import textwrap
from typing import Any, List, Tuple
from collections import defaultdict
from typing import Dict, List, Optional, Type
try:
import black
@ -24,33 +23,23 @@ from google.protobuf.descriptor_pb2 import (
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
FileDescriptorProto,
ServiceDescriptorProto,
)
from betterproto.casing import safe_snake_case
import google.protobuf.wrappers_pb2
import google.protobuf.wrappers_pb2 as google_wrappers
WRAPPER_TYPES = {
google.protobuf.wrappers_pb2.DoubleValue: "float",
google.protobuf.wrappers_pb2.FloatValue: "float",
google.protobuf.wrappers_pb2.Int64Value: "int",
google.protobuf.wrappers_pb2.UInt64Value: "int",
google.protobuf.wrappers_pb2.Int32Value: "int",
google.protobuf.wrappers_pb2.UInt32Value: "int",
google.protobuf.wrappers_pb2.BoolValue: "bool",
google.protobuf.wrappers_pb2.StringValue: "str",
google.protobuf.wrappers_pb2.BytesValue: "bytes",
}
def get_wrapper_type(type_name: str) -> (Any, str):
for wrapper, wrapped_type in WRAPPER_TYPES.items():
if wrapper.DESCRIPTOR.full_name == type_name:
return wrapper, wrapped_type
return None, None
WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict(lambda: None, {
'google.protobuf.DoubleValue': google_wrappers.DoubleValue,
'google.protobuf.FloatValue': google_wrappers.FloatValue,
'google.protobuf.Int64Value': google_wrappers.Int64Value,
'google.protobuf.UInt64Value': google_wrappers.UInt64Value,
'google.protobuf.Int32Value': google_wrappers.Int32Value,
'google.protobuf.UInt32Value': google_wrappers.UInt32Value,
'google.protobuf.BoolValue': google_wrappers.BoolValue,
'google.protobuf.StringValue': google_wrappers.StringValue,
'google.protobuf.BytesValue': google_wrappers.BytesValue,
})
def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True) -> str:
@ -64,20 +53,21 @@ def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True
type_name = type_name.lstrip(".")
# Check if type is wrapper.
wrapper, wrapped_type = get_wrapper_type(type_name)
wrapper_class = WRAPPER_TYPES[type_name]
if unwrap:
if wrapper:
return f"Optional[{wrapped_type}]"
if wrapper_class:
wrapped_type = type(wrapper_class().value)
return f"Optional[{wrapped_type.__name__}]"
if type_name == "google.protobuf.Duration":
return "timedelta"
if type_name == "google.protobuf.Timestamp":
return "datetime"
elif wrapper:
imports.add(f"from {wrapper.__module__} import {wrapper.__name__}")
return f"{wrapper.__name__}"
elif wrapper_class:
imports.add(f"from {wrapper_class.__module__} import {wrapper_class.__name__}")
return f"{wrapper_class.__name__}"
if type_name.startswith(package):
parts = type_name.lstrip(package).lstrip(".").split(".")