Fixes #23 again, a broken test made it seem the issue was fixed before.
This commit is contained in:
parent
dedead048f
commit
3f519d4fb1
@ -15,8 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
> `2.0.0` will be released once the interface is stable.
|
> `2.0.0` will be released once the interface is stable.
|
||||||
|
|
||||||
- Add support for gRPC and **stream-stream** [#83](https://github.com/danielgtaylor/python-betterproto/pull/83)
|
- Add support for gRPC and **stream-stream** [#83](https://github.com/danielgtaylor/python-betterproto/pull/83)
|
||||||
- Switch from to `poetry` for development [#75](https://github.com/danielgtaylor/python-betterproto/pull/75)
|
- Switch from `pipenv` to `poetry` for development [#75](https://github.com/danielgtaylor/python-betterproto/pull/75)
|
||||||
- Fix No arguments are generated for stub methods when using import with proto definition
|
|
||||||
- Fix two packages with the same name suffix should not cause naming conflict [#25](https://github.com/danielgtaylor/python-betterproto/issues/25)
|
- Fix two packages with the same name suffix should not cause naming conflict [#25](https://github.com/danielgtaylor/python-betterproto/issues/25)
|
||||||
|
|
||||||
- Fix Import child package from root [#57](https://github.com/danielgtaylor/python-betterproto/issues/57)
|
- Fix Import child package from root [#57](https://github.com/danielgtaylor/python-betterproto/issues/57)
|
||||||
|
@ -11,12 +11,13 @@ from typing import List, Union
|
|||||||
from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest
|
from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest
|
||||||
|
|
||||||
import betterproto
|
import betterproto
|
||||||
from betterproto.compile.importing import get_type_reference
|
from betterproto.compile.importing import get_type_reference, parse_source_type_name
|
||||||
from betterproto.compile.naming import (
|
from betterproto.compile.naming import (
|
||||||
pythonize_class_name,
|
pythonize_class_name,
|
||||||
pythonize_field_name,
|
pythonize_field_name,
|
||||||
pythonize_method_name,
|
pythonize_method_name,
|
||||||
)
|
)
|
||||||
|
from betterproto.lib.google.protobuf import ServiceDescriptorProto
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# betterproto[compiler] specific dependencies
|
# betterproto[compiler] specific dependencies
|
||||||
@ -76,11 +77,12 @@ def get_py_zero(type_num: int) -> Union[str, float]:
|
|||||||
return zero
|
return zero
|
||||||
|
|
||||||
|
|
||||||
# Todo: Keep information about nested hierarchy
|
|
||||||
def traverse(proto_file):
|
def traverse(proto_file):
|
||||||
|
# Todo: Keep information about nested hierarchy
|
||||||
def _traverse(path, items, prefix=""):
|
def _traverse(path, items, prefix=""):
|
||||||
for i, item in enumerate(items):
|
for i, item in enumerate(items):
|
||||||
# Adjust the name since we flatten the heirarchy.
|
# Adjust the name since we flatten the hierarchy.
|
||||||
|
# Todo: don't change the name, but include full name in returned tuple
|
||||||
item.name = next_prefix = prefix + item.name
|
item.name = next_prefix = prefix + item.name
|
||||||
yield item, path + [i]
|
yield item, path + [i]
|
||||||
|
|
||||||
@ -162,17 +164,21 @@ def generate_code(request, response):
|
|||||||
output_package_content["template_data"] = template_data
|
output_package_content["template_data"] = template_data
|
||||||
|
|
||||||
# Read Messages and Enums
|
# Read Messages and Enums
|
||||||
|
output_types = []
|
||||||
for output_package_name, output_package_content in output_package_files.items():
|
for output_package_name, output_package_content in output_package_files.items():
|
||||||
for proto_file in output_package_content["files"]:
|
for proto_file in output_package_content["files"]:
|
||||||
for item, path in traverse(proto_file):
|
for item, path in traverse(proto_file):
|
||||||
read_protobuf_object(item, path, proto_file, output_package_content)
|
type_data = read_protobuf_type(
|
||||||
|
item, path, proto_file, output_package_content
|
||||||
|
)
|
||||||
|
output_types.append(type_data)
|
||||||
|
|
||||||
# Read Services
|
# Read Services
|
||||||
for output_package_name, output_package_content in output_package_files.items():
|
for output_package_name, output_package_content in output_package_files.items():
|
||||||
for proto_file in output_package_content["files"]:
|
for proto_file in output_package_content["files"]:
|
||||||
for index, service in enumerate(proto_file.service):
|
for index, service in enumerate(proto_file.service):
|
||||||
read_protobuf_service(
|
read_protobuf_service(
|
||||||
service, index, proto_file, output_package_content
|
service, index, proto_file, output_package_content, output_types
|
||||||
)
|
)
|
||||||
|
|
||||||
# Render files
|
# Render files
|
||||||
@ -214,63 +220,31 @@ def generate_code(request, response):
|
|||||||
print(f"Writing {output_package_name}", file=sys.stderr)
|
print(f"Writing {output_package_name}", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
def read_protobuf_service(service: DescriptorProto, index, proto_file, content):
|
def lookup_method_input_type(method, types):
|
||||||
|
package, name = parse_source_type_name(method.input_type)
|
||||||
|
|
||||||
|
for known_type in types:
|
||||||
|
if known_type["type"] != "Message":
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Nested types are currently flattened without dots.
|
||||||
|
# Todo: keep a fully quantified name in types, that is comparable with method.input_type
|
||||||
|
if (
|
||||||
|
package == known_type["package"]
|
||||||
|
and name.replace(".", "") == known_type["name"]
|
||||||
|
):
|
||||||
|
return known_type
|
||||||
|
|
||||||
|
|
||||||
|
def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file, content):
|
||||||
input_package_name = content["input_package"]
|
input_package_name = content["input_package"]
|
||||||
template_data = content["template_data"]
|
template_data = content["template_data"]
|
||||||
# print(service, file=sys.stderr)
|
|
||||||
data = {
|
data = {
|
||||||
"name": service.name,
|
"name": item.name,
|
||||||
"py_name": pythonize_class_name(service.name),
|
"py_name": pythonize_class_name(item.name),
|
||||||
"comment": get_comment(proto_file, [6, index]),
|
"descriptor": item,
|
||||||
"methods": [],
|
"package": input_package_name,
|
||||||
}
|
}
|
||||||
for j, method in enumerate(service.method):
|
|
||||||
input_message = None
|
|
||||||
input_type = get_type_reference(
|
|
||||||
input_package_name, template_data["imports"], method.input_type
|
|
||||||
).strip('"')
|
|
||||||
for msg in template_data["messages"]:
|
|
||||||
if msg["name"] == input_type:
|
|
||||||
input_message = msg
|
|
||||||
for field in msg["properties"]:
|
|
||||||
if field["zero"] == "None":
|
|
||||||
template_data["typing_imports"].add("Optional")
|
|
||||||
break
|
|
||||||
|
|
||||||
data["methods"].append(
|
|
||||||
{
|
|
||||||
"name": method.name,
|
|
||||||
"py_name": pythonize_method_name(method.name),
|
|
||||||
"comment": get_comment(proto_file, [6, index, 2, j], indent=8),
|
|
||||||
"route": f"/{input_package_name}.{service.name}/{method.name}",
|
|
||||||
"input": get_type_reference(
|
|
||||||
input_package_name, template_data["imports"], method.input_type
|
|
||||||
).strip('"'),
|
|
||||||
"input_message": input_message,
|
|
||||||
"output": get_type_reference(
|
|
||||||
input_package_name,
|
|
||||||
template_data["imports"],
|
|
||||||
method.output_type,
|
|
||||||
unwrap=False,
|
|
||||||
),
|
|
||||||
"client_streaming": method.client_streaming,
|
|
||||||
"server_streaming": method.server_streaming,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if method.client_streaming:
|
|
||||||
template_data["typing_imports"].add("AsyncIterable")
|
|
||||||
template_data["typing_imports"].add("Iterable")
|
|
||||||
template_data["typing_imports"].add("Union")
|
|
||||||
if method.server_streaming:
|
|
||||||
template_data["typing_imports"].add("AsyncIterator")
|
|
||||||
template_data["services"].append(data)
|
|
||||||
|
|
||||||
|
|
||||||
def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, content):
|
|
||||||
input_package_name = content["input_package"]
|
|
||||||
template_data = content["template_data"]
|
|
||||||
data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
|
|
||||||
if isinstance(item, DescriptorProto):
|
if isinstance(item, DescriptorProto):
|
||||||
# print(item, file=sys.stderr)
|
# print(item, file=sys.stderr)
|
||||||
if item.options.map_entry:
|
if item.options.map_entry:
|
||||||
@ -373,6 +347,7 @@ def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, con
|
|||||||
# print(f, file=sys.stderr)
|
# print(f, file=sys.stderr)
|
||||||
|
|
||||||
template_data["messages"].append(data)
|
template_data["messages"].append(data)
|
||||||
|
return data
|
||||||
elif isinstance(item, EnumDescriptorProto):
|
elif isinstance(item, EnumDescriptorProto):
|
||||||
# print(item.name, path, file=sys.stderr)
|
# print(item.name, path, file=sys.stderr)
|
||||||
data.update(
|
data.update(
|
||||||
@ -391,6 +366,57 @@ def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, con
|
|||||||
)
|
)
|
||||||
|
|
||||||
template_data["enums"].append(data)
|
template_data["enums"].append(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def read_protobuf_service(
|
||||||
|
service: ServiceDescriptorProto, index, proto_file, content, output_types
|
||||||
|
):
|
||||||
|
input_package_name = content["input_package"]
|
||||||
|
template_data = content["template_data"]
|
||||||
|
# print(service, file=sys.stderr)
|
||||||
|
data = {
|
||||||
|
"name": service.name,
|
||||||
|
"py_name": pythonize_class_name(service.name),
|
||||||
|
"comment": get_comment(proto_file, [6, index]),
|
||||||
|
"methods": [],
|
||||||
|
}
|
||||||
|
for j, method in enumerate(service.method):
|
||||||
|
method_input_message = lookup_method_input_type(method, output_types)
|
||||||
|
|
||||||
|
if method_input_message:
|
||||||
|
for field in method_input_message["properties"]:
|
||||||
|
if field["zero"] == "None":
|
||||||
|
template_data["typing_imports"].add("Optional")
|
||||||
|
|
||||||
|
data["methods"].append(
|
||||||
|
{
|
||||||
|
"name": method.name,
|
||||||
|
"py_name": pythonize_method_name(method.name),
|
||||||
|
"comment": get_comment(proto_file, [6, index, 2, j], indent=8),
|
||||||
|
"route": f"/{input_package_name}.{service.name}/{method.name}",
|
||||||
|
"input": get_type_reference(
|
||||||
|
input_package_name, template_data["imports"], method.input_type
|
||||||
|
).strip('"'),
|
||||||
|
"input_message": method_input_message,
|
||||||
|
"output": get_type_reference(
|
||||||
|
input_package_name,
|
||||||
|
template_data["imports"],
|
||||||
|
method.output_type,
|
||||||
|
unwrap=False,
|
||||||
|
),
|
||||||
|
"client_streaming": method.client_streaming,
|
||||||
|
"server_streaming": method.server_streaming,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if method.client_streaming:
|
||||||
|
template_data["typing_imports"].add("AsyncIterable")
|
||||||
|
template_data["typing_imports"].add("Iterable")
|
||||||
|
template_data["typing_imports"].add("Union")
|
||||||
|
if method.server_streaming:
|
||||||
|
template_data["typing_imports"].add("AsyncIterator")
|
||||||
|
template_data["services"].append(data)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -0,0 +1,7 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package child;
|
||||||
|
|
||||||
|
message ChildRequestMessage {
|
||||||
|
int32 child_argument = 1;
|
||||||
|
}
|
@ -1,11 +1,14 @@
|
|||||||
syntax = "proto3";
|
syntax = "proto3";
|
||||||
|
|
||||||
import "request_message.proto";
|
import "request_message.proto";
|
||||||
|
import "child_package_request_message.proto";
|
||||||
|
|
||||||
// Tests generated service correctly imports the RequestMessage
|
// Tests generated service correctly imports the RequestMessage
|
||||||
|
|
||||||
service Test {
|
service Test {
|
||||||
rpc DoThing (RequestMessage) returns (RequestResponse);
|
rpc DoThing (RequestMessage) returns (RequestResponse);
|
||||||
|
rpc DoThing2 (child.ChildRequestMessage) returns (RequestResponse);
|
||||||
|
rpc DoThing3 (Nested.RequestMessage) returns (RequestResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -13,3 +16,8 @@ message RequestResponse {
|
|||||||
int32 value = 1;
|
int32 value = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message Nested {
|
||||||
|
message RequestMessage {
|
||||||
|
int32 nestedArgument = 1;
|
||||||
|
}
|
||||||
|
}
|
@ -1,16 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from betterproto.tests.mocks import MockChannel
|
|
||||||
from betterproto.tests.output_betterproto.import_service_input_message import (
|
|
||||||
RequestResponse,
|
|
||||||
TestStub,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="#68 Request Input Messages are not imported for service")
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_service_correctly_imports_reference_message():
|
|
||||||
mock_response = RequestResponse(value=10)
|
|
||||||
service = TestStub(MockChannel([mock_response]))
|
|
||||||
response = await service.do_thing()
|
|
||||||
assert mock_response == response
|
|
@ -0,0 +1,31 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from betterproto.tests.mocks import MockChannel
|
||||||
|
from betterproto.tests.output_betterproto.import_service_input_message import (
|
||||||
|
RequestResponse,
|
||||||
|
TestStub,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_service_correctly_imports_reference_message():
|
||||||
|
mock_response = RequestResponse(value=10)
|
||||||
|
service = TestStub(MockChannel([mock_response]))
|
||||||
|
response = await service.do_thing(argument=1)
|
||||||
|
assert mock_response == response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_service_correctly_imports_reference_message_from_child_package():
|
||||||
|
mock_response = RequestResponse(value=10)
|
||||||
|
service = TestStub(MockChannel([mock_response]))
|
||||||
|
response = await service.do_thing2(child_argument=1)
|
||||||
|
assert mock_response == response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_service_correctly_imports_nested_reference():
|
||||||
|
mock_response = RequestResponse(value=10)
|
||||||
|
service = TestStub(MockChannel([mock_response]))
|
||||||
|
response = await service.do_thing3(nested_argument=1)
|
||||||
|
assert mock_response == response
|
Loading…
x
Reference in New Issue
Block a user