130 lines
4.3 KiB
Python
130 lines
4.3 KiB
Python
from functools import reduce
|
|
from typing import Dict, List, Type
|
|
|
|
import stringcase
|
|
|
|
from betterproto import safe_snake_case
|
|
from betterproto.lib.google import protobuf as google_protobuf
|
|
|
|
WRAPPER_TYPES: Dict[str, Type] = {
|
|
"google.protobuf.DoubleValue": google_protobuf.DoubleValue,
|
|
"google.protobuf.FloatValue": google_protobuf.FloatValue,
|
|
"google.protobuf.Int32Value": google_protobuf.Int32Value,
|
|
"google.protobuf.Int64Value": google_protobuf.Int64Value,
|
|
"google.protobuf.UInt32Value": google_protobuf.UInt32Value,
|
|
"google.protobuf.UInt64Value": google_protobuf.UInt64Value,
|
|
"google.protobuf.BoolValue": google_protobuf.BoolValue,
|
|
"google.protobuf.StringValue": google_protobuf.StringValue,
|
|
"google.protobuf.BytesValue": google_protobuf.BytesValue,
|
|
}
|
|
|
|
|
|
def get_ref_type(
|
|
package: str, imports: set, type_name: str, unwrap: bool = True
|
|
) -> str:
|
|
"""
|
|
Return a Python type name for a proto type reference. Adds the import if
|
|
necessary. Unwraps well known type if required.
|
|
"""
|
|
# If the package name is a blank string, then this should still work
|
|
# because by convention packages are lowercase and message/enum types are
|
|
# pascal-cased. May require refactoring in the future.
|
|
type_name = type_name.lstrip(".")
|
|
|
|
is_wrapper = type_name in WRAPPER_TYPES
|
|
|
|
if unwrap:
|
|
if is_wrapper:
|
|
wrapped_type = type(WRAPPER_TYPES[type_name]().value)
|
|
return f"Optional[{wrapped_type.__name__}]"
|
|
|
|
if type_name == "google.protobuf.Duration":
|
|
return "timedelta"
|
|
|
|
if type_name == "google.protobuf.Timestamp":
|
|
return "datetime"
|
|
|
|
|
|
# Use precompiled classes for google.protobuf.* objects
|
|
if type_name.startswith("google.protobuf.") and type_name.count(".") == 2:
|
|
type_name = type_name.rsplit(".", maxsplit=1)[1]
|
|
import_package = "betterproto.lib.google.protobuf"
|
|
import_alias = safe_snake_case(import_package)
|
|
imports.add(f"import {import_package} as {import_alias}")
|
|
return f"{import_alias}.{type_name}"
|
|
|
|
importing_package: List[str] = type_name.split('.')
|
|
importing_type: str = stringcase.pascalcase(importing_package.pop())
|
|
current_package: List[str] = package.split('.') if package else []
|
|
|
|
# importing sibling
|
|
'''
|
|
package =
|
|
name = Foo
|
|
|
|
package = foo
|
|
name = foo.Bar
|
|
|
|
package = foo.bar
|
|
name = foo.bar.Baz
|
|
'''
|
|
if importing_package == current_package:
|
|
imports.add(f"from . import {importing_type}")
|
|
return importing_type
|
|
|
|
# importing child & descendent:
|
|
'''
|
|
package =
|
|
name = foo.Bar
|
|
|
|
package =
|
|
name = foo.bar.Baz
|
|
'''
|
|
if importing_package[0:len(current_package)] == current_package:
|
|
importing_descendent = importing_package[len(current_package):]
|
|
string_from = '.'.join(importing_descendent[0:-1])
|
|
string_import = importing_descendent[-1]
|
|
|
|
if string_from:
|
|
string_alias = '_'.join(importing_descendent)
|
|
imports.add(f"from .{string_from} import {string_import} as {string_alias}")
|
|
return f"{string_alias}.{importing_type}"
|
|
else:
|
|
imports.add(f"from . import {string_import}")
|
|
return f"{string_import}.{importing_type}"
|
|
|
|
# importing parent & ancestor
|
|
'''
|
|
package = foo.bar
|
|
name = foo.Foo
|
|
|
|
package = foo
|
|
name = Bar
|
|
|
|
package = foo.bar.baz
|
|
name = Bar
|
|
'''
|
|
if current_package[0:len(importing_package)] == importing_package:
|
|
distance = len(current_package) - len(importing_package)
|
|
imports.add(f"from .{'.' * distance} import {importing_type}")
|
|
return importing_type
|
|
|
|
# importing unrelated or cousin
|
|
'''
|
|
package = foo.bar
|
|
name = baz.Foo
|
|
|
|
package = foo.bar.baz
|
|
name = foo.example.Bar
|
|
'''
|
|
root_distance = len(current_package)
|
|
shared_ancestory_length = reduce(lambda l, pair: l + (pair[0] == pair[1]), zip(current_package, importing_package), 0)
|
|
|
|
string_from = f"{'.' * (shared_ancestory_length+1)}.{'.'.join(importing_package[0:-1])}"
|
|
string_import = importing_package[-1]
|
|
string_alias = '_' * root_distance + safe_snake_case('.'.join(importing_package))
|
|
imports.add(f"from {string_from} import {string_import} as {string_alias}")
|
|
|
|
return f"{string_alias}.{importing_type}"
|
|
# return type_name
|