Support Message.from_dict() as a class and an instance method (#476)

* Make Message.from_dict() a class method

Signed-off-by: Marek Pikuła <marek.pikula@embevity.com>

* Sync 1/2 of review comments

* Sync other half

* Update .pre-commit-config.yaml

* Update __init__.py

* Update utils.py

* Update src/betterproto/__init__.py

* Update .pre-commit-config.yaml

* Update __init__.py

* Update utils.py

* Fix CI again

* Fix failing formatting

---------

Signed-off-by: Marek Pikuła <marek.pikula@embevity.com>
Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
This commit is contained in:
Marek Pikuła 2023-10-25 23:20:23 +02:00 committed by GitHub
parent 02aa4e88b7
commit d9b7608980
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 164 additions and 76 deletions

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import dataclasses
import enum as builtin_enum
import json
@ -22,8 +24,8 @@ from itertools import count
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Callable,
ClassVar,
Dict,
Generator,
Iterable,
@ -37,6 +39,7 @@ from typing import (
)
from dateutil.parser import isoparse
from typing_extensions import Self
from ._types import T
from ._version import __version__
@ -47,6 +50,10 @@ from .casing import (
)
from .enum import Enum as Enum
from .grpc.grpclib_client import ServiceStub as ServiceStub
from .utils import (
classproperty,
hybridmethod,
)
if TYPE_CHECKING:
@ -729,6 +736,7 @@ class Message(ABC):
_serialized_on_wire: bool
_unknown_fields: bytes
_group_current: Dict[str, str]
_betterproto_meta: ClassVar[ProtoClassMetadata]
def __post_init__(self) -> None:
# Keep track of whether every field was default
@ -882,17 +890,17 @@ class Message(ABC):
kwargs[name] = value
return self.__class__(**kwargs) # type: ignore
@property
def _betterproto(self) -> ProtoClassMetadata:
@classproperty
def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore
"""
Lazy initialize metadata for each protobuf class.
It may be initialized multiple times in a multi-threaded environment,
but that won't affect the correctness.
"""
meta = getattr(self.__class__, "_betterproto_meta", None)
if not meta:
meta = ProtoClassMetadata(self.__class__)
self.__class__._betterproto_meta = meta # type: ignore
try:
return cls._betterproto_meta
except AttributeError:
cls._betterproto_meta = meta = ProtoClassMetadata(cls)
return meta
def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None:
@ -1512,7 +1520,91 @@ class Message(ABC):
output[cased_name] = value
return output
def from_dict(self: T, value: Mapping[str, Any]) -> T:
@classmethod
def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
init_kwargs: Dict[str, Any] = {}
for key, value in mapping.items():
field_name = safe_snake_case(key)
try:
meta = cls._betterproto.meta_by_field_name[field_name]
except KeyError:
continue
if value is None:
continue
if meta.proto_type == TYPE_MESSAGE:
sub_cls = cls._betterproto.cls_by_field[field_name]
if sub_cls == datetime:
value = (
[isoparse(item) for item in value]
if isinstance(value, list)
else isoparse(value)
)
elif sub_cls == timedelta:
value = (
[timedelta(seconds=float(item[:-1])) for item in value]
if isinstance(value, list)
else timedelta(seconds=float(value[:-1]))
)
elif not meta.wraps:
value = (
[sub_cls.from_dict(item) for item in value]
if isinstance(value, list)
else sub_cls.from_dict(value)
)
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"]
value = {k: sub_cls.from_dict(v) for k, v in value.items()}
else:
if meta.proto_type in INT_64_TYPES:
value = (
[int(n) for n in value]
if isinstance(value, list)
else int(value)
)
elif meta.proto_type == TYPE_BYTES:
value = (
[b64decode(n) for n in value]
if isinstance(value, list)
else b64decode(value)
)
elif meta.proto_type == TYPE_ENUM:
enum_cls = cls._betterproto.cls_by_field[field_name]
if isinstance(value, list):
value = [enum_cls.from_string(e) for e in value]
elif isinstance(value, str):
value = enum_cls.from_string(value)
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
value = (
[_parse_float(n) for n in value]
if isinstance(value, list)
else _parse_float(value)
)
init_kwargs[field_name] = value
return init_kwargs
@hybridmethod
def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self: # type: ignore
"""
Parse the key/value pairs into the a new message instance.
Parameters
-----------
value: Dict[:class:`str`, Any]
The dictionary to parse from.
Returns
--------
:class:`Message`
The initialized message.
"""
self = cls(**cls._from_dict_init(value))
self._serialized_on_wire = True
return self
@from_dict.instancemethod
def from_dict(self, value: Mapping[str, Any]) -> Self:
"""
Parse the key/value pairs into the current message instance. This returns the
instance itself and is therefore assignable and chainable.
@ -1528,71 +1620,8 @@ class Message(ABC):
The initialized message.
"""
self._serialized_on_wire = True
for key in value:
field_name = safe_snake_case(key)
meta = self._betterproto.meta_by_field_name.get(field_name)
if not meta:
continue
if value[key] is not None:
if meta.proto_type == TYPE_MESSAGE:
v = self._get_field_default(field_name)
cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
if cls == datetime:
v = [isoparse(item) for item in value[key]]
elif cls == timedelta:
v = [
timedelta(seconds=float(item[:-1]))
for item in value[key]
]
else:
v = [cls().from_dict(item) for item in value[key]]
elif cls == datetime:
v = isoparse(value[key])
setattr(self, field_name, v)
elif cls == timedelta:
v = timedelta(seconds=float(value[key][:-1]))
setattr(self, field_name, v)
elif meta.wraps:
setattr(self, field_name, value[key])
elif v is None:
setattr(self, field_name, cls().from_dict(value[key]))
else:
# NOTE: `from_dict` mutates the underlying message, so no
# assignment here is necessary.
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field_name)
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
for k in value[key]:
v[k] = cls().from_dict(value[key][k])
else:
v = value[key]
if meta.proto_type in INT_64_TYPES:
if isinstance(value[key], list):
v = [int(n) for n in value[key]]
else:
v = int(value[key])
elif meta.proto_type == TYPE_BYTES:
if isinstance(value[key], list):
v = [b64decode(n) for n in value[key]]
else:
v = b64decode(value[key])
elif meta.proto_type == TYPE_ENUM:
enum_cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
v = [enum_cls.from_string(e) for e in v]
elif isinstance(v, str):
v = enum_cls.from_string(v)
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
if isinstance(value[key], list):
v = [_parse_float(n) for n in value[key]]
else:
v = _parse_float(value[key])
if v is not None:
setattr(self, field_name, v)
for field, value in self._from_dict_init(value).items():
setattr(self, field, value)
return self
def to_json(
@ -1809,8 +1838,8 @@ class Message(ABC):
@classmethod
def _validate_field_groups(cls, values):
group_to_one_ofs = cls._betterproto_meta.oneof_field_by_group # type: ignore
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore
group_to_one_ofs = cls._betterproto.oneof_field_by_group
field_name_to_meta = cls._betterproto.meta_by_field_name
for group, field_set in group_to_one_ofs.items():
if len(field_set) == 1:
@ -1837,6 +1866,9 @@ class Message(ABC):
return values
Message.__annotations__ = {} # HACK to avoid typing.get_type_hints breaking :)
def serialized_on_wire(message: Message) -> bool:
"""
If this message was or should be serialized on the wire. This can be used to detect

56
src/betterproto/utils.py Normal file
View File

@ -0,0 +1,56 @@
from __future__ import annotations
from typing import (
Any,
Callable,
Generic,
Optional,
Type,
TypeVar,
)
from typing_extensions import (
Concatenate,
ParamSpec,
Self,
)
SelfT = TypeVar("SelfT")
P = ParamSpec("P")
HybridT = TypeVar("HybridT", covariant=True)
class hybridmethod(Generic[SelfT, P, HybridT]):
def __init__(
self,
func: Callable[
Concatenate[type[SelfT], P], HybridT
], # Must be the classmethod version
):
self.cls_func = func
self.__doc__ = func.__doc__
def instancemethod(self, func: Callable[Concatenate[SelfT, P], HybridT]) -> Self:
self.instance_func = func
return self
def __get__(
self, instance: Optional[SelfT], owner: Type[SelfT]
) -> Callable[P, HybridT]:
if instance is None or self.instance_func is None:
# either bound to the class, or no instance method available
return self.cls_func.__get__(owner, None)
return self.instance_func.__get__(instance, owner)
T_co = TypeVar("T_co")
TT_co = TypeVar("TT_co", bound="type[Any]")
class classproperty(Generic[TT_co, T_co]):
def __init__(self, func: Callable[[TT_co], T_co]):
self.__func__ = func
def __get__(self, instance: Any, type: TT_co) -> T_co:
return self.__func__(type)