Merge pull request #62 from jameslan/perf/cache-fields
Cache field metadata, to avoid calling `dataclasses.fields` to get more than 10% performance improvement
This commit is contained in:
		| @@ -14,11 +14,10 @@ from typing import ( | ||||
|     Collection, | ||||
|     Dict, | ||||
|     Generator, | ||||
|     Iterable, | ||||
|     List, | ||||
|     Mapping, | ||||
|     Optional, | ||||
|     SupportsBytes, | ||||
|     Set, | ||||
|     Tuple, | ||||
|     Type, | ||||
|     TypeVar, | ||||
| @@ -435,14 +434,29 @@ T = TypeVar("T", bound="Message") | ||||
|  | ||||
|  | ||||
| class ProtoClassMetadata: | ||||
|     cls: Type["Message"] | ||||
|     oneof_group_by_field: Dict[str, str] | ||||
|     oneof_field_by_group: Dict[str, Set[dataclasses.Field]] | ||||
|     default_gen: Dict[str, Callable] | ||||
|     cls_by_field: Dict[str, Type] | ||||
|     field_name_by_number: Dict[int, str] | ||||
|     meta_by_field_name: Dict[str, FieldMetadata] | ||||
|     __slots__ = ( | ||||
|         "oneof_group_by_field", | ||||
|         "oneof_field_by_group", | ||||
|         "default_gen", | ||||
|         "cls_by_field", | ||||
|         "field_name_by_number", | ||||
|         "meta_by_field_name", | ||||
|     ) | ||||
|  | ||||
|     def __init__(self, cls: Type["Message"]): | ||||
|         self.cls = cls | ||||
|         by_field = {} | ||||
|         by_group = {} | ||||
|         by_field_name = {} | ||||
|         by_field_number = {} | ||||
|  | ||||
|         for field in dataclasses.fields(cls): | ||||
|         fields = dataclasses.fields(cls) | ||||
|         for field in fields: | ||||
|             meta = FieldMetadata.get(field) | ||||
|  | ||||
|             if meta.group: | ||||
| @@ -451,30 +465,36 @@ class ProtoClassMetadata: | ||||
|  | ||||
|                 by_group.setdefault(meta.group, set()).add(field) | ||||
|  | ||||
|             by_field_name[field.name] = meta | ||||
|             by_field_number[meta.number] = field.name | ||||
|  | ||||
|         self.oneof_group_by_field = by_field | ||||
|         self.oneof_field_by_group = by_group | ||||
|         self.field_name_by_number = by_field_number | ||||
|         self.meta_by_field_name = by_field_name | ||||
|  | ||||
|         self.init_default_gen() | ||||
|         self.init_cls_by_field() | ||||
|         self.default_gen = self._get_default_gen(cls, fields) | ||||
|         self.cls_by_field = self._get_cls_by_field(cls, fields) | ||||
|  | ||||
|     def init_default_gen(self): | ||||
|     @staticmethod | ||||
|     def _get_default_gen(cls, fields): | ||||
|         default_gen = {} | ||||
|  | ||||
|         for field in dataclasses.fields(self.cls): | ||||
|             meta = FieldMetadata.get(field) | ||||
|             default_gen[field.name] = self.cls._get_field_default_gen(field, meta) | ||||
|         for field in fields: | ||||
|             default_gen[field.name] = cls._get_field_default_gen(field) | ||||
|  | ||||
|         self.default_gen = default_gen | ||||
|         return default_gen | ||||
|  | ||||
|     def init_cls_by_field(self): | ||||
|     @staticmethod | ||||
|     def _get_cls_by_field(cls, fields): | ||||
|         field_cls = {} | ||||
|  | ||||
|         for field in dataclasses.fields(self.cls): | ||||
|         for field in fields: | ||||
|             meta = FieldMetadata.get(field) | ||||
|             if meta.proto_type == TYPE_MAP: | ||||
|                 assert meta.map_types | ||||
|                 kt = self.cls._cls_for(field, index=0) | ||||
|                 vt = self.cls._cls_for(field, index=1) | ||||
|                 kt = cls._cls_for(field, index=0) | ||||
|                 vt = cls._cls_for(field, index=1) | ||||
|                 Entry = dataclasses.make_dataclass( | ||||
|                     "Entry", | ||||
|                     [ | ||||
| @@ -486,9 +506,9 @@ class ProtoClassMetadata: | ||||
|                 field_cls[field.name] = Entry | ||||
|                 field_cls[field.name + ".value"] = vt | ||||
|             else: | ||||
|                 field_cls[field.name] = self.cls._cls_for(field) | ||||
|                 field_cls[field.name] = cls._cls_for(field) | ||||
|  | ||||
|         self.cls_by_field = field_cls | ||||
|         return field_cls | ||||
|  | ||||
|  | ||||
| class Message(ABC): | ||||
| @@ -500,53 +520,50 @@ class Message(ABC): | ||||
|  | ||||
|     _serialized_on_wire: bool | ||||
|     _unknown_fields: bytes | ||||
|     _group_map: Dict[str, dict] | ||||
|     _group_current: Dict[str, str] | ||||
|  | ||||
|     def __post_init__(self) -> None: | ||||
|         # Keep track of whether every field was default | ||||
|         all_sentinel = True | ||||
|  | ||||
|         # Set a default value for each field in the class after `__init__` has | ||||
|         # already been run. | ||||
|         group_map: Dict[str, dataclasses.Field] = {} | ||||
|         for field in dataclasses.fields(self): | ||||
|             meta = FieldMetadata.get(field) | ||||
|         # Set current field of each group after `__init__` has already been run. | ||||
|         group_current: Dict[str, str] = {} | ||||
|         for field_name, meta in self._betterproto.meta_by_field_name.items(): | ||||
|  | ||||
|             if meta.group: | ||||
|                 group_map.setdefault(meta.group) | ||||
|                 group_current.setdefault(meta.group) | ||||
|  | ||||
|             if getattr(self, field.name) != PLACEHOLDER: | ||||
|             if getattr(self, field_name) != PLACEHOLDER: | ||||
|                 # Skip anything not set to the sentinel value | ||||
|                 all_sentinel = False | ||||
|  | ||||
|                 if meta.group: | ||||
|                     # This was set, so make it the selected value of the one-of. | ||||
|                     group_map[meta.group] = field | ||||
|                     group_current[meta.group] = field_name | ||||
|  | ||||
|                 continue | ||||
|  | ||||
|             setattr(self, field.name, self._get_field_default(field, meta)) | ||||
|             setattr(self, field_name, self._get_field_default(field_name)) | ||||
|  | ||||
|         # Now that all the defaults are set, reset it! | ||||
|         self.__dict__["_serialized_on_wire"] = not all_sentinel | ||||
|         self.__dict__["_unknown_fields"] = b"" | ||||
|         self.__dict__["_group_map"] = group_map | ||||
|         self.__dict__["_group_current"] = group_current | ||||
|  | ||||
|     def __setattr__(self, attr: str, value: Any) -> None: | ||||
|         if attr != "_serialized_on_wire": | ||||
|             # Track when a field has been set. | ||||
|             self.__dict__["_serialized_on_wire"] = True | ||||
|  | ||||
|         if hasattr(self, "_group_map"):  # __post_init__ had already run | ||||
|         if hasattr(self, "_group_current"):  # __post_init__ had already run | ||||
|             if attr in self._betterproto.oneof_group_by_field: | ||||
|                 group = self._betterproto.oneof_group_by_field[attr] | ||||
|                 for field in self._betterproto.oneof_field_by_group[group]: | ||||
|                     if field.name == attr: | ||||
|                         self._group_map[group] = field | ||||
|                         self._group_current[group] = field.name | ||||
|                     else: | ||||
|                         super().__setattr__( | ||||
|                             field.name, | ||||
|                             self._get_field_default(field, FieldMetadata.get(field)), | ||||
|                             field.name, self._get_field_default(field.name), | ||||
|                         ) | ||||
|  | ||||
|         super().__setattr__(attr, value) | ||||
| @@ -569,9 +586,8 @@ class Message(ABC): | ||||
|         Get the binary encoded Protobuf representation of this instance. | ||||
|         """ | ||||
|         output = b"" | ||||
|         for field in dataclasses.fields(self): | ||||
|             meta = FieldMetadata.get(field) | ||||
|             value = getattr(self, field.name) | ||||
|         for field_name, meta in self._betterproto.meta_by_field_name.items(): | ||||
|             value = getattr(self, field_name) | ||||
|  | ||||
|             if value is None: | ||||
|                 # Optional items should be skipped. This is used for the Google | ||||
| @@ -582,7 +598,7 @@ class Message(ABC): | ||||
|             # currently set in a `oneof` group, so it must be serialized even | ||||
|             # if the value is the default zero value. | ||||
|             selected_in_group = False | ||||
|             if meta.group and self._group_map[meta.group] == field: | ||||
|             if meta.group and self._group_current[meta.group] == field_name: | ||||
|                 selected_in_group = True | ||||
|  | ||||
|             serialize_empty = False | ||||
| @@ -591,7 +607,7 @@ class Message(ABC): | ||||
|                 # set (or received empty). | ||||
|                 serialize_empty = True | ||||
|  | ||||
|             if value == self._get_field_default(field, meta) and not ( | ||||
|             if value == self._get_field_default(field_name) and not ( | ||||
|                 selected_in_group or serialize_empty | ||||
|             ): | ||||
|                 # Default (zero) values are not serialized. Two exceptions are | ||||
| @@ -648,13 +664,11 @@ class Message(ABC): | ||||
|             field_cls = field_cls.__args__[index] | ||||
|         return field_cls | ||||
|  | ||||
|     def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any: | ||||
|         return self._betterproto.default_gen[field.name]() | ||||
|     def _get_field_default(self, field_name): | ||||
|         return self._betterproto.default_gen[field_name]() | ||||
|  | ||||
|     @classmethod | ||||
|     def _get_field_default_gen( | ||||
|         cls, field: dataclasses.Field, meta: FieldMetadata | ||||
|     ) -> Any: | ||||
|     def _get_field_default_gen(cls, field: dataclasses.Field) -> Any: | ||||
|         t = cls._type_hint(field.name) | ||||
|  | ||||
|         if hasattr(t, "__origin__"): | ||||
| @@ -682,7 +696,7 @@ class Message(ABC): | ||||
|             return t | ||||
|  | ||||
|     def _postprocess_single( | ||||
|         self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any | ||||
|         self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any | ||||
|     ) -> Any: | ||||
|         """Adjusts values after parsing.""" | ||||
|         if wire_type == WIRE_VARINT: | ||||
| @@ -704,7 +718,7 @@ class Message(ABC): | ||||
|             if meta.proto_type == TYPE_STRING: | ||||
|                 value = value.decode("utf-8") | ||||
|             elif meta.proto_type == TYPE_MESSAGE: | ||||
|                 cls = self._betterproto.cls_by_field[field.name] | ||||
|                 cls = self._betterproto.cls_by_field[field_name] | ||||
|  | ||||
|                 if cls == datetime: | ||||
|                     value = _Timestamp().parse(value).to_datetime() | ||||
| @@ -718,7 +732,7 @@ class Message(ABC): | ||||
|                     value = cls().parse(value) | ||||
|                     value._serialized_on_wire = True | ||||
|             elif meta.proto_type == TYPE_MAP: | ||||
|                 value = self._betterproto.cls_by_field[field.name]().parse(value) | ||||
|                 value = self._betterproto.cls_by_field[field_name]().parse(value) | ||||
|  | ||||
|         return value | ||||
|  | ||||
| @@ -727,49 +741,46 @@ class Message(ABC): | ||||
|         Parse the binary encoded Protobuf into this message instance. This | ||||
|         returns the instance itself and is therefore assignable and chainable. | ||||
|         """ | ||||
|         fields = {f.metadata["betterproto"].number: f for f in dataclasses.fields(self)} | ||||
|         for parsed in parse_fields(data): | ||||
|             if parsed.number in fields: | ||||
|                 field = fields[parsed.number] | ||||
|                 meta = FieldMetadata.get(field) | ||||
|  | ||||
|                 value: Any | ||||
|                 if ( | ||||
|                     parsed.wire_type == WIRE_LEN_DELIM | ||||
|                     and meta.proto_type in PACKED_TYPES | ||||
|                 ): | ||||
|                     # This is a packed repeated field. | ||||
|                     pos = 0 | ||||
|                     value = [] | ||||
|                     while pos < len(parsed.value): | ||||
|                         if meta.proto_type in ["float", "fixed32", "sfixed32"]: | ||||
|                             decoded, pos = parsed.value[pos : pos + 4], pos + 4 | ||||
|                             wire_type = WIRE_FIXED_32 | ||||
|                         elif meta.proto_type in ["double", "fixed64", "sfixed64"]: | ||||
|                             decoded, pos = parsed.value[pos : pos + 8], pos + 8 | ||||
|                             wire_type = WIRE_FIXED_64 | ||||
|                         else: | ||||
|                             decoded, pos = decode_varint(parsed.value, pos) | ||||
|                             wire_type = WIRE_VARINT | ||||
|                         decoded = self._postprocess_single( | ||||
|                             wire_type, meta, field, decoded | ||||
|                         ) | ||||
|                         value.append(decoded) | ||||
|                 else: | ||||
|                     value = self._postprocess_single( | ||||
|                         parsed.wire_type, meta, field, parsed.value | ||||
|                     ) | ||||
|  | ||||
|                 current = getattr(self, field.name) | ||||
|                 if meta.proto_type == TYPE_MAP: | ||||
|                     # Value represents a single key/value pair entry in the map. | ||||
|                     current[value.key] = value.value | ||||
|                 elif isinstance(current, list) and not isinstance(value, list): | ||||
|                     current.append(value) | ||||
|                 else: | ||||
|                     setattr(self, field.name, value) | ||||
|             else: | ||||
|             field_name = self._betterproto.field_name_by_number.get(parsed.number) | ||||
|             if not field_name: | ||||
|                 self._unknown_fields += parsed.raw | ||||
|                 continue | ||||
|  | ||||
|             meta = self._betterproto.meta_by_field_name[field_name] | ||||
|  | ||||
|             value: Any | ||||
|             if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES: | ||||
|                 # This is a packed repeated field. | ||||
|                 pos = 0 | ||||
|                 value = [] | ||||
|                 while pos < len(parsed.value): | ||||
|                     if meta.proto_type in ["float", "fixed32", "sfixed32"]: | ||||
|                         decoded, pos = parsed.value[pos : pos + 4], pos + 4 | ||||
|                         wire_type = WIRE_FIXED_32 | ||||
|                     elif meta.proto_type in ["double", "fixed64", "sfixed64"]: | ||||
|                         decoded, pos = parsed.value[pos : pos + 8], pos + 8 | ||||
|                         wire_type = WIRE_FIXED_64 | ||||
|                     else: | ||||
|                         decoded, pos = decode_varint(parsed.value, pos) | ||||
|                         wire_type = WIRE_VARINT | ||||
|                     decoded = self._postprocess_single( | ||||
|                         wire_type, meta, field_name, decoded | ||||
|                     ) | ||||
|                     value.append(decoded) | ||||
|             else: | ||||
|                 value = self._postprocess_single( | ||||
|                     parsed.wire_type, meta, field_name, parsed.value | ||||
|                 ) | ||||
|  | ||||
|             current = getattr(self, field_name) | ||||
|             if meta.proto_type == TYPE_MAP: | ||||
|                 # Value represents a single key/value pair entry in the map. | ||||
|                 current[value.key] = value.value | ||||
|             elif isinstance(current, list) and not isinstance(value, list): | ||||
|                 current.append(value) | ||||
|             else: | ||||
|                 setattr(self, field_name, value) | ||||
|  | ||||
|         return self | ||||
|  | ||||
| @@ -792,10 +803,9 @@ class Message(ABC): | ||||
|         `False`. | ||||
|         """ | ||||
|         output: Dict[str, Any] = {} | ||||
|         for field in dataclasses.fields(self): | ||||
|             meta = FieldMetadata.get(field) | ||||
|             v = getattr(self, field.name) | ||||
|             cased_name = casing(field.name).rstrip("_")  # type: ignore | ||||
|         for field_name, meta in self._betterproto.meta_by_field_name.items(): | ||||
|             v = getattr(self, field_name) | ||||
|             cased_name = casing(field_name).rstrip("_")  # type: ignore | ||||
|             if meta.proto_type == "message": | ||||
|                 if isinstance(v, datetime): | ||||
|                     if v != DATETIME_ZERO or include_default_values: | ||||
| @@ -821,7 +831,7 @@ class Message(ABC): | ||||
|  | ||||
|                 if v or include_default_values: | ||||
|                     output[cased_name] = v | ||||
|             elif v != self._get_field_default(field, meta) or include_default_values: | ||||
|             elif v != self._get_field_default(field_name) or include_default_values: | ||||
|                 if meta.proto_type in INT_64_TYPES: | ||||
|                     if isinstance(v, list): | ||||
|                         output[cased_name] = [str(n) for n in v] | ||||
| @@ -834,7 +844,7 @@ class Message(ABC): | ||||
|                         output[cased_name] = b64encode(v).decode("utf8") | ||||
|                 elif meta.proto_type == TYPE_ENUM: | ||||
|                     enum_values = list( | ||||
|                         self._betterproto.cls_by_field[field.name] | ||||
|                         self._betterproto.cls_by_field[field_name] | ||||
|                     )  # type: ignore | ||||
|                     if isinstance(v, list): | ||||
|                         output[cased_name] = [enum_values[e].name for e in v] | ||||
| @@ -852,56 +862,54 @@ class Message(ABC): | ||||
|         self._serialized_on_wire = True | ||||
|         fields_by_name = {f.name: f for f in dataclasses.fields(self)} | ||||
|         for key in value: | ||||
|             snake_cased = safe_snake_case(key) | ||||
|             if snake_cased in fields_by_name: | ||||
|                 field = fields_by_name[snake_cased] | ||||
|                 meta = FieldMetadata.get(field) | ||||
|             field_name = safe_snake_case(key) | ||||
|             meta = self._betterproto.meta_by_field_name.get(field_name) | ||||
|             if not meta: | ||||
|                 continue | ||||
|  | ||||
|                 if value[key] is not None: | ||||
|                     if meta.proto_type == "message": | ||||
|                         v = getattr(self, field.name) | ||||
|                         if isinstance(v, list): | ||||
|                             cls = self._betterproto.cls_by_field[field.name] | ||||
|                             for i in range(len(value[key])): | ||||
|                                 v.append(cls().from_dict(value[key][i])) | ||||
|                         elif isinstance(v, datetime): | ||||
|                             v = datetime.fromisoformat( | ||||
|                                 value[key].replace("Z", "+00:00") | ||||
|                             ) | ||||
|                             setattr(self, field.name, v) | ||||
|                         elif isinstance(v, timedelta): | ||||
|                             v = timedelta(seconds=float(value[key][:-1])) | ||||
|                             setattr(self, field.name, v) | ||||
|                         elif meta.wraps: | ||||
|                             setattr(self, field.name, value[key]) | ||||
|                         else: | ||||
|                             v.from_dict(value[key]) | ||||
|                     elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: | ||||
|                         v = getattr(self, field.name) | ||||
|                         cls = self._betterproto.cls_by_field[field.name + ".value"] | ||||
|                         for k in value[key]: | ||||
|                             v[k] = cls().from_dict(value[key][k]) | ||||
|             if value[key] is not None: | ||||
|                 if meta.proto_type == "message": | ||||
|                     v = getattr(self, field_name) | ||||
|                     if isinstance(v, list): | ||||
|                         cls = self._betterproto.cls_by_field[field_name] | ||||
|                         for i in range(len(value[key])): | ||||
|                             v.append(cls().from_dict(value[key][i])) | ||||
|                     elif isinstance(v, datetime): | ||||
|                         v = datetime.fromisoformat(value[key].replace("Z", "+00:00")) | ||||
|                         setattr(self, field_name, v) | ||||
|                     elif isinstance(v, timedelta): | ||||
|                         v = timedelta(seconds=float(value[key][:-1])) | ||||
|                         setattr(self, field_name, v) | ||||
|                     elif meta.wraps: | ||||
|                         setattr(self, field_name, value[key]) | ||||
|                     else: | ||||
|                         v = value[key] | ||||
|                         if meta.proto_type in INT_64_TYPES: | ||||
|                             if isinstance(value[key], list): | ||||
|                                 v = [int(n) for n in value[key]] | ||||
|                             else: | ||||
|                                 v = int(value[key]) | ||||
|                         elif meta.proto_type == TYPE_BYTES: | ||||
|                             if isinstance(value[key], list): | ||||
|                                 v = [b64decode(n) for n in value[key]] | ||||
|                             else: | ||||
|                                 v = b64decode(value[key]) | ||||
|                         elif meta.proto_type == TYPE_ENUM: | ||||
|                             enum_cls = self._betterproto.cls_by_field[field.name] | ||||
|                             if isinstance(v, list): | ||||
|                                 v = [enum_cls.from_string(e) for e in v] | ||||
|                             elif isinstance(v, str): | ||||
|                                 v = enum_cls.from_string(v) | ||||
|                         v.from_dict(value[key]) | ||||
|                 elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: | ||||
|                     v = getattr(self, field_name) | ||||
|                     cls = self._betterproto.cls_by_field[field_name + ".value"] | ||||
|                     for k in value[key]: | ||||
|                         v[k] = cls().from_dict(value[key][k]) | ||||
|                 else: | ||||
|                     v = value[key] | ||||
|                     if meta.proto_type in INT_64_TYPES: | ||||
|                         if isinstance(value[key], list): | ||||
|                             v = [int(n) for n in value[key]] | ||||
|                         else: | ||||
|                             v = int(value[key]) | ||||
|                     elif meta.proto_type == TYPE_BYTES: | ||||
|                         if isinstance(value[key], list): | ||||
|                             v = [b64decode(n) for n in value[key]] | ||||
|                         else: | ||||
|                             v = b64decode(value[key]) | ||||
|                     elif meta.proto_type == TYPE_ENUM: | ||||
|                         enum_cls = self._betterproto.cls_by_field[field_name] | ||||
|                         if isinstance(v, list): | ||||
|                             v = [enum_cls.from_string(e) for e in v] | ||||
|                         elif isinstance(v, str): | ||||
|                             v = enum_cls.from_string(v) | ||||
|  | ||||
|                         if v is not None: | ||||
|                             setattr(self, field.name, v) | ||||
|                     if v is not None: | ||||
|                         setattr(self, field_name, v) | ||||
|         return self | ||||
|  | ||||
|     def to_json(self, indent: Union[None, int, str] = None) -> str: | ||||
| @@ -927,10 +935,10 @@ def serialized_on_wire(message: Message) -> bool: | ||||
|  | ||||
| def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]: | ||||
|     """Return the name and value of a message's one-of field group.""" | ||||
|     field = message._group_map.get(group_name) | ||||
|     if not field: | ||||
|     field_name = message._group_current.get(group_name) | ||||
|     if not field_name: | ||||
|         return ("", None) | ||||
|     return (field.name, getattr(message, field.name)) | ||||
|     return (field_name, getattr(message, field_name)) | ||||
|  | ||||
|  | ||||
| @dataclasses.dataclass | ||||
|   | ||||
		Reference in New Issue
	
	Block a user