Fix serialization of repeated fields with empty messages (#180)
Extend test config and utils to support exclusion of certain json samples from testing for symetry.
This commit is contained in:
parent
deb623ed14
commit
7368299a70
@ -739,9 +739,18 @@ class Message(ABC):
|
||||
output += _serialize_single(meta.number, TYPE_BYTES, buf)
|
||||
else:
|
||||
for item in value:
|
||||
output += _serialize_single(
|
||||
meta.number, meta.proto_type, item, wraps=meta.wraps or ""
|
||||
output += (
|
||||
_serialize_single(
|
||||
meta.number,
|
||||
meta.proto_type,
|
||||
item,
|
||||
wraps=meta.wraps or "",
|
||||
)
|
||||
# if it's an empty message it still needs to be represented
|
||||
# as an item in the repeated list
|
||||
or b"\n\x00"
|
||||
)
|
||||
|
||||
elif isinstance(value, dict):
|
||||
for k, v in value.items():
|
||||
assert meta.map_types
|
||||
|
@ -19,3 +19,10 @@ services = {
|
||||
"example_service",
|
||||
"empty_service",
|
||||
}
|
||||
|
||||
|
||||
# Indicate json sample messages to skip when testing that json (de)serialization
|
||||
# is symmetrical becuase some cases legitimately are not symmetrical.
|
||||
# Each key references the name of the test scenario and the values in the tuple
|
||||
# Are the names of the json files.
|
||||
non_symmetrical_json = {"empty_repeated": ("empty_repeated",)}
|
||||
|
3
tests/inputs/empty_repeated/empty_repeated.json
Normal file
3
tests/inputs/empty_repeated/empty_repeated.json
Normal file
@ -0,0 +1,3 @@
|
||||
{
|
||||
"msg": [{"values":[]}]
|
||||
}
|
9
tests/inputs/empty_repeated/empty_repeated.proto
Normal file
9
tests/inputs/empty_repeated/empty_repeated.proto
Normal file
@ -0,0 +1,9 @@
|
||||
syntax = "proto3";
|
||||
|
||||
message MessageA {
|
||||
repeated float values = 1;
|
||||
}
|
||||
|
||||
message Test {
|
||||
repeated MessageA msg = 1;
|
||||
}
|
@ -5,11 +5,11 @@ from tests.util import get_test_case_json_data
|
||||
|
||||
def test_which_count():
|
||||
message = Test()
|
||||
message.from_json(get_test_case_json_data("oneof")[0])
|
||||
message.from_json(get_test_case_json_data("oneof")[0].json)
|
||||
assert betterproto.which_one_of(message, "foo") == ("pitied", 100)
|
||||
|
||||
|
||||
def test_which_name():
|
||||
message = Test()
|
||||
message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0])
|
||||
message.from_json(get_test_case_json_data("oneof", "oneof_name.json")[0].json)
|
||||
assert betterproto.which_one_of(message, "foo") == ("pitier", "Mr. T")
|
||||
|
@ -15,7 +15,7 @@ def test_which_one_of_returns_enum_with_default_value():
|
||||
"""
|
||||
message = Test()
|
||||
message.from_json(
|
||||
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0]
|
||||
get_test_case_json_data("oneof_enum", "oneof_enum-enum-0.json")[0].json
|
||||
)
|
||||
|
||||
assert message.move == Move(
|
||||
@ -31,7 +31,7 @@ def test_which_one_of_returns_enum_with_non_default_value():
|
||||
"""
|
||||
message = Test()
|
||||
message.from_json(
|
||||
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0]
|
||||
get_test_case_json_data("oneof_enum", "oneof_enum-enum-1.json")[0].json
|
||||
)
|
||||
assert message.move == Move(
|
||||
x=0, y=0
|
||||
@ -42,7 +42,7 @@ def test_which_one_of_returns_enum_with_non_default_value():
|
||||
|
||||
def test_which_one_of_returns_second_field_when_set():
|
||||
message = Test()
|
||||
message.from_json(get_test_case_json_data("oneof_enum")[0])
|
||||
message.from_json(get_test_case_json_data("oneof_enum")[0].json)
|
||||
assert message.move == Move(x=2, y=3)
|
||||
assert message.signal == Signal.PASS
|
||||
assert betterproto.which_one_of(message, "action") == ("move", Move(x=2, y=3))
|
||||
|
@ -5,7 +5,7 @@ import os
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
from types import ModuleType
|
||||
from typing import Any, Dict, List, Set
|
||||
from typing import Any, Dict, List, Set, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
@ -29,7 +29,12 @@ from google.protobuf.json_format import Parse
|
||||
|
||||
|
||||
class TestCases:
|
||||
def __init__(self, path, services: Set[str], xfail: Set[str]):
|
||||
def __init__(
|
||||
self,
|
||||
path,
|
||||
services: Set[str],
|
||||
xfail: Set[str],
|
||||
):
|
||||
_all = set(get_directories(path)) - {"__pycache__"}
|
||||
_services = services
|
||||
_messages = (_all - services) - {"__pycache__"}
|
||||
@ -175,15 +180,18 @@ def test_message_json(repeat, test_data: TestData) -> None:
|
||||
plugin_module, _, json_data = test_data
|
||||
|
||||
for _ in range(repeat):
|
||||
for json_sample in json_data:
|
||||
for sample in json_data:
|
||||
if sample.belongs_to(test_input_config.non_symmetrical_json):
|
||||
continue
|
||||
|
||||
message: betterproto.Message = plugin_module.Test()
|
||||
|
||||
message.from_json(json_sample)
|
||||
message.from_json(sample.json)
|
||||
message_json = message.to_json(0)
|
||||
|
||||
assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
|
||||
json.loads(json_sample)
|
||||
)
|
||||
assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
|
||||
json.loads(sample.json)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
|
||||
@ -195,13 +203,13 @@ def test_service_can_be_instantiated(test_data: TestData) -> None:
|
||||
def test_binary_compatibility(repeat, test_data: TestData) -> None:
|
||||
plugin_module, reference_module, json_data = test_data
|
||||
|
||||
for json_sample in json_data:
|
||||
reference_instance = Parse(json_sample, reference_module().Test())
|
||||
for sample in json_data:
|
||||
reference_instance = Parse(sample.json, reference_module().Test())
|
||||
reference_binary_output = reference_instance.SerializeToString()
|
||||
|
||||
for _ in range(repeat):
|
||||
plugin_instance_from_json: betterproto.Message = (
|
||||
plugin_module.Test().from_json(json_sample)
|
||||
plugin_module.Test().from_json(sample.json)
|
||||
)
|
||||
plugin_instance_from_binary = plugin_module.Test.FromString(
|
||||
reference_binary_output
|
||||
|
@ -1,11 +1,11 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
import importlib
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Callable, Generator, List, Optional, Union
|
||||
from typing import Callable, Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
|
||||
@ -47,11 +47,24 @@ async def protoc(
|
||||
return stdout, stderr, proc.returncode
|
||||
|
||||
|
||||
def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[str]:
|
||||
@dataclass
|
||||
class TestCaseJsonFile:
|
||||
json: str
|
||||
test_name: str
|
||||
file_name: str
|
||||
|
||||
def belongs_to(self, non_symmetrical_json: Dict[str, Tuple[str, ...]]):
|
||||
return self.file_name in non_symmetrical_json.get(self.test_name, tuple())
|
||||
|
||||
|
||||
def get_test_case_json_data(
|
||||
test_case_name: str, *json_file_names: str
|
||||
) -> List[TestCaseJsonFile]:
|
||||
"""
|
||||
:return:
|
||||
A list of all files found in "inputs_path/test_case_name" with names matching
|
||||
f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by json_file_names
|
||||
A list of all files found in "{inputs_path}/test_case_name" with names matching
|
||||
f"{test_case_name}.json" or f"{test_case_name}_*.json", OR given by
|
||||
json_file_names
|
||||
"""
|
||||
test_case_dir = inputs_path.joinpath(test_case_name)
|
||||
possible_file_paths = [
|
||||
@ -65,7 +78,11 @@ def get_test_case_json_data(test_case_name: str, *json_file_names: str) -> List[
|
||||
if not test_data_file_path.exists():
|
||||
continue
|
||||
with test_data_file_path.open("r") as fh:
|
||||
result.append(fh.read())
|
||||
result.append(
|
||||
TestCaseJsonFile(
|
||||
fh.read(), test_case_name, test_data_file_path.name.split(".")[0]
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@ -86,7 +103,7 @@ def find_module(
|
||||
if predicate(module):
|
||||
return module
|
||||
|
||||
module_path = pathlib.Path(*module.__path__)
|
||||
module_path = Path(*module.__path__)
|
||||
|
||||
for sub in [sub.parent for sub in module_path.glob("**/__init__.py")]:
|
||||
if sub == module_path:
|
||||
|
Loading…
x
Reference in New Issue
Block a user