Fixes #23 again, a broken test made it seem the issue was fixed before.
This commit is contained in:
		@@ -15,8 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
 | 
			
		||||
> `2.0.0` will be released once the interface is stable.
 | 
			
		||||
 | 
			
		||||
- Add support for gRPC  and **stream-stream** [#83](https://github.com/danielgtaylor/python-betterproto/pull/83)
 | 
			
		||||
- Switch from  to `poetry` for development [#75](https://github.com/danielgtaylor/python-betterproto/pull/75)
 | 
			
		||||
- Fix No arguments are generated for stub methods when using import with proto definition 
 | 
			
		||||
- Switch from `pipenv` to `poetry` for development [#75](https://github.com/danielgtaylor/python-betterproto/pull/75)
 | 
			
		||||
- Fix two packages with the same name suffix should not cause naming conflict [#25](https://github.com/danielgtaylor/python-betterproto/issues/25)
 | 
			
		||||
 | 
			
		||||
- Fix Import child package from root [#57](https://github.com/danielgtaylor/python-betterproto/issues/57)
 | 
			
		||||
 
 | 
			
		||||
@@ -11,12 +11,13 @@ from typing import List, Union
 | 
			
		||||
from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest
 | 
			
		||||
 | 
			
		||||
import betterproto
 | 
			
		||||
from betterproto.compile.importing import get_type_reference
 | 
			
		||||
from betterproto.compile.importing import get_type_reference, parse_source_type_name
 | 
			
		||||
from betterproto.compile.naming import (
 | 
			
		||||
    pythonize_class_name,
 | 
			
		||||
    pythonize_field_name,
 | 
			
		||||
    pythonize_method_name,
 | 
			
		||||
)
 | 
			
		||||
from betterproto.lib.google.protobuf import ServiceDescriptorProto
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    # betterproto[compiler] specific dependencies
 | 
			
		||||
@@ -76,11 +77,12 @@ def get_py_zero(type_num: int) -> Union[str, float]:
 | 
			
		||||
    return zero
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Todo: Keep information about nested hierarchy
 | 
			
		||||
def traverse(proto_file):
 | 
			
		||||
    # Todo: Keep information about nested hierarchy
 | 
			
		||||
    def _traverse(path, items, prefix=""):
 | 
			
		||||
        for i, item in enumerate(items):
 | 
			
		||||
            # Adjust the name since we flatten the heirarchy.
 | 
			
		||||
            # Adjust the name since we flatten the hierarchy.
 | 
			
		||||
            # Todo: don't change the name, but include full name in returned tuple
 | 
			
		||||
            item.name = next_prefix = prefix + item.name
 | 
			
		||||
            yield item, path + [i]
 | 
			
		||||
 | 
			
		||||
@@ -162,17 +164,21 @@ def generate_code(request, response):
 | 
			
		||||
        output_package_content["template_data"] = template_data
 | 
			
		||||
 | 
			
		||||
    # Read Messages and Enums
 | 
			
		||||
    output_types = []
 | 
			
		||||
    for output_package_name, output_package_content in output_package_files.items():
 | 
			
		||||
        for proto_file in output_package_content["files"]:
 | 
			
		||||
            for item, path in traverse(proto_file):
 | 
			
		||||
                read_protobuf_object(item, path, proto_file, output_package_content)
 | 
			
		||||
                type_data = read_protobuf_type(
 | 
			
		||||
                    item, path, proto_file, output_package_content
 | 
			
		||||
                )
 | 
			
		||||
                output_types.append(type_data)
 | 
			
		||||
 | 
			
		||||
    # Read Services
 | 
			
		||||
    for output_package_name, output_package_content in output_package_files.items():
 | 
			
		||||
        for proto_file in output_package_content["files"]:
 | 
			
		||||
            for index, service in enumerate(proto_file.service):
 | 
			
		||||
                read_protobuf_service(
 | 
			
		||||
                    service, index, proto_file, output_package_content
 | 
			
		||||
                    service, index, proto_file, output_package_content, output_types
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    # Render files
 | 
			
		||||
@@ -214,63 +220,31 @@ def generate_code(request, response):
 | 
			
		||||
        print(f"Writing {output_package_name}", file=sys.stderr)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_protobuf_service(service: DescriptorProto, index, proto_file, content):
 | 
			
		||||
def lookup_method_input_type(method, types):
 | 
			
		||||
    package, name = parse_source_type_name(method.input_type)
 | 
			
		||||
 | 
			
		||||
    for known_type in types:
 | 
			
		||||
        if known_type["type"] != "Message":
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        # Nested types are currently flattened without dots.
 | 
			
		||||
        # Todo: keep a fully quantified name in types, that is comparable with method.input_type
 | 
			
		||||
        if (
 | 
			
		||||
            package == known_type["package"]
 | 
			
		||||
            and name.replace(".", "") == known_type["name"]
 | 
			
		||||
        ):
 | 
			
		||||
            return known_type
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file, content):
 | 
			
		||||
    input_package_name = content["input_package"]
 | 
			
		||||
    template_data = content["template_data"]
 | 
			
		||||
    # print(service, file=sys.stderr)
 | 
			
		||||
    data = {
 | 
			
		||||
        "name": service.name,
 | 
			
		||||
        "py_name": pythonize_class_name(service.name),
 | 
			
		||||
        "comment": get_comment(proto_file, [6, index]),
 | 
			
		||||
        "methods": [],
 | 
			
		||||
        "name": item.name,
 | 
			
		||||
        "py_name": pythonize_class_name(item.name),
 | 
			
		||||
        "descriptor": item,
 | 
			
		||||
        "package": input_package_name,
 | 
			
		||||
    }
 | 
			
		||||
    for j, method in enumerate(service.method):
 | 
			
		||||
        input_message = None
 | 
			
		||||
        input_type = get_type_reference(
 | 
			
		||||
            input_package_name, template_data["imports"], method.input_type
 | 
			
		||||
        ).strip('"')
 | 
			
		||||
        for msg in template_data["messages"]:
 | 
			
		||||
            if msg["name"] == input_type:
 | 
			
		||||
                input_message = msg
 | 
			
		||||
                for field in msg["properties"]:
 | 
			
		||||
                    if field["zero"] == "None":
 | 
			
		||||
                        template_data["typing_imports"].add("Optional")
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
        data["methods"].append(
 | 
			
		||||
            {
 | 
			
		||||
                "name": method.name,
 | 
			
		||||
                "py_name": pythonize_method_name(method.name),
 | 
			
		||||
                "comment": get_comment(proto_file, [6, index, 2, j], indent=8),
 | 
			
		||||
                "route": f"/{input_package_name}.{service.name}/{method.name}",
 | 
			
		||||
                "input": get_type_reference(
 | 
			
		||||
                    input_package_name, template_data["imports"], method.input_type
 | 
			
		||||
                ).strip('"'),
 | 
			
		||||
                "input_message": input_message,
 | 
			
		||||
                "output": get_type_reference(
 | 
			
		||||
                    input_package_name,
 | 
			
		||||
                    template_data["imports"],
 | 
			
		||||
                    method.output_type,
 | 
			
		||||
                    unwrap=False,
 | 
			
		||||
                ),
 | 
			
		||||
                "client_streaming": method.client_streaming,
 | 
			
		||||
                "server_streaming": method.server_streaming,
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if method.client_streaming:
 | 
			
		||||
            template_data["typing_imports"].add("AsyncIterable")
 | 
			
		||||
            template_data["typing_imports"].add("Iterable")
 | 
			
		||||
            template_data["typing_imports"].add("Union")
 | 
			
		||||
        if method.server_streaming:
 | 
			
		||||
            template_data["typing_imports"].add("AsyncIterator")
 | 
			
		||||
    template_data["services"].append(data)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, content):
 | 
			
		||||
    input_package_name = content["input_package"]
 | 
			
		||||
    template_data = content["template_data"]
 | 
			
		||||
    data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
 | 
			
		||||
    if isinstance(item, DescriptorProto):
 | 
			
		||||
        # print(item, file=sys.stderr)
 | 
			
		||||
        if item.options.map_entry:
 | 
			
		||||
@@ -373,6 +347,7 @@ def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, con
 | 
			
		||||
            # print(f, file=sys.stderr)
 | 
			
		||||
 | 
			
		||||
        template_data["messages"].append(data)
 | 
			
		||||
        return data
 | 
			
		||||
    elif isinstance(item, EnumDescriptorProto):
 | 
			
		||||
        # print(item.name, path, file=sys.stderr)
 | 
			
		||||
        data.update(
 | 
			
		||||
@@ -391,6 +366,57 @@ def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, con
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        template_data["enums"].append(data)
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_protobuf_service(
 | 
			
		||||
    service: ServiceDescriptorProto, index, proto_file, content, output_types
 | 
			
		||||
):
 | 
			
		||||
    input_package_name = content["input_package"]
 | 
			
		||||
    template_data = content["template_data"]
 | 
			
		||||
    # print(service, file=sys.stderr)
 | 
			
		||||
    data = {
 | 
			
		||||
        "name": service.name,
 | 
			
		||||
        "py_name": pythonize_class_name(service.name),
 | 
			
		||||
        "comment": get_comment(proto_file, [6, index]),
 | 
			
		||||
        "methods": [],
 | 
			
		||||
    }
 | 
			
		||||
    for j, method in enumerate(service.method):
 | 
			
		||||
        method_input_message = lookup_method_input_type(method, output_types)
 | 
			
		||||
 | 
			
		||||
        if method_input_message:
 | 
			
		||||
            for field in method_input_message["properties"]:
 | 
			
		||||
                if field["zero"] == "None":
 | 
			
		||||
                    template_data["typing_imports"].add("Optional")
 | 
			
		||||
 | 
			
		||||
        data["methods"].append(
 | 
			
		||||
            {
 | 
			
		||||
                "name": method.name,
 | 
			
		||||
                "py_name": pythonize_method_name(method.name),
 | 
			
		||||
                "comment": get_comment(proto_file, [6, index, 2, j], indent=8),
 | 
			
		||||
                "route": f"/{input_package_name}.{service.name}/{method.name}",
 | 
			
		||||
                "input": get_type_reference(
 | 
			
		||||
                    input_package_name, template_data["imports"], method.input_type
 | 
			
		||||
                ).strip('"'),
 | 
			
		||||
                "input_message": method_input_message,
 | 
			
		||||
                "output": get_type_reference(
 | 
			
		||||
                    input_package_name,
 | 
			
		||||
                    template_data["imports"],
 | 
			
		||||
                    method.output_type,
 | 
			
		||||
                    unwrap=False,
 | 
			
		||||
                ),
 | 
			
		||||
                "client_streaming": method.client_streaming,
 | 
			
		||||
                "server_streaming": method.server_streaming,
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if method.client_streaming:
 | 
			
		||||
            template_data["typing_imports"].add("AsyncIterable")
 | 
			
		||||
            template_data["typing_imports"].add("Iterable")
 | 
			
		||||
            template_data["typing_imports"].add("Union")
 | 
			
		||||
        if method.server_streaming:
 | 
			
		||||
            template_data["typing_imports"].add("AsyncIterator")
 | 
			
		||||
    template_data["services"].append(data)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
 
 | 
			
		||||
@@ -0,0 +1,7 @@
 | 
			
		||||
syntax = "proto3";
 | 
			
		||||
 | 
			
		||||
package child;
 | 
			
		||||
 | 
			
		||||
message ChildRequestMessage {
 | 
			
		||||
    int32 child_argument = 1;
 | 
			
		||||
}
 | 
			
		||||
@@ -1,11 +1,14 @@
 | 
			
		||||
syntax = "proto3";
 | 
			
		||||
 | 
			
		||||
import "request_message.proto";
 | 
			
		||||
import "child_package_request_message.proto";
 | 
			
		||||
 | 
			
		||||
// Tests generated service correctly imports the RequestMessage
 | 
			
		||||
 | 
			
		||||
service Test {
 | 
			
		||||
    rpc DoThing (RequestMessage) returns (RequestResponse);
 | 
			
		||||
    rpc DoThing2 (child.ChildRequestMessage) returns (RequestResponse);
 | 
			
		||||
    rpc DoThing3 (Nested.RequestMessage) returns (RequestResponse);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -13,3 +16,8 @@ message RequestResponse {
 | 
			
		||||
    int32 value = 1;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
message Nested {
 | 
			
		||||
    message RequestMessage {
 | 
			
		||||
        int32 nestedArgument = 1;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -1,16 +0,0 @@
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from betterproto.tests.mocks import MockChannel
 | 
			
		||||
from betterproto.tests.output_betterproto.import_service_input_message import (
 | 
			
		||||
    RequestResponse,
 | 
			
		||||
    TestStub,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.xfail(reason="#68 Request Input Messages are not imported for service")
 | 
			
		||||
@pytest.mark.asyncio
 | 
			
		||||
async def test_service_correctly_imports_reference_message():
 | 
			
		||||
    mock_response = RequestResponse(value=10)
 | 
			
		||||
    service = TestStub(MockChannel([mock_response]))
 | 
			
		||||
    response = await service.do_thing()
 | 
			
		||||
    assert mock_response == response
 | 
			
		||||
@@ -0,0 +1,31 @@
 | 
			
		||||
import pytest
 | 
			
		||||
 | 
			
		||||
from betterproto.tests.mocks import MockChannel
 | 
			
		||||
from betterproto.tests.output_betterproto.import_service_input_message import (
 | 
			
		||||
    RequestResponse,
 | 
			
		||||
    TestStub,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.asyncio
 | 
			
		||||
async def test_service_correctly_imports_reference_message():
 | 
			
		||||
    mock_response = RequestResponse(value=10)
 | 
			
		||||
    service = TestStub(MockChannel([mock_response]))
 | 
			
		||||
    response = await service.do_thing(argument=1)
 | 
			
		||||
    assert mock_response == response
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.asyncio
 | 
			
		||||
async def test_service_correctly_imports_reference_message_from_child_package():
 | 
			
		||||
    mock_response = RequestResponse(value=10)
 | 
			
		||||
    service = TestStub(MockChannel([mock_response]))
 | 
			
		||||
    response = await service.do_thing2(child_argument=1)
 | 
			
		||||
    assert mock_response == response
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.asyncio
 | 
			
		||||
async def test_service_correctly_imports_nested_reference():
 | 
			
		||||
    mock_response = RequestResponse(value=10)
 | 
			
		||||
    service = TestStub(MockChannel([mock_response]))
 | 
			
		||||
    response = await service.do_thing3(nested_argument=1)
 | 
			
		||||
    assert mock_response == response
 | 
			
		||||
		Reference in New Issue
	
	Block a user