From 653618190217416fba98206196622b9901d503e2 Mon Sep 17 00:00:00 2001 From: Flynn Date: Mon, 9 May 2022 12:34:12 -0400 Subject: [PATCH] Add to/from_pydict methods (#203) * add to/from_pydict methods * Remove unnecessary method call * Fix formatting Co-authored-by: James Hilton-Balfe --- src/betterproto/__init__.py | 130 ++++++++++++++++++++++++++++++++++++ tests/test_features.py | 97 ++++++++++++++++++++++++++- 2 files changed, 226 insertions(+), 1 deletion(-) diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index a69efad..384c260 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -1306,6 +1306,136 @@ class Message(ABC): """ return self.from_dict(json.loads(value)) + def to_pydict( + self, casing: Casing = Casing.CAMEL, include_default_values: bool = False + ) -> Dict[str, Any]: + """ + Returns a python dict representation of this object. + + Parameters + ----------- + casing: :class:`Casing` + The casing to use for key values. Default is :attr:`Casing.CAMEL` for + compatibility purposes. + include_default_values: :class:`bool` + If ``True`` will include the default values of fields. Default is ``False``. + E.g. an ``int32`` field will be included with a value of ``0`` if this is + set to ``True``, otherwise this would be ignored. + + Returns + -------- + Dict[:class:`str`, Any] + The python dict representation of this object. + """ + output: Dict[str, Any] = {} + defaults = self._betterproto.default_gen + for field_name, meta in self._betterproto.meta_by_field_name.items(): + field_is_repeated = defaults[field_name] is list + value = getattr(self, field_name) + cased_name = casing(field_name).rstrip("_") # type: ignore + if meta.proto_type == TYPE_MESSAGE: + if isinstance(value, datetime): + if ( + value != DATETIME_ZERO + or include_default_values + or self._include_default_value_for_oneof( + field_name=field_name, meta=meta + ) + ): + output[cased_name] = value + elif isinstance(value, timedelta): + if ( + value != timedelta(0) + or include_default_values + or self._include_default_value_for_oneof( + field_name=field_name, meta=meta + ) + ): + output[cased_name] = value + elif meta.wraps: + if value is not None or include_default_values: + output[cased_name] = value + elif field_is_repeated: + # Convert each item. + value = [i.to_pydict(casing, include_default_values) for i in value] + if value or include_default_values: + output[cased_name] = value + elif ( + value._serialized_on_wire + or include_default_values + or self._include_default_value_for_oneof( + field_name=field_name, meta=meta + ) + ): + output[cased_name] = value.to_pydict(casing, include_default_values) + elif meta.proto_type == TYPE_MAP: + for k in value: + if hasattr(value[k], "to_pydict"): + value[k] = value[k].to_pydict(casing, include_default_values) + + if value or include_default_values: + output[cased_name] = value + elif ( + value != self._get_field_default(field_name) + or include_default_values + or self._include_default_value_for_oneof( + field_name=field_name, meta=meta + ) + ): + output[cased_name] = value + return output + + def from_pydict(self: T, value: Dict[str, Any]) -> T: + """ + Parse the key/value pairs into the current message instance. This returns the + instance itself and is therefore assignable and chainable. + + Parameters + ----------- + value: Dict[:class:`str`, Any] + The dictionary to parse from. + + Returns + -------- + :class:`Message` + The initialized message. + """ + self._serialized_on_wire = True + for key in value: + field_name = safe_snake_case(key) + meta = self._betterproto.meta_by_field_name.get(field_name) + if not meta: + continue + + if value[key] is not None: + if meta.proto_type == TYPE_MESSAGE: + v = getattr(self, field_name) + if isinstance(v, list): + cls = self._betterproto.cls_by_field[field_name] + for item in value[key]: + v.append(cls().from_pydict(item)) + elif isinstance(v, datetime): + v = value[key] + elif isinstance(v, timedelta): + v = value[key] + elif meta.wraps: + v = value[key] + else: + # NOTE: `from_pydict` mutates the underlying message, so no + # assignment here is necessary. + v.from_pydict(value[key]) + elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: + v = getattr(self, field_name) + cls = self._betterproto.cls_by_field[f"{field_name}.value"] + for k in value[key]: + v[k] = cls().from_pydict(value[key][k]) + else: + v = value[key] + + if v is not None: + setattr(self, field_name, v) + return self + def is_set(self, name: str) -> bool: """ Check if field with the given name has been set. diff --git a/tests/test_features.py b/tests/test_features.py index ffbab47..b59bfe8 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -3,7 +3,10 @@ from copy import ( deepcopy, ) from dataclasses import dataclass -from datetime import datetime +from datetime import ( + datetime, + timedelta, +) from inspect import ( Parameter, signature, @@ -79,6 +82,7 @@ def test_class_init(): foo = Foo(name="foo", child=Bar(name="bar")) assert foo.to_dict() == {"name": "foo", "child": {"name": "bar"}} + assert foo.to_pydict() == {"name": "foo", "child": {"name": "bar"}} def test_enum_as_int_json(): @@ -98,6 +102,11 @@ def test_enum_as_int_json(): foo.bar = 1 assert foo.to_dict() == {"bar": "ONE"} + # Similar expectations for pydict + foo = Foo().from_pydict({"bar": 1}) + assert foo.bar == TestEnum.ONE + assert foo.to_pydict() == {"bar": TestEnum.ONE} + def test_unknown_fields(): @dataclass @@ -188,6 +197,12 @@ def test_json_casing(): "snakeCase": 3, "kabobCase": 4, } + assert test.to_pydict() == { + "pascalCase": 1, + "camelCase": 2, + "snakeCase": 3, + "kabobCase": 4, + } assert test.to_dict(casing=betterproto.Casing.SNAKE) == { "pascal_case": 1, @@ -195,6 +210,12 @@ def test_json_casing(): "snake_case": 3, "kabob_case": 4, } + assert test.to_pydict(casing=betterproto.Casing.SNAKE) == { + "pascal_case": 1, + "camel_case": 2, + "snake_case": 3, + "kabob_case": 4, + } def test_optional_flag(): @@ -230,6 +251,15 @@ def test_to_dict_default_values(): "someBool": False, } + test = TestMessage().from_pydict({}) + + assert test.to_pydict(include_default_values=True) == { + "someInt": 0, + "someDouble": 0.0, + "someStr": "", + "someBool": False, + } + # All default values test = TestMessage().from_dict( {"someInt": 0, "someDouble": 0.0, "someStr": "", "someBool": False} @@ -242,6 +272,17 @@ def test_to_dict_default_values(): "someBool": False, } + test = TestMessage().from_pydict( + {"someInt": 0, "someDouble": 0.0, "someStr": "", "someBool": False} + ) + + assert test.to_pydict(include_default_values=True) == { + "someInt": 0, + "someDouble": 0.0, + "someStr": "", + "someBool": False, + } + # Some default and some other values @dataclass class TestMessage2(betterproto.Message): @@ -278,6 +319,30 @@ def test_to_dict_default_values(): "someDefaultBool": False, } + test = TestMessage2().from_pydict( + { + "someInt": 2, + "someDouble": 1.2, + "someStr": "hello", + "someBool": True, + "someDefaultInt": 0, + "someDefaultDouble": 0.0, + "someDefaultStr": "", + "someDefaultBool": False, + } + ) + + assert test.to_pydict(include_default_values=True) == { + "someInt": 2, + "someDouble": 1.2, + "someStr": "hello", + "someBool": True, + "someDefaultInt": 0, + "someDefaultDouble": 0.0, + "someDefaultStr": "", + "someDefaultBool": False, + } + # Nested messages @dataclass class TestChildMessage(betterproto.Message): @@ -297,6 +362,36 @@ def test_to_dict_default_values(): "someMessage": {"someOtherInt": 0}, } + test = TestParentMessage().from_pydict({"someInt": 0, "someDouble": 1.2}) + + assert test.to_pydict(include_default_values=True) == { + "someInt": 0, + "someDouble": 1.2, + "someMessage": {"someOtherInt": 0}, + } + + +def test_to_dict_datetime_values(): + @dataclass + class TestDatetimeMessage(betterproto.Message): + bar: datetime = betterproto.message_field(1) + baz: timedelta = betterproto.message_field(2) + + test = TestDatetimeMessage().from_dict( + {"bar": "2020-01-01T00:00:00Z", "baz": "86400.000s"} + ) + + assert test.to_dict() == {"bar": "2020-01-01T00:00:00Z", "baz": "86400.000s"} + + test = TestDatetimeMessage().from_pydict( + {"bar": datetime(year=2020, month=1, day=1), "baz": timedelta(days=1)} + ) + + assert test.to_pydict() == { + "bar": datetime(year=2020, month=1, day=1), + "baz": timedelta(days=1), + } + def test_oneof_default_value_set_causes_writes_wire(): @dataclass