Fixes issue where generated Google Protobuf messages imported from betterproto.lib instead of using local forward references

This commit is contained in:
boukeversteegh 2020-06-12 13:55:55 +02:00
parent c88edfd093
commit d9fa6d2dd3

View File

@ -42,10 +42,8 @@ def get_ref_type(
Return a Python type name for a proto type reference. Adds the import if
necessary. Unwraps well known type if required.
"""
is_wrapper = source_type in WRAPPER_TYPES
if unwrap:
if is_wrapper:
if source_type in WRAPPER_TYPES:
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
return f"Optional[{wrapped_type.__name__}]"
@ -57,18 +55,18 @@ def get_ref_type(
source_package, source_type = parse_source_type_name(source_type)
# Use precompiled classes for google.protobuf.* objects
if source_package == "google.protobuf":
string_import = f"betterproto.lib.{source_package}"
py_type = source_type
string_alias = safe_snake_case(string_import)
imports.add(f"import {string_import} as {string_alias}")
return f"{string_alias}.{py_type}"
current_package: List[str] = package.split(".") if package else []
py_package: List[str] = source_package.split(".") if source_package else []
py_type: str = pythonize_class_name(source_type)
compiling_google_protobuf = current_package == ["google", "protobuf"]
importing_google_protobuf = py_package == ["google", "protobuf"]
if importing_google_protobuf and not compiling_google_protobuf:
py_package = ["betterproto", "lib"] + py_package
if py_package[:1] == ["betterproto"]:
return import_root(imports, py_package, py_type)
if py_package == current_package:
return import_sibling(py_type)
@ -81,6 +79,13 @@ def get_ref_type(
return import_cousin(current_package, imports, py_package, py_type)
def import_root(imports, py_package, py_type):
string_import = ".".join(py_package)
string_alias = safe_snake_case(string_import)
imports.add(f"import {string_import} as {string_alias}")
return f"{string_alias}.{py_type}"
def import_sibling(py_type):
"""
package =