diff --git a/betterproto/tests/README.md b/betterproto/tests/README.md index ea15758..de2e2d2 100644 --- a/betterproto/tests/README.md +++ b/betterproto/tests/README.md @@ -73,3 +73,18 @@ The following tests are automatically executed for all cases: - `betterproto/tests/output_reference` — *reference implementation classes* - `pipenv run test` +## Intentionally Failing tests + +The standard test suite includes tests that fail by intention. These tests document known bugs and missing features that are intended to be corrented in the future. + +When running `pytest`, they show up as `x` or `X` in the test results. + +``` +betterproto/tests/test_inputs.py ..x...x..x...x.X........xx........x.....x.......x.xx....x...................... [ 84%] +``` + +- `.` — PASSED +- `x` — XFAIL: expected failure +- `X` — XPASS: expected failure, but still passed + +Test cases marked for expected failure are declared in [inputs/xfail.py](inputs.xfail.py) \ No newline at end of file diff --git a/betterproto/tests/generate.py b/betterproto/tests/generate.py index fc3c4cd..f3b92e6 100644 --- a/betterproto/tests/generate.py +++ b/betterproto/tests/generate.py @@ -1,5 +1,7 @@ #!/usr/bin/env python +import glob import os +import shutil import sys from typing import Set @@ -17,6 +19,14 @@ from betterproto.tests.util import ( os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" +def clear_directory(path: str): + for file_or_directory in glob.glob(os.path.join(path, "*")): + if os.path.isdir(file_or_directory): + shutil.rmtree(file_or_directory) + else: + os.remove(file_or_directory) + + def generate(whitelist: Set[str]): path_whitelist = {os.path.realpath(e) for e in whitelist if os.path.exists(e)} name_whitelist = {e for e in whitelist if not os.path.exists(e)} @@ -24,26 +34,33 @@ def generate(whitelist: Set[str]): test_case_names = set(get_directories(inputs_path)) for test_case_name in sorted(test_case_names): - test_case_path = os.path.realpath(os.path.join(inputs_path, test_case_name)) + test_case_input_path = os.path.realpath( + os.path.join(inputs_path, test_case_name) + ) if ( whitelist - and test_case_path not in path_whitelist + and test_case_input_path not in path_whitelist and test_case_name not in name_whitelist ): continue - case_output_dir_reference = os.path.join(output_path_reference, test_case_name) - case_output_dir_betterproto = os.path.join( + test_case_output_path_reference = os.path.join( + output_path_reference, test_case_name + ) + test_case_output_path_betterproto = os.path.join( output_path_betterproto, test_case_name ) print(f"Generating output for {test_case_name}") - os.makedirs(case_output_dir_reference, exist_ok=True) - os.makedirs(case_output_dir_betterproto, exist_ok=True) + os.makedirs(test_case_output_path_reference, exist_ok=True) + os.makedirs(test_case_output_path_betterproto, exist_ok=True) - protoc_reference(test_case_path, case_output_dir_reference) - protoc_plugin(test_case_path, case_output_dir_betterproto) + clear_directory(test_case_output_path_reference) + clear_directory(test_case_output_path_betterproto) + + protoc_reference(test_case_input_path, test_case_output_path_reference) + protoc_plugin(test_case_input_path, test_case_output_path_betterproto) HELP = "\n".join( diff --git a/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py b/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py index 76c012b..fb2152b 100644 --- a/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py +++ b/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py @@ -43,11 +43,14 @@ async def test_channel_receives_wrapped_type( async def test_service_unwraps_response( service_method: Callable[[TestStub], Any], wrapper_class: Callable, value ): + """ + grpclib does not unwrap wrapper values returned by services + """ wrapped_value = wrapper_class() wrapped_value.value = value service = TestStub(MockChannel(responses=[wrapped_value])) response_value = await service_method(service) - assert type(response_value) == value + assert response_value == value assert type(response_value) == type(value) diff --git a/betterproto/tests/inputs/import_child_package_from_package/child.proto b/betterproto/tests/inputs/import_child_package_from_package/child.proto new file mode 100644 index 0000000..0865fc8 --- /dev/null +++ b/betterproto/tests/inputs/import_child_package_from_package/child.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package package.childpackage; + +message ChildMessage { + +} diff --git a/betterproto/tests/inputs/import_child_package_from_package/import_child_package_from_package.proto b/betterproto/tests/inputs/import_child_package_from_package/import_child_package_from_package.proto new file mode 100644 index 0000000..0d09132 --- /dev/null +++ b/betterproto/tests/inputs/import_child_package_from_package/import_child_package_from_package.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +import "package_message.proto"; + +// Tests generated imports when a message in a package refers to a message in a nested child package. + +message Test { + package.PackageMessage message = 1; +} diff --git a/betterproto/tests/inputs/import_child_package_from_package/package_message.proto b/betterproto/tests/inputs/import_child_package_from_package/package_message.proto new file mode 100644 index 0000000..943282c --- /dev/null +++ b/betterproto/tests/inputs/import_child_package_from_package/package_message.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +import "child.proto"; + +package package; + +message PackageMessage { + package.childpackage.ChildMessage c = 1; +} diff --git a/betterproto/tests/inputs/import_child_package_from_root/child.proto b/betterproto/tests/inputs/import_child_package_from_root/child.proto new file mode 100644 index 0000000..c874e14 --- /dev/null +++ b/betterproto/tests/inputs/import_child_package_from_root/child.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package childpackage; + +message Message { + +} diff --git a/betterproto/tests/inputs/import_child_package_from_root/import_child_package_from_root.proto b/betterproto/tests/inputs/import_child_package_from_root/import_child_package_from_root.proto new file mode 100644 index 0000000..d0c111f --- /dev/null +++ b/betterproto/tests/inputs/import_child_package_from_root/import_child_package_from_root.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +import "child.proto"; + +// Tests generated imports when a message in root refers to a message in a child package. + +message Test { + childpackage.Message child = 1; +} diff --git a/betterproto/tests/inputs/import_circular_dependency/import_circular_dependency.proto b/betterproto/tests/inputs/import_circular_dependency/import_circular_dependency.proto new file mode 100644 index 0000000..589d14f --- /dev/null +++ b/betterproto/tests/inputs/import_circular_dependency/import_circular_dependency.proto @@ -0,0 +1,28 @@ +syntax = "proto3"; + +import "root.proto"; +import "other.proto"; + +// This test-case verifies that future implementations will support circular dependencies in the generated python files. +// +// This becomes important when generating 1 python file/module per package, rather than 1 file per proto file. +// +// Scenario: +// +// The proto messages depend on each other in a non-circular way: +// +// Test -------> RootPackageMessage <--------------. +// `------------------------------------> OtherPackageMessage +// +// Test and RootPackageMessage are in different files, but belong to the same package (root): +// +// (Test -------> RootPackageMessage) <------------. +// `------------------------------------> OtherPackageMessage +// +// After grouping the packages into single files or modules, a circular dependency is created: +// +// (root: Test & RootPackageMessage) <-------> (other: OtherPackageMessage) +message Test { + RootPackageMessage message = 1; + other.OtherPackageMessage other =2; +} diff --git a/betterproto/tests/inputs/import_circular_dependency/other.proto b/betterproto/tests/inputs/import_circular_dependency/other.proto new file mode 100644 index 0000000..2b936a9 --- /dev/null +++ b/betterproto/tests/inputs/import_circular_dependency/other.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; + +import "root.proto"; +package other; + +message OtherPackageMessage { + RootPackageMessage rootPackageMessage = 1; +} diff --git a/betterproto/tests/inputs/import_circular_dependency/root.proto b/betterproto/tests/inputs/import_circular_dependency/root.proto new file mode 100644 index 0000000..63d15bf --- /dev/null +++ b/betterproto/tests/inputs/import_circular_dependency/root.proto @@ -0,0 +1,5 @@ +syntax = "proto3"; + +message RootPackageMessage { + +} diff --git a/betterproto/tests/inputs/import_parent_package_from_child/import_parent_package_from_child.proto b/betterproto/tests/inputs/import_parent_package_from_child/import_parent_package_from_child.proto new file mode 100644 index 0000000..c43c1bc --- /dev/null +++ b/betterproto/tests/inputs/import_parent_package_from_child/import_parent_package_from_child.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +import "parent_package_message.proto"; + +package parent.child; + +// Tests generated imports when a message refers to a message defined in its parent package + +message Test { + ParentPackageMessage message_implicit = 1; + parent.ParentPackageMessage message_explicit = 2; +} diff --git a/betterproto/tests/inputs/import_parent_package_from_child/parent_package_message.proto b/betterproto/tests/inputs/import_parent_package_from_child/parent_package_message.proto new file mode 100644 index 0000000..cea3066 --- /dev/null +++ b/betterproto/tests/inputs/import_parent_package_from_child/parent_package_message.proto @@ -0,0 +1,6 @@ +syntax = "proto3"; + +package parent; + +message ParentPackageMessage { +} diff --git a/betterproto/tests/inputs/import_root_package_from_child/import_root_package_from_child.proto b/betterproto/tests/inputs/import_root_package_from_child/import_root_package_from_child.proto new file mode 100644 index 0000000..9e7dbcd --- /dev/null +++ b/betterproto/tests/inputs/import_root_package_from_child/import_root_package_from_child.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +import "root.proto"; + +package child; + +// Tests generated imports when a message inside a child-package refers to a message defined in the root. + +message Test { + RootMessage message = 1; +} diff --git a/betterproto/tests/inputs/import_root_package_from_child/root.proto b/betterproto/tests/inputs/import_root_package_from_child/root.proto new file mode 100644 index 0000000..650b29b --- /dev/null +++ b/betterproto/tests/inputs/import_root_package_from_child/root.proto @@ -0,0 +1,5 @@ +syntax = "proto3"; + + +message RootMessage { +} diff --git a/betterproto/tests/inputs/import_root_sibling/import_root_sibling.proto b/betterproto/tests/inputs/import_root_sibling/import_root_sibling.proto new file mode 100644 index 0000000..1d671b8 --- /dev/null +++ b/betterproto/tests/inputs/import_root_sibling/import_root_sibling.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +import "sibling.proto"; + +// Tests generated imports when a message in the root package refers to another message in the root package + +message Test { + SiblingMessage sibling = 1; +} diff --git a/betterproto/tests/inputs/import_root_sibling/sibling.proto b/betterproto/tests/inputs/import_root_sibling/sibling.proto new file mode 100644 index 0000000..870baff --- /dev/null +++ b/betterproto/tests/inputs/import_root_sibling/sibling.proto @@ -0,0 +1,5 @@ +syntax = "proto3"; + +message SiblingMessage { + +} diff --git a/betterproto/tests/inputs/oneof/oneof-name.json b/betterproto/tests/inputs/oneof/oneof-name.json index bde99de..45960e7 100644 --- a/betterproto/tests/inputs/oneof/oneof-name.json +++ b/betterproto/tests/inputs/oneof/oneof-name.json @@ -1,3 +1,3 @@ { - "name": "foo" + "name": "foobar" } diff --git a/betterproto/tests/inputs/oneof/oneof.json b/betterproto/tests/inputs/oneof/oneof.json index 400decb..0197c99 100644 --- a/betterproto/tests/inputs/oneof/oneof.json +++ b/betterproto/tests/inputs/oneof/oneof.json @@ -1,3 +1,3 @@ { - "count": 1 + "count": 100 } diff --git a/betterproto/tests/inputs/oneof/test_oneof.py b/betterproto/tests/inputs/oneof/test_oneof.py new file mode 100644 index 0000000..400e4fd --- /dev/null +++ b/betterproto/tests/inputs/oneof/test_oneof.py @@ -0,0 +1,15 @@ +import betterproto +from betterproto.tests.output_betterproto.oneof.oneof import Test +from betterproto.tests.util import get_test_case_json_data + + +def test_which_count(): + message = Test() + message.from_json(get_test_case_json_data("oneof")) + assert betterproto.which_one_of(message, "foo") == ("count", 100) + + +def test_which_name(): + message = Test() + message.from_json(get_test_case_json_data("oneof", "oneof-name.json")) + assert betterproto.which_one_of(message, "foo") == ("name", "foobar") diff --git a/betterproto/tests/inputs/oneof_enum/oneof_enum-enum-0.json b/betterproto/tests/inputs/oneof_enum/oneof_enum-enum-0.json new file mode 100644 index 0000000..be30cf0 --- /dev/null +++ b/betterproto/tests/inputs/oneof_enum/oneof_enum-enum-0.json @@ -0,0 +1,3 @@ +{ + "signal": "PASS" +} diff --git a/betterproto/tests/inputs/oneof_enum/oneof_enum-enum-1.json b/betterproto/tests/inputs/oneof_enum/oneof_enum-enum-1.json new file mode 100644 index 0000000..cb63873 --- /dev/null +++ b/betterproto/tests/inputs/oneof_enum/oneof_enum-enum-1.json @@ -0,0 +1,3 @@ +{ + "signal": "RESIGN" +} diff --git a/betterproto/tests/inputs/oneof_enum/oneof_enum.json b/betterproto/tests/inputs/oneof_enum/oneof_enum.json new file mode 100644 index 0000000..3220b70 --- /dev/null +++ b/betterproto/tests/inputs/oneof_enum/oneof_enum.json @@ -0,0 +1,6 @@ +{ + "move": { + "x": 2, + "y": 3 + } +} diff --git a/betterproto/tests/inputs/oneof_enum/oneof_enum.proto b/betterproto/tests/inputs/oneof_enum/oneof_enum.proto new file mode 100644 index 0000000..dfe19d4 --- /dev/null +++ b/betterproto/tests/inputs/oneof_enum/oneof_enum.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +message Test { + oneof action { + Signal signal = 1; + Move move = 2; + } +} + +enum Signal { + PASS = 0; + RESIGN = 1; +} + +message Move { + int32 x = 1; + int32 y = 2; +} \ No newline at end of file diff --git a/betterproto/tests/inputs/oneof_enum/test_oneof_enum.py b/betterproto/tests/inputs/oneof_enum/test_oneof_enum.py new file mode 100644 index 0000000..1d6ea98 --- /dev/null +++ b/betterproto/tests/inputs/oneof_enum/test_oneof_enum.py @@ -0,0 +1,42 @@ +import pytest + +import betterproto +from betterproto.tests.output_betterproto.oneof_enum.oneof_enum import ( + Move, + Signal, + Test, +) +from betterproto.tests.util import get_test_case_json_data + + +@pytest.mark.xfail +def test_which_one_of_returns_enum_with_default_value(): + """ + returns first field when it is enum and set with default value + """ + message = Test() + message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")) + assert message.move is None + assert message.signal == Signal.PASS + assert betterproto.which_one_of(message, "action") == ("signal", Signal.PASS) + + +@pytest.mark.xfail +def test_which_one_of_returns_enum_with_non_default_value(): + """ + returns first field when it is enum and set with non default value + """ + message = Test() + message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")) + assert message.move is None + assert message.signal == Signal.PASS + assert betterproto.which_one_of(message, "action") == ("signal", Signal.RESIGN) + + +@pytest.mark.xfail +def test_which_one_of_returns_second_field_when_set(): + message = Test() + message.from_json(get_test_case_json_data("oneof_enum")) + assert message.move == Move(x=2, y=3) + assert message.signal == 0 + assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3)) diff --git a/betterproto/tests/inputs/service/service.proto b/betterproto/tests/inputs/service/service.proto index aaf4254..7c931ed 100644 --- a/betterproto/tests/inputs/service/service.proto +++ b/betterproto/tests/inputs/service/service.proto @@ -10,6 +10,6 @@ message DoThingResponse { int32 successfulIterations = 1; } -service ExampleService { +service Test { rpc DoThing (DoThingRequest) returns (DoThingResponse); } diff --git a/betterproto/tests/test_service_stub.py b/betterproto/tests/inputs/service/test_service.py similarity index 95% rename from betterproto/tests/test_service_stub.py rename to betterproto/tests/inputs/service/test_service.py index b614e82..ebd9308 100644 --- a/betterproto/tests/test_service_stub.py +++ b/betterproto/tests/inputs/service/test_service.py @@ -7,7 +7,7 @@ from typing import Dict from betterproto.tests.output_betterproto.service.service import ( DoThingResponse, DoThingRequest, - ExampleServiceStub, + TestStub as ExampleServiceStub, ) @@ -29,12 +29,12 @@ class ExampleService: def __mapping__(self) -> Dict[str, grpclib.const.Handler]: return { - "/service.ExampleService/DoThing": grpclib.const.Handler( + "/service.Test/DoThing": grpclib.const.Handler( self.DoThing, grpclib.const.Cardinality.UNARY_UNARY, DoThingRequest, DoThingResponse, - ) + ), } @@ -99,7 +99,7 @@ async def test_service_call_lower_level_with_overrides(): ) as channel: stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata) response = await stub._unary_unary( - "/service.ExampleService/DoThing", + "/service.Test/DoThing", DoThingRequest(ITERATIONS), DoThingResponse, deadline=kwarg_deadline, @@ -123,7 +123,7 @@ async def test_service_call_lower_level_with_overrides(): ) as channel: stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata) response = await stub._unary_unary( - "/service.ExampleService/DoThing", + "/service.Test/DoThing", DoThingRequest(ITERATIONS), DoThingResponse, timeout=kwarg_timeout, diff --git a/betterproto/tests/inputs/xfail.py b/betterproto/tests/inputs/xfail.py new file mode 100644 index 0000000..f80f0f4 --- /dev/null +++ b/betterproto/tests/inputs/xfail.py @@ -0,0 +1,10 @@ +# Test cases that are expected to fail, e.g. unimplemented features or bug-fixes. +# Remove from list when fixed. +tests = { + "import_root_sibling", + "import_child_package_from_package", + "import_root_package_from_child", + "import_parent_package_from_child", + "import_circular_dependency", + "oneof_enum", +} diff --git a/betterproto/tests/mocks.py b/betterproto/tests/mocks.py index 326b892..cd0efaa 100644 --- a/betterproto/tests/mocks.py +++ b/betterproto/tests/mocks.py @@ -5,8 +5,8 @@ from grpclib.client import Channel class MockChannel(Channel): # noinspection PyMissingConstructor - def __init__(self, responses: List) -> None: - self.responses = responses + def __init__(self, responses=None) -> None: + self.responses = responses if responses else [] self.requests = [] def request(self, route, cardinality, request, response_type, **kwargs): diff --git a/betterproto/tests/test_inputs.py b/betterproto/tests/test_inputs.py index 628344e..1d31348 100644 --- a/betterproto/tests/test_inputs.py +++ b/betterproto/tests/test_inputs.py @@ -2,10 +2,15 @@ import importlib import json import os import sys -import pytest -import betterproto -from betterproto.tests.util import get_directories, inputs_path from collections import namedtuple +from typing import Set + +import pytest + +import betterproto +from betterproto.tests.inputs import xfail +from betterproto.tests.mocks import MockChannel +from betterproto.tests.util import get_directories, get_test_case_json_data, inputs_path # Force pure-python implementation instead of C++, otherwise imports # break things because we can't properly reset the symbol database. @@ -16,12 +21,34 @@ from google.protobuf.descriptor_pool import DescriptorPool from google.protobuf.json_format import Parse -excluded_test_cases = { - "googletypes_response", - "googletypes_response_embedded", - "service", -} -test_case_names = {*get_directories(inputs_path)} - excluded_test_cases +class TestCases: + def __init__(self, path, services: Set[str], xfail: Set[str]): + _all = set(get_directories(path)) + _services = services + _messages = _all - services + _messages_with_json = { + test for test in _messages if get_test_case_json_data(test) + } + + self.all = self.apply_xfail_marks(_all, xfail) + self.services = self.apply_xfail_marks(_services, xfail) + self.messages = self.apply_xfail_marks(_messages, xfail) + self.messages_with_json = self.apply_xfail_marks(_messages_with_json, xfail) + + @staticmethod + def apply_xfail_marks(test_set: Set[str], xfail: Set[str]): + return [ + pytest.param(test, marks=pytest.mark.xfail) if test in xfail else test + for test in test_set + ] + + +test_cases = TestCases( + path=inputs_path, + # test cases for services + services={"googletypes_response", "googletypes_response_embedded", "service"}, + xfail=xfail.tests, +) plugin_output_package = "betterproto.tests.output_betterproto" reference_output_package = "betterproto.tests.output_reference" @@ -30,7 +57,7 @@ reference_output_package = "betterproto.tests.output_reference" TestData = namedtuple("TestData", "plugin_module, reference_module, json_data") -@pytest.fixture(scope="module", params=test_case_names) +@pytest.fixture def test_data(request): test_case_name = request.param @@ -45,24 +72,28 @@ def test_data(request): sys.path.append(reference_module_root) - yield TestData( - plugin_module=importlib.import_module( - f"{plugin_output_package}.{test_case_name}.{test_case_name}" - ), - reference_module=importlib.import_module( - f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2" - ), - json_data=get_test_case_json_data(test_case_name), + yield ( + TestData( + plugin_module=importlib.import_module( + f"{plugin_output_package}.{test_case_name}.{test_case_name}" + ), + reference_module=lambda: importlib.import_module( + f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2" + ), + json_data=get_test_case_json_data(test_case_name), + ) ) sys.path.remove(reference_module_root) +@pytest.mark.parametrize("test_data", test_cases.messages, indirect=True) def test_message_can_instantiated(test_data: TestData) -> None: plugin_module, *_ = test_data plugin_module.Test() +@pytest.mark.parametrize("test_data", test_cases.messages, indirect=True) def test_message_equality(test_data: TestData) -> None: plugin_module, *_ = test_data message1 = plugin_module.Test() @@ -70,6 +101,7 @@ def test_message_equality(test_data: TestData) -> None: assert message1 == message2 +@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True) def test_message_json(repeat, test_data: TestData) -> None: plugin_module, _, json_data = test_data @@ -82,10 +114,17 @@ def test_message_json(repeat, test_data: TestData) -> None: assert json.loads(json_data) == json.loads(message_json) +@pytest.mark.parametrize("test_data", test_cases.services, indirect=True) +def test_service_can_be_instantiated(test_data: TestData) -> None: + plugin_module, _, json_data = test_data + plugin_module.TestStub(MockChannel()) + + +@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True) def test_binary_compatibility(repeat, test_data: TestData) -> None: plugin_module, reference_module, json_data = test_data - reference_instance = Parse(json_data, reference_module.Test()) + reference_instance = Parse(json_data, reference_module().Test()) reference_binary_output = reference_instance.SerializeToString() for _ in range(repeat): @@ -99,21 +138,10 @@ def test_binary_compatibility(repeat, test_data: TestData) -> None: # # Generally this can't be relied on, but here we are aiming to match the # # existing Python implementation and aren't doing anything tricky. # # https://developers.google.com/protocol-buffers/docs/encoding#implications + assert bytes(plugin_instance_from_json) == reference_binary_output + assert bytes(plugin_instance_from_binary) == reference_binary_output + assert plugin_instance_from_json == plugin_instance_from_binary assert ( plugin_instance_from_json.to_dict() == plugin_instance_from_binary.to_dict() ) - - -""" -helper methods -""" - - -def get_test_case_json_data(test_case_name): - test_data_path = os.path.join(inputs_path, test_case_name, f"{test_case_name}.json") - if not os.path.exists(test_data_path): - return None - - with open(test_data_path) as fh: - return fh.read() diff --git a/betterproto/tests/util.py b/betterproto/tests/util.py index 83cfd98..11d5052 100644 --- a/betterproto/tests/util.py +++ b/betterproto/tests/util.py @@ -48,3 +48,14 @@ def protoc_reference(path: str, output_dir: str): f"protoc --python_out={output_dir} --proto_path={path} {path}/*.proto", shell=True, ) + + +def get_test_case_json_data(test_case_name, json_file_name=None): + test_data_file_name = json_file_name if json_file_name else f"{test_case_name}.json" + test_data_file_path = os.path.join(inputs_path, test_case_name, test_data_file_name) + + if not os.path.exists(test_data_file_path): + return None + + with open(test_data_file_path) as fh: + return fh.read()