Read desired wrapper type directly from wrapper definition
This commit is contained in:
parent
c50d9e2fdc
commit
8f0caf1db2
@ -1,12 +1,11 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import os.path
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import Any, List, Tuple
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
try:
|
||||
import black
|
||||
@ -24,33 +23,23 @@ from google.protobuf.descriptor_pb2 import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FieldDescriptorProto,
|
||||
FileDescriptorProto,
|
||||
ServiceDescriptorProto,
|
||||
)
|
||||
|
||||
from betterproto.casing import safe_snake_case
|
||||
|
||||
import google.protobuf.wrappers_pb2
|
||||
import google.protobuf.wrappers_pb2 as google_wrappers
|
||||
|
||||
|
||||
WRAPPER_TYPES = {
|
||||
google.protobuf.wrappers_pb2.DoubleValue: "float",
|
||||
google.protobuf.wrappers_pb2.FloatValue: "float",
|
||||
google.protobuf.wrappers_pb2.Int64Value: "int",
|
||||
google.protobuf.wrappers_pb2.UInt64Value: "int",
|
||||
google.protobuf.wrappers_pb2.Int32Value: "int",
|
||||
google.protobuf.wrappers_pb2.UInt32Value: "int",
|
||||
google.protobuf.wrappers_pb2.BoolValue: "bool",
|
||||
google.protobuf.wrappers_pb2.StringValue: "str",
|
||||
google.protobuf.wrappers_pb2.BytesValue: "bytes",
|
||||
}
|
||||
|
||||
|
||||
def get_wrapper_type(type_name: str) -> (Any, str):
|
||||
for wrapper, wrapped_type in WRAPPER_TYPES.items():
|
||||
if wrapper.DESCRIPTOR.full_name == type_name:
|
||||
return wrapper, wrapped_type
|
||||
return None, None
|
||||
WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict(lambda: None, {
|
||||
'google.protobuf.DoubleValue': google_wrappers.DoubleValue,
|
||||
'google.protobuf.FloatValue': google_wrappers.FloatValue,
|
||||
'google.protobuf.Int64Value': google_wrappers.Int64Value,
|
||||
'google.protobuf.UInt64Value': google_wrappers.UInt64Value,
|
||||
'google.protobuf.Int32Value': google_wrappers.Int32Value,
|
||||
'google.protobuf.UInt32Value': google_wrappers.UInt32Value,
|
||||
'google.protobuf.BoolValue': google_wrappers.BoolValue,
|
||||
'google.protobuf.StringValue': google_wrappers.StringValue,
|
||||
'google.protobuf.BytesValue': google_wrappers.BytesValue,
|
||||
})
|
||||
|
||||
|
||||
def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True) -> str:
|
||||
@ -64,20 +53,21 @@ def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True
|
||||
type_name = type_name.lstrip(".")
|
||||
|
||||
# Check if type is wrapper.
|
||||
wrapper, wrapped_type = get_wrapper_type(type_name)
|
||||
wrapper_class = WRAPPER_TYPES[type_name]
|
||||
|
||||
if unwrap:
|
||||
if wrapper:
|
||||
return f"Optional[{wrapped_type}]"
|
||||
if wrapper_class:
|
||||
wrapped_type = type(wrapper_class().value)
|
||||
return f"Optional[{wrapped_type.__name__}]"
|
||||
|
||||
if type_name == "google.protobuf.Duration":
|
||||
return "timedelta"
|
||||
|
||||
if type_name == "google.protobuf.Timestamp":
|
||||
return "datetime"
|
||||
elif wrapper:
|
||||
imports.add(f"from {wrapper.__module__} import {wrapper.__name__}")
|
||||
return f"{wrapper.__name__}"
|
||||
elif wrapper_class:
|
||||
imports.add(f"from {wrapper_class.__module__} import {wrapper_class.__name__}")
|
||||
return f"{wrapper_class.__name__}"
|
||||
|
||||
if type_name.startswith(package):
|
||||
parts = type_name.lstrip(package).lstrip(".").split(".")
|
||||
|
Loading…
x
Reference in New Issue
Block a user