Fix: to_dict returns wrong enum fields when numbering is not consecutive (#102)

Fixes #93 to_dict returns wrong enum fields when numbering is not consecutive
This commit is contained in:
Bouke Versteegh 2020-07-12 15:06:55 +02:00 committed by GitHub
parent 0ba0692dec
commit 6c29771f4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 166 additions and 56 deletions

View File

@ -4,6 +4,7 @@ import inspect
import json import json
import struct import struct
import sys import sys
import warnings
from abc import ABC from abc import ABC
from base64 import b64decode, b64encode from base64 import b64decode, b64encode
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
@ -21,6 +22,8 @@ from typing import (
get_type_hints, get_type_hints,
) )
import typing
from ._types import T from ._types import T
from .casing import camel_case, safe_snake_case, snake_case from .casing import camel_case, safe_snake_case, snake_case
from .grpc.grpclib_client import ServiceStub from .grpc.grpclib_client import ServiceStub
@ -251,7 +254,7 @@ def map_field(
) )
class Enum(int, enum.Enum): class Enum(enum.IntEnum):
"""Protocol buffers enumeration base class. Acts like `enum.IntEnum`.""" """Protocol buffers enumeration base class. Acts like `enum.IntEnum`."""
@classmethod @classmethod
@ -635,9 +638,13 @@ class Message(ABC):
@classmethod @classmethod
def _type_hint(cls, field_name: str) -> Type: def _type_hint(cls, field_name: str) -> Type:
return cls._type_hints()[field_name]
@classmethod
def _type_hints(cls) -> Dict[str, Type]:
module = inspect.getmodule(cls) module = inspect.getmodule(cls)
type_hints = get_type_hints(cls, vars(module)) type_hints = get_type_hints(cls, vars(module))
return type_hints[field_name] return type_hints
@classmethod @classmethod
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type: def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
@ -789,55 +796,67 @@ class Message(ABC):
`False`. `False`.
""" """
output: Dict[str, Any] = {} output: Dict[str, Any] = {}
field_types = self._type_hints()
for field_name, meta in self._betterproto.meta_by_field_name.items(): for field_name, meta in self._betterproto.meta_by_field_name.items():
v = getattr(self, field_name) field_type = field_types[field_name]
field_is_repeated = type(field_type) is type(typing.List)
value = getattr(self, field_name)
cased_name = casing(field_name).rstrip("_") # type: ignore cased_name = casing(field_name).rstrip("_") # type: ignore
if meta.proto_type == "message": if meta.proto_type == TYPE_MESSAGE:
if isinstance(v, datetime): if isinstance(value, datetime):
if v != DATETIME_ZERO or include_default_values: if value != DATETIME_ZERO or include_default_values:
output[cased_name] = _Timestamp.timestamp_to_json(v) output[cased_name] = _Timestamp.timestamp_to_json(value)
elif isinstance(v, timedelta): elif isinstance(value, timedelta):
if v != timedelta(0) or include_default_values: if value != timedelta(0) or include_default_values:
output[cased_name] = _Duration.delta_to_json(v) output[cased_name] = _Duration.delta_to_json(value)
elif meta.wraps: elif meta.wraps:
if v is not None or include_default_values: if value is not None or include_default_values:
output[cased_name] = v output[cased_name] = value
elif isinstance(v, list): elif field_is_repeated:
# Convert each item. # Convert each item.
v = [i.to_dict(casing, include_default_values) for i in v] value = [i.to_dict(casing, include_default_values) for i in value]
if v or include_default_values: if value or include_default_values:
output[cased_name] = v output[cased_name] = value
else: else:
if v._serialized_on_wire or include_default_values: if value._serialized_on_wire or include_default_values:
output[cased_name] = v.to_dict(casing, include_default_values) output[cased_name] = value.to_dict(
elif meta.proto_type == "map": casing, include_default_values
for k in v: )
if hasattr(v[k], "to_dict"): elif meta.proto_type == TYPE_MAP:
v[k] = v[k].to_dict(casing, include_default_values) for k in value:
if hasattr(value[k], "to_dict"):
value[k] = value[k].to_dict(casing, include_default_values)
if v or include_default_values: if value or include_default_values:
output[cased_name] = v output[cased_name] = value
elif v != self._get_field_default(field_name) or include_default_values: elif value != self._get_field_default(field_name) or include_default_values:
if meta.proto_type in INT_64_TYPES: if meta.proto_type in INT_64_TYPES:
if isinstance(v, list): if field_is_repeated:
output[cased_name] = [str(n) for n in v] output[cased_name] = [str(n) for n in value]
else: else:
output[cased_name] = str(v) output[cased_name] = str(value)
elif meta.proto_type == TYPE_BYTES: elif meta.proto_type == TYPE_BYTES:
if isinstance(v, list): if field_is_repeated:
output[cased_name] = [b64encode(b).decode("utf8") for b in v] output[cased_name] = [
b64encode(b).decode("utf8") for b in value
]
else: else:
output[cased_name] = b64encode(v).decode("utf8") output[cased_name] = b64encode(value).decode("utf8")
elif meta.proto_type == TYPE_ENUM: elif meta.proto_type == TYPE_ENUM:
enum_values = list( if field_is_repeated:
self._betterproto.cls_by_field[field_name] enum_class: Type[Enum] = field_type.__args__[0]
) # type: ignore if isinstance(value, typing.Iterable) and not isinstance(
if isinstance(v, list): value, str
output[cased_name] = [enum_values[e].name for e in v] ):
output[cased_name] = [enum_class(el).name for el in value]
else:
# transparently upgrade single value to repeated
output[cased_name] = [enum_class(value).name]
else: else:
output[cased_name] = enum_values[v].name enum_class: Type[Enum] = field_type # noqa
output[cased_name] = enum_class(value).name
else: else:
output[cased_name] = v output[cased_name] = value
return output return output
def from_dict(self: T, value: dict) -> T: def from_dict(self: T, value: dict) -> T:

View File

@ -5,7 +5,7 @@ xfail = {
"namespace_keywords", # 70 "namespace_keywords", # 70
"namespace_builtin_types", # 53 "namespace_builtin_types", # 53
"googletypes_struct", # 9 "googletypes_struct", # 9
"googletypes_value", # 9, "googletypes_value", # 9
"import_capitalized_package", "import_capitalized_package",
"example", # This is the example in the readme. Not a test. "example", # This is the example in the readme. Not a test.
} }

View File

@ -0,0 +1,9 @@
{
"choice": "FOUR",
"choices": [
"ZERO",
"ONE",
"THREE",
"FOUR"
]
}

View File

@ -0,0 +1,15 @@
syntax = "proto3";
// Tests that enums are correctly serialized and that it correctly handles skipped and out-of-order enum values
message Test {
Choice choice = 1;
repeated Choice choices = 2;
}
enum Choice {
ZERO = 0;
ONE = 1;
// TWO = 2;
FOUR = 4;
THREE = 3;
}

View File

@ -0,0 +1,84 @@
from tests.output_betterproto.enum import (
Test,
Choice,
)
def test_enum_set_and_get():
assert Test(choice=Choice.ZERO).choice == Choice.ZERO
assert Test(choice=Choice.ONE).choice == Choice.ONE
assert Test(choice=Choice.THREE).choice == Choice.THREE
assert Test(choice=Choice.FOUR).choice == Choice.FOUR
def test_enum_set_with_int():
assert Test(choice=0).choice == Choice.ZERO
assert Test(choice=1).choice == Choice.ONE
assert Test(choice=3).choice == Choice.THREE
assert Test(choice=4).choice == Choice.FOUR
def test_enum_is_comparable_with_int():
assert Test(choice=Choice.ZERO).choice == 0
assert Test(choice=Choice.ONE).choice == 1
assert Test(choice=Choice.THREE).choice == 3
assert Test(choice=Choice.FOUR).choice == 4
def test_enum_to_dict():
assert (
"choice" not in Test(choice=Choice.ZERO).to_dict()
), "Default enum value is not serialized"
assert (
Test(choice=Choice.ZERO).to_dict(include_default_values=True)["choice"]
== "ZERO"
)
assert Test(choice=Choice.ONE).to_dict()["choice"] == "ONE"
assert Test(choice=Choice.THREE).to_dict()["choice"] == "THREE"
assert Test(choice=Choice.FOUR).to_dict()["choice"] == "FOUR"
def test_repeated_enum_is_comparable_with_int():
assert Test(choices=[Choice.ZERO]).choices == [0]
assert Test(choices=[Choice.ONE]).choices == [1]
assert Test(choices=[Choice.THREE]).choices == [3]
assert Test(choices=[Choice.FOUR]).choices == [4]
def test_repeated_enum_set_and_get():
assert Test(choices=[Choice.ZERO]).choices == [Choice.ZERO]
assert Test(choices=[Choice.ONE]).choices == [Choice.ONE]
assert Test(choices=[Choice.THREE]).choices == [Choice.THREE]
assert Test(choices=[Choice.FOUR]).choices == [Choice.FOUR]
def test_repeated_enum_to_dict():
assert Test(choices=[Choice.ZERO]).to_dict()["choices"] == ["ZERO"]
assert Test(choices=[Choice.ONE]).to_dict()["choices"] == ["ONE"]
assert Test(choices=[Choice.THREE]).to_dict()["choices"] == ["THREE"]
assert Test(choices=[Choice.FOUR]).to_dict()["choices"] == ["FOUR"]
all_enums_dict = Test(
choices=[Choice.ZERO, Choice.ONE, Choice.THREE, Choice.FOUR]
).to_dict()
assert (all_enums_dict["choices"]) == ["ZERO", "ONE", "THREE", "FOUR"]
def test_repeated_enum_with_single_value_to_dict():
assert Test(choices=Choice.ONE).to_dict()["choices"] == ["ONE"]
assert Test(choices=1).to_dict()["choices"] == ["ONE"]
def test_repeated_enum_with_non_list_iterables_to_dict():
assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"]
assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"]
assert Test(choices=(Choice.ONE, Choice.THREE)).to_dict()["choices"] == [
"ONE",
"THREE",
]
def enum_generator():
yield Choice.ONE
yield Choice.THREE
assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"]

View File

@ -1,3 +0,0 @@
{
"greeting": "HEY"
}

View File

@ -1,14 +0,0 @@
syntax = "proto3";
// Enum for the different greeting types
enum Greeting {
HI = 0;
HEY = 1;
// Formal greeting
HELLO = 2;
}
message Test {
// Greeting enum example
Greeting greeting = 1;
}