Use betterproto wrapper classes, extract to module for testability
This commit is contained in:
0
betterproto/compile/__init__.py
Normal file
0
betterproto/compile/__init__.py
Normal file
70
betterproto/compile/importing.py
Normal file
70
betterproto/compile/importing.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from typing import Dict, Type
|
||||
|
||||
import stringcase
|
||||
|
||||
from betterproto import safe_snake_case
|
||||
from betterproto.lib.google import protobuf as google_protobuf
|
||||
|
||||
WRAPPER_TYPES: Dict[str, Type] = {
|
||||
"google.protobuf.DoubleValue": google_protobuf.DoubleValue,
|
||||
"google.protobuf.FloatValue": google_protobuf.FloatValue,
|
||||
"google.protobuf.Int32Value": google_protobuf.Int32Value,
|
||||
"google.protobuf.Int64Value": google_protobuf.Int64Value,
|
||||
"google.protobuf.UInt32Value": google_protobuf.UInt32Value,
|
||||
"google.protobuf.UInt64Value": google_protobuf.UInt64Value,
|
||||
"google.protobuf.BoolValue": google_protobuf.BoolValue,
|
||||
"google.protobuf.StringValue": google_protobuf.StringValue,
|
||||
"google.protobuf.BytesValue": google_protobuf.BytesValue,
|
||||
}
|
||||
|
||||
|
||||
def get_ref_type(
|
||||
package: str, imports: set, type_name: str, unwrap: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Return a Python type name for a proto type reference. Adds the import if
|
||||
necessary. Unwraps well known type if required.
|
||||
"""
|
||||
# If the package name is a blank string, then this should still work
|
||||
# because by convention packages are lowercase and message/enum types are
|
||||
# pascal-cased. May require refactoring in the future.
|
||||
type_name = type_name.lstrip(".")
|
||||
|
||||
is_wrapper = type_name in WRAPPER_TYPES
|
||||
|
||||
if unwrap:
|
||||
if is_wrapper:
|
||||
wrapped_type = type(WRAPPER_TYPES[type_name]().value)
|
||||
return f"Optional[{wrapped_type.__name__}]"
|
||||
|
||||
if type_name == "google.protobuf.Duration":
|
||||
return "timedelta"
|
||||
|
||||
if type_name == "google.protobuf.Timestamp":
|
||||
return "datetime"
|
||||
|
||||
if type_name.startswith(package):
|
||||
parts = type_name.lstrip(package).lstrip(".").split(".")
|
||||
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
|
||||
# This is the current package, which has nested types flattened.
|
||||
# foo.bar_thing => FooBarThing
|
||||
cased = [stringcase.pascalcase(part) for part in parts]
|
||||
type_name = f'"{"".join(cased)}"'
|
||||
|
||||
# Use precompiled classes for google.protobuf.* objects
|
||||
if type_name.startswith("google.protobuf.") and type_name.count(".") == 2:
|
||||
type_name = type_name.rsplit(".", maxsplit=1)[1]
|
||||
import_package = "betterproto.lib.google.protobuf"
|
||||
import_alias = safe_snake_case(import_package)
|
||||
imports.add(f"import {import_package} as {import_alias}")
|
||||
return f"{import_alias}.{type_name}"
|
||||
|
||||
if "." in type_name:
|
||||
# This is imported from another package. No need
|
||||
# to use a forward ref and we need to add the import.
|
||||
parts = type_name.split(".")
|
||||
parts[-1] = stringcase.pascalcase(parts[-1])
|
||||
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
|
||||
type_name = f"{parts[-2]}.{parts[-1]}"
|
||||
|
||||
return type_name
|
||||
Reference in New Issue
Block a user