diff --git a/betterproto/compile/importing.py b/betterproto/compile/importing.py index 3988718..faab18f 100644 --- a/betterproto/compile/importing.py +++ b/betterproto/compile/importing.py @@ -64,14 +64,25 @@ def get_ref_type( imports.add(f"import {string_import} as {string_alias}") return f"{string_alias}.{py_type}" + current_package: List[str] = package.split(".") if package else [] py_package: List[str] = source_package.split(".") if source_package else [] py_type: str = pythonize_class_name(source_type) - current_package: List[str] = package.split(".") if package else [] + if py_package == current_package: + return import_sibling(py_type) - # importing sibling + if py_package[0 : len(current_package)] == current_package: + return import_descendent(current_package, imports, py_package, py_type) + + if current_package[0 : len(py_package)] == py_package: + return import_ancestor(current_package, imports, py_package, py_type) + + return import_cousin(current_package, imports, py_package, py_type) + + +def import_sibling(py_type): """ - package = + package = name = Foo package = foo @@ -80,47 +91,46 @@ def get_ref_type( package = foo.bar name = foo.bar.Baz """ - if py_package == current_package: - return f'"{py_type}"' + return f'"{py_type}"' - # importing child & descendent: + +def import_descendent(current_package, imports, py_package, py_type): """ - package = + package = name = foo.Bar - - package = + + package = name = foo.bar.Baz """ - if py_package[0 : len(current_package)] == current_package: - importing_descendent = py_package[len(current_package) :] - string_from = ".".join(importing_descendent[0:-1]) - string_import = importing_descendent[-1] + importing_descendent = py_package[len(current_package) :] + string_from = ".".join(importing_descendent[0:-1]) + string_import = importing_descendent[-1] + if string_from: + string_alias = "_".join(importing_descendent) + imports.add(f"from .{string_from} import {string_import} as {string_alias}") + return f"{string_alias}.{py_type}" + else: + imports.add(f"from . import {string_import}") + return f"{string_import}.{py_type}" - if string_from: - string_alias = "_".join(importing_descendent) - imports.add(f"from .{string_from} import {string_import} as {string_alias}") - return f"{string_alias}.{py_type}" - else: - imports.add(f"from . import {string_import}") - return f"{string_import}.{py_type}" - # importing parent & ancestor +def import_ancestor(current_package, imports, py_package, py_type): """ package = foo.bar name = foo.Foo - + package = foo name = Bar - + package = foo.bar.baz name = Bar """ - if current_package[0 : len(py_package)] == py_package: - distance_up = len(current_package) - len(py_package) - imports.add(f"from .{'.' * distance_up} import {py_type}") - return py_type + distance_up = len(current_package) - len(py_package) + imports.add(f"from .{'.' * distance_up} import {py_type}") + return py_type - # importing unrelated or cousin + +def import_cousin(current_package, imports, py_package, py_type): """ package = foo.bar name = baz.Foo