Fix: to_dict returns wrong enum fields when numbering is not consecutive (#102)
Fixes #93 to_dict returns wrong enum fields when numbering is not consecutive
This commit is contained in:
		| @@ -4,6 +4,7 @@ import inspect | |||||||
| import json | import json | ||||||
| import struct | import struct | ||||||
| import sys | import sys | ||||||
|  | import warnings | ||||||
| from abc import ABC | from abc import ABC | ||||||
| from base64 import b64decode, b64encode | from base64 import b64decode, b64encode | ||||||
| from datetime import datetime, timedelta, timezone | from datetime import datetime, timedelta, timezone | ||||||
| @@ -21,6 +22,8 @@ from typing import ( | |||||||
|     get_type_hints, |     get_type_hints, | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | import typing | ||||||
|  |  | ||||||
| from ._types import T | from ._types import T | ||||||
| from .casing import camel_case, safe_snake_case, snake_case | from .casing import camel_case, safe_snake_case, snake_case | ||||||
| from .grpc.grpclib_client import ServiceStub | from .grpc.grpclib_client import ServiceStub | ||||||
| @@ -251,7 +254,7 @@ def map_field( | |||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Enum(int, enum.Enum): | class Enum(enum.IntEnum): | ||||||
|     """Protocol buffers enumeration base class. Acts like `enum.IntEnum`.""" |     """Protocol buffers enumeration base class. Acts like `enum.IntEnum`.""" | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
| @@ -635,9 +638,13 @@ class Message(ABC): | |||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _type_hint(cls, field_name: str) -> Type: |     def _type_hint(cls, field_name: str) -> Type: | ||||||
|  |         return cls._type_hints()[field_name] | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def _type_hints(cls) -> Dict[str, Type]: | ||||||
|         module = inspect.getmodule(cls) |         module = inspect.getmodule(cls) | ||||||
|         type_hints = get_type_hints(cls, vars(module)) |         type_hints = get_type_hints(cls, vars(module)) | ||||||
|         return type_hints[field_name] |         return type_hints | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type: |     def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type: | ||||||
| @@ -789,55 +796,67 @@ class Message(ABC): | |||||||
|         `False`. |         `False`. | ||||||
|         """ |         """ | ||||||
|         output: Dict[str, Any] = {} |         output: Dict[str, Any] = {} | ||||||
|  |         field_types = self._type_hints() | ||||||
|         for field_name, meta in self._betterproto.meta_by_field_name.items(): |         for field_name, meta in self._betterproto.meta_by_field_name.items(): | ||||||
|             v = getattr(self, field_name) |             field_type = field_types[field_name] | ||||||
|  |             field_is_repeated = type(field_type) is type(typing.List) | ||||||
|  |             value = getattr(self, field_name) | ||||||
|             cased_name = casing(field_name).rstrip("_")  # type: ignore |             cased_name = casing(field_name).rstrip("_")  # type: ignore | ||||||
|             if meta.proto_type == "message": |             if meta.proto_type == TYPE_MESSAGE: | ||||||
|                 if isinstance(v, datetime): |                 if isinstance(value, datetime): | ||||||
|                     if v != DATETIME_ZERO or include_default_values: |                     if value != DATETIME_ZERO or include_default_values: | ||||||
|                         output[cased_name] = _Timestamp.timestamp_to_json(v) |                         output[cased_name] = _Timestamp.timestamp_to_json(value) | ||||||
|                 elif isinstance(v, timedelta): |                 elif isinstance(value, timedelta): | ||||||
|                     if v != timedelta(0) or include_default_values: |                     if value != timedelta(0) or include_default_values: | ||||||
|                         output[cased_name] = _Duration.delta_to_json(v) |                         output[cased_name] = _Duration.delta_to_json(value) | ||||||
|                 elif meta.wraps: |                 elif meta.wraps: | ||||||
|                     if v is not None or include_default_values: |                     if value is not None or include_default_values: | ||||||
|                         output[cased_name] = v |                         output[cased_name] = value | ||||||
|                 elif isinstance(v, list): |                 elif field_is_repeated: | ||||||
|                     # Convert each item. |                     # Convert each item. | ||||||
|                     v = [i.to_dict(casing, include_default_values) for i in v] |                     value = [i.to_dict(casing, include_default_values) for i in value] | ||||||
|                     if v or include_default_values: |                     if value or include_default_values: | ||||||
|                         output[cased_name] = v |                         output[cased_name] = value | ||||||
|                 else: |                 else: | ||||||
|                     if v._serialized_on_wire or include_default_values: |                     if value._serialized_on_wire or include_default_values: | ||||||
|                         output[cased_name] = v.to_dict(casing, include_default_values) |                         output[cased_name] = value.to_dict( | ||||||
|             elif meta.proto_type == "map": |                             casing, include_default_values | ||||||
|                 for k in v: |                         ) | ||||||
|                     if hasattr(v[k], "to_dict"): |             elif meta.proto_type == TYPE_MAP: | ||||||
|                         v[k] = v[k].to_dict(casing, include_default_values) |                 for k in value: | ||||||
|  |                     if hasattr(value[k], "to_dict"): | ||||||
|  |                         value[k] = value[k].to_dict(casing, include_default_values) | ||||||
|  |  | ||||||
|                 if v or include_default_values: |                 if value or include_default_values: | ||||||
|                     output[cased_name] = v |                     output[cased_name] = value | ||||||
|             elif v != self._get_field_default(field_name) or include_default_values: |             elif value != self._get_field_default(field_name) or include_default_values: | ||||||
|                 if meta.proto_type in INT_64_TYPES: |                 if meta.proto_type in INT_64_TYPES: | ||||||
|                     if isinstance(v, list): |                     if field_is_repeated: | ||||||
|                         output[cased_name] = [str(n) for n in v] |                         output[cased_name] = [str(n) for n in value] | ||||||
|                     else: |                     else: | ||||||
|                         output[cased_name] = str(v) |                         output[cased_name] = str(value) | ||||||
|                 elif meta.proto_type == TYPE_BYTES: |                 elif meta.proto_type == TYPE_BYTES: | ||||||
|                     if isinstance(v, list): |                     if field_is_repeated: | ||||||
|                         output[cased_name] = [b64encode(b).decode("utf8") for b in v] |                         output[cased_name] = [ | ||||||
|  |                             b64encode(b).decode("utf8") for b in value | ||||||
|  |                         ] | ||||||
|                     else: |                     else: | ||||||
|                         output[cased_name] = b64encode(v).decode("utf8") |                         output[cased_name] = b64encode(value).decode("utf8") | ||||||
|                 elif meta.proto_type == TYPE_ENUM: |                 elif meta.proto_type == TYPE_ENUM: | ||||||
|                     enum_values = list( |                     if field_is_repeated: | ||||||
|                         self._betterproto.cls_by_field[field_name] |                         enum_class: Type[Enum] = field_type.__args__[0] | ||||||
|                     )  # type: ignore |                         if isinstance(value, typing.Iterable) and not isinstance( | ||||||
|                     if isinstance(v, list): |                             value, str | ||||||
|                         output[cased_name] = [enum_values[e].name for e in v] |                         ): | ||||||
|  |                             output[cased_name] = [enum_class(el).name for el in value] | ||||||
|                         else: |                         else: | ||||||
|                         output[cased_name] = enum_values[v].name |                             # transparently upgrade single value to repeated | ||||||
|  |                             output[cased_name] = [enum_class(value).name] | ||||||
|                     else: |                     else: | ||||||
|                     output[cased_name] = v |                         enum_class: Type[Enum] = field_type  # noqa | ||||||
|  |                         output[cased_name] = enum_class(value).name | ||||||
|  |                 else: | ||||||
|  |                     output[cased_name] = value | ||||||
|         return output |         return output | ||||||
|  |  | ||||||
|     def from_dict(self: T, value: dict) -> T: |     def from_dict(self: T, value: dict) -> T: | ||||||
|   | |||||||
| @@ -5,7 +5,7 @@ xfail = { | |||||||
|     "namespace_keywords",  # 70 |     "namespace_keywords",  # 70 | ||||||
|     "namespace_builtin_types",  # 53 |     "namespace_builtin_types",  # 53 | ||||||
|     "googletypes_struct",  # 9 |     "googletypes_struct",  # 9 | ||||||
|     "googletypes_value",  # 9, |     "googletypes_value",  # 9 | ||||||
|     "import_capitalized_package", |     "import_capitalized_package", | ||||||
|     "example",  # This is the example in the readme. Not a test. |     "example",  # This is the example in the readme. Not a test. | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										9
									
								
								tests/inputs/enum/enum.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								tests/inputs/enum/enum.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | |||||||
|  | { | ||||||
|  |   "choice": "FOUR", | ||||||
|  |   "choices": [ | ||||||
|  |     "ZERO", | ||||||
|  |     "ONE", | ||||||
|  |     "THREE", | ||||||
|  |     "FOUR" | ||||||
|  |   ] | ||||||
|  | } | ||||||
							
								
								
									
										15
									
								
								tests/inputs/enum/enum.proto
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								tests/inputs/enum/enum.proto
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | |||||||
|  | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | // Tests that enums are correctly serialized and that it correctly handles skipped and out-of-order enum values | ||||||
|  | message Test { | ||||||
|  |   Choice choice = 1; | ||||||
|  |   repeated Choice choices = 2; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | enum Choice { | ||||||
|  |   ZERO = 0; | ||||||
|  |   ONE = 1; | ||||||
|  |   // TWO = 2; | ||||||
|  |   FOUR = 4; | ||||||
|  |   THREE = 3; | ||||||
|  | } | ||||||
							
								
								
									
										84
									
								
								tests/inputs/enum/test_enum.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								tests/inputs/enum/test_enum.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,84 @@ | |||||||
|  | from tests.output_betterproto.enum import ( | ||||||
|  |     Test, | ||||||
|  |     Choice, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_enum_set_and_get(): | ||||||
|  |     assert Test(choice=Choice.ZERO).choice == Choice.ZERO | ||||||
|  |     assert Test(choice=Choice.ONE).choice == Choice.ONE | ||||||
|  |     assert Test(choice=Choice.THREE).choice == Choice.THREE | ||||||
|  |     assert Test(choice=Choice.FOUR).choice == Choice.FOUR | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_enum_set_with_int(): | ||||||
|  |     assert Test(choice=0).choice == Choice.ZERO | ||||||
|  |     assert Test(choice=1).choice == Choice.ONE | ||||||
|  |     assert Test(choice=3).choice == Choice.THREE | ||||||
|  |     assert Test(choice=4).choice == Choice.FOUR | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_enum_is_comparable_with_int(): | ||||||
|  |     assert Test(choice=Choice.ZERO).choice == 0 | ||||||
|  |     assert Test(choice=Choice.ONE).choice == 1 | ||||||
|  |     assert Test(choice=Choice.THREE).choice == 3 | ||||||
|  |     assert Test(choice=Choice.FOUR).choice == 4 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_enum_to_dict(): | ||||||
|  |     assert ( | ||||||
|  |         "choice" not in Test(choice=Choice.ZERO).to_dict() | ||||||
|  |     ), "Default enum value is not serialized" | ||||||
|  |     assert ( | ||||||
|  |         Test(choice=Choice.ZERO).to_dict(include_default_values=True)["choice"] | ||||||
|  |         == "ZERO" | ||||||
|  |     ) | ||||||
|  |     assert Test(choice=Choice.ONE).to_dict()["choice"] == "ONE" | ||||||
|  |     assert Test(choice=Choice.THREE).to_dict()["choice"] == "THREE" | ||||||
|  |     assert Test(choice=Choice.FOUR).to_dict()["choice"] == "FOUR" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_repeated_enum_is_comparable_with_int(): | ||||||
|  |     assert Test(choices=[Choice.ZERO]).choices == [0] | ||||||
|  |     assert Test(choices=[Choice.ONE]).choices == [1] | ||||||
|  |     assert Test(choices=[Choice.THREE]).choices == [3] | ||||||
|  |     assert Test(choices=[Choice.FOUR]).choices == [4] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_repeated_enum_set_and_get(): | ||||||
|  |     assert Test(choices=[Choice.ZERO]).choices == [Choice.ZERO] | ||||||
|  |     assert Test(choices=[Choice.ONE]).choices == [Choice.ONE] | ||||||
|  |     assert Test(choices=[Choice.THREE]).choices == [Choice.THREE] | ||||||
|  |     assert Test(choices=[Choice.FOUR]).choices == [Choice.FOUR] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_repeated_enum_to_dict(): | ||||||
|  |     assert Test(choices=[Choice.ZERO]).to_dict()["choices"] == ["ZERO"] | ||||||
|  |     assert Test(choices=[Choice.ONE]).to_dict()["choices"] == ["ONE"] | ||||||
|  |     assert Test(choices=[Choice.THREE]).to_dict()["choices"] == ["THREE"] | ||||||
|  |     assert Test(choices=[Choice.FOUR]).to_dict()["choices"] == ["FOUR"] | ||||||
|  |  | ||||||
|  |     all_enums_dict = Test( | ||||||
|  |         choices=[Choice.ZERO, Choice.ONE, Choice.THREE, Choice.FOUR] | ||||||
|  |     ).to_dict() | ||||||
|  |     assert (all_enums_dict["choices"]) == ["ZERO", "ONE", "THREE", "FOUR"] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_repeated_enum_with_single_value_to_dict(): | ||||||
|  |     assert Test(choices=Choice.ONE).to_dict()["choices"] == ["ONE"] | ||||||
|  |     assert Test(choices=1).to_dict()["choices"] == ["ONE"] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_repeated_enum_with_non_list_iterables_to_dict(): | ||||||
|  |     assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"] | ||||||
|  |     assert Test(choices=(1, 3)).to_dict()["choices"] == ["ONE", "THREE"] | ||||||
|  |     assert Test(choices=(Choice.ONE, Choice.THREE)).to_dict()["choices"] == [ | ||||||
|  |         "ONE", | ||||||
|  |         "THREE", | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |     def enum_generator(): | ||||||
|  |         yield Choice.ONE | ||||||
|  |         yield Choice.THREE | ||||||
|  |  | ||||||
|  |     assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"] | ||||||
| @@ -1,3 +0,0 @@ | |||||||
| { |  | ||||||
|   "greeting": "HEY" |  | ||||||
| } |  | ||||||
| @@ -1,14 +0,0 @@ | |||||||
| syntax = "proto3"; |  | ||||||
|  |  | ||||||
| // Enum for the different greeting types |  | ||||||
| enum Greeting { |  | ||||||
|   HI = 0; |  | ||||||
|   HEY = 1; |  | ||||||
|   // Formal greeting |  | ||||||
|   HELLO = 2; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| message Test { |  | ||||||
|   // Greeting enum example |  | ||||||
|   Greeting greeting = 1; |  | ||||||
| } |  | ||||||
		Reference in New Issue
	
	Block a user