Support Message.from_dict()
as a class and an instance method (#476)
* Make Message.from_dict() a class method Signed-off-by: Marek Pikuła <marek.pikula@embevity.com> * Sync 1/2 of review comments * Sync other half * Update .pre-commit-config.yaml * Update __init__.py * Update utils.py * Update src/betterproto/__init__.py * Update .pre-commit-config.yaml * Update __init__.py * Update utils.py * Fix CI again * Fix failing formatting --------- Signed-off-by: Marek Pikuła <marek.pikula@embevity.com> Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
This commit is contained in:
parent
02aa4e88b7
commit
d9b7608980
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import enum as builtin_enum
|
import enum as builtin_enum
|
||||||
import json
|
import json
|
||||||
@ -22,8 +24,8 @@ from itertools import count
|
|||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
BinaryIO,
|
|
||||||
Callable,
|
Callable,
|
||||||
|
ClassVar,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
Iterable,
|
Iterable,
|
||||||
@ -37,6 +39,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from dateutil.parser import isoparse
|
from dateutil.parser import isoparse
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from ._types import T
|
from ._types import T
|
||||||
from ._version import __version__
|
from ._version import __version__
|
||||||
@ -47,6 +50,10 @@ from .casing import (
|
|||||||
)
|
)
|
||||||
from .enum import Enum as Enum
|
from .enum import Enum as Enum
|
||||||
from .grpc.grpclib_client import ServiceStub as ServiceStub
|
from .grpc.grpclib_client import ServiceStub as ServiceStub
|
||||||
|
from .utils import (
|
||||||
|
classproperty,
|
||||||
|
hybridmethod,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -729,6 +736,7 @@ class Message(ABC):
|
|||||||
_serialized_on_wire: bool
|
_serialized_on_wire: bool
|
||||||
_unknown_fields: bytes
|
_unknown_fields: bytes
|
||||||
_group_current: Dict[str, str]
|
_group_current: Dict[str, str]
|
||||||
|
_betterproto_meta: ClassVar[ProtoClassMetadata]
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
# Keep track of whether every field was default
|
# Keep track of whether every field was default
|
||||||
@ -882,18 +890,18 @@ class Message(ABC):
|
|||||||
kwargs[name] = value
|
kwargs[name] = value
|
||||||
return self.__class__(**kwargs) # type: ignore
|
return self.__class__(**kwargs) # type: ignore
|
||||||
|
|
||||||
@property
|
@classproperty
|
||||||
def _betterproto(self) -> ProtoClassMetadata:
|
def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore
|
||||||
"""
|
"""
|
||||||
Lazy initialize metadata for each protobuf class.
|
Lazy initialize metadata for each protobuf class.
|
||||||
It may be initialized multiple times in a multi-threaded environment,
|
It may be initialized multiple times in a multi-threaded environment,
|
||||||
but that won't affect the correctness.
|
but that won't affect the correctness.
|
||||||
"""
|
"""
|
||||||
meta = getattr(self.__class__, "_betterproto_meta", None)
|
try:
|
||||||
if not meta:
|
return cls._betterproto_meta
|
||||||
meta = ProtoClassMetadata(self.__class__)
|
except AttributeError:
|
||||||
self.__class__._betterproto_meta = meta # type: ignore
|
cls._betterproto_meta = meta = ProtoClassMetadata(cls)
|
||||||
return meta
|
return meta
|
||||||
|
|
||||||
def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None:
|
def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
@ -1512,7 +1520,91 @@ class Message(ABC):
|
|||||||
output[cased_name] = value
|
output[cased_name] = value
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def from_dict(self: T, value: Mapping[str, Any]) -> T:
|
@classmethod
|
||||||
|
def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
init_kwargs: Dict[str, Any] = {}
|
||||||
|
for key, value in mapping.items():
|
||||||
|
field_name = safe_snake_case(key)
|
||||||
|
try:
|
||||||
|
meta = cls._betterproto.meta_by_field_name[field_name]
|
||||||
|
except KeyError:
|
||||||
|
continue
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if meta.proto_type == TYPE_MESSAGE:
|
||||||
|
sub_cls = cls._betterproto.cls_by_field[field_name]
|
||||||
|
if sub_cls == datetime:
|
||||||
|
value = (
|
||||||
|
[isoparse(item) for item in value]
|
||||||
|
if isinstance(value, list)
|
||||||
|
else isoparse(value)
|
||||||
|
)
|
||||||
|
elif sub_cls == timedelta:
|
||||||
|
value = (
|
||||||
|
[timedelta(seconds=float(item[:-1])) for item in value]
|
||||||
|
if isinstance(value, list)
|
||||||
|
else timedelta(seconds=float(value[:-1]))
|
||||||
|
)
|
||||||
|
elif not meta.wraps:
|
||||||
|
value = (
|
||||||
|
[sub_cls.from_dict(item) for item in value]
|
||||||
|
if isinstance(value, list)
|
||||||
|
else sub_cls.from_dict(value)
|
||||||
|
)
|
||||||
|
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||||
|
sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"]
|
||||||
|
value = {k: sub_cls.from_dict(v) for k, v in value.items()}
|
||||||
|
else:
|
||||||
|
if meta.proto_type in INT_64_TYPES:
|
||||||
|
value = (
|
||||||
|
[int(n) for n in value]
|
||||||
|
if isinstance(value, list)
|
||||||
|
else int(value)
|
||||||
|
)
|
||||||
|
elif meta.proto_type == TYPE_BYTES:
|
||||||
|
value = (
|
||||||
|
[b64decode(n) for n in value]
|
||||||
|
if isinstance(value, list)
|
||||||
|
else b64decode(value)
|
||||||
|
)
|
||||||
|
elif meta.proto_type == TYPE_ENUM:
|
||||||
|
enum_cls = cls._betterproto.cls_by_field[field_name]
|
||||||
|
if isinstance(value, list):
|
||||||
|
value = [enum_cls.from_string(e) for e in value]
|
||||||
|
elif isinstance(value, str):
|
||||||
|
value = enum_cls.from_string(value)
|
||||||
|
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
|
||||||
|
value = (
|
||||||
|
[_parse_float(n) for n in value]
|
||||||
|
if isinstance(value, list)
|
||||||
|
else _parse_float(value)
|
||||||
|
)
|
||||||
|
|
||||||
|
init_kwargs[field_name] = value
|
||||||
|
return init_kwargs
|
||||||
|
|
||||||
|
@hybridmethod
|
||||||
|
def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self: # type: ignore
|
||||||
|
"""
|
||||||
|
Parse the key/value pairs into the a new message instance.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
value: Dict[:class:`str`, Any]
|
||||||
|
The dictionary to parse from.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
--------
|
||||||
|
:class:`Message`
|
||||||
|
The initialized message.
|
||||||
|
"""
|
||||||
|
self = cls(**cls._from_dict_init(value))
|
||||||
|
self._serialized_on_wire = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
@from_dict.instancemethod
|
||||||
|
def from_dict(self, value: Mapping[str, Any]) -> Self:
|
||||||
"""
|
"""
|
||||||
Parse the key/value pairs into the current message instance. This returns the
|
Parse the key/value pairs into the current message instance. This returns the
|
||||||
instance itself and is therefore assignable and chainable.
|
instance itself and is therefore assignable and chainable.
|
||||||
@ -1528,71 +1620,8 @@ class Message(ABC):
|
|||||||
The initialized message.
|
The initialized message.
|
||||||
"""
|
"""
|
||||||
self._serialized_on_wire = True
|
self._serialized_on_wire = True
|
||||||
for key in value:
|
for field, value in self._from_dict_init(value).items():
|
||||||
field_name = safe_snake_case(key)
|
setattr(self, field, value)
|
||||||
meta = self._betterproto.meta_by_field_name.get(field_name)
|
|
||||||
if not meta:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if value[key] is not None:
|
|
||||||
if meta.proto_type == TYPE_MESSAGE:
|
|
||||||
v = self._get_field_default(field_name)
|
|
||||||
cls = self._betterproto.cls_by_field[field_name]
|
|
||||||
if isinstance(v, list):
|
|
||||||
if cls == datetime:
|
|
||||||
v = [isoparse(item) for item in value[key]]
|
|
||||||
elif cls == timedelta:
|
|
||||||
v = [
|
|
||||||
timedelta(seconds=float(item[:-1]))
|
|
||||||
for item in value[key]
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
v = [cls().from_dict(item) for item in value[key]]
|
|
||||||
elif cls == datetime:
|
|
||||||
v = isoparse(value[key])
|
|
||||||
setattr(self, field_name, v)
|
|
||||||
elif cls == timedelta:
|
|
||||||
v = timedelta(seconds=float(value[key][:-1]))
|
|
||||||
setattr(self, field_name, v)
|
|
||||||
elif meta.wraps:
|
|
||||||
setattr(self, field_name, value[key])
|
|
||||||
elif v is None:
|
|
||||||
setattr(self, field_name, cls().from_dict(value[key]))
|
|
||||||
else:
|
|
||||||
# NOTE: `from_dict` mutates the underlying message, so no
|
|
||||||
# assignment here is necessary.
|
|
||||||
v.from_dict(value[key])
|
|
||||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
|
||||||
v = getattr(self, field_name)
|
|
||||||
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
|
|
||||||
for k in value[key]:
|
|
||||||
v[k] = cls().from_dict(value[key][k])
|
|
||||||
else:
|
|
||||||
v = value[key]
|
|
||||||
if meta.proto_type in INT_64_TYPES:
|
|
||||||
if isinstance(value[key], list):
|
|
||||||
v = [int(n) for n in value[key]]
|
|
||||||
else:
|
|
||||||
v = int(value[key])
|
|
||||||
elif meta.proto_type == TYPE_BYTES:
|
|
||||||
if isinstance(value[key], list):
|
|
||||||
v = [b64decode(n) for n in value[key]]
|
|
||||||
else:
|
|
||||||
v = b64decode(value[key])
|
|
||||||
elif meta.proto_type == TYPE_ENUM:
|
|
||||||
enum_cls = self._betterproto.cls_by_field[field_name]
|
|
||||||
if isinstance(v, list):
|
|
||||||
v = [enum_cls.from_string(e) for e in v]
|
|
||||||
elif isinstance(v, str):
|
|
||||||
v = enum_cls.from_string(v)
|
|
||||||
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
|
|
||||||
if isinstance(value[key], list):
|
|
||||||
v = [_parse_float(n) for n in value[key]]
|
|
||||||
else:
|
|
||||||
v = _parse_float(value[key])
|
|
||||||
|
|
||||||
if v is not None:
|
|
||||||
setattr(self, field_name, v)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_json(
|
def to_json(
|
||||||
@ -1809,8 +1838,8 @@ class Message(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_field_groups(cls, values):
|
def _validate_field_groups(cls, values):
|
||||||
group_to_one_ofs = cls._betterproto_meta.oneof_field_by_group # type: ignore
|
group_to_one_ofs = cls._betterproto.oneof_field_by_group
|
||||||
field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore
|
field_name_to_meta = cls._betterproto.meta_by_field_name
|
||||||
|
|
||||||
for group, field_set in group_to_one_ofs.items():
|
for group, field_set in group_to_one_ofs.items():
|
||||||
if len(field_set) == 1:
|
if len(field_set) == 1:
|
||||||
@ -1837,6 +1866,9 @@ class Message(ABC):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
Message.__annotations__ = {} # HACK to avoid typing.get_type_hints breaking :)
|
||||||
|
|
||||||
|
|
||||||
def serialized_on_wire(message: Message) -> bool:
|
def serialized_on_wire(message: Message) -> bool:
|
||||||
"""
|
"""
|
||||||
If this message was or should be serialized on the wire. This can be used to detect
|
If this message was or should be serialized on the wire. This can be used to detect
|
||||||
|
56
src/betterproto/utils.py
Normal file
56
src/betterproto/utils.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Generic,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
|
from typing_extensions import (
|
||||||
|
Concatenate,
|
||||||
|
ParamSpec,
|
||||||
|
Self,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
SelfT = TypeVar("SelfT")
|
||||||
|
P = ParamSpec("P")
|
||||||
|
HybridT = TypeVar("HybridT", covariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
class hybridmethod(Generic[SelfT, P, HybridT]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
func: Callable[
|
||||||
|
Concatenate[type[SelfT], P], HybridT
|
||||||
|
], # Must be the classmethod version
|
||||||
|
):
|
||||||
|
self.cls_func = func
|
||||||
|
self.__doc__ = func.__doc__
|
||||||
|
|
||||||
|
def instancemethod(self, func: Callable[Concatenate[SelfT, P], HybridT]) -> Self:
|
||||||
|
self.instance_func = func
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __get__(
|
||||||
|
self, instance: Optional[SelfT], owner: Type[SelfT]
|
||||||
|
) -> Callable[P, HybridT]:
|
||||||
|
if instance is None or self.instance_func is None:
|
||||||
|
# either bound to the class, or no instance method available
|
||||||
|
return self.cls_func.__get__(owner, None)
|
||||||
|
return self.instance_func.__get__(instance, owner)
|
||||||
|
|
||||||
|
|
||||||
|
T_co = TypeVar("T_co")
|
||||||
|
TT_co = TypeVar("TT_co", bound="type[Any]")
|
||||||
|
|
||||||
|
|
||||||
|
class classproperty(Generic[TT_co, T_co]):
|
||||||
|
def __init__(self, func: Callable[[TT_co], T_co]):
|
||||||
|
self.__func__ = func
|
||||||
|
|
||||||
|
def __get__(self, instance: Any, type: TT_co) -> T_co:
|
||||||
|
return self.__func__(type)
|
Loading…
x
Reference in New Issue
Block a user