Cache resolved classes for fields, so that there's no new data classes generated while deserializing.

This commit is contained in:
James Lan 2020-05-19 15:42:26 -07:00
parent 3d001a2a1a
commit 1f7f39049e

View File

@ -452,25 +452,49 @@ class ProtoClassMetadata:
self.oneof_group_by_field = by_field
self.oneof_field_by_group = by_group
def init_default_gen(self):
default_gen = {}
for field in dataclasses.fields(self.cls):
meta = FieldMetadata.get(field)
default_gen[field.name] = self.cls._get_field_default_gen(field, meta)
self.default_gen = default_gen
def init_cls_by_field(self):
field_cls = {}
for field in dataclasses.fields(self.cls):
meta = FieldMetadata.get(field)
if meta.proto_type == TYPE_MAP:
assert meta.map_types
kt = self.cls._cls_for(field, index=0)
vt = self.cls._cls_for(field, index=1)
Entry = dataclasses.make_dataclass(
"Entry",
[
("key", kt, dataclass_field(1, meta.map_types[0])),
("value", vt, dataclass_field(2, meta.map_types[1])),
],
bases=(Message,),
)
make_protoclass(Entry)
field_cls[field.name] = Entry
field_cls[field.name + ".value"] = vt
else:
field_cls[field.name] = self.cls._cls_for(field)
self.cls_by_field = field_cls
def __getattr__(self, item):
# Lazy init because forward reference classes may not be available at the beginning.
if item == 'default_gen':
defaults = {}
for field in dataclasses.fields(self.cls):
meta = FieldMetadata.get(field)
defaults[field.name] = self.cls._get_field_default_gen(field, meta)
self.default_gen = defaults # __getattr__ won't be called next time
return defaults
self.init_default_gen()
return self.default_gen
if item == 'cls_by_field':
field_cls = {}
for field in dataclasses.fields(self.cls):
meta = FieldMetadata.get(field)
field_cls[field.name] = self.cls._type_hint(field.name)
self.cls_by_field = field_cls # __getattr__ won't be called next time
return field_cls
self.init_cls_by_field()
return self.cls_by_field
def make_protoclass(cls):
@ -619,12 +643,13 @@ class Message(ABC):
type_hints = get_type_hints(cls, vars(module))
return type_hints[field_name]
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
@classmethod
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
"""Get the message class for a field from the type hints."""
cls = self._betterproto.cls_by_field[field.name]
if hasattr(cls, "__args__") and index >= 0:
cls = cls.__args__[index]
return cls
field_cls = cls._type_hint(field.name)
if hasattr(field_cls, "__args__") and index >= 0:
field_cls = field_cls.__args__[index]
return field_cls
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
return self._betterproto.default_gen[field.name]()
@ -680,7 +705,7 @@ class Message(ABC):
if meta.proto_type == TYPE_STRING:
value = value.decode("utf-8")
elif meta.proto_type == TYPE_MESSAGE:
cls = self._cls_for(field)
cls = self._betterproto.cls_by_field[field.name]
if cls == datetime:
value = _Timestamp().parse(value).to_datetime()
@ -694,21 +719,7 @@ class Message(ABC):
value = cls().parse(value)
value._serialized_on_wire = True
elif meta.proto_type == TYPE_MAP:
# TODO: This is slow, use a cache to make it faster since each
# key/value pair will recreate the class.
assert meta.map_types
kt = self._cls_for(field, index=0)
vt = self._cls_for(field, index=1)
Entry = dataclasses.make_dataclass(
"Entry",
[
("key", kt, dataclass_field(1, meta.map_types[0])),
("value", vt, dataclass_field(2, meta.map_types[1])),
],
bases=(Message,),
)
make_protoclass(Entry)
value = Entry().parse(value)
value = self._betterproto.cls_by_field[field.name]().parse(value)
return value
@ -823,7 +834,7 @@ class Message(ABC):
else:
output[cased_name] = b64encode(v).decode("utf8")
elif meta.proto_type == TYPE_ENUM:
enum_values = list(self._cls_for(field)) # type: ignore
enum_values = list(self._betterproto.cls_by_field[field.name]) # type: ignore
if isinstance(v, list):
output[cased_name] = [enum_values[e].name for e in v]
else:
@ -849,7 +860,7 @@ class Message(ABC):
if meta.proto_type == "message":
v = getattr(self, field.name)
if isinstance(v, list):
cls = self._cls_for(field)
cls = self._betterproto.cls_by_field[field.name]
for i in range(len(value[key])):
v.append(cls().from_dict(value[key][i]))
elif isinstance(v, datetime):
@ -866,7 +877,7 @@ class Message(ABC):
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field.name)
cls = self._cls_for(field, index=1)
cls = self._betterproto.cls_by_field[field.name + ".value"]
for k in value[key]:
v[k] = cls().from_dict(value[key][k])
else:
@ -882,7 +893,7 @@ class Message(ABC):
else:
v = b64decode(value[key])
elif meta.proto_type == TYPE_ENUM:
enum_cls = self._cls_for(field)
enum_cls = self._betterproto.cls_by_field[field.name]
if isinstance(v, list):
v = [enum_cls.from_string(e) for e in v]
elif isinstance(v, str):