Added support for infinite and nan floats/doubles (#215)
- Added support for the custom double values from the protobuf json spec: "Infinity", "-Infinity", and "NaN" - Added `infinite_floats` test data - Updated Message.__eq__ to consider nan values equal - Updated `test_message_json` and `test_binary_compatibility` to replace NaN float values in dictionaries before comparison (because two NaN values are not equal)
This commit is contained in:
parent
bb646fe26f
commit
7c5ee47e68
@ -2,6 +2,7 @@ import dataclasses
|
||||
import enum
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
import struct
|
||||
import sys
|
||||
import typing
|
||||
@ -113,6 +114,12 @@ def datetime_default_gen() -> datetime:
|
||||
DATETIME_ZERO = datetime_default_gen()
|
||||
|
||||
|
||||
# Special protobuf json doubles
|
||||
INFINITY = "Infinity"
|
||||
NEG_INFINITY = "-Infinity"
|
||||
NAN = "NaN"
|
||||
|
||||
|
||||
class Casing(enum.Enum):
|
||||
"""Casing constants for serialization."""
|
||||
|
||||
@ -369,6 +376,51 @@ def _serialize_single(
|
||||
return bytes(output)
|
||||
|
||||
|
||||
def _parse_float(value: Any) -> float:
|
||||
"""Parse the given value to a float
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value : Any
|
||||
Value to parse
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
Parsed value
|
||||
"""
|
||||
if value == INFINITY:
|
||||
return float("inf")
|
||||
if value == NEG_INFINITY:
|
||||
return -float("inf")
|
||||
if value == NAN:
|
||||
return float("nan")
|
||||
return float(value)
|
||||
|
||||
|
||||
def _dump_float(value: float) -> Union[float, str]:
|
||||
"""Dump the given float to JSON
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value : float
|
||||
Value to dump
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[float, str]
|
||||
Dumped valid, either a float or the strings
|
||||
"Infinity" or "-Infinity"
|
||||
"""
|
||||
if value == float("inf"):
|
||||
return INFINITY
|
||||
if value == -float("inf"):
|
||||
return NEG_INFINITY
|
||||
if value == float("nan"):
|
||||
return NAN
|
||||
return value
|
||||
|
||||
|
||||
def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]:
|
||||
"""
|
||||
Decode a single varint value from a byte buffer. Returns the value and the
|
||||
@ -564,7 +616,18 @@ class Message(ABC):
|
||||
other_val = other._get_field_default(field_name)
|
||||
|
||||
if self_val != other_val:
|
||||
return False
|
||||
# We consider two nan values to be the same for the
|
||||
# purposes of comparing messages (otherwise a message
|
||||
# is not equal to itself)
|
||||
if (
|
||||
isinstance(self_val, float)
|
||||
and isinstance(other_val, float)
|
||||
and math.isnan(self_val)
|
||||
and math.isnan(other_val)
|
||||
):
|
||||
continue
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@ -1015,6 +1078,11 @@ class Message(ABC):
|
||||
else:
|
||||
enum_class: Type[Enum] = field_types[field_name] # noqa
|
||||
output[cased_name] = enum_class(value).name
|
||||
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
|
||||
if field_is_repeated:
|
||||
output[cased_name] = [_dump_float(n) for n in value]
|
||||
else:
|
||||
output[cased_name] = _dump_float(value)
|
||||
else:
|
||||
output[cased_name] = value
|
||||
return output
|
||||
@ -1090,6 +1158,11 @@ class Message(ABC):
|
||||
v = [enum_cls.from_string(e) for e in v]
|
||||
elif isinstance(v, str):
|
||||
v = enum_cls.from_string(v)
|
||||
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
|
||||
if isinstance(value[key], list):
|
||||
v = [_parse_float(n) for n in value[key]]
|
||||
else:
|
||||
v = _parse_float(value[key])
|
||||
|
||||
if v is not None:
|
||||
setattr(self, field_name, v)
|
||||
|
9
tests/inputs/float/float.json
Normal file
9
tests/inputs/float/float.json
Normal file
@ -0,0 +1,9 @@
|
||||
{
|
||||
"positive": "Infinity",
|
||||
"negative": "-Infinity",
|
||||
"nan": "NaN",
|
||||
"three": 3.0,
|
||||
"threePointOneFour": 3.14,
|
||||
"negThree": -3.0,
|
||||
"negThreePointOneFour": -3.14
|
||||
}
|
12
tests/inputs/float/float.proto
Normal file
12
tests/inputs/float/float.proto
Normal file
@ -0,0 +1,12 @@
|
||||
syntax = "proto3";
|
||||
|
||||
// Some documentation about the Test message.
|
||||
message Test {
|
||||
double positive = 1;
|
||||
double negative = 2;
|
||||
double nan = 3;
|
||||
double three = 4;
|
||||
double three_point_one_four = 5;
|
||||
double neg_three = 6;
|
||||
double neg_three_point_one_four = 7;
|
||||
}
|
@ -1,10 +1,11 @@
|
||||
import importlib
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
from types import ModuleType
|
||||
from typing import Set
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
import pytest
|
||||
|
||||
@ -69,6 +70,55 @@ def module_has_entry_point(module: ModuleType):
|
||||
return any(hasattr(module, attr) for attr in ["Test", "TestStub"])
|
||||
|
||||
|
||||
def list_replace_nans(items: List) -> List[Any]:
|
||||
"""Replace float("nan") in a list with the string "NaN"
|
||||
|
||||
Parameters
|
||||
----------
|
||||
items : List
|
||||
List to update
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[Any]
|
||||
Updated list
|
||||
"""
|
||||
result = []
|
||||
for item in items:
|
||||
if isinstance(item, list):
|
||||
result.append(list_replace_nans(item))
|
||||
elif isinstance(item, dict):
|
||||
result.append(dict_replace_nans(item))
|
||||
elif isinstance(item, float) and math.isnan(item):
|
||||
result.append(betterproto.NAN)
|
||||
return result
|
||||
|
||||
|
||||
def dict_replace_nans(input_dict: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||
"""Replace float("nan") in a dictionary with the string "NaN"
|
||||
|
||||
Parameters
|
||||
----------
|
||||
input_dict : Dict[Any, Any]
|
||||
Dictionary to update
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[Any, Any]
|
||||
Updated dictionary
|
||||
"""
|
||||
result = {}
|
||||
for key, value in input_dict.items():
|
||||
if isinstance(value, dict):
|
||||
value = dict_replace_nans(value)
|
||||
elif isinstance(value, list):
|
||||
value = list_replace_nans(value)
|
||||
elif isinstance(value, float) and math.isnan(value):
|
||||
value = betterproto.NAN
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_data(request):
|
||||
test_case_name = request.param
|
||||
@ -81,7 +131,6 @@ def test_data(request):
|
||||
reference_module_root = os.path.join(
|
||||
*reference_output_package.split("."), test_case_name
|
||||
)
|
||||
|
||||
sys.path.append(reference_module_root)
|
||||
|
||||
plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}")
|
||||
@ -132,7 +181,9 @@ def test_message_json(repeat, test_data: TestData) -> None:
|
||||
message.from_json(json_sample)
|
||||
message_json = message.to_json(0)
|
||||
|
||||
assert json.loads(message_json) == json.loads(json_sample)
|
||||
assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans(
|
||||
json.loads(json_sample)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
|
||||
@ -156,14 +207,13 @@ def test_binary_compatibility(repeat, test_data: TestData) -> None:
|
||||
reference_binary_output
|
||||
)
|
||||
|
||||
# # Generally this can't be relied on, but here we are aiming to match the
|
||||
# # existing Python implementation and aren't doing anything tricky.
|
||||
# # https://developers.google.com/protocol-buffers/docs/encoding#implications
|
||||
# Generally this can't be relied on, but here we are aiming to match the
|
||||
# existing Python implementation and aren't doing anything tricky.
|
||||
# https://developers.google.com/protocol-buffers/docs/encoding#implications
|
||||
assert bytes(plugin_instance_from_json) == reference_binary_output
|
||||
assert bytes(plugin_instance_from_binary) == reference_binary_output
|
||||
|
||||
assert plugin_instance_from_json == plugin_instance_from_binary
|
||||
assert (
|
||||
assert dict_replace_nans(
|
||||
plugin_instance_from_json.to_dict()
|
||||
== plugin_instance_from_binary.to_dict()
|
||||
)
|
||||
) == dict_replace_nans(plugin_instance_from_binary.to_dict())
|
||||
|
Loading…
x
Reference in New Issue
Block a user