From 3f519d4fb1216a127ebbc6b228e21b46cc239cfd Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Sun, 5 Jul 2020 17:14:53 +0200 Subject: [PATCH] Fixes #23 again, a broken test made it seem the issue was fixed before. --- CHANGELOG.md | 3 +- betterproto/plugin.py | 142 +++++++++++------- .../child_package_request_message.proto | 7 + .../import_service_input_message.proto | 8 + .../test_import_service.py | 16 -- .../test_import_service_input_message.py | 31 ++++ 6 files changed, 131 insertions(+), 76 deletions(-) create mode 100644 betterproto/tests/inputs/import_service_input_message/child_package_request_message.proto delete mode 100644 betterproto/tests/inputs/import_service_input_message/test_import_service.py create mode 100644 betterproto/tests/inputs/import_service_input_message/test_import_service_input_message.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 383d3f7..c5c65b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,8 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 > `2.0.0` will be released once the interface is stable. - Add support for gRPC and **stream-stream** [#83](https://github.com/danielgtaylor/python-betterproto/pull/83) -- Switch from to `poetry` for development [#75](https://github.com/danielgtaylor/python-betterproto/pull/75) -- Fix No arguments are generated for stub methods when using import with proto definition +- Switch from `pipenv` to `poetry` for development [#75](https://github.com/danielgtaylor/python-betterproto/pull/75) - Fix two packages with the same name suffix should not cause naming conflict [#25](https://github.com/danielgtaylor/python-betterproto/issues/25) - Fix Import child package from root [#57](https://github.com/danielgtaylor/python-betterproto/issues/57) diff --git a/betterproto/plugin.py b/betterproto/plugin.py index a795efa..4ab1b93 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -11,12 +11,13 @@ from typing import List, Union from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest import betterproto -from betterproto.compile.importing import get_type_reference +from betterproto.compile.importing import get_type_reference, parse_source_type_name from betterproto.compile.naming import ( pythonize_class_name, pythonize_field_name, pythonize_method_name, ) +from betterproto.lib.google.protobuf import ServiceDescriptorProto try: # betterproto[compiler] specific dependencies @@ -76,11 +77,12 @@ def get_py_zero(type_num: int) -> Union[str, float]: return zero -# Todo: Keep information about nested hierarchy def traverse(proto_file): + # Todo: Keep information about nested hierarchy def _traverse(path, items, prefix=""): for i, item in enumerate(items): - # Adjust the name since we flatten the heirarchy. + # Adjust the name since we flatten the hierarchy. + # Todo: don't change the name, but include full name in returned tuple item.name = next_prefix = prefix + item.name yield item, path + [i] @@ -162,17 +164,21 @@ def generate_code(request, response): output_package_content["template_data"] = template_data # Read Messages and Enums + output_types = [] for output_package_name, output_package_content in output_package_files.items(): for proto_file in output_package_content["files"]: for item, path in traverse(proto_file): - read_protobuf_object(item, path, proto_file, output_package_content) + type_data = read_protobuf_type( + item, path, proto_file, output_package_content + ) + output_types.append(type_data) # Read Services for output_package_name, output_package_content in output_package_files.items(): for proto_file in output_package_content["files"]: for index, service in enumerate(proto_file.service): read_protobuf_service( - service, index, proto_file, output_package_content + service, index, proto_file, output_package_content, output_types ) # Render files @@ -214,63 +220,31 @@ def generate_code(request, response): print(f"Writing {output_package_name}", file=sys.stderr) -def read_protobuf_service(service: DescriptorProto, index, proto_file, content): +def lookup_method_input_type(method, types): + package, name = parse_source_type_name(method.input_type) + + for known_type in types: + if known_type["type"] != "Message": + continue + + # Nested types are currently flattened without dots. + # Todo: keep a fully quantified name in types, that is comparable with method.input_type + if ( + package == known_type["package"] + and name.replace(".", "") == known_type["name"] + ): + return known_type + + +def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file, content): input_package_name = content["input_package"] template_data = content["template_data"] - # print(service, file=sys.stderr) data = { - "name": service.name, - "py_name": pythonize_class_name(service.name), - "comment": get_comment(proto_file, [6, index]), - "methods": [], + "name": item.name, + "py_name": pythonize_class_name(item.name), + "descriptor": item, + "package": input_package_name, } - for j, method in enumerate(service.method): - input_message = None - input_type = get_type_reference( - input_package_name, template_data["imports"], method.input_type - ).strip('"') - for msg in template_data["messages"]: - if msg["name"] == input_type: - input_message = msg - for field in msg["properties"]: - if field["zero"] == "None": - template_data["typing_imports"].add("Optional") - break - - data["methods"].append( - { - "name": method.name, - "py_name": pythonize_method_name(method.name), - "comment": get_comment(proto_file, [6, index, 2, j], indent=8), - "route": f"/{input_package_name}.{service.name}/{method.name}", - "input": get_type_reference( - input_package_name, template_data["imports"], method.input_type - ).strip('"'), - "input_message": input_message, - "output": get_type_reference( - input_package_name, - template_data["imports"], - method.output_type, - unwrap=False, - ), - "client_streaming": method.client_streaming, - "server_streaming": method.server_streaming, - } - ) - - if method.client_streaming: - template_data["typing_imports"].add("AsyncIterable") - template_data["typing_imports"].add("Iterable") - template_data["typing_imports"].add("Union") - if method.server_streaming: - template_data["typing_imports"].add("AsyncIterator") - template_data["services"].append(data) - - -def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, content): - input_package_name = content["input_package"] - template_data = content["template_data"] - data = {"name": item.name, "py_name": pythonize_class_name(item.name)} if isinstance(item, DescriptorProto): # print(item, file=sys.stderr) if item.options.map_entry: @@ -373,6 +347,7 @@ def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, con # print(f, file=sys.stderr) template_data["messages"].append(data) + return data elif isinstance(item, EnumDescriptorProto): # print(item.name, path, file=sys.stderr) data.update( @@ -391,6 +366,57 @@ def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, con ) template_data["enums"].append(data) + return data + + +def read_protobuf_service( + service: ServiceDescriptorProto, index, proto_file, content, output_types +): + input_package_name = content["input_package"] + template_data = content["template_data"] + # print(service, file=sys.stderr) + data = { + "name": service.name, + "py_name": pythonize_class_name(service.name), + "comment": get_comment(proto_file, [6, index]), + "methods": [], + } + for j, method in enumerate(service.method): + method_input_message = lookup_method_input_type(method, output_types) + + if method_input_message: + for field in method_input_message["properties"]: + if field["zero"] == "None": + template_data["typing_imports"].add("Optional") + + data["methods"].append( + { + "name": method.name, + "py_name": pythonize_method_name(method.name), + "comment": get_comment(proto_file, [6, index, 2, j], indent=8), + "route": f"/{input_package_name}.{service.name}/{method.name}", + "input": get_type_reference( + input_package_name, template_data["imports"], method.input_type + ).strip('"'), + "input_message": method_input_message, + "output": get_type_reference( + input_package_name, + template_data["imports"], + method.output_type, + unwrap=False, + ), + "client_streaming": method.client_streaming, + "server_streaming": method.server_streaming, + } + ) + + if method.client_streaming: + template_data["typing_imports"].add("AsyncIterable") + template_data["typing_imports"].add("Iterable") + template_data["typing_imports"].add("Union") + if method.server_streaming: + template_data["typing_imports"].add("AsyncIterator") + template_data["services"].append(data) def main(): diff --git a/betterproto/tests/inputs/import_service_input_message/child_package_request_message.proto b/betterproto/tests/inputs/import_service_input_message/child_package_request_message.proto new file mode 100644 index 0000000..6380db2 --- /dev/null +++ b/betterproto/tests/inputs/import_service_input_message/child_package_request_message.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package child; + +message ChildRequestMessage { + int32 child_argument = 1; +} \ No newline at end of file diff --git a/betterproto/tests/inputs/import_service_input_message/import_service_input_message.proto b/betterproto/tests/inputs/import_service_input_message/import_service_input_message.proto index a5073db..7ca9c46 100644 --- a/betterproto/tests/inputs/import_service_input_message/import_service_input_message.proto +++ b/betterproto/tests/inputs/import_service_input_message/import_service_input_message.proto @@ -1,11 +1,14 @@ syntax = "proto3"; import "request_message.proto"; +import "child_package_request_message.proto"; // Tests generated service correctly imports the RequestMessage service Test { rpc DoThing (RequestMessage) returns (RequestResponse); + rpc DoThing2 (child.ChildRequestMessage) returns (RequestResponse); + rpc DoThing3 (Nested.RequestMessage) returns (RequestResponse); } @@ -13,3 +16,8 @@ message RequestResponse { int32 value = 1; } +message Nested { + message RequestMessage { + int32 nestedArgument = 1; + } +} \ No newline at end of file diff --git a/betterproto/tests/inputs/import_service_input_message/test_import_service.py b/betterproto/tests/inputs/import_service_input_message/test_import_service.py deleted file mode 100644 index 891b77a..0000000 --- a/betterproto/tests/inputs/import_service_input_message/test_import_service.py +++ /dev/null @@ -1,16 +0,0 @@ -import pytest - -from betterproto.tests.mocks import MockChannel -from betterproto.tests.output_betterproto.import_service_input_message import ( - RequestResponse, - TestStub, -) - - -@pytest.mark.xfail(reason="#68 Request Input Messages are not imported for service") -@pytest.mark.asyncio -async def test_service_correctly_imports_reference_message(): - mock_response = RequestResponse(value=10) - service = TestStub(MockChannel([mock_response])) - response = await service.do_thing() - assert mock_response == response diff --git a/betterproto/tests/inputs/import_service_input_message/test_import_service_input_message.py b/betterproto/tests/inputs/import_service_input_message/test_import_service_input_message.py new file mode 100644 index 0000000..e53fc48 --- /dev/null +++ b/betterproto/tests/inputs/import_service_input_message/test_import_service_input_message.py @@ -0,0 +1,31 @@ +import pytest + +from betterproto.tests.mocks import MockChannel +from betterproto.tests.output_betterproto.import_service_input_message import ( + RequestResponse, + TestStub, +) + + +@pytest.mark.asyncio +async def test_service_correctly_imports_reference_message(): + mock_response = RequestResponse(value=10) + service = TestStub(MockChannel([mock_response])) + response = await service.do_thing(argument=1) + assert mock_response == response + + +@pytest.mark.asyncio +async def test_service_correctly_imports_reference_message_from_child_package(): + mock_response = RequestResponse(value=10) + service = TestStub(MockChannel([mock_response])) + response = await service.do_thing2(child_argument=1) + assert mock_response == response + + +@pytest.mark.asyncio +async def test_service_correctly_imports_nested_reference(): + mock_response = RequestResponse(value=10) + service = TestStub(MockChannel([mock_response])) + response = await service.do_thing3(nested_argument=1) + assert mock_response == response