Map enum int's into Enums redux (#293)

Re-implement Enum to be faster along with being an open set

---------
Co-authored-by: ydylla <ydylla@gmail.com>
This commit is contained in:
James Hilton-Balfe 2023-10-16 03:32:30 +01:00 committed by GitHub
parent 8659c51123
commit c82816b8be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 725 additions and 511 deletions

897
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -19,6 +19,7 @@ importlib-metadata = { version = ">=1.6.0", python = "<3.8" }
jinja2 = { version = ">=3.0.3", optional = true } jinja2 = { version = ">=3.0.3", optional = true }
python-dateutil = "^2.8" python-dateutil = "^2.8"
isort = {version = "^5.11.5", optional = true} isort = {version = "^5.11.5", optional = true}
typing-extensions = "^4.7.1"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
asv = "^0.4.2" asv = "^0.4.2"
@ -37,7 +38,7 @@ sphinx-rtd-theme = "0.5.0"
tomlkit = "^0.7.0" tomlkit = "^0.7.0"
tox = "^3.15.1" tox = "^3.15.1"
pre-commit = "^2.17.0" pre-commit = "^2.17.0"
pydantic = ">=1.8.0" pydantic = ">=1.8.0,<2"
[tool.poetry.scripts] [tool.poetry.scripts]

View File

@ -1,5 +1,5 @@
import dataclasses import dataclasses
import enum import enum as builtin_enum
import json import json
import math import math
import struct import struct
@ -45,7 +45,8 @@ from .casing import (
safe_snake_case, safe_snake_case,
snake_case, snake_case,
) )
from .grpc.grpclib_client import ServiceStub from .enum import Enum as Enum
from .grpc.grpclib_client import ServiceStub as ServiceStub
if TYPE_CHECKING: if TYPE_CHECKING:
@ -140,7 +141,7 @@ NEG_INFINITY = "-Infinity"
NAN = "NaN" NAN = "NaN"
class Casing(enum.Enum): class Casing(builtin_enum.Enum):
"""Casing constants for serialization.""" """Casing constants for serialization."""
CAMEL = camel_case #: A camelCase sterilization function. CAMEL = camel_case #: A camelCase sterilization function.
@ -309,32 +310,6 @@ def map_field(
) )
class Enum(enum.IntEnum):
"""
The base class for protobuf enumerations, all generated enumerations will inherit
from this. Bases :class:`enum.IntEnum`.
"""
@classmethod
def from_string(cls, name: str) -> "Enum":
"""Return the value which corresponds to the string name.
Parameters
-----------
name: :class:`str`
The name of the enum member to get
Raises
-------
:exc:`ValueError`
The member was not found in the Enum.
"""
try:
return cls._member_map_[name] # type: ignore
except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
def _pack_fmt(proto_type: str) -> str: def _pack_fmt(proto_type: str) -> str:
"""Returns a little-endian format string for reading/writing binary.""" """Returns a little-endian format string for reading/writing binary."""
return { return {
@ -1168,7 +1143,7 @@ class Message(ABC):
return t return t
elif issubclass(t, Enum): elif issubclass(t, Enum):
# Enums always default to zero. # Enums always default to zero.
return int return t.try_value
elif t is datetime: elif t is datetime:
# Offsets are relative to 1970-01-01T00:00:00Z # Offsets are relative to 1970-01-01T00:00:00Z
return datetime_default_gen return datetime_default_gen
@ -1193,6 +1168,9 @@ class Message(ABC):
elif meta.proto_type == TYPE_BOOL: elif meta.proto_type == TYPE_BOOL:
# Booleans use a varint encoding, so convert it to true/false. # Booleans use a varint encoding, so convert it to true/false.
value = value > 0 value = value > 0
elif meta.proto_type == TYPE_ENUM:
# Convert enum ints to python enum instances
value = self._betterproto.cls_by_field[field_name].try_value(value)
elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64): elif wire_type in (WIRE_FIXED_32, WIRE_FIXED_64):
fmt = _pack_fmt(meta.proto_type) fmt = _pack_fmt(meta.proto_type)
value = struct.unpack(fmt, value)[0] value = struct.unpack(fmt, value)[0]

199
src/betterproto/enum.py Normal file
View File

@ -0,0 +1,199 @@
from __future__ import annotations
import sys
from enum import (
EnumMeta,
IntEnum,
)
from types import MappingProxyType
from typing import (
TYPE_CHECKING,
Any,
Dict,
Optional,
Tuple,
)
if TYPE_CHECKING:
from collections.abc import (
Generator,
Mapping,
)
from typing_extensions import (
Never,
Self,
)
def _is_descriptor(obj: object) -> bool:
return (
hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__")
)
class EnumType(EnumMeta if TYPE_CHECKING else type):
_value_map_: Mapping[int, Enum]
_member_map_: Mapping[str, Enum]
def __new__(
mcs, name: str, bases: Tuple[type, ...], namespace: Dict[str, Any]
) -> Self:
value_map = {}
member_map = {}
new_mcs = type(
f"{name}Type",
tuple(
dict.fromkeys(
[base.__class__ for base in bases if base.__class__ is not type]
+ [EnumType, type]
)
), # reorder the bases so EnumType and type are last to avoid conflicts
{"_value_map_": value_map, "_member_map_": member_map},
)
members = {
name: value
for name, value in namespace.items()
if not _is_descriptor(value) and name[0] != "_"
}
cls = type.__new__(
new_mcs,
name,
bases,
{key: value for key, value in namespace.items() if key not in members},
)
# this allows us to disallow member access from other members as
# members become proper class variables
for name, value in members.items():
if _is_descriptor(value) or name[0] == "_":
continue
member = value_map.get(value)
if member is None:
member = cls.__new__(cls, name=name, value=value) # type: ignore
value_map[value] = member
member_map[name] = member
type.__setattr__(new_mcs, name, member)
return cls
if not TYPE_CHECKING:
def __call__(cls, value: int) -> Enum:
try:
return cls._value_map_[value]
except (KeyError, TypeError):
raise ValueError(f"{value!r} is not a valid {cls.__name__}") from None
def __iter__(cls) -> Generator[Enum, None, None]:
yield from cls._member_map_.values()
if sys.version_info >= (3, 8): # 3.8 added __reversed__ to dict_values
def __reversed__(cls) -> Generator[Enum, None, None]:
yield from reversed(cls._member_map_.values())
else:
def __reversed__(cls) -> Generator[Enum, None, None]:
yield from reversed(tuple(cls._member_map_.values()))
def __getitem__(cls, key: str) -> Enum:
return cls._member_map_[key]
@property
def __members__(cls) -> MappingProxyType[str, Enum]:
return MappingProxyType(cls._member_map_)
def __repr__(cls) -> str:
return f"<enum {cls.__name__!r}>"
def __len__(cls) -> int:
return len(cls._member_map_)
def __setattr__(cls, name: str, value: Any) -> Never:
raise AttributeError(f"{cls.__name__}: cannot reassign Enum members.")
def __delattr__(cls, name: str) -> Never:
raise AttributeError(f"{cls.__name__}: cannot delete Enum members.")
def __contains__(cls, member: object) -> bool:
return isinstance(member, cls) and member.name in cls._member_map_
class Enum(IntEnum if TYPE_CHECKING else int, metaclass=EnumType):
"""
The base class for protobuf enumerations, all generated enumerations will
inherit from this. Emulates `enum.IntEnum`.
"""
name: Optional[str]
value: int
if not TYPE_CHECKING:
def __new__(cls, *, name: Optional[str], value: int) -> Self:
self = super().__new__(cls, value)
super().__setattr__(self, "name", name)
super().__setattr__(self, "value", value)
return self
def __str__(self) -> str:
return self.name or "None"
def __repr__(self) -> str:
return f"{self.__class__.__name__}.{self.name}"
def __setattr__(self, key: str, value: Any) -> Never:
raise AttributeError(
f"{self.__class__.__name__} Cannot reassign a member's attributes."
)
def __delattr__(self, item: Any) -> Never:
raise AttributeError(
f"{self.__class__.__name__} Cannot delete a member's attributes."
)
@classmethod
def try_value(cls, value: int = 0) -> Self:
"""Return the value which corresponds to the value.
Parameters
-----------
value: :class:`int`
The value of the enum member to get.
Returns
-------
:class:`Enum`
The corresponding member or a new instance of the enum if
``value`` isn't actually a member.
"""
try:
return cls._value_map_[value]
except (KeyError, TypeError):
return cls.__new__(cls, name=None, value=value)
@classmethod
def from_string(cls, name: str) -> Self:
"""Return the value which corresponds to the string name.
Parameters
-----------
name: :class:`str`
The name of the enum member to get.
Raises
-------
:exc:`ValueError`
The member was not found in the Enum.
"""
try:
return cls._member_map_[name]
except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e

View File

@ -82,3 +82,23 @@ def test_repeated_enum_with_non_list_iterables_to_dict():
yield Choice.THREE yield Choice.THREE
assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"] assert Test(choices=enum_generator()).to_dict()["choices"] == ["ONE", "THREE"]
def test_enum_mapped_on_parse():
# test default value
b = Test().parse(bytes(Test()))
assert b.choice.name == Choice.ZERO.name
assert b.choices == []
# test non default value
a = Test().parse(bytes(Test(choice=Choice.ONE)))
assert a.choice.name == Choice.ONE.name
assert b.choices == []
# test repeated
c = Test().parse(bytes(Test(choices=[Choice.THREE, Choice.FOUR])))
assert c.choices[0].name == Choice.THREE.name
assert c.choices[1].name == Choice.FOUR.name
# bonus: defaults after empty init are also mapped
assert Test().choice.name == Choice.ZERO.name

79
tests/test_enum.py Normal file
View File

@ -0,0 +1,79 @@
from typing import (
Optional,
Tuple,
)
import pytest
import betterproto
class Colour(betterproto.Enum):
RED = 1
GREEN = 2
BLUE = 3
PURPLE = Colour.__new__(Colour, name=None, value=4)
@pytest.mark.parametrize(
"member, str_value",
[
(Colour.RED, "RED"),
(Colour.GREEN, "GREEN"),
(Colour.BLUE, "BLUE"),
],
)
def test_str(member: Colour, str_value: str) -> None:
assert str(member) == str_value
@pytest.mark.parametrize(
"member, repr_value",
[
(Colour.RED, "Colour.RED"),
(Colour.GREEN, "Colour.GREEN"),
(Colour.BLUE, "Colour.BLUE"),
],
)
def test_repr(member: Colour, repr_value: str) -> None:
assert repr(member) == repr_value
@pytest.mark.parametrize(
"member, values",
[
(Colour.RED, ("RED", 1)),
(Colour.GREEN, ("GREEN", 2)),
(Colour.BLUE, ("BLUE", 3)),
(PURPLE, (None, 4)),
],
)
def test_name_values(member: Colour, values: Tuple[Optional[str], int]) -> None:
assert (member.name, member.value) == values
@pytest.mark.parametrize(
"member, input_str",
[
(Colour.RED, "RED"),
(Colour.GREEN, "GREEN"),
(Colour.BLUE, "BLUE"),
],
)
def test_from_string(member: Colour, input_str: str) -> None:
assert Colour.from_string(input_str) == member
@pytest.mark.parametrize(
"member, input_int",
[
(Colour.RED, 1),
(Colour.GREEN, 2),
(Colour.BLUE, 3),
(PURPLE, 4),
],
)
def test_try_value(member: Colour, input_int: int) -> None:
assert Colour.try_value(input_int) == member