Fix serialization of repeated fields with empty messages (#180)

Extend test config and utils to support exclusion of certain json samples from
testing for symetry.
This commit is contained in:
nat 2021-04-06 02:50:45 +02:00 committed by GitHub
parent deb623ed14
commit 7368299a70
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 78 additions and 25 deletions

View File

@ -739,9 +739,18 @@ class Message(ABC):
output += _serialize_single(meta.number, TYPE_BYTES, buf)
else:
for item in value:
output += _serialize_single(
meta.number, meta.proto_type, item, wraps=meta.wraps or ""
output += (
_serialize_single(
meta.number,
meta.proto_type,
item,
wraps=meta.wraps or "",
)
# if it's an empty message it still needs to be represented
# as an item in the repeated list
or b"\n\x00"
)
elif isinstance(value, dict):
for k, v in value.items():
assert meta.map_types

View File

@ -19,3 +19,10 @@ services = {
"example_service",
"empty_service",
}
# Indicate json sample messages to skip when testing that json (de)serialization
# is symmetrical becuase some cases legitimately are not symmetrical.
# Each key references the name of the test scenario and the values in the tuple
# Are the names of the json files.
non_symmetrical_json = {"empty_repeated": ("empty_repeated",)}

View File

@ -0,0 +1,3 @@
{
"msg": [{"values":[]}]
}

View File

@ -0,0 +1,9 @@
syntax = "proto3";
message MessageA {
repeated float values = 1;
}
message Test {
repeated MessageA msg = 1;
}

View File

@ -5,11 +5,11 @@ from tests.util import get_test_case_json_data
def test_which_count():
message = Test()
message.from_json(get_test_case_json_data("oneof")[0])
message.from_json(get_test_case_json_data("oneof")[0].json)
assert betterproto.which_one_of(message, "foo") == ("pitied", 100)
def test_which_name():
message = Test()
message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0])
message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json)
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")

View File

@ -15,7 +15,7 @@ def test_which_one_of_returns_enum_with_default_value():
"""
message = Test()
message.from_json(
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0]
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json
)
assert message.move == Move(
@ -31,7 +31,7 @@ def test_which_one_of_returns_enum_with_non_default_value():
"""
message = Test()
message.from_json(
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0]
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json
)
assert message.move == Move(
x=0, y=0
@ -42,7 +42,7 @@ def test_which_one_of_returns_enum_with_non_default_value():
def test_which_one_of_returns_second_field_when_set():
message = Test()
message.from_json(get_test_case_json_data("oneof_enum")[0])
message.from_json(get_test_case_json_data("oneof_enum")[0].json)
assert message.move == Move(x=2, y=3)
assert message.signal == Signal.PASS
assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))

View File

@ -5,7 +5,7 @@ import os
import sys
from collections import namedtuple
from types import ModuleType
from typing import Any, Dict, List, Set
from typing import Any, Dict, List, Set, Tuple
import pytest
@ -29,7 +29,12 @@ from google.protobuf.json_format import Parse
class TestCases:
def __init__(self, path, services: Set[str], xfail: Set[str]):
def __init__(
self,
path,
services: Set[str],
xfail: Set[str],
):
_all = set(get_directories(path)) - {"__pycache__"}
_services = services
_messages = (_all - services) - {"__pycache__"}
@ -175,15 +180,18 @@ def test_message_json(repeat, test_data: TestData) -> None:
plugin_module, _, json_data = test_data
for _ in range(repeat):
for json_sample in json_data:
for sample in json_data:
if sample.belongs_to(test_input_config.non_symmetrical_json):
continue
message: betterproto.Message = plugin_module.Test()
message.from_json(json_sample)
message.from_json(sample.json)
message_json = message.to_json(0)
assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
json.loads(json_sample)
)
assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
json.loads(sample.json)
)
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
@ -195,13 +203,13 @@ def test_service_can_be_instantiated(test_data: TestData) -> None:
def test_binary_compatibility(repeat, test_data: TestData) -> None:
plugin_module, reference_module, json_data = test_data
for json_sample in json_data:
reference_instance = Parse(json_sample, reference_module().Test())
for sample in json_data:
reference_instance = Parse(sample.json, reference_module().Test())
reference_binary_output = reference_instance.SerializeToString()
for _ in range(repeat):
plugin_instance_from_json: betterproto.Message = (
plugin_module.Test().from_json(json_sample)
plugin_module.Test().from_json(sample.json)
)
plugin_instance_from_binary = plugin_module.Test.FromString(
reference_binary_output

View File

@ -1,11 +1,11 @@
import asyncio
from dataclasses import dataclass
import importlib
import os
import pathlib
import sys
from pathlib import Path
import sys
from types import ModuleType
from typing import Callable, Generator, List, Optional, Union
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
@ -47,11 +47,24 @@ async def protoc(
return stdout, stderr, proc.returncode
def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[str]:
@dataclass
class TestCaseJsonFile:
json: str
test_name: str
file_name: str
def belongs_to(self, non_symmetrical_json: Dict[str, Tuple[str, ...]]):
return self.file_name in non_symmetrical_json.get(self.test_name, tuple())
def get_test_case_json_data(
test_case_name: str, *json_file_names: str
) -> List[TestCaseJsonFile]:
"""
:return:
A list of all files found in "inputs_path/test_case_name" with names matching
f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by json_file_names
A list of all files found in "{inputs_path}/test_case_name" with names matching
f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by
json_file_names
"""
test_case_dir = inputs_path.joinpath(test_case_name)
possible_file_paths = [
@ -65,7 +78,11 @@ def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[
if not test_data_file_path.exists():
continue
with test_data_file_path.open("r") as fh:
result.append(fh.read())
result.append(
TestCaseJsonFile(
fh.read(), test_case_name, test_data_file_path.name.split(".")[0]
)
)
return result
@ -86,7 +103,7 @@ def find_module(
if predicate(module):
return module
module_path = pathlib.Path(*module.__path__)
module_path = Path(*module.__path__)
for sub in [sub.parent for sub in module_path.glob("**/__init__.py")]:
if sub == module_path: