Support Duration/Timestamp Google well-known types
This commit is contained in:
@@ -30,6 +30,19 @@ from google.protobuf.descriptor_pb2 import (
|
||||
from betterproto.casing import safe_snake_case
|
||||
|
||||
|
||||
WRAPPER_TYPES = {
|
||||
"google.protobuf.DoubleValue": "float",
|
||||
"google.protobuf.FloatValue": "float",
|
||||
"google.protobuf.Int64Value": "int",
|
||||
"google.protobuf.UInt64Value": "int",
|
||||
"google.protobuf.Int32Value": "int",
|
||||
"google.protobuf.UInt32Value": "int",
|
||||
"google.protobuf.BoolValue": "bool",
|
||||
"google.protobuf.StringValue": "str",
|
||||
"google.protobuf.BytesValue": "bytes",
|
||||
}
|
||||
|
||||
|
||||
def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
||||
"""
|
||||
Return a Python type name for a proto type reference. Adds the import if
|
||||
@@ -39,6 +52,16 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
||||
# because by convention packages are lowercase and message/enum types are
|
||||
# pascal-cased. May require refactoring in the future.
|
||||
type_name = type_name.lstrip(".")
|
||||
|
||||
if type_name in WRAPPER_TYPES:
|
||||
return f"Optional[{WRAPPER_TYPES[type_name]}]"
|
||||
|
||||
if type_name == "google.protobuf.Duration":
|
||||
return "timedelta"
|
||||
|
||||
if type_name == "google.protobuf.Timestamp":
|
||||
return "datetime"
|
||||
|
||||
if type_name.startswith(package):
|
||||
parts = type_name.lstrip(package).lstrip(".").split(".")
|
||||
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
|
||||
@@ -152,6 +175,9 @@ def generate_code(request, response):
|
||||
output_map = {}
|
||||
for proto_file in request.proto_file:
|
||||
out = proto_file.package
|
||||
if out == "google.protobuf":
|
||||
continue
|
||||
|
||||
if not out:
|
||||
out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".")
|
||||
|
||||
@@ -169,6 +195,7 @@ def generate_code(request, response):
|
||||
"package": package,
|
||||
"files": [f.name for f in options["files"]],
|
||||
"imports": set(),
|
||||
"datetime_imports": set(),
|
||||
"typing_imports": set(),
|
||||
"messages": [],
|
||||
"enums": [],
|
||||
@@ -258,6 +285,14 @@ def generate_code(request, response):
|
||||
if f.HasField("oneof_index"):
|
||||
one_of = item.oneof_decl[f.oneof_index].name
|
||||
|
||||
if "Optional[" in t:
|
||||
output["typing_imports"].add("Optional")
|
||||
|
||||
if "timedelta" in t:
|
||||
output["datetime_imports"].add("timedelta")
|
||||
elif "datetime" in t:
|
||||
output["datetime_imports"].add("datetime")
|
||||
|
||||
data["properties"].append(
|
||||
{
|
||||
"name": f.name,
|
||||
@@ -346,6 +381,7 @@ def generate_code(request, response):
|
||||
output["services"].append(data)
|
||||
|
||||
output["imports"] = sorted(output["imports"])
|
||||
output["datetime_imports"] = sorted(output["datetime_imports"])
|
||||
output["typing_imports"] = sorted(output["typing_imports"])
|
||||
|
||||
# Fill response
|
||||
|
||||
Reference in New Issue
Block a user