Cache field metadata, to avoid calling dataclasses.fields
to get more than 10% performance improvement
This commit is contained in:
parent
ee362a7a73
commit
ed33a48d64
@ -14,11 +14,10 @@ from typing import (
|
||||
Collection,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
SupportsBytes,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
@ -435,14 +434,29 @@ T = TypeVar("T", bound="Message")
|
||||
|
||||
|
||||
class ProtoClassMetadata:
|
||||
cls: Type["Message"]
|
||||
oneof_group_by_field: Dict[str, str]
|
||||
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
|
||||
default_gen: Dict[str, Callable]
|
||||
cls_by_field: Dict[str, Type]
|
||||
field_name_by_number: Dict[int, str]
|
||||
meta_by_field_name: Dict[str, FieldMetadata]
|
||||
__slots__ = (
|
||||
"oneof_group_by_field",
|
||||
"oneof_field_by_group",
|
||||
"default_gen",
|
||||
"cls_by_field",
|
||||
"field_name_by_number",
|
||||
"meta_by_field_name",
|
||||
)
|
||||
|
||||
def __init__(self, cls: Type["Message"]):
|
||||
self.cls = cls
|
||||
by_field = {}
|
||||
by_group = {}
|
||||
by_field_name = {}
|
||||
by_field_number = {}
|
||||
|
||||
for field in dataclasses.fields(cls):
|
||||
fields = dataclasses.fields(cls)
|
||||
for field in fields:
|
||||
meta = FieldMetadata.get(field)
|
||||
|
||||
if meta.group:
|
||||
@ -451,30 +465,36 @@ class ProtoClassMetadata:
|
||||
|
||||
by_group.setdefault(meta.group, set()).add(field)
|
||||
|
||||
by_field_name[field.name] = meta
|
||||
by_field_number[meta.number] = field.name
|
||||
|
||||
self.oneof_group_by_field = by_field
|
||||
self.oneof_field_by_group = by_group
|
||||
self.field_name_by_number = by_field_number
|
||||
self.meta_by_field_name = by_field_name
|
||||
|
||||
self.init_default_gen()
|
||||
self.init_cls_by_field()
|
||||
self.default_gen = self._get_default_gen(cls, fields)
|
||||
self.cls_by_field = self._get_cls_by_field(cls, fields)
|
||||
|
||||
def init_default_gen(self):
|
||||
@staticmethod
|
||||
def _get_default_gen(cls, fields):
|
||||
default_gen = {}
|
||||
|
||||
for field in dataclasses.fields(self.cls):
|
||||
meta = FieldMetadata.get(field)
|
||||
default_gen[field.name] = self.cls._get_field_default_gen(field, meta)
|
||||
for field in fields:
|
||||
default_gen[field.name] = cls._get_field_default_gen(field)
|
||||
|
||||
self.default_gen = default_gen
|
||||
return default_gen
|
||||
|
||||
def init_cls_by_field(self):
|
||||
@staticmethod
|
||||
def _get_cls_by_field(cls, fields):
|
||||
field_cls = {}
|
||||
|
||||
for field in dataclasses.fields(self.cls):
|
||||
for field in fields:
|
||||
meta = FieldMetadata.get(field)
|
||||
if meta.proto_type == TYPE_MAP:
|
||||
assert meta.map_types
|
||||
kt = self.cls._cls_for(field, index=0)
|
||||
vt = self.cls._cls_for(field, index=1)
|
||||
kt = cls._cls_for(field, index=0)
|
||||
vt = cls._cls_for(field, index=1)
|
||||
Entry = dataclasses.make_dataclass(
|
||||
"Entry",
|
||||
[
|
||||
@ -486,9 +506,9 @@ class ProtoClassMetadata:
|
||||
field_cls[field.name] = Entry
|
||||
field_cls[field.name + ".value"] = vt
|
||||
else:
|
||||
field_cls[field.name] = self.cls._cls_for(field)
|
||||
field_cls[field.name] = cls._cls_for(field)
|
||||
|
||||
self.cls_by_field = field_cls
|
||||
return field_cls
|
||||
|
||||
|
||||
class Message(ABC):
|
||||
@ -500,53 +520,50 @@ class Message(ABC):
|
||||
|
||||
_serialized_on_wire: bool
|
||||
_unknown_fields: bytes
|
||||
_group_map: Dict[str, dict]
|
||||
_group_current: Dict[str, str]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Keep track of whether every field was default
|
||||
all_sentinel = True
|
||||
|
||||
# Set a default value for each field in the class after `__init__` has
|
||||
# already been run.
|
||||
group_map: Dict[str, dataclasses.Field] = {}
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
# Set current field of each group after `__init__` has already been run.
|
||||
group_current: Dict[str, str] = {}
|
||||
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
||||
|
||||
if meta.group:
|
||||
group_map.setdefault(meta.group)
|
||||
group_current.setdefault(meta.group)
|
||||
|
||||
if getattr(self, field.name) != PLACEHOLDER:
|
||||
if getattr(self, field_name) != PLACEHOLDER:
|
||||
# Skip anything not set to the sentinel value
|
||||
all_sentinel = False
|
||||
|
||||
if meta.group:
|
||||
# This was set, so make it the selected value of the one-of.
|
||||
group_map[meta.group] = field
|
||||
group_current[meta.group] = field_name
|
||||
|
||||
continue
|
||||
|
||||
setattr(self, field.name, self._get_field_default(field, meta))
|
||||
setattr(self, field_name, self._get_field_default(field_name))
|
||||
|
||||
# Now that all the defaults are set, reset it!
|
||||
self.__dict__["_serialized_on_wire"] = not all_sentinel
|
||||
self.__dict__["_unknown_fields"] = b""
|
||||
self.__dict__["_group_map"] = group_map
|
||||
self.__dict__["_group_current"] = group_current
|
||||
|
||||
def __setattr__(self, attr: str, value: Any) -> None:
|
||||
if attr != "_serialized_on_wire":
|
||||
# Track when a field has been set.
|
||||
self.__dict__["_serialized_on_wire"] = True
|
||||
|
||||
if hasattr(self, "_group_map"): # __post_init__ had already run
|
||||
if hasattr(self, "_group_current"): # __post_init__ had already run
|
||||
if attr in self._betterproto.oneof_group_by_field:
|
||||
group = self._betterproto.oneof_group_by_field[attr]
|
||||
for field in self._betterproto.oneof_field_by_group[group]:
|
||||
if field.name == attr:
|
||||
self._group_map[group] = field
|
||||
self._group_current[group] = field.name
|
||||
else:
|
||||
super().__setattr__(
|
||||
field.name,
|
||||
self._get_field_default(field, FieldMetadata.get(field)),
|
||||
field.name, self._get_field_default(field.name),
|
||||
)
|
||||
|
||||
super().__setattr__(attr, value)
|
||||
@ -569,9 +586,8 @@ class Message(ABC):
|
||||
Get the binary encoded Protobuf representation of this instance.
|
||||
"""
|
||||
output = b""
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
value = getattr(self, field.name)
|
||||
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
||||
value = getattr(self, field_name)
|
||||
|
||||
if value is None:
|
||||
# Optional items should be skipped. This is used for the Google
|
||||
@ -582,7 +598,7 @@ class Message(ABC):
|
||||
# currently set in a `oneof` group, so it must be serialized even
|
||||
# if the value is the default zero value.
|
||||
selected_in_group = False
|
||||
if meta.group and self._group_map[meta.group] == field:
|
||||
if meta.group and self._group_current[meta.group] == field_name:
|
||||
selected_in_group = True
|
||||
|
||||
serialize_empty = False
|
||||
@ -591,7 +607,7 @@ class Message(ABC):
|
||||
# set (or received empty).
|
||||
serialize_empty = True
|
||||
|
||||
if value == self._get_field_default(field, meta) and not (
|
||||
if value == self._get_field_default(field_name) and not (
|
||||
selected_in_group or serialize_empty
|
||||
):
|
||||
# Default (zero) values are not serialized. Two exceptions are
|
||||
@ -648,13 +664,11 @@ class Message(ABC):
|
||||
field_cls = field_cls.__args__[index]
|
||||
return field_cls
|
||||
|
||||
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
|
||||
return self._betterproto.default_gen[field.name]()
|
||||
def _get_field_default(self, field_name):
|
||||
return self._betterproto.default_gen[field_name]()
|
||||
|
||||
@classmethod
|
||||
def _get_field_default_gen(
|
||||
cls, field: dataclasses.Field, meta: FieldMetadata
|
||||
) -> Any:
|
||||
def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
|
||||
t = cls._type_hint(field.name)
|
||||
|
||||
if hasattr(t, "__origin__"):
|
||||
@ -682,7 +696,7 @@ class Message(ABC):
|
||||
return t
|
||||
|
||||
def _postprocess_single(
|
||||
self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any
|
||||
self, wire_type: int, meta: FieldMetadata, field_name: str, value: Any
|
||||
) -> Any:
|
||||
"""Adjusts values after parsing."""
|
||||
if wire_type == WIRE_VARINT:
|
||||
@ -704,7 +718,7 @@ class Message(ABC):
|
||||
if meta.proto_type == TYPE_STRING:
|
||||
value = value.decode("utf-8")
|
||||
elif meta.proto_type == TYPE_MESSAGE:
|
||||
cls = self._betterproto.cls_by_field[field.name]
|
||||
cls = self._betterproto.cls_by_field[field_name]
|
||||
|
||||
if cls == datetime:
|
||||
value = _Timestamp().parse(value).to_datetime()
|
||||
@ -718,7 +732,7 @@ class Message(ABC):
|
||||
value = cls().parse(value)
|
||||
value._serialized_on_wire = True
|
||||
elif meta.proto_type == TYPE_MAP:
|
||||
value = self._betterproto.cls_by_field[field.name]().parse(value)
|
||||
value = self._betterproto.cls_by_field[field_name]().parse(value)
|
||||
|
||||
return value
|
||||
|
||||
@ -727,49 +741,46 @@ class Message(ABC):
|
||||
Parse the binary encoded Protobuf into this message instance. This
|
||||
returns the instance itself and is therefore assignable and chainable.
|
||||
"""
|
||||
fields = {f.metadata["betterproto"].number: f for f in dataclasses.fields(self)}
|
||||
for parsed in parse_fields(data):
|
||||
if parsed.number in fields:
|
||||
field = fields[parsed.number]
|
||||
meta = FieldMetadata.get(field)
|
||||
|
||||
value: Any
|
||||
if (
|
||||
parsed.wire_type == WIRE_LEN_DELIM
|
||||
and meta.proto_type in PACKED_TYPES
|
||||
):
|
||||
# This is a packed repeated field.
|
||||
pos = 0
|
||||
value = []
|
||||
while pos < len(parsed.value):
|
||||
if meta.proto_type in ["float", "fixed32", "sfixed32"]:
|
||||
decoded, pos = parsed.value[pos : pos + 4], pos + 4
|
||||
wire_type = WIRE_FIXED_32
|
||||
elif meta.proto_type in ["double", "fixed64", "sfixed64"]:
|
||||
decoded, pos = parsed.value[pos : pos + 8], pos + 8
|
||||
wire_type = WIRE_FIXED_64
|
||||
else:
|
||||
decoded, pos = decode_varint(parsed.value, pos)
|
||||
wire_type = WIRE_VARINT
|
||||
decoded = self._postprocess_single(
|
||||
wire_type, meta, field, decoded
|
||||
)
|
||||
value.append(decoded)
|
||||
else:
|
||||
value = self._postprocess_single(
|
||||
parsed.wire_type, meta, field, parsed.value
|
||||
)
|
||||
|
||||
current = getattr(self, field.name)
|
||||
if meta.proto_type == TYPE_MAP:
|
||||
# Value represents a single key/value pair entry in the map.
|
||||
current[value.key] = value.value
|
||||
elif isinstance(current, list) and not isinstance(value, list):
|
||||
current.append(value)
|
||||
else:
|
||||
setattr(self, field.name, value)
|
||||
else:
|
||||
field_name = self._betterproto.field_name_by_number.get(parsed.number)
|
||||
if not field_name:
|
||||
self._unknown_fields += parsed.raw
|
||||
continue
|
||||
|
||||
meta = self._betterproto.meta_by_field_name[field_name]
|
||||
|
||||
value: Any
|
||||
if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES:
|
||||
# This is a packed repeated field.
|
||||
pos = 0
|
||||
value = []
|
||||
while pos < len(parsed.value):
|
||||
if meta.proto_type in ["float", "fixed32", "sfixed32"]:
|
||||
decoded, pos = parsed.value[pos : pos + 4], pos + 4
|
||||
wire_type = WIRE_FIXED_32
|
||||
elif meta.proto_type in ["double", "fixed64", "sfixed64"]:
|
||||
decoded, pos = parsed.value[pos : pos + 8], pos + 8
|
||||
wire_type = WIRE_FIXED_64
|
||||
else:
|
||||
decoded, pos = decode_varint(parsed.value, pos)
|
||||
wire_type = WIRE_VARINT
|
||||
decoded = self._postprocess_single(
|
||||
wire_type, meta, field_name, decoded
|
||||
)
|
||||
value.append(decoded)
|
||||
else:
|
||||
value = self._postprocess_single(
|
||||
parsed.wire_type, meta, field_name, parsed.value
|
||||
)
|
||||
|
||||
current = getattr(self, field_name)
|
||||
if meta.proto_type == TYPE_MAP:
|
||||
# Value represents a single key/value pair entry in the map.
|
||||
current[value.key] = value.value
|
||||
elif isinstance(current, list) and not isinstance(value, list):
|
||||
current.append(value)
|
||||
else:
|
||||
setattr(self, field_name, value)
|
||||
|
||||
return self
|
||||
|
||||
@ -792,10 +803,9 @@ class Message(ABC):
|
||||
`False`.
|
||||
"""
|
||||
output: Dict[str, Any] = {}
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
v = getattr(self, field.name)
|
||||
cased_name = casing(field.name).rstrip("_") # type: ignore
|
||||
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
||||
v = getattr(self, field_name)
|
||||
cased_name = casing(field_name).rstrip("_") # type: ignore
|
||||
if meta.proto_type == "message":
|
||||
if isinstance(v, datetime):
|
||||
if v != DATETIME_ZERO or include_default_values:
|
||||
@ -821,7 +831,7 @@ class Message(ABC):
|
||||
|
||||
if v or include_default_values:
|
||||
output[cased_name] = v
|
||||
elif v != self._get_field_default(field, meta) or include_default_values:
|
||||
elif v != self._get_field_default(field_name) or include_default_values:
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(v, list):
|
||||
output[cased_name] = [str(n) for n in v]
|
||||
@ -834,7 +844,7 @@ class Message(ABC):
|
||||
output[cased_name] = b64encode(v).decode("utf8")
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_values = list(
|
||||
self._betterproto.cls_by_field[field.name]
|
||||
self._betterproto.cls_by_field[field_name]
|
||||
) # type: ignore
|
||||
if isinstance(v, list):
|
||||
output[cased_name] = [enum_values[e].name for e in v]
|
||||
@ -852,56 +862,54 @@ class Message(ABC):
|
||||
self._serialized_on_wire = True
|
||||
fields_by_name = {f.name: f for f in dataclasses.fields(self)}
|
||||
for key in value:
|
||||
snake_cased = safe_snake_case(key)
|
||||
if snake_cased in fields_by_name:
|
||||
field = fields_by_name[snake_cased]
|
||||
meta = FieldMetadata.get(field)
|
||||
field_name = safe_snake_case(key)
|
||||
meta = self._betterproto.meta_by_field_name.get(field_name)
|
||||
if not meta:
|
||||
continue
|
||||
|
||||
if value[key] is not None:
|
||||
if meta.proto_type == "message":
|
||||
v = getattr(self, field.name)
|
||||
if isinstance(v, list):
|
||||
cls = self._betterproto.cls_by_field[field.name]
|
||||
for i in range(len(value[key])):
|
||||
v.append(cls().from_dict(value[key][i]))
|
||||
elif isinstance(v, datetime):
|
||||
v = datetime.fromisoformat(
|
||||
value[key].replace("Z", "+00:00")
|
||||
)
|
||||
setattr(self, field.name, v)
|
||||
elif isinstance(v, timedelta):
|
||||
v = timedelta(seconds=float(value[key][:-1]))
|
||||
setattr(self, field.name, v)
|
||||
elif meta.wraps:
|
||||
setattr(self, field.name, value[key])
|
||||
else:
|
||||
v.from_dict(value[key])
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
v = getattr(self, field.name)
|
||||
cls = self._betterproto.cls_by_field[field.name + ".value"]
|
||||
for k in value[key]:
|
||||
v[k] = cls().from_dict(value[key][k])
|
||||
if value[key] is not None:
|
||||
if meta.proto_type == "message":
|
||||
v = getattr(self, field_name)
|
||||
if isinstance(v, list):
|
||||
cls = self._betterproto.cls_by_field[field_name]
|
||||
for i in range(len(value[key])):
|
||||
v.append(cls().from_dict(value[key][i]))
|
||||
elif isinstance(v, datetime):
|
||||
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
|
||||
setattr(self, field_name, v)
|
||||
elif isinstance(v, timedelta):
|
||||
v = timedelta(seconds=float(value[key][:-1]))
|
||||
setattr(self, field_name, v)
|
||||
elif meta.wraps:
|
||||
setattr(self, field_name, value[key])
|
||||
else:
|
||||
v = value[key]
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(value[key], list):
|
||||
v = [int(n) for n in value[key]]
|
||||
else:
|
||||
v = int(value[key])
|
||||
elif meta.proto_type == TYPE_BYTES:
|
||||
if isinstance(value[key], list):
|
||||
v = [b64decode(n) for n in value[key]]
|
||||
else:
|
||||
v = b64decode(value[key])
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_cls = self._betterproto.cls_by_field[field.name]
|
||||
if isinstance(v, list):
|
||||
v = [enum_cls.from_string(e) for e in v]
|
||||
elif isinstance(v, str):
|
||||
v = enum_cls.from_string(v)
|
||||
v.from_dict(value[key])
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
v = getattr(self, field_name)
|
||||
cls = self._betterproto.cls_by_field[field_name + ".value"]
|
||||
for k in value[key]:
|
||||
v[k] = cls().from_dict(value[key][k])
|
||||
else:
|
||||
v = value[key]
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(value[key], list):
|
||||
v = [int(n) for n in value[key]]
|
||||
else:
|
||||
v = int(value[key])
|
||||
elif meta.proto_type == TYPE_BYTES:
|
||||
if isinstance(value[key], list):
|
||||
v = [b64decode(n) for n in value[key]]
|
||||
else:
|
||||
v = b64decode(value[key])
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_cls = self._betterproto.cls_by_field[field_name]
|
||||
if isinstance(v, list):
|
||||
v = [enum_cls.from_string(e) for e in v]
|
||||
elif isinstance(v, str):
|
||||
v = enum_cls.from_string(v)
|
||||
|
||||
if v is not None:
|
||||
setattr(self, field.name, v)
|
||||
if v is not None:
|
||||
setattr(self, field_name, v)
|
||||
return self
|
||||
|
||||
def to_json(self, indent: Union[None, int, str] = None) -> str:
|
||||
@ -927,10 +935,10 @@ def serialized_on_wire(message: Message) -> bool:
|
||||
|
||||
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
||||
"""Return the name and value of a message's one-of field group."""
|
||||
field = message._group_map.get(group_name)
|
||||
if not field:
|
||||
field_name = message._group_current.get(group_name)
|
||||
if not field_name:
|
||||
return ("", None)
|
||||
return (field.name, getattr(message, field.name))
|
||||
return (field_name, getattr(message, field_name))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
Loading…
x
Reference in New Issue
Block a user