Fix: to_dict returns wrong enum fields when numbering is not consecutive (#102)
Fixes #93 to_dict returns wrong enum fields when numbering is not consecutive
This commit is contained in:
parent
0ba0692dec
commit
6c29771f4c
@ -4,6 +4,7 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
|
import warnings
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from base64 import b64decode, b64encode
|
from base64 import b64decode, b64encode
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
@ -21,6 +22,8 @@ from typing import (
|
|||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import typing
|
||||||
|
|
||||||
from ._types import T
|
from ._types import T
|
||||||
from .casing import camel_case, safe_snake_case, snake_case
|
from .casing import camel_case, safe_snake_case, snake_case
|
||||||
from .grpc.grpclib_client import ServiceStub
|
from .grpc.grpclib_client import ServiceStub
|
||||||
@ -251,7 +254,7 @@ def map_field(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Enum(int, enum.Enum):
|
class Enum(enum.IntEnum):
|
||||||
"""Protocol buffers enumeration base class. Acts like `enum.IntEnum`."""
|
"""Protocol buffers enumeration base class. Acts like `enum.IntEnum`."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -635,9 +638,13 @@ class Message(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _type_hint(cls, field_name: str) -> Type:
|
def _type_hint(cls, field_name: str) -> Type:
|
||||||
|
return cls._type_hints()[field_name]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _type_hints(cls) -> Dict[str, Type]:
|
||||||
module = inspect.getmodule(cls)
|
module = inspect.getmodule(cls)
|
||||||
type_hints = get_type_hints(cls, vars(module))
|
type_hints = get_type_hints(cls, vars(module))
|
||||||
return type_hints[field_name]
|
return type_hints
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
|
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
|
||||||
@ -789,55 +796,67 @@ class Message(ABC):
|
|||||||
`False`.
|
`False`.
|
||||||
"""
|
"""
|
||||||
output: Dict[str, Any] = {}
|
output: Dict[str, Any] = {}
|
||||||
|
field_types = self._type_hints()
|
||||||
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
||||||
v = getattr(self, field_name)
|
field_type = field_types[field_name]
|
||||||
|
field_is_repeated = type(field_type) is type(typing.List)
|
||||||
|
value = getattr(self, field_name)
|
||||||
cased_name = casing(field_name).rstrip("_") # type: ignore
|
cased_name = casing(field_name).rstrip("_") # type: ignore
|
||||||
if meta.proto_type == "message":
|
if meta.proto_type == TYPE_MESSAGE:
|
||||||
if isinstance(v, datetime):
|
if isinstance(value, datetime):
|
||||||
if v != DATETIME_ZERO or include_default_values:
|
if value != DATETIME_ZERO or include_default_values:
|
||||||
output[cased_name] = _Timestamp.timestamp_to_json(v)
|
output[cased_name] = _Timestamp.timestamp_to_json(value)
|
||||||
elif isinstance(v, timedelta):
|
elif isinstance(value, timedelta):
|
||||||
if v != timedelta(0) or include_default_values:
|
if value != timedelta(0) or include_default_values:
|
||||||
output[cased_name] = _Duration.delta_to_json(v)
|
output[cased_name] = _Duration.delta_to_json(value)
|
||||||
elif meta.wraps:
|
elif meta.wraps:
|
||||||
if v is not None or include_default_values:
|
if value is not None or include_default_values:
|
||||||
output[cased_name] = v
|
output[cased_name] = value
|
||||||
elif isinstance(v, list):
|
elif field_is_repeated:
|
||||||
# Convert each item.
|
# Convert each item.
|
||||||
v = [i.to_dict(casing, include_default_values) for i in v]
|
value = [i.to_dict(casing, include_default_values) for i in value]
|
||||||
if v or include_default_values:
|
if value or include_default_values:
|
||||||
output[cased_name] = v
|
output[cased_name] = value
|
||||||
else:
|
else:
|
||||||
if v._serialized_on_wire or include_default_values:
|
if value._serialized_on_wire or include_default_values:
|
||||||
output[cased_name] = v.to_dict(casing, include_default_values)
|
output[cased_name] = value.to_dict(
|
||||||
elif meta.proto_type == "map":
|
casing, include_default_values
|
||||||
for k in v:
|
)
|
||||||
if hasattr(v[k], "to_dict"):
|
elif meta.proto_type == TYPE_MAP:
|
||||||
v[k] = v[k].to_dict(casing, include_default_values)
|
for k in value:
|
||||||
|
if hasattr(value[k], "to_dict"):
|
||||||
|
value[k] = value[k].to_dict(casing, include_default_values)
|
||||||
|
|
||||||
if v or include_default_values:
|
if value or include_default_values:
|
||||||
output[cased_name] = v
|
output[cased_name] = value
|
||||||
elif v != self._get_field_default(field_name) or include_default_values:
|
elif value != self._get_field_default(field_name) or include_default_values:
|
||||||
if meta.proto_type in INT_64_TYPES:
|
if meta.proto_type in INT_64_TYPES:
|
||||||
if isinstance(v, list):
|
if field_is_repeated:
|
||||||
output[cased_name] = [str(n) for n in v]
|
output[cased_name] = [str(n) for n in value]
|
||||||
else:
|
else:
|
||||||
output[cased_name] = str(v)
|
output[cased_name] = str(value)
|
||||||
elif meta.proto_type == TYPE_BYTES:
|
elif meta.proto_type == TYPE_BYTES:
|
||||||
if isinstance(v, list):
|
if field_is_repeated:
|
||||||
output[cased_name] = [b64encode(b).decode("utf8") for b in v]
|
output[cased_name] = [
|
||||||
|
b64encode(b).decode("utf8") for b in value
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
output[cased_name] = b64encode(v).decode("utf8")
|
output[cased_name] = b64encode(value).decode("utf8")
|
||||||
elif meta.proto_type == TYPE_ENUM:
|
elif meta.proto_type == TYPE_ENUM:
|
||||||
enum_values = list(
|
if field_is_repeated:
|
||||||
self._betterproto.cls_by_field[field_name]
|
enum_class: Type[Enum] = field_type.__args__[0]
|
||||||
) # type: ignore
|
if isinstance(value, typing.Iterable) and not isinstance(
|
||||||
if isinstance(v, list):
|
value, str
|
||||||
output[cased_name] = [enum_values[e].name for e in v]
|
):
|
||||||
|
output[cased_name] = [enum_class(el).name for el in value]
|
||||||
else:
|
else:
|
||||||
output[cased_name] = enum_values[v].name
|
# transparently upgrade single value to repeated
|
||||||
|
output[cased_name] = [enum_class(value).name]
|
||||||
else:
|
else:
|
||||||
output[cased_name] = v
|
enum_class: Type[Enum] = field_type # noqa
|
||||||
|
output[cased_name] = enum_class(value).name
|
||||||
|
else:
|
||||||
|
output[cased_name] = value
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def from_dict(self: T, value: dict) -> T:
|
def from_dict(self: T, value: dict) -> T:
|
||||||
|
@ -5,7 +5,7 @@ xfail = {
|
|||||||
"namespace_keywords", # 70
|
"namespace_keywords", # 70
|
||||||
"namespace_builtin_types", # 53
|
"namespace_builtin_types", # 53
|
||||||
"googletypes_struct", # 9
|
"googletypes_struct", # 9
|
||||||
"googletypes_value", # 9,
|
"googletypes_value", # 9
|
||||||
"import_capitalized_package",
|
"import_capitalized_package",
|
||||||
"example", # This is the example in the readme. Not a test.
|
"example", # This is the example in the readme. Not a test.
|
||||||
}
|
}
|
||||||
|
9
tests/inputs/enum/enum.json
Normal file
9
tests/inputs/enum/enum.json
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"choice": "FOUR",
|
||||||
|
"choices": [
|
||||||
|
"ZERO",
|
||||||
|
"ONE",
|
||||||
|
"THREE",
|
||||||
|
"FOUR"
|
||||||
|
]
|
||||||
|
}
|
15
tests/inputs/enum/enum.proto
Normal file
15
tests/inputs/enum/enum.proto
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
// Tests that enums are correctly serialized and that it correctly handles skipped and out-of-order enum values
|
||||||
|
message Test {
|
||||||
|
Choice choice = 1;
|
||||||
|
repeated Choice choices = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
enum Choice {
|
||||||
|
ZERO = 0;
|
||||||
|
ONE = 1;
|
||||||
|
// TWO = 2;
|
||||||
|
FOUR = 4;
|
||||||
|
THREE = 3;
|
||||||
|
}
|
84
tests/inputs/enum/test_enum.py
Normal file
84
tests/inputs/enum/test_enum.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
from tests.output_betterproto.enum import (
|
||||||
|
Test,
|
||||||
|
Choice,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_enum_set_and_get():
|
||||||
|
assert Test(choice=Choice.ZERO).choice == Choice.ZERO
|
||||||
|
assert Test(choice=Choice.ONE).choice == Choice.ONE
|
||||||
|
assert Test(choice=Choice.THREE).choice == Choice.THREE
|
||||||
|
assert Test(choice=Choice.FOUR).choice == Choice.FOUR
|
||||||
|
|
||||||
|
|
||||||
|
def test_enum_set_with_int():
|
||||||
|
assert Test(choice=0).choice == Choice.ZERO
|
||||||
|
assert Test(choice=1).choice == Choice.ONE
|
||||||
|
assert Test(choice=3).choice == Choice.THREE
|
||||||
|
assert Test(choice=4).choice == Choice.FOUR
|
||||||
|
|
||||||
|
|
||||||
|
def test_enum_is_comparable_with_int():
|
||||||
|
assert Test(choice=Choice.ZERO).choice == 0
|
||||||
|
assert Test(choice=Choice.ONE).choice == 1
|
||||||
|
assert Test(choice=Choice.THREE).choice == 3
|
||||||
|
assert Test(choice=Choice.FOUR).choice == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_enum_to_dict():
|
||||||
|
assert (
|
||||||
|
"choice" not in Test(choice=Choice.ZERO).to_dict()
|
||||||
|
), "Default enum value is not serialized"
|
||||||
|
assert (
|
||||||
|
Test(choice=Choice.ZERO).to_dict(include_default_values=True)["choice"]
|
||||||
|
== "ZERO"
|
||||||
|
)
|
||||||
|
assert Test(choice=Choice.ONE).to_dict()["choice"] == "ONE"
|
||||||
|
assert Test(choice=Choice.THREE).to_dict()["choice"] == "THREE"
|
||||||
|
assert Test(choice=Choice.FOUR).to_dict()["choice"] == "FOUR"
|
||||||
|
|
||||||
|
|
||||||
|
def test_repeated_enum_is_comparable_with_int():
|
||||||
|
assert Test(choices=[Choice.ZERO]).choices == [0]
|
||||||
|
assert Test(choices=[Choice.ONE]).choices == [1]
|
||||||
|
assert Test(choices=[Choice.THREE]).choices == [3]
|
||||||
|
assert Test(choices=[Choice.FOUR]).choices == [4]
|
||||||
|
|
||||||
|
|
||||||
|
def test_repeated_enum_set_and_get():
|
||||||
|
assert Test(choices=[Choice.ZERO]).choices == [Choice.ZERO]
|
||||||
|
assert Test(choices=[Choice.ONE]).choices == [Choice.ONE]
|
||||||
|
assert Test(choices=[Choice.THREE]).choices == [Choice.THREE]
|
||||||
|
assert Test(choices=[Choice.FOUR]).choices == [Choice.FOUR]
|
||||||
|
|
||||||
|
|
||||||
|
def test_repeated_enum_to_dict():
|
||||||
|
assert Test(choices=[Choice.ZERO]).to_dict()["choices"] == ["ZERO"]
|
||||||
|
assert Test(choices=[Choice.ONE]).to_dict()["choices"] == ["ONE"]
|
||||||
|
assert Test(choices=[Choice.THREE]).to_dict()["choices"] == ["THREE"]
|
||||||
|
assert Test(choices=[Choice.FOUR]).to_dict()["choices"] == ["FOUR"]
|
||||||
|
|
||||||
|
all_enums_dict = Test(
|
||||||
|
choices=[Choice.ZERO, Choice.ONE, Choice.THREE, Choice.FOUR]
|
||||||
|
).to_dict()
|
||||||
|
assert (all_enums_dict["choices"]) == ["ZERO", "ONE", "THREE", "FOUR"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_repeated_enum_with_single_value_to_dict():
|
||||||
|
assert Test(choices=Choice.ONE).to_dict()["choices"] == ["ONE"]
|
||||||
|
assert Test(choices=1).to_dict()["choices"] == ["ONE"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_repeated_enum_with_non_list_iterables_to_dict():
|
||||||
|
assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"]
|
||||||
|
assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"]
|
||||||
|
assert Test(choices=(Choice.ONE, Choice.THREE)).to_dict()["choices"] == [
|
||||||
|
"ONE",
|
||||||
|
"THREE",
|
||||||
|
]
|
||||||
|
|
||||||
|
def enum_generator():
|
||||||
|
yield Choice.ONE
|
||||||
|
yield Choice.THREE
|
||||||
|
|
||||||
|
assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"]
|
@ -1,3 +0,0 @@
|
|||||||
{
|
|
||||||
"greeting": "HEY"
|
|
||||||
}
|
|
@ -1,14 +0,0 @@
|
|||||||
syntax = "proto3";
|
|
||||||
|
|
||||||
// Enum for the different greeting types
|
|
||||||
enum Greeting {
|
|
||||||
HI = 0;
|
|
||||||
HEY = 1;
|
|
||||||
// Formal greeting
|
|
||||||
HELLO = 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
message Test {
|
|
||||||
// Greeting enum example
|
|
||||||
Greeting greeting = 1;
|
|
||||||
}
|
|
Loading…
x
Reference in New Issue
Block a user