Fixes issue where generated Google Protobuf messages imported from betterproto.lib instead of using local forward references
This commit is contained in:
parent
c88edfd093
commit
d9fa6d2dd3
@ -42,10 +42,8 @@ def get_ref_type(
|
|||||||
Return a Python type name for a proto type reference. Adds the import if
|
Return a Python type name for a proto type reference. Adds the import if
|
||||||
necessary. Unwraps well known type if required.
|
necessary. Unwraps well known type if required.
|
||||||
"""
|
"""
|
||||||
is_wrapper = source_type in WRAPPER_TYPES
|
|
||||||
|
|
||||||
if unwrap:
|
if unwrap:
|
||||||
if is_wrapper:
|
if source_type in WRAPPER_TYPES:
|
||||||
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
|
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
|
||||||
return f"Optional[{wrapped_type.__name__}]"
|
return f"Optional[{wrapped_type.__name__}]"
|
||||||
|
|
||||||
@ -57,18 +55,18 @@ def get_ref_type(
|
|||||||
|
|
||||||
source_package, source_type = parse_source_type_name(source_type)
|
source_package, source_type = parse_source_type_name(source_type)
|
||||||
|
|
||||||
# Use precompiled classes for google.protobuf.* objects
|
|
||||||
if source_package == "google.protobuf":
|
|
||||||
string_import = f"betterproto.lib.{source_package}"
|
|
||||||
py_type = source_type
|
|
||||||
string_alias = safe_snake_case(string_import)
|
|
||||||
imports.add(f"import {string_import} as {string_alias}")
|
|
||||||
return f"{string_alias}.{py_type}"
|
|
||||||
|
|
||||||
current_package: List[str] = package.split(".") if package else []
|
current_package: List[str] = package.split(".") if package else []
|
||||||
py_package: List[str] = source_package.split(".") if source_package else []
|
py_package: List[str] = source_package.split(".") if source_package else []
|
||||||
py_type: str = pythonize_class_name(source_type)
|
py_type: str = pythonize_class_name(source_type)
|
||||||
|
|
||||||
|
compiling_google_protobuf = current_package == ["google", "protobuf"]
|
||||||
|
importing_google_protobuf = py_package == ["google", "protobuf"]
|
||||||
|
if importing_google_protobuf and not compiling_google_protobuf:
|
||||||
|
py_package = ["betterproto", "lib"] + py_package
|
||||||
|
|
||||||
|
if py_package[:1] == ["betterproto"]:
|
||||||
|
return import_root(imports, py_package, py_type)
|
||||||
|
|
||||||
if py_package == current_package:
|
if py_package == current_package:
|
||||||
return import_sibling(py_type)
|
return import_sibling(py_type)
|
||||||
|
|
||||||
@ -81,6 +79,13 @@ def get_ref_type(
|
|||||||
return import_cousin(current_package, imports, py_package, py_type)
|
return import_cousin(current_package, imports, py_package, py_type)
|
||||||
|
|
||||||
|
|
||||||
|
def import_root(imports, py_package, py_type):
|
||||||
|
string_import = ".".join(py_package)
|
||||||
|
string_alias = safe_snake_case(string_import)
|
||||||
|
imports.add(f"import {string_import} as {string_alias}")
|
||||||
|
return f"{string_alias}.{py_type}"
|
||||||
|
|
||||||
|
|
||||||
def import_sibling(py_type):
|
def import_sibling(py_type):
|
||||||
"""
|
"""
|
||||||
package =
|
package =
|
||||||
|
Loading…
x
Reference in New Issue
Block a user