diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 27788af..8791c24 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -1,5 +1,5 @@ #!/usr/bin/env python - +import collections import itertools import os.path import pathlib @@ -8,6 +8,8 @@ import sys import textwrap from typing import List, Union +from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest + import betterproto from betterproto.compile.importing import get_type_reference from betterproto.compile.naming import ( @@ -129,7 +131,8 @@ def generate_code(request, response): ) template = env.get_template("template.py.j2") - output_map = {} + # Gather output packages + output_package_files = collections.defaultdict() for proto_file in request.proto_file: if ( proto_file.package == "google.protobuf" @@ -137,21 +140,18 @@ def generate_code(request, response): ): continue - output_file = str(pathlib.Path(*proto_file.package.split("."), "__init__.py")) + output_package = proto_file.package + output_package_files.setdefault( + output_package, {"input_package": proto_file.package, "files": []} + ) + output_package_files[output_package]["files"].append(proto_file) - if output_file not in output_map: - output_map[output_file] = {"package": proto_file.package, "files": []} - output_map[output_file]["files"].append(proto_file) - - # TODO: Figure out how to handle gRPC request/response messages and add - # processing below for Service. - - for filename, options in output_map.items(): - package = options["package"] - # print(package, filename, file=sys.stderr) - output = { - "package": package, - "files": [f.name for f in options["files"]], + output_paths = set() + for output_package_name, output_package_content in output_package_files.items(): + input_package_name = output_package_content["input_package"] + template_data = { + "input_package": input_package_name, + "files": [f.name for f in output_package_content["files"]], "imports": set(), "datetime_imports": set(), "typing_imports": set(), @@ -160,7 +160,7 @@ def generate_code(request, response): "services": [], } - for proto_file in options["files"]: + for proto_file in output_package_content["files"]: item: DescriptorProto for item, path in traverse(proto_file): data = {"name": item.name, "py_name": pythonize_class_name(item.name)} @@ -180,7 +180,7 @@ def generate_code(request, response): ) for i, f in enumerate(item.field): - t = py_type(package, output["imports"], f) + t = py_type(input_package_name, template_data["imports"], f) zero = get_py_zero(f.type) repeated = False @@ -213,13 +213,13 @@ def generate_code(request, response): if nested.options.map_entry: # print("Found a map!", file=sys.stderr) k = py_type( - package, - output["imports"], + input_package_name, + template_data["imports"], nested.field[0], ) v = py_type( - package, - output["imports"], + input_package_name, + template_data["imports"], nested.field[1], ) t = f"Dict[{k}, {v}]" @@ -228,14 +228,14 @@ def generate_code(request, response): f.Type.Name(nested.field[0].type), f.Type.Name(nested.field[1].type), ) - output["typing_imports"].add("Dict") + template_data["typing_imports"].add("Dict") if f.label == 3 and field_type != "map": # Repeated field repeated = True t = f"List[{t}]" zero = "[]" - output["typing_imports"].add("List") + template_data["typing_imports"].add("List") if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: packed = True @@ -245,12 +245,12 @@ def generate_code(request, response): one_of = item.oneof_decl[f.oneof_index].name if "Optional[" in t: - output["typing_imports"].add("Optional") + template_data["typing_imports"].add("Optional") if "timedelta" in t: - output["datetime_imports"].add("timedelta") + template_data["datetime_imports"].add("timedelta") elif "datetime" in t: - output["datetime_imports"].add("datetime") + template_data["datetime_imports"].add("datetime") data["properties"].append( { @@ -271,7 +271,7 @@ def generate_code(request, response): ) # print(f, file=sys.stderr) - output["messages"].append(data) + template_data["messages"].append(data) elif isinstance(item, EnumDescriptorProto): # print(item.name, path, file=sys.stderr) data.update( @@ -289,7 +289,7 @@ def generate_code(request, response): } ) - output["enums"].append(data) + template_data["enums"].append(data) for i, service in enumerate(proto_file.service): # print(service, file=sys.stderr) @@ -304,14 +304,14 @@ def generate_code(request, response): for j, method in enumerate(service.method): input_message = None input_type = get_type_reference( - package, output["imports"], method.input_type + input_package_name, template_data["imports"], method.input_type ).strip('"') - for msg in output["messages"]: + for msg in template_data["messages"]: if msg["name"] == input_type: input_message = msg for field in msg["properties"]: if field["zero"] == "None": - output["typing_imports"].add("Optional") + template_data["typing_imports"].add("Optional") break data["methods"].append( @@ -319,14 +319,14 @@ def generate_code(request, response): "name": method.name, "py_name": pythonize_method_name(method.name), "comment": get_comment(proto_file, [6, i, 2, j], indent=8), - "route": f"/{package}.{service.name}/{method.name}", + "route": f"/{input_package_name}.{service.name}/{method.name}", "input": get_type_reference( - package, output["imports"], method.input_type + input_package_name, template_data["imports"], method.input_type ).strip('"'), "input_message": input_message, "output": get_type_reference( - package, - output["imports"], + input_package_name, + template_data["imports"], method.output_type, unwrap=False, ), @@ -336,30 +336,32 @@ def generate_code(request, response): ) if method.client_streaming: - output["typing_imports"].add("AsyncIterable") - output["typing_imports"].add("Iterable") - output["typing_imports"].add("Union") + template_data["typing_imports"].add("AsyncIterable") + template_data["typing_imports"].add("Iterable") + template_data["typing_imports"].add("Union") if method.server_streaming: - output["typing_imports"].add("AsyncIterator") + template_data["typing_imports"].add("AsyncIterator") - output["services"].append(data) + template_data["services"].append(data) - output["imports"] = sorted(output["imports"]) - output["datetime_imports"] = sorted(output["datetime_imports"]) - output["typing_imports"] = sorted(output["typing_imports"]) + template_data["imports"] = sorted(template_data["imports"]) + template_data["datetime_imports"] = sorted(template_data["datetime_imports"]) + template_data["typing_imports"] = sorted(template_data["typing_imports"]) # Fill response + output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") + output_paths.add(output_path) + f = response.file.add() - f.name = filename + f.name = str(output_path) # Render and then format the output file. f.content = black.format_str( - template.render(description=output), + template.render(description=template_data), mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])), ) # Make each output directory a package with __init__ file - output_paths = set(pathlib.Path(path) for path in output_map.keys()) init_files = ( set( directory.joinpath("__init__.py") @@ -373,8 +375,8 @@ def generate_code(request, response): init = response.file.add() init.name = str(init_file) - for filename in sorted(output_paths.union(init_files)): - print(f"Writing {filename}", file=sys.stderr) + for output_package_name in sorted(output_paths.union(init_files)): + print(f"Writing {output_package_name}", file=sys.stderr) def main():