diff --git a/betterproto/compile/importing.py b/betterproto/compile/importing.py index 6890b35..7fbf317 100644 --- a/betterproto/compile/importing.py +++ b/betterproto/compile/importing.py @@ -1,6 +1,6 @@ import os import re -from typing import Dict, List, Type +from typing import Dict, List, Set, Type from betterproto import safe_snake_case from betterproto.compile.naming import pythonize_class_name @@ -35,7 +35,7 @@ def parse_source_type_name(field_type_name): return package, name -def get_ref_type( +def get_type_reference( package: str, imports: set, source_type: str, unwrap: bool = True, ) -> str: """ @@ -65,48 +65,43 @@ def get_ref_type( py_package = ["betterproto", "lib"] + py_package if py_package[:1] == ["betterproto"]: - return import_root(imports, py_package, py_type) + return reference_absolute(imports, py_package, py_type) if py_package == current_package: - return import_sibling(py_type) + return reference_sibling(py_type) if py_package[: len(current_package)] == current_package: - return import_descendent(current_package, imports, py_package, py_type) + return reference_descendent(current_package, imports, py_package, py_type) if current_package[: len(py_package)] == py_package: - return import_ancestor(current_package, imports, py_package, py_type) + return reference_ancestor(current_package, imports, py_package, py_type) - return import_cousin(current_package, imports, py_package, py_type) + return reference_cousin(current_package, imports, py_package, py_type) -def import_root(imports, py_package, py_type): +def reference_absolute(imports, py_package, py_type): + """ + Returns a reference to a python type located in the root, i.e. sys.path. + """ string_import = ".".join(py_package) string_alias = safe_snake_case(string_import) imports.add(f"import {string_import} as {string_alias}") return f"{string_alias}.{py_type}" -def import_sibling(py_type): +def reference_sibling(py_type: str) -> str: """ - package = - name = Foo - - package = foo - name = foo.Bar - - package = foo.bar - name = foo.bar.Baz + Returns a reference to a python type within the same package as the current package. """ return f'"{py_type}"' -def import_descendent(current_package, imports, py_package, py_type): +def reference_descendent( + current_package: List[str], imports: Set[str], py_package: List[str], py_type: str +) -> str: """ - package = - name = foo.Bar - - package = - name = foo.bar.Baz + Returns a reference to a python type in a package that is a descendent of the current package, + and adds the required import that is aliased to avoid name conflicts. """ importing_descendent = py_package[len(current_package) :] string_from = ".".join(importing_descendent[:-1]) @@ -120,16 +115,12 @@ def import_descendent(current_package, imports, py_package, py_type): return f"{string_import}.{py_type}" -def import_ancestor(current_package, imports, py_package, py_type): +def reference_ancestor( + current_package: List[str], imports: Set[str], py_package: List[str], py_type: str +) -> str: """ - package = foo.bar - name = foo.Foo - - package = foo - name = Bar - - package = foo.bar.baz - name = Bar + Returns a reference to a python type in a package which is an ancestor to the current package, + and adds the required import that is aliased (if possible) to avoid name conflicts. """ distance_up = len(current_package) - len(py_package) if py_package: @@ -144,13 +135,12 @@ def import_ancestor(current_package, imports, py_package, py_type): return py_type -def import_cousin(current_package, imports, py_package, py_type): +def reference_cousin( + current_package: List[str], imports: Set[str], py_package: List[str], py_type: str +) -> str: """ - package = foo.bar - name = baz.Foo - - package = foo.bar.baz - name = foo.example.Bar + Returns a reference to a python type in a package that is not descendent, ancestor or sibling, + and adds the required import that is aliased to avoid name conflicts. """ shared_ancestry = os.path.commonprefix([current_package, py_package]) distance_up = len(current_package) - len(shared_ancestry) diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 8065c13..384e120 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -9,7 +9,7 @@ import textwrap from typing import List import betterproto -from betterproto.compile.importing import get_ref_type +from betterproto.compile.importing import get_type_reference from betterproto.compile.naming import ( pythonize_class_name, pythonize_field_name, @@ -51,7 +51,7 @@ def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str: return "str" elif field.type in [11, 14]: # Type referencing another defined Message or a named enum - return get_ref_type(package, imports, field.type_name) + return get_type_reference(package, imports, field.type_name) elif field.type == 12: return "bytes" else: @@ -306,7 +306,7 @@ def generate_code(request, response): raise NotImplementedError("Client streaming not yet supported") input_message = None - input_type = get_ref_type( + input_type = get_type_reference( package, output["imports"], method.input_type ).strip('"') for msg in output["messages"]: @@ -323,11 +323,11 @@ def generate_code(request, response): "py_name": pythonize_method_name(method.name), "comment": get_comment(proto_file, [6, i, 2, j], indent=8), "route": f"/{package}.{service.name}/{method.name}", - "input": get_ref_type( + "input": get_type_reference( package, output["imports"], method.input_type ).strip('"'), "input_message": input_message, - "output": get_ref_type( + "output": get_type_reference( package, output["imports"], method.output_type, diff --git a/betterproto/tests/test_get_ref_type.py b/betterproto/tests/test_get_ref_type.py index 412b4ed..5cb5f74 100644 --- a/betterproto/tests/test_get_ref_type.py +++ b/betterproto/tests/test_get_ref_type.py @@ -1,6 +1,6 @@ import pytest -from ..compile.importing import get_ref_type, parse_source_type_name +from ..compile.importing import get_type_reference, parse_source_type_name @pytest.mark.parametrize( @@ -28,11 +28,11 @@ from ..compile.importing import get_ref_type, parse_source_type_name ), ], ) -def test_import_google_wellknown_types_non_wrappers( +def test_reference_google_wellknown_types_non_wrappers( google_type: str, expected_name: str, expected_import: str ): imports = set() - name = get_ref_type(package="", imports=imports, source_type=google_type) + name = get_type_reference(package="", imports=imports, source_type=google_type) assert name == expected_name assert imports.__contains__( @@ -54,9 +54,11 @@ def test_import_google_wellknown_types_non_wrappers( (".google.protobuf.BytesValue", "Optional[bytes]"), ], ) -def test_importing_google_wrappers_unwraps_them(google_type: str, expected_name: str): +def test_referenceing_google_wrappers_unwraps_them( + google_type: str, expected_name: str +): imports = set() - name = get_ref_type(package="", imports=imports, source_type=google_type) + name = get_type_reference(package="", imports=imports, source_type=google_type) assert name == expected_name assert imports == set() @@ -76,19 +78,19 @@ def test_importing_google_wrappers_unwraps_them(google_type: str, expected_name: (".google.protobuf.BytesValue", "betterproto_lib_google_protobuf.BytesValue"), ], ) -def test_importing_google_wrappers_without_unwrapping( +def test_referenceing_google_wrappers_without_unwrapping( google_type: str, expected_name: str ): - name = get_ref_type( + name = get_type_reference( package="", imports=set(), source_type=google_type, unwrap=False ) assert name == expected_name -def test_import_child_package_from_package(): +def test_reference_child_package_from_package(): imports = set() - name = get_ref_type( + name = get_type_reference( package="package", imports=imports, source_type="package.child.Message" ) @@ -96,17 +98,17 @@ def test_import_child_package_from_package(): assert name == "child.Message" -def test_import_child_package_from_root(): +def test_reference_child_package_from_root(): imports = set() - name = get_ref_type(package="", imports=imports, source_type="child.Message") + name = get_type_reference(package="", imports=imports, source_type="child.Message") assert imports == {"from . import child"} assert name == "child.Message" -def test_import_camel_cased(): +def test_reference_camel_cased(): imports = set() - name = get_ref_type( + name = get_type_reference( package="", imports=imports, source_type="child_package.example_message" ) @@ -114,17 +116,19 @@ def test_import_camel_cased(): assert name == "child_package.ExampleMessage" -def test_import_nested_child_from_root(): +def test_reference_nested_child_from_root(): imports = set() - name = get_ref_type(package="", imports=imports, source_type="nested.child.Message") + name = get_type_reference( + package="", imports=imports, source_type="nested.child.Message" + ) assert imports == {"from .nested import child as nested_child"} assert name == "nested_child.Message" -def test_import_deeply_nested_child_from_root(): +def test_reference_deeply_nested_child_from_root(): imports = set() - name = get_ref_type( + name = get_type_reference( package="", imports=imports, source_type="deeply.nested.child.Message" ) @@ -132,9 +136,9 @@ def test_import_deeply_nested_child_from_root(): assert name == "deeply_nested_child.Message" -def test_import_deeply_nested_child_from_package(): +def test_reference_deeply_nested_child_from_package(): imports = set() - name = get_ref_type( + name = get_type_reference( package="package", imports=imports, source_type="package.deeply.nested.child.Message", @@ -144,33 +148,35 @@ def test_import_deeply_nested_child_from_package(): assert name == "deeply_nested_child.Message" -def test_import_root_sibling(): +def test_reference_root_sibling(): imports = set() - name = get_ref_type(package="", imports=imports, source_type="Message") + name = get_type_reference(package="", imports=imports, source_type="Message") assert imports == set() assert name == '"Message"' -def test_import_nested_siblings(): +def test_reference_nested_siblings(): imports = set() - name = get_ref_type(package="foo", imports=imports, source_type="foo.Message") + name = get_type_reference(package="foo", imports=imports, source_type="foo.Message") + assert imports == set() assert name == '"Message"' -def test_import_deeply_nested_siblings(): +def test_reference_deeply_nested_siblings(): imports = set() - name = get_ref_type( + name = get_type_reference( package="foo.bar", imports=imports, source_type="foo.bar.Message" ) + assert imports == set() assert name == '"Message"' -def test_import_parent_package_from_child(): +def test_reference_parent_package_from_child(): imports = set() - name = get_ref_type( + name = get_type_reference( package="package.child", imports=imports, source_type="package.Message" ) @@ -178,9 +184,9 @@ def test_import_parent_package_from_child(): assert name == "__package__.Message" -def test_import_parent_package_from_deeply_nested_child(): +def test_reference_parent_package_from_deeply_nested_child(): imports = set() - name = get_ref_type( + name = get_type_reference( package="package.deeply.nested.child", imports=imports, source_type="package.deeply.nested.Message", @@ -190,9 +196,9 @@ def test_import_parent_package_from_deeply_nested_child(): assert name == "__nested__.Message" -def test_import_ancestor_package_from_nested_child(): +def test_reference_ancestor_package_from_nested_child(): imports = set() - name = get_ref_type( + name = get_type_reference( package="package.ancestor.nested.child", imports=imports, source_type="package.ancestor.Message", @@ -202,17 +208,19 @@ def test_import_ancestor_package_from_nested_child(): assert name == "___ancestor__.Message" -def test_import_root_package_from_child(): +def test_reference_root_package_from_child(): imports = set() - name = get_ref_type(package="package.child", imports=imports, source_type="Message") + name = get_type_reference( + package="package.child", imports=imports, source_type="Message" + ) assert imports == {"from ... import Message"} assert name == "Message" -def test_import_root_package_from_deeply_nested_child(): +def test_reference_root_package_from_deeply_nested_child(): imports = set() - name = get_ref_type( + name = get_type_reference( package="package.deeply.nested.child", imports=imports, source_type="Message" ) @@ -220,25 +228,25 @@ def test_import_root_package_from_deeply_nested_child(): assert name == "Message" -def test_import_unrelated_package(): +def test_reference_unrelated_package(): imports = set() - name = get_ref_type(package="a", imports=imports, source_type="p.Message") + name = get_type_reference(package="a", imports=imports, source_type="p.Message") assert imports == {"from .. import p as _p__"} assert name == "_p__.Message" -def test_import_unrelated_nested_package(): +def test_reference_unrelated_nested_package(): imports = set() - name = get_ref_type(package="a.b", imports=imports, source_type="p.q.Message") + name = get_type_reference(package="a.b", imports=imports, source_type="p.q.Message") assert imports == {"from ...p import q as __p_q__"} assert name == "__p_q__.Message" -def test_import_unrelated_deeply_nested_package(): +def test_reference_unrelated_deeply_nested_package(): imports = set() - name = get_ref_type( + name = get_type_reference( package="a.b.c.d", imports=imports, source_type="p.q.r.s.Message" ) @@ -246,17 +254,17 @@ def test_import_unrelated_deeply_nested_package(): assert name == "____p_q_r_s__.Message" -def test_import_cousin_package(): +def test_reference_cousin_package(): imports = set() - name = get_ref_type(package="a.x", imports=imports, source_type="a.y.Message") + name = get_type_reference(package="a.x", imports=imports, source_type="a.y.Message") assert imports == {"from .. import y as _y__"} assert name == "_y__.Message" -def test_import_cousin_package_different_name(): +def test_reference_cousin_package_different_name(): imports = set() - name = get_ref_type( + name = get_type_reference( package="test.package1", imports=imports, source_type="cousin.package2.Message" ) @@ -264,9 +272,9 @@ def test_import_cousin_package_different_name(): assert name == "__cousin_package2__.Message" -def test_import_cousin_package_same_name(): +def test_reference_cousin_package_same_name(): imports = set() - name = get_ref_type( + name = get_type_reference( package="test.package", imports=imports, source_type="cousin.package.Message" ) @@ -274,17 +282,19 @@ def test_import_cousin_package_same_name(): assert name == "__cousin_package__.Message" -def test_import_far_cousin_package(): +def test_reference_far_cousin_package(): imports = set() - name = get_ref_type(package="a.x.y", imports=imports, source_type="a.b.c.Message") + name = get_type_reference( + package="a.x.y", imports=imports, source_type="a.b.c.Message" + ) assert imports == {"from ...b import c as __b_c__"} assert name == "__b_c__.Message" -def test_import_far_far_cousin_package(): +def test_reference_far_far_cousin_package(): imports = set() - name = get_ref_type( + name = get_type_reference( package="a.x.y.z", imports=imports, source_type="a.b.c.d.Message" )