From 5f7e4d58efa8aeb7af28c9ed89e4870548a1ce62 Mon Sep 17 00:00:00 2001 From: Gabriel Pajot <50614942+alk-gpajot@users.noreply.github.com> Date: Fri, 18 Mar 2022 23:36:27 +0100 Subject: [PATCH] Fix documentation for nested enums (#351) --- src/betterproto/plugin/parser.py | 40 +++++++++----------- tests/inputs/nestedtwice/nestedtwice.proto | 12 ++++++ tests/inputs/nestedtwice/test_nestedtwice.py | 25 ++++++++++++ 3 files changed, 54 insertions(+), 23 deletions(-) create mode 100644 tests/inputs/nestedtwice/test_nestedtwice.py diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py index e05e568..5d23e4f 100644 --- a/src/betterproto/plugin/parser.py +++ b/src/betterproto/plugin/parser.py @@ -1,9 +1,7 @@ -import itertools import pathlib import sys from typing import ( - TYPE_CHECKING, - Iterator, + Generator, List, Set, Tuple, @@ -13,7 +11,6 @@ from typing import ( from betterproto.lib.google.protobuf import ( DescriptorProto, EnumDescriptorProto, - FieldDescriptorProto, FileDescriptorProto, ServiceDescriptorProto, ) @@ -40,35 +37,32 @@ from .models import ( ) -if TYPE_CHECKING: - from google.protobuf.descriptor import Descriptor - - def traverse( - proto_file: FieldDescriptorProto, -) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]": + proto_file: FileDescriptorProto, +) -> Generator[ + Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None +]: # Todo: Keep information about nested hierarchy def _traverse( - path: List[int], items: List["EnumDescriptorProto"], prefix="" - ) -> Iterator[Tuple[Union[str, EnumDescriptorProto], List[int]]]: + path: List[int], + items: Union[List[EnumDescriptorProto], List[DescriptorProto]], + prefix: str = "", + ) -> Generator[ + Tuple[Union[EnumDescriptorProto, DescriptorProto], List[int]], None, None + ]: for i, item in enumerate(items): # Adjust the name since we flatten the hierarchy. # Todo: don't change the name, but include full name in returned tuple item.name = next_prefix = f"{prefix}_{item.name}" - yield item, path + [i] + yield item, [*path, i] if isinstance(item, DescriptorProto): - for enum in item.enum_type: - enum.name = f"{next_prefix}_{enum.name}" - yield enum, path + [i, 4] + # Get nested types. + yield from _traverse([*path, i, 4], item.enum_type, next_prefix) + yield from _traverse([*path, i, 3], item.nested_type, next_prefix) - if item.nested_type: - for n, p in _traverse(path + [i, 3], item.nested_type, next_prefix): - yield n, p - - return itertools.chain( - _traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type) - ) + yield from _traverse([5], proto_file.enum_type) + yield from _traverse([4], proto_file.message_type) def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: diff --git a/tests/inputs/nestedtwice/nestedtwice.proto b/tests/inputs/nestedtwice/nestedtwice.proto index 9a54a86..84d142a 100644 --- a/tests/inputs/nestedtwice/nestedtwice.proto +++ b/tests/inputs/nestedtwice/nestedtwice.proto @@ -2,27 +2,39 @@ syntax = "proto3"; package nestedtwice; +/* Test doc. */ message Test { + /* Top doc. */ message Top { + /* Middle doc. */ message Middle { + /* TopMiddleBottom doc.*/ message TopMiddleBottom { + // TopMiddleBottom.a doc. string a = 1; } + /* EnumBottom doc. */ enum EnumBottom{ + /* EnumBottom.A doc. */ A = 0; B = 1; } + /* Bottom doc. */ message Bottom { + /* Bottom.foo doc. */ string foo = 1; } reserved 1; + /* Middle.bottom doc. */ repeated Bottom bottom = 2; repeated EnumBottom enumBottom=3; repeated TopMiddleBottom topMiddleBottom=4; bool bar = 5; } + /* Top.name doc. */ string name = 1; Middle middle = 2; } + /* Test.top doc. */ Top top = 1; } diff --git a/tests/inputs/nestedtwice/test_nestedtwice.py b/tests/inputs/nestedtwice/test_nestedtwice.py new file mode 100644 index 0000000..606467c --- /dev/null +++ b/tests/inputs/nestedtwice/test_nestedtwice.py @@ -0,0 +1,25 @@ +import pytest + +from tests.output_betterproto.nestedtwice import ( + Test, + TestTop, + TestTopMiddle, + TestTopMiddleBottom, + TestTopMiddleEnumBottom, + TestTopMiddleTopMiddleBottom, +) + + +@pytest.mark.parametrize( + ("cls", "expected_comment"), + [ + (Test, "Test doc."), + (TestTopMiddleEnumBottom, "EnumBottom doc."), + (TestTop, "Top doc."), + (TestTopMiddle, "Middle doc."), + (TestTopMiddleTopMiddleBottom, "TopMiddleBottom doc."), + (TestTopMiddleBottom, "Bottom doc."), + ], +) +def test_comment(cls, expected_comment): + assert cls.__doc__ == expected_comment