diff --git a/betterproto/compile/importing.py b/betterproto/compile/importing.py index 824e57e..8b9ee3f 100644 --- a/betterproto/compile/importing.py +++ b/betterproto/compile/importing.py @@ -44,7 +44,6 @@ def get_ref_type( if type_name == "google.protobuf.Timestamp": return "datetime" - # Use precompiled classes for google.protobuf.* objects if type_name.startswith("google.protobuf.") and type_name.count(".") == 2: type_name = type_name.rsplit(".", maxsplit=1)[1] @@ -53,12 +52,12 @@ def get_ref_type( imports.add(f"import {import_package} as {import_alias}") return f"{import_alias}.{type_name}" - importing_package: List[str] = type_name.split('.') + importing_package: List[str] = type_name.split(".") importing_type: str = stringcase.pascalcase(importing_package.pop()) - current_package: List[str] = package.split('.') if package else [] + current_package: List[str] = package.split(".") if package else [] # importing sibling - ''' + """ package = name = Foo @@ -67,26 +66,26 @@ def get_ref_type( package = foo.bar name = foo.bar.Baz - ''' + """ if importing_package == current_package: imports.add(f"from . import {importing_type}") return importing_type # importing child & descendent: - ''' + """ package = name = foo.Bar package = name = foo.bar.Baz - ''' - if importing_package[0:len(current_package)] == current_package: - importing_descendent = importing_package[len(current_package):] - string_from = '.'.join(importing_descendent[0:-1]) + """ + if importing_package[0 : len(current_package)] == current_package: + importing_descendent = importing_package[len(current_package) :] + string_from = ".".join(importing_descendent[0:-1]) string_import = importing_descendent[-1] if string_from: - string_alias = '_'.join(importing_descendent) + string_alias = "_".join(importing_descendent) imports.add(f"from .{string_from} import {string_import} as {string_alias}") return f"{string_alias}.{importing_type}" else: @@ -94,7 +93,7 @@ def get_ref_type( return f"{string_import}.{importing_type}" # importing parent & ancestor - ''' + """ package = foo.bar name = foo.Foo @@ -103,27 +102,34 @@ def get_ref_type( package = foo.bar.baz name = Bar - ''' - if current_package[0:len(importing_package)] == importing_package: - distance = len(current_package) - len(importing_package) - imports.add(f"from .{'.' * distance} import {importing_type}") + """ + if current_package[0 : len(importing_package)] == importing_package: + distance_up = len(current_package) - len(importing_package) + imports.add(f"from .{'.' * distance_up} import {importing_type}") return importing_type # importing unrelated or cousin - ''' + """ package = foo.bar name = baz.Foo package = foo.bar.baz name = foo.example.Bar - ''' - root_distance = len(current_package) - shared_ancestory_length = reduce(lambda l, pair: l + (pair[0] == pair[1]), zip(current_package, importing_package), 0) + """ - string_from = f"{'.' * (shared_ancestory_length+1)}.{'.'.join(importing_package[0:-1])}" + shared_ancestory = [ + pair[0] + for pair in zip(current_package, importing_package) + if pair[0] == pair[1] + ] + distance_up = len(current_package) - len(shared_ancestory) + + string_from = f".{'.' * distance_up}" + ".".join( + importing_package[len(shared_ancestory) : -1] + ) string_import = importing_package[-1] - string_alias = '_' * root_distance + safe_snake_case('.'.join(importing_package)) + string_alias = f"{'_' * distance_up}" + safe_snake_case( + ".".join(importing_package[len(shared_ancestory) :]) + ) imports.add(f"from {string_from} import {string_import} as {string_alias}") - return f"{string_alias}.{importing_type}" - # return type_name diff --git a/betterproto/tests/test_get_ref_type.py b/betterproto/tests/test_get_ref_type.py index aa27352..8d72406 100644 --- a/betterproto/tests/test_get_ref_type.py +++ b/betterproto/tests/test_get_ref_type.py @@ -134,7 +134,9 @@ def test_import_deeply_nested_child_from_root(): def test_import_deeply_nested_child_from_package(): imports = set() name = get_ref_type( - package="package", imports=imports, type_name="package.deeply.nested.child.Message" + package="package", + imports=imports, + type_name="package.deeply.nested.child.Message", ) assert imports == {"from .deeply.nested import child as deeply_nested_child"} @@ -207,23 +209,47 @@ def test_import_root_package_from_deeply_nested_child(): def test_import_unrelated_package(): imports = set() - name = get_ref_type(package="a", imports=imports, type_name="b.Message") + name = get_ref_type(package="a", imports=imports, type_name="p.Message") - assert imports == {"from .. import b as _b"} - assert name == "_b.Message" + assert imports == {"from .. import p as _p"} + assert name == "_p.Message" + + +def test_import_unrelated_nested_package(): + imports = set() + name = get_ref_type(package="a.b", imports=imports, type_name="p.q.Message") + + assert imports == {"from ...p import q as __p_q"} + assert name == "__p_q.Message" + + +def test_import_unrelated_deeply_nested_package(): + imports = set() + name = get_ref_type(package="a.b.c.d", imports=imports, type_name="p.q.r.s.Message") + + assert imports == {"from .....p.q.r import s as ____p_q_r_s"} + assert name == "____p_q_r_s.Message" def test_import_cousin_package(): imports = set() - name = get_ref_type(package="a.a", imports=imports, type_name="a.b.Message") + name = get_ref_type(package="a.x", imports=imports, type_name="a.y.Message") - assert imports == {"from .. import b as __b"} - assert name == "__b.Message" + assert imports == {"from .. import y as _y"} + assert name == "_y.Message" def test_import_far_cousin_package(): imports = set() - name = get_ref_type(package="a.a.a", imports=imports, type_name="a.b.c.Message") + name = get_ref_type(package="a.x.y", imports=imports, type_name="a.b.c.Message") - assert imports == {"from ... import c as ___c"} - assert name == "___c.Message" + assert imports == {"from ...b import c as __b_c"} + assert name == "__b_c.Message" + + +def test_import_far_far_cousin_package(): + imports = set() + name = get_ref_type(package="a.x.y.z", imports=imports, type_name="a.b.c.d.Message") + + assert imports == {"from ....b.c import d as ___b_c_d"} + assert name == "___b_c_d.Message"