Replace extra decorator with property and lazy initialization so that it is backward compatible.

This commit is contained in:
James Lan 2020-05-21 17:03:05 -07:00
parent 1f7f39049e
commit 917de09bb6
3 changed files with 46 additions and 50 deletions

View File

@ -433,9 +433,9 @@ T = TypeVar("T", bound="Message")
class ProtoClassMetadata: class ProtoClassMetadata:
cls: "Message" cls: Type["Message"]
def __init__(self, cls: "Message"): def __init__(self, cls: Type["Message"]):
self.cls = cls self.cls = cls
by_field = {} by_field = {}
by_group = {} by_group = {}
@ -452,6 +452,9 @@ class ProtoClassMetadata:
self.oneof_group_by_field = by_field self.oneof_group_by_field = by_field
self.oneof_field_by_group = by_group self.oneof_field_by_group = by_group
self.init_default_gen()
self.init_cls_by_field()
def init_default_gen(self): def init_default_gen(self):
default_gen = {} default_gen = {}
@ -478,7 +481,6 @@ class ProtoClassMetadata:
], ],
bases=(Message,), bases=(Message,),
) )
make_protoclass(Entry)
field_cls[field.name] = Entry field_cls[field.name] = Entry
field_cls[field.name + ".value"] = vt field_cls[field.name + ".value"] = vt
else: else:
@ -486,26 +488,6 @@ class ProtoClassMetadata:
self.cls_by_field = field_cls self.cls_by_field = field_cls
def __getattr__(self, item):
# Lazy init because forward reference classes may not be available at the beginning.
if item == 'default_gen':
self.init_default_gen()
return self.default_gen
if item == 'cls_by_field':
self.init_cls_by_field()
return self.cls_by_field
def make_protoclass(cls):
setattr(cls, "_betterproto", ProtoClassMetadata(cls))
def protoclass(*args, **kwargs):
cls = dataclasses.dataclass(*args, **kwargs)
make_protoclass(cls)
return cls
class Message(ABC): class Message(ABC):
""" """
@ -567,6 +549,19 @@ class Message(ABC):
super().__setattr__(attr, value) super().__setattr__(attr, value)
@property
def _betterproto(self):
"""
Lazy initialize metadata for each protobuf class.
It may be initialized multiple times in a multi-threaded environment,
but that won't affect the correctness.
"""
meta = getattr(self.__class__, "_betterproto_meta", None)
if not meta:
meta = ProtoClassMetadata(self.__class__)
self.__class__._betterproto_meta = meta
return meta
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
""" """
Get the binary encoded Protobuf representation of this instance. Get the binary encoded Protobuf representation of this instance.
@ -932,7 +927,7 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
return (field.name, getattr(message, field.name)) return (field.name, getattr(message, field.name))
@protoclass @dataclasses.dataclass
class _Duration(Message): class _Duration(Message):
# Signed seconds of the span of time. Must be from -315,576,000,000 to # Signed seconds of the span of time. Must be from -315,576,000,000 to
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60 # +315,576,000,000 inclusive. Note: these bounds are computed from: 60
@ -957,7 +952,7 @@ class _Duration(Message):
return ".".join(parts) + "s" return ".".join(parts) + "s"
@protoclass @dataclasses.dataclass
class _Timestamp(Message): class _Timestamp(Message):
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must # Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive. # be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
@ -1007,47 +1002,47 @@ class _WrappedMessage(Message):
return self return self
@protoclass @dataclasses.dataclass
class _BoolValue(_WrappedMessage): class _BoolValue(_WrappedMessage):
value: bool = bool_field(1) value: bool = bool_field(1)
@protoclass @dataclasses.dataclass
class _Int32Value(_WrappedMessage): class _Int32Value(_WrappedMessage):
value: int = int32_field(1) value: int = int32_field(1)
@protoclass @dataclasses.dataclass
class _UInt32Value(_WrappedMessage): class _UInt32Value(_WrappedMessage):
value: int = uint32_field(1) value: int = uint32_field(1)
@protoclass @dataclasses.dataclass
class _Int64Value(_WrappedMessage): class _Int64Value(_WrappedMessage):
value: int = int64_field(1) value: int = int64_field(1)
@protoclass @dataclasses.dataclass
class _UInt64Value(_WrappedMessage): class _UInt64Value(_WrappedMessage):
value: int = uint64_field(1) value: int = uint64_field(1)
@protoclass @dataclasses.dataclass
class _FloatValue(_WrappedMessage): class _FloatValue(_WrappedMessage):
value: float = float_field(1) value: float = float_field(1)
@protoclass @dataclasses.dataclass
class _DoubleValue(_WrappedMessage): class _DoubleValue(_WrappedMessage):
value: float = double_field(1) value: float = double_field(1)
@protoclass @dataclasses.dataclass
class _StringValue(_WrappedMessage): class _StringValue(_WrappedMessage):
value: str = string_field(1) value: str = string_field(1)
@protoclass @dataclasses.dataclass
class _BytesValue(_WrappedMessage): class _BytesValue(_WrappedMessage):
value: bytes = bytes_field(1) value: bytes = bytes_field(1)

View File

@ -1,6 +1,7 @@
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: {{ ', '.join(description.files) }} # sources: {{ ', '.join(description.files) }}
# plugin: python-betterproto # plugin: python-betterproto
from dataclasses import dataclass
{% if description.datetime_imports %} {% if description.datetime_imports %}
from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
@ -37,7 +38,7 @@ class {{ enum.py_name }}(betterproto.Enum):
{% endfor %} {% endfor %}
{% endif %} {% endif %}
{% for message in description.messages %} {% for message in description.messages %}
@betterproto.protoclass @dataclass
class {{ message.py_name }}(betterproto.Message): class {{ message.py_name }}(betterproto.Message):
{% if message.comment %} {% if message.comment %}
{{ message.comment }} {{ message.comment }}

View File

@ -4,11 +4,11 @@ from typing import Optional
def test_has_field(): def test_has_field():
@betterproto.protoclass @dataclass
class Bar(betterproto.Message): class Bar(betterproto.Message):
baz: int = betterproto.int32_field(1) baz: int = betterproto.int32_field(1)
@betterproto.protoclass @dataclass
class Foo(betterproto.Message): class Foo(betterproto.Message):
bar: Bar = betterproto.message_field(1) bar: Bar = betterproto.message_field(1)
@ -34,11 +34,11 @@ def test_has_field():
def test_class_init(): def test_class_init():
@betterproto.protoclass @dataclass
class Bar(betterproto.Message): class Bar(betterproto.Message):
name: str = betterproto.string_field(1) name: str = betterproto.string_field(1)
@betterproto.protoclass @dataclass
class Foo(betterproto.Message): class Foo(betterproto.Message):
name: str = betterproto.string_field(1) name: str = betterproto.string_field(1)
child: Bar = betterproto.message_field(2) child: Bar = betterproto.message_field(2)
@ -53,7 +53,7 @@ def test_enum_as_int_json():
ZERO = 0 ZERO = 0
ONE = 1 ONE = 1
@betterproto.protoclass @dataclass
class Foo(betterproto.Message): class Foo(betterproto.Message):
bar: TestEnum = betterproto.enum_field(1) bar: TestEnum = betterproto.enum_field(1)
@ -67,13 +67,13 @@ def test_enum_as_int_json():
def test_unknown_fields(): def test_unknown_fields():
@betterproto.protoclass @dataclass
class Newer(betterproto.Message): class Newer(betterproto.Message):
foo: bool = betterproto.bool_field(1) foo: bool = betterproto.bool_field(1)
bar: int = betterproto.int32_field(2) bar: int = betterproto.int32_field(2)
baz: str = betterproto.string_field(3) baz: str = betterproto.string_field(3)
@betterproto.protoclass @dataclass
class Older(betterproto.Message): class Older(betterproto.Message):
foo: bool = betterproto.bool_field(1) foo: bool = betterproto.bool_field(1)
@ -89,11 +89,11 @@ def test_unknown_fields():
def test_oneof_support(): def test_oneof_support():
@betterproto.protoclass @dataclass
class Sub(betterproto.Message): class Sub(betterproto.Message):
val: int = betterproto.int32_field(1) val: int = betterproto.int32_field(1)
@betterproto.protoclass @dataclass
class Foo(betterproto.Message): class Foo(betterproto.Message):
bar: int = betterproto.int32_field(1, group="group1") bar: int = betterproto.int32_field(1, group="group1")
baz: str = betterproto.string_field(2, group="group1") baz: str = betterproto.string_field(2, group="group1")
@ -134,7 +134,7 @@ def test_oneof_support():
def test_json_casing(): def test_json_casing():
@betterproto.protoclass @dataclass
class CasingTest(betterproto.Message): class CasingTest(betterproto.Message):
pascal_case: int = betterproto.int32_field(1) pascal_case: int = betterproto.int32_field(1)
camel_case: int = betterproto.int32_field(2) camel_case: int = betterproto.int32_field(2)
@ -165,7 +165,7 @@ def test_json_casing():
def test_optional_flag(): def test_optional_flag():
@betterproto.protoclass @dataclass
class Request(betterproto.Message): class Request(betterproto.Message):
flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL) flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL)
@ -180,7 +180,7 @@ def test_optional_flag():
def test_to_dict_default_values(): def test_to_dict_default_values():
@betterproto.protoclass @dataclass
class TestMessage(betterproto.Message): class TestMessage(betterproto.Message):
some_int: int = betterproto.int32_field(1) some_int: int = betterproto.int32_field(1)
some_double: float = betterproto.double_field(2) some_double: float = betterproto.double_field(2)
@ -210,7 +210,7 @@ def test_to_dict_default_values():
} }
# Some default and some other values # Some default and some other values
@betterproto.protoclass @dataclass
class TestMessage2(betterproto.Message): class TestMessage2(betterproto.Message):
some_int: int = betterproto.int32_field(1) some_int: int = betterproto.int32_field(1)
some_double: float = betterproto.double_field(2) some_double: float = betterproto.double_field(2)
@ -246,11 +246,11 @@ def test_to_dict_default_values():
} }
# Nested messages # Nested messages
@betterproto.protoclass @dataclass
class TestChildMessage(betterproto.Message): class TestChildMessage(betterproto.Message):
some_other_int: int = betterproto.int32_field(1) some_other_int: int = betterproto.int32_field(1)
@betterproto.protoclass @dataclass
class TestParentMessage(betterproto.Message): class TestParentMessage(betterproto.Message):
some_int: int = betterproto.int32_field(1) some_int: int = betterproto.int32_field(1)
some_double: float = betterproto.double_field(2) some_double: float = betterproto.double_field(2)