diff --git a/betterproto/compile/__init__.py b/betterproto/compile/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/betterproto/compile/importing.py b/betterproto/compile/importing.py new file mode 100644 index 0000000..0c53e0b --- /dev/null +++ b/betterproto/compile/importing.py @@ -0,0 +1,70 @@ +from typing import Dict, Type + +import stringcase + +from betterproto import safe_snake_case +from betterproto.lib.google import protobuf as google_protobuf + +WRAPPER_TYPES: Dict[str, Type] = { + "google.protobuf.DoubleValue": google_protobuf.DoubleValue, + "google.protobuf.FloatValue": google_protobuf.FloatValue, + "google.protobuf.Int32Value": google_protobuf.Int32Value, + "google.protobuf.Int64Value": google_protobuf.Int64Value, + "google.protobuf.UInt32Value": google_protobuf.UInt32Value, + "google.protobuf.UInt64Value": google_protobuf.UInt64Value, + "google.protobuf.BoolValue": google_protobuf.BoolValue, + "google.protobuf.StringValue": google_protobuf.StringValue, + "google.protobuf.BytesValue": google_protobuf.BytesValue, +} + + +def get_ref_type( + package: str, imports: set, type_name: str, unwrap: bool = True +) -> str: + """ + Return a Python type name for a proto type reference. Adds the import if + necessary. Unwraps well known type if required. + """ + # If the package name is a blank string, then this should still work + # because by convention packages are lowercase and message/enum types are + # pascal-cased. May require refactoring in the future. + type_name = type_name.lstrip(".") + + is_wrapper = type_name in WRAPPER_TYPES + + if unwrap: + if is_wrapper: + wrapped_type = type(WRAPPER_TYPES[type_name]().value) + return f"Optional[{wrapped_type.__name__}]" + + if type_name == "google.protobuf.Duration": + return "timedelta" + + if type_name == "google.protobuf.Timestamp": + return "datetime" + + if type_name.startswith(package): + parts = type_name.lstrip(package).lstrip(".").split(".") + if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()): + # This is the current package, which has nested types flattened. + # foo.bar_thing => FooBarThing + cased = [stringcase.pascalcase(part) for part in parts] + type_name = f'"{"".join(cased)}"' + + # Use precompiled classes for google.protobuf.* objects + if type_name.startswith("google.protobuf.") and type_name.count(".") == 2: + type_name = type_name.rsplit(".", maxsplit=1)[1] + import_package = "betterproto.lib.google.protobuf" + import_alias = safe_snake_case(import_package) + imports.add(f"import {import_package} as {import_alias}") + return f"{import_alias}.{type_name}" + + if "." in type_name: + # This is imported from another package. No need + # to use a forward ref and we need to add the import. + parts = type_name.split(".") + parts[-1] = stringcase.pascalcase(parts[-1]) + imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}") + type_name = f"{parts[-2]}.{parts[-1]}" + + return type_name diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 96cb12d..5780240 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -6,9 +6,9 @@ import re import stringcase import sys import textwrap -from collections import defaultdict -from typing import Dict, List, Optional, Type +from typing import List from betterproto.casing import safe_snake_case +from betterproto.compile.importing import get_ref_type import betterproto try: @@ -35,78 +35,6 @@ except ImportError as err: raise SystemExit(1) -WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict( - lambda: None, - { - "google.protobuf.DoubleValue": google_wrappers.DoubleValue, - "google.protobuf.FloatValue": google_wrappers.FloatValue, - "google.protobuf.Int64Value": google_wrappers.Int64Value, - "google.protobuf.UInt64Value": google_wrappers.UInt64Value, - "google.protobuf.Int32Value": google_wrappers.Int32Value, - "google.protobuf.UInt32Value": google_wrappers.UInt32Value, - "google.protobuf.BoolValue": google_wrappers.BoolValue, - "google.protobuf.StringValue": google_wrappers.StringValue, - "google.protobuf.BytesValue": google_wrappers.BytesValue, - }, -) - - -def get_ref_type( - package: str, imports: set, type_name: str, unwrap: bool = True -) -> str: - """ - Return a Python type name for a proto type reference. Adds the import if - necessary. Unwraps well known type if required. - """ - # If the package name is a blank string, then this should still work - # because by convention packages are lowercase and message/enum types are - # pascal-cased. May require refactoring in the future. - type_name = type_name.lstrip(".") - - # Check if type is wrapper. - wrapper_class = WRAPPER_TYPES[type_name] - - if unwrap: - if wrapper_class: - wrapped_type = type(wrapper_class().value) - return f"Optional[{wrapped_type.__name__}]" - - if type_name == "google.protobuf.Duration": - return "timedelta" - - if type_name == "google.protobuf.Timestamp": - return "datetime" - elif wrapper_class: - imports.add(f"from {wrapper_class.__module__} import {wrapper_class.__name__}") - return f"{wrapper_class.__name__}" - - if type_name.startswith(package): - parts = type_name.lstrip(package).lstrip(".").split(".") - if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()): - # This is the current package, which has nested types flattened. - # foo.bar_thing => FooBarThing - cased = [stringcase.pascalcase(part) for part in parts] - type_name = f'"{"".join(cased)}"' - - # Use precompiled classes for google.protobuf.* objects - if type_name.startswith("google.protobuf.") and type_name.count(".") == 2: - type_name = type_name.rsplit(".", maxsplit=1)[1] - import_package = "betterproto.lib.google.protobuf" - import_alias = safe_snake_case(import_package) - imports.add(f"import {import_package} as {import_alias}") - return f"{import_alias}.{type_name}" - - if "." in type_name: - # This is imported from another package. No need - # to use a forward ref and we need to add the import. - parts = type_name.split(".") - parts[-1] = stringcase.pascalcase(parts[-1]) - imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}") - type_name = f"{parts[-2]}.{parts[-1]}" - - return type_name - - def py_type( package: str, imports: set, diff --git a/betterproto/tests/__init__.py b/betterproto/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/betterproto/tests/inputs/config.py b/betterproto/tests/inputs/config.py index 7ed40f5..2525d8f 100644 --- a/betterproto/tests/inputs/config.py +++ b/betterproto/tests/inputs/config.py @@ -8,7 +8,6 @@ tests = { "import_circular_dependency", # failing because of other bugs now "import_packages_same_name", # 25 "oneof_enum", # 63 - "googletypes_service_returns_empty", # 9 "casing_message_field_uppercase", # 11 "namespace_keywords", # 70 "namespace_builtin_types", # 53 @@ -22,4 +21,5 @@ services = { "service", "import_service_input_message", "googletypes_service_returns_empty", + "googletypes_service_returns_googletype", } diff --git a/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py b/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py index fb2152b..02fa193 100644 --- a/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py +++ b/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py @@ -1,6 +1,6 @@ from typing import Any, Callable, Optional -import google.protobuf.wrappers_pb2 as wrappers +import betterproto.lib.google.protobuf as protobuf import pytest from betterproto.tests.mocks import MockChannel @@ -9,15 +9,15 @@ from betterproto.tests.output_betterproto.googletypes_response.googletypes_respo ) test_cases = [ - (TestStub.get_double, wrappers.DoubleValue, 2.5), - (TestStub.get_float, wrappers.FloatValue, 2.5), - (TestStub.get_int64, wrappers.Int64Value, -64), - (TestStub.get_u_int64, wrappers.UInt64Value, 64), - (TestStub.get_int32, wrappers.Int32Value, -32), - (TestStub.get_u_int32, wrappers.UInt32Value, 32), - (TestStub.get_bool, wrappers.BoolValue, True), - (TestStub.get_string, wrappers.StringValue, "string"), - (TestStub.get_bytes, wrappers.BytesValue, bytes(0xFF)[0:4]), + (TestStub.get_double, protobuf.DoubleValue, 2.5), + (TestStub.get_float, protobuf.FloatValue, 2.5), + (TestStub.get_int64, protobuf.Int64Value, -64), + (TestStub.get_u_int64, protobuf.UInt64Value, 64), + (TestStub.get_int32, protobuf.Int32Value, -32), + (TestStub.get_u_int32, protobuf.UInt32Value, 32), + (TestStub.get_bool, protobuf.BoolValue, True), + (TestStub.get_string, protobuf.StringValue, "string"), + (TestStub.get_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]), ] diff --git a/betterproto/tests/test_get_ref_type.py b/betterproto/tests/test_get_ref_type.py new file mode 100644 index 0000000..9635356 --- /dev/null +++ b/betterproto/tests/test_get_ref_type.py @@ -0,0 +1,82 @@ +import pytest + +from ..compile.importing import get_ref_type + + +@pytest.mark.parametrize( + ["google_type", "expected_name", "expected_import"], + [ + ( + ".google.protobuf.Empty", + "betterproto_lib_google_protobuf.Empty", + "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf", + ), + ( + ".google.protobuf.Struct", + "betterproto_lib_google_protobuf.Struct", + "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf", + ), + ( + ".google.protobuf.ListValue", + "betterproto_lib_google_protobuf.ListValue", + "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf", + ), + ( + ".google.protobuf.Value", + "betterproto_lib_google_protobuf.Value", + "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf", + ), + ], +) +def test_import_google_wellknown_types_non_wrappers( + google_type: str, expected_name: str, expected_import: str +): + imports = set() + name = get_ref_type(package="", imports=imports, type_name=google_type) + + assert name == expected_name + assert imports.__contains__(expected_import) + + +@pytest.mark.parametrize( + ["google_type", "expected_name"], + [ + (".google.protobuf.DoubleValue", "Optional[float]"), + (".google.protobuf.FloatValue", "Optional[float]"), + (".google.protobuf.Int32Value", "Optional[int]"), + (".google.protobuf.Int64Value", "Optional[int]"), + (".google.protobuf.UInt32Value", "Optional[int]"), + (".google.protobuf.UInt64Value", "Optional[int]"), + (".google.protobuf.BoolValue", "Optional[bool]"), + (".google.protobuf.StringValue", "Optional[str]"), + (".google.protobuf.BytesValue", "Optional[bytes]"), + ], +) +def test_importing_google_wrappers_unwraps_them(google_type: str, expected_name: str): + imports = set() + name = get_ref_type(package="", imports=imports, type_name=google_type) + + assert name == expected_name + assert imports == set() + + +@pytest.mark.parametrize( + ["google_type", "expected_name"], + [ + (".google.protobuf.DoubleValue", "betterproto_lib_google_protobuf.DoubleValue"), + (".google.protobuf.FloatValue", "betterproto_lib_google_protobuf.FloatValue"), + (".google.protobuf.Int32Value", "betterproto_lib_google_protobuf.Int32Value"), + (".google.protobuf.Int64Value", "betterproto_lib_google_protobuf.Int64Value"), + (".google.protobuf.UInt32Value", "betterproto_lib_google_protobuf.UInt32Value"), + (".google.protobuf.UInt64Value", "betterproto_lib_google_protobuf.UInt64Value"), + (".google.protobuf.BoolValue", "betterproto_lib_google_protobuf.BoolValue"), + (".google.protobuf.StringValue", "betterproto_lib_google_protobuf.StringValue"), + (".google.protobuf.BytesValue", "betterproto_lib_google_protobuf.BytesValue"), + ], +) +def test_importing_google_wrappers_without_unwrapping( + google_type: str, expected_name: str +): + name = get_ref_type(package="", imports=set(), type_name=google_type, unwrap=False) + + assert name == expected_name