Clarify variable names
This commit is contained in:
parent
98d00f0d21
commit
f2e87192b0
@ -1,5 +1,5 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
import collections
|
||||||
import itertools
|
import itertools
|
||||||
import os.path
|
import os.path
|
||||||
import pathlib
|
import pathlib
|
||||||
@ -8,6 +8,8 @@ import sys
|
|||||||
import textwrap
|
import textwrap
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest
|
||||||
|
|
||||||
import betterproto
|
import betterproto
|
||||||
from betterproto.compile.importing import get_type_reference
|
from betterproto.compile.importing import get_type_reference
|
||||||
from betterproto.compile.naming import (
|
from betterproto.compile.naming import (
|
||||||
@ -129,7 +131,8 @@ def generate_code(request, response):
|
|||||||
)
|
)
|
||||||
template = env.get_template("template.py.j2")
|
template = env.get_template("template.py.j2")
|
||||||
|
|
||||||
output_map = {}
|
# Gather output packages
|
||||||
|
output_package_files = collections.defaultdict()
|
||||||
for proto_file in request.proto_file:
|
for proto_file in request.proto_file:
|
||||||
if (
|
if (
|
||||||
proto_file.package == "google.protobuf"
|
proto_file.package == "google.protobuf"
|
||||||
@ -137,21 +140,18 @@ def generate_code(request, response):
|
|||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
output_file = str(pathlib.Path(*proto_file.package.split("."), "__init__.py"))
|
output_package = proto_file.package
|
||||||
|
output_package_files.setdefault(
|
||||||
|
output_package, {"input_package": proto_file.package, "files": []}
|
||||||
|
)
|
||||||
|
output_package_files[output_package]["files"].append(proto_file)
|
||||||
|
|
||||||
if output_file not in output_map:
|
output_paths = set()
|
||||||
output_map[output_file] = {"package": proto_file.package, "files": []}
|
for output_package_name, output_package_content in output_package_files.items():
|
||||||
output_map[output_file]["files"].append(proto_file)
|
input_package_name = output_package_content["input_package"]
|
||||||
|
template_data = {
|
||||||
# TODO: Figure out how to handle gRPC request/response messages and add
|
"input_package": input_package_name,
|
||||||
# processing below for Service.
|
"files": [f.name for f in output_package_content["files"]],
|
||||||
|
|
||||||
for filename, options in output_map.items():
|
|
||||||
package = options["package"]
|
|
||||||
# print(package, filename, file=sys.stderr)
|
|
||||||
output = {
|
|
||||||
"package": package,
|
|
||||||
"files": [f.name for f in options["files"]],
|
|
||||||
"imports": set(),
|
"imports": set(),
|
||||||
"datetime_imports": set(),
|
"datetime_imports": set(),
|
||||||
"typing_imports": set(),
|
"typing_imports": set(),
|
||||||
@ -160,7 +160,7 @@ def generate_code(request, response):
|
|||||||
"services": [],
|
"services": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
for proto_file in options["files"]:
|
for proto_file in output_package_content["files"]:
|
||||||
item: DescriptorProto
|
item: DescriptorProto
|
||||||
for item, path in traverse(proto_file):
|
for item, path in traverse(proto_file):
|
||||||
data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
|
data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
|
||||||
@ -180,7 +180,7 @@ def generate_code(request, response):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for i, f in enumerate(item.field):
|
for i, f in enumerate(item.field):
|
||||||
t = py_type(package, output["imports"], f)
|
t = py_type(input_package_name, template_data["imports"], f)
|
||||||
zero = get_py_zero(f.type)
|
zero = get_py_zero(f.type)
|
||||||
|
|
||||||
repeated = False
|
repeated = False
|
||||||
@ -213,13 +213,13 @@ def generate_code(request, response):
|
|||||||
if nested.options.map_entry:
|
if nested.options.map_entry:
|
||||||
# print("Found a map!", file=sys.stderr)
|
# print("Found a map!", file=sys.stderr)
|
||||||
k = py_type(
|
k = py_type(
|
||||||
package,
|
input_package_name,
|
||||||
output["imports"],
|
template_data["imports"],
|
||||||
nested.field[0],
|
nested.field[0],
|
||||||
)
|
)
|
||||||
v = py_type(
|
v = py_type(
|
||||||
package,
|
input_package_name,
|
||||||
output["imports"],
|
template_data["imports"],
|
||||||
nested.field[1],
|
nested.field[1],
|
||||||
)
|
)
|
||||||
t = f"Dict[{k}, {v}]"
|
t = f"Dict[{k}, {v}]"
|
||||||
@ -228,14 +228,14 @@ def generate_code(request, response):
|
|||||||
f.Type.Name(nested.field[0].type),
|
f.Type.Name(nested.field[0].type),
|
||||||
f.Type.Name(nested.field[1].type),
|
f.Type.Name(nested.field[1].type),
|
||||||
)
|
)
|
||||||
output["typing_imports"].add("Dict")
|
template_data["typing_imports"].add("Dict")
|
||||||
|
|
||||||
if f.label == 3 and field_type != "map":
|
if f.label == 3 and field_type != "map":
|
||||||
# Repeated field
|
# Repeated field
|
||||||
repeated = True
|
repeated = True
|
||||||
t = f"List[{t}]"
|
t = f"List[{t}]"
|
||||||
zero = "[]"
|
zero = "[]"
|
||||||
output["typing_imports"].add("List")
|
template_data["typing_imports"].add("List")
|
||||||
|
|
||||||
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
|
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
|
||||||
packed = True
|
packed = True
|
||||||
@ -245,12 +245,12 @@ def generate_code(request, response):
|
|||||||
one_of = item.oneof_decl[f.oneof_index].name
|
one_of = item.oneof_decl[f.oneof_index].name
|
||||||
|
|
||||||
if "Optional[" in t:
|
if "Optional[" in t:
|
||||||
output["typing_imports"].add("Optional")
|
template_data["typing_imports"].add("Optional")
|
||||||
|
|
||||||
if "timedelta" in t:
|
if "timedelta" in t:
|
||||||
output["datetime_imports"].add("timedelta")
|
template_data["datetime_imports"].add("timedelta")
|
||||||
elif "datetime" in t:
|
elif "datetime" in t:
|
||||||
output["datetime_imports"].add("datetime")
|
template_data["datetime_imports"].add("datetime")
|
||||||
|
|
||||||
data["properties"].append(
|
data["properties"].append(
|
||||||
{
|
{
|
||||||
@ -271,7 +271,7 @@ def generate_code(request, response):
|
|||||||
)
|
)
|
||||||
# print(f, file=sys.stderr)
|
# print(f, file=sys.stderr)
|
||||||
|
|
||||||
output["messages"].append(data)
|
template_data["messages"].append(data)
|
||||||
elif isinstance(item, EnumDescriptorProto):
|
elif isinstance(item, EnumDescriptorProto):
|
||||||
# print(item.name, path, file=sys.stderr)
|
# print(item.name, path, file=sys.stderr)
|
||||||
data.update(
|
data.update(
|
||||||
@ -289,7 +289,7 @@ def generate_code(request, response):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
output["enums"].append(data)
|
template_data["enums"].append(data)
|
||||||
|
|
||||||
for i, service in enumerate(proto_file.service):
|
for i, service in enumerate(proto_file.service):
|
||||||
# print(service, file=sys.stderr)
|
# print(service, file=sys.stderr)
|
||||||
@ -304,14 +304,14 @@ def generate_code(request, response):
|
|||||||
for j, method in enumerate(service.method):
|
for j, method in enumerate(service.method):
|
||||||
input_message = None
|
input_message = None
|
||||||
input_type = get_type_reference(
|
input_type = get_type_reference(
|
||||||
package, output["imports"], method.input_type
|
input_package_name, template_data["imports"], method.input_type
|
||||||
).strip('"')
|
).strip('"')
|
||||||
for msg in output["messages"]:
|
for msg in template_data["messages"]:
|
||||||
if msg["name"] == input_type:
|
if msg["name"] == input_type:
|
||||||
input_message = msg
|
input_message = msg
|
||||||
for field in msg["properties"]:
|
for field in msg["properties"]:
|
||||||
if field["zero"] == "None":
|
if field["zero"] == "None":
|
||||||
output["typing_imports"].add("Optional")
|
template_data["typing_imports"].add("Optional")
|
||||||
break
|
break
|
||||||
|
|
||||||
data["methods"].append(
|
data["methods"].append(
|
||||||
@ -319,14 +319,14 @@ def generate_code(request, response):
|
|||||||
"name": method.name,
|
"name": method.name,
|
||||||
"py_name": pythonize_method_name(method.name),
|
"py_name": pythonize_method_name(method.name),
|
||||||
"comment": get_comment(proto_file, [6, i, 2, j], indent=8),
|
"comment": get_comment(proto_file, [6, i, 2, j], indent=8),
|
||||||
"route": f"/{package}.{service.name}/{method.name}",
|
"route": f"/{input_package_name}.{service.name}/{method.name}",
|
||||||
"input": get_type_reference(
|
"input": get_type_reference(
|
||||||
package, output["imports"], method.input_type
|
input_package_name, template_data["imports"], method.input_type
|
||||||
).strip('"'),
|
).strip('"'),
|
||||||
"input_message": input_message,
|
"input_message": input_message,
|
||||||
"output": get_type_reference(
|
"output": get_type_reference(
|
||||||
package,
|
input_package_name,
|
||||||
output["imports"],
|
template_data["imports"],
|
||||||
method.output_type,
|
method.output_type,
|
||||||
unwrap=False,
|
unwrap=False,
|
||||||
),
|
),
|
||||||
@ -336,30 +336,32 @@ def generate_code(request, response):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if method.client_streaming:
|
if method.client_streaming:
|
||||||
output["typing_imports"].add("AsyncIterable")
|
template_data["typing_imports"].add("AsyncIterable")
|
||||||
output["typing_imports"].add("Iterable")
|
template_data["typing_imports"].add("Iterable")
|
||||||
output["typing_imports"].add("Union")
|
template_data["typing_imports"].add("Union")
|
||||||
if method.server_streaming:
|
if method.server_streaming:
|
||||||
output["typing_imports"].add("AsyncIterator")
|
template_data["typing_imports"].add("AsyncIterator")
|
||||||
|
|
||||||
output["services"].append(data)
|
template_data["services"].append(data)
|
||||||
|
|
||||||
output["imports"] = sorted(output["imports"])
|
template_data["imports"] = sorted(template_data["imports"])
|
||||||
output["datetime_imports"] = sorted(output["datetime_imports"])
|
template_data["datetime_imports"] = sorted(template_data["datetime_imports"])
|
||||||
output["typing_imports"] = sorted(output["typing_imports"])
|
template_data["typing_imports"] = sorted(template_data["typing_imports"])
|
||||||
|
|
||||||
# Fill response
|
# Fill response
|
||||||
|
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
|
||||||
|
output_paths.add(output_path)
|
||||||
|
|
||||||
f = response.file.add()
|
f = response.file.add()
|
||||||
f.name = filename
|
f.name = str(output_path)
|
||||||
|
|
||||||
# Render and then format the output file.
|
# Render and then format the output file.
|
||||||
f.content = black.format_str(
|
f.content = black.format_str(
|
||||||
template.render(description=output),
|
template.render(description=template_data),
|
||||||
mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])),
|
mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Make each output directory a package with __init__ file
|
# Make each output directory a package with __init__ file
|
||||||
output_paths = set(pathlib.Path(path) for path in output_map.keys())
|
|
||||||
init_files = (
|
init_files = (
|
||||||
set(
|
set(
|
||||||
directory.joinpath("__init__.py")
|
directory.joinpath("__init__.py")
|
||||||
@ -373,8 +375,8 @@ def generate_code(request, response):
|
|||||||
init = response.file.add()
|
init = response.file.add()
|
||||||
init.name = str(init_file)
|
init.name = str(init_file)
|
||||||
|
|
||||||
for filename in sorted(output_paths.union(init_files)):
|
for output_package_name in sorted(output_paths.union(init_files)):
|
||||||
print(f"Writing {filename}", file=sys.stderr)
|
print(f"Writing {output_package_name}", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user