Fix serialization of repeated fields with empty messages (#180)

Extend test config and utils to support exclusion of certain json samples from
testing for symetry.
This commit is contained in:
nat 2021-04-06 02:50:45 +02:00 committed by GitHub
parent deb623ed14
commit 7368299a70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 78 additions and 25 deletions

View File

@ -739,9 +739,18 @@ class Message(ABC):
output += _serialize_single(meta.number, TYPE_BYTES, buf) output += _serialize_single(meta.number, TYPE_BYTES, buf)
else: else:
for item in value: for item in value:
output += _serialize_single( output += (
meta.number, meta.proto_type, item, wraps=meta.wraps or "" _serialize_single(
meta.number,
meta.proto_type,
item,
wraps=meta.wraps or "",
)
# if it's an empty message it still needs to be represented
# as an item in the repeated list
or b"\n\x00"
) )
elif isinstance(value, dict): elif isinstance(value, dict):
for k, v in value.items(): for k, v in value.items():
assert meta.map_types assert meta.map_types

View File

@ -19,3 +19,10 @@ services = {
"example_service", "example_service",
"empty_service", "empty_service",
} }
# Indicate json sample messages to skip when testing that json (de)serialization
# is symmetrical becuase some cases legitimately are not symmetrical.
# Each key references the name of the test scenario and the values in the tuple
# Are the names of the json files.
non_symmetrical_json = {"empty_repeated": ("empty_repeated",)}

View File

@ -0,0 +1,3 @@
{
"msg": [{"values":[]}]
}

View File

@ -0,0 +1,9 @@
syntax = "proto3";
message MessageA {
repeated float values = 1;
}
message Test {
repeated MessageA msg = 1;
}

View File

@ -5,11 +5,11 @@ from tests.util import get_test_case_json_data
def test_which_count(): def test_which_count():
message = Test() message = Test()
message.from_json(get_test_case_json_data("oneof")[0]) message.from_json(get_test_case_json_data("oneof")[0].json)
assert betterproto.which_one_of(message, "foo") == ("pitied", 100) assert betterproto.which_one_of(message, "foo") == ("pitied", 100)
def test_which_name(): def test_which_name():
message = Test() message = Test()
message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0]) message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json)
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T") assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")

View File

@ -15,7 +15,7 @@ def test_which_one_of_returns_enum_with_default_value():
""" """
message = Test() message = Test()
message.from_json( message.from_json(
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0] get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json
) )
assert message.move == Move( assert message.move == Move(
@ -31,7 +31,7 @@ def test_which_one_of_returns_enum_with_non_default_value():
""" """
message = Test() message = Test()
message.from_json( message.from_json(
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0] get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json
) )
assert message.move == Move( assert message.move == Move(
x=0, y=0 x=0, y=0
@ -42,7 +42,7 @@ def test_which_one_of_returns_enum_with_non_default_value():
def test_which_one_of_returns_second_field_when_set(): def test_which_one_of_returns_second_field_when_set():
message = Test() message = Test()
message.from_json(get_test_case_json_data("oneof_enum")[0]) message.from_json(get_test_case_json_data("oneof_enum")[0].json)
assert message.move == Move(x=2, y=3) assert message.move == Move(x=2, y=3)
assert message.signal == Signal.PASS assert message.signal == Signal.PASS
assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3)) assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))

View File

@ -5,7 +5,7 @@ import os
import sys import sys
from collections import namedtuple from collections import namedtuple
from types import ModuleType from types import ModuleType
from typing import Any, Dict, List, Set from typing import Any, Dict, List, Set, Tuple
import pytest import pytest
@ -29,7 +29,12 @@ from google.protobuf.json_format import Parse
class TestCases: class TestCases:
def __init__(self, path, services: Set[str], xfail: Set[str]): def __init__(
self,
path,
services: Set[str],
xfail: Set[str],
):
_all = set(get_directories(path)) - {"__pycache__"} _all = set(get_directories(path)) - {"__pycache__"}
_services = services _services = services
_messages = (_all - services) - {"__pycache__"} _messages = (_all - services) - {"__pycache__"}
@ -175,15 +180,18 @@ def test_message_json(repeat, test_data: TestData) -> None:
plugin_module, _, json_data = test_data plugin_module, _, json_data = test_data
for _ in range(repeat): for _ in range(repeat):
for json_sample in json_data: for sample in json_data:
if sample.belongs_to(test_input_config.non_symmetrical_json):
continue
message: betterproto.Message = plugin_module.Test() message: betterproto.Message = plugin_module.Test()
message.from_json(json_sample) message.from_json(sample.json)
message_json = message.to_json(0) message_json = message.to_json(0)
assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans( assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
json.loads(json_sample) json.loads(sample.json)
) )
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True) @pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
@ -195,13 +203,13 @@ def test_service_can_be_instantiated(test_data: TestData) -> None:
def test_binary_compatibility(repeat, test_data: TestData) -> None: def test_binary_compatibility(repeat, test_data: TestData) -> None:
plugin_module, reference_module, json_data = test_data plugin_module, reference_module, json_data = test_data
for json_sample in json_data: for sample in json_data:
reference_instance = Parse(json_sample, reference_module().Test()) reference_instance = Parse(sample.json, reference_module().Test())
reference_binary_output = reference_instance.SerializeToString() reference_binary_output = reference_instance.SerializeToString()
for _ in range(repeat): for _ in range(repeat):
plugin_instance_from_json: betterproto.Message = ( plugin_instance_from_json: betterproto.Message = (
plugin_module.Test().from_json(json_sample) plugin_module.Test().from_json(sample.json)
) )
plugin_instance_from_binary = plugin_module.Test.FromString( plugin_instance_from_binary = plugin_module.Test.FromString(
reference_binary_output reference_binary_output

View File

@ -1,11 +1,11 @@
import asyncio import asyncio
from dataclasses import dataclass
import importlib import importlib
import os import os
import pathlib
import sys
from pathlib import Path from pathlib import Path
import sys
from types import ModuleType from types import ModuleType
from typing import Callable, Generator, List, Optional, Union from typing import Callable, Dict, Generator, List, Optional, Tuple, Union
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
@ -47,11 +47,24 @@ async def protoc(
return stdout, stderr, proc.returncode return stdout, stderr, proc.returncode
def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[str]: @dataclass
class TestCaseJsonFile:
json: str
test_name: str
file_name: str
def belongs_to(self, non_symmetrical_json: Dict[str, Tuple[str, ...]]):
return self.file_name in non_symmetrical_json.get(self.test_name, tuple())
def get_test_case_json_data(
test_case_name: str, *json_file_names: str
) -> List[TestCaseJsonFile]:
""" """
:return: :return:
A list of all files found in "inputs_path/test_case_name" with names matching A list of all files found in "{inputs_path}/test_case_name" with names matching
f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by json_file_names f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by
json_file_names
""" """
test_case_dir = inputs_path.joinpath(test_case_name) test_case_dir = inputs_path.joinpath(test_case_name)
possible_file_paths = [ possible_file_paths = [
@ -65,7 +78,11 @@ def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[
if not test_data_file_path.exists(): if not test_data_file_path.exists():
continue continue
with test_data_file_path.open("r") as fh: with test_data_file_path.open("r") as fh:
result.append(fh.read()) result.append(
TestCaseJsonFile(
fh.read(), test_case_name, test_data_file_path.name.split(".")[0]
)
)
return result return result
@ -86,7 +103,7 @@ def find_module(
if predicate(module): if predicate(module):
return module return module
module_path = pathlib.Path(*module.__path__) module_path = Path(*module.__path__)
for sub in [sub.parent for sub in module_path.glob("**/__init__.py")]: for sub in [sub.parent for sub in module_path.glob("**/__init__.py")]:
if sub == module_path: if sub == module_path: