2020-06-11 13:55:11 +02:00

110 lines
3.3 KiB
Python

from typing import Dict, List, Type
import stringcase
from betterproto import safe_snake_case
from betterproto.lib.google import protobuf as google_protobuf
WRAPPER_TYPES: Dict[str, Type] = {
"google.protobuf.DoubleValue": google_protobuf.DoubleValue,
"google.protobuf.FloatValue": google_protobuf.FloatValue,
"google.protobuf.Int32Value": google_protobuf.Int32Value,
"google.protobuf.Int64Value": google_protobuf.Int64Value,
"google.protobuf.UInt32Value": google_protobuf.UInt32Value,
"google.protobuf.UInt64Value": google_protobuf.UInt64Value,
"google.protobuf.BoolValue": google_protobuf.BoolValue,
"google.protobuf.StringValue": google_protobuf.StringValue,
"google.protobuf.BytesValue": google_protobuf.BytesValue,
}
def get_ref_type(
package: str, imports: set, type_name: str, unwrap: bool = True
) -> str:
"""
Return a Python type name for a proto type reference. Adds the import if
necessary. Unwraps well known type if required.
"""
# If the package name is a blank string, then this should still work
# because by convention packages are lowercase and message/enum types are
# pascal-cased. May require refactoring in the future.
type_name = type_name.lstrip(".")
is_wrapper = type_name in WRAPPER_TYPES
if unwrap:
if is_wrapper:
wrapped_type = type(WRAPPER_TYPES[type_name]().value)
return f"Optional[{wrapped_type.__name__}]"
if type_name == "google.protobuf.Duration":
return "timedelta"
if type_name == "google.protobuf.Timestamp":
return "datetime"
# Use precompiled classes for google.protobuf.* objects
if type_name.startswith("google.protobuf.") and type_name.count(".") == 2:
type_name = type_name.rsplit(".", maxsplit=1)[1]
import_package = "betterproto.lib.google.protobuf"
import_alias = safe_snake_case(import_package)
imports.add(f"import {import_package} as {import_alias}")
return f"{import_alias}.{type_name}"
importing_package: List[str] = type_name.split('.')
importing_type: str = stringcase.pascalcase(importing_package.pop())
current_package: List[str] = package.split('.') if package else []
# importing sibling
'''
package =
name = Foo
package = foo
name = foo.Bar
package = foo.bar
name = foo.bar.Baz
'''
if importing_package == current_package:
imports.add(f"from . import {importing_type}")
return importing_type
# importing child & descendent:
'''
package =
name = foo.Bar
package =
name = foo.bar.Baz
'''
if importing_package[0:len(current_package)] == current_package:
relative_importing_package = '.'.join(importing_package[len(current_package):])
imports.add(f"from . import {relative_importing_package}")
return f"{relative_importing_package}.{importing_type}"
# importing parent & ancestor
'''
package = foo.bar
name = foo.Foo
package = foo
name = Bar
package = foo.bar.baz
name = Bar
'''
# importing unrelated or cousin
'''
package = foo.bar
name = baz.Foo
package = foo.bar.baz
name = foo.example.Bar
'''
return None
# return type_name