Added support for infinite and nan floats/doubles (#215)
- Added support for the custom double values from the protobuf json spec: "Infinity", "-Infinity", and "NaN" - Added `infinite_floats` test data - Updated Message.__eq__ to consider nan values equal - Updated `test_message_json` and `test_binary_compatibility` to replace NaN float values in dictionaries before comparison (because two NaN values are not equal)
This commit is contained in:
@@ -2,6 +2,7 @@ import dataclasses
|
||||
import enum
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
import struct
|
||||
import sys
|
||||
import typing
|
||||
@@ -113,6 +114,12 @@ def datetime_default_gen() -> datetime:
|
||||
DATETIME_ZERO = datetime_default_gen()
|
||||
|
||||
|
||||
# Special protobuf json doubles
|
||||
INFINITY = "Infinity"
|
||||
NEG_INFINITY = "-Infinity"
|
||||
NAN = "NaN"
|
||||
|
||||
|
||||
class Casing(enum.Enum):
|
||||
"""Casing constants for serialization."""
|
||||
|
||||
@@ -369,6 +376,51 @@ def _serialize_single(
|
||||
return bytes(output)
|
||||
|
||||
|
||||
def _parse_float(value: Any) -> float:
|
||||
"""Parse the given value to a float
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value : Any
|
||||
Value to parse
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
Parsed value
|
||||
"""
|
||||
if value == INFINITY:
|
||||
return float("inf")
|
||||
if value == NEG_INFINITY:
|
||||
return -float("inf")
|
||||
if value == NAN:
|
||||
return float("nan")
|
||||
return float(value)
|
||||
|
||||
|
||||
def _dump_float(value: float) -> Union[float, str]:
|
||||
"""Dump the given float to JSON
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value : float
|
||||
Value to dump
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[float, str]
|
||||
Dumped valid, either a float or the strings
|
||||
"Infinity" or "-Infinity"
|
||||
"""
|
||||
if value == float("inf"):
|
||||
return INFINITY
|
||||
if value == -float("inf"):
|
||||
return NEG_INFINITY
|
||||
if value == float("nan"):
|
||||
return NAN
|
||||
return value
|
||||
|
||||
|
||||
def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]:
|
||||
"""
|
||||
Decode a single varint value from a byte buffer. Returns the value and the
|
||||
@@ -564,7 +616,18 @@ class Message(ABC):
|
||||
other_val = other._get_field_default(field_name)
|
||||
|
||||
if self_val != other_val:
|
||||
return False
|
||||
# We consider two nan values to be the same for the
|
||||
# purposes of comparing messages (otherwise a message
|
||||
# is not equal to itself)
|
||||
if (
|
||||
isinstance(self_val, float)
|
||||
and isinstance(other_val, float)
|
||||
and math.isnan(self_val)
|
||||
and math.isnan(other_val)
|
||||
):
|
||||
continue
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -1015,6 +1078,11 @@ class Message(ABC):
|
||||
else:
|
||||
enum_class: Type[Enum] = field_types[field_name] # noqa
|
||||
output[cased_name] = enum_class(value).name
|
||||
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
|
||||
if field_is_repeated:
|
||||
output[cased_name] = [_dump_float(n) for n in value]
|
||||
else:
|
||||
output[cased_name] = _dump_float(value)
|
||||
else:
|
||||
output[cased_name] = value
|
||||
return output
|
||||
@@ -1090,6 +1158,11 @@ class Message(ABC):
|
||||
v = [enum_cls.from_string(e) for e in v]
|
||||
elif isinstance(v, str):
|
||||
v = enum_cls.from_string(v)
|
||||
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
|
||||
if isinstance(value[key], list):
|
||||
v = [_parse_float(n) for n in value[key]]
|
||||
else:
|
||||
v = _parse_float(value[key])
|
||||
|
||||
if v is not None:
|
||||
setattr(self, field_name, v)
|
||||
|
Reference in New Issue
Block a user