From d9fa6d2dd35c377bc8f049abcaefb8cafabd818c Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Fri, 12 Jun 2020 13:55:55 +0200 Subject: [PATCH] Fixes issue where generated Google Protobuf messages imported from betterproto.lib instead of using local forward references --- betterproto/compile/importing.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/betterproto/compile/importing.py b/betterproto/compile/importing.py index 115ba78..69cca91 100644 --- a/betterproto/compile/importing.py +++ b/betterproto/compile/importing.py @@ -42,10 +42,8 @@ def get_ref_type( Return a Python type name for a proto type reference. Adds the import if necessary. Unwraps well known type if required. """ - is_wrapper = source_type in WRAPPER_TYPES - if unwrap: - if is_wrapper: + if source_type in WRAPPER_TYPES: wrapped_type = type(WRAPPER_TYPES[source_type]().value) return f"Optional[{wrapped_type.__name__}]" @@ -57,18 +55,18 @@ def get_ref_type( source_package, source_type = parse_source_type_name(source_type) - # Use precompiled classes for google.protobuf.* objects - if source_package == "google.protobuf": - string_import = f"betterproto.lib.{source_package}" - py_type = source_type - string_alias = safe_snake_case(string_import) - imports.add(f"import {string_import} as {string_alias}") - return f"{string_alias}.{py_type}" - current_package: List[str] = package.split(".") if package else [] py_package: List[str] = source_package.split(".") if source_package else [] py_type: str = pythonize_class_name(source_type) + compiling_google_protobuf = current_package == ["google", "protobuf"] + importing_google_protobuf = py_package == ["google", "protobuf"] + if importing_google_protobuf and not compiling_google_protobuf: + py_package = ["betterproto", "lib"] + py_package + + if py_package[:1] == ["betterproto"]: + return import_root(imports, py_package, py_type) + if py_package == current_package: return import_sibling(py_type) @@ -81,6 +79,13 @@ def get_ref_type( return import_cousin(current_package, imports, py_package, py_type) +def import_root(imports, py_package, py_type): + string_import = ".".join(py_package) + string_alias = safe_snake_case(string_import) + imports.add(f"import {string_import} as {string_alias}") + return f"{string_alias}.{py_type}" + + def import_sibling(py_type): """ package =