Support Message.from_dict() as a class and an instance method (#476)
				
					
				
			* Make Message.from_dict() a class method Signed-off-by: Marek Pikuła <marek.pikula@embevity.com> * Sync 1/2 of review comments * Sync other half * Update .pre-commit-config.yaml * Update __init__.py * Update utils.py * Update src/betterproto/__init__.py * Update .pre-commit-config.yaml * Update __init__.py * Update utils.py * Fix CI again * Fix failing formatting --------- Signed-off-by: Marek Pikuła <marek.pikula@embevity.com> Co-authored-by: James Hilton-Balfe <gobot1234yt@gmail.com>
This commit is contained in:
		| @@ -1,3 +1,5 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| import dataclasses | ||||
| import enum as builtin_enum | ||||
| import json | ||||
| @@ -22,8 +24,8 @@ from itertools import count | ||||
| from typing import ( | ||||
|     TYPE_CHECKING, | ||||
|     Any, | ||||
|     BinaryIO, | ||||
|     Callable, | ||||
|     ClassVar, | ||||
|     Dict, | ||||
|     Generator, | ||||
|     Iterable, | ||||
| @@ -37,6 +39,7 @@ from typing import ( | ||||
| ) | ||||
|  | ||||
| from dateutil.parser import isoparse | ||||
| from typing_extensions import Self | ||||
|  | ||||
| from ._types import T | ||||
| from ._version import __version__ | ||||
| @@ -47,6 +50,10 @@ from .casing import ( | ||||
| ) | ||||
| from .enum import Enum as Enum | ||||
| from .grpc.grpclib_client import ServiceStub as ServiceStub | ||||
| from .utils import ( | ||||
|     classproperty, | ||||
|     hybridmethod, | ||||
| ) | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
| @@ -729,6 +736,7 @@ class Message(ABC): | ||||
|     _serialized_on_wire: bool | ||||
|     _unknown_fields: bytes | ||||
|     _group_current: Dict[str, str] | ||||
|     _betterproto_meta: ClassVar[ProtoClassMetadata] | ||||
|  | ||||
|     def __post_init__(self) -> None: | ||||
|         # Keep track of whether every field was default | ||||
| @@ -882,18 +890,18 @@ class Message(ABC): | ||||
|                 kwargs[name] = value | ||||
|         return self.__class__(**kwargs)  # type: ignore | ||||
|  | ||||
|     @property | ||||
|     def _betterproto(self) -> ProtoClassMetadata: | ||||
|     @classproperty | ||||
|     def _betterproto(cls: type[Self]) -> ProtoClassMetadata:  # type: ignore | ||||
|         """ | ||||
|         Lazy initialize metadata for each protobuf class. | ||||
|         It may be initialized multiple times in a multi-threaded environment, | ||||
|         but that won't affect the correctness. | ||||
|         """ | ||||
|         meta = getattr(self.__class__, "_betterproto_meta", None) | ||||
|         if not meta: | ||||
|             meta = ProtoClassMetadata(self.__class__) | ||||
|             self.__class__._betterproto_meta = meta  # type: ignore | ||||
|         return meta | ||||
|         try: | ||||
|             return cls._betterproto_meta | ||||
|         except AttributeError: | ||||
|             cls._betterproto_meta = meta = ProtoClassMetadata(cls) | ||||
|             return meta | ||||
|  | ||||
|     def dump(self, stream: "SupportsWrite[bytes]", delimit: bool = False) -> None: | ||||
|         """ | ||||
| @@ -1512,7 +1520,91 @@ class Message(ABC): | ||||
|                     output[cased_name] = value | ||||
|         return output | ||||
|  | ||||
|     def from_dict(self: T, value: Mapping[str, Any]) -> T: | ||||
|     @classmethod | ||||
|     def _from_dict_init(cls, mapping: Mapping[str, Any]) -> Mapping[str, Any]: | ||||
|         init_kwargs: Dict[str, Any] = {} | ||||
|         for key, value in mapping.items(): | ||||
|             field_name = safe_snake_case(key) | ||||
|             try: | ||||
|                 meta = cls._betterproto.meta_by_field_name[field_name] | ||||
|             except KeyError: | ||||
|                 continue | ||||
|             if value is None: | ||||
|                 continue | ||||
|  | ||||
|             if meta.proto_type == TYPE_MESSAGE: | ||||
|                 sub_cls = cls._betterproto.cls_by_field[field_name] | ||||
|                 if sub_cls == datetime: | ||||
|                     value = ( | ||||
|                         [isoparse(item) for item in value] | ||||
|                         if isinstance(value, list) | ||||
|                         else isoparse(value) | ||||
|                     ) | ||||
|                 elif sub_cls == timedelta: | ||||
|                     value = ( | ||||
|                         [timedelta(seconds=float(item[:-1])) for item in value] | ||||
|                         if isinstance(value, list) | ||||
|                         else timedelta(seconds=float(value[:-1])) | ||||
|                     ) | ||||
|                 elif not meta.wraps: | ||||
|                     value = ( | ||||
|                         [sub_cls.from_dict(item) for item in value] | ||||
|                         if isinstance(value, list) | ||||
|                         else sub_cls.from_dict(value) | ||||
|                     ) | ||||
|             elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: | ||||
|                 sub_cls = cls._betterproto.cls_by_field[f"{field_name}.value"] | ||||
|                 value = {k: sub_cls.from_dict(v) for k, v in value.items()} | ||||
|             else: | ||||
|                 if meta.proto_type in INT_64_TYPES: | ||||
|                     value = ( | ||||
|                         [int(n) for n in value] | ||||
|                         if isinstance(value, list) | ||||
|                         else int(value) | ||||
|                     ) | ||||
|                 elif meta.proto_type == TYPE_BYTES: | ||||
|                     value = ( | ||||
|                         [b64decode(n) for n in value] | ||||
|                         if isinstance(value, list) | ||||
|                         else b64decode(value) | ||||
|                     ) | ||||
|                 elif meta.proto_type == TYPE_ENUM: | ||||
|                     enum_cls = cls._betterproto.cls_by_field[field_name] | ||||
|                     if isinstance(value, list): | ||||
|                         value = [enum_cls.from_string(e) for e in value] | ||||
|                     elif isinstance(value, str): | ||||
|                         value = enum_cls.from_string(value) | ||||
|                 elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): | ||||
|                     value = ( | ||||
|                         [_parse_float(n) for n in value] | ||||
|                         if isinstance(value, list) | ||||
|                         else _parse_float(value) | ||||
|                     ) | ||||
|  | ||||
|             init_kwargs[field_name] = value | ||||
|         return init_kwargs | ||||
|  | ||||
|     @hybridmethod | ||||
|     def from_dict(cls: type[Self], value: Mapping[str, Any]) -> Self:  # type: ignore | ||||
|         """ | ||||
|         Parse the key/value pairs into the a new message instance. | ||||
|  | ||||
|         Parameters | ||||
|         ----------- | ||||
|         value: Dict[:class:`str`, Any] | ||||
|             The dictionary to parse from. | ||||
|  | ||||
|         Returns | ||||
|         -------- | ||||
|         :class:`Message` | ||||
|             The initialized message. | ||||
|         """ | ||||
|         self = cls(**cls._from_dict_init(value)) | ||||
|         self._serialized_on_wire = True | ||||
|         return self | ||||
|  | ||||
|     @from_dict.instancemethod | ||||
|     def from_dict(self, value: Mapping[str, Any]) -> Self: | ||||
|         """ | ||||
|         Parse the key/value pairs into the current message instance. This returns the | ||||
|         instance itself and is therefore assignable and chainable. | ||||
| @@ -1528,71 +1620,8 @@ class Message(ABC): | ||||
|             The initialized message. | ||||
|         """ | ||||
|         self._serialized_on_wire = True | ||||
|         for key in value: | ||||
|             field_name = safe_snake_case(key) | ||||
|             meta = self._betterproto.meta_by_field_name.get(field_name) | ||||
|             if not meta: | ||||
|                 continue | ||||
|  | ||||
|             if value[key] is not None: | ||||
|                 if meta.proto_type == TYPE_MESSAGE: | ||||
|                     v = self._get_field_default(field_name) | ||||
|                     cls = self._betterproto.cls_by_field[field_name] | ||||
|                     if isinstance(v, list): | ||||
|                         if cls == datetime: | ||||
|                             v = [isoparse(item) for item in value[key]] | ||||
|                         elif cls == timedelta: | ||||
|                             v = [ | ||||
|                                 timedelta(seconds=float(item[:-1])) | ||||
|                                 for item in value[key] | ||||
|                             ] | ||||
|                         else: | ||||
|                             v = [cls().from_dict(item) for item in value[key]] | ||||
|                     elif cls == datetime: | ||||
|                         v = isoparse(value[key]) | ||||
|                         setattr(self, field_name, v) | ||||
|                     elif cls == timedelta: | ||||
|                         v = timedelta(seconds=float(value[key][:-1])) | ||||
|                         setattr(self, field_name, v) | ||||
|                     elif meta.wraps: | ||||
|                         setattr(self, field_name, value[key]) | ||||
|                     elif v is None: | ||||
|                         setattr(self, field_name, cls().from_dict(value[key])) | ||||
|                     else: | ||||
|                         # NOTE: `from_dict` mutates the underlying message, so no | ||||
|                         # assignment here is necessary. | ||||
|                         v.from_dict(value[key]) | ||||
|                 elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: | ||||
|                     v = getattr(self, field_name) | ||||
|                     cls = self._betterproto.cls_by_field[f"{field_name}.value"] | ||||
|                     for k in value[key]: | ||||
|                         v[k] = cls().from_dict(value[key][k]) | ||||
|                 else: | ||||
|                     v = value[key] | ||||
|                     if meta.proto_type in INT_64_TYPES: | ||||
|                         if isinstance(value[key], list): | ||||
|                             v = [int(n) for n in value[key]] | ||||
|                         else: | ||||
|                             v = int(value[key]) | ||||
|                     elif meta.proto_type == TYPE_BYTES: | ||||
|                         if isinstance(value[key], list): | ||||
|                             v = [b64decode(n) for n in value[key]] | ||||
|                         else: | ||||
|                             v = b64decode(value[key]) | ||||
|                     elif meta.proto_type == TYPE_ENUM: | ||||
|                         enum_cls = self._betterproto.cls_by_field[field_name] | ||||
|                         if isinstance(v, list): | ||||
|                             v = [enum_cls.from_string(e) for e in v] | ||||
|                         elif isinstance(v, str): | ||||
|                             v = enum_cls.from_string(v) | ||||
|                     elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE): | ||||
|                         if isinstance(value[key], list): | ||||
|                             v = [_parse_float(n) for n in value[key]] | ||||
|                         else: | ||||
|                             v = _parse_float(value[key]) | ||||
|  | ||||
|                 if v is not None: | ||||
|                     setattr(self, field_name, v) | ||||
|         for field, value in self._from_dict_init(value).items(): | ||||
|             setattr(self, field, value) | ||||
|         return self | ||||
|  | ||||
|     def to_json( | ||||
| @@ -1809,8 +1838,8 @@ class Message(ABC): | ||||
|  | ||||
|     @classmethod | ||||
|     def _validate_field_groups(cls, values): | ||||
|         group_to_one_ofs = cls._betterproto_meta.oneof_field_by_group  # type: ignore | ||||
|         field_name_to_meta = cls._betterproto_meta.meta_by_field_name  # type: ignore | ||||
|         group_to_one_ofs = cls._betterproto.oneof_field_by_group | ||||
|         field_name_to_meta = cls._betterproto.meta_by_field_name | ||||
|  | ||||
|         for group, field_set in group_to_one_ofs.items(): | ||||
|             if len(field_set) == 1: | ||||
| @@ -1837,6 +1866,9 @@ class Message(ABC): | ||||
|         return values | ||||
|  | ||||
|  | ||||
| Message.__annotations__ = {}  # HACK to avoid typing.get_type_hints breaking :) | ||||
|  | ||||
|  | ||||
| def serialized_on_wire(message: Message) -> bool: | ||||
|     """ | ||||
|     If this message was or should be serialized on the wire. This can be used to detect | ||||
|   | ||||
							
								
								
									
										56
									
								
								src/betterproto/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								src/betterproto/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,56 @@ | ||||
| from __future__ import annotations | ||||
|  | ||||
| from typing import ( | ||||
|     Any, | ||||
|     Callable, | ||||
|     Generic, | ||||
|     Optional, | ||||
|     Type, | ||||
|     TypeVar, | ||||
| ) | ||||
|  | ||||
| from typing_extensions import ( | ||||
|     Concatenate, | ||||
|     ParamSpec, | ||||
|     Self, | ||||
| ) | ||||
|  | ||||
|  | ||||
| SelfT = TypeVar("SelfT") | ||||
| P = ParamSpec("P") | ||||
| HybridT = TypeVar("HybridT", covariant=True) | ||||
|  | ||||
|  | ||||
| class hybridmethod(Generic[SelfT, P, HybridT]): | ||||
|     def __init__( | ||||
|         self, | ||||
|         func: Callable[ | ||||
|             Concatenate[type[SelfT], P], HybridT | ||||
|         ],  # Must be the classmethod version | ||||
|     ): | ||||
|         self.cls_func = func | ||||
|         self.__doc__ = func.__doc__ | ||||
|  | ||||
|     def instancemethod(self, func: Callable[Concatenate[SelfT, P], HybridT]) -> Self: | ||||
|         self.instance_func = func | ||||
|         return self | ||||
|  | ||||
|     def __get__( | ||||
|         self, instance: Optional[SelfT], owner: Type[SelfT] | ||||
|     ) -> Callable[P, HybridT]: | ||||
|         if instance is None or self.instance_func is None: | ||||
|             # either bound to the class, or no instance method available | ||||
|             return self.cls_func.__get__(owner, None) | ||||
|         return self.instance_func.__get__(instance, owner) | ||||
|  | ||||
|  | ||||
| T_co = TypeVar("T_co") | ||||
| TT_co = TypeVar("TT_co", bound="type[Any]") | ||||
|  | ||||
|  | ||||
| class classproperty(Generic[TT_co, T_co]): | ||||
|     def __init__(self, func: Callable[[TT_co], T_co]): | ||||
|         self.__func__ = func | ||||
|  | ||||
|     def __get__(self, instance: Any, type: TT_co) -> T_co: | ||||
|         return self.__func__(type) | ||||
		Reference in New Issue
	
	Block a user