Read desired wrapper type directly from wrapper definition
This commit is contained in:
parent
c50d9e2fdc
commit
8f0caf1db2
@ -1,12 +1,11 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
|
||||||
import os.path
|
import os.path
|
||||||
import re
|
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Any, List, Tuple
|
from collections import defaultdict
|
||||||
|
from typing import Dict, List, Optional, Type
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import black
|
import black
|
||||||
@ -24,33 +23,23 @@ from google.protobuf.descriptor_pb2 import (
|
|||||||
DescriptorProto,
|
DescriptorProto,
|
||||||
EnumDescriptorProto,
|
EnumDescriptorProto,
|
||||||
FieldDescriptorProto,
|
FieldDescriptorProto,
|
||||||
FileDescriptorProto,
|
|
||||||
ServiceDescriptorProto,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from betterproto.casing import safe_snake_case
|
from betterproto.casing import safe_snake_case
|
||||||
|
|
||||||
import google.protobuf.wrappers_pb2
|
import google.protobuf.wrappers_pb2 as google_wrappers
|
||||||
|
|
||||||
|
WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict(lambda: None, {
|
||||||
WRAPPER_TYPES = {
|
'google.protobuf.DoubleValue': google_wrappers.DoubleValue,
|
||||||
google.protobuf.wrappers_pb2.DoubleValue: "float",
|
'google.protobuf.FloatValue': google_wrappers.FloatValue,
|
||||||
google.protobuf.wrappers_pb2.FloatValue: "float",
|
'google.protobuf.Int64Value': google_wrappers.Int64Value,
|
||||||
google.protobuf.wrappers_pb2.Int64Value: "int",
|
'google.protobuf.UInt64Value': google_wrappers.UInt64Value,
|
||||||
google.protobuf.wrappers_pb2.UInt64Value: "int",
|
'google.protobuf.Int32Value': google_wrappers.Int32Value,
|
||||||
google.protobuf.wrappers_pb2.Int32Value: "int",
|
'google.protobuf.UInt32Value': google_wrappers.UInt32Value,
|
||||||
google.protobuf.wrappers_pb2.UInt32Value: "int",
|
'google.protobuf.BoolValue': google_wrappers.BoolValue,
|
||||||
google.protobuf.wrappers_pb2.BoolValue: "bool",
|
'google.protobuf.StringValue': google_wrappers.StringValue,
|
||||||
google.protobuf.wrappers_pb2.StringValue: "str",
|
'google.protobuf.BytesValue': google_wrappers.BytesValue,
|
||||||
google.protobuf.wrappers_pb2.BytesValue: "bytes",
|
})
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_wrapper_type(type_name: str) -> (Any, str):
|
|
||||||
for wrapper, wrapped_type in WRAPPER_TYPES.items():
|
|
||||||
if wrapper.DESCRIPTOR.full_name == type_name:
|
|
||||||
return wrapper, wrapped_type
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
|
|
||||||
def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True) -> str:
|
def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True) -> str:
|
||||||
@ -64,20 +53,21 @@ def get_ref_type(package: str, imports: set, type_name: str, unwrap: bool = True
|
|||||||
type_name = type_name.lstrip(".")
|
type_name = type_name.lstrip(".")
|
||||||
|
|
||||||
# Check if type is wrapper.
|
# Check if type is wrapper.
|
||||||
wrapper, wrapped_type = get_wrapper_type(type_name)
|
wrapper_class = WRAPPER_TYPES[type_name]
|
||||||
|
|
||||||
if unwrap:
|
if unwrap:
|
||||||
if wrapper:
|
if wrapper_class:
|
||||||
return f"Optional[{wrapped_type}]"
|
wrapped_type = type(wrapper_class().value)
|
||||||
|
return f"Optional[{wrapped_type.__name__}]"
|
||||||
|
|
||||||
if type_name == "google.protobuf.Duration":
|
if type_name == "google.protobuf.Duration":
|
||||||
return "timedelta"
|
return "timedelta"
|
||||||
|
|
||||||
if type_name == "google.protobuf.Timestamp":
|
if type_name == "google.protobuf.Timestamp":
|
||||||
return "datetime"
|
return "datetime"
|
||||||
elif wrapper:
|
elif wrapper_class:
|
||||||
imports.add(f"from {wrapper.__module__} import {wrapper.__name__}")
|
imports.add(f"from {wrapper_class.__module__} import {wrapper_class.__name__}")
|
||||||
return f"{wrapper.__name__}"
|
return f"{wrapper_class.__name__}"
|
||||||
|
|
||||||
if type_name.startswith(package):
|
if type_name.startswith(package):
|
||||||
parts = type_name.lstrip(package).lstrip(".").split(".")
|
parts = type_name.lstrip(package).lstrip(".").split(".")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user