Added support for infinite and nan floats/doubles (#215)
- Added support for the custom double values from the protobuf json spec: "Infinity", "-Infinity", and "NaN" - Added `infinite_floats` test data - Updated Message.__eq__ to consider nan values equal - Updated `test_message_json` and `test_binary_compatibility` to replace NaN float values in dictionaries before comparison (because two NaN values are not equal)
This commit is contained in:
parent
bb646fe26f
commit
7c5ee47e68
@ -2,6 +2,7 @@ import dataclasses
|
|||||||
import enum
|
import enum
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
|
import math
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
import typing
|
import typing
|
||||||
@ -113,6 +114,12 @@ def datetime_default_gen() -> datetime:
|
|||||||
DATETIME_ZERO = datetime_default_gen()
|
DATETIME_ZERO = datetime_default_gen()
|
||||||
|
|
||||||
|
|
||||||
|
# Special protobuf json doubles
|
||||||
|
INFINITY = "Infinity"
|
||||||
|
NEG_INFINITY = "-Infinity"
|
||||||
|
NAN = "NaN"
|
||||||
|
|
||||||
|
|
||||||
class Casing(enum.Enum):
|
class Casing(enum.Enum):
|
||||||
"""Casing constants for serialization."""
|
"""Casing constants for serialization."""
|
||||||
|
|
||||||
@ -369,6 +376,51 @@ def _serialize_single(
|
|||||||
return bytes(output)
|
return bytes(output)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_float(value: Any) -> float:
|
||||||
|
"""Parse the given value to a float
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
value : Any
|
||||||
|
Value to parse
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
float
|
||||||
|
Parsed value
|
||||||
|
"""
|
||||||
|
if value == INFINITY:
|
||||||
|
return float("inf")
|
||||||
|
if value == NEG_INFINITY:
|
||||||
|
return -float("inf")
|
||||||
|
if value == NAN:
|
||||||
|
return float("nan")
|
||||||
|
return float(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _dump_float(value: float) -> Union[float, str]:
|
||||||
|
"""Dump the given float to JSON
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
value : float
|
||||||
|
Value to dump
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Union[float, str]
|
||||||
|
Dumped valid, either a float or the strings
|
||||||
|
"Infinity" or "-Infinity"
|
||||||
|
"""
|
||||||
|
if value == float("inf"):
|
||||||
|
return INFINITY
|
||||||
|
if value == -float("inf"):
|
||||||
|
return NEG_INFINITY
|
||||||
|
if value == float("nan"):
|
||||||
|
return NAN
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]:
|
def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
Decode a single varint value from a byte buffer. Returns the value and the
|
Decode a single varint value from a byte buffer. Returns the value and the
|
||||||
@ -564,6 +616,17 @@ class Message(ABC):
|
|||||||
other_val = other._get_field_default(field_name)
|
other_val = other._get_field_default(field_name)
|
||||||
|
|
||||||
if self_val != other_val:
|
if self_val != other_val:
|
||||||
|
# We consider two nan values to be the same for the
|
||||||
|
# purposes of comparing messages (otherwise a message
|
||||||
|
# is not equal to itself)
|
||||||
|
if (
|
||||||
|
isinstance(self_val, float)
|
||||||
|
and isinstance(other_val, float)
|
||||||
|
and math.isnan(self_val)
|
||||||
|
and math.isnan(other_val)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@ -1015,6 +1078,11 @@ class Message(ABC):
|
|||||||
else:
|
else:
|
||||||
enum_class: Type[Enum] = field_types[field_name] # noqa
|
enum_class: Type[Enum] = field_types[field_name] # noqa
|
||||||
output[cased_name] = enum_class(value).name
|
output[cased_name] = enum_class(value).name
|
||||||
|
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
|
||||||
|
if field_is_repeated:
|
||||||
|
output[cased_name] = [_dump_float(n) for n in value]
|
||||||
|
else:
|
||||||
|
output[cased_name] = _dump_float(value)
|
||||||
else:
|
else:
|
||||||
output[cased_name] = value
|
output[cased_name] = value
|
||||||
return output
|
return output
|
||||||
@ -1090,6 +1158,11 @@ class Message(ABC):
|
|||||||
v = [enum_cls.from_string(e) for e in v]
|
v = [enum_cls.from_string(e) for e in v]
|
||||||
elif isinstance(v, str):
|
elif isinstance(v, str):
|
||||||
v = enum_cls.from_string(v)
|
v = enum_cls.from_string(v)
|
||||||
|
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
|
||||||
|
if isinstance(value[key], list):
|
||||||
|
v = [_parse_float(n) for n in value[key]]
|
||||||
|
else:
|
||||||
|
v = _parse_float(value[key])
|
||||||
|
|
||||||
if v is not None:
|
if v is not None:
|
||||||
setattr(self, field_name, v)
|
setattr(self, field_name, v)
|
||||||
|
9
tests/inputs/float/float.json
Normal file
9
tests/inputs/float/float.json
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"positive": "Infinity",
|
||||||
|
"negative": "-Infinity",
|
||||||
|
"nan": "NaN",
|
||||||
|
"three": 3.0,
|
||||||
|
"threePointOneFour": 3.14,
|
||||||
|
"negThree": -3.0,
|
||||||
|
"negThreePointOneFour": -3.14
|
||||||
|
}
|
12
tests/inputs/float/float.proto
Normal file
12
tests/inputs/float/float.proto
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
// Some documentation about the Test message.
|
||||||
|
message Test {
|
||||||
|
double positive = 1;
|
||||||
|
double negative = 2;
|
||||||
|
double nan = 3;
|
||||||
|
double three = 4;
|
||||||
|
double three_point_one_four = 5;
|
||||||
|
double neg_three = 6;
|
||||||
|
double neg_three_point_one_four = 7;
|
||||||
|
}
|
@ -1,10 +1,11 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Set
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -69,6 +70,55 @@ def module_has_entry_point(module: ModuleType):
|
|||||||
return any(hasattr(module, attr) for attr in ["Test", "TestStub"])
|
return any(hasattr(module, attr) for attr in ["Test", "TestStub"])
|
||||||
|
|
||||||
|
|
||||||
|
def list_replace_nans(items: List) -> List[Any]:
|
||||||
|
"""Replace float("nan") in a list with the string "NaN"
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
items : List
|
||||||
|
List to update
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
List[Any]
|
||||||
|
Updated list
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
for item in items:
|
||||||
|
if isinstance(item, list):
|
||||||
|
result.append(list_replace_nans(item))
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
result.append(dict_replace_nans(item))
|
||||||
|
elif isinstance(item, float) and math.isnan(item):
|
||||||
|
result.append(betterproto.NAN)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def dict_replace_nans(input_dict: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||||
|
"""Replace float("nan") in a dictionary with the string "NaN"
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
input_dict : Dict[Any, Any]
|
||||||
|
Dictionary to update
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Dict[Any, Any]
|
||||||
|
Updated dictionary
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
for key, value in input_dict.items():
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = dict_replace_nans(value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
value = list_replace_nans(value)
|
||||||
|
elif isinstance(value, float) and math.isnan(value):
|
||||||
|
value = betterproto.NAN
|
||||||
|
result[key] = value
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def test_data(request):
|
def test_data(request):
|
||||||
test_case_name = request.param
|
test_case_name = request.param
|
||||||
@ -81,7 +131,6 @@ def test_data(request):
|
|||||||
reference_module_root = os.path.join(
|
reference_module_root = os.path.join(
|
||||||
*reference_output_package.split("."), test_case_name
|
*reference_output_package.split("."), test_case_name
|
||||||
)
|
)
|
||||||
|
|
||||||
sys.path.append(reference_module_root)
|
sys.path.append(reference_module_root)
|
||||||
|
|
||||||
plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}")
|
plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}")
|
||||||
@ -132,7 +181,9 @@ def test_message_json(repeat, test_data: TestData) -> None:
|
|||||||
message.from_json(json_sample)
|
message.from_json(json_sample)
|
||||||
message_json = message.to_json(0)
|
message_json = message.to_json(0)
|
||||||
|
|
||||||
assert json.loads(message_json) == json.loads(json_sample)
|
assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
|
||||||
|
json.loads(json_sample)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
|
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
|
||||||
@ -156,14 +207,13 @@ def test_binary_compatibility(repeat, test_data: TestData) -> None:
|
|||||||
reference_binary_output
|
reference_binary_output
|
||||||
)
|
)
|
||||||
|
|
||||||
# # Generally this can't be relied on, but here we are aiming to match the
|
# Generally this can't be relied on, but here we are aiming to match the
|
||||||
# # existing Python implementation and aren't doing anything tricky.
|
# existing Python implementation and aren't doing anything tricky.
|
||||||
# # https://developers.google.com/protocol-buffers/docs/encoding#implications
|
# https://developers.google.com/protocol-buffers/docs/encoding#implications
|
||||||
assert bytes(plugin_instance_from_json) == reference_binary_output
|
assert bytes(plugin_instance_from_json) == reference_binary_output
|
||||||
assert bytes(plugin_instance_from_binary) == reference_binary_output
|
assert bytes(plugin_instance_from_binary) == reference_binary_output
|
||||||
|
|
||||||
assert plugin_instance_from_json == plugin_instance_from_binary
|
assert plugin_instance_from_json == plugin_instance_from_binary
|
||||||
assert (
|
assert dict_replace_nans(
|
||||||
plugin_instance_from_json.to_dict()
|
plugin_instance_from_json.to_dict()
|
||||||
== plugin_instance_from_binary.to_dict()
|
) == dict_replace_nans(plugin_instance_from_binary.to_dict())
|
||||||
)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user