Properly serialize zero-value messages in a oneof group (#176)

Also improve test utils to make it easier to have multiple json examples.

Co-authored-by: Christopher Chambers <chris@peanutcode.com>
This commit is contained in:
nat 2021-03-15 13:52:35 +01:00 committed by GitHub
parent 2f62189346
commit 342e6559dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 98 additions and 41 deletions

View File

@ -49,9 +49,9 @@ generate = { script = "tests.generate:main", help = "Generate test cases (do
test = { cmd = "pytest --cov src", help = "Run tests" }
types = { cmd = "mypy src --ignore-missing-imports", help = "Check types with mypy" }
format = { cmd = "black . --exclude tests/output_", help = "Apply black formatting to source code" }
clean = { cmd = "rm -rf .coverage .mypy_cache .pytest_cache dist betterproto.egg-info **/__pycache__ tests/output_*", help = "Clean out generated files from the workspace" }
docs = { cmd = "sphinx-build docs docs/build", help = "Build the sphinx docs"}
bench = { shell = "asv run master^! && asv run HEAD^! && asv compare master HEAD", help = "Benchmark current commit vs. master branch"}
clean = { cmd = "rm -rf .asv .coverage .mypy_cache .pytest_cache dist betterproto.egg-info **/__pycache__ tests/output_*", help = "Clean out generated files from the workspace" }
# CI tasks
full-test = { shell = "poe generate && tox", help = "Run tests with multiple pythons" }

View File

@ -699,7 +699,7 @@ class Message(ABC):
meta.number,
meta.proto_type,
value,
serialize_empty=serialize_empty,
serialize_empty=serialize_empty or selected_in_group,
wraps=meta.wraps or "",
)

View File

@ -5,11 +5,11 @@ from tests.util import get_test_case_json_data
def test_which_count():
message = Test()
message.from_json(get_test_case_json_data("oneof"))
message.from_json(get_test_case_json_data("oneof")[0])
assert betterproto.which_one_of(message, "foo") == ("count", 100)
def test_which_name():
message = Test()
message.from_json(get_test_case_json_data("oneof", "oneof-name.json"))
message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0])
assert betterproto.which_one_of(message, "foo") == ("name", "foobar")

View File

@ -0,0 +1,3 @@
{
"nothing": {}
}

View File

@ -0,0 +1,15 @@
syntax = "proto3";
message Nothing {}
message MaybeNothing {
string sometimes = 42;
}
message Test {
oneof empty {
Nothing nothing = 1;
MaybeNothing maybe1 = 2;
MaybeNothing maybe2 = 3;
}
}

View File

@ -0,0 +1,3 @@
{
"maybe1": {}
}

View File

@ -0,0 +1,5 @@
{
"maybe2": {
"sometimes": "now"
}
}

View File

@ -14,7 +14,9 @@ def test_which_one_of_returns_enum_with_default_value():
returns first field when it is enum and set with default value
"""
message = Test()
message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json"))
message.from_json(
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0]
)
assert message.move == Move(
x=0, y=0
@ -28,7 +30,9 @@ def test_which_one_of_returns_enum_with_non_default_value():
returns first field when it is enum and set with non default value
"""
message = Test()
message.from_json(get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json"))
message.from_json(
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0]
)
assert message.move == Move(
x=0, y=0
) # Proto3 will default this as there is no null
@ -38,7 +42,7 @@ def test_which_one_of_returns_enum_with_non_default_value():
def test_which_one_of_returns_second_field_when_set():
message = Test()
message.from_json(get_test_case_json_data("oneof_enum"))
message.from_json(get_test_case_json_data("oneof_enum")[0])
assert message.move == Move(x=2, y=3)
assert message.signal == Signal.PASS
assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))

View File

@ -286,17 +286,23 @@ def test_to_dict_default_values():
def test_oneof_default_value_set_causes_writes_wire():
@dataclass
class Empty(betterproto.Message):
pass
@dataclass
class Foo(betterproto.Message):
bar: int = betterproto.int32_field(1, group="group1")
baz: str = betterproto.string_field(2, group="group1")
qux: Empty = betterproto.message_field(3, group="group1")
def _round_trip_serialization(foo: Foo) -> Foo:
return Foo().parse(bytes(foo))
foo1 = Foo(bar=0)
foo2 = Foo(baz="")
foo3 = Foo()
foo3 = Foo(qux=Empty())
foo4 = Foo()
assert bytes(foo1) == b"\x08\x00"
assert (
@ -312,10 +318,17 @@ def test_oneof_default_value_set_causes_writes_wire():
== ("baz", "")
)
assert bytes(foo3) == b""
assert bytes(foo3) == b"\x1a\x00"
assert (
betterproto.which_one_of(foo3, "group1")
== betterproto.which_one_of(_round_trip_serialization(foo3), "group1")
== ("qux", Empty())
)
assert bytes(foo4) == b""
assert (
betterproto.which_one_of(foo4, "group1")
== betterproto.which_one_of(_round_trip_serialization(foo4), "group1")
== ("", None)
)

View File

@ -126,42 +126,44 @@ def test_message_json(repeat, test_data: TestData) -> None:
plugin_module, _, json_data = test_data
for _ in range(repeat):
message: betterproto.Message = plugin_module.Test()
for json_sample in json_data:
message: betterproto.Message = plugin_module.Test()
message.from_json(json_data)
message_json = message.to_json(0)
message.from_json(json_sample)
message_json = message.to_json(0)
assert json.loads(message_json) == json.loads(json_data)
assert json.loads(message_json) == json.loads(json_sample)
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
def test_service_can_be_instantiated(test_data: TestData) -> None:
plugin_module, _, json_data = test_data
plugin_module.TestStub(MockChannel())
test_data.plugin_module.TestStub(MockChannel())
@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
def test_binary_compatibility(repeat, test_data: TestData) -> None:
plugin_module, reference_module, json_data = test_data
reference_instance = Parse(json_data, reference_module().Test())
reference_binary_output = reference_instance.SerializeToString()
for json_sample in json_data:
reference_instance = Parse(json_sample, reference_module().Test())
reference_binary_output = reference_instance.SerializeToString()
for _ in range(repeat):
plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json(
json_data
)
plugin_instance_from_binary = plugin_module.Test.FromString(
reference_binary_output
)
for _ in range(repeat):
plugin_instance_from_json: betterproto.Message = (
plugin_module.Test().from_json(json_sample)
)
plugin_instance_from_binary = plugin_module.Test.FromString(
reference_binary_output
)
# # Generally this can't be relied on, but here we are aiming to match the
# # existing Python implementation and aren't doing anything tricky.
# # https://developers.google.com/protocol-buffers/docs/encoding#implications
assert bytes(plugin_instance_from_json) == reference_binary_output
assert bytes(plugin_instance_from_binary) == reference_binary_output
# # Generally this can't be relied on, but here we are aiming to match the
# # existing Python implementation and aren't doing anything tricky.
# # https://developers.google.com/protocol-buffers/docs/encoding#implications
assert bytes(plugin_instance_from_json) == reference_binary_output
assert bytes(plugin_instance_from_binary) == reference_binary_output
assert plugin_instance_from_json == plugin_instance_from_binary
assert (
plugin_instance_from_json.to_dict() == plugin_instance_from_binary.to_dict()
)
assert plugin_instance_from_json == plugin_instance_from_binary
assert (
plugin_instance_from_json.to_dict()
== plugin_instance_from_binary.to_dict()
)

View File

@ -5,7 +5,7 @@ import pathlib
import sys
from pathlib import Path
from types import ModuleType
from typing import Callable, Generator, Optional, Union
from typing import Callable, Generator, List, Optional, Union
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
@ -47,15 +47,27 @@ async def protoc(
return stdout, stderr, proc.returncode
def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = None):
test_data_file_name = json_file_name or f"{test_case_name}.json"
test_data_file_path = inputs_path.joinpath(test_case_name, test_data_file_name)
def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[str]:
"""
:return:
A list of all files found in "inputs_path/test_case_name" with names matching
f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by json_file_names
"""
test_case_dir = inputs_path.joinpath(test_case_name)
possible_file_paths = [
*(test_case_dir.joinpath(json_file_name) for json_file_name in json_file_names),
test_case_dir.joinpath(f"{test_case_name}.json"),
*test_case_dir.glob(f"{test_case_name}_*.json"),
]
if not test_data_file_path.exists():
return None
result = []
for test_data_file_path in possible_file_paths:
if not test_data_file_path.exists():
continue
with test_data_file_path.open("r") as fh:
result.append(fh.read())
with test_data_file_path.open("r") as fh:
return fh.read()
return result
def find_module(