Move parsing of protobuf data types and services into separate methods
This commit is contained in:
@@ -163,186 +163,10 @@ def generate_code(request, response):
|
|||||||
for proto_file in output_package_content["files"]:
|
for proto_file in output_package_content["files"]:
|
||||||
item: DescriptorProto
|
item: DescriptorProto
|
||||||
for item, path in traverse(proto_file):
|
for item, path in traverse(proto_file):
|
||||||
data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
|
read_protobuf_type(input_package_name, item, path, proto_file, template_data)
|
||||||
|
|
||||||
if isinstance(item, DescriptorProto):
|
|
||||||
# print(item, file=sys.stderr)
|
|
||||||
if item.options.map_entry:
|
|
||||||
# Skip generated map entry messages since we just use dicts
|
|
||||||
continue
|
|
||||||
|
|
||||||
data.update(
|
|
||||||
{
|
|
||||||
"type": "Message",
|
|
||||||
"comment": get_comment(proto_file, path),
|
|
||||||
"properties": [],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, f in enumerate(item.field):
|
|
||||||
t = py_type(input_package_name, template_data["imports"], f)
|
|
||||||
zero = get_py_zero(f.type)
|
|
||||||
|
|
||||||
repeated = False
|
|
||||||
packed = False
|
|
||||||
|
|
||||||
field_type = f.Type.Name(f.type).lower()[5:]
|
|
||||||
|
|
||||||
field_wraps = ""
|
|
||||||
match_wrapper = re.match(
|
|
||||||
r"\.google\.protobuf\.(.+)Value", f.type_name
|
|
||||||
)
|
|
||||||
if match_wrapper:
|
|
||||||
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
|
|
||||||
if hasattr(betterproto, wrapped_type):
|
|
||||||
field_wraps = f"betterproto.{wrapped_type}"
|
|
||||||
|
|
||||||
map_types = None
|
|
||||||
if f.type == 11:
|
|
||||||
# This might be a map...
|
|
||||||
message_type = f.type_name.split(".").pop().lower()
|
|
||||||
# message_type = py_type(package)
|
|
||||||
map_entry = f"{f.name.replace('_', '').lower()}entry"
|
|
||||||
|
|
||||||
if message_type == map_entry:
|
|
||||||
for nested in item.nested_type:
|
|
||||||
if (
|
|
||||||
nested.name.replace("_", "").lower()
|
|
||||||
== map_entry
|
|
||||||
):
|
|
||||||
if nested.options.map_entry:
|
|
||||||
# print("Found a map!", file=sys.stderr)
|
|
||||||
k = py_type(
|
|
||||||
input_package_name,
|
|
||||||
template_data["imports"],
|
|
||||||
nested.field[0],
|
|
||||||
)
|
|
||||||
v = py_type(
|
|
||||||
input_package_name,
|
|
||||||
template_data["imports"],
|
|
||||||
nested.field[1],
|
|
||||||
)
|
|
||||||
t = f"Dict[{k}, {v}]"
|
|
||||||
field_type = "map"
|
|
||||||
map_types = (
|
|
||||||
f.Type.Name(nested.field[0].type),
|
|
||||||
f.Type.Name(nested.field[1].type),
|
|
||||||
)
|
|
||||||
template_data["typing_imports"].add("Dict")
|
|
||||||
|
|
||||||
if f.label == 3 and field_type != "map":
|
|
||||||
# Repeated field
|
|
||||||
repeated = True
|
|
||||||
t = f"List[{t}]"
|
|
||||||
zero = "[]"
|
|
||||||
template_data["typing_imports"].add("List")
|
|
||||||
|
|
||||||
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
|
|
||||||
packed = True
|
|
||||||
|
|
||||||
one_of = ""
|
|
||||||
if f.HasField("oneof_index"):
|
|
||||||
one_of = item.oneof_decl[f.oneof_index].name
|
|
||||||
|
|
||||||
if "Optional[" in t:
|
|
||||||
template_data["typing_imports"].add("Optional")
|
|
||||||
|
|
||||||
if "timedelta" in t:
|
|
||||||
template_data["datetime_imports"].add("timedelta")
|
|
||||||
elif "datetime" in t:
|
|
||||||
template_data["datetime_imports"].add("datetime")
|
|
||||||
|
|
||||||
data["properties"].append(
|
|
||||||
{
|
|
||||||
"name": f.name,
|
|
||||||
"py_name": pythonize_field_name(f.name),
|
|
||||||
"number": f.number,
|
|
||||||
"comment": get_comment(proto_file, path + [2, i]),
|
|
||||||
"proto_type": int(f.type),
|
|
||||||
"field_type": field_type,
|
|
||||||
"field_wraps": field_wraps,
|
|
||||||
"map_types": map_types,
|
|
||||||
"type": t,
|
|
||||||
"zero": zero,
|
|
||||||
"repeated": repeated,
|
|
||||||
"packed": packed,
|
|
||||||
"one_of": one_of,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
# print(f, file=sys.stderr)
|
|
||||||
|
|
||||||
template_data["messages"].append(data)
|
|
||||||
elif isinstance(item, EnumDescriptorProto):
|
|
||||||
# print(item.name, path, file=sys.stderr)
|
|
||||||
data.update(
|
|
||||||
{
|
|
||||||
"type": "Enum",
|
|
||||||
"comment": get_comment(proto_file, path),
|
|
||||||
"entries": [
|
|
||||||
{
|
|
||||||
"name": v.name,
|
|
||||||
"value": v.number,
|
|
||||||
"comment": get_comment(proto_file, path + [2, i]),
|
|
||||||
}
|
|
||||||
for i, v in enumerate(item.value)
|
|
||||||
],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
template_data["enums"].append(data)
|
|
||||||
|
|
||||||
for i, service in enumerate(proto_file.service):
|
for i, service in enumerate(proto_file.service):
|
||||||
# print(service, file=sys.stderr)
|
read_protobuf_service(i, input_package_name, proto_file, service, template_data)
|
||||||
|
|
||||||
data = {
|
|
||||||
"name": service.name,
|
|
||||||
"py_name": pythonize_class_name(service.name),
|
|
||||||
"comment": get_comment(proto_file, [6, i]),
|
|
||||||
"methods": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
for j, method in enumerate(service.method):
|
|
||||||
input_message = None
|
|
||||||
input_type = get_type_reference(
|
|
||||||
input_package_name, template_data["imports"], method.input_type
|
|
||||||
).strip('"')
|
|
||||||
for msg in template_data["messages"]:
|
|
||||||
if msg["name"] == input_type:
|
|
||||||
input_message = msg
|
|
||||||
for field in msg["properties"]:
|
|
||||||
if field["zero"] == "None":
|
|
||||||
template_data["typing_imports"].add("Optional")
|
|
||||||
break
|
|
||||||
|
|
||||||
data["methods"].append(
|
|
||||||
{
|
|
||||||
"name": method.name,
|
|
||||||
"py_name": pythonize_method_name(method.name),
|
|
||||||
"comment": get_comment(proto_file, [6, i, 2, j], indent=8),
|
|
||||||
"route": f"/{input_package_name}.{service.name}/{method.name}",
|
|
||||||
"input": get_type_reference(
|
|
||||||
input_package_name, template_data["imports"], method.input_type
|
|
||||||
).strip('"'),
|
|
||||||
"input_message": input_message,
|
|
||||||
"output": get_type_reference(
|
|
||||||
input_package_name,
|
|
||||||
template_data["imports"],
|
|
||||||
method.output_type,
|
|
||||||
unwrap=False,
|
|
||||||
),
|
|
||||||
"client_streaming": method.client_streaming,
|
|
||||||
"server_streaming": method.server_streaming,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if method.client_streaming:
|
|
||||||
template_data["typing_imports"].add("AsyncIterable")
|
|
||||||
template_data["typing_imports"].add("Iterable")
|
|
||||||
template_data["typing_imports"].add("Union")
|
|
||||||
if method.server_streaming:
|
|
||||||
template_data["typing_imports"].add("AsyncIterator")
|
|
||||||
|
|
||||||
template_data["services"].append(data)
|
|
||||||
|
|
||||||
template_data["imports"] = sorted(template_data["imports"])
|
template_data["imports"] = sorted(template_data["imports"])
|
||||||
template_data["datetime_imports"] = sorted(template_data["datetime_imports"])
|
template_data["datetime_imports"] = sorted(template_data["datetime_imports"])
|
||||||
@@ -379,6 +203,186 @@ def generate_code(request, response):
|
|||||||
print(f"Writing {output_package_name}", file=sys.stderr)
|
print(f"Writing {output_package_name}", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def read_protobuf_service(i, input_package_name, proto_file, service, template_data):
|
||||||
|
# print(service, file=sys.stderr)
|
||||||
|
data = {
|
||||||
|
"name": service.name,
|
||||||
|
"py_name": pythonize_class_name(service.name),
|
||||||
|
"comment": get_comment(proto_file, [6, i]),
|
||||||
|
"methods": [],
|
||||||
|
}
|
||||||
|
for j, method in enumerate(service.method):
|
||||||
|
input_message = None
|
||||||
|
input_type = get_type_reference(
|
||||||
|
input_package_name, template_data["imports"], method.input_type
|
||||||
|
).strip('"')
|
||||||
|
for msg in template_data["messages"]:
|
||||||
|
if msg["name"] == input_type:
|
||||||
|
input_message = msg
|
||||||
|
for field in msg["properties"]:
|
||||||
|
if field["zero"] == "None":
|
||||||
|
template_data["typing_imports"].add("Optional")
|
||||||
|
break
|
||||||
|
|
||||||
|
data["methods"].append(
|
||||||
|
{
|
||||||
|
"name": method.name,
|
||||||
|
"py_name": pythonize_method_name(method.name),
|
||||||
|
"comment": get_comment(proto_file, [6, i, 2, j], indent=8),
|
||||||
|
"route": f"/{input_package_name}.{service.name}/{method.name}",
|
||||||
|
"input": get_type_reference(
|
||||||
|
input_package_name, template_data["imports"], method.input_type
|
||||||
|
).strip('"'),
|
||||||
|
"input_message": input_message,
|
||||||
|
"output": get_type_reference(
|
||||||
|
input_package_name,
|
||||||
|
template_data["imports"],
|
||||||
|
method.output_type,
|
||||||
|
unwrap=False,
|
||||||
|
),
|
||||||
|
"client_streaming": method.client_streaming,
|
||||||
|
"server_streaming": method.server_streaming,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if method.client_streaming:
|
||||||
|
template_data["typing_imports"].add("AsyncIterable")
|
||||||
|
template_data["typing_imports"].add("Iterable")
|
||||||
|
template_data["typing_imports"].add("Union")
|
||||||
|
if method.server_streaming:
|
||||||
|
template_data["typing_imports"].add("AsyncIterator")
|
||||||
|
template_data["services"].append(data)
|
||||||
|
|
||||||
|
|
||||||
|
def read_protobuf_type(input_package_name, item, path, proto_file, template_data):
|
||||||
|
data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
|
||||||
|
if isinstance(item, DescriptorProto):
|
||||||
|
# print(item, file=sys.stderr)
|
||||||
|
if item.options.map_entry:
|
||||||
|
# Skip generated map entry messages since we just use dicts
|
||||||
|
return
|
||||||
|
|
||||||
|
data.update(
|
||||||
|
{
|
||||||
|
"type": "Message",
|
||||||
|
"comment": get_comment(proto_file, path),
|
||||||
|
"properties": [],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, f in enumerate(item.field):
|
||||||
|
t = py_type(input_package_name, template_data["imports"], f)
|
||||||
|
zero = get_py_zero(f.type)
|
||||||
|
|
||||||
|
repeated = False
|
||||||
|
packed = False
|
||||||
|
|
||||||
|
field_type = f.Type.Name(f.type).lower()[5:]
|
||||||
|
|
||||||
|
field_wraps = ""
|
||||||
|
match_wrapper = re.match(
|
||||||
|
r"\.google\.protobuf\.(.+)Value", f.type_name
|
||||||
|
)
|
||||||
|
if match_wrapper:
|
||||||
|
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
|
||||||
|
if hasattr(betterproto, wrapped_type):
|
||||||
|
field_wraps = f"betterproto.{wrapped_type}"
|
||||||
|
|
||||||
|
map_types = None
|
||||||
|
if f.type == 11:
|
||||||
|
# This might be a map...
|
||||||
|
message_type = f.type_name.split(".").pop().lower()
|
||||||
|
# message_type = py_type(package)
|
||||||
|
map_entry = f"{f.name.replace('_', '').lower()}entry"
|
||||||
|
|
||||||
|
if message_type == map_entry:
|
||||||
|
for nested in item.nested_type:
|
||||||
|
if (
|
||||||
|
nested.name.replace("_", "").lower()
|
||||||
|
== map_entry
|
||||||
|
):
|
||||||
|
if nested.options.map_entry:
|
||||||
|
# print("Found a map!", file=sys.stderr)
|
||||||
|
k = py_type(
|
||||||
|
input_package_name,
|
||||||
|
template_data["imports"],
|
||||||
|
nested.field[0],
|
||||||
|
)
|
||||||
|
v = py_type(
|
||||||
|
input_package_name,
|
||||||
|
template_data["imports"],
|
||||||
|
nested.field[1],
|
||||||
|
)
|
||||||
|
t = f"Dict[{k}, {v}]"
|
||||||
|
field_type = "map"
|
||||||
|
map_types = (
|
||||||
|
f.Type.Name(nested.field[0].type),
|
||||||
|
f.Type.Name(nested.field[1].type),
|
||||||
|
)
|
||||||
|
template_data["typing_imports"].add("Dict")
|
||||||
|
|
||||||
|
if f.label == 3 and field_type != "map":
|
||||||
|
# Repeated field
|
||||||
|
repeated = True
|
||||||
|
t = f"List[{t}]"
|
||||||
|
zero = "[]"
|
||||||
|
template_data["typing_imports"].add("List")
|
||||||
|
|
||||||
|
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
|
||||||
|
packed = True
|
||||||
|
|
||||||
|
one_of = ""
|
||||||
|
if f.HasField("oneof_index"):
|
||||||
|
one_of = item.oneof_decl[f.oneof_index].name
|
||||||
|
|
||||||
|
if "Optional[" in t:
|
||||||
|
template_data["typing_imports"].add("Optional")
|
||||||
|
|
||||||
|
if "timedelta" in t:
|
||||||
|
template_data["datetime_imports"].add("timedelta")
|
||||||
|
elif "datetime" in t:
|
||||||
|
template_data["datetime_imports"].add("datetime")
|
||||||
|
|
||||||
|
data["properties"].append(
|
||||||
|
{
|
||||||
|
"name": f.name,
|
||||||
|
"py_name": pythonize_field_name(f.name),
|
||||||
|
"number": f.number,
|
||||||
|
"comment": get_comment(proto_file, path + [2, i]),
|
||||||
|
"proto_type": int(f.type),
|
||||||
|
"field_type": field_type,
|
||||||
|
"field_wraps": field_wraps,
|
||||||
|
"map_types": map_types,
|
||||||
|
"type": t,
|
||||||
|
"zero": zero,
|
||||||
|
"repeated": repeated,
|
||||||
|
"packed": packed,
|
||||||
|
"one_of": one_of,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# print(f, file=sys.stderr)
|
||||||
|
|
||||||
|
template_data["messages"].append(data)
|
||||||
|
elif isinstance(item, EnumDescriptorProto):
|
||||||
|
# print(item.name, path, file=sys.stderr)
|
||||||
|
data.update(
|
||||||
|
{
|
||||||
|
"type": "Enum",
|
||||||
|
"comment": get_comment(proto_file, path),
|
||||||
|
"entries": [
|
||||||
|
{
|
||||||
|
"name": v.name,
|
||||||
|
"value": v.number,
|
||||||
|
"comment": get_comment(proto_file, path + [2, i]),
|
||||||
|
}
|
||||||
|
for i, v in enumerate(item.value)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
template_data["enums"].append(data)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""The plugin's main entry point."""
|
"""The plugin's main entry point."""
|
||||||
# Read request message from stdin
|
# Read request message from stdin
|
||||||
|
|||||||
Reference in New Issue
Block a user