From de61ddab21a9e3e34a4ffa7fd3df7fa64a594b85 Mon Sep 17 00:00:00 2001 From: James Lan Date: Tue, 19 May 2020 10:26:23 -0700 Subject: [PATCH 1/4] Add option to repeatly execute betterproto operations in test, to evaluate performance --- betterproto/tests/test_inputs.py | 120 +++++++++++++++---------------- conftest.py | 10 +++ 2 files changed, 70 insertions(+), 60 deletions(-) create mode 100644 conftest.py diff --git a/betterproto/tests/test_inputs.py b/betterproto/tests/test_inputs.py index c8fb7d3..0088d81 100644 --- a/betterproto/tests/test_inputs.py +++ b/betterproto/tests/test_inputs.py @@ -5,6 +5,7 @@ import sys import pytest import betterproto from betterproto.tests.util import get_directories, inputs_path +from collections import namedtuple # Force pure-python implementation instead of C++, otherwise imports # break things because we can't properly reset the symbol database. @@ -22,47 +23,13 @@ plugin_output_package = "betterproto.tests.output_betterproto" reference_output_package = "betterproto.tests.output_reference" -@pytest.mark.parametrize("test_case_name", test_case_names) -def test_message_can_be_imported(test_case_name: str) -> None: - importlib.import_module( - f"{plugin_output_package}.{test_case_name}.{test_case_name}" - ) +TestData = namedtuple("TestData", "plugin_module, reference_module, json_data") -@pytest.mark.parametrize("test_case_name", test_case_names) -def test_message_can_instantiated(test_case_name: str) -> None: - plugin_module = importlib.import_module( - f"{plugin_output_package}.{test_case_name}.{test_case_name}" - ) - plugin_module.Test() +@pytest.fixture(scope="module", params=test_case_names) +def test_data(request): + test_case_name = request.param - -@pytest.mark.parametrize("test_case_name", test_case_names) -def test_message_equality(test_case_name: str) -> None: - plugin_module = importlib.import_module( - f"{plugin_output_package}.{test_case_name}.{test_case_name}" - ) - message1 = plugin_module.Test() - message2 = plugin_module.Test() - assert message1 == message2 - - -@pytest.mark.parametrize("test_case_name", test_case_names) -def test_message_json(test_case_name: str) -> None: - plugin_module = importlib.import_module( - f"{plugin_output_package}.{test_case_name}.{test_case_name}" - ) - message: betterproto.Message = plugin_module.Test() - reference_json_data = get_test_case_json_data(test_case_name) - - message.from_json(reference_json_data) - message_json = message.to_json(0) - - assert json.loads(reference_json_data) == json.loads(message_json) - - -@pytest.mark.parametrize("test_case_name", test_case_names) -def test_binary_compatibility(test_case_name: str) -> None: # Reset the internal symbol database so we can import the `Test` message # multiple times. Ugh. sym = symbol_database.Default() @@ -74,33 +41,66 @@ def test_binary_compatibility(test_case_name: str) -> None: sys.path.append(reference_module_root) - # import reference message - reference_module = importlib.import_module( - f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2" + yield TestData( + plugin_module=importlib.import_module( + f"{plugin_output_package}.{test_case_name}.{test_case_name}" + ), + reference_module=importlib.import_module( + f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2" + ), + json_data=get_test_case_json_data(test_case_name), ) - plugin_module = importlib.import_module( - f"{plugin_output_package}.{test_case_name}.{test_case_name}" - ) - - test_data = get_test_case_json_data(test_case_name) - - reference_instance = Parse(test_data, reference_module.Test()) - reference_binary_output = reference_instance.SerializeToString() - - plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json( - test_data - ) - plugin_instance_from_binary = plugin_module.Test.FromString(reference_binary_output) - - # # Generally this can't be relied on, but here we are aiming to match the - # # existing Python implementation and aren't doing anything tricky. - # # https://developers.google.com/protocol-buffers/docs/encoding#implications - assert plugin_instance_from_json == plugin_instance_from_binary - assert plugin_instance_from_json.to_dict() == plugin_instance_from_binary.to_dict() sys.path.remove(reference_module_root) +def test_message_can_instantiated(test_data: TestData) -> None: + plugin_module, *_ = test_data + plugin_module.Test() + + +def test_message_equality(test_data: TestData) -> None: + plugin_module, *_ = test_data + message1 = plugin_module.Test() + message2 = plugin_module.Test() + assert message1 == message2 + + +def test_message_json(repeat, test_data: TestData) -> None: + plugin_module, _, json_data = test_data + + for _ in range(repeat): + message: betterproto.Message = plugin_module.Test() + + message.from_json(json_data) + message_json = message.to_json(0) + + assert json.loads(json_data) == json.loads(message_json) + + +def test_binary_compatibility(repeat, test_data: TestData) -> None: + plugin_module, reference_module, json_data = test_data + + reference_instance = Parse(json_data, reference_module.Test()) + reference_binary_output = reference_instance.SerializeToString() + + for _ in range(repeat): + plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json( + json_data + ) + plugin_instance_from_binary = plugin_module.Test.FromString( + reference_binary_output + ) + + # # Generally this can't be relied on, but here we are aiming to match the + # # existing Python implementation and aren't doing anything tricky. + # # https://developers.google.com/protocol-buffers/docs/encoding#implications + assert plugin_instance_from_json == plugin_instance_from_binary + assert ( + plugin_instance_from_json.to_dict() == plugin_instance_from_binary.to_dict() + ) + + """ helper methods """ diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..1727782 --- /dev/null +++ b/conftest.py @@ -0,0 +1,10 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption("--repeat", type=int, default=1, help="repeat the operation multiple times") + + +@pytest.fixture(scope="session") +def repeat(request): + return request.config.getoption("repeat") From 3d001a2a1a53cadb83d6e958464124a24558c3b8 Mon Sep 17 00:00:00 2001 From: James Lan Date: Thu, 14 May 2020 15:20:23 -0700 Subject: [PATCH 2/4] Store the class metadata of fields in the class, to improve preformance Cached data include, - lookup table between groups and fields of "oneof" fields - default value creator of each field - type hint of each field --- betterproto/__init__.py | 148 ++++++++++++++++++++--------- betterproto/templates/template.py | 3 +- betterproto/tests/test_features.py | 30 +++--- 3 files changed, 117 insertions(+), 64 deletions(-) diff --git a/betterproto/__init__.py b/betterproto/__init__.py index dc2566c..3584714 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -120,7 +120,11 @@ WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] # Protobuf datetimes start at the Unix Epoch in 1970 in UTC. -DATETIME_ZERO = datetime(1970, 1, 1, tzinfo=timezone.utc) +def datetime_default_gen(): + return datetime(1970, 1, 1, tzinfo=timezone.utc) + + +DATETIME_ZERO = datetime_default_gen() class Casing(enum.Enum): @@ -428,6 +432,57 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: T = TypeVar("T", bound="Message") +class ProtoClassMetadata: + cls: "Message" + + def __init__(self, cls: "Message"): + self.cls = cls + by_field = {} + by_group = {} + + for field in dataclasses.fields(cls): + meta = FieldMetadata.get(field) + + if meta.group: + # This is part of a one-of group. + by_field[field.name] = meta.group + + by_group.setdefault(meta.group, set()).add(field) + + self.oneof_group_by_field = by_field + self.oneof_field_by_group = by_group + + def __getattr__(self, item): + # Lazy init because forward reference classes may not be available at the beginning. + if item == 'default_gen': + defaults = {} + for field in dataclasses.fields(self.cls): + meta = FieldMetadata.get(field) + defaults[field.name] = self.cls._get_field_default_gen(field, meta) + + self.default_gen = defaults # __getattr__ won't be called next time + return defaults + + if item == 'cls_by_field': + field_cls = {} + for field in dataclasses.fields(self.cls): + meta = FieldMetadata.get(field) + field_cls[field.name] = self.cls._type_hint(field.name) + + self.cls_by_field = field_cls # __getattr__ won't be called next time + return field_cls + + +def make_protoclass(cls): + setattr(cls, "_betterproto", ProtoClassMetadata(cls)) + + +def protoclass(*args, **kwargs): + cls = dataclasses.dataclass(*args, **kwargs) + make_protoclass(cls) + return cls + + class Message(ABC): """ A protobuf message base class. Generated code will inherit from this and @@ -445,17 +500,12 @@ class Message(ABC): # Set a default value for each field in the class after `__init__` has # already been run. - group_map: Dict[str, dict] = {"fields": {}, "groups": {}} + group_map: Dict[str, dataclasses.Field] = {} for field in dataclasses.fields(self): meta = FieldMetadata.get(field) if meta.group: - # This is part of a one-of group. - group_map["fields"][field.name] = meta.group - - if meta.group not in group_map["groups"]: - group_map["groups"][meta.group] = {"current": None, "fields": set()} - group_map["groups"][meta.group]["fields"].add(field) + group_map.setdefault(meta.group) if getattr(self, field.name) != PLACEHOLDER: # Skip anything not set to the sentinel value @@ -463,7 +513,7 @@ class Message(ABC): if meta.group: # This was set, so make it the selected value of the one-of. - group_map["groups"][meta.group]["current"] = field + group_map[meta.group] = field continue @@ -479,16 +529,17 @@ class Message(ABC): # Track when a field has been set. self.__dict__["_serialized_on_wire"] = True - if attr in getattr(self, "_group_map", {}).get("fields", {}): - group = self._group_map["fields"][attr] - for field in self._group_map["groups"][group]["fields"]: - if field.name == attr: - self._group_map["groups"][group]["current"] = field - else: - super().__setattr__( - field.name, - self._get_field_default(field, FieldMetadata.get(field)), - ) + if hasattr(self, "_group_map"): # __post_init__ had already run + if attr in self._betterproto.oneof_group_by_field: + group = self._betterproto.oneof_group_by_field[attr] + for field in self._betterproto.oneof_field_by_group[group]: + if field.name == attr: + self._group_map[group] = field + else: + super().__setattr__( + field.name, + self._get_field_default(field, FieldMetadata.get(field)), + ) super().__setattr__(attr, value) @@ -510,7 +561,7 @@ class Message(ABC): # currently set in a `oneof` group, so it must be serialized even # if the value is the default zero value. selected_in_group = False - if meta.group and self._group_map["groups"][meta.group]["current"] == field: + if meta.group and self._group_map[meta.group] == field: selected_in_group = True serialize_empty = False @@ -562,47 +613,49 @@ class Message(ABC): # For compatibility with other libraries SerializeToString = __bytes__ - def _type_hint(self, field_name: str) -> Type: - module = inspect.getmodule(self.__class__) - type_hints = get_type_hints(self.__class__, vars(module)) + @classmethod + def _type_hint(cls, field_name: str) -> Type: + module = inspect.getmodule(cls) + type_hints = get_type_hints(cls, vars(module)) return type_hints[field_name] def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type: """Get the message class for a field from the type hints.""" - cls = self._type_hint(field.name) + cls = self._betterproto.cls_by_field[field.name] if hasattr(cls, "__args__") and index >= 0: cls = cls.__args__[index] return cls def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any: - t = self._type_hint(field.name) + return self._betterproto.default_gen[field.name]() + + @classmethod + def _get_field_default_gen(cls, field: dataclasses.Field, meta: FieldMetadata) -> Any: + t = cls._type_hint(field.name) - value: Any = 0 if hasattr(t, "__origin__"): if t.__origin__ in (dict, Dict): # This is some kind of map (dict in Python). - value = {} + return dict elif t.__origin__ in (list, List): # This is some kind of list (repeated) field. - value = [] + return list elif t.__origin__ == Union and t.__args__[1] == type(None): # This is an optional (wrapped) field. For setting the default we # really don't care what kind of field it is. - value = None + return type(None) else: - value = t() + return t elif issubclass(t, Enum): # Enums always default to zero. - value = 0 + return int elif t == datetime: # Offsets are relative to 1970-01-01T00:00:00Z - value = DATETIME_ZERO + return datetime_default_gen else: # This is either a primitive scalar or another message type. Calling # it should result in its zero value. - value = t() - - return value + return t def _postprocess_single( self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any @@ -654,6 +707,7 @@ class Message(ABC): ], bases=(Message,), ) + make_protoclass(Entry) value = Entry().parse(value) return value @@ -861,13 +915,13 @@ def serialized_on_wire(message: Message) -> bool: def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]: """Return the name and value of a message's one-of field group.""" - field = message._group_map["groups"].get(group_name, {}).get("current") + field = message._group_map.get(group_name) if not field: return ("", None) return (field.name, getattr(message, field.name)) -@dataclasses.dataclass +@protoclass class _Duration(Message): # Signed seconds of the span of time. Must be from -315,576,000,000 to # +315,576,000,000 inclusive. Note: these bounds are computed from: 60 @@ -892,7 +946,7 @@ class _Duration(Message): return ".".join(parts) + "s" -@dataclasses.dataclass +@protoclass class _Timestamp(Message): # Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must # be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive. @@ -942,47 +996,47 @@ class _WrappedMessage(Message): return self -@dataclasses.dataclass +@protoclass class _BoolValue(_WrappedMessage): value: bool = bool_field(1) -@dataclasses.dataclass +@protoclass class _Int32Value(_WrappedMessage): value: int = int32_field(1) -@dataclasses.dataclass +@protoclass class _UInt32Value(_WrappedMessage): value: int = uint32_field(1) -@dataclasses.dataclass +@protoclass class _Int64Value(_WrappedMessage): value: int = int64_field(1) -@dataclasses.dataclass +@protoclass class _UInt64Value(_WrappedMessage): value: int = uint64_field(1) -@dataclasses.dataclass +@protoclass class _FloatValue(_WrappedMessage): value: float = float_field(1) -@dataclasses.dataclass +@protoclass class _DoubleValue(_WrappedMessage): value: float = double_field(1) -@dataclasses.dataclass +@protoclass class _StringValue(_WrappedMessage): value: str = string_field(1) -@dataclasses.dataclass +@protoclass class _BytesValue(_WrappedMessage): value: bytes = bytes_field(1) diff --git a/betterproto/templates/template.py b/betterproto/templates/template.py index 4c18ccc..73f3dac 100644 --- a/betterproto/templates/template.py +++ b/betterproto/templates/template.py @@ -1,7 +1,6 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # sources: {{ ', '.join(description.files) }} # plugin: python-betterproto -from dataclasses import dataclass {% if description.datetime_imports %} from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} @@ -38,7 +37,7 @@ class {{ enum.py_name }}(betterproto.Enum): {% endfor %} {% endif %} {% for message in description.messages %} -@dataclass +@betterproto.protoclass class {{ message.py_name }}(betterproto.Message): {% if message.comment %} {{ message.comment }} diff --git a/betterproto/tests/test_features.py b/betterproto/tests/test_features.py index 47019e1..307094f 100644 --- a/betterproto/tests/test_features.py +++ b/betterproto/tests/test_features.py @@ -4,11 +4,11 @@ from typing import Optional def test_has_field(): - @dataclass + @betterproto.protoclass class Bar(betterproto.Message): baz: int = betterproto.int32_field(1) - @dataclass + @betterproto.protoclass class Foo(betterproto.Message): bar: Bar = betterproto.message_field(1) @@ -34,11 +34,11 @@ def test_has_field(): def test_class_init(): - @dataclass + @betterproto.protoclass class Bar(betterproto.Message): name: str = betterproto.string_field(1) - @dataclass + @betterproto.protoclass class Foo(betterproto.Message): name: str = betterproto.string_field(1) child: Bar = betterproto.message_field(2) @@ -53,7 +53,7 @@ def test_enum_as_int_json(): ZERO = 0 ONE = 1 - @dataclass + @betterproto.protoclass class Foo(betterproto.Message): bar: TestEnum = betterproto.enum_field(1) @@ -67,13 +67,13 @@ def test_enum_as_int_json(): def test_unknown_fields(): - @dataclass + @betterproto.protoclass class Newer(betterproto.Message): foo: bool = betterproto.bool_field(1) bar: int = betterproto.int32_field(2) baz: str = betterproto.string_field(3) - @dataclass + @betterproto.protoclass class Older(betterproto.Message): foo: bool = betterproto.bool_field(1) @@ -89,11 +89,11 @@ def test_unknown_fields(): def test_oneof_support(): - @dataclass + @betterproto.protoclass class Sub(betterproto.Message): val: int = betterproto.int32_field(1) - @dataclass + @betterproto.protoclass class Foo(betterproto.Message): bar: int = betterproto.int32_field(1, group="group1") baz: str = betterproto.string_field(2, group="group1") @@ -134,7 +134,7 @@ def test_oneof_support(): def test_json_casing(): - @dataclass + @betterproto.protoclass class CasingTest(betterproto.Message): pascal_case: int = betterproto.int32_field(1) camel_case: int = betterproto.int32_field(2) @@ -165,7 +165,7 @@ def test_json_casing(): def test_optional_flag(): - @dataclass + @betterproto.protoclass class Request(betterproto.Message): flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL) @@ -180,7 +180,7 @@ def test_optional_flag(): def test_to_dict_default_values(): - @dataclass + @betterproto.protoclass class TestMessage(betterproto.Message): some_int: int = betterproto.int32_field(1) some_double: float = betterproto.double_field(2) @@ -210,7 +210,7 @@ def test_to_dict_default_values(): } # Some default and some other values - @dataclass + @betterproto.protoclass class TestMessage2(betterproto.Message): some_int: int = betterproto.int32_field(1) some_double: float = betterproto.double_field(2) @@ -246,11 +246,11 @@ def test_to_dict_default_values(): } # Nested messages - @dataclass + @betterproto.protoclass class TestChildMessage(betterproto.Message): some_other_int: int = betterproto.int32_field(1) - @dataclass + @betterproto.protoclass class TestParentMessage(betterproto.Message): some_int: int = betterproto.int32_field(1) some_double: float = betterproto.double_field(2) From 1f7f39049eb87d6809657a89e8fafb3f2bf4833e Mon Sep 17 00:00:00 2001 From: James Lan Date: Tue, 19 May 2020 15:42:26 -0700 Subject: [PATCH 3/4] Cache resolved classes for fields, so that there's no new data classes generated while deserializing. --- betterproto/__init__.py | 89 +++++++++++++++++++++++------------------ 1 file changed, 50 insertions(+), 39 deletions(-) diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 3584714..0fefb77 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -452,25 +452,49 @@ class ProtoClassMetadata: self.oneof_group_by_field = by_field self.oneof_field_by_group = by_group + def init_default_gen(self): + default_gen = {} + + for field in dataclasses.fields(self.cls): + meta = FieldMetadata.get(field) + default_gen[field.name] = self.cls._get_field_default_gen(field, meta) + + self.default_gen = default_gen + + def init_cls_by_field(self): + field_cls = {} + + for field in dataclasses.fields(self.cls): + meta = FieldMetadata.get(field) + if meta.proto_type == TYPE_MAP: + assert meta.map_types + kt = self.cls._cls_for(field, index=0) + vt = self.cls._cls_for(field, index=1) + Entry = dataclasses.make_dataclass( + "Entry", + [ + ("key", kt, dataclass_field(1, meta.map_types[0])), + ("value", vt, dataclass_field(2, meta.map_types[1])), + ], + bases=(Message,), + ) + make_protoclass(Entry) + field_cls[field.name] = Entry + field_cls[field.name + ".value"] = vt + else: + field_cls[field.name] = self.cls._cls_for(field) + + self.cls_by_field = field_cls + def __getattr__(self, item): # Lazy init because forward reference classes may not be available at the beginning. if item == 'default_gen': - defaults = {} - for field in dataclasses.fields(self.cls): - meta = FieldMetadata.get(field) - defaults[field.name] = self.cls._get_field_default_gen(field, meta) - - self.default_gen = defaults # __getattr__ won't be called next time - return defaults + self.init_default_gen() + return self.default_gen if item == 'cls_by_field': - field_cls = {} - for field in dataclasses.fields(self.cls): - meta = FieldMetadata.get(field) - field_cls[field.name] = self.cls._type_hint(field.name) - - self.cls_by_field = field_cls # __getattr__ won't be called next time - return field_cls + self.init_cls_by_field() + return self.cls_by_field def make_protoclass(cls): @@ -619,12 +643,13 @@ class Message(ABC): type_hints = get_type_hints(cls, vars(module)) return type_hints[field_name] - def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type: + @classmethod + def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type: """Get the message class for a field from the type hints.""" - cls = self._betterproto.cls_by_field[field.name] - if hasattr(cls, "__args__") and index >= 0: - cls = cls.__args__[index] - return cls + field_cls = cls._type_hint(field.name) + if hasattr(field_cls, "__args__") and index >= 0: + field_cls = field_cls.__args__[index] + return field_cls def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any: return self._betterproto.default_gen[field.name]() @@ -680,7 +705,7 @@ class Message(ABC): if meta.proto_type == TYPE_STRING: value = value.decode("utf-8") elif meta.proto_type == TYPE_MESSAGE: - cls = self._cls_for(field) + cls = self._betterproto.cls_by_field[field.name] if cls == datetime: value = _Timestamp().parse(value).to_datetime() @@ -694,21 +719,7 @@ class Message(ABC): value = cls().parse(value) value._serialized_on_wire = True elif meta.proto_type == TYPE_MAP: - # TODO: This is slow, use a cache to make it faster since each - # key/value pair will recreate the class. - assert meta.map_types - kt = self._cls_for(field, index=0) - vt = self._cls_for(field, index=1) - Entry = dataclasses.make_dataclass( - "Entry", - [ - ("key", kt, dataclass_field(1, meta.map_types[0])), - ("value", vt, dataclass_field(2, meta.map_types[1])), - ], - bases=(Message,), - ) - make_protoclass(Entry) - value = Entry().parse(value) + value = self._betterproto.cls_by_field[field.name]().parse(value) return value @@ -823,7 +834,7 @@ class Message(ABC): else: output[cased_name] = b64encode(v).decode("utf8") elif meta.proto_type == TYPE_ENUM: - enum_values = list(self._cls_for(field)) # type: ignore + enum_values = list(self._betterproto.cls_by_field[field.name]) # type: ignore if isinstance(v, list): output[cased_name] = [enum_values[e].name for e in v] else: @@ -849,7 +860,7 @@ class Message(ABC): if meta.proto_type == "message": v = getattr(self, field.name) if isinstance(v, list): - cls = self._cls_for(field) + cls = self._betterproto.cls_by_field[field.name] for i in range(len(value[key])): v.append(cls().from_dict(value[key][i])) elif isinstance(v, datetime): @@ -866,7 +877,7 @@ class Message(ABC): v.from_dict(value[key]) elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: v = getattr(self, field.name) - cls = self._cls_for(field, index=1) + cls = self._betterproto.cls_by_field[field.name + ".value"] for k in value[key]: v[k] = cls().from_dict(value[key][k]) else: @@ -882,7 +893,7 @@ class Message(ABC): else: v = b64decode(value[key]) elif meta.proto_type == TYPE_ENUM: - enum_cls = self._cls_for(field) + enum_cls = self._betterproto.cls_by_field[field.name] if isinstance(v, list): v = [enum_cls.from_string(e) for e in v] elif isinstance(v, str): From 917de09bb6fc6c73aa034f36fd1981103442b979 Mon Sep 17 00:00:00 2001 From: James Lan Date: Thu, 21 May 2020 17:03:05 -0700 Subject: [PATCH 4/4] Replace extra decorator with property and lazy initialization so that it is backward compatible. --- betterproto/__init__.py | 63 ++++++++++++++---------------- betterproto/templates/template.py | 3 +- betterproto/tests/test_features.py | 30 +++++++------- 3 files changed, 46 insertions(+), 50 deletions(-) diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 0fefb77..418378e 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -433,9 +433,9 @@ T = TypeVar("T", bound="Message") class ProtoClassMetadata: - cls: "Message" + cls: Type["Message"] - def __init__(self, cls: "Message"): + def __init__(self, cls: Type["Message"]): self.cls = cls by_field = {} by_group = {} @@ -452,6 +452,9 @@ class ProtoClassMetadata: self.oneof_group_by_field = by_field self.oneof_field_by_group = by_group + self.init_default_gen() + self.init_cls_by_field() + def init_default_gen(self): default_gen = {} @@ -478,7 +481,6 @@ class ProtoClassMetadata: ], bases=(Message,), ) - make_protoclass(Entry) field_cls[field.name] = Entry field_cls[field.name + ".value"] = vt else: @@ -486,26 +488,6 @@ class ProtoClassMetadata: self.cls_by_field = field_cls - def __getattr__(self, item): - # Lazy init because forward reference classes may not be available at the beginning. - if item == 'default_gen': - self.init_default_gen() - return self.default_gen - - if item == 'cls_by_field': - self.init_cls_by_field() - return self.cls_by_field - - -def make_protoclass(cls): - setattr(cls, "_betterproto", ProtoClassMetadata(cls)) - - -def protoclass(*args, **kwargs): - cls = dataclasses.dataclass(*args, **kwargs) - make_protoclass(cls) - return cls - class Message(ABC): """ @@ -567,6 +549,19 @@ class Message(ABC): super().__setattr__(attr, value) + @property + def _betterproto(self): + """ + Lazy initialize metadata for each protobuf class. + It may be initialized multiple times in a multi-threaded environment, + but that won't affect the correctness. + """ + meta = getattr(self.__class__, "_betterproto_meta", None) + if not meta: + meta = ProtoClassMetadata(self.__class__) + self.__class__._betterproto_meta = meta + return meta + def __bytes__(self) -> bytes: """ Get the binary encoded Protobuf representation of this instance. @@ -932,7 +927,7 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]: return (field.name, getattr(message, field.name)) -@protoclass +@dataclasses.dataclass class _Duration(Message): # Signed seconds of the span of time. Must be from -315,576,000,000 to # +315,576,000,000 inclusive. Note: these bounds are computed from: 60 @@ -957,7 +952,7 @@ class _Duration(Message): return ".".join(parts) + "s" -@protoclass +@dataclasses.dataclass class _Timestamp(Message): # Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must # be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive. @@ -1007,47 +1002,47 @@ class _WrappedMessage(Message): return self -@protoclass +@dataclasses.dataclass class _BoolValue(_WrappedMessage): value: bool = bool_field(1) -@protoclass +@dataclasses.dataclass class _Int32Value(_WrappedMessage): value: int = int32_field(1) -@protoclass +@dataclasses.dataclass class _UInt32Value(_WrappedMessage): value: int = uint32_field(1) -@protoclass +@dataclasses.dataclass class _Int64Value(_WrappedMessage): value: int = int64_field(1) -@protoclass +@dataclasses.dataclass class _UInt64Value(_WrappedMessage): value: int = uint64_field(1) -@protoclass +@dataclasses.dataclass class _FloatValue(_WrappedMessage): value: float = float_field(1) -@protoclass +@dataclasses.dataclass class _DoubleValue(_WrappedMessage): value: float = double_field(1) -@protoclass +@dataclasses.dataclass class _StringValue(_WrappedMessage): value: str = string_field(1) -@protoclass +@dataclasses.dataclass class _BytesValue(_WrappedMessage): value: bytes = bytes_field(1) diff --git a/betterproto/templates/template.py b/betterproto/templates/template.py index 73f3dac..4c18ccc 100644 --- a/betterproto/templates/template.py +++ b/betterproto/templates/template.py @@ -1,6 +1,7 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # sources: {{ ', '.join(description.files) }} # plugin: python-betterproto +from dataclasses import dataclass {% if description.datetime_imports %} from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} @@ -37,7 +38,7 @@ class {{ enum.py_name }}(betterproto.Enum): {% endfor %} {% endif %} {% for message in description.messages %} -@betterproto.protoclass +@dataclass class {{ message.py_name }}(betterproto.Message): {% if message.comment %} {{ message.comment }} diff --git a/betterproto/tests/test_features.py b/betterproto/tests/test_features.py index 307094f..47019e1 100644 --- a/betterproto/tests/test_features.py +++ b/betterproto/tests/test_features.py @@ -4,11 +4,11 @@ from typing import Optional def test_has_field(): - @betterproto.protoclass + @dataclass class Bar(betterproto.Message): baz: int = betterproto.int32_field(1) - @betterproto.protoclass + @dataclass class Foo(betterproto.Message): bar: Bar = betterproto.message_field(1) @@ -34,11 +34,11 @@ def test_has_field(): def test_class_init(): - @betterproto.protoclass + @dataclass class Bar(betterproto.Message): name: str = betterproto.string_field(1) - @betterproto.protoclass + @dataclass class Foo(betterproto.Message): name: str = betterproto.string_field(1) child: Bar = betterproto.message_field(2) @@ -53,7 +53,7 @@ def test_enum_as_int_json(): ZERO = 0 ONE = 1 - @betterproto.protoclass + @dataclass class Foo(betterproto.Message): bar: TestEnum = betterproto.enum_field(1) @@ -67,13 +67,13 @@ def test_enum_as_int_json(): def test_unknown_fields(): - @betterproto.protoclass + @dataclass class Newer(betterproto.Message): foo: bool = betterproto.bool_field(1) bar: int = betterproto.int32_field(2) baz: str = betterproto.string_field(3) - @betterproto.protoclass + @dataclass class Older(betterproto.Message): foo: bool = betterproto.bool_field(1) @@ -89,11 +89,11 @@ def test_unknown_fields(): def test_oneof_support(): - @betterproto.protoclass + @dataclass class Sub(betterproto.Message): val: int = betterproto.int32_field(1) - @betterproto.protoclass + @dataclass class Foo(betterproto.Message): bar: int = betterproto.int32_field(1, group="group1") baz: str = betterproto.string_field(2, group="group1") @@ -134,7 +134,7 @@ def test_oneof_support(): def test_json_casing(): - @betterproto.protoclass + @dataclass class CasingTest(betterproto.Message): pascal_case: int = betterproto.int32_field(1) camel_case: int = betterproto.int32_field(2) @@ -165,7 +165,7 @@ def test_json_casing(): def test_optional_flag(): - @betterproto.protoclass + @dataclass class Request(betterproto.Message): flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL) @@ -180,7 +180,7 @@ def test_optional_flag(): def test_to_dict_default_values(): - @betterproto.protoclass + @dataclass class TestMessage(betterproto.Message): some_int: int = betterproto.int32_field(1) some_double: float = betterproto.double_field(2) @@ -210,7 +210,7 @@ def test_to_dict_default_values(): } # Some default and some other values - @betterproto.protoclass + @dataclass class TestMessage2(betterproto.Message): some_int: int = betterproto.int32_field(1) some_double: float = betterproto.double_field(2) @@ -246,11 +246,11 @@ def test_to_dict_default_values(): } # Nested messages - @betterproto.protoclass + @dataclass class TestChildMessage(betterproto.Message): some_other_int: int = betterproto.int32_field(1) - @betterproto.protoclass + @dataclass class TestParentMessage(betterproto.Message): some_int: int = betterproto.int32_field(1) some_double: float = betterproto.double_field(2)