From f7c2fd1194e8e6419fe36928c3898a7411d3810d Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Sun, 7 Jun 2020 16:57:57 +0200 Subject: [PATCH] Support nested messages, fix casing. Support test-cases in packages. --- Pipfile | 1 - betterproto/__init__.py | 4 +- betterproto/casing.py | 46 ++++++- betterproto/compile/importing.py | 114 ++++++++++-------- betterproto/compile/naming.py | 13 ++ betterproto/plugin.py | 49 +++----- betterproto/tests/inputs/casing/casing.proto | 1 + .../tests/inputs/casing/test_casing.py | 1 + betterproto/tests/inputs/config.py | 17 ++- .../import_circular_dependency.proto | 6 +- .../child.proto | 11 ++ .../import_root_package_from_child.proto | 11 -- betterproto/tests/inputs/nested/nested.proto | 2 +- .../tests/inputs/nestedtwice/nestedtwice.json | 10 +- .../inputs/nestedtwice/nestedtwice.proto | 20 +-- betterproto/tests/inputs/ref/ref.proto | 2 - betterproto/tests/test_casing.py | 89 ++++++++++++++ betterproto/tests/test_get_ref_type.py | 84 ++++++++----- betterproto/tests/test_inputs.py | 15 ++- 19 files changed, 333 insertions(+), 163 deletions(-) create mode 100644 betterproto/compile/naming.py create mode 100644 betterproto/tests/inputs/import_root_package_from_child/child.proto delete mode 100644 betterproto/tests/inputs/import_root_package_from_child/import_root_package_from_child.proto create mode 100644 betterproto/tests/test_casing.py diff --git a/Pipfile b/Pipfile index 0e9397c..455081e 100644 --- a/Pipfile +++ b/Pipfile @@ -15,7 +15,6 @@ rope = "*" protobuf = "*" jinja2 = "*" grpclib = "*" -stringcase = "*" black = "*" backports-datetime-fromisoformat = "*" dataclasses = "*" diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 5d901be..9ed73f6 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -30,7 +30,7 @@ from typing import ( import grpclib.const import stringcase -from .casing import safe_snake_case +from .casing import safe_snake_case, snake_case if TYPE_CHECKING: from grpclib._protocols import IProtoMessage @@ -132,7 +132,7 @@ class Casing(enum.Enum): """Casing constants for serialization.""" CAMEL = stringcase.camelcase - SNAKE = stringcase.snakecase + SNAKE = snake_case class _PLACEHOLDER: diff --git a/betterproto/casing.py b/betterproto/casing.py index 67ca9a2..919f02e 100644 --- a/betterproto/casing.py +++ b/betterproto/casing.py @@ -1,9 +1,21 @@ -import stringcase +import re + +# Word delimiters and symbols that will not be preserved when re-casing. +# language=PythonRegExp +SYMBOLS = "[^a-zA-Z0-9]*" + +# Optionally capitalized word. +# language=PythonRegExp +WORD = "[A-Z]*[a-z]*[0-9]*" + +# Uppercase word, not followed by lowercase letters. +# language=PythonRegExp +WORD_UPPER = "[A-Z]+(?![a-z])[0-9]*" def safe_snake_case(value: str) -> str: """Snake case a value taking into account Python keywords.""" - value = stringcase.snakecase(value) + value = snake_case(value) if value in [ "and", "as", @@ -39,3 +51,33 @@ def safe_snake_case(value: str) -> str: # https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles value += "_" return value + + +def snake_case(value: str): + """ + Join words with an underscore into lowercase and remove symbols. + """ + snake = re.sub( + f"{SYMBOLS}({WORD_UPPER}|{WORD})", lambda groups: "_" + groups[1].lower(), value + ) + return snake.strip("_") + + +def pascal_case(value: str): + """ + Capitalize each word and remove symbols. + """ + return re.sub( + f"{SYMBOLS}({WORD_UPPER}|{WORD})", lambda groups: groups[1].capitalize(), value + ) + + +def camel_case(value: str): + """ + Capitalize all words except first and remove symbols. + """ + return capitalize_first(pascal_case(value)) + + +def capitalize_first(value: str): + return value[0:1].lower() + value[1:] diff --git a/betterproto/compile/importing.py b/betterproto/compile/importing.py index 8b9ee3f..3988718 100644 --- a/betterproto/compile/importing.py +++ b/betterproto/compile/importing.py @@ -1,59 +1,72 @@ -from functools import reduce +import re from typing import Dict, List, Type -import stringcase - from betterproto import safe_snake_case +from betterproto.compile.naming import pythonize_class_name from betterproto.lib.google import protobuf as google_protobuf WRAPPER_TYPES: Dict[str, Type] = { - "google.protobuf.DoubleValue": google_protobuf.DoubleValue, - "google.protobuf.FloatValue": google_protobuf.FloatValue, - "google.protobuf.Int32Value": google_protobuf.Int32Value, - "google.protobuf.Int64Value": google_protobuf.Int64Value, - "google.protobuf.UInt32Value": google_protobuf.UInt32Value, - "google.protobuf.UInt64Value": google_protobuf.UInt64Value, - "google.protobuf.BoolValue": google_protobuf.BoolValue, - "google.protobuf.StringValue": google_protobuf.StringValue, - "google.protobuf.BytesValue": google_protobuf.BytesValue, + ".google.protobuf.DoubleValue": google_protobuf.DoubleValue, + ".google.protobuf.FloatValue": google_protobuf.FloatValue, + ".google.protobuf.Int32Value": google_protobuf.Int32Value, + ".google.protobuf.Int64Value": google_protobuf.Int64Value, + ".google.protobuf.UInt32Value": google_protobuf.UInt32Value, + ".google.protobuf.UInt64Value": google_protobuf.UInt64Value, + ".google.protobuf.BoolValue": google_protobuf.BoolValue, + ".google.protobuf.StringValue": google_protobuf.StringValue, + ".google.protobuf.BytesValue": google_protobuf.BytesValue, } +def parse_source_type_name(field_type_name): + """ + Split full source type name into package and type name. + E.g. 'root.package.Message' -> ('root.package', 'Message') + 'root.Message.SomeEnum' -> ('root', 'Message.SomeEnum') + """ + package_match = re.match(r"^\.?([^A-Z]+)\.(.+)", field_type_name) + if package_match: + package = package_match.group(1) + name = package_match.group(2) + else: + package = "" + name = field_type_name.lstrip(".") + return package, name + + def get_ref_type( - package: str, imports: set, type_name: str, unwrap: bool = True + package: str, imports: set, source_type: str, unwrap: bool = True, ) -> str: """ Return a Python type name for a proto type reference. Adds the import if necessary. Unwraps well known type if required. """ - # If the package name is a blank string, then this should still work - # because by convention packages are lowercase and message/enum types are - # pascal-cased. May require refactoring in the future. - type_name = type_name.lstrip(".") - - is_wrapper = type_name in WRAPPER_TYPES + is_wrapper = source_type in WRAPPER_TYPES if unwrap: if is_wrapper: - wrapped_type = type(WRAPPER_TYPES[type_name]().value) + wrapped_type = type(WRAPPER_TYPES[source_type]().value) return f"Optional[{wrapped_type.__name__}]" - if type_name == "google.protobuf.Duration": + if source_type == ".google.protobuf.Duration": return "timedelta" - if type_name == "google.protobuf.Timestamp": + if source_type == ".google.protobuf.Timestamp": return "datetime" - # Use precompiled classes for google.protobuf.* objects - if type_name.startswith("google.protobuf.") and type_name.count(".") == 2: - type_name = type_name.rsplit(".", maxsplit=1)[1] - import_package = "betterproto.lib.google.protobuf" - import_alias = safe_snake_case(import_package) - imports.add(f"import {import_package} as {import_alias}") - return f"{import_alias}.{type_name}" + source_package, source_type = parse_source_type_name(source_type) + + # Use precompiled classes for google.protobuf.* objects + if source_package == "google.protobuf": + string_import = f"betterproto.lib.{source_package}" + py_type = source_type + string_alias = safe_snake_case(string_import) + imports.add(f"import {string_import} as {string_alias}") + return f"{string_alias}.{py_type}" + + py_package: List[str] = source_package.split(".") if source_package else [] + py_type: str = pythonize_class_name(source_type) - importing_package: List[str] = type_name.split(".") - importing_type: str = stringcase.pascalcase(importing_package.pop()) current_package: List[str] = package.split(".") if package else [] # importing sibling @@ -67,9 +80,8 @@ def get_ref_type( package = foo.bar name = foo.bar.Baz """ - if importing_package == current_package: - imports.add(f"from . import {importing_type}") - return importing_type + if py_package == current_package: + return f'"{py_type}"' # importing child & descendent: """ @@ -79,18 +91,18 @@ def get_ref_type( package = name = foo.bar.Baz """ - if importing_package[0 : len(current_package)] == current_package: - importing_descendent = importing_package[len(current_package) :] + if py_package[0 : len(current_package)] == current_package: + importing_descendent = py_package[len(current_package) :] string_from = ".".join(importing_descendent[0:-1]) string_import = importing_descendent[-1] if string_from: string_alias = "_".join(importing_descendent) imports.add(f"from .{string_from} import {string_import} as {string_alias}") - return f"{string_alias}.{importing_type}" + return f"{string_alias}.{py_type}" else: imports.add(f"from . import {string_import}") - return f"{string_import}.{importing_type}" + return f"{string_import}.{py_type}" # importing parent & ancestor """ @@ -103,10 +115,10 @@ def get_ref_type( package = foo.bar.baz name = Bar """ - if current_package[0 : len(importing_package)] == importing_package: - distance_up = len(current_package) - len(importing_package) - imports.add(f"from .{'.' * distance_up} import {importing_type}") - return importing_type + if current_package[0 : len(py_package)] == py_package: + distance_up = len(current_package) - len(py_package) + imports.add(f"from .{'.' * distance_up} import {py_type}") + return py_type # importing unrelated or cousin """ @@ -116,20 +128,16 @@ def get_ref_type( package = foo.bar.baz name = foo.example.Bar """ - shared_ancestory = [ - pair[0] - for pair in zip(current_package, importing_package) - if pair[0] == pair[1] + pair[0] for pair in zip(current_package, py_package) if pair[0] == pair[1] ] distance_up = len(current_package) - len(shared_ancestory) - string_from = f".{'.' * distance_up}" + ".".join( - importing_package[len(shared_ancestory) : -1] + py_package[len(shared_ancestory) : -1] ) - string_import = importing_package[-1] - string_alias = f"{'_' * distance_up}" + safe_snake_case( - ".".join(importing_package[len(shared_ancestory) :]) + string_import = py_package[-1] + alias = f"{'_' * distance_up}" + safe_snake_case( + ".".join(py_package[len(shared_ancestory) :]) ) - imports.add(f"from {string_from} import {string_import} as {string_alias}") - return f"{string_alias}.{importing_type}" + imports.add(f"from {string_from} import {string_import} as {alias}") + return f"{alias}.{py_type}" diff --git a/betterproto/compile/naming.py b/betterproto/compile/naming.py new file mode 100644 index 0000000..3d56852 --- /dev/null +++ b/betterproto/compile/naming.py @@ -0,0 +1,13 @@ +from betterproto import casing + + +def pythonize_class_name(name): + return casing.pascal_case(name) + + +def pythonize_field_name(name: str): + return casing.safe_snake_case(name) + + +def pythonize_method_name(name: str): + return casing.safe_snake_case(name) diff --git a/betterproto/plugin.py b/betterproto/plugin.py index e300318..e30571f 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -10,6 +10,11 @@ from typing import List from betterproto.casing import safe_snake_case from betterproto.compile.importing import get_ref_type import betterproto +from betterproto.compile.naming import ( + pythonize_class_name, + pythonize_field_name, + pythonize_method_name, +) try: # betterproto[compiler] specific dependencies @@ -35,27 +40,22 @@ except ImportError as err: raise SystemExit(1) -def py_type( - package: str, - imports: set, - message: DescriptorProto, - descriptor: FieldDescriptorProto, -) -> str: - if descriptor.type in [1, 2, 6, 7, 15, 16]: +def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str: + if field.type in [1, 2, 6, 7, 15, 16]: return "float" - elif descriptor.type in [3, 4, 5, 13, 17, 18]: + elif field.type in [3, 4, 5, 13, 17, 18]: return "int" - elif descriptor.type == 8: + elif field.type == 8: return "bool" - elif descriptor.type == 9: + elif field.type == 9: return "str" - elif descriptor.type in [11, 14]: + elif field.type in [11, 14]: # Type referencing another defined Message or a named enum - return get_ref_type(package, imports, descriptor.type_name) - elif descriptor.type == 12: + return get_ref_type(package, imports, field.type_name) + elif field.type == 12: return "bytes" else: - raise NotImplementedError(f"Unknown type {descriptor.type}") + raise NotImplementedError(f"Unknown type {field.type}") def get_py_zero(type_num: int) -> str: @@ -160,17 +160,10 @@ def generate_code(request, response): "services": [], } - type_mapping = {} - for proto_file in options["files"]: - # print(proto_file.message_type, file=sys.stderr) - # print(proto_file.service, file=sys.stderr) - # print(proto_file.source_code_info, file=sys.stderr) - + item: DescriptorProto for item, path in traverse(proto_file): - # print(item, file=sys.stderr) - # print(path, file=sys.stderr) - data = {"name": item.name, "py_name": stringcase.pascalcase(item.name)} + data = {"name": item.name, "py_name": pythonize_class_name(item.name)} if isinstance(item, DescriptorProto): # print(item, file=sys.stderr) @@ -187,7 +180,7 @@ def generate_code(request, response): ) for i, f in enumerate(item.field): - t = py_type(package, output["imports"], item, f) + t = py_type(package, output["imports"], f) zero = get_py_zero(f.type) repeated = False @@ -222,13 +215,11 @@ def generate_code(request, response): k = py_type( package, output["imports"], - item, nested.field[0], ) v = py_type( package, output["imports"], - item, nested.field[1], ) t = f"Dict[{k}, {v}]" @@ -264,7 +255,7 @@ def generate_code(request, response): data["properties"].append( { "name": f.name, - "py_name": safe_snake_case(f.name), + "py_name": pythonize_field_name(f.name), "number": f.number, "comment": get_comment(proto_file, path + [2, i]), "proto_type": int(f.type), @@ -305,7 +296,7 @@ def generate_code(request, response): data = { "name": service.name, - "py_name": stringcase.pascalcase(service.name), + "py_name": pythonize_class_name(service.name), "comment": get_comment(proto_file, [6, i]), "methods": [], } @@ -329,7 +320,7 @@ def generate_code(request, response): data["methods"].append( { "name": method.name, - "py_name": stringcase.snakecase(method.name), + "py_name": pythonize_method_name(method.name), "comment": get_comment(proto_file, [6, i, 2, j], indent=8), "route": f"/{package}.{service.name}/{method.name}", "input": get_ref_type( diff --git a/betterproto/tests/inputs/casing/casing.proto b/betterproto/tests/inputs/casing/casing.proto index ad0c427..ca458b5 100644 --- a/betterproto/tests/inputs/casing/casing.proto +++ b/betterproto/tests/inputs/casing/casing.proto @@ -10,6 +10,7 @@ message Test { int32 camelCase = 1; my_enum snake_case = 2; snake_case_message snake_case_message = 3; + int32 UPPERCASE = 4; } message snake_case_message { diff --git a/betterproto/tests/inputs/casing/test_casing.py b/betterproto/tests/inputs/casing/test_casing.py index 17f01bd..1c0dc80 100644 --- a/betterproto/tests/inputs/casing/test_casing.py +++ b/betterproto/tests/inputs/casing/test_casing.py @@ -8,6 +8,7 @@ def test_message_attributes(): message, "snake_case_message" ), "snake_case field name is same in python" assert hasattr(message, "camel_case"), "CamelCase field is snake_case in python" + assert hasattr(message, "uppercase"), "UPPERCASE field is lowercase in python" def test_message_casing(): diff --git a/betterproto/tests/inputs/config.py b/betterproto/tests/inputs/config.py index 2525d8f..48e4b09 100644 --- a/betterproto/tests/inputs/config.py +++ b/betterproto/tests/inputs/config.py @@ -1,12 +1,7 @@ # Test cases that are expected to fail, e.g. unimplemented features or bug-fixes. # Remove from list when fixed. tests = { - "import_root_sibling", # 61 - "import_child_package_from_package", # 58 - "import_root_package_from_child", # 60 - "import_parent_package_from_child", # 59 - "import_circular_dependency", # failing because of other bugs now - "import_packages_same_name", # 25 + "import_circular_dependency", "oneof_enum", # 63 "casing_message_field_uppercase", # 11 "namespace_keywords", # 70 @@ -15,6 +10,16 @@ tests = { "googletypes_value", # 9 } + +# Defines where the main package for this test resides. +# Needed to test relative package imports. +packages = { + "import_root_package_from_child": ".child", + "import_parent_package_from_child": ".parent.child", + "repeatedmessage": ".repeatedmessage", + "service": ".service", +} + services = { "googletypes_response", "googletypes_response_embedded", diff --git a/betterproto/tests/inputs/import_circular_dependency/import_circular_dependency.proto b/betterproto/tests/inputs/import_circular_dependency/import_circular_dependency.proto index 589d14f..7d02aad 100644 --- a/betterproto/tests/inputs/import_circular_dependency/import_circular_dependency.proto +++ b/betterproto/tests/inputs/import_circular_dependency/import_circular_dependency.proto @@ -3,9 +3,9 @@ syntax = "proto3"; import "root.proto"; import "other.proto"; -// This test-case verifies that future implementations will support circular dependencies in the generated python files. +// This test-case verifies support for circular dependencies in the generated python files. // -// This becomes important when generating 1 python file/module per package, rather than 1 file per proto file. +// This is important because we generate 1 python file/module per package, rather than 1 file per proto file. // // Scenario: // @@ -24,5 +24,5 @@ import "other.proto"; // (root: Test & RootPackageMessage) <-------> (other: OtherPackageMessage) message Test { RootPackageMessage message = 1; - other.OtherPackageMessage other =2; + other.OtherPackageMessage other = 2; } diff --git a/betterproto/tests/inputs/import_root_package_from_child/child.proto b/betterproto/tests/inputs/import_root_package_from_child/child.proto new file mode 100644 index 0000000..d2b29cc --- /dev/null +++ b/betterproto/tests/inputs/import_root_package_from_child/child.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package child; + +import "root.proto"; + +// Verify that we can import root message from child package + +message Test { + RootMessage message = 1; +} diff --git a/betterproto/tests/inputs/import_root_package_from_child/import_root_package_from_child.proto b/betterproto/tests/inputs/import_root_package_from_child/import_root_package_from_child.proto deleted file mode 100644 index 9e7dbcd..0000000 --- a/betterproto/tests/inputs/import_root_package_from_child/import_root_package_from_child.proto +++ /dev/null @@ -1,11 +0,0 @@ -syntax = "proto3"; - -import "root.proto"; - -package child; - -// Tests generated imports when a message inside a child-package refers to a message defined in the root. - -message Test { - RootMessage message = 1; -} diff --git a/betterproto/tests/inputs/nested/nested.proto b/betterproto/tests/inputs/nested/nested.proto index 974bf86..2eaef59 100644 --- a/betterproto/tests/inputs/nested/nested.proto +++ b/betterproto/tests/inputs/nested/nested.proto @@ -10,7 +10,7 @@ message Test { Nested nested = 1; Sibling sibling = 2; - Sibling sibling2 = 3; +// Sibling sibling2 = 3; } message Sibling { diff --git a/betterproto/tests/inputs/nestedtwice/nestedtwice.json b/betterproto/tests/inputs/nestedtwice/nestedtwice.json index 203a660..c953132 100644 --- a/betterproto/tests/inputs/nestedtwice/nestedtwice.json +++ b/betterproto/tests/inputs/nestedtwice/nestedtwice.json @@ -1,10 +1,10 @@ { - "root": { + "top": { "name": "double-nested", - "parent": { - "child": [{"foo": "hello"}], - "enumChild": ["A"], - "rootParentChild": [{"a": "hello"}], + "middle": { + "bottom": [{"foo": "hello"}], + "enumBottom": ["A"], + "topMiddleBottom": [{"a": "hello"}], "bar": true } } diff --git a/betterproto/tests/inputs/nestedtwice/nestedtwice.proto b/betterproto/tests/inputs/nestedtwice/nestedtwice.proto index 91c8050..7e9c206 100644 --- a/betterproto/tests/inputs/nestedtwice/nestedtwice.proto +++ b/betterproto/tests/inputs/nestedtwice/nestedtwice.proto @@ -1,26 +1,26 @@ syntax = "proto3"; message Test { - message Root { - message Parent { - message RootParentChild { + message Top { + message Middle { + message TopMiddleBottom { string a = 1; } - enum EnumChild{ + enum EnumBottom{ A = 0; B = 1; } - message Child { + message Bottom { string foo = 1; } reserved 1; - repeated Child child = 2; - repeated EnumChild enumChild=3; - repeated RootParentChild rootParentChild=4; + repeated Bottom bottom = 2; + repeated EnumBottom enumBottom=3; + repeated TopMiddleBottom topMiddleBottom=4; bool bar = 5; } string name = 1; - Parent parent = 2; + Middle middle = 2; } - Root root = 1; + Top top = 1; } diff --git a/betterproto/tests/inputs/ref/ref.proto b/betterproto/tests/inputs/ref/ref.proto index 6945590..e09fb15 100644 --- a/betterproto/tests/inputs/ref/ref.proto +++ b/betterproto/tests/inputs/ref/ref.proto @@ -1,7 +1,5 @@ syntax = "proto3"; -package ref; - import "repeatedmessage.proto"; message Test { diff --git a/betterproto/tests/test_casing.py b/betterproto/tests/test_casing.py new file mode 100644 index 0000000..f777e2c --- /dev/null +++ b/betterproto/tests/test_casing.py @@ -0,0 +1,89 @@ +import pytest + +from betterproto.casing import camel_case, pascal_case, snake_case + + +@pytest.mark.parametrize( + ["value", "expected"], + [ + ("", ""), + ("a", "A"), + ("foobar", "Foobar"), + ("FooBar", "FooBar"), + ("foo.bar", "FooBar"), + ("foo_bar", "FooBar"), + ("FOOBAR", "Foobar"), + ("FOOBar", "FooBar"), + ("UInt32", "UInt32"), + ("FOO_BAR", "FooBar"), + ("FOOBAR1", "Foobar1"), + ("FOOBAR_1", "Foobar1"), + ("FOO1BAR2", "Foo1Bar2"), + ("foo__bar", "FooBar"), + ("_foobar", "Foobar"), + ("foobaR", "FoobaR"), + ("foo~bar", "FooBar"), + ("foo:bar", "FooBar"), + ("1foobar", "1Foobar"), + ], +) +def test_pascal_case(value, expected): + actual = pascal_case(value) + assert actual == expected, f"{value} => {expected} (actual: {actual})" + + +@pytest.mark.parametrize( + ["value", "expected"], + [ + ("", ""), + ("a", "a"), + ("foobar", "foobar"), + ("FooBar", "fooBar"), + ("foo.bar", "fooBar"), + ("foo_bar", "fooBar"), + ("FOOBAR", "foobar"), + ("FOO_BAR", "fooBar"), + ("FOOBAR1", "foobar1"), + ("FOOBAR_1", "foobar1"), + ("FOO1BAR2", "foo1Bar2"), + ("foo__bar", "fooBar"), + ("_foobar", "foobar"), + ("foobaR", "foobaR"), + ("foo~bar", "fooBar"), + ("foo:bar", "fooBar"), + ("1foobar", "1Foobar"), + ], +) +def test_camel_case(value, expected): + actual = camel_case(value) + assert actual == expected, f"{value} => {expected} (actual: {actual})" + + +@pytest.mark.parametrize( + ["value", "expected"], + [ + ("", ""), + ("a", "a"), + ("foobar", "foobar"), + ("FooBar", "foo_bar"), + ("foo.bar", "foo_bar"), + ("foo_bar", "foo_bar"), + ("FOOBAR", "foobar"), + ("FOOBar", "foo_bar"), + ("UInt32", "u_int32"), + ("FOO_BAR", "foo_bar"), + ("FOOBAR1", "foobar1"), + ("FOOBAR_1", "foobar_1"), + ("FOOBAR_123", "foobar_123"), + ("FOO1BAR2", "foo1_bar2"), + ("foo__bar", "foo_bar"), + ("_foobar", "foobar"), + ("foobaR", "fooba_r"), + ("foo~bar", "foo_bar"), + ("foo:bar", "foo_bar"), + ("1foobar", "1_foobar"), + ], +) +def test_snake_case(value, expected): + actual = snake_case(value) + assert actual == expected, f"{value} => {expected} (actual: {actual})" diff --git a/betterproto/tests/test_get_ref_type.py b/betterproto/tests/test_get_ref_type.py index 8d72406..16a5466 100644 --- a/betterproto/tests/test_get_ref_type.py +++ b/betterproto/tests/test_get_ref_type.py @@ -1,9 +1,8 @@ import pytest -from ..compile.importing import get_ref_type +from ..compile.importing import get_ref_type, parse_source_type_name -@pytest.mark.skip @pytest.mark.parametrize( ["google_type", "expected_name", "expected_import"], [ @@ -33,13 +32,14 @@ def test_import_google_wellknown_types_non_wrappers( google_type: str, expected_name: str, expected_import: str ): imports = set() - name = get_ref_type(package="", imports=imports, type_name=google_type) + name = get_ref_type(package="", imports=imports, source_type=google_type) assert name == expected_name - assert imports.__contains__(expected_import) + assert imports.__contains__( + expected_import + ), f"{expected_import} not found in {imports}" -@pytest.mark.skip @pytest.mark.parametrize( ["google_type", "expected_name"], [ @@ -56,13 +56,12 @@ def test_import_google_wellknown_types_non_wrappers( ) def test_importing_google_wrappers_unwraps_them(google_type: str, expected_name: str): imports = set() - name = get_ref_type(package="", imports=imports, type_name=google_type) + name = get_ref_type(package="", imports=imports, source_type=google_type) assert name == expected_name assert imports == set() -@pytest.mark.skip @pytest.mark.parametrize( ["google_type", "expected_name"], [ @@ -80,7 +79,9 @@ def test_importing_google_wrappers_unwraps_them(google_type: str, expected_name: def test_importing_google_wrappers_without_unwrapping( google_type: str, expected_name: str ): - name = get_ref_type(package="", imports=set(), type_name=google_type, unwrap=False) + name = get_ref_type( + package="", imports=set(), source_type=google_type, unwrap=False + ) assert name == expected_name @@ -88,7 +89,7 @@ def test_importing_google_wrappers_without_unwrapping( def test_import_child_package_from_package(): imports = set() name = get_ref_type( - package="package", imports=imports, type_name="package.child.Message" + package="package", imports=imports, source_type="package.child.Message" ) assert imports == {"from . import child"} @@ -97,7 +98,7 @@ def test_import_child_package_from_package(): def test_import_child_package_from_root(): imports = set() - name = get_ref_type(package="", imports=imports, type_name="child.Message") + name = get_ref_type(package="", imports=imports, source_type="child.Message") assert imports == {"from . import child"} assert name == "child.Message" @@ -106,7 +107,7 @@ def test_import_child_package_from_root(): def test_import_camel_cased(): imports = set() name = get_ref_type( - package="", imports=imports, type_name="child_package.example_message" + package="", imports=imports, source_type="child_package.example_message" ) assert imports == {"from . import child_package"} @@ -115,7 +116,7 @@ def test_import_camel_cased(): def test_import_nested_child_from_root(): imports = set() - name = get_ref_type(package="", imports=imports, type_name="nested.child.Message") + name = get_ref_type(package="", imports=imports, source_type="nested.child.Message") assert imports == {"from .nested import child as nested_child"} assert name == "nested_child.Message" @@ -124,7 +125,7 @@ def test_import_nested_child_from_root(): def test_import_deeply_nested_child_from_root(): imports = set() name = get_ref_type( - package="", imports=imports, type_name="deeply.nested.child.Message" + package="", imports=imports, source_type="deeply.nested.child.Message" ) assert imports == {"from .deeply.nested import child as deeply_nested_child"} @@ -136,7 +137,7 @@ def test_import_deeply_nested_child_from_package(): name = get_ref_type( package="package", imports=imports, - type_name="package.deeply.nested.child.Message", + source_type="package.deeply.nested.child.Message", ) assert imports == {"from .deeply.nested import child as deeply_nested_child"} @@ -145,32 +146,32 @@ def test_import_deeply_nested_child_from_package(): def test_import_root_sibling(): imports = set() - name = get_ref_type(package="", imports=imports, type_name="Message") + name = get_ref_type(package="", imports=imports, source_type="Message") - assert imports == {"from . import Message"} - assert name == "Message" + assert imports == set() + assert name == '"Message"' def test_import_nested_siblings(): imports = set() - name = get_ref_type(package="foo", imports=imports, type_name="foo.Message") + name = get_ref_type(package="foo", imports=imports, source_type="foo.Message") - assert imports == {"from . import Message"} - assert name == "Message" + assert name == '"Message"' def test_import_deeply_nested_siblings(): imports = set() - name = get_ref_type(package="foo.bar", imports=imports, type_name="foo.bar.Message") + name = get_ref_type( + package="foo.bar", imports=imports, source_type="foo.bar.Message" + ) - assert imports == {"from . import Message"} - assert name == "Message" + assert name == '"Message"' def test_import_parent_package_from_child(): imports = set() name = get_ref_type( - package="package.child", imports=imports, type_name="package.Message" + package="package.child", imports=imports, source_type="package.Message" ) assert imports == {"from .. import Message"} @@ -182,7 +183,7 @@ def test_import_parent_package_from_deeply_nested_child(): name = get_ref_type( package="package.deeply.nested.child", imports=imports, - type_name="package.deeply.nested.Message", + source_type="package.deeply.nested.Message", ) assert imports == {"from .. import Message"} @@ -191,7 +192,7 @@ def test_import_parent_package_from_deeply_nested_child(): def test_import_root_package_from_child(): imports = set() - name = get_ref_type(package="package.child", imports=imports, type_name="Message") + name = get_ref_type(package="package.child", imports=imports, source_type="Message") assert imports == {"from ... import Message"} assert name == "Message" @@ -200,7 +201,7 @@ def test_import_root_package_from_child(): def test_import_root_package_from_deeply_nested_child(): imports = set() name = get_ref_type( - package="package.deeply.nested.child", imports=imports, type_name="Message" + package="package.deeply.nested.child", imports=imports, source_type="Message" ) assert imports == {"from ..... import Message"} @@ -209,7 +210,7 @@ def test_import_root_package_from_deeply_nested_child(): def test_import_unrelated_package(): imports = set() - name = get_ref_type(package="a", imports=imports, type_name="p.Message") + name = get_ref_type(package="a", imports=imports, source_type="p.Message") assert imports == {"from .. import p as _p"} assert name == "_p.Message" @@ -217,7 +218,7 @@ def test_import_unrelated_package(): def test_import_unrelated_nested_package(): imports = set() - name = get_ref_type(package="a.b", imports=imports, type_name="p.q.Message") + name = get_ref_type(package="a.b", imports=imports, source_type="p.q.Message") assert imports == {"from ...p import q as __p_q"} assert name == "__p_q.Message" @@ -225,7 +226,9 @@ def test_import_unrelated_nested_package(): def test_import_unrelated_deeply_nested_package(): imports = set() - name = get_ref_type(package="a.b.c.d", imports=imports, type_name="p.q.r.s.Message") + name = get_ref_type( + package="a.b.c.d", imports=imports, source_type="p.q.r.s.Message" + ) assert imports == {"from .....p.q.r import s as ____p_q_r_s"} assert name == "____p_q_r_s.Message" @@ -233,7 +236,7 @@ def test_import_unrelated_deeply_nested_package(): def test_import_cousin_package(): imports = set() - name = get_ref_type(package="a.x", imports=imports, type_name="a.y.Message") + name = get_ref_type(package="a.x", imports=imports, source_type="a.y.Message") assert imports == {"from .. import y as _y"} assert name == "_y.Message" @@ -241,7 +244,7 @@ def test_import_cousin_package(): def test_import_far_cousin_package(): imports = set() - name = get_ref_type(package="a.x.y", imports=imports, type_name="a.b.c.Message") + name = get_ref_type(package="a.x.y", imports=imports, source_type="a.b.c.Message") assert imports == {"from ...b import c as __b_c"} assert name == "__b_c.Message" @@ -249,7 +252,22 @@ def test_import_far_cousin_package(): def test_import_far_far_cousin_package(): imports = set() - name = get_ref_type(package="a.x.y.z", imports=imports, type_name="a.b.c.d.Message") + name = get_ref_type( + package="a.x.y.z", imports=imports, source_type="a.b.c.d.Message" + ) assert imports == {"from ....b.c import d as ___b_c_d"} assert name == "___b_c_d.Message" + + +@pytest.mark.parametrize( + ["full_name", "expected_output"], + [ + ("package.SomeMessage.NestedType", ("package", "SomeMessage.NestedType")), + (".package.SomeMessage.NestedType", ("package", "SomeMessage.NestedType")), + (".service.ExampleRequest", ("service", "ExampleRequest")), + (".package.lower_case_message", ("package", "lower_case_message")), + ], +) +def test_parse_field_type_name(full_name, expected_output): + assert parse_source_type_name(full_name) == expected_output diff --git a/betterproto/tests/test_inputs.py b/betterproto/tests/test_inputs.py index cac8327..6778425 100644 --- a/betterproto/tests/test_inputs.py +++ b/betterproto/tests/test_inputs.py @@ -57,7 +57,9 @@ plugin_output_package = "betterproto.tests.output_betterproto" reference_output_package = "betterproto.tests.output_reference" -TestData = namedtuple("TestData", "plugin_module, reference_module, json_data") +TestData = namedtuple( + "TestData", ["plugin_module", "reference_module", "json_data", "entry_point"] +) @pytest.fixture @@ -75,15 +77,18 @@ def test_data(request): sys.path.append(reference_module_root) + test_package = test_case_name + test_input_config.packages.get(test_case_name, "") + yield ( TestData( plugin_module=importlib.import_module( - f"{plugin_output_package}.{test_case_name}.{test_case_name}" + f"{plugin_output_package}.{test_package}" ), reference_module=lambda: importlib.import_module( f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2" ), json_data=get_test_case_json_data(test_case_name), + entry_point=test_package, ) ) @@ -106,7 +111,7 @@ def test_message_equality(test_data: TestData) -> None: @pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True) def test_message_json(repeat, test_data: TestData) -> None: - plugin_module, _, json_data = test_data + plugin_module, _, json_data, entry_point = test_data for _ in range(repeat): message: betterproto.Message = plugin_module.Test() @@ -119,13 +124,13 @@ def test_message_json(repeat, test_data: TestData) -> None: @pytest.mark.parametrize("test_data", test_cases.services, indirect=True) def test_service_can_be_instantiated(test_data: TestData) -> None: - plugin_module, _, json_data = test_data + plugin_module, _, json_data, entry_point = test_data plugin_module.TestStub(MockChannel()) @pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True) def test_binary_compatibility(repeat, test_data: TestData) -> None: - plugin_module, reference_module, json_data = test_data + plugin_module, reference_module, json_data, entry_point = test_data reference_instance = Parse(json_data, reference_module().Test()) reference_binary_output = reference_instance.SerializeToString()