Store the class metadata of fields in the class, to improve preformance
Cached data include, - lookup table between groups and fields of "oneof" fields - default value creator of each field - type hint of each field
This commit is contained in:
		| @@ -120,7 +120,11 @@ WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] | ||||
|  | ||||
|  | ||||
| # Protobuf datetimes start at the Unix Epoch in 1970 in UTC. | ||||
| DATETIME_ZERO = datetime(1970, 1, 1, tzinfo=timezone.utc) | ||||
| def datetime_default_gen(): | ||||
|     return datetime(1970, 1, 1, tzinfo=timezone.utc) | ||||
|  | ||||
|  | ||||
| DATETIME_ZERO = datetime_default_gen() | ||||
|  | ||||
|  | ||||
| class Casing(enum.Enum): | ||||
| @@ -428,6 +432,57 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: | ||||
| T = TypeVar("T", bound="Message") | ||||
|  | ||||
|  | ||||
| class ProtoClassMetadata: | ||||
|     cls: "Message" | ||||
|  | ||||
|     def __init__(self, cls: "Message"): | ||||
|         self.cls = cls | ||||
|         by_field = {} | ||||
|         by_group = {} | ||||
|  | ||||
|         for field in dataclasses.fields(cls): | ||||
|             meta = FieldMetadata.get(field) | ||||
|  | ||||
|             if meta.group: | ||||
|                 # This is part of a one-of group. | ||||
|                 by_field[field.name] = meta.group | ||||
|  | ||||
|                 by_group.setdefault(meta.group, set()).add(field) | ||||
|  | ||||
|         self.oneof_group_by_field = by_field | ||||
|         self.oneof_field_by_group = by_group | ||||
|  | ||||
|     def __getattr__(self, item): | ||||
|         # Lazy init because forward reference classes may not be available at the beginning. | ||||
|         if item == 'default_gen': | ||||
|             defaults = {} | ||||
|             for field in dataclasses.fields(self.cls): | ||||
|                 meta = FieldMetadata.get(field) | ||||
|                 defaults[field.name] = self.cls._get_field_default_gen(field, meta) | ||||
|  | ||||
|             self.default_gen = defaults  # __getattr__ won't be called next time | ||||
|             return defaults | ||||
|  | ||||
|         if item == 'cls_by_field': | ||||
|             field_cls = {} | ||||
|             for field in dataclasses.fields(self.cls): | ||||
|                 meta = FieldMetadata.get(field) | ||||
|                 field_cls[field.name] = self.cls._type_hint(field.name) | ||||
|  | ||||
|             self.cls_by_field = field_cls  # __getattr__ won't be called next time | ||||
|             return field_cls | ||||
|  | ||||
|  | ||||
| def make_protoclass(cls): | ||||
|     setattr(cls, "_betterproto", ProtoClassMetadata(cls)) | ||||
|  | ||||
|  | ||||
| def protoclass(*args, **kwargs): | ||||
|     cls = dataclasses.dataclass(*args, **kwargs) | ||||
|     make_protoclass(cls) | ||||
|     return cls | ||||
|  | ||||
|  | ||||
| class Message(ABC): | ||||
|     """ | ||||
|     A protobuf message base class. Generated code will inherit from this and | ||||
| @@ -445,17 +500,12 @@ class Message(ABC): | ||||
|  | ||||
|         # Set a default value for each field in the class after `__init__` has | ||||
|         # already been run. | ||||
|         group_map: Dict[str, dict] = {"fields": {}, "groups": {}} | ||||
|         group_map: Dict[str, dataclasses.Field] = {} | ||||
|         for field in dataclasses.fields(self): | ||||
|             meta = FieldMetadata.get(field) | ||||
|  | ||||
|             if meta.group: | ||||
|                 # This is part of a one-of group. | ||||
|                 group_map["fields"][field.name] = meta.group | ||||
|  | ||||
|                 if meta.group not in group_map["groups"]: | ||||
|                     group_map["groups"][meta.group] = {"current": None, "fields": set()} | ||||
|                 group_map["groups"][meta.group]["fields"].add(field) | ||||
|                 group_map.setdefault(meta.group) | ||||
|  | ||||
|             if getattr(self, field.name) != PLACEHOLDER: | ||||
|                 # Skip anything not set to the sentinel value | ||||
| @@ -463,7 +513,7 @@ class Message(ABC): | ||||
|  | ||||
|                 if meta.group: | ||||
|                     # This was set, so make it the selected value of the one-of. | ||||
|                     group_map["groups"][meta.group]["current"] = field | ||||
|                     group_map[meta.group] = field | ||||
|  | ||||
|                 continue | ||||
|  | ||||
| @@ -479,16 +529,17 @@ class Message(ABC): | ||||
|             # Track when a field has been set. | ||||
|             self.__dict__["_serialized_on_wire"] = True | ||||
|  | ||||
|         if attr in getattr(self, "_group_map", {}).get("fields", {}): | ||||
|             group = self._group_map["fields"][attr] | ||||
|             for field in self._group_map["groups"][group]["fields"]: | ||||
|                 if field.name == attr: | ||||
|                     self._group_map["groups"][group]["current"] = field | ||||
|                 else: | ||||
|                     super().__setattr__( | ||||
|                         field.name, | ||||
|                         self._get_field_default(field, FieldMetadata.get(field)), | ||||
|                     ) | ||||
|         if hasattr(self, "_group_map"):  # __post_init__ had already run | ||||
|             if attr in self._betterproto.oneof_group_by_field: | ||||
|                 group = self._betterproto.oneof_group_by_field[attr] | ||||
|                 for field in self._betterproto.oneof_field_by_group[group]: | ||||
|                     if field.name == attr: | ||||
|                         self._group_map[group] = field | ||||
|                     else: | ||||
|                         super().__setattr__( | ||||
|                             field.name, | ||||
|                             self._get_field_default(field, FieldMetadata.get(field)), | ||||
|                         ) | ||||
|  | ||||
|         super().__setattr__(attr, value) | ||||
|  | ||||
| @@ -510,7 +561,7 @@ class Message(ABC): | ||||
|             # currently set in a `oneof` group, so it must be serialized even | ||||
|             # if the value is the default zero value. | ||||
|             selected_in_group = False | ||||
|             if meta.group and self._group_map["groups"][meta.group]["current"] == field: | ||||
|             if meta.group and self._group_map[meta.group] == field: | ||||
|                 selected_in_group = True | ||||
|  | ||||
|             serialize_empty = False | ||||
| @@ -562,47 +613,49 @@ class Message(ABC): | ||||
|     # For compatibility with other libraries | ||||
|     SerializeToString = __bytes__ | ||||
|  | ||||
|     def _type_hint(self, field_name: str) -> Type: | ||||
|         module = inspect.getmodule(self.__class__) | ||||
|         type_hints = get_type_hints(self.__class__, vars(module)) | ||||
|     @classmethod | ||||
|     def _type_hint(cls, field_name: str) -> Type: | ||||
|         module = inspect.getmodule(cls) | ||||
|         type_hints = get_type_hints(cls, vars(module)) | ||||
|         return type_hints[field_name] | ||||
|  | ||||
|     def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type: | ||||
|         """Get the message class for a field from the type hints.""" | ||||
|         cls = self._type_hint(field.name) | ||||
|         cls = self._betterproto.cls_by_field[field.name] | ||||
|         if hasattr(cls, "__args__") and index >= 0: | ||||
|             cls = cls.__args__[index] | ||||
|         return cls | ||||
|  | ||||
|     def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any: | ||||
|         t = self._type_hint(field.name) | ||||
|         return self._betterproto.default_gen[field.name]() | ||||
|  | ||||
|     @classmethod | ||||
|     def _get_field_default_gen(cls, field: dataclasses.Field, meta: FieldMetadata) -> Any: | ||||
|         t = cls._type_hint(field.name) | ||||
|  | ||||
|         value: Any = 0 | ||||
|         if hasattr(t, "__origin__"): | ||||
|             if t.__origin__ in (dict, Dict): | ||||
|                 # This is some kind of map (dict in Python). | ||||
|                 value = {} | ||||
|                 return dict | ||||
|             elif t.__origin__ in (list, List): | ||||
|                 # This is some kind of list (repeated) field. | ||||
|                 value = [] | ||||
|                 return list | ||||
|             elif t.__origin__ == Union and t.__args__[1] == type(None): | ||||
|                 # This is an optional (wrapped) field. For setting the default we | ||||
|                 # really don't care what kind of field it is. | ||||
|                 value = None | ||||
|                 return type(None) | ||||
|             else: | ||||
|                 value = t() | ||||
|                 return t | ||||
|         elif issubclass(t, Enum): | ||||
|             # Enums always default to zero. | ||||
|             value = 0 | ||||
|             return int | ||||
|         elif t == datetime: | ||||
|             # Offsets are relative to 1970-01-01T00:00:00Z | ||||
|             value = DATETIME_ZERO | ||||
|             return datetime_default_gen | ||||
|         else: | ||||
|             # This is either a primitive scalar or another message type. Calling | ||||
|             # it should result in its zero value. | ||||
|             value = t() | ||||
|  | ||||
|         return value | ||||
|             return t | ||||
|  | ||||
|     def _postprocess_single( | ||||
|         self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any | ||||
| @@ -654,6 +707,7 @@ class Message(ABC): | ||||
|                     ], | ||||
|                     bases=(Message,), | ||||
|                 ) | ||||
|                 make_protoclass(Entry) | ||||
|                 value = Entry().parse(value) | ||||
|  | ||||
|         return value | ||||
| @@ -861,13 +915,13 @@ def serialized_on_wire(message: Message) -> bool: | ||||
|  | ||||
| def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]: | ||||
|     """Return the name and value of a message's one-of field group.""" | ||||
|     field = message._group_map["groups"].get(group_name, {}).get("current") | ||||
|     field = message._group_map.get(group_name) | ||||
|     if not field: | ||||
|         return ("", None) | ||||
|     return (field.name, getattr(message, field.name)) | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass | ||||
| @protoclass | ||||
| class _Duration(Message): | ||||
|     # Signed seconds of the span of time. Must be from -315,576,000,000 to | ||||
|     # +315,576,000,000 inclusive. Note: these bounds are computed from: 60 | ||||
| @@ -892,7 +946,7 @@ class _Duration(Message): | ||||
|         return ".".join(parts) + "s" | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass | ||||
| @protoclass | ||||
| class _Timestamp(Message): | ||||
|     # Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must | ||||
|     # be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive. | ||||
| @@ -942,47 +996,47 @@ class _WrappedMessage(Message): | ||||
|         return self | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass | ||||
| @protoclass | ||||
| class _BoolValue(_WrappedMessage): | ||||
|     value: bool = bool_field(1) | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass | ||||
| @protoclass | ||||
| class _Int32Value(_WrappedMessage): | ||||
|     value: int = int32_field(1) | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass | ||||
| @protoclass | ||||
| class _UInt32Value(_WrappedMessage): | ||||
|     value: int = uint32_field(1) | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass | ||||
| @protoclass | ||||
| class _Int64Value(_WrappedMessage): | ||||
|     value: int = int64_field(1) | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass | ||||
| @protoclass | ||||
| class _UInt64Value(_WrappedMessage): | ||||
|     value: int = uint64_field(1) | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass | ||||
| @protoclass | ||||
| class _FloatValue(_WrappedMessage): | ||||
|     value: float = float_field(1) | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass | ||||
| @protoclass | ||||
| class _DoubleValue(_WrappedMessage): | ||||
|     value: float = double_field(1) | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass | ||||
| @protoclass | ||||
| class _StringValue(_WrappedMessage): | ||||
|     value: str = string_field(1) | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass | ||||
| @protoclass | ||||
| class _BytesValue(_WrappedMessage): | ||||
|     value: bytes = bytes_field(1) | ||||
|  | ||||
|   | ||||
| @@ -1,7 +1,6 @@ | ||||
| # Generated by the protocol buffer compiler.  DO NOT EDIT! | ||||
| # sources: {{ ', '.join(description.files) }} | ||||
| # plugin: python-betterproto | ||||
| from dataclasses import dataclass | ||||
| {% if description.datetime_imports %} | ||||
| from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} | ||||
|  | ||||
| @@ -38,7 +37,7 @@ class {{ enum.py_name }}(betterproto.Enum): | ||||
| {% endfor %} | ||||
| {% endif %} | ||||
| {% for message in description.messages %} | ||||
| @dataclass | ||||
| @betterproto.protoclass | ||||
| class {{ message.py_name }}(betterproto.Message): | ||||
|     {% if message.comment %} | ||||
| {{ message.comment }} | ||||
|   | ||||
| @@ -4,11 +4,11 @@ from typing import Optional | ||||
|  | ||||
|  | ||||
| def test_has_field(): | ||||
|     @dataclass | ||||
|     @betterproto.protoclass | ||||
|     class Bar(betterproto.Message): | ||||
|         baz: int = betterproto.int32_field(1) | ||||
|  | ||||
|     @dataclass | ||||
|     @betterproto.protoclass | ||||
|     class Foo(betterproto.Message): | ||||
|         bar: Bar = betterproto.message_field(1) | ||||
|  | ||||
| @@ -34,11 +34,11 @@ def test_has_field(): | ||||
|  | ||||
|  | ||||
| def test_class_init(): | ||||
|     @dataclass | ||||
|     @betterproto.protoclass | ||||
|     class Bar(betterproto.Message): | ||||
|         name: str = betterproto.string_field(1) | ||||
|  | ||||
|     @dataclass | ||||
|     @betterproto.protoclass | ||||
|     class Foo(betterproto.Message): | ||||
|         name: str = betterproto.string_field(1) | ||||
|         child: Bar = betterproto.message_field(2) | ||||
| @@ -53,7 +53,7 @@ def test_enum_as_int_json(): | ||||
|         ZERO = 0 | ||||
|         ONE = 1 | ||||
|  | ||||
|     @dataclass | ||||
|     @betterproto.protoclass | ||||
|     class Foo(betterproto.Message): | ||||
|         bar: TestEnum = betterproto.enum_field(1) | ||||
|  | ||||
| @@ -67,13 +67,13 @@ def test_enum_as_int_json(): | ||||
|  | ||||
|  | ||||
| def test_unknown_fields(): | ||||
|     @dataclass | ||||
|     @betterproto.protoclass | ||||
|     class Newer(betterproto.Message): | ||||
|         foo: bool = betterproto.bool_field(1) | ||||
|         bar: int = betterproto.int32_field(2) | ||||
|         baz: str = betterproto.string_field(3) | ||||
|  | ||||
|     @dataclass | ||||
|     @betterproto.protoclass | ||||
|     class Older(betterproto.Message): | ||||
|         foo: bool = betterproto.bool_field(1) | ||||
|  | ||||
| @@ -89,11 +89,11 @@ def test_unknown_fields(): | ||||
|  | ||||
|  | ||||
| def test_oneof_support(): | ||||
|     @dataclass | ||||
|     @betterproto.protoclass | ||||
|     class Sub(betterproto.Message): | ||||
|         val: int = betterproto.int32_field(1) | ||||
|  | ||||
|     @dataclass | ||||
|     @betterproto.protoclass | ||||
|     class Foo(betterproto.Message): | ||||
|         bar: int = betterproto.int32_field(1, group="group1") | ||||
|         baz: str = betterproto.string_field(2, group="group1") | ||||
| @@ -134,7 +134,7 @@ def test_oneof_support(): | ||||
|  | ||||
|  | ||||
| def test_json_casing(): | ||||
|     @dataclass | ||||
|     @betterproto.protoclass | ||||
|     class CasingTest(betterproto.Message): | ||||
|         pascal_case: int = betterproto.int32_field(1) | ||||
|         camel_case: int = betterproto.int32_field(2) | ||||
| @@ -165,7 +165,7 @@ def test_json_casing(): | ||||
|  | ||||
|  | ||||
| def test_optional_flag(): | ||||
|     @dataclass | ||||
|     @betterproto.protoclass | ||||
|     class Request(betterproto.Message): | ||||
|         flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL) | ||||
|  | ||||
| @@ -180,7 +180,7 @@ def test_optional_flag(): | ||||
|  | ||||
|  | ||||
| def test_to_dict_default_values(): | ||||
|     @dataclass | ||||
|     @betterproto.protoclass | ||||
|     class TestMessage(betterproto.Message): | ||||
|         some_int: int = betterproto.int32_field(1) | ||||
|         some_double: float = betterproto.double_field(2) | ||||
| @@ -210,7 +210,7 @@ def test_to_dict_default_values(): | ||||
|     } | ||||
|  | ||||
|     # Some default and some other values | ||||
|     @dataclass | ||||
|     @betterproto.protoclass | ||||
|     class TestMessage2(betterproto.Message): | ||||
|         some_int: int = betterproto.int32_field(1) | ||||
|         some_double: float = betterproto.double_field(2) | ||||
| @@ -246,11 +246,11 @@ def test_to_dict_default_values(): | ||||
|     } | ||||
|  | ||||
|     # Nested messages | ||||
|     @dataclass | ||||
|     @betterproto.protoclass | ||||
|     class TestChildMessage(betterproto.Message): | ||||
|         some_other_int: int = betterproto.int32_field(1) | ||||
|  | ||||
|     @dataclass | ||||
|     @betterproto.protoclass | ||||
|     class TestParentMessage(betterproto.Message): | ||||
|         some_int: int = betterproto.int32_field(1) | ||||
|         some_double: float = betterproto.double_field(2) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user