Cache resolved classes for fields, so that there's no new data classes generated while deserializing.
This commit is contained in:
parent
3d001a2a1a
commit
1f7f39049e
@ -452,25 +452,49 @@ class ProtoClassMetadata:
|
|||||||
self.oneof_group_by_field = by_field
|
self.oneof_group_by_field = by_field
|
||||||
self.oneof_field_by_group = by_group
|
self.oneof_field_by_group = by_group
|
||||||
|
|
||||||
|
def init_default_gen(self):
|
||||||
|
default_gen = {}
|
||||||
|
|
||||||
|
for field in dataclasses.fields(self.cls):
|
||||||
|
meta = FieldMetadata.get(field)
|
||||||
|
default_gen[field.name] = self.cls._get_field_default_gen(field, meta)
|
||||||
|
|
||||||
|
self.default_gen = default_gen
|
||||||
|
|
||||||
|
def init_cls_by_field(self):
|
||||||
|
field_cls = {}
|
||||||
|
|
||||||
|
for field in dataclasses.fields(self.cls):
|
||||||
|
meta = FieldMetadata.get(field)
|
||||||
|
if meta.proto_type == TYPE_MAP:
|
||||||
|
assert meta.map_types
|
||||||
|
kt = self.cls._cls_for(field, index=0)
|
||||||
|
vt = self.cls._cls_for(field, index=1)
|
||||||
|
Entry = dataclasses.make_dataclass(
|
||||||
|
"Entry",
|
||||||
|
[
|
||||||
|
("key", kt, dataclass_field(1, meta.map_types[0])),
|
||||||
|
("value", vt, dataclass_field(2, meta.map_types[1])),
|
||||||
|
],
|
||||||
|
bases=(Message,),
|
||||||
|
)
|
||||||
|
make_protoclass(Entry)
|
||||||
|
field_cls[field.name] = Entry
|
||||||
|
field_cls[field.name + ".value"] = vt
|
||||||
|
else:
|
||||||
|
field_cls[field.name] = self.cls._cls_for(field)
|
||||||
|
|
||||||
|
self.cls_by_field = field_cls
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
# Lazy init because forward reference classes may not be available at the beginning.
|
# Lazy init because forward reference classes may not be available at the beginning.
|
||||||
if item == 'default_gen':
|
if item == 'default_gen':
|
||||||
defaults = {}
|
self.init_default_gen()
|
||||||
for field in dataclasses.fields(self.cls):
|
return self.default_gen
|
||||||
meta = FieldMetadata.get(field)
|
|
||||||
defaults[field.name] = self.cls._get_field_default_gen(field, meta)
|
|
||||||
|
|
||||||
self.default_gen = defaults # __getattr__ won't be called next time
|
|
||||||
return defaults
|
|
||||||
|
|
||||||
if item == 'cls_by_field':
|
if item == 'cls_by_field':
|
||||||
field_cls = {}
|
self.init_cls_by_field()
|
||||||
for field in dataclasses.fields(self.cls):
|
return self.cls_by_field
|
||||||
meta = FieldMetadata.get(field)
|
|
||||||
field_cls[field.name] = self.cls._type_hint(field.name)
|
|
||||||
|
|
||||||
self.cls_by_field = field_cls # __getattr__ won't be called next time
|
|
||||||
return field_cls
|
|
||||||
|
|
||||||
|
|
||||||
def make_protoclass(cls):
|
def make_protoclass(cls):
|
||||||
@ -619,12 +643,13 @@ class Message(ABC):
|
|||||||
type_hints = get_type_hints(cls, vars(module))
|
type_hints = get_type_hints(cls, vars(module))
|
||||||
return type_hints[field_name]
|
return type_hints[field_name]
|
||||||
|
|
||||||
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
|
@classmethod
|
||||||
|
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
|
||||||
"""Get the message class for a field from the type hints."""
|
"""Get the message class for a field from the type hints."""
|
||||||
cls = self._betterproto.cls_by_field[field.name]
|
field_cls = cls._type_hint(field.name)
|
||||||
if hasattr(cls, "__args__") and index >= 0:
|
if hasattr(field_cls, "__args__") and index >= 0:
|
||||||
cls = cls.__args__[index]
|
field_cls = field_cls.__args__[index]
|
||||||
return cls
|
return field_cls
|
||||||
|
|
||||||
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
|
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
|
||||||
return self._betterproto.default_gen[field.name]()
|
return self._betterproto.default_gen[field.name]()
|
||||||
@ -680,7 +705,7 @@ class Message(ABC):
|
|||||||
if meta.proto_type == TYPE_STRING:
|
if meta.proto_type == TYPE_STRING:
|
||||||
value = value.decode("utf-8")
|
value = value.decode("utf-8")
|
||||||
elif meta.proto_type == TYPE_MESSAGE:
|
elif meta.proto_type == TYPE_MESSAGE:
|
||||||
cls = self._cls_for(field)
|
cls = self._betterproto.cls_by_field[field.name]
|
||||||
|
|
||||||
if cls == datetime:
|
if cls == datetime:
|
||||||
value = _Timestamp().parse(value).to_datetime()
|
value = _Timestamp().parse(value).to_datetime()
|
||||||
@ -694,21 +719,7 @@ class Message(ABC):
|
|||||||
value = cls().parse(value)
|
value = cls().parse(value)
|
||||||
value._serialized_on_wire = True
|
value._serialized_on_wire = True
|
||||||
elif meta.proto_type == TYPE_MAP:
|
elif meta.proto_type == TYPE_MAP:
|
||||||
# TODO: This is slow, use a cache to make it faster since each
|
value = self._betterproto.cls_by_field[field.name]().parse(value)
|
||||||
# key/value pair will recreate the class.
|
|
||||||
assert meta.map_types
|
|
||||||
kt = self._cls_for(field, index=0)
|
|
||||||
vt = self._cls_for(field, index=1)
|
|
||||||
Entry = dataclasses.make_dataclass(
|
|
||||||
"Entry",
|
|
||||||
[
|
|
||||||
("key", kt, dataclass_field(1, meta.map_types[0])),
|
|
||||||
("value", vt, dataclass_field(2, meta.map_types[1])),
|
|
||||||
],
|
|
||||||
bases=(Message,),
|
|
||||||
)
|
|
||||||
make_protoclass(Entry)
|
|
||||||
value = Entry().parse(value)
|
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@ -823,7 +834,7 @@ class Message(ABC):
|
|||||||
else:
|
else:
|
||||||
output[cased_name] = b64encode(v).decode("utf8")
|
output[cased_name] = b64encode(v).decode("utf8")
|
||||||
elif meta.proto_type == TYPE_ENUM:
|
elif meta.proto_type == TYPE_ENUM:
|
||||||
enum_values = list(self._cls_for(field)) # type: ignore
|
enum_values = list(self._betterproto.cls_by_field[field.name]) # type: ignore
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
output[cased_name] = [enum_values[e].name for e in v]
|
output[cased_name] = [enum_values[e].name for e in v]
|
||||||
else:
|
else:
|
||||||
@ -849,7 +860,7 @@ class Message(ABC):
|
|||||||
if meta.proto_type == "message":
|
if meta.proto_type == "message":
|
||||||
v = getattr(self, field.name)
|
v = getattr(self, field.name)
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
cls = self._cls_for(field)
|
cls = self._betterproto.cls_by_field[field.name]
|
||||||
for i in range(len(value[key])):
|
for i in range(len(value[key])):
|
||||||
v.append(cls().from_dict(value[key][i]))
|
v.append(cls().from_dict(value[key][i]))
|
||||||
elif isinstance(v, datetime):
|
elif isinstance(v, datetime):
|
||||||
@ -866,7 +877,7 @@ class Message(ABC):
|
|||||||
v.from_dict(value[key])
|
v.from_dict(value[key])
|
||||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||||
v = getattr(self, field.name)
|
v = getattr(self, field.name)
|
||||||
cls = self._cls_for(field, index=1)
|
cls = self._betterproto.cls_by_field[field.name + ".value"]
|
||||||
for k in value[key]:
|
for k in value[key]:
|
||||||
v[k] = cls().from_dict(value[key][k])
|
v[k] = cls().from_dict(value[key][k])
|
||||||
else:
|
else:
|
||||||
@ -882,7 +893,7 @@ class Message(ABC):
|
|||||||
else:
|
else:
|
||||||
v = b64decode(value[key])
|
v = b64decode(value[key])
|
||||||
elif meta.proto_type == TYPE_ENUM:
|
elif meta.proto_type == TYPE_ENUM:
|
||||||
enum_cls = self._cls_for(field)
|
enum_cls = self._betterproto.cls_by_field[field.name]
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = [enum_cls.from_string(e) for e in v]
|
v = [enum_cls.from_string(e) for e in v]
|
||||||
elif isinstance(v, str):
|
elif isinstance(v, str):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user