Cache resolved classes for fields, so that there's no new data classes generated while deserializing.

This commit is contained in:
James Lan 2020-05-19 15:42:26 -07:00
parent 3d001a2a1a
commit 1f7f39049e

View File

@ -452,25 +452,49 @@ class ProtoClassMetadata:
self.oneof_group_by_field = by_field self.oneof_group_by_field = by_field
self.oneof_field_by_group = by_group self.oneof_field_by_group = by_group
def init_default_gen(self):
default_gen = {}
for field in dataclasses.fields(self.cls):
meta = FieldMetadata.get(field)
default_gen[field.name] = self.cls._get_field_default_gen(field, meta)
self.default_gen = default_gen
def init_cls_by_field(self):
field_cls = {}
for field in dataclasses.fields(self.cls):
meta = FieldMetadata.get(field)
if meta.proto_type == TYPE_MAP:
assert meta.map_types
kt = self.cls._cls_for(field, index=0)
vt = self.cls._cls_for(field, index=1)
Entry = dataclasses.make_dataclass(
"Entry",
[
("key", kt, dataclass_field(1, meta.map_types[0])),
("value", vt, dataclass_field(2, meta.map_types[1])),
],
bases=(Message,),
)
make_protoclass(Entry)
field_cls[field.name] = Entry
field_cls[field.name + ".value"] = vt
else:
field_cls[field.name] = self.cls._cls_for(field)
self.cls_by_field = field_cls
def __getattr__(self, item): def __getattr__(self, item):
# Lazy init because forward reference classes may not be available at the beginning. # Lazy init because forward reference classes may not be available at the beginning.
if item == 'default_gen': if item == 'default_gen':
defaults = {} self.init_default_gen()
for field in dataclasses.fields(self.cls): return self.default_gen
meta = FieldMetadata.get(field)
defaults[field.name] = self.cls._get_field_default_gen(field, meta)
self.default_gen = defaults # __getattr__ won't be called next time
return defaults
if item == 'cls_by_field': if item == 'cls_by_field':
field_cls = {} self.init_cls_by_field()
for field in dataclasses.fields(self.cls): return self.cls_by_field
meta = FieldMetadata.get(field)
field_cls[field.name] = self.cls._type_hint(field.name)
self.cls_by_field = field_cls # __getattr__ won't be called next time
return field_cls
def make_protoclass(cls): def make_protoclass(cls):
@ -619,12 +643,13 @@ class Message(ABC):
type_hints = get_type_hints(cls, vars(module)) type_hints = get_type_hints(cls, vars(module))
return type_hints[field_name] return type_hints[field_name]
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type: @classmethod
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
"""Get the message class for a field from the type hints.""" """Get the message class for a field from the type hints."""
cls = self._betterproto.cls_by_field[field.name] field_cls = cls._type_hint(field.name)
if hasattr(cls, "__args__") and index >= 0: if hasattr(field_cls, "__args__") and index >= 0:
cls = cls.__args__[index] field_cls = field_cls.__args__[index]
return cls return field_cls
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any: def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
return self._betterproto.default_gen[field.name]() return self._betterproto.default_gen[field.name]()
@ -680,7 +705,7 @@ class Message(ABC):
if meta.proto_type == TYPE_STRING: if meta.proto_type == TYPE_STRING:
value = value.decode("utf-8") value = value.decode("utf-8")
elif meta.proto_type == TYPE_MESSAGE: elif meta.proto_type == TYPE_MESSAGE:
cls = self._cls_for(field) cls = self._betterproto.cls_by_field[field.name]
if cls == datetime: if cls == datetime:
value = _Timestamp().parse(value).to_datetime() value = _Timestamp().parse(value).to_datetime()
@ -694,21 +719,7 @@ class Message(ABC):
value = cls().parse(value) value = cls().parse(value)
value._serialized_on_wire = True value._serialized_on_wire = True
elif meta.proto_type == TYPE_MAP: elif meta.proto_type == TYPE_MAP:
# TODO: This is slow, use a cache to make it faster since each value = self._betterproto.cls_by_field[field.name]().parse(value)
# key/value pair will recreate the class.
assert meta.map_types
kt = self._cls_for(field, index=0)
vt = self._cls_for(field, index=1)
Entry = dataclasses.make_dataclass(
"Entry",
[
("key", kt, dataclass_field(1, meta.map_types[0])),
("value", vt, dataclass_field(2, meta.map_types[1])),
],
bases=(Message,),
)
make_protoclass(Entry)
value = Entry().parse(value)
return value return value
@ -823,7 +834,7 @@ class Message(ABC):
else: else:
output[cased_name] = b64encode(v).decode("utf8") output[cased_name] = b64encode(v).decode("utf8")
elif meta.proto_type == TYPE_ENUM: elif meta.proto_type == TYPE_ENUM:
enum_values = list(self._cls_for(field)) # type: ignore enum_values = list(self._betterproto.cls_by_field[field.name]) # type: ignore
if isinstance(v, list): if isinstance(v, list):
output[cased_name] = [enum_values[e].name for e in v] output[cased_name] = [enum_values[e].name for e in v]
else: else:
@ -849,7 +860,7 @@ class Message(ABC):
if meta.proto_type == "message": if meta.proto_type == "message":
v = getattr(self, field.name) v = getattr(self, field.name)
if isinstance(v, list): if isinstance(v, list):
cls = self._cls_for(field) cls = self._betterproto.cls_by_field[field.name]
for i in range(len(value[key])): for i in range(len(value[key])):
v.append(cls().from_dict(value[key][i])) v.append(cls().from_dict(value[key][i]))
elif isinstance(v, datetime): elif isinstance(v, datetime):
@ -866,7 +877,7 @@ class Message(ABC):
v.from_dict(value[key]) v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field.name) v = getattr(self, field.name)
cls = self._cls_for(field, index=1) cls = self._betterproto.cls_by_field[field.name + ".value"]
for k in value[key]: for k in value[key]:
v[k] = cls().from_dict(value[key][k]) v[k] = cls().from_dict(value[key][k])
else: else:
@ -882,7 +893,7 @@ class Message(ABC):
else: else:
v = b64decode(value[key]) v = b64decode(value[key])
elif meta.proto_type == TYPE_ENUM: elif meta.proto_type == TYPE_ENUM:
enum_cls = self._cls_for(field) enum_cls = self._betterproto.cls_by_field[field.name]
if isinstance(v, list): if isinstance(v, list):
v = [enum_cls.from_string(e) for e in v] v = [enum_cls.from_string(e) for e in v]
elif isinstance(v, str): elif isinstance(v, str):