diff --git a/betterproto/plugin.py b/betterproto/plugin.py index c96a1ed..f1a6d30 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -114,11 +114,11 @@ def get_comment(proto_file, path: List[int]) -> str: sci.leading_comments.strip().replace("\n", ""), width=75 ) - if path[-2] == 2: + if path[-2] == 2 and path[-4] != 6: # This is a field return " # " + " # ".join(lines) else: - # This is a class + # This is a message, enum, service, or method if len(lines) == 1 and len(lines[0]) < 70: lines[0] = lines[0].strip('"') return f' """{lines[0]}"""' @@ -278,13 +278,16 @@ def generate_code(request, response): output["enums"].append(data) - for service in proto_file.service: + for i, service in enumerate(proto_file.service): # print(service, file=sys.stderr) - # TODO: comments - data = {"name": service.name, "methods": []} + data = { + "name": service.name, + "comment": get_comment(proto_file, [6, i]), + "methods": [], + } - for method in service.method: + for j, method in enumerate(service.method): if method.client_streaming: raise NotImplementedError("Client streaming not yet supported") @@ -304,6 +307,7 @@ def generate_code(request, response): { "name": method.name, "py_name": snake_case(method.name), + "comment": get_comment(proto_file, [6, i, 2, j]), "route": f"/{package}.{service.name}/{method.name}", "input": get_ref_type( package, output["imports"], method.input_type diff --git a/betterproto/templates/template.py b/betterproto/templates/template.py index 2e3441a..a8af85e 100644 --- a/betterproto/templates/template.py +++ b/betterproto/templates/template.py @@ -54,8 +54,16 @@ class {{ message.name }}(betterproto.Message): {% endfor %} {% for service in description.services %} class {{ service.name }}Stub(betterproto.ServiceStub): + {% if service.comment %} +{{ service.comment }} + + {% endif %} {% for method in service.methods %} async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}: + {% if method.comment %} +{{ method.comment }} + + {% endif %} request = {{ method.input }}() {% for field in method.input_message.properties %} {% if field.field_type == 'message' %}