Support Message.from_dict()
as a class and an instance method (#476)
* Make Message.from_dict() a class method Signed-off-by: Marek Pikuła <marek.pikula@embevity.com> * Sync 1/2 of review comments * Sync other half * Update .pre-commit-config.yaml * Update __init__.py * Update utils.py * Update src/betterproto/__init__.py * Update .pre-commit-config.yaml * Update __init__.py * Update utils.py * Fix CI again * Fix failing formatting --------- Signed-off-by: Marek Pikuła <marek.pikula@embevity.com> Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
This commit is contained in:
parent
02aa4e88b7
commit
d9b7608980
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import enum as builtin_enum
|
||||
import json
|
||||
@ -22,8 +24,8 @@ from itertools import count
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
BinaryIO,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
@ -37,6 +39,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from dateutil.parser import isoparse
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._types import T
|
||||
from ._version import __version__
|
||||
@ -47,6 +50,10 @@ from .casing import (
|
||||
)
|
||||
from .enum import Enum as Enum
|
||||
from .grpc.grpclib_client import ServiceStub as ServiceStub
|
||||
from .utils import (
|
||||
classproperty,
|
||||
hybridmethod,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -729,6 +736,7 @@ class Message(ABC):
|
||||
_serialized_on_wire: bool
|
||||
_unknown_fields: bytes
|
||||
_group_current: Dict[str, str]
|
||||
_betterproto_meta: ClassVar[ProtoClassMetadata]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Keep track of whether every field was default
|
||||
@ -882,17 +890,17 @@ class Message(ABC):
|
||||
kwargs[name] = value
|
||||
return self.__class__(**kwargs) # type: ignore
|
||||
|
||||
@property
|
||||
def _betterproto(self) -> ProtoClassMetadata:
|
||||
@classproperty
|
||||
def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore
|
||||
"""
|
||||
Lazy initialize metadata for each protobuf class.
|
||||
It may be initialized multiple times in a multi-threaded environment,
|
||||
but that won't affect the correctness.
|
||||
"""
|
||||
meta = getattr(self.__class__, "_betterproto_meta", None)
|
||||
if not meta:
|
||||
meta = ProtoClassMetadata(self.__class__)
|
||||
self.__class__._betterproto_meta = meta # type: ignore
|
||||
try:
|
||||
return cls._betterproto_meta
|
||||
except AttributeError:
|
||||
cls._betterproto_meta = meta = ProtoClassMetadata(cls)
|
||||
return meta
|
||||
|
||||
def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None:
|
||||
@ -1512,7 +1520,91 @@ class Message(ABC):
|
||||
output[cased_name] = value
|
||||
return output
|
||||
|
||||
def from_dict(self: T, value: Mapping[str, Any]) -> T:
|
||||
@classmethod
|
||||
def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||
init_kwargs: Dict[str, Any] = {}
|
||||
for key, value in mapping.items():
|
||||
field_name = safe_snake_case(key)
|
||||
try:
|
||||
meta = cls._betterproto.meta_by_field_name[field_name]
|
||||
except KeyError:
|
||||
continue
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
if meta.proto_type == TYPE_MESSAGE:
|
||||
sub_cls = cls._betterproto.cls_by_field[field_name]
|
||||
if sub_cls == datetime:
|
||||
value = (
|
||||
[isoparse(item) for item in value]
|
||||
if isinstance(value, list)
|
||||
else isoparse(value)
|
||||
)
|
||||
elif sub_cls == timedelta:
|
||||
value = (
|
||||
[timedelta(seconds=float(item[:-1])) for item in value]
|
||||
if isinstance(value, list)
|
||||
else timedelta(seconds=float(value[:-1]))
|
||||
)
|
||||
elif not meta.wraps:
|
||||
value = (
|
||||
[sub_cls.from_dict(item) for item in value]
|
||||
if isinstance(value, list)
|
||||
else sub_cls.from_dict(value)
|
||||
)
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"]
|
||||
value = {k: sub_cls.from_dict(v) for k, v in value.items()}
|
||||
else:
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
value = (
|
||||
[int(n) for n in value]
|
||||
if isinstance(value, list)
|
||||
else int(value)
|
||||
)
|
||||
elif meta.proto_type == TYPE_BYTES:
|
||||
value = (
|
||||
[b64decode(n) for n in value]
|
||||
if isinstance(value, list)
|
||||
else b64decode(value)
|
||||
)
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_cls = cls._betterproto.cls_by_field[field_name]
|
||||
if isinstance(value, list):
|
||||
value = [enum_cls.from_string(e) for e in value]
|
||||
elif isinstance(value, str):
|
||||
value = enum_cls.from_string(value)
|
||||
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
|
||||
value = (
|
||||
[_parse_float(n) for n in value]
|
||||
if isinstance(value, list)
|
||||
else _parse_float(value)
|
||||
)
|
||||
|
||||
init_kwargs[field_name] = value
|
||||
return init_kwargs
|
||||
|
||||
@hybridmethod
|
||||
def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self: # type: ignore
|
||||
"""
|
||||
Parse the key/value pairs into the a new message instance.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
value: Dict[:class:`str`, Any]
|
||||
The dictionary to parse from.
|
||||
|
||||
Returns
|
||||
--------
|
||||
:class:`Message`
|
||||
The initialized message.
|
||||
"""
|
||||
self = cls(**cls._from_dict_init(value))
|
||||
self._serialized_on_wire = True
|
||||
return self
|
||||
|
||||
@from_dict.instancemethod
|
||||
def from_dict(self, value: Mapping[str, Any]) -> Self:
|
||||
"""
|
||||
Parse the key/value pairs into the current message instance. This returns the
|
||||
instance itself and is therefore assignable and chainable.
|
||||
@ -1528,71 +1620,8 @@ class Message(ABC):
|
||||
The initialized message.
|
||||
"""
|
||||
self._serialized_on_wire = True
|
||||
for key in value:
|
||||
field_name = safe_snake_case(key)
|
||||
meta = self._betterproto.meta_by_field_name.get(field_name)
|
||||
if not meta:
|
||||
continue
|
||||
|
||||
if value[key] is not None:
|
||||
if meta.proto_type == TYPE_MESSAGE:
|
||||
v = self._get_field_default(field_name)
|
||||
cls = self._betterproto.cls_by_field[field_name]
|
||||
if isinstance(v, list):
|
||||
if cls == datetime:
|
||||
v = [isoparse(item) for item in value[key]]
|
||||
elif cls == timedelta:
|
||||
v = [
|
||||
timedelta(seconds=float(item[:-1]))
|
||||
for item in value[key]
|
||||
]
|
||||
else:
|
||||
v = [cls().from_dict(item) for item in value[key]]
|
||||
elif cls == datetime:
|
||||
v = isoparse(value[key])
|
||||
setattr(self, field_name, v)
|
||||
elif cls == timedelta:
|
||||
v = timedelta(seconds=float(value[key][:-1]))
|
||||
setattr(self, field_name, v)
|
||||
elif meta.wraps:
|
||||
setattr(self, field_name, value[key])
|
||||
elif v is None:
|
||||
setattr(self, field_name, cls().from_dict(value[key]))
|
||||
else:
|
||||
# NOTE: `from_dict` mutates the underlying message, so no
|
||||
# assignment here is necessary.
|
||||
v.from_dict(value[key])
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
v = getattr(self, field_name)
|
||||
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
|
||||
for k in value[key]:
|
||||
v[k] = cls().from_dict(value[key][k])
|
||||
else:
|
||||
v = value[key]
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(value[key], list):
|
||||
v = [int(n) for n in value[key]]
|
||||
else:
|
||||
v = int(value[key])
|
||||
elif meta.proto_type == TYPE_BYTES:
|
||||
if isinstance(value[key], list):
|
||||
v = [b64decode(n) for n in value[key]]
|
||||
else:
|
||||
v = b64decode(value[key])
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_cls = self._betterproto.cls_by_field[field_name]
|
||||
if isinstance(v, list):
|
||||
v = [enum_cls.from_string(e) for e in v]
|
||||
elif isinstance(v, str):
|
||||
v = enum_cls.from_string(v)
|
||||
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
|
||||
if isinstance(value[key], list):
|
||||
v = [_parse_float(n) for n in value[key]]
|
||||
else:
|
||||
v = _parse_float(value[key])
|
||||
|
||||
if v is not None:
|
||||
setattr(self, field_name, v)
|
||||
for field, value in self._from_dict_init(value).items():
|
||||
setattr(self, field, value)
|
||||
return self
|
||||
|
||||
def to_json(
|
||||
@ -1809,8 +1838,8 @@ class Message(ABC):
|
||||
|
||||
@classmethod
|
||||
def _validate_field_groups(cls, values):
|
||||
group_to_one_ofs = cls._betterproto_meta.oneof_field_by_group # type: ignore
|
||||
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore
|
||||
group_to_one_ofs = cls._betterproto.oneof_field_by_group
|
||||
field_name_to_meta = cls._betterproto.meta_by_field_name
|
||||
|
||||
for group, field_set in group_to_one_ofs.items():
|
||||
if len(field_set) == 1:
|
||||
@ -1837,6 +1866,9 @@ class Message(ABC):
|
||||
return values
|
||||
|
||||
|
||||
Message.__annotations__ = {} # HACK to avoid typing.get_type_hints breaking :)
|
||||
|
||||
|
||||
def serialized_on_wire(message: Message) -> bool:
|
||||
"""
|
||||
If this message was or should be serialized on the wire. This can be used to detect
|
||||
|
56
src/betterproto/utils.py
Normal file
56
src/betterproto/utils.py
Normal file
@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from typing_extensions import (
|
||||
Concatenate,
|
||||
ParamSpec,
|
||||
Self,
|
||||
)
|
||||
|
||||
|
||||
SelfT = TypeVar("SelfT")
|
||||
P = ParamSpec("P")
|
||||
HybridT = TypeVar("HybridT", covariant=True)
|
||||
|
||||
|
||||
class hybridmethod(Generic[SelfT, P, HybridT]):
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[
|
||||
Concatenate[type[SelfT], P], HybridT
|
||||
], # Must be the classmethod version
|
||||
):
|
||||
self.cls_func = func
|
||||
self.__doc__ = func.__doc__
|
||||
|
||||
def instancemethod(self, func: Callable[Concatenate[SelfT, P], HybridT]) -> Self:
|
||||
self.instance_func = func
|
||||
return self
|
||||
|
||||
def __get__(
|
||||
self, instance: Optional[SelfT], owner: Type[SelfT]
|
||||
) -> Callable[P, HybridT]:
|
||||
if instance is None or self.instance_func is None:
|
||||
# either bound to the class, or no instance method available
|
||||
return self.cls_func.__get__(owner, None)
|
||||
return self.instance_func.__get__(instance, owner)
|
||||
|
||||
|
||||
T_co = TypeVar("T_co")
|
||||
TT_co = TypeVar("TT_co", bound="type[Any]")
|
||||
|
||||
|
||||
class classproperty(Generic[TT_co, T_co]):
|
||||
def __init__(self, func: Callable[[TT_co], T_co]):
|
||||
self.__func__ = func
|
||||
|
||||
def __get__(self, instance: Any, type: TT_co) -> T_co:
|
||||
return self.__func__(type)
|
Loading…
x
Reference in New Issue
Block a user