Fixes issue where generated Google Protobuf messages imported from betterproto.lib instead of using local forward references
This commit is contained in:
parent
c88edfd093
commit
d9fa6d2dd3
@ -42,10 +42,8 @@ def get_ref_type(
|
||||
Return a Python type name for a proto type reference. Adds the import if
|
||||
necessary. Unwraps well known type if required.
|
||||
"""
|
||||
is_wrapper = source_type in WRAPPER_TYPES
|
||||
|
||||
if unwrap:
|
||||
if is_wrapper:
|
||||
if source_type in WRAPPER_TYPES:
|
||||
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
|
||||
return f"Optional[{wrapped_type.__name__}]"
|
||||
|
||||
@ -57,18 +55,18 @@ def get_ref_type(
|
||||
|
||||
source_package, source_type = parse_source_type_name(source_type)
|
||||
|
||||
# Use precompiled classes for google.protobuf.* objects
|
||||
if source_package == "google.protobuf":
|
||||
string_import = f"betterproto.lib.{source_package}"
|
||||
py_type = source_type
|
||||
string_alias = safe_snake_case(string_import)
|
||||
imports.add(f"import {string_import} as {string_alias}")
|
||||
return f"{string_alias}.{py_type}"
|
||||
|
||||
current_package: List[str] = package.split(".") if package else []
|
||||
py_package: List[str] = source_package.split(".") if source_package else []
|
||||
py_type: str = pythonize_class_name(source_type)
|
||||
|
||||
compiling_google_protobuf = current_package == ["google", "protobuf"]
|
||||
importing_google_protobuf = py_package == ["google", "protobuf"]
|
||||
if importing_google_protobuf and not compiling_google_protobuf:
|
||||
py_package = ["betterproto", "lib"] + py_package
|
||||
|
||||
if py_package[:1] == ["betterproto"]:
|
||||
return import_root(imports, py_package, py_type)
|
||||
|
||||
if py_package == current_package:
|
||||
return import_sibling(py_type)
|
||||
|
||||
@ -81,6 +79,13 @@ def get_ref_type(
|
||||
return import_cousin(current_package, imports, py_package, py_type)
|
||||
|
||||
|
||||
def import_root(imports, py_package, py_type):
|
||||
string_import = ".".join(py_package)
|
||||
string_alias = safe_snake_case(string_import)
|
||||
imports.add(f"import {string_import} as {string_alias}")
|
||||
return f"{string_alias}.{py_type}"
|
||||
|
||||
|
||||
def import_sibling(py_type):
|
||||
"""
|
||||
package =
|
||||
|
Loading…
x
Reference in New Issue
Block a user