Clarify variable names

This commit is contained in:
boukeversteegh 2020-07-05 12:24:21 +02:00
parent 98d00f0d21
commit f2e87192b0

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
import collections
import itertools import itertools
import os.path import os.path
import pathlib import pathlib
@ -8,6 +8,8 @@ import sys
import textwrap import textwrap
from typing import List, Union from typing import List, Union
from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest
import betterproto import betterproto
from betterproto.compile.importing import get_type_reference from betterproto.compile.importing import get_type_reference
from betterproto.compile.naming import ( from betterproto.compile.naming import (
@ -129,7 +131,8 @@ def generate_code(request, response):
) )
template = env.get_template("template.py.j2") template = env.get_template("template.py.j2")
output_map = {} # Gather output packages
output_package_files = collections.defaultdict()
for proto_file in request.proto_file: for proto_file in request.proto_file:
if ( if (
proto_file.package == "google.protobuf" proto_file.package == "google.protobuf"
@ -137,21 +140,18 @@ def generate_code(request, response):
): ):
continue continue
output_file = str(pathlib.Path(*proto_file.package.split("."), "__init__.py")) output_package = proto_file.package
output_package_files.setdefault(
output_package, {"input_package": proto_file.package, "files": []}
)
output_package_files[output_package]["files"].append(proto_file)
if output_file not in output_map: output_paths = set()
output_map[output_file] = {"package": proto_file.package, "files": []} for output_package_name, output_package_content in output_package_files.items():
output_map[output_file]["files"].append(proto_file) input_package_name = output_package_content["input_package"]
template_data = {
# TODO: Figure out how to handle gRPC request/response messages and add "input_package": input_package_name,
# processing below for Service. "files": [f.name for f in output_package_content["files"]],
for filename, options in output_map.items():
package = options["package"]
# print(package, filename, file=sys.stderr)
output = {
"package": package,
"files": [f.name for f in options["files"]],
"imports": set(), "imports": set(),
"datetime_imports": set(), "datetime_imports": set(),
"typing_imports": set(), "typing_imports": set(),
@ -160,7 +160,7 @@ def generate_code(request, response):
"services": [], "services": [],
} }
for proto_file in options["files"]: for proto_file in output_package_content["files"]:
item: DescriptorProto item: DescriptorProto
for item, path in traverse(proto_file): for item, path in traverse(proto_file):
data = {"name": item.name, "py_name": pythonize_class_name(item.name)} data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
@ -180,7 +180,7 @@ def generate_code(request, response):
) )
for i, f in enumerate(item.field): for i, f in enumerate(item.field):
t = py_type(package, output["imports"], f) t = py_type(input_package_name, template_data["imports"], f)
zero = get_py_zero(f.type) zero = get_py_zero(f.type)
repeated = False repeated = False
@ -213,13 +213,13 @@ def generate_code(request, response):
if nested.options.map_entry: if nested.options.map_entry:
# print("Found a map!", file=sys.stderr) # print("Found a map!", file=sys.stderr)
k = py_type( k = py_type(
package, input_package_name,
output["imports"], template_data["imports"],
nested.field[0], nested.field[0],
) )
v = py_type( v = py_type(
package, input_package_name,
output["imports"], template_data["imports"],
nested.field[1], nested.field[1],
) )
t = f"Dict[{k}, {v}]" t = f"Dict[{k}, {v}]"
@ -228,14 +228,14 @@ def generate_code(request, response):
f.Type.Name(nested.field[0].type), f.Type.Name(nested.field[0].type),
f.Type.Name(nested.field[1].type), f.Type.Name(nested.field[1].type),
) )
output["typing_imports"].add("Dict") template_data["typing_imports"].add("Dict")
if f.label == 3 and field_type != "map": if f.label == 3 and field_type != "map":
# Repeated field # Repeated field
repeated = True repeated = True
t = f"List[{t}]" t = f"List[{t}]"
zero = "[]" zero = "[]"
output["typing_imports"].add("List") template_data["typing_imports"].add("List")
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
packed = True packed = True
@ -245,12 +245,12 @@ def generate_code(request, response):
one_of = item.oneof_decl[f.oneof_index].name one_of = item.oneof_decl[f.oneof_index].name
if "Optional[" in t: if "Optional[" in t:
output["typing_imports"].add("Optional") template_data["typing_imports"].add("Optional")
if "timedelta" in t: if "timedelta" in t:
output["datetime_imports"].add("timedelta") template_data["datetime_imports"].add("timedelta")
elif "datetime" in t: elif "datetime" in t:
output["datetime_imports"].add("datetime") template_data["datetime_imports"].add("datetime")
data["properties"].append( data["properties"].append(
{ {
@ -271,7 +271,7 @@ def generate_code(request, response):
) )
# print(f, file=sys.stderr) # print(f, file=sys.stderr)
output["messages"].append(data) template_data["messages"].append(data)
elif isinstance(item, EnumDescriptorProto): elif isinstance(item, EnumDescriptorProto):
# print(item.name, path, file=sys.stderr) # print(item.name, path, file=sys.stderr)
data.update( data.update(
@ -289,7 +289,7 @@ def generate_code(request, response):
} }
) )
output["enums"].append(data) template_data["enums"].append(data)
for i, service in enumerate(proto_file.service): for i, service in enumerate(proto_file.service):
# print(service, file=sys.stderr) # print(service, file=sys.stderr)
@ -304,14 +304,14 @@ def generate_code(request, response):
for j, method in enumerate(service.method): for j, method in enumerate(service.method):
input_message = None input_message = None
input_type = get_type_reference( input_type = get_type_reference(
package, output["imports"], method.input_type input_package_name, template_data["imports"], method.input_type
).strip('"') ).strip('"')
for msg in output["messages"]: for msg in template_data["messages"]:
if msg["name"] == input_type: if msg["name"] == input_type:
input_message = msg input_message = msg
for field in msg["properties"]: for field in msg["properties"]:
if field["zero"] == "None": if field["zero"] == "None":
output["typing_imports"].add("Optional") template_data["typing_imports"].add("Optional")
break break
data["methods"].append( data["methods"].append(
@ -319,14 +319,14 @@ def generate_code(request, response):
"name": method.name, "name": method.name,
"py_name": pythonize_method_name(method.name), "py_name": pythonize_method_name(method.name),
"comment": get_comment(proto_file, [6, i, 2, j], indent=8), "comment": get_comment(proto_file, [6, i, 2, j], indent=8),
"route": f"/{package}.{service.name}/{method.name}", "route": f"/{input_package_name}.{service.name}/{method.name}",
"input": get_type_reference( "input": get_type_reference(
package, output["imports"], method.input_type input_package_name, template_data["imports"], method.input_type
).strip('"'), ).strip('"'),
"input_message": input_message, "input_message": input_message,
"output": get_type_reference( "output": get_type_reference(
package, input_package_name,
output["imports"], template_data["imports"],
method.output_type, method.output_type,
unwrap=False, unwrap=False,
), ),
@ -336,30 +336,32 @@ def generate_code(request, response):
) )
if method.client_streaming: if method.client_streaming:
output["typing_imports"].add("AsyncIterable") template_data["typing_imports"].add("AsyncIterable")
output["typing_imports"].add("Iterable") template_data["typing_imports"].add("Iterable")
output["typing_imports"].add("Union") template_data["typing_imports"].add("Union")
if method.server_streaming: if method.server_streaming:
output["typing_imports"].add("AsyncIterator") template_data["typing_imports"].add("AsyncIterator")
output["services"].append(data) template_data["services"].append(data)
output["imports"] = sorted(output["imports"]) template_data["imports"] = sorted(template_data["imports"])
output["datetime_imports"] = sorted(output["datetime_imports"]) template_data["datetime_imports"] = sorted(template_data["datetime_imports"])
output["typing_imports"] = sorted(output["typing_imports"]) template_data["typing_imports"] = sorted(template_data["typing_imports"])
# Fill response # Fill response
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
output_paths.add(output_path)
f = response.file.add() f = response.file.add()
f.name = filename f.name = str(output_path)
# Render and then format the output file. # Render and then format the output file.
f.content = black.format_str( f.content = black.format_str(
template.render(description=output), template.render(description=template_data),
mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])), mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])),
) )
# Make each output directory a package with __init__ file # Make each output directory a package with __init__ file
output_paths = set(pathlib.Path(path) for path in output_map.keys())
init_files = ( init_files = (
set( set(
directory.joinpath("__init__.py") directory.joinpath("__init__.py")
@ -373,8 +375,8 @@ def generate_code(request, response):
init = response.file.add() init = response.file.add()
init.name = str(init_file) init.name = str(init_file)
for filename in sorted(output_paths.union(init_files)): for output_package_name in sorted(output_paths.union(init_files)):
print(f"Writing {filename}", file=sys.stderr) print(f"Writing {output_package_name}", file=sys.stderr)
def main(): def main():