Properly serialize zero-value messages in a oneof group (#176)

Also improve test utils to make it easier to have multiple json examples.

Co-authored-by: Christopher Chambers <chris@peanutcode.com>
This commit is contained in:
nat 2021-03-15 13:52:35 +01:00 committed by GitHub
parent 2f62189346
commit 342e6559dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 98 additions and 41 deletions

View File

@ -49,9 +49,9 @@ generate = { script = "tests.generate:main", help = "Generate test cases (do
test = { cmd = "pytest --cov src", help = "Run tests" } test = { cmd = "pytest --cov src", help = "Run tests" }
types = { cmd = "mypy src --ignore-missing-imports", help = "Check types with mypy" } types = { cmd = "mypy src --ignore-missing-imports", help = "Check types with mypy" }
format = { cmd = "black . --exclude tests/output_", help = "Apply black formatting to source code" } format = { cmd = "black . --exclude tests/output_", help = "Apply black formatting to source code" }
clean = { cmd = "rm -rf .coverage .mypy_cache .pytest_cache dist betterproto.egg-info **/__pycache__ tests/output_*", help = "Clean out generated files from the workspace" }
docs = { cmd = "sphinx-build docs docs/build", help = "Build the sphinx docs"} docs = { cmd = "sphinx-build docs docs/build", help = "Build the sphinx docs"}
bench = { shell = "asv run master^! && asv run HEAD^! && asv compare master HEAD", help = "Benchmark current commit vs. master branch"} bench = { shell = "asv run master^! && asv run HEAD^! && asv compare master HEAD", help = "Benchmark current commit vs. master branch"}
clean = { cmd = "rm -rf .asv .coverage .mypy_cache .pytest_cache dist betterproto.egg-info **/__pycache__ tests/output_*", help = "Clean out generated files from the workspace" }
# CI tasks # CI tasks
full-test = { shell = "poe generate && tox", help = "Run tests with multiple pythons" } full-test = { shell = "poe generate && tox", help = "Run tests with multiple pythons" }

View File

@ -699,7 +699,7 @@ class Message(ABC):
meta.number, meta.number,
meta.proto_type, meta.proto_type,
value, value,
serialize_empty=serialize_empty, serialize_empty=serialize_empty or selected_in_group,
wraps=meta.wraps or "", wraps=meta.wraps or "",
) )

View File

@ -5,11 +5,11 @@ from tests.util import get_test_case_json_data
def test_which_count(): def test_which_count():
message = Test() message = Test()
message.from_json(get_test_case_json_data("oneof")) message.from_json(get_test_case_json_data("oneof")[0])
assert betterproto.which_one_of(message, "foo") == ("count", 100) assert betterproto.which_one_of(message, "foo") == ("count", 100)
def test_which_name(): def test_which_name():
message = Test() message = Test()
message.from_json(get_test_case_json_data("oneof", "oneof-name.json")) message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0])
assert betterproto.which_one_of(message, "foo") == ("name", "foobar") assert betterproto.which_one_of(message, "foo") == ("name", "foobar")

View File

@ -0,0 +1,3 @@
{
"nothing": {}
}

View File

@ -0,0 +1,15 @@
syntax = "proto3";
message Nothing {}
message MaybeNothing {
string sometimes = 42;
}
message Test {
oneof empty {
Nothing nothing = 1;
MaybeNothing maybe1 = 2;
MaybeNothing maybe2 = 3;
}
}

View File

@ -0,0 +1,3 @@
{
"maybe1": {}
}

View File

@ -0,0 +1,5 @@
{
"maybe2": {
"sometimes": "now"
}
}

View File

@ -14,7 +14,9 @@ def test_which_one_of_returns_enum_with_default_value():
returns first field when it is enum and set with default value returns first field when it is enum and set with default value
""" """
message = Test() message = Test()
message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")) message.from_json(
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0]
)
assert message.move == Move( assert message.move == Move(
x=0, y=0 x=0, y=0
@ -28,7 +30,9 @@ def test_which_one_of_returns_enum_with_non_default_value():
returns first field when it is enum and set with non default value returns first field when it is enum and set with non default value
""" """
message = Test() message = Test()
message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")) message.from_json(
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0]
)
assert message.move == Move( assert message.move == Move(
x=0, y=0 x=0, y=0
) # Proto3 will default this as there is no null ) # Proto3 will default this as there is no null
@ -38,7 +42,7 @@ def test_which_one_of_returns_enum_with_non_default_value():
def test_which_one_of_returns_second_field_when_set(): def test_which_one_of_returns_second_field_when_set():
message = Test() message = Test()
message.from_json(get_test_case_json_data("oneof_enum")) message.from_json(get_test_case_json_data("oneof_enum")[0])
assert message.move == Move(x=2, y=3) assert message.move == Move(x=2, y=3)
assert message.signal == Signal.PASS assert message.signal == Signal.PASS
assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3)) assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))

View File

@ -286,17 +286,23 @@ def test_to_dict_default_values():
def test_oneof_default_value_set_causes_writes_wire(): def test_oneof_default_value_set_causes_writes_wire():
@dataclass
class Empty(betterproto.Message):
pass
@dataclass @dataclass
class Foo(betterproto.Message): class Foo(betterproto.Message):
bar: int = betterproto.int32_field(1, group="group1") bar: int = betterproto.int32_field(1, group="group1")
baz: str = betterproto.string_field(2, group="group1") baz: str = betterproto.string_field(2, group="group1")
qux: Empty = betterproto.message_field(3, group="group1")
def _round_trip_serialization(foo: Foo) -> Foo: def _round_trip_serialization(foo: Foo) -> Foo:
return Foo().parse(bytes(foo)) return Foo().parse(bytes(foo))
foo1 = Foo(bar=0) foo1 = Foo(bar=0)
foo2 = Foo(baz="") foo2 = Foo(baz="")
foo3 = Foo() foo3 = Foo(qux=Empty())
foo4 = Foo()
assert bytes(foo1) == b"\x08\x00" assert bytes(foo1) == b"\x08\x00"
assert ( assert (
@ -312,10 +318,17 @@ def test_oneof_default_value_set_causes_writes_wire():
== ("baz", "") == ("baz", "")
) )
assert bytes(foo3) == b"" assert bytes(foo3) == b"\x1a\x00"
assert ( assert (
betterproto.which_one_of(foo3, "group1") betterproto.which_one_of(foo3, "group1")
== betterproto.which_one_of(_round_trip_serialization(foo3), "group1") == betterproto.which_one_of(_round_trip_serialization(foo3), "group1")
== ("qux", Empty())
)
assert bytes(foo4) == b""
assert (
betterproto.which_one_of(foo4, "group1")
== betterproto.which_one_of(_round_trip_serialization(foo4), "group1")
== ("", None) == ("", None)
) )

View File

@ -126,42 +126,44 @@ def test_message_json(repeat, test_data: TestData) -> None:
plugin_module, _, json_data = test_data plugin_module, _, json_data = test_data
for _ in range(repeat): for _ in range(repeat):
message: betterproto.Message = plugin_module.Test() for json_sample in json_data:
message: betterproto.Message = plugin_module.Test()
message.from_json(json_data) message.from_json(json_sample)
message_json = message.to_json(0) message_json = message.to_json(0)
assert json.loads(message_json) == json.loads(json_data) assert json.loads(message_json) == json.loads(json_sample)
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True) @pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
def test_service_can_be_instantiated(test_data: TestData) -> None: def test_service_can_be_instantiated(test_data: TestData) -> None:
plugin_module, _, json_data = test_data test_data.plugin_module.TestStub(MockChannel())
plugin_module.TestStub(MockChannel())
@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True) @pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
def test_binary_compatibility(repeat, test_data: TestData) -> None: def test_binary_compatibility(repeat, test_data: TestData) -> None:
plugin_module, reference_module, json_data = test_data plugin_module, reference_module, json_data = test_data
reference_instance = Parse(json_data, reference_module().Test()) for json_sample in json_data:
reference_binary_output = reference_instance.SerializeToString() reference_instance = Parse(json_sample, reference_module().Test())
reference_binary_output = reference_instance.SerializeToString()
for _ in range(repeat): for _ in range(repeat):
plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json( plugin_instance_from_json: betterproto.Message = (
json_data plugin_module.Test().from_json(json_sample)
) )
plugin_instance_from_binary = plugin_module.Test.FromString( plugin_instance_from_binary = plugin_module.Test.FromString(
reference_binary_output reference_binary_output
) )
# # Generally this can't be relied on, but here we are aiming to match the # # Generally this can't be relied on, but here we are aiming to match the
# # existing Python implementation and aren't doing anything tricky. # # existing Python implementation and aren't doing anything tricky.
# # https://developers.google.com/protocol-buffers/docs/encoding#implications # # https://developers.google.com/protocol-buffers/docs/encoding#implications
assert bytes(plugin_instance_from_json) == reference_binary_output assert bytes(plugin_instance_from_json) == reference_binary_output
assert bytes(plugin_instance_from_binary) == reference_binary_output assert bytes(plugin_instance_from_binary) == reference_binary_output
assert plugin_instance_from_json == plugin_instance_from_binary assert plugin_instance_from_json == plugin_instance_from_binary
assert ( assert (
plugin_instance_from_json.to_dict() == plugin_instance_from_binary.to_dict() plugin_instance_from_json.to_dict()
) == plugin_instance_from_binary.to_dict()
)

View File

@ -5,7 +5,7 @@ import pathlib
import sys import sys
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import Callable, Generator, Optional, Union from typing import Callable, Generator, List, Optional, Union
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
@ -47,15 +47,27 @@ async def protoc(
return stdout, stderr, proc.returncode return stdout, stderr, proc.returncode
def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = None): def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[str]:
test_data_file_name = json_file_name or f"{test_case_name}.json" """
test_data_file_path = inputs_path.joinpath(test_case_name, test_data_file_name) :return:
A list of all files found in "inputs_path/test_case_name" with names matching
f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by json_file_names
"""
test_case_dir = inputs_path.joinpath(test_case_name)
possible_file_paths = [
*(test_case_dir.joinpath(json_file_name) for json_file_name in json_file_names),
test_case_dir.joinpath(f"{test_case_name}.json"),
*test_case_dir.glob(f"{test_case_name}_*.json"),
]
if not test_data_file_path.exists(): result = []
return None for test_data_file_path in possible_file_paths:
if not test_data_file_path.exists():
continue
with test_data_file_path.open("r") as fh:
result.append(fh.read())
with test_data_file_path.open("r") as fh: return result
return fh.read()
def find_module( def find_module(