Added include_default_values parameter to to_dict function.
				
					
				
			This commit is contained in:
		| @@ -700,11 +700,16 @@ class Message(ABC): | |||||||
|     def FromString(cls: Type[T], data: bytes) -> T: |     def FromString(cls: Type[T], data: bytes) -> T: | ||||||
|         return cls().parse(data) |         return cls().parse(data) | ||||||
|  |  | ||||||
|     def to_dict(self, casing: Casing = Casing.CAMEL) -> dict: |     def to_dict(self, casing: Casing = Casing.CAMEL, include_default_values: bool = False) -> dict: | ||||||
|         """ |         """ | ||||||
|         Returns a dict representation of this message instance which can be |         Returns a dict representation of this message instance which can be | ||||||
|         used to serialize to e.g. JSON. Defaults to camel casing for |         used to serialize to e.g. JSON. Defaults to camel casing for | ||||||
|         compatibility but can be set to other modes. |         compatibility but can be set to other modes. | ||||||
|  |  | ||||||
|  |         `include_default_values` can be set to `True` to include default | ||||||
|  |         values of fields. E.g. an `int32` type field with `0` value will | ||||||
|  |         not be in returned dict if `include_default_values` is set to | ||||||
|  |         `False`. | ||||||
|         """ |         """ | ||||||
|         output: Dict[str, Any] = {} |         output: Dict[str, Any] = {} | ||||||
|         for field in dataclasses.fields(self): |         for field in dataclasses.fields(self): | ||||||
| @@ -713,28 +718,29 @@ class Message(ABC): | |||||||
|             cased_name = casing(field.name).rstrip("_")  # type: ignore |             cased_name = casing(field.name).rstrip("_")  # type: ignore | ||||||
|             if meta.proto_type == "message": |             if meta.proto_type == "message": | ||||||
|                 if isinstance(v, datetime): |                 if isinstance(v, datetime): | ||||||
|                     if v != DATETIME_ZERO: |                     if v != DATETIME_ZERO or include_default_values: | ||||||
|                         output[cased_name] = _Timestamp.timestamp_to_json(v) |                         output[cased_name] = _Timestamp.timestamp_to_json(v) | ||||||
|                 elif isinstance(v, timedelta): |                 elif isinstance(v, timedelta): | ||||||
|                     if v != timedelta(0): |                     if v != timedelta(0) or include_default_values: | ||||||
|                         output[cased_name] = _Duration.delta_to_json(v) |                         output[cased_name] = _Duration.delta_to_json(v) | ||||||
|                 elif meta.wraps: |                 elif meta.wraps: | ||||||
|                     if v is not None: |                     if v is not None or include_default_values: | ||||||
|                         output[cased_name] = v |                         output[cased_name] = v | ||||||
|                 elif isinstance(v, list): |                 elif isinstance(v, list): | ||||||
|                     # Convert each item. |                     # Convert each item. | ||||||
|                     v = [i.to_dict(casing) for i in v] |                     v = [i.to_dict(casing, include_default_values) for i in v] | ||||||
|                     output[cased_name] = v |                     output[cased_name] = v | ||||||
|                 elif v._serialized_on_wire: |                 else: | ||||||
|                     output[cased_name] = v.to_dict(casing) |                     if v._serialized_on_wire or include_default_values: | ||||||
|  |                         output[cased_name] = v.to_dict(casing, include_default_values) | ||||||
|             elif meta.proto_type == "map": |             elif meta.proto_type == "map": | ||||||
|                 for k in v: |                 for k in v: | ||||||
|                     if hasattr(v[k], "to_dict"): |                     if hasattr(v[k], "to_dict"): | ||||||
|                         v[k] = v[k].to_dict(casing) |                         v[k] = v[k].to_dict(casing, include_default_values) | ||||||
|  |  | ||||||
|                 if v: |                 if v or include_default_values: | ||||||
|                     output[cased_name] = v |                     output[cased_name] = v | ||||||
|             elif v != self._get_field_default(field, meta): |             elif v != self._get_field_default(field, meta) or include_default_values: | ||||||
|                 if meta.proto_type in INT_64_TYPES: |                 if meta.proto_type in INT_64_TYPES: | ||||||
|                     if isinstance(v, list): |                     if isinstance(v, list): | ||||||
|                         output[cased_name] = [str(n) for n in v] |                         output[cased_name] = [str(n) for n in v] | ||||||
|   | |||||||
| @@ -162,3 +162,95 @@ def test_optional_flag(): | |||||||
|     # Differentiate between not passed and the zero-value. |     # Differentiate between not passed and the zero-value. | ||||||
|     assert Request().parse(b"").flag == None |     assert Request().parse(b"").flag == None | ||||||
|     assert Request().parse(b"\n\x00").flag == False |     assert Request().parse(b"\n\x00").flag == False | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_to_dict_default_values(): | ||||||
|  |     @dataclass | ||||||
|  |     class TestMessage(betterproto.Message): | ||||||
|  |         some_int: int = betterproto.int32_field(1) | ||||||
|  |         some_double: float = betterproto.double_field(2) | ||||||
|  |         some_str: str = betterproto.string_field(3) | ||||||
|  |         some_bool: bool = betterproto.bool_field(4) | ||||||
|  |  | ||||||
|  |     # Empty dict | ||||||
|  |     test = TestMessage().from_dict({}) | ||||||
|  |  | ||||||
|  |     assert test.to_dict(include_default_values=True) == { | ||||||
|  |         'someInt': 0, | ||||||
|  |         'someDouble': 0.0, | ||||||
|  |         'someStr': '', | ||||||
|  |         'someBool': False | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     # All default values | ||||||
|  |     test = TestMessage().from_dict({ | ||||||
|  |         'someInt': 0, | ||||||
|  |         'someDouble': 0.0, | ||||||
|  |         'someStr': '', | ||||||
|  |         'someBool': False | ||||||
|  |     }) | ||||||
|  |  | ||||||
|  |     assert test.to_dict(include_default_values=True) == { | ||||||
|  |         'someInt': 0, | ||||||
|  |         'someDouble': 0.0, | ||||||
|  |         'someStr': '', | ||||||
|  |         'someBool': False | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     # Some default and some other values | ||||||
|  |     @dataclass | ||||||
|  |     class TestMessage2(betterproto.Message): | ||||||
|  |         some_int: int = betterproto.int32_field(1) | ||||||
|  |         some_double: float = betterproto.double_field(2) | ||||||
|  |         some_str: str = betterproto.string_field(3) | ||||||
|  |         some_bool: bool = betterproto.bool_field(4) | ||||||
|  |         some_default_int: int = betterproto.int32_field(5) | ||||||
|  |         some_default_double: float = betterproto.double_field(6) | ||||||
|  |         some_default_str: str = betterproto.string_field(7) | ||||||
|  |         some_default_bool: bool = betterproto.bool_field(8) | ||||||
|  |  | ||||||
|  |     test = TestMessage2().from_dict({ | ||||||
|  |         'someInt': 2, | ||||||
|  |         'someDouble': 1.2, | ||||||
|  |         'someStr': 'hello', | ||||||
|  |         'someBool': True, | ||||||
|  |         'someDefaultInt': 0, | ||||||
|  |         'someDefaultDouble': 0.0, | ||||||
|  |         'someDefaultStr': '', | ||||||
|  |         'someDefaultBool': False | ||||||
|  |     }) | ||||||
|  |  | ||||||
|  |     assert test.to_dict(include_default_values=True) == { | ||||||
|  |         'someInt': 2, | ||||||
|  |         'someDouble': 1.2, | ||||||
|  |         'someStr': 'hello', | ||||||
|  |         'someBool': True, | ||||||
|  |         'someDefaultInt': 0, | ||||||
|  |         'someDefaultDouble': 0.0, | ||||||
|  |         'someDefaultStr': '', | ||||||
|  |         'someDefaultBool': False | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     # Nested messages | ||||||
|  |     @dataclass | ||||||
|  |     class TestChildMessage(betterproto.Message): | ||||||
|  |         some_other_int: int = betterproto.int32_field(1) | ||||||
|  |  | ||||||
|  |     @dataclass | ||||||
|  |     class TestParentMessage(betterproto.Message): | ||||||
|  |         some_int: int = betterproto.int32_field(1) | ||||||
|  |         some_double: float = betterproto.double_field(2) | ||||||
|  |         some_message: TestChildMessage = betterproto.message_field(3) | ||||||
|  |  | ||||||
|  |     test = TestParentMessage().from_dict({ | ||||||
|  |         'someInt': 0, | ||||||
|  |         'someDouble': 1.2, | ||||||
|  |     }) | ||||||
|  |  | ||||||
|  |     assert test.to_dict(include_default_values=True) == { | ||||||
|  |         'someInt': 0, | ||||||
|  |         'someDouble': 1.2, | ||||||
|  |         'someMessage': { | ||||||
|  |             'someOtherInt': 0 | ||||||
|  |         } | ||||||
|  |     } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user