diff --git a/pyproject.toml b/pyproject.toml index a93a96e..114209f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,9 +49,9 @@ generate = { script = "tests.generate:main", help = "Generate test cases (do test = { cmd = "pytest --cov src", help = "Run tests" } types = { cmd = "mypy src --ignore-missing-imports", help = "Check types with mypy" } format = { cmd = "black . --exclude tests/output_", help = "Apply black formatting to source code" } -clean = { cmd = "rm -rf .coverage .mypy_cache .pytest_cache dist betterproto.egg-info **/__pycache__ tests/output_*", help = "Clean out generated files from the workspace" } docs = { cmd = "sphinx-build docs docs/build", help = "Build the sphinx docs"} bench = { shell = "asv run master^! && asv run HEAD^! && asv compare master HEAD", help = "Benchmark current commit vs. master branch"} +clean = { cmd = "rm -rf .asv .coverage .mypy_cache .pytest_cache dist betterproto.egg-info **/__pycache__ tests/output_*", help = "Clean out generated files from the workspace" } # CI tasks full-test = { shell = "poe generate && tox", help = "Run tests with multiple pythons" } diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index f73c3f4..48a7358 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -699,7 +699,7 @@ class Message(ABC): meta.number, meta.proto_type, value, - serialize_empty=serialize_empty, + serialize_empty=serialize_empty or selected_in_group, wraps=meta.wraps or "", ) diff --git a/tests/inputs/oneof/oneof-name.json b/tests/inputs/oneof/oneof_name.json similarity index 100% rename from tests/inputs/oneof/oneof-name.json rename to tests/inputs/oneof/oneof_name.json diff --git a/tests/inputs/oneof/test_oneof.py b/tests/inputs/oneof/test_oneof.py index c361b53..8a1d3ea 100644 --- a/tests/inputs/oneof/test_oneof.py +++ b/tests/inputs/oneof/test_oneof.py @@ -5,11 +5,11 @@ from tests.util import get_test_case_json_data def test_which_count(): message = Test() - message.from_json(get_test_case_json_data("oneof")) + message.from_json(get_test_case_json_data("oneof")[0]) assert betterproto.which_one_of(message, "foo") == ("count", 100) def test_which_name(): message = Test() - message.from_json(get_test_case_json_data("oneof", "oneof-name.json")) + message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0]) assert betterproto.which_one_of(message, "foo") == ("name", "foobar") diff --git a/tests/inputs/oneof_empty/oneof_empty.json b/tests/inputs/oneof_empty/oneof_empty.json new file mode 100644 index 0000000..9d21c89 --- /dev/null +++ b/tests/inputs/oneof_empty/oneof_empty.json @@ -0,0 +1,3 @@ +{ + "nothing": {} +} diff --git a/tests/inputs/oneof_empty/oneof_empty.proto b/tests/inputs/oneof_empty/oneof_empty.proto new file mode 100644 index 0000000..45ca371 --- /dev/null +++ b/tests/inputs/oneof_empty/oneof_empty.proto @@ -0,0 +1,15 @@ +syntax = "proto3"; + +message Nothing {} + +message MaybeNothing { + string sometimes = 42; +} + +message Test { + oneof empty { + Nothing nothing = 1; + MaybeNothing maybe1 = 2; + MaybeNothing maybe2 = 3; + } +} diff --git a/tests/inputs/oneof_empty/oneof_empty_maybe1.json b/tests/inputs/oneof_empty/oneof_empty_maybe1.json new file mode 100644 index 0000000..f7a2d27 --- /dev/null +++ b/tests/inputs/oneof_empty/oneof_empty_maybe1.json @@ -0,0 +1,3 @@ +{ + "maybe1": {} +} diff --git a/tests/inputs/oneof_empty/oneof_empty_maybe2.json b/tests/inputs/oneof_empty/oneof_empty_maybe2.json new file mode 100644 index 0000000..bc2b385 --- /dev/null +++ b/tests/inputs/oneof_empty/oneof_empty_maybe2.json @@ -0,0 +1,5 @@ +{ + "maybe2": { + "sometimes": "now" + } +} diff --git a/tests/inputs/oneof_empty/test_oneof_empty.py b/tests/inputs/oneof_empty/test_oneof_empty.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/inputs/oneof_enum/test_oneof_enum.py b/tests/inputs/oneof_enum/test_oneof_enum.py index e3eca13..73b37c6 100644 --- a/tests/inputs/oneof_enum/test_oneof_enum.py +++ b/tests/inputs/oneof_enum/test_oneof_enum.py @@ -14,7 +14,9 @@ def test_which_one_of_returns_enum_with_default_value(): returns first field when it is enum and set with default value """ message = Test() - message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")) + message.from_json( + get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0] + ) assert message.move == Move( x=0, y=0 @@ -28,7 +30,9 @@ def test_which_one_of_returns_enum_with_non_default_value(): returns first field when it is enum and set with non default value """ message = Test() - message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")) + message.from_json( + get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0] + ) assert message.move == Move( x=0, y=0 ) # Proto3 will default this as there is no null @@ -38,7 +42,7 @@ def test_which_one_of_returns_enum_with_non_default_value(): def test_which_one_of_returns_second_field_when_set(): message = Test() - message.from_json(get_test_case_json_data("oneof_enum")) + message.from_json(get_test_case_json_data("oneof_enum")[0]) assert message.move == Move(x=2, y=3) assert message.signal == Signal.PASS assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3)) diff --git a/tests/test_features.py b/tests/test_features.py index 9bf30e6..3f44f17 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -286,17 +286,23 @@ def test_to_dict_default_values(): def test_oneof_default_value_set_causes_writes_wire(): + @dataclass + class Empty(betterproto.Message): + pass + @dataclass class Foo(betterproto.Message): bar: int = betterproto.int32_field(1, group="group1") baz: str = betterproto.string_field(2, group="group1") + qux: Empty = betterproto.message_field(3, group="group1") def _round_trip_serialization(foo: Foo) -> Foo: return Foo().parse(bytes(foo)) foo1 = Foo(bar=0) foo2 = Foo(baz="") - foo3 = Foo() + foo3 = Foo(qux=Empty()) + foo4 = Foo() assert bytes(foo1) == b"\x08\x00" assert ( @@ -312,10 +318,17 @@ def test_oneof_default_value_set_causes_writes_wire(): == ("baz", "") ) - assert bytes(foo3) == b"" + assert bytes(foo3) == b"\x1a\x00" assert ( betterproto.which_one_of(foo3, "group1") == betterproto.which_one_of(_round_trip_serialization(foo3), "group1") + == ("qux", Empty()) + ) + + assert bytes(foo4) == b"" + assert ( + betterproto.which_one_of(foo4, "group1") + == betterproto.which_one_of(_round_trip_serialization(foo4), "group1") == ("", None) ) diff --git a/tests/test_inputs.py b/tests/test_inputs.py index f09ee79..e743f64 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -126,42 +126,44 @@ def test_message_json(repeat, test_data: TestData) -> None: plugin_module, _, json_data = test_data for _ in range(repeat): - message: betterproto.Message = plugin_module.Test() + for json_sample in json_data: + message: betterproto.Message = plugin_module.Test() - message.from_json(json_data) - message_json = message.to_json(0) + message.from_json(json_sample) + message_json = message.to_json(0) - assert json.loads(message_json) == json.loads(json_data) + assert json.loads(message_json) == json.loads(json_sample) @pytest.mark.parametrize("test_data", test_cases.services, indirect=True) def test_service_can_be_instantiated(test_data: TestData) -> None: - plugin_module, _, json_data = test_data - plugin_module.TestStub(MockChannel()) + test_data.plugin_module.TestStub(MockChannel()) @pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True) def test_binary_compatibility(repeat, test_data: TestData) -> None: plugin_module, reference_module, json_data = test_data - reference_instance = Parse(json_data, reference_module().Test()) - reference_binary_output = reference_instance.SerializeToString() + for json_sample in json_data: + reference_instance = Parse(json_sample, reference_module().Test()) + reference_binary_output = reference_instance.SerializeToString() - for _ in range(repeat): - plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json( - json_data - ) - plugin_instance_from_binary = plugin_module.Test.FromString( - reference_binary_output - ) + for _ in range(repeat): + plugin_instance_from_json: betterproto.Message = ( + plugin_module.Test().from_json(json_sample) + ) + plugin_instance_from_binary = plugin_module.Test.FromString( + reference_binary_output + ) - # # Generally this can't be relied on, but here we are aiming to match the - # # existing Python implementation and aren't doing anything tricky. - # # https://developers.google.com/protocol-buffers/docs/encoding#implications - assert bytes(plugin_instance_from_json) == reference_binary_output - assert bytes(plugin_instance_from_binary) == reference_binary_output + # # Generally this can't be relied on, but here we are aiming to match the + # # existing Python implementation and aren't doing anything tricky. + # # https://developers.google.com/protocol-buffers/docs/encoding#implications + assert bytes(plugin_instance_from_json) == reference_binary_output + assert bytes(plugin_instance_from_binary) == reference_binary_output - assert plugin_instance_from_json == plugin_instance_from_binary - assert ( - plugin_instance_from_json.to_dict() == plugin_instance_from_binary.to_dict() - ) + assert plugin_instance_from_json == plugin_instance_from_binary + assert ( + plugin_instance_from_json.to_dict() + == plugin_instance_from_binary.to_dict() + ) diff --git a/tests/util.py b/tests/util.py index 6c63141..5dcf155 100644 --- a/tests/util.py +++ b/tests/util.py @@ -5,7 +5,7 @@ import pathlib import sys from pathlib import Path from types import ModuleType -from typing import Callable, Generator, Optional, Union +from typing import Callable, Generator, List, Optional, Union os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" @@ -47,15 +47,27 @@ async def protoc( return stdout, stderr, proc.returncode -def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = None): - test_data_file_name = json_file_name or f"{test_case_name}.json" - test_data_file_path = inputs_path.joinpath(test_case_name, test_data_file_name) +def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[str]: + """ + :return: + A list of all files found in "inputs_path/test_case_name" with names matching + f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by json_file_names + """ + test_case_dir = inputs_path.joinpath(test_case_name) + possible_file_paths = [ + *(test_case_dir.joinpath(json_file_name) for json_file_name in json_file_names), + test_case_dir.joinpath(f"{test_case_name}.json"), + *test_case_dir.glob(f"{test_case_name}_*.json"), + ] - if not test_data_file_path.exists(): - return None + result = [] + for test_data_file_path in possible_file_paths: + if not test_data_file_path.exists(): + continue + with test_data_file_path.open("r") as fh: + result.append(fh.read()) - with test_data_file_path.open("r") as fh: - return fh.read() + return result def find_module(