Merge pull request #62 from jameslan/perf/cache-fields

Cache field metadata, to avoid calling `dataclasses.fields` to get more than 10% performance improvement
This commit is contained in:
Bouke Versteegh 2020-05-29 12:17:25 +02:00 committed by GitHub
commit fcff3dff74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -14,11 +14,10 @@ from typing import (
Collection, Collection,
Dict, Dict,
Generator, Generator,
Iterable,
List, List,
Mapping, Mapping,
Optional, Optional,
SupportsBytes, Set,
Tuple, Tuple,
Type, Type,
TypeVar, TypeVar,
@ -435,14 +434,29 @@ T = TypeVar("T", bound="Message")
class ProtoClassMetadata: class ProtoClassMetadata:
cls: Type["Message"] oneof_group_by_field: Dict[str, str]
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
default_gen: Dict[str, Callable]
cls_by_field: Dict[str, Type]
field_name_by_number: Dict[int, str]
meta_by_field_name: Dict[str, FieldMetadata]
__slots__ = (
"oneof_group_by_field",
"oneof_field_by_group",
"default_gen",
"cls_by_field",
"field_name_by_number",
"meta_by_field_name",
)
def __init__(self, cls: Type["Message"]): def __init__(self, cls: Type["Message"]):
self.cls = cls
by_field = {} by_field = {}
by_group = {} by_group = {}
by_field_name = {}
by_field_number = {}
for field in dataclasses.fields(cls): fields = dataclasses.fields(cls)
for field in fields:
meta = FieldMetadata.get(field) meta = FieldMetadata.get(field)
if meta.group: if meta.group:
@ -451,30 +465,36 @@ class ProtoClassMetadata:
by_group.setdefault(meta.group, set()).add(field) by_group.setdefault(meta.group, set()).add(field)
by_field_name[field.name] = meta
by_field_number[meta.number] = field.name
self.oneof_group_by_field = by_field self.oneof_group_by_field = by_field
self.oneof_field_by_group = by_group self.oneof_field_by_group = by_group
self.field_name_by_number = by_field_number
self.meta_by_field_name = by_field_name
self.init_default_gen() self.default_gen = self._get_default_gen(cls, fields)
self.init_cls_by_field() self.cls_by_field = self._get_cls_by_field(cls, fields)
def init_default_gen(self): @staticmethod
def _get_default_gen(cls, fields):
default_gen = {} default_gen = {}
for field in dataclasses.fields(self.cls): for field in fields:
meta = FieldMetadata.get(field) default_gen[field.name] = cls._get_field_default_gen(field)
default_gen[field.name] = self.cls._get_field_default_gen(field, meta)
self.default_gen = default_gen return default_gen
def init_cls_by_field(self): @staticmethod
def _get_cls_by_field(cls, fields):
field_cls = {} field_cls = {}
for field in dataclasses.fields(self.cls): for field in fields:
meta = FieldMetadata.get(field) meta = FieldMetadata.get(field)
if meta.proto_type == TYPE_MAP: if meta.proto_type == TYPE_MAP:
assert meta.map_types assert meta.map_types
kt = self.cls._cls_for(field, index=0) kt = cls._cls_for(field, index=0)
vt = self.cls._cls_for(field, index=1) vt = cls._cls_for(field, index=1)
Entry = dataclasses.make_dataclass( Entry = dataclasses.make_dataclass(
"Entry", "Entry",
[ [
@ -486,9 +506,9 @@ class ProtoClassMetadata:
field_cls[field.name] = Entry field_cls[field.name] = Entry
field_cls[field.name + ".value"] = vt field_cls[field.name + ".value"] = vt
else: else:
field_cls[field.name] = self.cls._cls_for(field) field_cls[field.name] = cls._cls_for(field)
self.cls_by_field = field_cls return field_cls
class Message(ABC): class Message(ABC):
@ -500,53 +520,50 @@ class Message(ABC):
_serialized_on_wire: bool _serialized_on_wire: bool
_unknown_fields: bytes _unknown_fields: bytes
_group_map: Dict[str, dict] _group_current: Dict[str, str]
def __post_init__(self) -> None: def __post_init__(self) -> None:
# Keep track of whether every field was default # Keep track of whether every field was default
all_sentinel = True all_sentinel = True
# Set a default value for each field in the class after `__init__` has # Set current field of each group after `__init__` has already been run.
# already been run. group_current: Dict[str, str] = {}
group_map: Dict[str, dataclasses.Field] = {} for field_name, meta in self._betterproto.meta_by_field_name.items():
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
if meta.group: if meta.group:
group_map.setdefault(meta.group) group_current.setdefault(meta.group)
if getattr(self, field.name) != PLACEHOLDER: if getattr(self, field_name) != PLACEHOLDER:
# Skip anything not set to the sentinel value # Skip anything not set to the sentinel value
all_sentinel = False all_sentinel = False
if meta.group: if meta.group:
# This was set, so make it the selected value of the one-of. # This was set, so make it the selected value of the one-of.
group_map[meta.group] = field group_current[meta.group] = field_name
continue continue
setattr(self, field.name, self._get_field_default(field, meta)) setattr(self, field_name, self._get_field_default(field_name))
# Now that all the defaults are set, reset it! # Now that all the defaults are set, reset it!
self.__dict__["_serialized_on_wire"] = not all_sentinel self.__dict__["_serialized_on_wire"] = not all_sentinel
self.__dict__["_unknown_fields"] = b"" self.__dict__["_unknown_fields"] = b""
self.__dict__["_group_map"] = group_map self.__dict__["_group_current"] = group_current
def __setattr__(self, attr: str, value: Any) -> None: def __setattr__(self, attr: str, value: Any) -> None:
if attr != "_serialized_on_wire": if attr != "_serialized_on_wire":
# Track when a field has been set. # Track when a field has been set.
self.__dict__["_serialized_on_wire"] = True self.__dict__["_serialized_on_wire"] = True
if hasattr(self, "_group_map"): # __post_init__ had already run if hasattr(self, "_group_current"): # __post_init__ had already run
if attr in self._betterproto.oneof_group_by_field: if attr in self._betterproto.oneof_group_by_field:
group = self._betterproto.oneof_group_by_field[attr] group = self._betterproto.oneof_group_by_field[attr]
for field in self._betterproto.oneof_field_by_group[group]: for field in self._betterproto.oneof_field_by_group[group]:
if field.name == attr: if field.name == attr:
self._group_map[group] = field self._group_current[group] = field.name
else: else:
super().__setattr__( super().__setattr__(
field.name, field.name, self._get_field_default(field.name),
self._get_field_default(field, FieldMetadata.get(field)),
) )
super().__setattr__(attr, value) super().__setattr__(attr, value)
@ -569,9 +586,8 @@ class Message(ABC):
Get the binary encoded Protobuf representation of this instance. Get the binary encoded Protobuf representation of this instance.
""" """
output = b"" output = b""
for field in dataclasses.fields(self): for field_name, meta in self._betterproto.meta_by_field_name.items():
meta = FieldMetadata.get(field) value = getattr(self, field_name)
value = getattr(self, field.name)
if value is None: if value is None:
# Optional items should be skipped. This is used for the Google # Optional items should be skipped. This is used for the Google
@ -582,7 +598,7 @@ class Message(ABC):
# currently set in a `oneof` group, so it must be serialized even # currently set in a `oneof` group, so it must be serialized even
# if the value is the default zero value. # if the value is the default zero value.
selected_in_group = False selected_in_group = False
if meta.group and self._group_map[meta.group] == field: if meta.group and self._group_current[meta.group] == field_name:
selected_in_group = True selected_in_group = True
serialize_empty = False serialize_empty = False
@ -591,7 +607,7 @@ class Message(ABC):
# set (or received empty). # set (or received empty).
serialize_empty = True serialize_empty = True
if value == self._get_field_default(field, meta) and not ( if value == self._get_field_default(field_name) and not (
selected_in_group or serialize_empty selected_in_group or serialize_empty
): ):
# Default (zero) values are not serialized. Two exceptions are # Default (zero) values are not serialized. Two exceptions are
@ -648,13 +664,11 @@ class Message(ABC):
field_cls = field_cls.__args__[index] field_cls = field_cls.__args__[index]
return field_cls return field_cls
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any: def _get_field_default(self, field_name):
return self._betterproto.default_gen[field.name]() return self._betterproto.default_gen[field_name]()
@classmethod @classmethod
def _get_field_default_gen( def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
cls, field: dataclasses.Field, meta: FieldMetadata
) -> Any:
t = cls._type_hint(field.name) t = cls._type_hint(field.name)
if hasattr(t, "__origin__"): if hasattr(t, "__origin__"):
@ -682,7 +696,7 @@ class Message(ABC):
return t return t
def _postprocess_single( def _postprocess_single(
self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any
) -> Any: ) -> Any:
"""Adjusts values after parsing.""" """Adjusts values after parsing."""
if wire_type == WIRE_VARINT: if wire_type == WIRE_VARINT:
@ -704,7 +718,7 @@ class Message(ABC):
if meta.proto_type == TYPE_STRING: if meta.proto_type == TYPE_STRING:
value = value.decode("utf-8") value = value.decode("utf-8")
elif meta.proto_type == TYPE_MESSAGE: elif meta.proto_type == TYPE_MESSAGE:
cls = self._betterproto.cls_by_field[field.name] cls = self._betterproto.cls_by_field[field_name]
if cls == datetime: if cls == datetime:
value = _Timestamp().parse(value).to_datetime() value = _Timestamp().parse(value).to_datetime()
@ -718,7 +732,7 @@ class Message(ABC):
value = cls().parse(value) value = cls().parse(value)
value._serialized_on_wire = True value._serialized_on_wire = True
elif meta.proto_type == TYPE_MAP: elif meta.proto_type == TYPE_MAP:
value = self._betterproto.cls_by_field[field.name]().parse(value) value = self._betterproto.cls_by_field[field_name]().parse(value)
return value return value
@ -727,49 +741,46 @@ class Message(ABC):
Parse the binary encoded Protobuf into this message instance. This Parse the binary encoded Protobuf into this message instance. This
returns the instance itself and is therefore assignable and chainable. returns the instance itself and is therefore assignable and chainable.
""" """
fields = {f.metadata["betterproto"].number: f for f in dataclasses.fields(self)}
for parsed in parse_fields(data): for parsed in parse_fields(data):
if parsed.number in fields: field_name = self._betterproto.field_name_by_number.get(parsed.number)
field = fields[parsed.number] if not field_name:
meta = FieldMetadata.get(field)
value: Any
if (
parsed.wire_type == WIRE_LEN_DELIM
and meta.proto_type in PACKED_TYPES
):
# This is a packed repeated field.
pos = 0
value = []
while pos < len(parsed.value):
if meta.proto_type in ["float", "fixed32", "sfixed32"]:
decoded, pos = parsed.value[pos : pos + 4], pos + 4
wire_type = WIRE_FIXED_32
elif meta.proto_type in ["double", "fixed64", "sfixed64"]:
decoded, pos = parsed.value[pos : pos + 8], pos + 8
wire_type = WIRE_FIXED_64
else:
decoded, pos = decode_varint(parsed.value, pos)
wire_type = WIRE_VARINT
decoded = self._postprocess_single(
wire_type, meta, field, decoded
)
value.append(decoded)
else:
value = self._postprocess_single(
parsed.wire_type, meta, field, parsed.value
)
current = getattr(self, field.name)
if meta.proto_type == TYPE_MAP:
# Value represents a single key/value pair entry in the map.
current[value.key] = value.value
elif isinstance(current, list) and not isinstance(value, list):
current.append(value)
else:
setattr(self, field.name, value)
else:
self._unknown_fields += parsed.raw self._unknown_fields += parsed.raw
continue
meta = self._betterproto.meta_by_field_name[field_name]
value: Any
if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES:
# This is a packed repeated field.
pos = 0
value = []
while pos < len(parsed.value):
if meta.proto_type in ["float", "fixed32", "sfixed32"]:
decoded, pos = parsed.value[pos : pos + 4], pos + 4
wire_type = WIRE_FIXED_32
elif meta.proto_type in ["double", "fixed64", "sfixed64"]:
decoded, pos = parsed.value[pos : pos + 8], pos + 8
wire_type = WIRE_FIXED_64
else:
decoded, pos = decode_varint(parsed.value, pos)
wire_type = WIRE_VARINT
decoded = self._postprocess_single(
wire_type, meta, field_name, decoded
)
value.append(decoded)
else:
value = self._postprocess_single(
parsed.wire_type, meta, field_name, parsed.value
)
current = getattr(self, field_name)
if meta.proto_type == TYPE_MAP:
# Value represents a single key/value pair entry in the map.
current[value.key] = value.value
elif isinstance(current, list) and not isinstance(value, list):
current.append(value)
else:
setattr(self, field_name, value)
return self return self
@ -792,10 +803,9 @@ class Message(ABC):
`False`. `False`.
""" """
output: Dict[str, Any] = {} output: Dict[str, Any] = {}
for field in dataclasses.fields(self): for field_name, meta in self._betterproto.meta_by_field_name.items():
meta = FieldMetadata.get(field) v = getattr(self, field_name)
v = getattr(self, field.name) cased_name = casing(field_name).rstrip("_") # type: ignore
cased_name = casing(field.name).rstrip("_") # type: ignore
if meta.proto_type == "message": if meta.proto_type == "message":
if isinstance(v, datetime): if isinstance(v, datetime):
if v != DATETIME_ZERO or include_default_values: if v != DATETIME_ZERO or include_default_values:
@ -821,7 +831,7 @@ class Message(ABC):
if v or include_default_values: if v or include_default_values:
output[cased_name] = v output[cased_name] = v
elif v != self._get_field_default(field, meta) or include_default_values: elif v != self._get_field_default(field_name) or include_default_values:
if meta.proto_type in INT_64_TYPES: if meta.proto_type in INT_64_TYPES:
if isinstance(v, list): if isinstance(v, list):
output[cased_name] = [str(n) for n in v] output[cased_name] = [str(n) for n in v]
@ -834,7 +844,7 @@ class Message(ABC):
output[cased_name] = b64encode(v).decode("utf8") output[cased_name] = b64encode(v).decode("utf8")
elif meta.proto_type == TYPE_ENUM: elif meta.proto_type == TYPE_ENUM:
enum_values = list( enum_values = list(
self._betterproto.cls_by_field[field.name] self._betterproto.cls_by_field[field_name]
) # type: ignore ) # type: ignore
if isinstance(v, list): if isinstance(v, list):
output[cased_name] = [enum_values[e].name for e in v] output[cased_name] = [enum_values[e].name for e in v]
@ -852,56 +862,54 @@ class Message(ABC):
self._serialized_on_wire = True self._serialized_on_wire = True
fields_by_name = {f.name: f for f in dataclasses.fields(self)} fields_by_name = {f.name: f for f in dataclasses.fields(self)}
for key in value: for key in value:
snake_cased = safe_snake_case(key) field_name = safe_snake_case(key)
if snake_cased in fields_by_name: meta = self._betterproto.meta_by_field_name.get(field_name)
field = fields_by_name[snake_cased] if not meta:
meta = FieldMetadata.get(field) continue
if value[key] is not None: if value[key] is not None:
if meta.proto_type == "message": if meta.proto_type == "message":
v = getattr(self, field.name) v = getattr(self, field_name)
if isinstance(v, list): if isinstance(v, list):
cls = self._betterproto.cls_by_field[field.name] cls = self._betterproto.cls_by_field[field_name]
for i in range(len(value[key])): for i in range(len(value[key])):
v.append(cls().from_dict(value[key][i])) v.append(cls().from_dict(value[key][i]))
elif isinstance(v, datetime): elif isinstance(v, datetime):
v = datetime.fromisoformat( v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
value[key].replace("Z", "+00:00") setattr(self, field_name, v)
) elif isinstance(v, timedelta):
setattr(self, field.name, v) v = timedelta(seconds=float(value[key][:-1]))
elif isinstance(v, timedelta): setattr(self, field_name, v)
v = timedelta(seconds=float(value[key][:-1])) elif meta.wraps:
setattr(self, field.name, v) setattr(self, field_name, value[key])
elif meta.wraps:
setattr(self, field.name, value[key])
else:
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field.name)
cls = self._betterproto.cls_by_field[field.name + ".value"]
for k in value[key]:
v[k] = cls().from_dict(value[key][k])
else: else:
v = value[key] v.from_dict(value[key])
if meta.proto_type in INT_64_TYPES: elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
if isinstance(value[key], list): v = getattr(self, field_name)
v = [int(n) for n in value[key]] cls = self._betterproto.cls_by_field[field_name + ".value"]
else: for k in value[key]:
v = int(value[key]) v[k] = cls().from_dict(value[key][k])
elif meta.proto_type == TYPE_BYTES: else:
if isinstance(value[key], list): v = value[key]
v = [b64decode(n) for n in value[key]] if meta.proto_type in INT_64_TYPES:
else: if isinstance(value[key], list):
v = b64decode(value[key]) v = [int(n) for n in value[key]]
elif meta.proto_type == TYPE_ENUM: else:
enum_cls = self._betterproto.cls_by_field[field.name] v = int(value[key])
if isinstance(v, list): elif meta.proto_type == TYPE_BYTES:
v = [enum_cls.from_string(e) for e in v] if isinstance(value[key], list):
elif isinstance(v, str): v = [b64decode(n) for n in value[key]]
v = enum_cls.from_string(v) else:
v = b64decode(value[key])
elif meta.proto_type == TYPE_ENUM:
enum_cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
v = [enum_cls.from_string(e) for e in v]
elif isinstance(v, str):
v = enum_cls.from_string(v)
if v is not None: if v is not None:
setattr(self, field.name, v) setattr(self, field_name, v)
return self return self
def to_json(self, indent: Union[None, int, str] = None) -> str: def to_json(self, indent: Union[None, int, str] = None) -> str:
@ -927,10 +935,10 @@ def serialized_on_wire(message: Message) -> bool:
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]: def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
"""Return the name and value of a message's one-of field group.""" """Return the name and value of a message's one-of field group."""
field = message._group_map.get(group_name) field_name = message._group_current.get(group_name)
if not field: if not field_name:
return ("", None) return ("", None)
return (field.name, getattr(message, field.name)) return (field_name, getattr(message, field_name))
@dataclasses.dataclass @dataclasses.dataclass