Properly serialize zero-value messages in a oneof group (#176)
Also improve test utils to make it easier to have multiple json examples. Co-authored-by: Christopher Chambers <chris@peanutcode.com>
This commit is contained in:
		| @@ -49,9 +49,9 @@ generate    = { script = "tests.generate:main", help = "Generate test cases (do | |||||||
| test        = { cmd = "pytest --cov src", help = "Run tests" } | test        = { cmd = "pytest --cov src", help = "Run tests" } | ||||||
| types       = { cmd = "mypy src --ignore-missing-imports", help = "Check types with mypy" } | types       = { cmd = "mypy src --ignore-missing-imports", help = "Check types with mypy" } | ||||||
| format      = { cmd = "black . --exclude tests/output_",   help = "Apply black formatting to source code" } | format      = { cmd = "black . --exclude tests/output_",   help = "Apply black formatting to source code" } | ||||||
| clean       = { cmd = "rm -rf .coverage .mypy_cache .pytest_cache dist betterproto.egg-info **/__pycache__ tests/output_*", help = "Clean out generated files from the workspace" } |  | ||||||
| docs        = { cmd = "sphinx-build docs docs/build", help = "Build the sphinx docs"} | docs        = { cmd = "sphinx-build docs docs/build", help = "Build the sphinx docs"} | ||||||
| bench       = { shell = "asv run master^! && asv run HEAD^! && asv compare master HEAD", help = "Benchmark current commit vs. master branch"} | bench       = { shell = "asv run master^! && asv run HEAD^! && asv compare master HEAD", help = "Benchmark current commit vs. master branch"} | ||||||
|  | clean       = { cmd = "rm -rf .asv .coverage .mypy_cache .pytest_cache dist betterproto.egg-info **/__pycache__ tests/output_*", help  = "Clean out generated files from the workspace" } | ||||||
|  |  | ||||||
| # CI tasks | # CI tasks | ||||||
| full-test   = { shell = "poe generate && tox", help = "Run tests with multiple pythons" } | full-test   = { shell = "poe generate && tox", help = "Run tests with multiple pythons" } | ||||||
|   | |||||||
| @@ -699,7 +699,7 @@ class Message(ABC): | |||||||
|                     meta.number, |                     meta.number, | ||||||
|                     meta.proto_type, |                     meta.proto_type, | ||||||
|                     value, |                     value, | ||||||
|                     serialize_empty=serialize_empty, |                     serialize_empty=serialize_empty or selected_in_group, | ||||||
|                     wraps=meta.wraps or "", |                     wraps=meta.wraps or "", | ||||||
|                 ) |                 ) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -5,11 +5,11 @@ from tests.util import get_test_case_json_data | |||||||
|  |  | ||||||
| def test_which_count(): | def test_which_count(): | ||||||
|     message = Test() |     message = Test() | ||||||
|     message.from_json(get_test_case_json_data("oneof")) |     message.from_json(get_test_case_json_data("oneof")[0]) | ||||||
|     assert betterproto.which_one_of(message, "foo") == ("count", 100) |     assert betterproto.which_one_of(message, "foo") == ("count", 100) | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_which_name(): | def test_which_name(): | ||||||
|     message = Test() |     message = Test() | ||||||
|     message.from_json(get_test_case_json_data("oneof", "oneof-name.json")) |     message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0]) | ||||||
|     assert betterproto.which_one_of(message, "foo") == ("name", "foobar") |     assert betterproto.which_one_of(message, "foo") == ("name", "foobar") | ||||||
|   | |||||||
							
								
								
									
										3
									
								
								tests/inputs/oneof_empty/oneof_empty.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								tests/inputs/oneof_empty/oneof_empty.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | |||||||
|  | { | ||||||
|  |   "nothing": {} | ||||||
|  | } | ||||||
							
								
								
									
										15
									
								
								tests/inputs/oneof_empty/oneof_empty.proto
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								tests/inputs/oneof_empty/oneof_empty.proto
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | |||||||
|  | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | message Nothing {} | ||||||
|  |  | ||||||
|  | message MaybeNothing { | ||||||
|  |   string sometimes = 42; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | message Test { | ||||||
|  |   oneof empty { | ||||||
|  |     Nothing nothing = 1; | ||||||
|  |     MaybeNothing maybe1 = 2; | ||||||
|  |     MaybeNothing maybe2 = 3; | ||||||
|  |   } | ||||||
|  | } | ||||||
							
								
								
									
										3
									
								
								tests/inputs/oneof_empty/oneof_empty_maybe1.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								tests/inputs/oneof_empty/oneof_empty_maybe1.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | |||||||
|  | { | ||||||
|  |   "maybe1": {} | ||||||
|  | } | ||||||
							
								
								
									
										5
									
								
								tests/inputs/oneof_empty/oneof_empty_maybe2.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								tests/inputs/oneof_empty/oneof_empty_maybe2.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | { | ||||||
|  |   "maybe2": { | ||||||
|  |     "sometimes": "now" | ||||||
|  |   } | ||||||
|  | } | ||||||
							
								
								
									
										0
									
								
								tests/inputs/oneof_empty/test_oneof_empty.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								tests/inputs/oneof_empty/test_oneof_empty.py
									
									
									
									
									
										Normal file
									
								
							| @@ -14,7 +14,9 @@ def test_which_one_of_returns_enum_with_default_value(): | |||||||
|     returns first field when it is enum and set with default value |     returns first field when it is enum and set with default value | ||||||
|     """ |     """ | ||||||
|     message = Test() |     message = Test() | ||||||
|     message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")) |     message.from_json( | ||||||
|  |         get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0] | ||||||
|  |     ) | ||||||
|  |  | ||||||
|     assert message.move == Move( |     assert message.move == Move( | ||||||
|         x=0, y=0 |         x=0, y=0 | ||||||
| @@ -28,7 +30,9 @@ def test_which_one_of_returns_enum_with_non_default_value(): | |||||||
|     returns first field when it is enum and set with non default value |     returns first field when it is enum and set with non default value | ||||||
|     """ |     """ | ||||||
|     message = Test() |     message = Test() | ||||||
|     message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")) |     message.from_json( | ||||||
|  |         get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0] | ||||||
|  |     ) | ||||||
|     assert message.move == Move( |     assert message.move == Move( | ||||||
|         x=0, y=0 |         x=0, y=0 | ||||||
|     )  # Proto3 will default this as there is no null |     )  # Proto3 will default this as there is no null | ||||||
| @@ -38,7 +42,7 @@ def test_which_one_of_returns_enum_with_non_default_value(): | |||||||
|  |  | ||||||
| def test_which_one_of_returns_second_field_when_set(): | def test_which_one_of_returns_second_field_when_set(): | ||||||
|     message = Test() |     message = Test() | ||||||
|     message.from_json(get_test_case_json_data("oneof_enum")) |     message.from_json(get_test_case_json_data("oneof_enum")[0]) | ||||||
|     assert message.move == Move(x=2, y=3) |     assert message.move == Move(x=2, y=3) | ||||||
|     assert message.signal == Signal.PASS |     assert message.signal == Signal.PASS | ||||||
|     assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3)) |     assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3)) | ||||||
|   | |||||||
| @@ -286,17 +286,23 @@ def test_to_dict_default_values(): | |||||||
|  |  | ||||||
|  |  | ||||||
| def test_oneof_default_value_set_causes_writes_wire(): | def test_oneof_default_value_set_causes_writes_wire(): | ||||||
|  |     @dataclass | ||||||
|  |     class Empty(betterproto.Message): | ||||||
|  |         pass | ||||||
|  |  | ||||||
|     @dataclass |     @dataclass | ||||||
|     class Foo(betterproto.Message): |     class Foo(betterproto.Message): | ||||||
|         bar: int = betterproto.int32_field(1, group="group1") |         bar: int = betterproto.int32_field(1, group="group1") | ||||||
|         baz: str = betterproto.string_field(2, group="group1") |         baz: str = betterproto.string_field(2, group="group1") | ||||||
|  |         qux: Empty = betterproto.message_field(3, group="group1") | ||||||
|  |  | ||||||
|     def _round_trip_serialization(foo: Foo) -> Foo: |     def _round_trip_serialization(foo: Foo) -> Foo: | ||||||
|         return Foo().parse(bytes(foo)) |         return Foo().parse(bytes(foo)) | ||||||
|  |  | ||||||
|     foo1 = Foo(bar=0) |     foo1 = Foo(bar=0) | ||||||
|     foo2 = Foo(baz="") |     foo2 = Foo(baz="") | ||||||
|     foo3 = Foo() |     foo3 = Foo(qux=Empty()) | ||||||
|  |     foo4 = Foo() | ||||||
|  |  | ||||||
|     assert bytes(foo1) == b"\x08\x00" |     assert bytes(foo1) == b"\x08\x00" | ||||||
|     assert ( |     assert ( | ||||||
| @@ -312,10 +318,17 @@ def test_oneof_default_value_set_causes_writes_wire(): | |||||||
|         == ("baz", "") |         == ("baz", "") | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     assert bytes(foo3) == b"" |     assert bytes(foo3) == b"\x1a\x00" | ||||||
|     assert ( |     assert ( | ||||||
|         betterproto.which_one_of(foo3, "group1") |         betterproto.which_one_of(foo3, "group1") | ||||||
|         == betterproto.which_one_of(_round_trip_serialization(foo3), "group1") |         == betterproto.which_one_of(_round_trip_serialization(foo3), "group1") | ||||||
|  |         == ("qux", Empty()) | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     assert bytes(foo4) == b"" | ||||||
|  |     assert ( | ||||||
|  |         betterproto.which_one_of(foo4, "group1") | ||||||
|  |         == betterproto.which_one_of(_round_trip_serialization(foo4), "group1") | ||||||
|         == ("", None) |         == ("", None) | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -126,30 +126,31 @@ def test_message_json(repeat, test_data: TestData) -> None: | |||||||
|     plugin_module, _, json_data = test_data |     plugin_module, _, json_data = test_data | ||||||
|  |  | ||||||
|     for _ in range(repeat): |     for _ in range(repeat): | ||||||
|  |         for json_sample in json_data: | ||||||
|             message: betterproto.Message = plugin_module.Test() |             message: betterproto.Message = plugin_module.Test() | ||||||
|  |  | ||||||
|         message.from_json(json_data) |             message.from_json(json_sample) | ||||||
|             message_json = message.to_json(0) |             message_json = message.to_json(0) | ||||||
|  |  | ||||||
|         assert json.loads(message_json) == json.loads(json_data) |             assert json.loads(message_json) == json.loads(json_sample) | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.parametrize("test_data", test_cases.services, indirect=True) | @pytest.mark.parametrize("test_data", test_cases.services, indirect=True) | ||||||
| def test_service_can_be_instantiated(test_data: TestData) -> None: | def test_service_can_be_instantiated(test_data: TestData) -> None: | ||||||
|     plugin_module, _, json_data = test_data |     test_data.plugin_module.TestStub(MockChannel()) | ||||||
|     plugin_module.TestStub(MockChannel()) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True) | @pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True) | ||||||
| def test_binary_compatibility(repeat, test_data: TestData) -> None: | def test_binary_compatibility(repeat, test_data: TestData) -> None: | ||||||
|     plugin_module, reference_module, json_data = test_data |     plugin_module, reference_module, json_data = test_data | ||||||
|  |  | ||||||
|     reference_instance = Parse(json_data, reference_module().Test()) |     for json_sample in json_data: | ||||||
|  |         reference_instance = Parse(json_sample, reference_module().Test()) | ||||||
|         reference_binary_output = reference_instance.SerializeToString() |         reference_binary_output = reference_instance.SerializeToString() | ||||||
|  |  | ||||||
|         for _ in range(repeat): |         for _ in range(repeat): | ||||||
|         plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json( |             plugin_instance_from_json: betterproto.Message = ( | ||||||
|             json_data |                 plugin_module.Test().from_json(json_sample) | ||||||
|             ) |             ) | ||||||
|             plugin_instance_from_binary = plugin_module.Test.FromString( |             plugin_instance_from_binary = plugin_module.Test.FromString( | ||||||
|                 reference_binary_output |                 reference_binary_output | ||||||
| @@ -163,5 +164,6 @@ def test_binary_compatibility(repeat, test_data: TestData) -> None: | |||||||
|  |  | ||||||
|             assert plugin_instance_from_json == plugin_instance_from_binary |             assert plugin_instance_from_json == plugin_instance_from_binary | ||||||
|             assert ( |             assert ( | ||||||
|             plugin_instance_from_json.to_dict() == plugin_instance_from_binary.to_dict() |                 plugin_instance_from_json.to_dict() | ||||||
|  |                 == plugin_instance_from_binary.to_dict() | ||||||
|             ) |             ) | ||||||
|   | |||||||
| @@ -5,7 +5,7 @@ import pathlib | |||||||
| import sys | import sys | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from types import ModuleType | from types import ModuleType | ||||||
| from typing import Callable, Generator, Optional, Union | from typing import Callable, Generator, List, Optional, Union | ||||||
|  |  | ||||||
| os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" | os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" | ||||||
|  |  | ||||||
| @@ -47,15 +47,27 @@ async def protoc( | |||||||
|     return stdout, stderr, proc.returncode |     return stdout, stderr, proc.returncode | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = None): | def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[str]: | ||||||
|     test_data_file_name = json_file_name or f"{test_case_name}.json" |     """ | ||||||
|     test_data_file_path = inputs_path.joinpath(test_case_name, test_data_file_name) |     :return: | ||||||
|  |         A list of all files found in "inputs_path/test_case_name" with names matching | ||||||
|  |         f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by json_file_names | ||||||
|  |     """ | ||||||
|  |     test_case_dir = inputs_path.joinpath(test_case_name) | ||||||
|  |     possible_file_paths = [ | ||||||
|  |         *(test_case_dir.joinpath(json_file_name) for json_file_name in json_file_names), | ||||||
|  |         test_case_dir.joinpath(f"{test_case_name}.json"), | ||||||
|  |         *test_case_dir.glob(f"{test_case_name}_*.json"), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     result = [] | ||||||
|  |     for test_data_file_path in possible_file_paths: | ||||||
|         if not test_data_file_path.exists(): |         if not test_data_file_path.exists(): | ||||||
|         return None |             continue | ||||||
|  |  | ||||||
|         with test_data_file_path.open("r") as fh: |         with test_data_file_path.open("r") as fh: | ||||||
|         return fh.read() |             result.append(fh.read()) | ||||||
|  |  | ||||||
|  |     return result | ||||||
|  |  | ||||||
|  |  | ||||||
| def find_module( | def find_module( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user