From ed33a48d64594b15458b55d67cec50a087ecd5f5 Mon Sep 17 00:00:00 2001 From: James Lan Date: Sat, 23 May 2020 18:06:04 -0700 Subject: [PATCH] Cache field metadata, to avoid calling `dataclasses.fields` to get more than 10% performance improvement --- betterproto/__init__.py | 294 +++++++++++++++++++++------------------- 1 file changed, 151 insertions(+), 143 deletions(-) diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 11fb741..f394b41 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -14,11 +14,10 @@ from typing import ( Collection, Dict, Generator, - Iterable, List, Mapping, Optional, - SupportsBytes, + Set, Tuple, Type, TypeVar, @@ -435,14 +434,29 @@ T = TypeVar("T", bound="Message") class ProtoClassMetadata: - cls: Type["Message"] + oneof_group_by_field: Dict[str, str] + oneof_field_by_group: Dict[str, Set[dataclasses.Field]] + default_gen: Dict[str, Callable] + cls_by_field: Dict[str, Type] + field_name_by_number: Dict[int, str] + meta_by_field_name: Dict[str, FieldMetadata] + __slots__ = ( + "oneof_group_by_field", + "oneof_field_by_group", + "default_gen", + "cls_by_field", + "field_name_by_number", + "meta_by_field_name", + ) def __init__(self, cls: Type["Message"]): - self.cls = cls by_field = {} by_group = {} + by_field_name = {} + by_field_number = {} - for field in dataclasses.fields(cls): + fields = dataclasses.fields(cls) + for field in fields: meta = FieldMetadata.get(field) if meta.group: @@ -451,30 +465,36 @@ class ProtoClassMetadata: by_group.setdefault(meta.group, set()).add(field) + by_field_name[field.name] = meta + by_field_number[meta.number] = field.name + self.oneof_group_by_field = by_field self.oneof_field_by_group = by_group + self.field_name_by_number = by_field_number + self.meta_by_field_name = by_field_name - self.init_default_gen() - self.init_cls_by_field() + self.default_gen = self._get_default_gen(cls, fields) + self.cls_by_field = self._get_cls_by_field(cls, fields) - def init_default_gen(self): + @staticmethod + def _get_default_gen(cls, fields): default_gen = {} - for field in dataclasses.fields(self.cls): - meta = FieldMetadata.get(field) - default_gen[field.name] = self.cls._get_field_default_gen(field, meta) + for field in fields: + default_gen[field.name] = cls._get_field_default_gen(field) - self.default_gen = default_gen + return default_gen - def init_cls_by_field(self): + @staticmethod + def _get_cls_by_field(cls, fields): field_cls = {} - for field in dataclasses.fields(self.cls): + for field in fields: meta = FieldMetadata.get(field) if meta.proto_type == TYPE_MAP: assert meta.map_types - kt = self.cls._cls_for(field, index=0) - vt = self.cls._cls_for(field, index=1) + kt = cls._cls_for(field, index=0) + vt = cls._cls_for(field, index=1) Entry = dataclasses.make_dataclass( "Entry", [ @@ -486,9 +506,9 @@ class ProtoClassMetadata: field_cls[field.name] = Entry field_cls[field.name + ".value"] = vt else: - field_cls[field.name] = self.cls._cls_for(field) + field_cls[field.name] = cls._cls_for(field) - self.cls_by_field = field_cls + return field_cls class Message(ABC): @@ -500,53 +520,50 @@ class Message(ABC): _serialized_on_wire: bool _unknown_fields: bytes - _group_map: Dict[str, dict] + _group_current: Dict[str, str] def __post_init__(self) -> None: # Keep track of whether every field was default all_sentinel = True - # Set a default value for each field in the class after `__init__` has - # already been run. - group_map: Dict[str, dataclasses.Field] = {} - for field in dataclasses.fields(self): - meta = FieldMetadata.get(field) + # Set current field of each group after `__init__` has already been run. + group_current: Dict[str, str] = {} + for field_name, meta in self._betterproto.meta_by_field_name.items(): if meta.group: - group_map.setdefault(meta.group) + group_current.setdefault(meta.group) - if getattr(self, field.name) != PLACEHOLDER: + if getattr(self, field_name) != PLACEHOLDER: # Skip anything not set to the sentinel value all_sentinel = False if meta.group: # This was set, so make it the selected value of the one-of. - group_map[meta.group] = field + group_current[meta.group] = field_name continue - setattr(self, field.name, self._get_field_default(field, meta)) + setattr(self, field_name, self._get_field_default(field_name)) # Now that all the defaults are set, reset it! self.__dict__["_serialized_on_wire"] = not all_sentinel self.__dict__["_unknown_fields"] = b"" - self.__dict__["_group_map"] = group_map + self.__dict__["_group_current"] = group_current def __setattr__(self, attr: str, value: Any) -> None: if attr != "_serialized_on_wire": # Track when a field has been set. self.__dict__["_serialized_on_wire"] = True - if hasattr(self, "_group_map"): # __post_init__ had already run + if hasattr(self, "_group_current"): # __post_init__ had already run if attr in self._betterproto.oneof_group_by_field: group = self._betterproto.oneof_group_by_field[attr] for field in self._betterproto.oneof_field_by_group[group]: if field.name == attr: - self._group_map[group] = field + self._group_current[group] = field.name else: super().__setattr__( - field.name, - self._get_field_default(field, FieldMetadata.get(field)), + field.name, self._get_field_default(field.name), ) super().__setattr__(attr, value) @@ -569,9 +586,8 @@ class Message(ABC): Get the binary encoded Protobuf representation of this instance. """ output = b"" - for field in dataclasses.fields(self): - meta = FieldMetadata.get(field) - value = getattr(self, field.name) + for field_name, meta in self._betterproto.meta_by_field_name.items(): + value = getattr(self, field_name) if value is None: # Optional items should be skipped. This is used for the Google @@ -582,7 +598,7 @@ class Message(ABC): # currently set in a `oneof` group, so it must be serialized even # if the value is the default zero value. selected_in_group = False - if meta.group and self._group_map[meta.group] == field: + if meta.group and self._group_current[meta.group] == field_name: selected_in_group = True serialize_empty = False @@ -591,7 +607,7 @@ class Message(ABC): # set (or received empty). serialize_empty = True - if value == self._get_field_default(field, meta) and not ( + if value == self._get_field_default(field_name) and not ( selected_in_group or serialize_empty ): # Default (zero) values are not serialized. Two exceptions are @@ -648,13 +664,11 @@ class Message(ABC): field_cls = field_cls.__args__[index] return field_cls - def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any: - return self._betterproto.default_gen[field.name]() + def _get_field_default(self, field_name): + return self._betterproto.default_gen[field_name]() @classmethod - def _get_field_default_gen( - cls, field: dataclasses.Field, meta: FieldMetadata - ) -> Any: + def _get_field_default_gen(cls, field: dataclasses.Field) -> Any: t = cls._type_hint(field.name) if hasattr(t, "__origin__"): @@ -682,7 +696,7 @@ class Message(ABC): return t def _postprocess_single( - self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any + self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any ) -> Any: """Adjusts values after parsing.""" if wire_type == WIRE_VARINT: @@ -704,7 +718,7 @@ class Message(ABC): if meta.proto_type == TYPE_STRING: value = value.decode("utf-8") elif meta.proto_type == TYPE_MESSAGE: - cls = self._betterproto.cls_by_field[field.name] + cls = self._betterproto.cls_by_field[field_name] if cls == datetime: value = _Timestamp().parse(value).to_datetime() @@ -718,7 +732,7 @@ class Message(ABC): value = cls().parse(value) value._serialized_on_wire = True elif meta.proto_type == TYPE_MAP: - value = self._betterproto.cls_by_field[field.name]().parse(value) + value = self._betterproto.cls_by_field[field_name]().parse(value) return value @@ -727,49 +741,46 @@ class Message(ABC): Parse the binary encoded Protobuf into this message instance. This returns the instance itself and is therefore assignable and chainable. """ - fields = {f.metadata["betterproto"].number: f for f in dataclasses.fields(self)} for parsed in parse_fields(data): - if parsed.number in fields: - field = fields[parsed.number] - meta = FieldMetadata.get(field) - - value: Any - if ( - parsed.wire_type == WIRE_LEN_DELIM - and meta.proto_type in PACKED_TYPES - ): - # This is a packed repeated field. - pos = 0 - value = [] - while pos < len(parsed.value): - if meta.proto_type in ["float", "fixed32", "sfixed32"]: - decoded, pos = parsed.value[pos : pos + 4], pos + 4 - wire_type = WIRE_FIXED_32 - elif meta.proto_type in ["double", "fixed64", "sfixed64"]: - decoded, pos = parsed.value[pos : pos + 8], pos + 8 - wire_type = WIRE_FIXED_64 - else: - decoded, pos = decode_varint(parsed.value, pos) - wire_type = WIRE_VARINT - decoded = self._postprocess_single( - wire_type, meta, field, decoded - ) - value.append(decoded) - else: - value = self._postprocess_single( - parsed.wire_type, meta, field, parsed.value - ) - - current = getattr(self, field.name) - if meta.proto_type == TYPE_MAP: - # Value represents a single key/value pair entry in the map. - current[value.key] = value.value - elif isinstance(current, list) and not isinstance(value, list): - current.append(value) - else: - setattr(self, field.name, value) - else: + field_name = self._betterproto.field_name_by_number.get(parsed.number) + if not field_name: self._unknown_fields += parsed.raw + continue + + meta = self._betterproto.meta_by_field_name[field_name] + + value: Any + if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES: + # This is a packed repeated field. + pos = 0 + value = [] + while pos < len(parsed.value): + if meta.proto_type in ["float", "fixed32", "sfixed32"]: + decoded, pos = parsed.value[pos : pos + 4], pos + 4 + wire_type = WIRE_FIXED_32 + elif meta.proto_type in ["double", "fixed64", "sfixed64"]: + decoded, pos = parsed.value[pos : pos + 8], pos + 8 + wire_type = WIRE_FIXED_64 + else: + decoded, pos = decode_varint(parsed.value, pos) + wire_type = WIRE_VARINT + decoded = self._postprocess_single( + wire_type, meta, field_name, decoded + ) + value.append(decoded) + else: + value = self._postprocess_single( + parsed.wire_type, meta, field_name, parsed.value + ) + + current = getattr(self, field_name) + if meta.proto_type == TYPE_MAP: + # Value represents a single key/value pair entry in the map. + current[value.key] = value.value + elif isinstance(current, list) and not isinstance(value, list): + current.append(value) + else: + setattr(self, field_name, value) return self @@ -792,10 +803,9 @@ class Message(ABC): `False`. """ output: Dict[str, Any] = {} - for field in dataclasses.fields(self): - meta = FieldMetadata.get(field) - v = getattr(self, field.name) - cased_name = casing(field.name).rstrip("_") # type: ignore + for field_name, meta in self._betterproto.meta_by_field_name.items(): + v = getattr(self, field_name) + cased_name = casing(field_name).rstrip("_") # type: ignore if meta.proto_type == "message": if isinstance(v, datetime): if v != DATETIME_ZERO or include_default_values: @@ -821,7 +831,7 @@ class Message(ABC): if v or include_default_values: output[cased_name] = v - elif v != self._get_field_default(field, meta) or include_default_values: + elif v != self._get_field_default(field_name) or include_default_values: if meta.proto_type in INT_64_TYPES: if isinstance(v, list): output[cased_name] = [str(n) for n in v] @@ -834,7 +844,7 @@ class Message(ABC): output[cased_name] = b64encode(v).decode("utf8") elif meta.proto_type == TYPE_ENUM: enum_values = list( - self._betterproto.cls_by_field[field.name] + self._betterproto.cls_by_field[field_name] ) # type: ignore if isinstance(v, list): output[cased_name] = [enum_values[e].name for e in v] @@ -852,56 +862,54 @@ class Message(ABC): self._serialized_on_wire = True fields_by_name = {f.name: f for f in dataclasses.fields(self)} for key in value: - snake_cased = safe_snake_case(key) - if snake_cased in fields_by_name: - field = fields_by_name[snake_cased] - meta = FieldMetadata.get(field) + field_name = safe_snake_case(key) + meta = self._betterproto.meta_by_field_name.get(field_name) + if not meta: + continue - if value[key] is not None: - if meta.proto_type == "message": - v = getattr(self, field.name) - if isinstance(v, list): - cls = self._betterproto.cls_by_field[field.name] - for i in range(len(value[key])): - v.append(cls().from_dict(value[key][i])) - elif isinstance(v, datetime): - v = datetime.fromisoformat( - value[key].replace("Z", "+00:00") - ) - setattr(self, field.name, v) - elif isinstance(v, timedelta): - v = timedelta(seconds=float(value[key][:-1])) - setattr(self, field.name, v) - elif meta.wraps: - setattr(self, field.name, value[key]) - else: - v.from_dict(value[key]) - elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: - v = getattr(self, field.name) - cls = self._betterproto.cls_by_field[field.name + ".value"] - for k in value[key]: - v[k] = cls().from_dict(value[key][k]) + if value[key] is not None: + if meta.proto_type == "message": + v = getattr(self, field_name) + if isinstance(v, list): + cls = self._betterproto.cls_by_field[field_name] + for i in range(len(value[key])): + v.append(cls().from_dict(value[key][i])) + elif isinstance(v, datetime): + v = datetime.fromisoformat(value[key].replace("Z", "+00:00")) + setattr(self, field_name, v) + elif isinstance(v, timedelta): + v = timedelta(seconds=float(value[key][:-1])) + setattr(self, field_name, v) + elif meta.wraps: + setattr(self, field_name, value[key]) else: - v = value[key] - if meta.proto_type in INT_64_TYPES: - if isinstance(value[key], list): - v = [int(n) for n in value[key]] - else: - v = int(value[key]) - elif meta.proto_type == TYPE_BYTES: - if isinstance(value[key], list): - v = [b64decode(n) for n in value[key]] - else: - v = b64decode(value[key]) - elif meta.proto_type == TYPE_ENUM: - enum_cls = self._betterproto.cls_by_field[field.name] - if isinstance(v, list): - v = [enum_cls.from_string(e) for e in v] - elif isinstance(v, str): - v = enum_cls.from_string(v) + v.from_dict(value[key]) + elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: + v = getattr(self, field_name) + cls = self._betterproto.cls_by_field[field_name + ".value"] + for k in value[key]: + v[k] = cls().from_dict(value[key][k]) + else: + v = value[key] + if meta.proto_type in INT_64_TYPES: + if isinstance(value[key], list): + v = [int(n) for n in value[key]] + else: + v = int(value[key]) + elif meta.proto_type == TYPE_BYTES: + if isinstance(value[key], list): + v = [b64decode(n) for n in value[key]] + else: + v = b64decode(value[key]) + elif meta.proto_type == TYPE_ENUM: + enum_cls = self._betterproto.cls_by_field[field_name] + if isinstance(v, list): + v = [enum_cls.from_string(e) for e in v] + elif isinstance(v, str): + v = enum_cls.from_string(v) - if v is not None: - setattr(self, field.name, v) + if v is not None: + setattr(self, field_name, v) return self def to_json(self, indent: Union[None, int, str] = None) -> str: @@ -927,10 +935,10 @@ def serialized_on_wire(message: Message) -> bool: def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]: """Return the name and value of a message's one-of field group.""" - field = message._group_map.get(group_name) - if not field: + field_name = message._group_current.get(group_name) + if not field_name: return ("", None) - return (field.name, getattr(message, field.name)) + return (field_name, getattr(message, field_name)) @dataclasses.dataclass