Cache resolved classes for fields, so that there's no new data classes generated while deserializing.
This commit is contained in:
parent
3d001a2a1a
commit
1f7f39049e
@ -452,25 +452,49 @@ class ProtoClassMetadata:
|
||||
self.oneof_group_by_field = by_field
|
||||
self.oneof_field_by_group = by_group
|
||||
|
||||
def init_default_gen(self):
|
||||
default_gen = {}
|
||||
|
||||
for field in dataclasses.fields(self.cls):
|
||||
meta = FieldMetadata.get(field)
|
||||
default_gen[field.name] = self.cls._get_field_default_gen(field, meta)
|
||||
|
||||
self.default_gen = default_gen
|
||||
|
||||
def init_cls_by_field(self):
|
||||
field_cls = {}
|
||||
|
||||
for field in dataclasses.fields(self.cls):
|
||||
meta = FieldMetadata.get(field)
|
||||
if meta.proto_type == TYPE_MAP:
|
||||
assert meta.map_types
|
||||
kt = self.cls._cls_for(field, index=0)
|
||||
vt = self.cls._cls_for(field, index=1)
|
||||
Entry = dataclasses.make_dataclass(
|
||||
"Entry",
|
||||
[
|
||||
("key", kt, dataclass_field(1, meta.map_types[0])),
|
||||
("value", vt, dataclass_field(2, meta.map_types[1])),
|
||||
],
|
||||
bases=(Message,),
|
||||
)
|
||||
make_protoclass(Entry)
|
||||
field_cls[field.name] = Entry
|
||||
field_cls[field.name + ".value"] = vt
|
||||
else:
|
||||
field_cls[field.name] = self.cls._cls_for(field)
|
||||
|
||||
self.cls_by_field = field_cls
|
||||
|
||||
def __getattr__(self, item):
|
||||
# Lazy init because forward reference classes may not be available at the beginning.
|
||||
if item == 'default_gen':
|
||||
defaults = {}
|
||||
for field in dataclasses.fields(self.cls):
|
||||
meta = FieldMetadata.get(field)
|
||||
defaults[field.name] = self.cls._get_field_default_gen(field, meta)
|
||||
|
||||
self.default_gen = defaults # __getattr__ won't be called next time
|
||||
return defaults
|
||||
self.init_default_gen()
|
||||
return self.default_gen
|
||||
|
||||
if item == 'cls_by_field':
|
||||
field_cls = {}
|
||||
for field in dataclasses.fields(self.cls):
|
||||
meta = FieldMetadata.get(field)
|
||||
field_cls[field.name] = self.cls._type_hint(field.name)
|
||||
|
||||
self.cls_by_field = field_cls # __getattr__ won't be called next time
|
||||
return field_cls
|
||||
self.init_cls_by_field()
|
||||
return self.cls_by_field
|
||||
|
||||
|
||||
def make_protoclass(cls):
|
||||
@ -619,12 +643,13 @@ class Message(ABC):
|
||||
type_hints = get_type_hints(cls, vars(module))
|
||||
return type_hints[field_name]
|
||||
|
||||
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
|
||||
@classmethod
|
||||
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
|
||||
"""Get the message class for a field from the type hints."""
|
||||
cls = self._betterproto.cls_by_field[field.name]
|
||||
if hasattr(cls, "__args__") and index >= 0:
|
||||
cls = cls.__args__[index]
|
||||
return cls
|
||||
field_cls = cls._type_hint(field.name)
|
||||
if hasattr(field_cls, "__args__") and index >= 0:
|
||||
field_cls = field_cls.__args__[index]
|
||||
return field_cls
|
||||
|
||||
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
|
||||
return self._betterproto.default_gen[field.name]()
|
||||
@ -680,7 +705,7 @@ class Message(ABC):
|
||||
if meta.proto_type == TYPE_STRING:
|
||||
value = value.decode("utf-8")
|
||||
elif meta.proto_type == TYPE_MESSAGE:
|
||||
cls = self._cls_for(field)
|
||||
cls = self._betterproto.cls_by_field[field.name]
|
||||
|
||||
if cls == datetime:
|
||||
value = _Timestamp().parse(value).to_datetime()
|
||||
@ -694,21 +719,7 @@ class Message(ABC):
|
||||
value = cls().parse(value)
|
||||
value._serialized_on_wire = True
|
||||
elif meta.proto_type == TYPE_MAP:
|
||||
# TODO: This is slow, use a cache to make it faster since each
|
||||
# key/value pair will recreate the class.
|
||||
assert meta.map_types
|
||||
kt = self._cls_for(field, index=0)
|
||||
vt = self._cls_for(field, index=1)
|
||||
Entry = dataclasses.make_dataclass(
|
||||
"Entry",
|
||||
[
|
||||
("key", kt, dataclass_field(1, meta.map_types[0])),
|
||||
("value", vt, dataclass_field(2, meta.map_types[1])),
|
||||
],
|
||||
bases=(Message,),
|
||||
)
|
||||
make_protoclass(Entry)
|
||||
value = Entry().parse(value)
|
||||
value = self._betterproto.cls_by_field[field.name]().parse(value)
|
||||
|
||||
return value
|
||||
|
||||
@ -823,7 +834,7 @@ class Message(ABC):
|
||||
else:
|
||||
output[cased_name] = b64encode(v).decode("utf8")
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_values = list(self._cls_for(field)) # type: ignore
|
||||
enum_values = list(self._betterproto.cls_by_field[field.name]) # type: ignore
|
||||
if isinstance(v, list):
|
||||
output[cased_name] = [enum_values[e].name for e in v]
|
||||
else:
|
||||
@ -849,7 +860,7 @@ class Message(ABC):
|
||||
if meta.proto_type == "message":
|
||||
v = getattr(self, field.name)
|
||||
if isinstance(v, list):
|
||||
cls = self._cls_for(field)
|
||||
cls = self._betterproto.cls_by_field[field.name]
|
||||
for i in range(len(value[key])):
|
||||
v.append(cls().from_dict(value[key][i]))
|
||||
elif isinstance(v, datetime):
|
||||
@ -866,7 +877,7 @@ class Message(ABC):
|
||||
v.from_dict(value[key])
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
v = getattr(self, field.name)
|
||||
cls = self._cls_for(field, index=1)
|
||||
cls = self._betterproto.cls_by_field[field.name + ".value"]
|
||||
for k in value[key]:
|
||||
v[k] = cls().from_dict(value[key][k])
|
||||
else:
|
||||
@ -882,7 +893,7 @@ class Message(ABC):
|
||||
else:
|
||||
v = b64decode(value[key])
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_cls = self._cls_for(field)
|
||||
enum_cls = self._betterproto.cls_by_field[field.name]
|
||||
if isinstance(v, list):
|
||||
v = [enum_cls.from_string(e) for e in v]
|
||||
elif isinstance(v, str):
|
||||
|
Loading…
x
Reference in New Issue
Block a user