Support Message.from_dict() as a class and an instance method (#476)

* Make Message.from_dict() a class method

Signed-off-by: Marek Pikuła <marek.pikula@embevity.com>

* Sync 1/2 of review comments

* Sync other half

* Update .pre-commit-config.yaml

* Update __init__.py

* Update utils.py

* Update src/betterproto/__init__.py

* Update .pre-commit-config.yaml

* Update __init__.py

* Update utils.py

* Fix CI again

* Fix failing formatting

---------

Signed-off-by: Marek Pikuła <marek.pikula@embevity.com>
Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
This commit is contained in:
Marek Pikuła 2023-10-25 23:20:23 +02:00 committed by GitHub
parent 02aa4e88b7
commit d9b7608980
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 164 additions and 76 deletions

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import dataclasses import dataclasses
import enum as builtin_enum import enum as builtin_enum
import json import json
@ -22,8 +24,8 @@ from itertools import count
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
BinaryIO,
Callable, Callable,
ClassVar,
Dict, Dict,
Generator, Generator,
Iterable, Iterable,
@ -37,6 +39,7 @@ from typing import (
) )
from dateutil.parser import isoparse from dateutil.parser import isoparse
from typing_extensions import Self
from ._types import T from ._types import T
from ._version import __version__ from ._version import __version__
@ -47,6 +50,10 @@ from .casing import (
) )
from .enum import Enum as Enum from .enum import Enum as Enum
from .grpc.grpclib_client import ServiceStub as ServiceStub from .grpc.grpclib_client import ServiceStub as ServiceStub
from .utils import (
classproperty,
hybridmethod,
)
if TYPE_CHECKING: if TYPE_CHECKING:
@ -729,6 +736,7 @@ class Message(ABC):
_serialized_on_wire: bool _serialized_on_wire: bool
_unknown_fields: bytes _unknown_fields: bytes
_group_current: Dict[str, str] _group_current: Dict[str, str]
_betterproto_meta: ClassVar[ProtoClassMetadata]
def __post_init__(self) -> None: def __post_init__(self) -> None:
# Keep track of whether every field was default # Keep track of whether every field was default
@ -882,17 +890,17 @@ class Message(ABC):
kwargs[name] = value kwargs[name] = value
return self.__class__(**kwargs) # type: ignore return self.__class__(**kwargs) # type: ignore
@property @classproperty
def _betterproto(self) -> ProtoClassMetadata: def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore
""" """
Lazy initialize metadata for each protobuf class. Lazy initialize metadata for each protobuf class.
It may be initialized multiple times in a multi-threaded environment, It may be initialized multiple times in a multi-threaded environment,
but that won't affect the correctness. but that won't affect the correctness.
""" """
meta = getattr(self.__class__, "_betterproto_meta", None) try:
if not meta: return cls._betterproto_meta
meta = ProtoClassMetadata(self.__class__) except AttributeError:
self.__class__._betterproto_meta = meta # type: ignore cls._betterproto_meta = meta = ProtoClassMetadata(cls)
return meta return meta
def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None: def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None:
@ -1512,7 +1520,91 @@ class Message(ABC):
output[cased_name] = value output[cased_name] = value
return output return output
def from_dict(self: T, value: Mapping[str, Any]) -> T: @classmethod
def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
init_kwargs: Dict[str, Any] = {}
for key, value in mapping.items():
field_name = safe_snake_case(key)
try:
meta = cls._betterproto.meta_by_field_name[field_name]
except KeyError:
continue
if value is None:
continue
if meta.proto_type == TYPE_MESSAGE:
sub_cls = cls._betterproto.cls_by_field[field_name]
if sub_cls == datetime:
value = (
[isoparse(item) for item in value]
if isinstance(value, list)
else isoparse(value)
)
elif sub_cls == timedelta:
value = (
[timedelta(seconds=float(item[:-1])) for item in value]
if isinstance(value, list)
else timedelta(seconds=float(value[:-1]))
)
elif not meta.wraps:
value = (
[sub_cls.from_dict(item) for item in value]
if isinstance(value, list)
else sub_cls.from_dict(value)
)
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"]
value = {k: sub_cls.from_dict(v) for k, v in value.items()}
else:
if meta.proto_type in INT_64_TYPES:
value = (
[int(n) for n in value]
if isinstance(value, list)
else int(value)
)
elif meta.proto_type == TYPE_BYTES:
value = (
[b64decode(n) for n in value]
if isinstance(value, list)
else b64decode(value)
)
elif meta.proto_type == TYPE_ENUM:
enum_cls = cls._betterproto.cls_by_field[field_name]
if isinstance(value, list):
value = [enum_cls.from_string(e) for e in value]
elif isinstance(value, str):
value = enum_cls.from_string(value)
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
value = (
[_parse_float(n) for n in value]
if isinstance(value, list)
else _parse_float(value)
)
init_kwargs[field_name] = value
return init_kwargs
@hybridmethod
def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self: # type: ignore
"""
Parse the key/value pairs into the a new message instance.
Parameters
-----------
value: Dict[:class:`str`, Any]
The dictionary to parse from.
Returns
--------
:class:`Message`
The initialized message.
"""
self = cls(**cls._from_dict_init(value))
self._serialized_on_wire = True
return self
@from_dict.instancemethod
def from_dict(self, value: Mapping[str, Any]) -> Self:
""" """
Parse the key/value pairs into the current message instance. This returns the Parse the key/value pairs into the current message instance. This returns the
instance itself and is therefore assignable and chainable. instance itself and is therefore assignable and chainable.
@ -1528,71 +1620,8 @@ class Message(ABC):
The initialized message. The initialized message.
""" """
self._serialized_on_wire = True self._serialized_on_wire = True
for key in value: for field, value in self._from_dict_init(value).items():
field_name = safe_snake_case(key) setattr(self, field, value)
meta = self._betterproto.meta_by_field_name.get(field_name)
if not meta:
continue
if value[key] is not None:
if meta.proto_type == TYPE_MESSAGE:
v = self._get_field_default(field_name)
cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
if cls == datetime:
v = [isoparse(item) for item in value[key]]
elif cls == timedelta:
v = [
timedelta(seconds=float(item[:-1]))
for item in value[key]
]
else:
v = [cls().from_dict(item) for item in value[key]]
elif cls == datetime:
v = isoparse(value[key])
setattr(self, field_name, v)
elif cls == timedelta:
v = timedelta(seconds=float(value[key][:-1]))
setattr(self, field_name, v)
elif meta.wraps:
setattr(self, field_name, value[key])
elif v is None:
setattr(self, field_name, cls().from_dict(value[key]))
else:
# NOTE: `from_dict` mutates the underlying message, so no
# assignment here is necessary.
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field_name)
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
for k in value[key]:
v[k] = cls().from_dict(value[key][k])
else:
v = value[key]
if meta.proto_type in INT_64_TYPES:
if isinstance(value[key], list):
v = [int(n) for n in value[key]]
else:
v = int(value[key])
elif meta.proto_type == TYPE_BYTES:
if isinstance(value[key], list):
v = [b64decode(n) for n in value[key]]
else:
v = b64decode(value[key])
elif meta.proto_type == TYPE_ENUM:
enum_cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
v = [enum_cls.from_string(e) for e in v]
elif isinstance(v, str):
v = enum_cls.from_string(v)
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
if isinstance(value[key], list):
v = [_parse_float(n) for n in value[key]]
else:
v = _parse_float(value[key])
if v is not None:
setattr(self, field_name, v)
return self return self
def to_json( def to_json(
@ -1809,8 +1838,8 @@ class Message(ABC):
@classmethod @classmethod
def _validate_field_groups(cls, values): def _validate_field_groups(cls, values):
group_to_one_ofs = cls._betterproto_meta.oneof_field_by_group # type: ignore group_to_one_ofs = cls._betterproto.oneof_field_by_group
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore field_name_to_meta = cls._betterproto.meta_by_field_name
for group, field_set in group_to_one_ofs.items(): for group, field_set in group_to_one_ofs.items():
if len(field_set) == 1: if len(field_set) == 1:
@ -1837,6 +1866,9 @@ class Message(ABC):
return values return values
Message.__annotations__ = {} # HACK to avoid typing.get_type_hints breaking :)
def serialized_on_wire(message: Message) -> bool: def serialized_on_wire(message: Message) -> bool:
""" """
If this message was or should be serialized on the wire. This can be used to detect If this message was or should be serialized on the wire. This can be used to detect

56
src/betterproto/utils.py Normal file
View File

@ -0,0 +1,56 @@
from __future__ import annotations
from typing import (
Any,
Callable,
Generic,
Optional,
Type,
TypeVar,
)
from typing_extensions import (
Concatenate,
ParamSpec,
Self,
)
SelfT = TypeVar("SelfT")
P = ParamSpec("P")
HybridT = TypeVar("HybridT", covariant=True)
class hybridmethod(Generic[SelfT, P, HybridT]):
def __init__(
self,
func: Callable[
Concatenate[type[SelfT], P], HybridT
], # Must be the classmethod version
):
self.cls_func = func
self.__doc__ = func.__doc__
def instancemethod(self, func: Callable[Concatenate[SelfT, P], HybridT]) -> Self:
self.instance_func = func
return self
def __get__(
self, instance: Optional[SelfT], owner: Type[SelfT]
) -> Callable[P, HybridT]:
if instance is None or self.instance_func is None:
# either bound to the class, or no instance method available
return self.cls_func.__get__(owner, None)
return self.instance_func.__get__(instance, owner)
T_co = TypeVar("T_co")
TT_co = TypeVar("TT_co", bound="type[Any]")
class classproperty(Generic[TT_co, T_co]):
def __init__(self, func: Callable[[TT_co], T_co]):
self.__func__ = func
def __get__(self, instance: Any, type: TT_co) -> T_co:
return self.__func__(type)