diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 61d9fc2..f6f2d1a 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -2,6 +2,7 @@ import dataclasses import enum import inspect import json +import math import struct import sys import typing @@ -113,6 +114,12 @@ def datetime_default_gen() -> datetime: DATETIME_ZERO = datetime_default_gen() +# Special protobuf json doubles +INFINITY = "Infinity" +NEG_INFINITY = "-Infinity" +NAN = "NaN" + + class Casing(enum.Enum): """Casing constants for serialization.""" @@ -369,6 +376,51 @@ def _serialize_single( return bytes(output) +def _parse_float(value: Any) -> float: + """Parse the given value to a float + + Parameters + ---------- + value : Any + Value to parse + + Returns + ------- + float + Parsed value + """ + if value == INFINITY: + return float("inf") + if value == NEG_INFINITY: + return -float("inf") + if value == NAN: + return float("nan") + return float(value) + + +def _dump_float(value: float) -> Union[float, str]: + """Dump the given float to JSON + + Parameters + ---------- + value : float + Value to dump + + Returns + ------- + Union[float, str] + Dumped valid, either a float or the strings + "Infinity" or "-Infinity" + """ + if value == float("inf"): + return INFINITY + if value == -float("inf"): + return NEG_INFINITY + if value == float("nan"): + return NAN + return value + + def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]: """ Decode a single varint value from a byte buffer. Returns the value and the @@ -564,7 +616,18 @@ class Message(ABC): other_val = other._get_field_default(field_name) if self_val != other_val: - return False + # We consider two nan values to be the same for the + # purposes of comparing messages (otherwise a message + # is not equal to itself) + if ( + isinstance(self_val, float) + and isinstance(other_val, float) + and math.isnan(self_val) + and math.isnan(other_val) + ): + continue + else: + return False return True @@ -1015,6 +1078,11 @@ class Message(ABC): else: enum_class: Type[Enum] = field_types[field_name] # noqa output[cased_name] = enum_class(value).name + elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): + if field_is_repeated: + output[cased_name] = [_dump_float(n) for n in value] + else: + output[cased_name] = _dump_float(value) else: output[cased_name] = value return output @@ -1090,6 +1158,11 @@ class Message(ABC): v = [enum_cls.from_string(e) for e in v] elif isinstance(v, str): v = enum_cls.from_string(v) + elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): + if isinstance(value[key], list): + v = [_parse_float(n) for n in value[key]] + else: + v = _parse_float(value[key]) if v is not None: setattr(self, field_name, v) diff --git a/tests/inputs/float/float.json b/tests/inputs/float/float.json new file mode 100644 index 0000000..3adac97 --- /dev/null +++ b/tests/inputs/float/float.json @@ -0,0 +1,9 @@ +{ + "positive": "Infinity", + "negative": "-Infinity", + "nan": "NaN", + "three": 3.0, + "threePointOneFour": 3.14, + "negThree": -3.0, + "negThreePointOneFour": -3.14 + } diff --git a/tests/inputs/float/float.proto b/tests/inputs/float/float.proto new file mode 100644 index 0000000..79922af --- /dev/null +++ b/tests/inputs/float/float.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +// Some documentation about the Test message. +message Test { + double positive = 1; + double negative = 2; + double nan = 3; + double three = 4; + double three_point_one_four = 5; + double neg_three = 6; + double neg_three_point_one_four = 7; +} diff --git a/tests/test_inputs.py b/tests/test_inputs.py index e743f64..6d6907c 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -1,10 +1,11 @@ import importlib import json +import math import os import sys from collections import namedtuple from types import ModuleType -from typing import Set +from typing import Any, Dict, List, Set import pytest @@ -69,6 +70,55 @@ def module_has_entry_point(module: ModuleType): return any(hasattr(module, attr) for attr in ["Test", "TestStub"]) +def list_replace_nans(items: List) -> List[Any]: + """Replace float("nan") in a list with the string "NaN" + + Parameters + ---------- + items : List + List to update + + Returns + ------- + List[Any] + Updated list + """ + result = [] + for item in items: + if isinstance(item, list): + result.append(list_replace_nans(item)) + elif isinstance(item, dict): + result.append(dict_replace_nans(item)) + elif isinstance(item, float) and math.isnan(item): + result.append(betterproto.NAN) + return result + + +def dict_replace_nans(input_dict: Dict[Any, Any]) -> Dict[Any, Any]: + """Replace float("nan") in a dictionary with the string "NaN" + + Parameters + ---------- + input_dict : Dict[Any, Any] + Dictionary to update + + Returns + ------- + Dict[Any, Any] + Updated dictionary + """ + result = {} + for key, value in input_dict.items(): + if isinstance(value, dict): + value = dict_replace_nans(value) + elif isinstance(value, list): + value = list_replace_nans(value) + elif isinstance(value, float) and math.isnan(value): + value = betterproto.NAN + result[key] = value + return result + + @pytest.fixture def test_data(request): test_case_name = request.param @@ -81,7 +131,6 @@ def test_data(request): reference_module_root = os.path.join( *reference_output_package.split("."), test_case_name ) - sys.path.append(reference_module_root) plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}") @@ -132,7 +181,9 @@ def test_message_json(repeat, test_data: TestData) -> None: message.from_json(json_sample) message_json = message.to_json(0) - assert json.loads(message_json) == json.loads(json_sample) + assert dict_replace_nans(json.loads(message_json)) == dict_replace_nans( + json.loads(json_sample) + ) @pytest.mark.parametrize("test_data", test_cases.services, indirect=True) @@ -156,14 +207,13 @@ def test_binary_compatibility(repeat, test_data: TestData) -> None: reference_binary_output ) - # # Generally this can't be relied on, but here we are aiming to match the - # # existing Python implementation and aren't doing anything tricky. - # # https://developers.google.com/protocol-buffers/docs/encoding#implications + # Generally this can't be relied on, but here we are aiming to match the + # existing Python implementation and aren't doing anything tricky. + # https://developers.google.com/protocol-buffers/docs/encoding#implications assert bytes(plugin_instance_from_json) == reference_binary_output assert bytes(plugin_instance_from_binary) == reference_binary_output assert plugin_instance_from_json == plugin_instance_from_binary - assert ( + assert dict_replace_nans( plugin_instance_from_json.to_dict() - == plugin_instance_from_binary.to_dict() - ) + ) == dict_replace_nans(plugin_instance_from_binary.to_dict())