Use betterproto wrapper classes, extract to module for testability
This commit is contained in:
parent
b813d1cedb
commit
2f658df666
0
betterproto/compile/__init__.py
Normal file
0
betterproto/compile/__init__.py
Normal file
70
betterproto/compile/importing.py
Normal file
70
betterproto/compile/importing.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
from typing import Dict, Type
|
||||||
|
|
||||||
|
import stringcase
|
||||||
|
|
||||||
|
from betterproto import safe_snake_case
|
||||||
|
from betterproto.lib.google import protobuf as google_protobuf
|
||||||
|
|
||||||
|
WRAPPER_TYPES: Dict[str, Type] = {
|
||||||
|
"google.protobuf.DoubleValue": google_protobuf.DoubleValue,
|
||||||
|
"google.protobuf.FloatValue": google_protobuf.FloatValue,
|
||||||
|
"google.protobuf.Int32Value": google_protobuf.Int32Value,
|
||||||
|
"google.protobuf.Int64Value": google_protobuf.Int64Value,
|
||||||
|
"google.protobuf.UInt32Value": google_protobuf.UInt32Value,
|
||||||
|
"google.protobuf.UInt64Value": google_protobuf.UInt64Value,
|
||||||
|
"google.protobuf.BoolValue": google_protobuf.BoolValue,
|
||||||
|
"google.protobuf.StringValue": google_protobuf.StringValue,
|
||||||
|
"google.protobuf.BytesValue": google_protobuf.BytesValue,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_ref_type(
|
||||||
|
package: str, imports: set, type_name: str, unwrap: bool = True
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Return a Python type name for a proto type reference. Adds the import if
|
||||||
|
necessary. Unwraps well known type if required.
|
||||||
|
"""
|
||||||
|
# If the package name is a blank string, then this should still work
|
||||||
|
# because by convention packages are lowercase and message/enum types are
|
||||||
|
# pascal-cased. May require refactoring in the future.
|
||||||
|
type_name = type_name.lstrip(".")
|
||||||
|
|
||||||
|
is_wrapper = type_name in WRAPPER_TYPES
|
||||||
|
|
||||||
|
if unwrap:
|
||||||
|
if is_wrapper:
|
||||||
|
wrapped_type = type(WRAPPER_TYPES[type_name]().value)
|
||||||
|
return f"Optional[{wrapped_type.__name__}]"
|
||||||
|
|
||||||
|
if type_name == "google.protobuf.Duration":
|
||||||
|
return "timedelta"
|
||||||
|
|
||||||
|
if type_name == "google.protobuf.Timestamp":
|
||||||
|
return "datetime"
|
||||||
|
|
||||||
|
if type_name.startswith(package):
|
||||||
|
parts = type_name.lstrip(package).lstrip(".").split(".")
|
||||||
|
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
|
||||||
|
# This is the current package, which has nested types flattened.
|
||||||
|
# foo.bar_thing => FooBarThing
|
||||||
|
cased = [stringcase.pascalcase(part) for part in parts]
|
||||||
|
type_name = f'"{"".join(cased)}"'
|
||||||
|
|
||||||
|
# Use precompiled classes for google.protobuf.* objects
|
||||||
|
if type_name.startswith("google.protobuf.") and type_name.count(".") == 2:
|
||||||
|
type_name = type_name.rsplit(".", maxsplit=1)[1]
|
||||||
|
import_package = "betterproto.lib.google.protobuf"
|
||||||
|
import_alias = safe_snake_case(import_package)
|
||||||
|
imports.add(f"import {import_package} as {import_alias}")
|
||||||
|
return f"{import_alias}.{type_name}"
|
||||||
|
|
||||||
|
if "." in type_name:
|
||||||
|
# This is imported from another package. No need
|
||||||
|
# to use a forward ref and we need to add the import.
|
||||||
|
parts = type_name.split(".")
|
||||||
|
parts[-1] = stringcase.pascalcase(parts[-1])
|
||||||
|
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
|
||||||
|
type_name = f"{parts[-2]}.{parts[-1]}"
|
||||||
|
|
||||||
|
return type_name
|
@ -6,9 +6,9 @@ import re
|
|||||||
import stringcase
|
import stringcase
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
from collections import defaultdict
|
from typing import List
|
||||||
from typing import Dict, List, Optional, Type
|
|
||||||
from betterproto.casing import safe_snake_case
|
from betterproto.casing import safe_snake_case
|
||||||
|
from betterproto.compile.importing import get_ref_type
|
||||||
import betterproto
|
import betterproto
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -35,78 +35,6 @@ except ImportError as err:
|
|||||||
raise SystemExit(1)
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
|
||||||
WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict(
|
|
||||||
lambda: None,
|
|
||||||
{
|
|
||||||
"google.protobuf.DoubleValue": google_wrappers.DoubleValue,
|
|
||||||
"google.protobuf.FloatValue": google_wrappers.FloatValue,
|
|
||||||
"google.protobuf.Int64Value": google_wrappers.Int64Value,
|
|
||||||
"google.protobuf.UInt64Value": google_wrappers.UInt64Value,
|
|
||||||
"google.protobuf.Int32Value": google_wrappers.Int32Value,
|
|
||||||
"google.protobuf.UInt32Value": google_wrappers.UInt32Value,
|
|
||||||
"google.protobuf.BoolValue": google_wrappers.BoolValue,
|
|
||||||
"google.protobuf.StringValue": google_wrappers.StringValue,
|
|
||||||
"google.protobuf.BytesValue": google_wrappers.BytesValue,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_ref_type(
|
|
||||||
package: str, imports: set, type_name: str, unwrap: bool = True
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Return a Python type name for a proto type reference. Adds the import if
|
|
||||||
necessary. Unwraps well known type if required.
|
|
||||||
"""
|
|
||||||
# If the package name is a blank string, then this should still work
|
|
||||||
# because by convention packages are lowercase and message/enum types are
|
|
||||||
# pascal-cased. May require refactoring in the future.
|
|
||||||
type_name = type_name.lstrip(".")
|
|
||||||
|
|
||||||
# Check if type is wrapper.
|
|
||||||
wrapper_class = WRAPPER_TYPES[type_name]
|
|
||||||
|
|
||||||
if unwrap:
|
|
||||||
if wrapper_class:
|
|
||||||
wrapped_type = type(wrapper_class().value)
|
|
||||||
return f"Optional[{wrapped_type.__name__}]"
|
|
||||||
|
|
||||||
if type_name == "google.protobuf.Duration":
|
|
||||||
return "timedelta"
|
|
||||||
|
|
||||||
if type_name == "google.protobuf.Timestamp":
|
|
||||||
return "datetime"
|
|
||||||
elif wrapper_class:
|
|
||||||
imports.add(f"from {wrapper_class.__module__} import {wrapper_class.__name__}")
|
|
||||||
return f"{wrapper_class.__name__}"
|
|
||||||
|
|
||||||
if type_name.startswith(package):
|
|
||||||
parts = type_name.lstrip(package).lstrip(".").split(".")
|
|
||||||
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
|
|
||||||
# This is the current package, which has nested types flattened.
|
|
||||||
# foo.bar_thing => FooBarThing
|
|
||||||
cased = [stringcase.pascalcase(part) for part in parts]
|
|
||||||
type_name = f'"{"".join(cased)}"'
|
|
||||||
|
|
||||||
# Use precompiled classes for google.protobuf.* objects
|
|
||||||
if type_name.startswith("google.protobuf.") and type_name.count(".") == 2:
|
|
||||||
type_name = type_name.rsplit(".", maxsplit=1)[1]
|
|
||||||
import_package = "betterproto.lib.google.protobuf"
|
|
||||||
import_alias = safe_snake_case(import_package)
|
|
||||||
imports.add(f"import {import_package} as {import_alias}")
|
|
||||||
return f"{import_alias}.{type_name}"
|
|
||||||
|
|
||||||
if "." in type_name:
|
|
||||||
# This is imported from another package. No need
|
|
||||||
# to use a forward ref and we need to add the import.
|
|
||||||
parts = type_name.split(".")
|
|
||||||
parts[-1] = stringcase.pascalcase(parts[-1])
|
|
||||||
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
|
|
||||||
type_name = f"{parts[-2]}.{parts[-1]}"
|
|
||||||
|
|
||||||
return type_name
|
|
||||||
|
|
||||||
|
|
||||||
def py_type(
|
def py_type(
|
||||||
package: str,
|
package: str,
|
||||||
imports: set,
|
imports: set,
|
||||||
|
0
betterproto/tests/__init__.py
Normal file
0
betterproto/tests/__init__.py
Normal file
@ -8,7 +8,6 @@ tests = {
|
|||||||
"import_circular_dependency", # failing because of other bugs now
|
"import_circular_dependency", # failing because of other bugs now
|
||||||
"import_packages_same_name", # 25
|
"import_packages_same_name", # 25
|
||||||
"oneof_enum", # 63
|
"oneof_enum", # 63
|
||||||
"googletypes_service_returns_empty", # 9
|
|
||||||
"casing_message_field_uppercase", # 11
|
"casing_message_field_uppercase", # 11
|
||||||
"namespace_keywords", # 70
|
"namespace_keywords", # 70
|
||||||
"namespace_builtin_types", # 53
|
"namespace_builtin_types", # 53
|
||||||
@ -22,4 +21,5 @@ services = {
|
|||||||
"service",
|
"service",
|
||||||
"import_service_input_message",
|
"import_service_input_message",
|
||||||
"googletypes_service_returns_empty",
|
"googletypes_service_returns_empty",
|
||||||
|
"googletypes_service_returns_googletype",
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import google.protobuf.wrappers_pb2 as wrappers
|
import betterproto.lib.google.protobuf as protobuf
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from betterproto.tests.mocks import MockChannel
|
from betterproto.tests.mocks import MockChannel
|
||||||
@ -9,15 +9,15 @@ from betterproto.tests.output_betterproto.googletypes_response.googletypes_respo
|
|||||||
)
|
)
|
||||||
|
|
||||||
test_cases = [
|
test_cases = [
|
||||||
(TestStub.get_double, wrappers.DoubleValue, 2.5),
|
(TestStub.get_double, protobuf.DoubleValue, 2.5),
|
||||||
(TestStub.get_float, wrappers.FloatValue, 2.5),
|
(TestStub.get_float, protobuf.FloatValue, 2.5),
|
||||||
(TestStub.get_int64, wrappers.Int64Value, -64),
|
(TestStub.get_int64, protobuf.Int64Value, -64),
|
||||||
(TestStub.get_u_int64, wrappers.UInt64Value, 64),
|
(TestStub.get_u_int64, protobuf.UInt64Value, 64),
|
||||||
(TestStub.get_int32, wrappers.Int32Value, -32),
|
(TestStub.get_int32, protobuf.Int32Value, -32),
|
||||||
(TestStub.get_u_int32, wrappers.UInt32Value, 32),
|
(TestStub.get_u_int32, protobuf.UInt32Value, 32),
|
||||||
(TestStub.get_bool, wrappers.BoolValue, True),
|
(TestStub.get_bool, protobuf.BoolValue, True),
|
||||||
(TestStub.get_string, wrappers.StringValue, "string"),
|
(TestStub.get_string, protobuf.StringValue, "string"),
|
||||||
(TestStub.get_bytes, wrappers.BytesValue, bytes(0xFF)[0:4]),
|
(TestStub.get_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
82
betterproto/tests/test_get_ref_type.py
Normal file
82
betterproto/tests/test_get_ref_type.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from ..compile.importing import get_ref_type
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["google_type", "expected_name", "expected_import"],
|
||||||
|
[
|
||||||
|
(
|
||||||
|
".google.protobuf.Empty",
|
||||||
|
"betterproto_lib_google_protobuf.Empty",
|
||||||
|
"import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
".google.protobuf.Struct",
|
||||||
|
"betterproto_lib_google_protobuf.Struct",
|
||||||
|
"import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
".google.protobuf.ListValue",
|
||||||
|
"betterproto_lib_google_protobuf.ListValue",
|
||||||
|
"import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
".google.protobuf.Value",
|
||||||
|
"betterproto_lib_google_protobuf.Value",
|
||||||
|
"import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_import_google_wellknown_types_non_wrappers(
|
||||||
|
google_type: str, expected_name: str, expected_import: str
|
||||||
|
):
|
||||||
|
imports = set()
|
||||||
|
name = get_ref_type(package="", imports=imports, type_name=google_type)
|
||||||
|
|
||||||
|
assert name == expected_name
|
||||||
|
assert imports.__contains__(expected_import)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["google_type", "expected_name"],
|
||||||
|
[
|
||||||
|
(".google.protobuf.DoubleValue", "Optional[float]"),
|
||||||
|
(".google.protobuf.FloatValue", "Optional[float]"),
|
||||||
|
(".google.protobuf.Int32Value", "Optional[int]"),
|
||||||
|
(".google.protobuf.Int64Value", "Optional[int]"),
|
||||||
|
(".google.protobuf.UInt32Value", "Optional[int]"),
|
||||||
|
(".google.protobuf.UInt64Value", "Optional[int]"),
|
||||||
|
(".google.protobuf.BoolValue", "Optional[bool]"),
|
||||||
|
(".google.protobuf.StringValue", "Optional[str]"),
|
||||||
|
(".google.protobuf.BytesValue", "Optional[bytes]"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_importing_google_wrappers_unwraps_them(google_type: str, expected_name: str):
|
||||||
|
imports = set()
|
||||||
|
name = get_ref_type(package="", imports=imports, type_name=google_type)
|
||||||
|
|
||||||
|
assert name == expected_name
|
||||||
|
assert imports == set()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["google_type", "expected_name"],
|
||||||
|
[
|
||||||
|
(".google.protobuf.DoubleValue", "betterproto_lib_google_protobuf.DoubleValue"),
|
||||||
|
(".google.protobuf.FloatValue", "betterproto_lib_google_protobuf.FloatValue"),
|
||||||
|
(".google.protobuf.Int32Value", "betterproto_lib_google_protobuf.Int32Value"),
|
||||||
|
(".google.protobuf.Int64Value", "betterproto_lib_google_protobuf.Int64Value"),
|
||||||
|
(".google.protobuf.UInt32Value", "betterproto_lib_google_protobuf.UInt32Value"),
|
||||||
|
(".google.protobuf.UInt64Value", "betterproto_lib_google_protobuf.UInt64Value"),
|
||||||
|
(".google.protobuf.BoolValue", "betterproto_lib_google_protobuf.BoolValue"),
|
||||||
|
(".google.protobuf.StringValue", "betterproto_lib_google_protobuf.StringValue"),
|
||||||
|
(".google.protobuf.BytesValue", "betterproto_lib_google_protobuf.BytesValue"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_importing_google_wrappers_without_unwrapping(
|
||||||
|
google_type: str, expected_name: str
|
||||||
|
):
|
||||||
|
name = get_ref_type(package="", imports=set(), type_name=google_type, unwrap=False)
|
||||||
|
|
||||||
|
assert name == expected_name
|
Loading…
x
Reference in New Issue
Block a user