Use betterproto wrapper classes, extract to module for testability

This commit is contained in:
boukeversteegh
2020-05-29 19:13:55 +02:00
parent b813d1cedb
commit 2f658df666
7 changed files with 165 additions and 85 deletions

View File

@@ -6,9 +6,9 @@ import re
import stringcase
import sys
import textwrap
from collections import defaultdict
from typing import Dict, List, Optional, Type
from typing import List
from betterproto.casing import safe_snake_case
from betterproto.compile.importing import get_ref_type
import betterproto
try:
@@ -35,78 +35,6 @@ except ImportError as err:
raise SystemExit(1)
WRAPPER_TYPES: Dict[str, Optional[Type]] = defaultdict(
lambda: None,
{
"google.protobuf.DoubleValue": google_wrappers.DoubleValue,
"google.protobuf.FloatValue": google_wrappers.FloatValue,
"google.protobuf.Int64Value": google_wrappers.Int64Value,
"google.protobuf.UInt64Value": google_wrappers.UInt64Value,
"google.protobuf.Int32Value": google_wrappers.Int32Value,
"google.protobuf.UInt32Value": google_wrappers.UInt32Value,
"google.protobuf.BoolValue": google_wrappers.BoolValue,
"google.protobuf.StringValue": google_wrappers.StringValue,
"google.protobuf.BytesValue": google_wrappers.BytesValue,
},
)
def get_ref_type(
package: str, imports: set, type_name: str, unwrap: bool = True
) -> str:
"""
Return a Python type name for a proto type reference. Adds the import if
necessary. Unwraps well known type if required.
"""
# If the package name is a blank string, then this should still work
# because by convention packages are lowercase and message/enum types are
# pascal-cased. May require refactoring in the future.
type_name = type_name.lstrip(".")
# Check if type is wrapper.
wrapper_class = WRAPPER_TYPES[type_name]
if unwrap:
if wrapper_class:
wrapped_type = type(wrapper_class().value)
return f"Optional[{wrapped_type.__name__}]"
if type_name == "google.protobuf.Duration":
return "timedelta"
if type_name == "google.protobuf.Timestamp":
return "datetime"
elif wrapper_class:
imports.add(f"from {wrapper_class.__module__} import {wrapper_class.__name__}")
return f"{wrapper_class.__name__}"
if type_name.startswith(package):
parts = type_name.lstrip(package).lstrip(".").split(".")
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
# This is the current package, which has nested types flattened.
# foo.bar_thing => FooBarThing
cased = [stringcase.pascalcase(part) for part in parts]
type_name = f'"{"".join(cased)}"'
# Use precompiled classes for google.protobuf.* objects
if type_name.startswith("google.protobuf.") and type_name.count(".") == 2:
type_name = type_name.rsplit(".", maxsplit=1)[1]
import_package = "betterproto.lib.google.protobuf"
import_alias = safe_snake_case(import_package)
imports.add(f"import {import_package} as {import_alias}")
return f"{import_alias}.{type_name}"
if "." in type_name:
# This is imported from another package. No need
# to use a forward ref and we need to add the import.
parts = type_name.split(".")
parts[-1] = stringcase.pascalcase(parts[-1])
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
type_name = f"{parts[-2]}.{parts[-1]}"
return type_name
def py_type(
package: str,
imports: set,