From 917de09bb6fc6c73aa034f36fd1981103442b979 Mon Sep 17 00:00:00 2001 From: James Lan Date: Thu, 21 May 2020 17:03:05 -0700 Subject: [PATCH] Replace extra decorator with property and lazy initialization so that it is backward compatible. --- betterproto/__init__.py | 63 ++++++++++++++---------------- betterproto/templates/template.py | 3 +- betterproto/tests/test_features.py | 30 +++++++------- 3 files changed, 46 insertions(+), 50 deletions(-) diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 0fefb77..418378e 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -433,9 +433,9 @@ T = TypeVar("T", bound="Message") class ProtoClassMetadata: - cls: "Message" + cls: Type["Message"] - def __init__(self, cls: "Message"): + def __init__(self, cls: Type["Message"]): self.cls = cls by_field = {} by_group = {} @@ -452,6 +452,9 @@ class ProtoClassMetadata: self.oneof_group_by_field = by_field self.oneof_field_by_group = by_group + self.init_default_gen() + self.init_cls_by_field() + def init_default_gen(self): default_gen = {} @@ -478,7 +481,6 @@ class ProtoClassMetadata: ], bases=(Message,), ) - make_protoclass(Entry) field_cls[field.name] = Entry field_cls[field.name + ".value"] = vt else: @@ -486,26 +488,6 @@ class ProtoClassMetadata: self.cls_by_field = field_cls - def __getattr__(self, item): - # Lazy init because forward reference classes may not be available at the beginning. - if item == 'default_gen': - self.init_default_gen() - return self.default_gen - - if item == 'cls_by_field': - self.init_cls_by_field() - return self.cls_by_field - - -def make_protoclass(cls): - setattr(cls, "_betterproto", ProtoClassMetadata(cls)) - - -def protoclass(*args, **kwargs): - cls = dataclasses.dataclass(*args, **kwargs) - make_protoclass(cls) - return cls - class Message(ABC): """ @@ -567,6 +549,19 @@ class Message(ABC): super().__setattr__(attr, value) + @property + def _betterproto(self): + """ + Lazy initialize metadata for each protobuf class. + It may be initialized multiple times in a multi-threaded environment, + but that won't affect the correctness. + """ + meta = getattr(self.__class__, "_betterproto_meta", None) + if not meta: + meta = ProtoClassMetadata(self.__class__) + self.__class__._betterproto_meta = meta + return meta + def __bytes__(self) -> bytes: """ Get the binary encoded Protobuf representation of this instance. @@ -932,7 +927,7 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]: return (field.name, getattr(message, field.name)) -@protoclass +@dataclasses.dataclass class _Duration(Message): # Signed seconds of the span of time. Must be from -315,576,000,000 to # +315,576,000,000 inclusive. Note: these bounds are computed from: 60 @@ -957,7 +952,7 @@ class _Duration(Message): return ".".join(parts) + "s" -@protoclass +@dataclasses.dataclass class _Timestamp(Message): # Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must # be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive. @@ -1007,47 +1002,47 @@ class _WrappedMessage(Message): return self -@protoclass +@dataclasses.dataclass class _BoolValue(_WrappedMessage): value: bool = bool_field(1) -@protoclass +@dataclasses.dataclass class _Int32Value(_WrappedMessage): value: int = int32_field(1) -@protoclass +@dataclasses.dataclass class _UInt32Value(_WrappedMessage): value: int = uint32_field(1) -@protoclass +@dataclasses.dataclass class _Int64Value(_WrappedMessage): value: int = int64_field(1) -@protoclass +@dataclasses.dataclass class _UInt64Value(_WrappedMessage): value: int = uint64_field(1) -@protoclass +@dataclasses.dataclass class _FloatValue(_WrappedMessage): value: float = float_field(1) -@protoclass +@dataclasses.dataclass class _DoubleValue(_WrappedMessage): value: float = double_field(1) -@protoclass +@dataclasses.dataclass class _StringValue(_WrappedMessage): value: str = string_field(1) -@protoclass +@dataclasses.dataclass class _BytesValue(_WrappedMessage): value: bytes = bytes_field(1) diff --git a/betterproto/templates/template.py b/betterproto/templates/template.py index 73f3dac..4c18ccc 100644 --- a/betterproto/templates/template.py +++ b/betterproto/templates/template.py @@ -1,6 +1,7 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # sources: {{ ', '.join(description.files) }} # plugin: python-betterproto +from dataclasses import dataclass {% if description.datetime_imports %} from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} @@ -37,7 +38,7 @@ class {{ enum.py_name }}(betterproto.Enum): {% endfor %} {% endif %} {% for message in description.messages %} -@betterproto.protoclass +@dataclass class {{ message.py_name }}(betterproto.Message): {% if message.comment %} {{ message.comment }} diff --git a/betterproto/tests/test_features.py b/betterproto/tests/test_features.py index 307094f..47019e1 100644 --- a/betterproto/tests/test_features.py +++ b/betterproto/tests/test_features.py @@ -4,11 +4,11 @@ from typing import Optional def test_has_field(): - @betterproto.protoclass + @dataclass class Bar(betterproto.Message): baz: int = betterproto.int32_field(1) - @betterproto.protoclass + @dataclass class Foo(betterproto.Message): bar: Bar = betterproto.message_field(1) @@ -34,11 +34,11 @@ def test_has_field(): def test_class_init(): - @betterproto.protoclass + @dataclass class Bar(betterproto.Message): name: str = betterproto.string_field(1) - @betterproto.protoclass + @dataclass class Foo(betterproto.Message): name: str = betterproto.string_field(1) child: Bar = betterproto.message_field(2) @@ -53,7 +53,7 @@ def test_enum_as_int_json(): ZERO = 0 ONE = 1 - @betterproto.protoclass + @dataclass class Foo(betterproto.Message): bar: TestEnum = betterproto.enum_field(1) @@ -67,13 +67,13 @@ def test_enum_as_int_json(): def test_unknown_fields(): - @betterproto.protoclass + @dataclass class Newer(betterproto.Message): foo: bool = betterproto.bool_field(1) bar: int = betterproto.int32_field(2) baz: str = betterproto.string_field(3) - @betterproto.protoclass + @dataclass class Older(betterproto.Message): foo: bool = betterproto.bool_field(1) @@ -89,11 +89,11 @@ def test_unknown_fields(): def test_oneof_support(): - @betterproto.protoclass + @dataclass class Sub(betterproto.Message): val: int = betterproto.int32_field(1) - @betterproto.protoclass + @dataclass class Foo(betterproto.Message): bar: int = betterproto.int32_field(1, group="group1") baz: str = betterproto.string_field(2, group="group1") @@ -134,7 +134,7 @@ def test_oneof_support(): def test_json_casing(): - @betterproto.protoclass + @dataclass class CasingTest(betterproto.Message): pascal_case: int = betterproto.int32_field(1) camel_case: int = betterproto.int32_field(2) @@ -165,7 +165,7 @@ def test_json_casing(): def test_optional_flag(): - @betterproto.protoclass + @dataclass class Request(betterproto.Message): flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL) @@ -180,7 +180,7 @@ def test_optional_flag(): def test_to_dict_default_values(): - @betterproto.protoclass + @dataclass class TestMessage(betterproto.Message): some_int: int = betterproto.int32_field(1) some_double: float = betterproto.double_field(2) @@ -210,7 +210,7 @@ def test_to_dict_default_values(): } # Some default and some other values - @betterproto.protoclass + @dataclass class TestMessage2(betterproto.Message): some_int: int = betterproto.int32_field(1) some_double: float = betterproto.double_field(2) @@ -246,11 +246,11 @@ def test_to_dict_default_values(): } # Nested messages - @betterproto.protoclass + @dataclass class TestChildMessage(betterproto.Message): some_other_int: int = betterproto.int32_field(1) - @betterproto.protoclass + @dataclass class TestParentMessage(betterproto.Message): some_int: int = betterproto.int32_field(1) some_double: float = betterproto.double_field(2)