Fix: to_dict returns wrong enum fields when numbering is not consecutive (#102)
Fixes #93 to_dict returns wrong enum fields when numbering is not consecutive
This commit is contained in:
		| @@ -4,6 +4,7 @@ import inspect | ||||
| import json | ||||
| import struct | ||||
| import sys | ||||
| import warnings | ||||
| from abc import ABC | ||||
| from base64 import b64decode, b64encode | ||||
| from datetime import datetime, timedelta, timezone | ||||
| @@ -21,6 +22,8 @@ from typing import ( | ||||
|     get_type_hints, | ||||
| ) | ||||
|  | ||||
| import typing | ||||
|  | ||||
| from ._types import T | ||||
| from .casing import camel_case, safe_snake_case, snake_case | ||||
| from .grpc.grpclib_client import ServiceStub | ||||
| @@ -251,7 +254,7 @@ def map_field( | ||||
|     ) | ||||
|  | ||||
|  | ||||
| class Enum(int, enum.Enum): | ||||
| class Enum(enum.IntEnum): | ||||
|     """Protocol buffers enumeration base class. Acts like `enum.IntEnum`.""" | ||||
|  | ||||
|     @classmethod | ||||
| @@ -635,9 +638,13 @@ class Message(ABC): | ||||
|  | ||||
|     @classmethod | ||||
|     def _type_hint(cls, field_name: str) -> Type: | ||||
|         return cls._type_hints()[field_name] | ||||
|  | ||||
|     @classmethod | ||||
|     def _type_hints(cls) -> Dict[str, Type]: | ||||
|         module = inspect.getmodule(cls) | ||||
|         type_hints = get_type_hints(cls, vars(module)) | ||||
|         return type_hints[field_name] | ||||
|         return type_hints | ||||
|  | ||||
|     @classmethod | ||||
|     def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type: | ||||
| @@ -789,55 +796,67 @@ class Message(ABC): | ||||
|         `False`. | ||||
|         """ | ||||
|         output: Dict[str, Any] = {} | ||||
|         field_types = self._type_hints() | ||||
|         for field_name, meta in self._betterproto.meta_by_field_name.items(): | ||||
|             v = getattr(self, field_name) | ||||
|             field_type = field_types[field_name] | ||||
|             field_is_repeated = type(field_type) is type(typing.List) | ||||
|             value = getattr(self, field_name) | ||||
|             cased_name = casing(field_name).rstrip("_")  # type: ignore | ||||
|             if meta.proto_type == "message": | ||||
|                 if isinstance(v, datetime): | ||||
|                     if v != DATETIME_ZERO or include_default_values: | ||||
|                         output[cased_name] = _Timestamp.timestamp_to_json(v) | ||||
|                 elif isinstance(v, timedelta): | ||||
|                     if v != timedelta(0) or include_default_values: | ||||
|                         output[cased_name] = _Duration.delta_to_json(v) | ||||
|             if meta.proto_type == TYPE_MESSAGE: | ||||
|                 if isinstance(value, datetime): | ||||
|                     if value != DATETIME_ZERO or include_default_values: | ||||
|                         output[cased_name] = _Timestamp.timestamp_to_json(value) | ||||
|                 elif isinstance(value, timedelta): | ||||
|                     if value != timedelta(0) or include_default_values: | ||||
|                         output[cased_name] = _Duration.delta_to_json(value) | ||||
|                 elif meta.wraps: | ||||
|                     if v is not None or include_default_values: | ||||
|                         output[cased_name] = v | ||||
|                 elif isinstance(v, list): | ||||
|                     if value is not None or include_default_values: | ||||
|                         output[cased_name] = value | ||||
|                 elif field_is_repeated: | ||||
|                     # Convert each item. | ||||
|                     v = [i.to_dict(casing, include_default_values) for i in v] | ||||
|                     if v or include_default_values: | ||||
|                         output[cased_name] = v | ||||
|                     value = [i.to_dict(casing, include_default_values) for i in value] | ||||
|                     if value or include_default_values: | ||||
|                         output[cased_name] = value | ||||
|                 else: | ||||
|                     if v._serialized_on_wire or include_default_values: | ||||
|                         output[cased_name] = v.to_dict(casing, include_default_values) | ||||
|             elif meta.proto_type == "map": | ||||
|                 for k in v: | ||||
|                     if hasattr(v[k], "to_dict"): | ||||
|                         v[k] = v[k].to_dict(casing, include_default_values) | ||||
|                     if value._serialized_on_wire or include_default_values: | ||||
|                         output[cased_name] = value.to_dict( | ||||
|                             casing, include_default_values | ||||
|                         ) | ||||
|             elif meta.proto_type == TYPE_MAP: | ||||
|                 for k in value: | ||||
|                     if hasattr(value[k], "to_dict"): | ||||
|                         value[k] = value[k].to_dict(casing, include_default_values) | ||||
|  | ||||
|                 if v or include_default_values: | ||||
|                     output[cased_name] = v | ||||
|             elif v != self._get_field_default(field_name) or include_default_values: | ||||
|                 if value or include_default_values: | ||||
|                     output[cased_name] = value | ||||
|             elif value != self._get_field_default(field_name) or include_default_values: | ||||
|                 if meta.proto_type in INT_64_TYPES: | ||||
|                     if isinstance(v, list): | ||||
|                         output[cased_name] = [str(n) for n in v] | ||||
|                     if field_is_repeated: | ||||
|                         output[cased_name] = [str(n) for n in value] | ||||
|                     else: | ||||
|                         output[cased_name] = str(v) | ||||
|                         output[cased_name] = str(value) | ||||
|                 elif meta.proto_type == TYPE_BYTES: | ||||
|                     if isinstance(v, list): | ||||
|                         output[cased_name] = [b64encode(b).decode("utf8") for b in v] | ||||
|                     if field_is_repeated: | ||||
|                         output[cased_name] = [ | ||||
|                             b64encode(b).decode("utf8") for b in value | ||||
|                         ] | ||||
|                     else: | ||||
|                         output[cased_name] = b64encode(v).decode("utf8") | ||||
|                         output[cased_name] = b64encode(value).decode("utf8") | ||||
|                 elif meta.proto_type == TYPE_ENUM: | ||||
|                     enum_values = list( | ||||
|                         self._betterproto.cls_by_field[field_name] | ||||
|                     )  # type: ignore | ||||
|                     if isinstance(v, list): | ||||
|                         output[cased_name] = [enum_values[e].name for e in v] | ||||
|                     if field_is_repeated: | ||||
|                         enum_class: Type[Enum] = field_type.__args__[0] | ||||
|                         if isinstance(value, typing.Iterable) and not isinstance( | ||||
|                             value, str | ||||
|                         ): | ||||
|                             output[cased_name] = [enum_class(el).name for el in value] | ||||
|                         else: | ||||
|                             # transparently upgrade single value to repeated | ||||
|                             output[cased_name] = [enum_class(value).name] | ||||
|                     else: | ||||
|                         output[cased_name] = enum_values[v].name | ||||
|                         enum_class: Type[Enum] = field_type  # noqa | ||||
|                         output[cased_name] = enum_class(value).name | ||||
|                 else: | ||||
|                     output[cased_name] = v | ||||
|                     output[cased_name] = value | ||||
|         return output | ||||
|  | ||||
|     def from_dict(self: T, value: dict) -> T: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user