Fixes issue where importing cousin where path has a package with the same name broke import
This commit is contained in:
		| @@ -1,3 +1,4 @@ | ||||
| import os | ||||
| import re | ||||
| from typing import Dict, List, Type | ||||
|  | ||||
| @@ -138,16 +139,14 @@ def import_cousin(current_package, imports, py_package, py_type): | ||||
|     package = foo.bar.baz | ||||
|     name    = foo.example.Bar | ||||
|     """ | ||||
|     shared_ancestory = [ | ||||
|         pair[0] for pair in zip(current_package, py_package) if pair[0] == pair[1] | ||||
|     ] | ||||
|     distance_up = len(current_package) - len(shared_ancestory) | ||||
|     shared_ancestry = os.path.commonprefix([current_package, py_package]) | ||||
|     distance_up = len(current_package) - len(shared_ancestry) | ||||
|     string_from = f".{'.' * distance_up}" + ".".join( | ||||
|         py_package[len(shared_ancestory) : -1] | ||||
|         py_package[len(shared_ancestry) : -1] | ||||
|     ) | ||||
|     string_import = py_package[-1] | ||||
|     alias = f"{'_' * distance_up}" + safe_snake_case( | ||||
|         ".".join(py_package[len(shared_ancestory) :]) | ||||
|         ".".join(py_package[len(shared_ancestry) :]) | ||||
|     ) | ||||
|     imports.add(f"from {string_from} import {string_import} as {alias}") | ||||
|     return f"{alias}.{py_type}" | ||||
|   | ||||
| @@ -242,6 +242,26 @@ def test_import_cousin_package(): | ||||
|     assert name == "_y.Message" | ||||
|  | ||||
|  | ||||
| def test_import_cousin_package_different_name(): | ||||
|     imports = set() | ||||
|     name = get_ref_type( | ||||
|         package="test.package1", imports=imports, source_type="cousin.package2.Message" | ||||
|     ) | ||||
|  | ||||
|     assert imports == {"from ...cousin import package2 as __cousin_package2"} | ||||
|     assert name == "__cousin_package2.Message" | ||||
|  | ||||
|  | ||||
| def test_import_cousin_package_same_name(): | ||||
|     imports = set() | ||||
|     name = get_ref_type( | ||||
|         package="test.package", imports=imports, source_type="cousin.package.Message" | ||||
|     ) | ||||
|  | ||||
|     assert imports == {"from ...cousin import package as __cousin_package"} | ||||
|     assert name == "__cousin_package.Message" | ||||
|  | ||||
|  | ||||
| def test_import_far_cousin_package(): | ||||
|     imports = set() | ||||
|     name = get_ref_type(package="a.x.y", imports=imports, source_type="a.b.c.Message") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user