diff --git a/betterproto/__init__.py b/betterproto/__init__.py index b826cc0..0e09de9 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -704,11 +704,16 @@ class Message(ABC): def FromString(cls: Type[T], data: bytes) -> T: return cls().parse(data) - def to_dict(self, casing: Casing = Casing.CAMEL) -> dict: + def to_dict(self, casing: Casing = Casing.CAMEL, include_default_values: bool = False) -> dict: """ Returns a dict representation of this message instance which can be used to serialize to e.g. JSON. Defaults to camel casing for compatibility but can be set to other modes. + + `include_default_values` can be set to `True` to include default + values of fields. E.g. an `int32` type field with `0` value will + not be in returned dict if `include_default_values` is set to + `False`. """ output: Dict[str, Any] = {} for field in dataclasses.fields(self): @@ -717,29 +722,30 @@ class Message(ABC): cased_name = casing(field.name).rstrip("_") # type: ignore if meta.proto_type == "message": if isinstance(v, datetime): - if v != DATETIME_ZERO: + if v != DATETIME_ZERO or include_default_values: output[cased_name] = _Timestamp.timestamp_to_json(v) elif isinstance(v, timedelta): - if v != timedelta(0): + if v != timedelta(0) or include_default_values: output[cased_name] = _Duration.delta_to_json(v) elif meta.wraps: - if v is not None: + if v is not None or include_default_values: output[cased_name] = v elif isinstance(v, list): # Convert each item. - v = [i.to_dict(casing) for i in v] + v = [i.to_dict(casing, include_default_values) for i in v] if v: output[cased_name] = v - elif v._serialized_on_wire: - output[cased_name] = v.to_dict(casing) + else: + if v._serialized_on_wire or include_default_values: + output[cased_name] = v.to_dict(casing, include_default_values) elif meta.proto_type == "map": for k in v: if hasattr(v[k], "to_dict"): - v[k] = v[k].to_dict(casing) + v[k] = v[k].to_dict(casing, include_default_values) - if v: + if v or include_default_values: output[cased_name] = v - elif v != self._get_field_default(field, meta): + elif v != self._get_field_default(field, meta) or include_default_values: if meta.proto_type in INT_64_TYPES: if isinstance(v, list): output[cased_name] = [str(n) for n in v] diff --git a/betterproto/tests/test_features.py b/betterproto/tests/test_features.py index ed6deba..dca8cc3 100644 --- a/betterproto/tests/test_features.py +++ b/betterproto/tests/test_features.py @@ -177,3 +177,95 @@ def test_optional_flag(): # Differentiate between not passed and the zero-value. assert Request().parse(b"").flag == None assert Request().parse(b"\n\x00").flag == False + + +def test_to_dict_default_values(): + @dataclass + class TestMessage(betterproto.Message): + some_int: int = betterproto.int32_field(1) + some_double: float = betterproto.double_field(2) + some_str: str = betterproto.string_field(3) + some_bool: bool = betterproto.bool_field(4) + + # Empty dict + test = TestMessage().from_dict({}) + + assert test.to_dict(include_default_values=True) == { + 'someInt': 0, + 'someDouble': 0.0, + 'someStr': '', + 'someBool': False + } + + # All default values + test = TestMessage().from_dict({ + 'someInt': 0, + 'someDouble': 0.0, + 'someStr': '', + 'someBool': False + }) + + assert test.to_dict(include_default_values=True) == { + 'someInt': 0, + 'someDouble': 0.0, + 'someStr': '', + 'someBool': False + } + + # Some default and some other values + @dataclass + class TestMessage2(betterproto.Message): + some_int: int = betterproto.int32_field(1) + some_double: float = betterproto.double_field(2) + some_str: str = betterproto.string_field(3) + some_bool: bool = betterproto.bool_field(4) + some_default_int: int = betterproto.int32_field(5) + some_default_double: float = betterproto.double_field(6) + some_default_str: str = betterproto.string_field(7) + some_default_bool: bool = betterproto.bool_field(8) + + test = TestMessage2().from_dict({ + 'someInt': 2, + 'someDouble': 1.2, + 'someStr': 'hello', + 'someBool': True, + 'someDefaultInt': 0, + 'someDefaultDouble': 0.0, + 'someDefaultStr': '', + 'someDefaultBool': False + }) + + assert test.to_dict(include_default_values=True) == { + 'someInt': 2, + 'someDouble': 1.2, + 'someStr': 'hello', + 'someBool': True, + 'someDefaultInt': 0, + 'someDefaultDouble': 0.0, + 'someDefaultStr': '', + 'someDefaultBool': False + } + + # Nested messages + @dataclass + class TestChildMessage(betterproto.Message): + some_other_int: int = betterproto.int32_field(1) + + @dataclass + class TestParentMessage(betterproto.Message): + some_int: int = betterproto.int32_field(1) + some_double: float = betterproto.double_field(2) + some_message: TestChildMessage = betterproto.message_field(3) + + test = TestParentMessage().from_dict({ + 'someInt': 0, + 'someDouble': 1.2, + }) + + assert test.to_dict(include_default_values=True) == { + 'someInt': 0, + 'someDouble': 1.2, + 'someMessage': { + 'someOtherInt': 0 + } + }