diff --git a/CHANGELOG.md b/CHANGELOG.md index 383d3f7..c5c65b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,8 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 > `2.0.0` will be released once the interface is stable. - Add support for gRPC and **stream-stream** [#83](https://github.com/danielgtaylor/python-betterproto/pull/83) -- Switch from to `poetry` for development [#75](https://github.com/danielgtaylor/python-betterproto/pull/75) -- Fix No arguments are generated for stub methods when using import with proto definition +- Switch from `pipenv` to `poetry` for development [#75](https://github.com/danielgtaylor/python-betterproto/pull/75) - Fix two packages with the same name suffix should not cause naming conflict [#25](https://github.com/danielgtaylor/python-betterproto/issues/25) - Fix Import child package from root [#57](https://github.com/danielgtaylor/python-betterproto/issues/57) diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 529de36..0d88d47 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -1,5 +1,5 @@ #!/usr/bin/env python - +import collections import itertools import os.path import pathlib @@ -8,13 +8,16 @@ import sys import textwrap from typing import List, Union +from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest + import betterproto -from betterproto.compile.importing import get_type_reference +from betterproto.compile.importing import get_type_reference, parse_source_type_name from betterproto.compile.naming import ( pythonize_class_name, pythonize_field_name, pythonize_method_name, ) +from betterproto.lib.google.protobuf import ServiceDescriptorProto try: # betterproto[compiler] specific dependencies @@ -75,9 +78,11 @@ def get_py_zero(type_num: int) -> Union[str, float]: def traverse(proto_file): + # Todo: Keep information about nested hierarchy def _traverse(path, items, prefix=""): for i, item in enumerate(items): # Adjust the name since we flatten the hierarchy. + # Todo: don't change the name, but include full name in returned tuple item.name = next_prefix = prefix + item.name yield item, path + [i] @@ -129,7 +134,8 @@ def generate_code(request, response): ) template = env.get_template("template.py.j2") - output_map = {} + # Gather output packages + output_package_files = collections.defaultdict() for proto_file in request.proto_file: if ( proto_file.package == "google.protobuf" @@ -137,21 +143,17 @@ def generate_code(request, response): ): continue - output_file = str(pathlib.Path(*proto_file.package.split("."), "__init__.py")) + output_package = proto_file.package + output_package_files.setdefault( + output_package, {"input_package": proto_file.package, "files": []} + ) + output_package_files[output_package]["files"].append(proto_file) - if output_file not in output_map: - output_map[output_file] = {"package": proto_file.package, "files": []} - output_map[output_file]["files"].append(proto_file) - - # TODO: Figure out how to handle gRPC request/response messages and add - # processing below for Service. - - for filename, options in output_map.items(): - package = options["package"] - # print(package, filename, file=sys.stderr) - output = { - "package": package, - "files": [f.name for f in options["files"]], + # Initialize Template data for each package + for output_package_name, output_package_content in output_package_files.items(): + template_data = { + "input_package": output_package_content["input_package"], + "files": [f.name for f in output_package_content["files"]], "imports": set(), "datetime_imports": set(), "typing_imports": set(), @@ -159,207 +161,48 @@ def generate_code(request, response): "enums": [], "services": [], } + output_package_content["template_data"] = template_data - for proto_file in options["files"]: - item: DescriptorProto + # Read Messages and Enums + output_types = [] + for output_package_name, output_package_content in output_package_files.items(): + for proto_file in output_package_content["files"]: for item, path in traverse(proto_file): - data = {"name": item.name, "py_name": pythonize_class_name(item.name)} + type_data = read_protobuf_type( + item, path, proto_file, output_package_content + ) + output_types.append(type_data) - if isinstance(item, DescriptorProto): - # print(item, file=sys.stderr) - if item.options.map_entry: - # Skip generated map entry messages since we just use dicts - continue + # Read Services + for output_package_name, output_package_content in output_package_files.items(): + for proto_file in output_package_content["files"]: + for index, service in enumerate(proto_file.service): + read_protobuf_service( + service, index, proto_file, output_package_content, output_types + ) - data.update( - { - "type": "Message", - "comment": get_comment(proto_file, path), - "properties": [], - } - ) - - for i, f in enumerate(item.field): - t = py_type(package, output["imports"], f) - zero = get_py_zero(f.type) - - repeated = False - packed = False - - field_type = f.Type.Name(f.type).lower()[5:] - - field_wraps = "" - match_wrapper = re.match( - r"\.google\.protobuf\.(.+)Value", f.type_name - ) - if match_wrapper: - wrapped_type = "TYPE_" + match_wrapper.group(1).upper() - if hasattr(betterproto, wrapped_type): - field_wraps = f"betterproto.{wrapped_type}" - - map_types = None - if f.type == 11: - # This might be a map... - message_type = f.type_name.split(".").pop().lower() - # message_type = py_type(package) - map_entry = f"{f.name.replace('_', '').lower()}entry" - - if message_type == map_entry: - for nested in item.nested_type: - if ( - nested.name.replace("_", "").lower() - == map_entry - ): - if nested.options.map_entry: - # print("Found a map!", file=sys.stderr) - k = py_type( - package, - output["imports"], - nested.field[0], - ) - v = py_type( - package, - output["imports"], - nested.field[1], - ) - t = f"Dict[{k}, {v}]" - field_type = "map" - map_types = ( - f.Type.Name(nested.field[0].type), - f.Type.Name(nested.field[1].type), - ) - output["typing_imports"].add("Dict") - - if f.label == 3 and field_type != "map": - # Repeated field - repeated = True - t = f"List[{t}]" - zero = "[]" - output["typing_imports"].add("List") - - if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: - packed = True - - one_of = "" - if f.HasField("oneof_index"): - one_of = item.oneof_decl[f.oneof_index].name - - if "Optional[" in t: - output["typing_imports"].add("Optional") - - if "timedelta" in t: - output["datetime_imports"].add("timedelta") - elif "datetime" in t: - output["datetime_imports"].add("datetime") - - data["properties"].append( - { - "name": f.name, - "py_name": pythonize_field_name(f.name), - "number": f.number, - "comment": get_comment(proto_file, path + [2, i]), - "proto_type": int(f.type), - "field_type": field_type, - "field_wraps": field_wraps, - "map_types": map_types, - "type": t, - "zero": zero, - "repeated": repeated, - "packed": packed, - "one_of": one_of, - } - ) - # print(f, file=sys.stderr) - - output["messages"].append(data) - elif isinstance(item, EnumDescriptorProto): - # print(item.name, path, file=sys.stderr) - data.update( - { - "type": "Enum", - "comment": get_comment(proto_file, path), - "entries": [ - { - "name": v.name, - "value": v.number, - "comment": get_comment(proto_file, path + [2, i]), - } - for i, v in enumerate(item.value) - ], - } - ) - - output["enums"].append(data) - - for i, service in enumerate(proto_file.service): - # print(service, file=sys.stderr) - - data = { - "name": service.name, - "py_name": pythonize_class_name(service.name), - "comment": get_comment(proto_file, [6, i]), - "methods": [], - } - - for j, method in enumerate(service.method): - input_message = None - input_type = get_type_reference( - package, output["imports"], method.input_type - ).strip('"') - for msg in output["messages"]: - if msg["name"] == input_type: - input_message = msg - for field in msg["properties"]: - if field["zero"] == "None": - output["typing_imports"].add("Optional") - break - - data["methods"].append( - { - "name": method.name, - "py_name": pythonize_method_name(method.name), - "comment": get_comment(proto_file, [6, i, 2, j], indent=8), - "route": f"/{package}.{service.name}/{method.name}", - "input": get_type_reference( - package, output["imports"], method.input_type - ).strip('"'), - "input_message": input_message, - "output": get_type_reference( - package, - output["imports"], - method.output_type, - unwrap=False, - ), - "client_streaming": method.client_streaming, - "server_streaming": method.server_streaming, - } - ) - - if method.client_streaming: - output["typing_imports"].add("AsyncIterable") - output["typing_imports"].add("Iterable") - output["typing_imports"].add("Union") - if method.server_streaming: - output["typing_imports"].add("AsyncIterator") - - output["services"].append(data) - - output["imports"] = sorted(output["imports"]) - output["datetime_imports"] = sorted(output["datetime_imports"]) - output["typing_imports"] = sorted(output["typing_imports"]) + # Render files + output_paths = set() + for output_package_name, output_package_content in output_package_files.items(): + template_data = output_package_content["template_data"] + template_data["imports"] = sorted(template_data["imports"]) + template_data["datetime_imports"] = sorted(template_data["datetime_imports"]) + template_data["typing_imports"] = sorted(template_data["typing_imports"]) # Fill response + output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") + output_paths.add(output_path) + f = response.file.add() - f.name = filename + f.name = str(output_path) # Render and then format the output file. f.content = black.format_str( - template.render(description=output), + template.render(description=template_data), mode=black.FileMode(target_versions={black.TargetVersion.PY37}), ) # Make each output directory a package with __init__ file - output_paths = set(pathlib.Path(path) for path in output_map.keys()) init_files = ( set( directory.joinpath("__init__.py") @@ -373,8 +216,207 @@ def generate_code(request, response): init = response.file.add() init.name = str(init_file) - for filename in sorted(output_paths.union(init_files)): - print(f"Writing {filename}", file=sys.stderr) + for output_package_name in sorted(output_paths.union(init_files)): + print(f"Writing {output_package_name}", file=sys.stderr) + + +def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file, content): + input_package_name = content["input_package"] + template_data = content["template_data"] + data = { + "name": item.name, + "py_name": pythonize_class_name(item.name), + "descriptor": item, + "package": input_package_name, + } + if isinstance(item, DescriptorProto): + # print(item, file=sys.stderr) + if item.options.map_entry: + # Skip generated map entry messages since we just use dicts + return + + data.update( + { + "type": "Message", + "comment": get_comment(proto_file, path), + "properties": [], + } + ) + + for i, f in enumerate(item.field): + t = py_type(input_package_name, template_data["imports"], f) + zero = get_py_zero(f.type) + + repeated = False + packed = False + + field_type = f.Type.Name(f.type).lower()[5:] + + field_wraps = "" + match_wrapper = re.match(r"\.google\.protobuf\.(.+)Value", f.type_name) + if match_wrapper: + wrapped_type = "TYPE_" + match_wrapper.group(1).upper() + if hasattr(betterproto, wrapped_type): + field_wraps = f"betterproto.{wrapped_type}" + + map_types = None + if f.type == 11: + # This might be a map... + message_type = f.type_name.split(".").pop().lower() + # message_type = py_type(package) + map_entry = f"{f.name.replace('_', '').lower()}entry" + + if message_type == map_entry: + for nested in item.nested_type: + if nested.name.replace("_", "").lower() == map_entry: + if nested.options.map_entry: + # print("Found a map!", file=sys.stderr) + k = py_type( + input_package_name, + template_data["imports"], + nested.field[0], + ) + v = py_type( + input_package_name, + template_data["imports"], + nested.field[1], + ) + t = f"Dict[{k}, {v}]" + field_type = "map" + map_types = ( + f.Type.Name(nested.field[0].type), + f.Type.Name(nested.field[1].type), + ) + template_data["typing_imports"].add("Dict") + + if f.label == 3 and field_type != "map": + # Repeated field + repeated = True + t = f"List[{t}]" + zero = "[]" + template_data["typing_imports"].add("List") + + if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: + packed = True + + one_of = "" + if f.HasField("oneof_index"): + one_of = item.oneof_decl[f.oneof_index].name + + if "Optional[" in t: + template_data["typing_imports"].add("Optional") + + if "timedelta" in t: + template_data["datetime_imports"].add("timedelta") + elif "datetime" in t: + template_data["datetime_imports"].add("datetime") + + data["properties"].append( + { + "name": f.name, + "py_name": pythonize_field_name(f.name), + "number": f.number, + "comment": get_comment(proto_file, path + [2, i]), + "proto_type": int(f.type), + "field_type": field_type, + "field_wraps": field_wraps, + "map_types": map_types, + "type": t, + "zero": zero, + "repeated": repeated, + "packed": packed, + "one_of": one_of, + } + ) + # print(f, file=sys.stderr) + + template_data["messages"].append(data) + return data + elif isinstance(item, EnumDescriptorProto): + # print(item.name, path, file=sys.stderr) + data.update( + { + "type": "Enum", + "comment": get_comment(proto_file, path), + "entries": [ + { + "name": v.name, + "value": v.number, + "comment": get_comment(proto_file, path + [2, i]), + } + for i, v in enumerate(item.value) + ], + } + ) + + template_data["enums"].append(data) + return data + + +def lookup_method_input_type(method, types): + package, name = parse_source_type_name(method.input_type) + + for known_type in types: + if known_type["type"] != "Message": + continue + + # Nested types are currently flattened without dots. + # Todo: keep a fully quantified name in types, that is comparable with method.input_type + if ( + package == known_type["package"] + and name.replace(".", "") == known_type["name"] + ): + return known_type + + +def read_protobuf_service( + service: ServiceDescriptorProto, index, proto_file, content, output_types +): + input_package_name = content["input_package"] + template_data = content["template_data"] + # print(service, file=sys.stderr) + data = { + "name": service.name, + "py_name": pythonize_class_name(service.name), + "comment": get_comment(proto_file, [6, index]), + "methods": [], + } + for j, method in enumerate(service.method): + method_input_message = lookup_method_input_type(method, output_types) + + if method_input_message: + for field in method_input_message["properties"]: + if field["zero"] == "None": + template_data["typing_imports"].add("Optional") + + data["methods"].append( + { + "name": method.name, + "py_name": pythonize_method_name(method.name), + "comment": get_comment(proto_file, [6, index, 2, j], indent=8), + "route": f"/{input_package_name}.{service.name}/{method.name}", + "input": get_type_reference( + input_package_name, template_data["imports"], method.input_type + ).strip('"'), + "input_message": method_input_message, + "output": get_type_reference( + input_package_name, + template_data["imports"], + method.output_type, + unwrap=False, + ), + "client_streaming": method.client_streaming, + "server_streaming": method.server_streaming, + } + ) + + if method.client_streaming: + template_data["typing_imports"].add("AsyncIterable") + template_data["typing_imports"].add("Iterable") + template_data["typing_imports"].add("Union") + if method.server_streaming: + template_data["typing_imports"].add("AsyncIterator") + template_data["services"].append(data) def main(): @@ -386,6 +428,10 @@ def main(): request = plugin.CodeGeneratorRequest() request.ParseFromString(data) + dump_file = os.getenv("BETTERPROTO_DUMP") + if dump_file: + dump_request(dump_file, request) + # Create response response = plugin.CodeGeneratorResponse() @@ -399,5 +445,16 @@ def main(): sys.stdout.buffer.write(output) +def dump_request(dump_file: str, request: CodeGeneratorRequest): + """ + For developers: Supports running plugin.py standalone so its possible to debug it. + Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file. + Then run plugin.py from your IDE in debugging mode, and redirect stdin to the file. + """ + with open(str(dump_file), "wb") as fh: + sys.stderr.write(f"\033[31mWriting input from protoc to: {dump_file}\033[0m\n") + fh.write(request.SerializeToString()) + + if __name__ == "__main__": main() diff --git a/betterproto/tests/inputs/import_service_input_message/child_package_request_message.proto b/betterproto/tests/inputs/import_service_input_message/child_package_request_message.proto new file mode 100644 index 0000000..6380db2 --- /dev/null +++ b/betterproto/tests/inputs/import_service_input_message/child_package_request_message.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package child; + +message ChildRequestMessage { + int32 child_argument = 1; +} \ No newline at end of file diff --git a/betterproto/tests/inputs/import_service_input_message/import_service_input_message.proto b/betterproto/tests/inputs/import_service_input_message/import_service_input_message.proto index a5073db..7ca9c46 100644 --- a/betterproto/tests/inputs/import_service_input_message/import_service_input_message.proto +++ b/betterproto/tests/inputs/import_service_input_message/import_service_input_message.proto @@ -1,11 +1,14 @@ syntax = "proto3"; import "request_message.proto"; +import "child_package_request_message.proto"; // Tests generated service correctly imports the RequestMessage service Test { rpc DoThing (RequestMessage) returns (RequestResponse); + rpc DoThing2 (child.ChildRequestMessage) returns (RequestResponse); + rpc DoThing3 (Nested.RequestMessage) returns (RequestResponse); } @@ -13,3 +16,8 @@ message RequestResponse { int32 value = 1; } +message Nested { + message RequestMessage { + int32 nestedArgument = 1; + } +} \ No newline at end of file diff --git a/betterproto/tests/inputs/import_service_input_message/test_import_service.py b/betterproto/tests/inputs/import_service_input_message/test_import_service.py deleted file mode 100644 index 891b77a..0000000 --- a/betterproto/tests/inputs/import_service_input_message/test_import_service.py +++ /dev/null @@ -1,16 +0,0 @@ -import pytest - -from betterproto.tests.mocks import MockChannel -from betterproto.tests.output_betterproto.import_service_input_message import ( - RequestResponse, - TestStub, -) - - -@pytest.mark.xfail(reason="#68 Request Input Messages are not imported for service") -@pytest.mark.asyncio -async def test_service_correctly_imports_reference_message(): - mock_response = RequestResponse(value=10) - service = TestStub(MockChannel([mock_response])) - response = await service.do_thing() - assert mock_response == response diff --git a/betterproto/tests/inputs/import_service_input_message/test_import_service_input_message.py b/betterproto/tests/inputs/import_service_input_message/test_import_service_input_message.py new file mode 100644 index 0000000..e53fc48 --- /dev/null +++ b/betterproto/tests/inputs/import_service_input_message/test_import_service_input_message.py @@ -0,0 +1,31 @@ +import pytest + +from betterproto.tests.mocks import MockChannel +from betterproto.tests.output_betterproto.import_service_input_message import ( + RequestResponse, + TestStub, +) + + +@pytest.mark.asyncio +async def test_service_correctly_imports_reference_message(): + mock_response = RequestResponse(value=10) + service = TestStub(MockChannel([mock_response])) + response = await service.do_thing(argument=1) + assert mock_response == response + + +@pytest.mark.asyncio +async def test_service_correctly_imports_reference_message_from_child_package(): + mock_response = RequestResponse(value=10) + service = TestStub(MockChannel([mock_response])) + response = await service.do_thing2(child_argument=1) + assert mock_response == response + + +@pytest.mark.asyncio +async def test_service_correctly_imports_nested_reference(): + mock_response = RequestResponse(value=10) + service = TestStub(MockChannel([mock_response])) + response = await service.do_thing3(nested_argument=1) + assert mock_response == response