Store the class metadata of fields in the class, to improve preformance
Cached data include, - lookup table between groups and fields of "oneof" fields - default value creator of each field - type hint of each field
This commit is contained in:
		@@ -120,7 +120,11 @@ WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
 | 
			
		||||
DATETIME_ZERO = datetime(1970, 1, 1, tzinfo=timezone.utc)
 | 
			
		||||
def datetime_default_gen():
 | 
			
		||||
    return datetime(1970, 1, 1, tzinfo=timezone.utc)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
DATETIME_ZERO = datetime_default_gen()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Casing(enum.Enum):
 | 
			
		||||
@@ -428,6 +432,57 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
 | 
			
		||||
T = TypeVar("T", bound="Message")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ProtoClassMetadata:
 | 
			
		||||
    cls: "Message"
 | 
			
		||||
 | 
			
		||||
    def __init__(self, cls: "Message"):
 | 
			
		||||
        self.cls = cls
 | 
			
		||||
        by_field = {}
 | 
			
		||||
        by_group = {}
 | 
			
		||||
 | 
			
		||||
        for field in dataclasses.fields(cls):
 | 
			
		||||
            meta = FieldMetadata.get(field)
 | 
			
		||||
 | 
			
		||||
            if meta.group:
 | 
			
		||||
                # This is part of a one-of group.
 | 
			
		||||
                by_field[field.name] = meta.group
 | 
			
		||||
 | 
			
		||||
                by_group.setdefault(meta.group, set()).add(field)
 | 
			
		||||
 | 
			
		||||
        self.oneof_group_by_field = by_field
 | 
			
		||||
        self.oneof_field_by_group = by_group
 | 
			
		||||
 | 
			
		||||
    def __getattr__(self, item):
 | 
			
		||||
        # Lazy init because forward reference classes may not be available at the beginning.
 | 
			
		||||
        if item == 'default_gen':
 | 
			
		||||
            defaults = {}
 | 
			
		||||
            for field in dataclasses.fields(self.cls):
 | 
			
		||||
                meta = FieldMetadata.get(field)
 | 
			
		||||
                defaults[field.name] = self.cls._get_field_default_gen(field, meta)
 | 
			
		||||
 | 
			
		||||
            self.default_gen = defaults  # __getattr__ won't be called next time
 | 
			
		||||
            return defaults
 | 
			
		||||
 | 
			
		||||
        if item == 'cls_by_field':
 | 
			
		||||
            field_cls = {}
 | 
			
		||||
            for field in dataclasses.fields(self.cls):
 | 
			
		||||
                meta = FieldMetadata.get(field)
 | 
			
		||||
                field_cls[field.name] = self.cls._type_hint(field.name)
 | 
			
		||||
 | 
			
		||||
            self.cls_by_field = field_cls  # __getattr__ won't be called next time
 | 
			
		||||
            return field_cls
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_protoclass(cls):
 | 
			
		||||
    setattr(cls, "_betterproto", ProtoClassMetadata(cls))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def protoclass(*args, **kwargs):
 | 
			
		||||
    cls = dataclasses.dataclass(*args, **kwargs)
 | 
			
		||||
    make_protoclass(cls)
 | 
			
		||||
    return cls
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Message(ABC):
 | 
			
		||||
    """
 | 
			
		||||
    A protobuf message base class. Generated code will inherit from this and
 | 
			
		||||
@@ -445,17 +500,12 @@ class Message(ABC):
 | 
			
		||||
 | 
			
		||||
        # Set a default value for each field in the class after `__init__` has
 | 
			
		||||
        # already been run.
 | 
			
		||||
        group_map: Dict[str, dict] = {"fields": {}, "groups": {}}
 | 
			
		||||
        group_map: Dict[str, dataclasses.Field] = {}
 | 
			
		||||
        for field in dataclasses.fields(self):
 | 
			
		||||
            meta = FieldMetadata.get(field)
 | 
			
		||||
 | 
			
		||||
            if meta.group:
 | 
			
		||||
                # This is part of a one-of group.
 | 
			
		||||
                group_map["fields"][field.name] = meta.group
 | 
			
		||||
 | 
			
		||||
                if meta.group not in group_map["groups"]:
 | 
			
		||||
                    group_map["groups"][meta.group] = {"current": None, "fields": set()}
 | 
			
		||||
                group_map["groups"][meta.group]["fields"].add(field)
 | 
			
		||||
                group_map.setdefault(meta.group)
 | 
			
		||||
 | 
			
		||||
            if getattr(self, field.name) != PLACEHOLDER:
 | 
			
		||||
                # Skip anything not set to the sentinel value
 | 
			
		||||
@@ -463,7 +513,7 @@ class Message(ABC):
 | 
			
		||||
 | 
			
		||||
                if meta.group:
 | 
			
		||||
                    # This was set, so make it the selected value of the one-of.
 | 
			
		||||
                    group_map["groups"][meta.group]["current"] = field
 | 
			
		||||
                    group_map[meta.group] = field
 | 
			
		||||
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
@@ -479,11 +529,12 @@ class Message(ABC):
 | 
			
		||||
            # Track when a field has been set.
 | 
			
		||||
            self.__dict__["_serialized_on_wire"] = True
 | 
			
		||||
 | 
			
		||||
        if attr in getattr(self, "_group_map", {}).get("fields", {}):
 | 
			
		||||
            group = self._group_map["fields"][attr]
 | 
			
		||||
            for field in self._group_map["groups"][group]["fields"]:
 | 
			
		||||
        if hasattr(self, "_group_map"):  # __post_init__ had already run
 | 
			
		||||
            if attr in self._betterproto.oneof_group_by_field:
 | 
			
		||||
                group = self._betterproto.oneof_group_by_field[attr]
 | 
			
		||||
                for field in self._betterproto.oneof_field_by_group[group]:
 | 
			
		||||
                    if field.name == attr:
 | 
			
		||||
                    self._group_map["groups"][group]["current"] = field
 | 
			
		||||
                        self._group_map[group] = field
 | 
			
		||||
                    else:
 | 
			
		||||
                        super().__setattr__(
 | 
			
		||||
                            field.name,
 | 
			
		||||
@@ -510,7 +561,7 @@ class Message(ABC):
 | 
			
		||||
            # currently set in a `oneof` group, so it must be serialized even
 | 
			
		||||
            # if the value is the default zero value.
 | 
			
		||||
            selected_in_group = False
 | 
			
		||||
            if meta.group and self._group_map["groups"][meta.group]["current"] == field:
 | 
			
		||||
            if meta.group and self._group_map[meta.group] == field:
 | 
			
		||||
                selected_in_group = True
 | 
			
		||||
 | 
			
		||||
            serialize_empty = False
 | 
			
		||||
@@ -562,47 +613,49 @@ class Message(ABC):
 | 
			
		||||
    # For compatibility with other libraries
 | 
			
		||||
    SerializeToString = __bytes__
 | 
			
		||||
 | 
			
		||||
    def _type_hint(self, field_name: str) -> Type:
 | 
			
		||||
        module = inspect.getmodule(self.__class__)
 | 
			
		||||
        type_hints = get_type_hints(self.__class__, vars(module))
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _type_hint(cls, field_name: str) -> Type:
 | 
			
		||||
        module = inspect.getmodule(cls)
 | 
			
		||||
        type_hints = get_type_hints(cls, vars(module))
 | 
			
		||||
        return type_hints[field_name]
 | 
			
		||||
 | 
			
		||||
    def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
 | 
			
		||||
        """Get the message class for a field from the type hints."""
 | 
			
		||||
        cls = self._type_hint(field.name)
 | 
			
		||||
        cls = self._betterproto.cls_by_field[field.name]
 | 
			
		||||
        if hasattr(cls, "__args__") and index >= 0:
 | 
			
		||||
            cls = cls.__args__[index]
 | 
			
		||||
        return cls
 | 
			
		||||
 | 
			
		||||
    def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
 | 
			
		||||
        t = self._type_hint(field.name)
 | 
			
		||||
        return self._betterproto.default_gen[field.name]()
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def _get_field_default_gen(cls, field: dataclasses.Field, meta: FieldMetadata) -> Any:
 | 
			
		||||
        t = cls._type_hint(field.name)
 | 
			
		||||
 | 
			
		||||
        value: Any = 0
 | 
			
		||||
        if hasattr(t, "__origin__"):
 | 
			
		||||
            if t.__origin__ in (dict, Dict):
 | 
			
		||||
                # This is some kind of map (dict in Python).
 | 
			
		||||
                value = {}
 | 
			
		||||
                return dict
 | 
			
		||||
            elif t.__origin__ in (list, List):
 | 
			
		||||
                # This is some kind of list (repeated) field.
 | 
			
		||||
                value = []
 | 
			
		||||
                return list
 | 
			
		||||
            elif t.__origin__ == Union and t.__args__[1] == type(None):
 | 
			
		||||
                # This is an optional (wrapped) field. For setting the default we
 | 
			
		||||
                # really don't care what kind of field it is.
 | 
			
		||||
                value = None
 | 
			
		||||
                return type(None)
 | 
			
		||||
            else:
 | 
			
		||||
                value = t()
 | 
			
		||||
                return t
 | 
			
		||||
        elif issubclass(t, Enum):
 | 
			
		||||
            # Enums always default to zero.
 | 
			
		||||
            value = 0
 | 
			
		||||
            return int
 | 
			
		||||
        elif t == datetime:
 | 
			
		||||
            # Offsets are relative to 1970-01-01T00:00:00Z
 | 
			
		||||
            value = DATETIME_ZERO
 | 
			
		||||
            return datetime_default_gen
 | 
			
		||||
        else:
 | 
			
		||||
            # This is either a primitive scalar or another message type. Calling
 | 
			
		||||
            # it should result in its zero value.
 | 
			
		||||
            value = t()
 | 
			
		||||
 | 
			
		||||
        return value
 | 
			
		||||
            return t
 | 
			
		||||
 | 
			
		||||
    def _postprocess_single(
 | 
			
		||||
        self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any
 | 
			
		||||
@@ -654,6 +707,7 @@ class Message(ABC):
 | 
			
		||||
                    ],
 | 
			
		||||
                    bases=(Message,),
 | 
			
		||||
                )
 | 
			
		||||
                make_protoclass(Entry)
 | 
			
		||||
                value = Entry().parse(value)
 | 
			
		||||
 | 
			
		||||
        return value
 | 
			
		||||
@@ -861,13 +915,13 @@ def serialized_on_wire(message: Message) -> bool:
 | 
			
		||||
 | 
			
		||||
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
 | 
			
		||||
    """Return the name and value of a message's one-of field group."""
 | 
			
		||||
    field = message._group_map["groups"].get(group_name, {}).get("current")
 | 
			
		||||
    field = message._group_map.get(group_name)
 | 
			
		||||
    if not field:
 | 
			
		||||
        return ("", None)
 | 
			
		||||
    return (field.name, getattr(message, field.name))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
@protoclass
 | 
			
		||||
class _Duration(Message):
 | 
			
		||||
    # Signed seconds of the span of time. Must be from -315,576,000,000 to
 | 
			
		||||
    # +315,576,000,000 inclusive. Note: these bounds are computed from: 60
 | 
			
		||||
@@ -892,7 +946,7 @@ class _Duration(Message):
 | 
			
		||||
        return ".".join(parts) + "s"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
@protoclass
 | 
			
		||||
class _Timestamp(Message):
 | 
			
		||||
    # Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
 | 
			
		||||
    # be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
 | 
			
		||||
@@ -942,47 +996,47 @@ class _WrappedMessage(Message):
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
@protoclass
 | 
			
		||||
class _BoolValue(_WrappedMessage):
 | 
			
		||||
    value: bool = bool_field(1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
@protoclass
 | 
			
		||||
class _Int32Value(_WrappedMessage):
 | 
			
		||||
    value: int = int32_field(1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
@protoclass
 | 
			
		||||
class _UInt32Value(_WrappedMessage):
 | 
			
		||||
    value: int = uint32_field(1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
@protoclass
 | 
			
		||||
class _Int64Value(_WrappedMessage):
 | 
			
		||||
    value: int = int64_field(1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
@protoclass
 | 
			
		||||
class _UInt64Value(_WrappedMessage):
 | 
			
		||||
    value: int = uint64_field(1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
@protoclass
 | 
			
		||||
class _FloatValue(_WrappedMessage):
 | 
			
		||||
    value: float = float_field(1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
@protoclass
 | 
			
		||||
class _DoubleValue(_WrappedMessage):
 | 
			
		||||
    value: float = double_field(1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
@protoclass
 | 
			
		||||
class _StringValue(_WrappedMessage):
 | 
			
		||||
    value: str = string_field(1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
@protoclass
 | 
			
		||||
class _BytesValue(_WrappedMessage):
 | 
			
		||||
    value: bytes = bytes_field(1)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,6 @@
 | 
			
		||||
# Generated by the protocol buffer compiler.  DO NOT EDIT!
 | 
			
		||||
# sources: {{ ', '.join(description.files) }}
 | 
			
		||||
# plugin: python-betterproto
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
{% if description.datetime_imports %}
 | 
			
		||||
from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
 | 
			
		||||
 | 
			
		||||
@@ -38,7 +37,7 @@ class {{ enum.py_name }}(betterproto.Enum):
 | 
			
		||||
{% endfor %}
 | 
			
		||||
{% endif %}
 | 
			
		||||
{% for message in description.messages %}
 | 
			
		||||
@dataclass
 | 
			
		||||
@betterproto.protoclass
 | 
			
		||||
class {{ message.py_name }}(betterproto.Message):
 | 
			
		||||
    {% if message.comment %}
 | 
			
		||||
{{ message.comment }}
 | 
			
		||||
 
 | 
			
		||||
@@ -4,11 +4,11 @@ from typing import Optional
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_has_field():
 | 
			
		||||
    @dataclass
 | 
			
		||||
    @betterproto.protoclass
 | 
			
		||||
    class Bar(betterproto.Message):
 | 
			
		||||
        baz: int = betterproto.int32_field(1)
 | 
			
		||||
 | 
			
		||||
    @dataclass
 | 
			
		||||
    @betterproto.protoclass
 | 
			
		||||
    class Foo(betterproto.Message):
 | 
			
		||||
        bar: Bar = betterproto.message_field(1)
 | 
			
		||||
 | 
			
		||||
@@ -34,11 +34,11 @@ def test_has_field():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_class_init():
 | 
			
		||||
    @dataclass
 | 
			
		||||
    @betterproto.protoclass
 | 
			
		||||
    class Bar(betterproto.Message):
 | 
			
		||||
        name: str = betterproto.string_field(1)
 | 
			
		||||
 | 
			
		||||
    @dataclass
 | 
			
		||||
    @betterproto.protoclass
 | 
			
		||||
    class Foo(betterproto.Message):
 | 
			
		||||
        name: str = betterproto.string_field(1)
 | 
			
		||||
        child: Bar = betterproto.message_field(2)
 | 
			
		||||
@@ -53,7 +53,7 @@ def test_enum_as_int_json():
 | 
			
		||||
        ZERO = 0
 | 
			
		||||
        ONE = 1
 | 
			
		||||
 | 
			
		||||
    @dataclass
 | 
			
		||||
    @betterproto.protoclass
 | 
			
		||||
    class Foo(betterproto.Message):
 | 
			
		||||
        bar: TestEnum = betterproto.enum_field(1)
 | 
			
		||||
 | 
			
		||||
@@ -67,13 +67,13 @@ def test_enum_as_int_json():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_unknown_fields():
 | 
			
		||||
    @dataclass
 | 
			
		||||
    @betterproto.protoclass
 | 
			
		||||
    class Newer(betterproto.Message):
 | 
			
		||||
        foo: bool = betterproto.bool_field(1)
 | 
			
		||||
        bar: int = betterproto.int32_field(2)
 | 
			
		||||
        baz: str = betterproto.string_field(3)
 | 
			
		||||
 | 
			
		||||
    @dataclass
 | 
			
		||||
    @betterproto.protoclass
 | 
			
		||||
    class Older(betterproto.Message):
 | 
			
		||||
        foo: bool = betterproto.bool_field(1)
 | 
			
		||||
 | 
			
		||||
@@ -89,11 +89,11 @@ def test_unknown_fields():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_oneof_support():
 | 
			
		||||
    @dataclass
 | 
			
		||||
    @betterproto.protoclass
 | 
			
		||||
    class Sub(betterproto.Message):
 | 
			
		||||
        val: int = betterproto.int32_field(1)
 | 
			
		||||
 | 
			
		||||
    @dataclass
 | 
			
		||||
    @betterproto.protoclass
 | 
			
		||||
    class Foo(betterproto.Message):
 | 
			
		||||
        bar: int = betterproto.int32_field(1, group="group1")
 | 
			
		||||
        baz: str = betterproto.string_field(2, group="group1")
 | 
			
		||||
@@ -134,7 +134,7 @@ def test_oneof_support():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_json_casing():
 | 
			
		||||
    @dataclass
 | 
			
		||||
    @betterproto.protoclass
 | 
			
		||||
    class CasingTest(betterproto.Message):
 | 
			
		||||
        pascal_case: int = betterproto.int32_field(1)
 | 
			
		||||
        camel_case: int = betterproto.int32_field(2)
 | 
			
		||||
@@ -165,7 +165,7 @@ def test_json_casing():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_optional_flag():
 | 
			
		||||
    @dataclass
 | 
			
		||||
    @betterproto.protoclass
 | 
			
		||||
    class Request(betterproto.Message):
 | 
			
		||||
        flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL)
 | 
			
		||||
 | 
			
		||||
@@ -180,7 +180,7 @@ def test_optional_flag():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_to_dict_default_values():
 | 
			
		||||
    @dataclass
 | 
			
		||||
    @betterproto.protoclass
 | 
			
		||||
    class TestMessage(betterproto.Message):
 | 
			
		||||
        some_int: int = betterproto.int32_field(1)
 | 
			
		||||
        some_double: float = betterproto.double_field(2)
 | 
			
		||||
@@ -210,7 +210,7 @@ def test_to_dict_default_values():
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    # Some default and some other values
 | 
			
		||||
    @dataclass
 | 
			
		||||
    @betterproto.protoclass
 | 
			
		||||
    class TestMessage2(betterproto.Message):
 | 
			
		||||
        some_int: int = betterproto.int32_field(1)
 | 
			
		||||
        some_double: float = betterproto.double_field(2)
 | 
			
		||||
@@ -246,11 +246,11 @@ def test_to_dict_default_values():
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    # Nested messages
 | 
			
		||||
    @dataclass
 | 
			
		||||
    @betterproto.protoclass
 | 
			
		||||
    class TestChildMessage(betterproto.Message):
 | 
			
		||||
        some_other_int: int = betterproto.int32_field(1)
 | 
			
		||||
 | 
			
		||||
    @dataclass
 | 
			
		||||
    @betterproto.protoclass
 | 
			
		||||
    class TestParentMessage(betterproto.Message):
 | 
			
		||||
        some_int: int = betterproto.int32_field(1)
 | 
			
		||||
        some_double: float = betterproto.double_field(2)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user