From 1f7f39049eb87d6809657a89e8fafb3f2bf4833e Mon Sep 17 00:00:00 2001 From: James Lan Date: Tue, 19 May 2020 15:42:26 -0700 Subject: [PATCH] Cache resolved classes for fields, so that there's no new data classes generated while deserializing. --- betterproto/__init__.py | 89 +++++++++++++++++++++++------------------ 1 file changed, 50 insertions(+), 39 deletions(-) diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 3584714..0fefb77 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -452,25 +452,49 @@ class ProtoClassMetadata: self.oneof_group_by_field = by_field self.oneof_field_by_group = by_group + def init_default_gen(self): + default_gen = {} + + for field in dataclasses.fields(self.cls): + meta = FieldMetadata.get(field) + default_gen[field.name] = self.cls._get_field_default_gen(field, meta) + + self.default_gen = default_gen + + def init_cls_by_field(self): + field_cls = {} + + for field in dataclasses.fields(self.cls): + meta = FieldMetadata.get(field) + if meta.proto_type == TYPE_MAP: + assert meta.map_types + kt = self.cls._cls_for(field, index=0) + vt = self.cls._cls_for(field, index=1) + Entry = dataclasses.make_dataclass( + "Entry", + [ + ("key", kt, dataclass_field(1, meta.map_types[0])), + ("value", vt, dataclass_field(2, meta.map_types[1])), + ], + bases=(Message,), + ) + make_protoclass(Entry) + field_cls[field.name] = Entry + field_cls[field.name + ".value"] = vt + else: + field_cls[field.name] = self.cls._cls_for(field) + + self.cls_by_field = field_cls + def __getattr__(self, item): # Lazy init because forward reference classes may not be available at the beginning. if item == 'default_gen': - defaults = {} - for field in dataclasses.fields(self.cls): - meta = FieldMetadata.get(field) - defaults[field.name] = self.cls._get_field_default_gen(field, meta) - - self.default_gen = defaults # __getattr__ won't be called next time - return defaults + self.init_default_gen() + return self.default_gen if item == 'cls_by_field': - field_cls = {} - for field in dataclasses.fields(self.cls): - meta = FieldMetadata.get(field) - field_cls[field.name] = self.cls._type_hint(field.name) - - self.cls_by_field = field_cls # __getattr__ won't be called next time - return field_cls + self.init_cls_by_field() + return self.cls_by_field def make_protoclass(cls): @@ -619,12 +643,13 @@ class Message(ABC): type_hints = get_type_hints(cls, vars(module)) return type_hints[field_name] - def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type: + @classmethod + def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type: """Get the message class for a field from the type hints.""" - cls = self._betterproto.cls_by_field[field.name] - if hasattr(cls, "__args__") and index >= 0: - cls = cls.__args__[index] - return cls + field_cls = cls._type_hint(field.name) + if hasattr(field_cls, "__args__") and index >= 0: + field_cls = field_cls.__args__[index] + return field_cls def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any: return self._betterproto.default_gen[field.name]() @@ -680,7 +705,7 @@ class Message(ABC): if meta.proto_type == TYPE_STRING: value = value.decode("utf-8") elif meta.proto_type == TYPE_MESSAGE: - cls = self._cls_for(field) + cls = self._betterproto.cls_by_field[field.name] if cls == datetime: value = _Timestamp().parse(value).to_datetime() @@ -694,21 +719,7 @@ class Message(ABC): value = cls().parse(value) value._serialized_on_wire = True elif meta.proto_type == TYPE_MAP: - # TODO: This is slow, use a cache to make it faster since each - # key/value pair will recreate the class. - assert meta.map_types - kt = self._cls_for(field, index=0) - vt = self._cls_for(field, index=1) - Entry = dataclasses.make_dataclass( - "Entry", - [ - ("key", kt, dataclass_field(1, meta.map_types[0])), - ("value", vt, dataclass_field(2, meta.map_types[1])), - ], - bases=(Message,), - ) - make_protoclass(Entry) - value = Entry().parse(value) + value = self._betterproto.cls_by_field[field.name]().parse(value) return value @@ -823,7 +834,7 @@ class Message(ABC): else: output[cased_name] = b64encode(v).decode("utf8") elif meta.proto_type == TYPE_ENUM: - enum_values = list(self._cls_for(field)) # type: ignore + enum_values = list(self._betterproto.cls_by_field[field.name]) # type: ignore if isinstance(v, list): output[cased_name] = [enum_values[e].name for e in v] else: @@ -849,7 +860,7 @@ class Message(ABC): if meta.proto_type == "message": v = getattr(self, field.name) if isinstance(v, list): - cls = self._cls_for(field) + cls = self._betterproto.cls_by_field[field.name] for i in range(len(value[key])): v.append(cls().from_dict(value[key][i])) elif isinstance(v, datetime): @@ -866,7 +877,7 @@ class Message(ABC): v.from_dict(value[key]) elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: v = getattr(self, field.name) - cls = self._cls_for(field, index=1) + cls = self._betterproto.cls_by_field[field.name + ".value"] for k in value[key]: v[k] = cls().from_dict(value[key][k]) else: @@ -882,7 +893,7 @@ class Message(ABC): else: v = b64decode(value[key]) elif meta.proto_type == TYPE_ENUM: - enum_cls = self._cls_for(field) + enum_cls = self._betterproto.cls_by_field[field.name] if isinstance(v, list): v = [enum_cls.from_string(e) for e in v] elif isinstance(v, str):