Cache field metadata, to avoid calling dataclasses.fields to get more than 10% performance improvement

This commit is contained in:
James Lan 2020-05-23 18:06:04 -07:00
parent ee362a7a73
commit ed33a48d64

View File

@ -14,11 +14,10 @@ from typing import (
Collection,
Dict,
Generator,
Iterable,
List,
Mapping,
Optional,
SupportsBytes,
Set,
Tuple,
Type,
TypeVar,
@ -435,14 +434,29 @@ T = TypeVar("T", bound="Message")
class ProtoClassMetadata:
cls: Type["Message"]
oneof_group_by_field: Dict[str, str]
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
default_gen: Dict[str, Callable]
cls_by_field: Dict[str, Type]
field_name_by_number: Dict[int, str]
meta_by_field_name: Dict[str, FieldMetadata]
__slots__ = (
"oneof_group_by_field",
"oneof_field_by_group",
"default_gen",
"cls_by_field",
"field_name_by_number",
"meta_by_field_name",
)
def __init__(self, cls: Type["Message"]):
self.cls = cls
by_field = {}
by_group = {}
by_field_name = {}
by_field_number = {}
for field in dataclasses.fields(cls):
fields = dataclasses.fields(cls)
for field in fields:
meta = FieldMetadata.get(field)
if meta.group:
@ -451,30 +465,36 @@ class ProtoClassMetadata:
by_group.setdefault(meta.group, set()).add(field)
by_field_name[field.name] = meta
by_field_number[meta.number] = field.name
self.oneof_group_by_field = by_field
self.oneof_field_by_group = by_group
self.field_name_by_number = by_field_number
self.meta_by_field_name = by_field_name
self.init_default_gen()
self.init_cls_by_field()
self.default_gen = self._get_default_gen(cls, fields)
self.cls_by_field = self._get_cls_by_field(cls, fields)
def init_default_gen(self):
@staticmethod
def _get_default_gen(cls, fields):
default_gen = {}
for field in dataclasses.fields(self.cls):
meta = FieldMetadata.get(field)
default_gen[field.name] = self.cls._get_field_default_gen(field, meta)
for field in fields:
default_gen[field.name] = cls._get_field_default_gen(field)
self.default_gen = default_gen
return default_gen
def init_cls_by_field(self):
@staticmethod
def _get_cls_by_field(cls, fields):
field_cls = {}
for field in dataclasses.fields(self.cls):
for field in fields:
meta = FieldMetadata.get(field)
if meta.proto_type == TYPE_MAP:
assert meta.map_types
kt = self.cls._cls_for(field, index=0)
vt = self.cls._cls_for(field, index=1)
kt = cls._cls_for(field, index=0)
vt = cls._cls_for(field, index=1)
Entry = dataclasses.make_dataclass(
"Entry",
[
@ -486,9 +506,9 @@ class ProtoClassMetadata:
field_cls[field.name] = Entry
field_cls[field.name + ".value"] = vt
else:
field_cls[field.name] = self.cls._cls_for(field)
field_cls[field.name] = cls._cls_for(field)
self.cls_by_field = field_cls
return field_cls
class Message(ABC):
@ -500,53 +520,50 @@ class Message(ABC):
_serialized_on_wire: bool
_unknown_fields: bytes
_group_map: Dict[str, dict]
_group_current: Dict[str, str]
def __post_init__(self) -> None:
# Keep track of whether every field was default
all_sentinel = True
# Set a default value for each field in the class after `__init__` has
# already been run.
group_map: Dict[str, dataclasses.Field] = {}
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
# Set current field of each group after `__init__` has already been run.
group_current: Dict[str, str] = {}
for field_name, meta in self._betterproto.meta_by_field_name.items():
if meta.group:
group_map.setdefault(meta.group)
group_current.setdefault(meta.group)
if getattr(self, field.name) != PLACEHOLDER:
if getattr(self, field_name) != PLACEHOLDER:
# Skip anything not set to the sentinel value
all_sentinel = False
if meta.group:
# This was set, so make it the selected value of the one-of.
group_map[meta.group] = field
group_current[meta.group] = field_name
continue
setattr(self, field.name, self._get_field_default(field, meta))
setattr(self, field_name, self._get_field_default(field_name))
# Now that all the defaults are set, reset it!
self.__dict__["_serialized_on_wire"] = not all_sentinel
self.__dict__["_unknown_fields"] = b""
self.__dict__["_group_map"] = group_map
self.__dict__["_group_current"] = group_current
def __setattr__(self, attr: str, value: Any) -> None:
if attr != "_serialized_on_wire":
# Track when a field has been set.
self.__dict__["_serialized_on_wire"] = True
if hasattr(self, "_group_map"): # __post_init__ had already run
if hasattr(self, "_group_current"): # __post_init__ had already run
if attr in self._betterproto.oneof_group_by_field:
group = self._betterproto.oneof_group_by_field[attr]
for field in self._betterproto.oneof_field_by_group[group]:
if field.name == attr:
self._group_map[group] = field
self._group_current[group] = field.name
else:
super().__setattr__(
field.name,
self._get_field_default(field, FieldMetadata.get(field)),
field.name, self._get_field_default(field.name),
)
super().__setattr__(attr, value)
@ -569,9 +586,8 @@ class Message(ABC):
Get the binary encoded Protobuf representation of this instance.
"""
output = b""
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
value = getattr(self, field.name)
for field_name, meta in self._betterproto.meta_by_field_name.items():
value = getattr(self, field_name)
if value is None:
# Optional items should be skipped. This is used for the Google
@ -582,7 +598,7 @@ class Message(ABC):
# currently set in a `oneof` group, so it must be serialized even
# if the value is the default zero value.
selected_in_group = False
if meta.group and self._group_map[meta.group] == field:
if meta.group and self._group_current[meta.group] == field_name:
selected_in_group = True
serialize_empty = False
@ -591,7 +607,7 @@ class Message(ABC):
# set (or received empty).
serialize_empty = True
if value == self._get_field_default(field, meta) and not (
if value == self._get_field_default(field_name) and not (
selected_in_group or serialize_empty
):
# Default (zero) values are not serialized. Two exceptions are
@ -648,13 +664,11 @@ class Message(ABC):
field_cls = field_cls.__args__[index]
return field_cls
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
return self._betterproto.default_gen[field.name]()
def _get_field_default(self, field_name):
return self._betterproto.default_gen[field_name]()
@classmethod
def _get_field_default_gen(
cls, field: dataclasses.Field, meta: FieldMetadata
) -> Any:
def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
t = cls._type_hint(field.name)
if hasattr(t, "__origin__"):
@ -682,7 +696,7 @@ class Message(ABC):
return t
def _postprocess_single(
self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any
self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any
) -> Any:
"""Adjusts values after parsing."""
if wire_type == WIRE_VARINT:
@ -704,7 +718,7 @@ class Message(ABC):
if meta.proto_type == TYPE_STRING:
value = value.decode("utf-8")
elif meta.proto_type == TYPE_MESSAGE:
cls = self._betterproto.cls_by_field[field.name]
cls = self._betterproto.cls_by_field[field_name]
if cls == datetime:
value = _Timestamp().parse(value).to_datetime()
@ -718,7 +732,7 @@ class Message(ABC):
value = cls().parse(value)
value._serialized_on_wire = True
elif meta.proto_type == TYPE_MAP:
value = self._betterproto.cls_by_field[field.name]().parse(value)
value = self._betterproto.cls_by_field[field_name]().parse(value)
return value
@ -727,49 +741,46 @@ class Message(ABC):
Parse the binary encoded Protobuf into this message instance. This
returns the instance itself and is therefore assignable and chainable.
"""
fields = {f.metadata["betterproto"].number: f for f in dataclasses.fields(self)}
for parsed in parse_fields(data):
if parsed.number in fields:
field = fields[parsed.number]
meta = FieldMetadata.get(field)
value: Any
if (
parsed.wire_type == WIRE_LEN_DELIM
and meta.proto_type in PACKED_TYPES
):
# This is a packed repeated field.
pos = 0
value = []
while pos < len(parsed.value):
if meta.proto_type in ["float", "fixed32", "sfixed32"]:
decoded, pos = parsed.value[pos : pos + 4], pos + 4
wire_type = WIRE_FIXED_32
elif meta.proto_type in ["double", "fixed64", "sfixed64"]:
decoded, pos = parsed.value[pos : pos + 8], pos + 8
wire_type = WIRE_FIXED_64
else:
decoded, pos = decode_varint(parsed.value, pos)
wire_type = WIRE_VARINT
decoded = self._postprocess_single(
wire_type, meta, field, decoded
)
value.append(decoded)
else:
value = self._postprocess_single(
parsed.wire_type, meta, field, parsed.value
)
current = getattr(self, field.name)
if meta.proto_type == TYPE_MAP:
# Value represents a single key/value pair entry in the map.
current[value.key] = value.value
elif isinstance(current, list) and not isinstance(value, list):
current.append(value)
else:
setattr(self, field.name, value)
else:
field_name = self._betterproto.field_name_by_number.get(parsed.number)
if not field_name:
self._unknown_fields += parsed.raw
continue
meta = self._betterproto.meta_by_field_name[field_name]
value: Any
if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES:
# This is a packed repeated field.
pos = 0
value = []
while pos < len(parsed.value):
if meta.proto_type in ["float", "fixed32", "sfixed32"]:
decoded, pos = parsed.value[pos : pos + 4], pos + 4
wire_type = WIRE_FIXED_32
elif meta.proto_type in ["double", "fixed64", "sfixed64"]:
decoded, pos = parsed.value[pos : pos + 8], pos + 8
wire_type = WIRE_FIXED_64
else:
decoded, pos = decode_varint(parsed.value, pos)
wire_type = WIRE_VARINT
decoded = self._postprocess_single(
wire_type, meta, field_name, decoded
)
value.append(decoded)
else:
value = self._postprocess_single(
parsed.wire_type, meta, field_name, parsed.value
)
current = getattr(self, field_name)
if meta.proto_type == TYPE_MAP:
# Value represents a single key/value pair entry in the map.
current[value.key] = value.value
elif isinstance(current, list) and not isinstance(value, list):
current.append(value)
else:
setattr(self, field_name, value)
return self
@ -792,10 +803,9 @@ class Message(ABC):
`False`.
"""
output: Dict[str, Any] = {}
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
v = getattr(self, field.name)
cased_name = casing(field.name).rstrip("_") # type: ignore
for field_name, meta in self._betterproto.meta_by_field_name.items():
v = getattr(self, field_name)
cased_name = casing(field_name).rstrip("_") # type: ignore
if meta.proto_type == "message":
if isinstance(v, datetime):
if v != DATETIME_ZERO or include_default_values:
@ -821,7 +831,7 @@ class Message(ABC):
if v or include_default_values:
output[cased_name] = v
elif v != self._get_field_default(field, meta) or include_default_values:
elif v != self._get_field_default(field_name) or include_default_values:
if meta.proto_type in INT_64_TYPES:
if isinstance(v, list):
output[cased_name] = [str(n) for n in v]
@ -834,7 +844,7 @@ class Message(ABC):
output[cased_name] = b64encode(v).decode("utf8")
elif meta.proto_type == TYPE_ENUM:
enum_values = list(
self._betterproto.cls_by_field[field.name]
self._betterproto.cls_by_field[field_name]
) # type: ignore
if isinstance(v, list):
output[cased_name] = [enum_values[e].name for e in v]
@ -852,56 +862,54 @@ class Message(ABC):
self._serialized_on_wire = True
fields_by_name = {f.name: f for f in dataclasses.fields(self)}
for key in value:
snake_cased = safe_snake_case(key)
if snake_cased in fields_by_name:
field = fields_by_name[snake_cased]
meta = FieldMetadata.get(field)
field_name = safe_snake_case(key)
meta = self._betterproto.meta_by_field_name.get(field_name)
if not meta:
continue
if value[key] is not None:
if meta.proto_type == "message":
v = getattr(self, field.name)
if isinstance(v, list):
cls = self._betterproto.cls_by_field[field.name]
for i in range(len(value[key])):
v.append(cls().from_dict(value[key][i]))
elif isinstance(v, datetime):
v = datetime.fromisoformat(
value[key].replace("Z", "+00:00")
)
setattr(self, field.name, v)
elif isinstance(v, timedelta):
v = timedelta(seconds=float(value[key][:-1]))
setattr(self, field.name, v)
elif meta.wraps:
setattr(self, field.name, value[key])
else:
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field.name)
cls = self._betterproto.cls_by_field[field.name + ".value"]
for k in value[key]:
v[k] = cls().from_dict(value[key][k])
if value[key] is not None:
if meta.proto_type == "message":
v = getattr(self, field_name)
if isinstance(v, list):
cls = self._betterproto.cls_by_field[field_name]
for i in range(len(value[key])):
v.append(cls().from_dict(value[key][i]))
elif isinstance(v, datetime):
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
setattr(self, field_name, v)
elif isinstance(v, timedelta):
v = timedelta(seconds=float(value[key][:-1]))
setattr(self, field_name, v)
elif meta.wraps:
setattr(self, field_name, value[key])
else:
v = value[key]
if meta.proto_type in INT_64_TYPES:
if isinstance(value[key], list):
v = [int(n) for n in value[key]]
else:
v = int(value[key])
elif meta.proto_type == TYPE_BYTES:
if isinstance(value[key], list):
v = [b64decode(n) for n in value[key]]
else:
v = b64decode(value[key])
elif meta.proto_type == TYPE_ENUM:
enum_cls = self._betterproto.cls_by_field[field.name]
if isinstance(v, list):
v = [enum_cls.from_string(e) for e in v]
elif isinstance(v, str):
v = enum_cls.from_string(v)
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field_name)
cls = self._betterproto.cls_by_field[field_name + ".value"]
for k in value[key]:
v[k] = cls().from_dict(value[key][k])
else:
v = value[key]
if meta.proto_type in INT_64_TYPES:
if isinstance(value[key], list):
v = [int(n) for n in value[key]]
else:
v = int(value[key])
elif meta.proto_type == TYPE_BYTES:
if isinstance(value[key], list):
v = [b64decode(n) for n in value[key]]
else:
v = b64decode(value[key])
elif meta.proto_type == TYPE_ENUM:
enum_cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
v = [enum_cls.from_string(e) for e in v]
elif isinstance(v, str):
v = enum_cls.from_string(v)
if v is not None:
setattr(self, field.name, v)
if v is not None:
setattr(self, field_name, v)
return self
def to_json(self, indent: Union[None, int, str] = None) -> str:
@ -927,10 +935,10 @@ def serialized_on_wire(message: Message) -> bool:
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
"""Return the name and value of a message's one-of field group."""
field = message._group_map.get(group_name)
if not field:
field_name = message._group_current.get(group_name)
if not field_name:
return ("", None)
return (field.name, getattr(message, field.name))
return (field_name, getattr(message, field_name))
@dataclasses.dataclass