diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index f52edaa..b2a63d8 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses import enum as builtin_enum import json @@ -22,8 +24,8 @@ from itertools import count from typing import ( TYPE_CHECKING, Any, - BinaryIO, Callable, + ClassVar, Dict, Generator, Iterable, @@ -37,6 +39,7 @@ from typing import ( ) from dateutil.parser import isoparse +from typing_extensions import Self from ._types import T from ._version import __version__ @@ -47,6 +50,10 @@ from .casing import ( ) from .enum import Enum as Enum from .grpc.grpclib_client import ServiceStub as ServiceStub +from .utils import ( + classproperty, + hybridmethod, +) if TYPE_CHECKING: @@ -729,6 +736,7 @@ class Message(ABC): _serialized_on_wire: bool _unknown_fields: bytes _group_current: Dict[str, str] + _betterproto_meta: ClassVar[ProtoClassMetadata] def __post_init__(self) -> None: # Keep track of whether every field was default @@ -882,18 +890,18 @@ class Message(ABC): kwargs[name] = value return self.__class__(**kwargs) # type: ignore - @property - def _betterproto(self) -> ProtoClassMetadata: + @classproperty + def _betterproto(cls: type[Self]) -> ProtoClassMetadata: # type: ignore """ Lazy initialize metadata for each protobuf class. It may be initialized multiple times in a multi-threaded environment, but that won't affect the correctness. """ - meta = getattr(self.__class__, "_betterproto_meta", None) - if not meta: - meta = ProtoClassMetadata(self.__class__) - self.__class__._betterproto_meta = meta # type: ignore - return meta + try: + return cls._betterproto_meta + except AttributeError: + cls._betterproto_meta = meta = ProtoClassMetadata(cls) + return meta def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None: """ @@ -1512,7 +1520,91 @@ class Message(ABC): output[cased_name] = value return output - def from_dict(self: T, value: Mapping[str, Any]) -> T: + @classmethod + def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]: + init_kwargs: Dict[str, Any] = {} + for key, value in mapping.items(): + field_name = safe_snake_case(key) + try: + meta = cls._betterproto.meta_by_field_name[field_name] + except KeyError: + continue + if value is None: + continue + + if meta.proto_type == TYPE_MESSAGE: + sub_cls = cls._betterproto.cls_by_field[field_name] + if sub_cls == datetime: + value = ( + [isoparse(item) for item in value] + if isinstance(value, list) + else isoparse(value) + ) + elif sub_cls == timedelta: + value = ( + [timedelta(seconds=float(item[:-1])) for item in value] + if isinstance(value, list) + else timedelta(seconds=float(value[:-1])) + ) + elif not meta.wraps: + value = ( + [sub_cls.from_dict(item) for item in value] + if isinstance(value, list) + else sub_cls.from_dict(value) + ) + elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: + sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"] + value = {k: sub_cls.from_dict(v) for k, v in value.items()} + else: + if meta.proto_type in INT_64_TYPES: + value = ( + [int(n) for n in value] + if isinstance(value, list) + else int(value) + ) + elif meta.proto_type == TYPE_BYTES: + value = ( + [b64decode(n) for n in value] + if isinstance(value, list) + else b64decode(value) + ) + elif meta.proto_type == TYPE_ENUM: + enum_cls = cls._betterproto.cls_by_field[field_name] + if isinstance(value, list): + value = [enum_cls.from_string(e) for e in value] + elif isinstance(value, str): + value = enum_cls.from_string(value) + elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): + value = ( + [_parse_float(n) for n in value] + if isinstance(value, list) + else _parse_float(value) + ) + + init_kwargs[field_name] = value + return init_kwargs + + @hybridmethod + def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self: # type: ignore + """ + Parse the key/value pairs into the a new message instance. + + Parameters + ----------- + value: Dict[:class:`str`, Any] + The dictionary to parse from. + + Returns + -------- + :class:`Message` + The initialized message. + """ + self = cls(**cls._from_dict_init(value)) + self._serialized_on_wire = True + return self + + @from_dict.instancemethod + def from_dict(self, value: Mapping[str, Any]) -> Self: """ Parse the key/value pairs into the current message instance. This returns the instance itself and is therefore assignable and chainable. @@ -1528,71 +1620,8 @@ class Message(ABC): The initialized message. """ self._serialized_on_wire = True - for key in value: - field_name = safe_snake_case(key) - meta = self._betterproto.meta_by_field_name.get(field_name) - if not meta: - continue - - if value[key] is not None: - if meta.proto_type == TYPE_MESSAGE: - v = self._get_field_default(field_name) - cls = self._betterproto.cls_by_field[field_name] - if isinstance(v, list): - if cls == datetime: - v = [isoparse(item) for item in value[key]] - elif cls == timedelta: - v = [ - timedelta(seconds=float(item[:-1])) - for item in value[key] - ] - else: - v = [cls().from_dict(item) for item in value[key]] - elif cls == datetime: - v = isoparse(value[key]) - setattr(self, field_name, v) - elif cls == timedelta: - v = timedelta(seconds=float(value[key][:-1])) - setattr(self, field_name, v) - elif meta.wraps: - setattr(self, field_name, value[key]) - elif v is None: - setattr(self, field_name, cls().from_dict(value[key])) - else: - # NOTE: `from_dict` mutates the underlying message, so no - # assignment here is necessary. - v.from_dict(value[key]) - elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: - v = getattr(self, field_name) - cls = self._betterproto.cls_by_field[f"{field_name}.value"] - for k in value[key]: - v[k] = cls().from_dict(value[key][k]) - else: - v = value[key] - if meta.proto_type in INT_64_TYPES: - if isinstance(value[key], list): - v = [int(n) for n in value[key]] - else: - v = int(value[key]) - elif meta.proto_type == TYPE_BYTES: - if isinstance(value[key], list): - v = [b64decode(n) for n in value[key]] - else: - v = b64decode(value[key]) - elif meta.proto_type == TYPE_ENUM: - enum_cls = self._betterproto.cls_by_field[field_name] - if isinstance(v, list): - v = [enum_cls.from_string(e) for e in v] - elif isinstance(v, str): - v = enum_cls.from_string(v) - elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): - if isinstance(value[key], list): - v = [_parse_float(n) for n in value[key]] - else: - v = _parse_float(value[key]) - - if v is not None: - setattr(self, field_name, v) + for field, value in self._from_dict_init(value).items(): + setattr(self, field, value) return self def to_json( @@ -1809,8 +1838,8 @@ class Message(ABC): @classmethod def _validate_field_groups(cls, values): - group_to_one_ofs = cls._betterproto_meta.oneof_field_by_group # type: ignore - field_name_to_meta = cls._betterproto_meta.meta_by_field_name # type: ignore + group_to_one_ofs = cls._betterproto.oneof_field_by_group + field_name_to_meta = cls._betterproto.meta_by_field_name for group, field_set in group_to_one_ofs.items(): if len(field_set) == 1: @@ -1837,6 +1866,9 @@ class Message(ABC): return values +Message.__annotations__ = {} # HACK to avoid typing.get_type_hints breaking :) + + def serialized_on_wire(message: Message) -> bool: """ If this message was or should be serialized on the wire. This can be used to detect diff --git a/src/betterproto/utils.py b/src/betterproto/utils.py new file mode 100644 index 0000000..b977fc7 --- /dev/null +++ b/src/betterproto/utils.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from typing import ( + Any, + Callable, + Generic, + Optional, + Type, + TypeVar, +) + +from typing_extensions import ( + Concatenate, + ParamSpec, + Self, +) + + +SelfT = TypeVar("SelfT") +P = ParamSpec("P") +HybridT = TypeVar("HybridT", covariant=True) + + +class hybridmethod(Generic[SelfT, P, HybridT]): + def __init__( + self, + func: Callable[ + Concatenate[type[SelfT], P], HybridT + ], # Must be the classmethod version + ): + self.cls_func = func + self.__doc__ = func.__doc__ + + def instancemethod(self, func: Callable[Concatenate[SelfT, P], HybridT]) -> Self: + self.instance_func = func + return self + + def __get__( + self, instance: Optional[SelfT], owner: Type[SelfT] + ) -> Callable[P, HybridT]: + if instance is None or self.instance_func is None: + # either bound to the class, or no instance method available + return self.cls_func.__get__(owner, None) + return self.instance_func.__get__(instance, owner) + + +T_co = TypeVar("T_co") +TT_co = TypeVar("TT_co", bound="type[Any]") + + +class classproperty(Generic[TT_co, T_co]): + def __init__(self, func: Callable[[TT_co], T_co]): + self.__func__ = func + + def __get__(self, instance: Any, type: TT_co) -> T_co: + return self.__func__(type)