From 0af0cf4bfbd5369a56db7f9da54963ac09f7a277 Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Sat, 4 Jul 2020 15:35:42 +0200 Subject: [PATCH] Fixes circular import problem when a non-circular dependency triangle is flattened into two python packages --- betterproto/compile/importing.py | 12 ++-- betterproto/plugin.py | 2 +- betterproto/templates/template.py.j2 | 16 +++--- betterproto/tests/inputs/config.py | 1 - betterproto/tests/test_get_ref_type.py | 76 +++++++++++++++----------- 5 files changed, 59 insertions(+), 48 deletions(-) diff --git a/betterproto/compile/importing.py b/betterproto/compile/importing.py index 40441f8..57ef376 100644 --- a/betterproto/compile/importing.py +++ b/betterproto/compile/importing.py @@ -86,7 +86,7 @@ def reference_absolute(imports, py_package, py_type): string_import = ".".join(py_package) string_alias = safe_snake_case(string_import) imports.add(f"import {string_import} as {string_alias}") - return f"{string_alias}.{py_type}" + return f'"{string_alias}.{py_type}"' def reference_sibling(py_type: str) -> str: @@ -109,10 +109,10 @@ def reference_descendent( if string_from: string_alias = "_".join(importing_descendent) imports.add(f"from .{string_from} import {string_import} as {string_alias}") - return f"{string_alias}.{py_type}" + return f'"{string_alias}.{py_type}"' else: imports.add(f"from . import {string_import}") - return f"{string_import}.{py_type}" + return f'"{string_import}.{py_type}"' def reference_ancestor( @@ -130,11 +130,11 @@ def reference_ancestor( string_alias = f"_{'_' * distance_up}{string_import}__" string_from = f"..{'.' * distance_up}" imports.add(f"from {string_from} import {string_import} as {string_alias}") - return f"{string_alias}.{py_type}" + return f'"{string_alias}.{py_type}"' else: string_alias = f"{'_' * distance_up}{py_type}__" imports.add(f"from .{'.' * distance_up} import {py_type} as {string_alias}") - return string_alias + return f'"{string_alias}"' def reference_cousin( @@ -157,4 +157,4 @@ def reference_cousin( + "__" ) imports.add(f"from {string_from} import {string_import} as {string_alias}") - return f"{string_alias}.{py_type}" + return f'"{string_alias}.{py_type}"' diff --git a/betterproto/plugin.py b/betterproto/plugin.py index e835fab..9f4df64 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -329,7 +329,7 @@ def generate_code(request, response): output["imports"], method.output_type, unwrap=False, - ).strip('"'), + ), "client_streaming": method.client_streaming, "server_streaming": method.server_streaming, } diff --git a/betterproto/templates/template.py.j2 b/betterproto/templates/template.py.j2 index 3894619..b2d9112 100644 --- a/betterproto/templates/template.py.j2 +++ b/betterproto/templates/template.py.j2 @@ -16,10 +16,6 @@ import betterproto import grpclib {% endif %} -{% for i in description.imports %} -{{ i }} -{% endfor %} - {% if description.enums %}{% for enum in description.enums %} class {{ enum.py_name }}(betterproto.Enum): @@ -102,14 +98,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): "{{ method.route }}", request_iterator, {{ method.input }}, - {{ method.output }}, + {{ method.output.strip('"') }}, ): yield response {% else %}{# i.e. not client streaming #} async for response in self._unary_stream( "{{ method.route }}", request, - {{ method.output }}, + {{ method.output.strip('"') }}, ): yield response @@ -120,16 +116,20 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): "{{ method.route }}", request_iterator, {{ method.input }}, - {{ method.output }} + {{ method.output.strip('"') }} ) {% else %}{# i.e. not client streaming #} return await self._unary_unary( "{{ method.route }}", request, - {{ method.output }} + {{ method.output.strip('"') }} ) {% endif %}{# client streaming #} {% endif %} {% endfor %} {% endfor %} + +{% for i in description.imports %} +{{ i }} +{% endfor %} \ No newline at end of file diff --git a/betterproto/tests/inputs/config.py b/betterproto/tests/inputs/config.py index eab5ea4..38e9603 100644 --- a/betterproto/tests/inputs/config.py +++ b/betterproto/tests/inputs/config.py @@ -1,7 +1,6 @@ # Test cases that are expected to fail, e.g. unimplemented features or bug-fixes. # Remove from list when fixed. xfail = { - "import_circular_dependency", "oneof_enum", # 63 "namespace_keywords", # 70 "namespace_builtin_types", # 53 diff --git a/betterproto/tests/test_get_ref_type.py b/betterproto/tests/test_get_ref_type.py index 2bedf76..5a1722b 100644 --- a/betterproto/tests/test_get_ref_type.py +++ b/betterproto/tests/test_get_ref_type.py @@ -8,22 +8,22 @@ from ..compile.importing import get_type_reference, parse_source_type_name [ ( ".google.protobuf.Empty", - "betterproto_lib_google_protobuf.Empty", + '"betterproto_lib_google_protobuf.Empty"', "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf", ), ( ".google.protobuf.Struct", - "betterproto_lib_google_protobuf.Struct", + '"betterproto_lib_google_protobuf.Struct"', "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf", ), ( ".google.protobuf.ListValue", - "betterproto_lib_google_protobuf.ListValue", + '"betterproto_lib_google_protobuf.ListValue"', "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf", ), ( ".google.protobuf.Value", - "betterproto_lib_google_protobuf.Value", + '"betterproto_lib_google_protobuf.Value"', "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf", ), ], @@ -67,15 +67,27 @@ def test_referenceing_google_wrappers_unwraps_them( @pytest.mark.parametrize( ["google_type", "expected_name"], [ - (".google.protobuf.DoubleValue", "betterproto_lib_google_protobuf.DoubleValue"), - (".google.protobuf.FloatValue", "betterproto_lib_google_protobuf.FloatValue"), - (".google.protobuf.Int32Value", "betterproto_lib_google_protobuf.Int32Value"), - (".google.protobuf.Int64Value", "betterproto_lib_google_protobuf.Int64Value"), - (".google.protobuf.UInt32Value", "betterproto_lib_google_protobuf.UInt32Value"), - (".google.protobuf.UInt64Value", "betterproto_lib_google_protobuf.UInt64Value"), - (".google.protobuf.BoolValue", "betterproto_lib_google_protobuf.BoolValue"), - (".google.protobuf.StringValue", "betterproto_lib_google_protobuf.StringValue"), - (".google.protobuf.BytesValue", "betterproto_lib_google_protobuf.BytesValue"), + ( + ".google.protobuf.DoubleValue", + '"betterproto_lib_google_protobuf.DoubleValue"', + ), + (".google.protobuf.FloatValue", '"betterproto_lib_google_protobuf.FloatValue"'), + (".google.protobuf.Int32Value", '"betterproto_lib_google_protobuf.Int32Value"'), + (".google.protobuf.Int64Value", '"betterproto_lib_google_protobuf.Int64Value"'), + ( + ".google.protobuf.UInt32Value", + '"betterproto_lib_google_protobuf.UInt32Value"', + ), + ( + ".google.protobuf.UInt64Value", + '"betterproto_lib_google_protobuf.UInt64Value"', + ), + (".google.protobuf.BoolValue", '"betterproto_lib_google_protobuf.BoolValue"'), + ( + ".google.protobuf.StringValue", + '"betterproto_lib_google_protobuf.StringValue"', + ), + (".google.protobuf.BytesValue", '"betterproto_lib_google_protobuf.BytesValue"'), ], ) def test_referenceing_google_wrappers_without_unwrapping( @@ -95,7 +107,7 @@ def test_reference_child_package_from_package(): ) assert imports == {"from . import child"} - assert name == "child.Message" + assert name == '"child.Message"' def test_reference_child_package_from_root(): @@ -103,7 +115,7 @@ def test_reference_child_package_from_root(): name = get_type_reference(package="", imports=imports, source_type="child.Message") assert imports == {"from . import child"} - assert name == "child.Message" + assert name == '"child.Message"' def test_reference_camel_cased(): @@ -113,7 +125,7 @@ def test_reference_camel_cased(): ) assert imports == {"from . import child_package"} - assert name == "child_package.ExampleMessage" + assert name == '"child_package.ExampleMessage"' def test_reference_nested_child_from_root(): @@ -123,7 +135,7 @@ def test_reference_nested_child_from_root(): ) assert imports == {"from .nested import child as nested_child"} - assert name == "nested_child.Message" + assert name == '"nested_child.Message"' def test_reference_deeply_nested_child_from_root(): @@ -133,7 +145,7 @@ def test_reference_deeply_nested_child_from_root(): ) assert imports == {"from .deeply.nested import child as deeply_nested_child"} - assert name == "deeply_nested_child.Message" + assert name == '"deeply_nested_child.Message"' def test_reference_deeply_nested_child_from_package(): @@ -145,7 +157,7 @@ def test_reference_deeply_nested_child_from_package(): ) assert imports == {"from .deeply.nested import child as deeply_nested_child"} - assert name == "deeply_nested_child.Message" + assert name == '"deeply_nested_child.Message"' def test_reference_root_sibling(): @@ -181,7 +193,7 @@ def test_reference_parent_package_from_child(): ) assert imports == {"from ... import package as __package__"} - assert name == "__package__.Message" + assert name == '"__package__.Message"' def test_reference_parent_package_from_deeply_nested_child(): @@ -193,7 +205,7 @@ def test_reference_parent_package_from_deeply_nested_child(): ) assert imports == {"from ... import nested as __nested__"} - assert name == "__nested__.Message" + assert name == '"__nested__.Message"' def test_reference_ancestor_package_from_nested_child(): @@ -205,7 +217,7 @@ def test_reference_ancestor_package_from_nested_child(): ) assert imports == {"from .... import ancestor as ___ancestor__"} - assert name == "___ancestor__.Message" + assert name == '"___ancestor__.Message"' def test_reference_root_package_from_child(): @@ -215,7 +227,7 @@ def test_reference_root_package_from_child(): ) assert imports == {"from ... import Message as __Message__"} - assert name == "__Message__" + assert name == '"__Message__"' def test_reference_root_package_from_deeply_nested_child(): @@ -225,7 +237,7 @@ def test_reference_root_package_from_deeply_nested_child(): ) assert imports == {"from ..... import Message as ____Message__"} - assert name == "____Message__" + assert name == '"____Message__"' def test_reference_unrelated_package(): @@ -233,7 +245,7 @@ def test_reference_unrelated_package(): name = get_type_reference(package="a", imports=imports, source_type="p.Message") assert imports == {"from .. import p as _p__"} - assert name == "_p__.Message" + assert name == '"_p__.Message"' def test_reference_unrelated_nested_package(): @@ -241,7 +253,7 @@ def test_reference_unrelated_nested_package(): name = get_type_reference(package="a.b", imports=imports, source_type="p.q.Message") assert imports == {"from ...p import q as __p_q__"} - assert name == "__p_q__.Message" + assert name == '"__p_q__.Message"' def test_reference_unrelated_deeply_nested_package(): @@ -251,7 +263,7 @@ def test_reference_unrelated_deeply_nested_package(): ) assert imports == {"from .....p.q.r import s as ____p_q_r_s__"} - assert name == "____p_q_r_s__.Message" + assert name == '"____p_q_r_s__.Message"' def test_reference_cousin_package(): @@ -259,7 +271,7 @@ def test_reference_cousin_package(): name = get_type_reference(package="a.x", imports=imports, source_type="a.y.Message") assert imports == {"from .. import y as _y__"} - assert name == "_y__.Message" + assert name == '"_y__.Message"' def test_reference_cousin_package_different_name(): @@ -269,7 +281,7 @@ def test_reference_cousin_package_different_name(): ) assert imports == {"from ...cousin import package2 as __cousin_package2__"} - assert name == "__cousin_package2__.Message" + assert name == '"__cousin_package2__.Message"' def test_reference_cousin_package_same_name(): @@ -279,7 +291,7 @@ def test_reference_cousin_package_same_name(): ) assert imports == {"from ...cousin import package as __cousin_package__"} - assert name == "__cousin_package__.Message" + assert name == '"__cousin_package__.Message"' def test_reference_far_cousin_package(): @@ -289,7 +301,7 @@ def test_reference_far_cousin_package(): ) assert imports == {"from ...b import c as __b_c__"} - assert name == "__b_c__.Message" + assert name == '"__b_c__.Message"' def test_reference_far_far_cousin_package(): @@ -299,7 +311,7 @@ def test_reference_far_far_cousin_package(): ) assert imports == {"from ....b.c import d as ___b_c_d__"} - assert name == "___b_c_d__.Message" + assert name == '"___b_c_d__.Message"' @pytest.mark.parametrize(