Clarify variable names
This commit is contained in:
		| @@ -1,5 +1,5 @@ | |||||||
| #!/usr/bin/env python | #!/usr/bin/env python | ||||||
|  | import collections | ||||||
| import itertools | import itertools | ||||||
| import os.path | import os.path | ||||||
| import pathlib | import pathlib | ||||||
| @@ -8,6 +8,8 @@ import sys | |||||||
| import textwrap | import textwrap | ||||||
| from typing import List, Union | from typing import List, Union | ||||||
|  |  | ||||||
|  | from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest | ||||||
|  |  | ||||||
| import betterproto | import betterproto | ||||||
| from betterproto.compile.importing import get_type_reference | from betterproto.compile.importing import get_type_reference | ||||||
| from betterproto.compile.naming import ( | from betterproto.compile.naming import ( | ||||||
| @@ -129,7 +131,8 @@ def generate_code(request, response): | |||||||
|     ) |     ) | ||||||
|     template = env.get_template("template.py.j2") |     template = env.get_template("template.py.j2") | ||||||
|  |  | ||||||
|     output_map = {} |     # Gather output packages | ||||||
|  |     output_package_files = collections.defaultdict() | ||||||
|     for proto_file in request.proto_file: |     for proto_file in request.proto_file: | ||||||
|         if ( |         if ( | ||||||
|             proto_file.package == "google.protobuf" |             proto_file.package == "google.protobuf" | ||||||
| @@ -137,21 +140,18 @@ def generate_code(request, response): | |||||||
|         ): |         ): | ||||||
|             continue |             continue | ||||||
|  |  | ||||||
|         output_file = str(pathlib.Path(*proto_file.package.split("."), "__init__.py")) |         output_package = proto_file.package | ||||||
|  |         output_package_files.setdefault( | ||||||
|  |             output_package, {"input_package": proto_file.package, "files": []} | ||||||
|  |         ) | ||||||
|  |         output_package_files[output_package]["files"].append(proto_file) | ||||||
|  |  | ||||||
|         if output_file not in output_map: |     output_paths = set() | ||||||
|             output_map[output_file] = {"package": proto_file.package, "files": []} |     for output_package_name, output_package_content in output_package_files.items(): | ||||||
|         output_map[output_file]["files"].append(proto_file) |         input_package_name = output_package_content["input_package"] | ||||||
|  |         template_data = { | ||||||
|     # TODO: Figure out how to handle gRPC request/response messages and add |             "input_package": input_package_name, | ||||||
|     # processing below for Service. |             "files": [f.name for f in output_package_content["files"]], | ||||||
|  |  | ||||||
|     for filename, options in output_map.items(): |  | ||||||
|         package = options["package"] |  | ||||||
|         # print(package, filename, file=sys.stderr) |  | ||||||
|         output = { |  | ||||||
|             "package": package, |  | ||||||
|             "files": [f.name for f in options["files"]], |  | ||||||
|             "imports": set(), |             "imports": set(), | ||||||
|             "datetime_imports": set(), |             "datetime_imports": set(), | ||||||
|             "typing_imports": set(), |             "typing_imports": set(), | ||||||
| @@ -160,7 +160,7 @@ def generate_code(request, response): | |||||||
|             "services": [], |             "services": [], | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         for proto_file in options["files"]: |         for proto_file in output_package_content["files"]: | ||||||
|             item: DescriptorProto |             item: DescriptorProto | ||||||
|             for item, path in traverse(proto_file): |             for item, path in traverse(proto_file): | ||||||
|                 data = {"name": item.name, "py_name": pythonize_class_name(item.name)} |                 data = {"name": item.name, "py_name": pythonize_class_name(item.name)} | ||||||
| @@ -180,7 +180,7 @@ def generate_code(request, response): | |||||||
|                     ) |                     ) | ||||||
|  |  | ||||||
|                     for i, f in enumerate(item.field): |                     for i, f in enumerate(item.field): | ||||||
|                         t = py_type(package, output["imports"], f) |                         t = py_type(input_package_name, template_data["imports"], f) | ||||||
|                         zero = get_py_zero(f.type) |                         zero = get_py_zero(f.type) | ||||||
|  |  | ||||||
|                         repeated = False |                         repeated = False | ||||||
| @@ -213,13 +213,13 @@ def generate_code(request, response): | |||||||
|                                         if nested.options.map_entry: |                                         if nested.options.map_entry: | ||||||
|                                             # print("Found a map!", file=sys.stderr) |                                             # print("Found a map!", file=sys.stderr) | ||||||
|                                             k = py_type( |                                             k = py_type( | ||||||
|                                                 package, |                                                 input_package_name, | ||||||
|                                                 output["imports"], |                                                 template_data["imports"], | ||||||
|                                                 nested.field[0], |                                                 nested.field[0], | ||||||
|                                             ) |                                             ) | ||||||
|                                             v = py_type( |                                             v = py_type( | ||||||
|                                                 package, |                                                 input_package_name, | ||||||
|                                                 output["imports"], |                                                 template_data["imports"], | ||||||
|                                                 nested.field[1], |                                                 nested.field[1], | ||||||
|                                             ) |                                             ) | ||||||
|                                             t = f"Dict[{k}, {v}]" |                                             t = f"Dict[{k}, {v}]" | ||||||
| @@ -228,14 +228,14 @@ def generate_code(request, response): | |||||||
|                                                 f.Type.Name(nested.field[0].type), |                                                 f.Type.Name(nested.field[0].type), | ||||||
|                                                 f.Type.Name(nested.field[1].type), |                                                 f.Type.Name(nested.field[1].type), | ||||||
|                                             ) |                                             ) | ||||||
|                                             output["typing_imports"].add("Dict") |                                             template_data["typing_imports"].add("Dict") | ||||||
|  |  | ||||||
|                         if f.label == 3 and field_type != "map": |                         if f.label == 3 and field_type != "map": | ||||||
|                             # Repeated field |                             # Repeated field | ||||||
|                             repeated = True |                             repeated = True | ||||||
|                             t = f"List[{t}]" |                             t = f"List[{t}]" | ||||||
|                             zero = "[]" |                             zero = "[]" | ||||||
|                             output["typing_imports"].add("List") |                             template_data["typing_imports"].add("List") | ||||||
|  |  | ||||||
|                             if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: |                             if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: | ||||||
|                                 packed = True |                                 packed = True | ||||||
| @@ -245,12 +245,12 @@ def generate_code(request, response): | |||||||
|                             one_of = item.oneof_decl[f.oneof_index].name |                             one_of = item.oneof_decl[f.oneof_index].name | ||||||
|  |  | ||||||
|                         if "Optional[" in t: |                         if "Optional[" in t: | ||||||
|                             output["typing_imports"].add("Optional") |                             template_data["typing_imports"].add("Optional") | ||||||
|  |  | ||||||
|                         if "timedelta" in t: |                         if "timedelta" in t: | ||||||
|                             output["datetime_imports"].add("timedelta") |                             template_data["datetime_imports"].add("timedelta") | ||||||
|                         elif "datetime" in t: |                         elif "datetime" in t: | ||||||
|                             output["datetime_imports"].add("datetime") |                             template_data["datetime_imports"].add("datetime") | ||||||
|  |  | ||||||
|                         data["properties"].append( |                         data["properties"].append( | ||||||
|                             { |                             { | ||||||
| @@ -271,7 +271,7 @@ def generate_code(request, response): | |||||||
|                         ) |                         ) | ||||||
|                         # print(f, file=sys.stderr) |                         # print(f, file=sys.stderr) | ||||||
|  |  | ||||||
|                     output["messages"].append(data) |                     template_data["messages"].append(data) | ||||||
|                 elif isinstance(item, EnumDescriptorProto): |                 elif isinstance(item, EnumDescriptorProto): | ||||||
|                     # print(item.name, path, file=sys.stderr) |                     # print(item.name, path, file=sys.stderr) | ||||||
|                     data.update( |                     data.update( | ||||||
| @@ -289,7 +289,7 @@ def generate_code(request, response): | |||||||
|                         } |                         } | ||||||
|                     ) |                     ) | ||||||
|  |  | ||||||
|                     output["enums"].append(data) |                     template_data["enums"].append(data) | ||||||
|  |  | ||||||
|             for i, service in enumerate(proto_file.service): |             for i, service in enumerate(proto_file.service): | ||||||
|                 # print(service, file=sys.stderr) |                 # print(service, file=sys.stderr) | ||||||
| @@ -304,14 +304,14 @@ def generate_code(request, response): | |||||||
|                 for j, method in enumerate(service.method): |                 for j, method in enumerate(service.method): | ||||||
|                     input_message = None |                     input_message = None | ||||||
|                     input_type = get_type_reference( |                     input_type = get_type_reference( | ||||||
|                         package, output["imports"], method.input_type |                         input_package_name, template_data["imports"], method.input_type | ||||||
|                     ).strip('"') |                     ).strip('"') | ||||||
|                     for msg in output["messages"]: |                     for msg in template_data["messages"]: | ||||||
|                         if msg["name"] == input_type: |                         if msg["name"] == input_type: | ||||||
|                             input_message = msg |                             input_message = msg | ||||||
|                             for field in msg["properties"]: |                             for field in msg["properties"]: | ||||||
|                                 if field["zero"] == "None": |                                 if field["zero"] == "None": | ||||||
|                                     output["typing_imports"].add("Optional") |                                     template_data["typing_imports"].add("Optional") | ||||||
|                             break |                             break | ||||||
|  |  | ||||||
|                     data["methods"].append( |                     data["methods"].append( | ||||||
| @@ -319,14 +319,14 @@ def generate_code(request, response): | |||||||
|                             "name": method.name, |                             "name": method.name, | ||||||
|                             "py_name": pythonize_method_name(method.name), |                             "py_name": pythonize_method_name(method.name), | ||||||
|                             "comment": get_comment(proto_file, [6, i, 2, j], indent=8), |                             "comment": get_comment(proto_file, [6, i, 2, j], indent=8), | ||||||
|                             "route": f"/{package}.{service.name}/{method.name}", |                             "route": f"/{input_package_name}.{service.name}/{method.name}", | ||||||
|                             "input": get_type_reference( |                             "input": get_type_reference( | ||||||
|                                 package, output["imports"], method.input_type |                                 input_package_name, template_data["imports"], method.input_type | ||||||
|                             ).strip('"'), |                             ).strip('"'), | ||||||
|                             "input_message": input_message, |                             "input_message": input_message, | ||||||
|                             "output": get_type_reference( |                             "output": get_type_reference( | ||||||
|                                 package, |                                 input_package_name, | ||||||
|                                 output["imports"], |                                 template_data["imports"], | ||||||
|                                 method.output_type, |                                 method.output_type, | ||||||
|                                 unwrap=False, |                                 unwrap=False, | ||||||
|                             ), |                             ), | ||||||
| @@ -336,30 +336,32 @@ def generate_code(request, response): | |||||||
|                     ) |                     ) | ||||||
|  |  | ||||||
|                     if method.client_streaming: |                     if method.client_streaming: | ||||||
|                         output["typing_imports"].add("AsyncIterable") |                         template_data["typing_imports"].add("AsyncIterable") | ||||||
|                         output["typing_imports"].add("Iterable") |                         template_data["typing_imports"].add("Iterable") | ||||||
|                         output["typing_imports"].add("Union") |                         template_data["typing_imports"].add("Union") | ||||||
|                     if method.server_streaming: |                     if method.server_streaming: | ||||||
|                         output["typing_imports"].add("AsyncIterator") |                         template_data["typing_imports"].add("AsyncIterator") | ||||||
|  |  | ||||||
|                 output["services"].append(data) |                 template_data["services"].append(data) | ||||||
|  |  | ||||||
|         output["imports"] = sorted(output["imports"]) |         template_data["imports"] = sorted(template_data["imports"]) | ||||||
|         output["datetime_imports"] = sorted(output["datetime_imports"]) |         template_data["datetime_imports"] = sorted(template_data["datetime_imports"]) | ||||||
|         output["typing_imports"] = sorted(output["typing_imports"]) |         template_data["typing_imports"] = sorted(template_data["typing_imports"]) | ||||||
|  |  | ||||||
|         # Fill response |         # Fill response | ||||||
|  |         output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") | ||||||
|  |         output_paths.add(output_path) | ||||||
|  |  | ||||||
|         f = response.file.add() |         f = response.file.add() | ||||||
|         f.name = filename |         f.name = str(output_path) | ||||||
|  |  | ||||||
|         # Render and then format the output file. |         # Render and then format the output file. | ||||||
|         f.content = black.format_str( |         f.content = black.format_str( | ||||||
|             template.render(description=output), |             template.render(description=template_data), | ||||||
|             mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])), |             mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])), | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|     # Make each output directory a package with __init__ file |     # Make each output directory a package with __init__ file | ||||||
|     output_paths = set(pathlib.Path(path) for path in output_map.keys()) |  | ||||||
|     init_files = ( |     init_files = ( | ||||||
|         set( |         set( | ||||||
|             directory.joinpath("__init__.py") |             directory.joinpath("__init__.py") | ||||||
| @@ -373,8 +375,8 @@ def generate_code(request, response): | |||||||
|         init = response.file.add() |         init = response.file.add() | ||||||
|         init.name = str(init_file) |         init.name = str(init_file) | ||||||
|  |  | ||||||
|     for filename in sorted(output_paths.union(init_files)): |     for output_package_name in sorted(output_paths.union(init_files)): | ||||||
|         print(f"Writing {filename}", file=sys.stderr) |         print(f"Writing {output_package_name}", file=sys.stderr) | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(): | def main(): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user