Add support for recursive messages (#130)
Changes message initialization (`__post_init__`) so that default values are no longer eagerly created to prevent infinite recursion when initializing recursive messages. As a result, `PLACEHOLDER` will be present in the message for any uninitialized fields. So, an implementation of `__get_attribute__` is added that checks for `PLACEHOLDER` and lazily creates and stores default field values. And, because `PLACEHOLDER` values don't compare equal with zero values, a custom implementation of `__eq__` is provided, and the code generation template is updated so that messages generate with `@dataclass(eq=False)`. Also add new Message __repr__ implementation that skips PLACEHOLDER values and orders keys by number from the proto. Co-authored-by: Christopher Chambers <chris@peanutcode.com> Co-authored-by: nat <n@natn.me> Co-authored-by: James <50501825+Gobot1234@users.noreply.github.com>
This commit is contained in:
parent
ca16b6ed34
commit
034e2e7da0
@ -428,6 +428,7 @@ class ProtoClassMetadata:
|
|||||||
"cls_by_field",
|
"cls_by_field",
|
||||||
"field_name_by_number",
|
"field_name_by_number",
|
||||||
"meta_by_field_name",
|
"meta_by_field_name",
|
||||||
|
"sorted_field_names",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, cls: Type["Message"]):
|
def __init__(self, cls: Type["Message"]):
|
||||||
@ -453,6 +454,9 @@ class ProtoClassMetadata:
|
|||||||
self.oneof_field_by_group = by_group
|
self.oneof_field_by_group = by_group
|
||||||
self.field_name_by_number = by_field_number
|
self.field_name_by_number = by_field_number
|
||||||
self.meta_by_field_name = by_field_name
|
self.meta_by_field_name = by_field_name
|
||||||
|
self.sorted_field_names = tuple(
|
||||||
|
by_field_number[number] for number in sorted(by_field_number.keys())
|
||||||
|
)
|
||||||
|
|
||||||
self.default_gen = self._get_default_gen(cls, fields)
|
self.default_gen = self._get_default_gen(cls, fields)
|
||||||
self.cls_by_field = self._get_cls_by_field(cls, fields)
|
self.cls_by_field = self._get_cls_by_field(cls, fields)
|
||||||
@ -513,23 +517,63 @@ class Message(ABC):
|
|||||||
if meta.group:
|
if meta.group:
|
||||||
group_current.setdefault(meta.group)
|
group_current.setdefault(meta.group)
|
||||||
|
|
||||||
if getattr(self, field_name) != PLACEHOLDER:
|
if self.__raw_get(field_name) != PLACEHOLDER:
|
||||||
# Skip anything not set to the sentinel value
|
# Found a non-sentinel value
|
||||||
all_sentinel = False
|
all_sentinel = False
|
||||||
|
|
||||||
if meta.group:
|
if meta.group:
|
||||||
# This was set, so make it the selected value of the one-of.
|
# This was set, so make it the selected value of the one-of.
|
||||||
group_current[meta.group] = field_name
|
group_current[meta.group] = field_name
|
||||||
|
|
||||||
continue
|
|
||||||
|
|
||||||
setattr(self, field_name, self._get_field_default(field_name))
|
|
||||||
|
|
||||||
# Now that all the defaults are set, reset it!
|
# Now that all the defaults are set, reset it!
|
||||||
self.__dict__["_serialized_on_wire"] = not all_sentinel
|
self.__dict__["_serialized_on_wire"] = not all_sentinel
|
||||||
self.__dict__["_unknown_fields"] = b""
|
self.__dict__["_unknown_fields"] = b""
|
||||||
self.__dict__["_group_current"] = group_current
|
self.__dict__["_group_current"] = group_current
|
||||||
|
|
||||||
|
def __raw_get(self, name: str) -> Any:
|
||||||
|
return super().__getattribute__(name)
|
||||||
|
|
||||||
|
def __eq__(self, other) -> bool:
|
||||||
|
if type(self) is not type(other):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for field_name in self._betterproto.meta_by_field_name:
|
||||||
|
self_val = self.__raw_get(field_name)
|
||||||
|
other_val = other.__raw_get(field_name)
|
||||||
|
if self_val is PLACEHOLDER:
|
||||||
|
if other_val is PLACEHOLDER:
|
||||||
|
continue
|
||||||
|
self_val = self._get_field_default(field_name)
|
||||||
|
elif other_val is PLACEHOLDER:
|
||||||
|
other_val = other._get_field_default(field_name)
|
||||||
|
|
||||||
|
if self_val != other_val:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
parts = [
|
||||||
|
f"{field_name}={value!r}"
|
||||||
|
for field_name in self._betterproto.sorted_field_names
|
||||||
|
for value in (self.__raw_get(field_name),)
|
||||||
|
if value is not PLACEHOLDER
|
||||||
|
]
|
||||||
|
return f"{self.__class__.__name__}({', '.join(parts)})"
|
||||||
|
|
||||||
|
def __getattribute__(self, name: str) -> Any:
|
||||||
|
"""
|
||||||
|
Lazily initialize default values to avoid infinite recursion for recursive
|
||||||
|
message types
|
||||||
|
"""
|
||||||
|
value = super().__getattribute__(name)
|
||||||
|
if value is not PLACEHOLDER:
|
||||||
|
return value
|
||||||
|
|
||||||
|
value = self._get_field_default(name)
|
||||||
|
super().__setattr__(name, value)
|
||||||
|
return value
|
||||||
|
|
||||||
def __setattr__(self, attr: str, value: Any) -> None:
|
def __setattr__(self, attr: str, value: Any) -> None:
|
||||||
if attr != "_serialized_on_wire":
|
if attr != "_serialized_on_wire":
|
||||||
# Track when a field has been set.
|
# Track when a field has been set.
|
||||||
@ -542,9 +586,7 @@ class Message(ABC):
|
|||||||
if field.name == attr:
|
if field.name == attr:
|
||||||
self._group_current[group] = field.name
|
self._group_current[group] = field.name
|
||||||
else:
|
else:
|
||||||
super().__setattr__(
|
super().__setattr__(field.name, PLACEHOLDER)
|
||||||
field.name, self._get_field_default(field.name)
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__setattr__(attr, value)
|
super().__setattr__(attr, value)
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ class {{ enum.py_name }}(betterproto.Enum):
|
|||||||
{% endfor %}
|
{% endfor %}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% for message in output_file.messages %}
|
{% for message in output_file.messages %}
|
||||||
@dataclass
|
@dataclass(eq=False, repr=False)
|
||||||
class {{ message.py_name }}(betterproto.Message):
|
class {{ message.py_name }}(betterproto.Message):
|
||||||
{% if message.comment %}
|
{% if message.comment %}
|
||||||
{{ message.comment }}
|
{{ message.comment }}
|
||||||
|
12
tests/inputs/recursivemessage/recursivemessage.json
Normal file
12
tests/inputs/recursivemessage/recursivemessage.json
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"name": "Zues",
|
||||||
|
"child": {
|
||||||
|
"name": "Hercules"
|
||||||
|
},
|
||||||
|
"intermediate": {
|
||||||
|
"child": {
|
||||||
|
"name": "Douglas Adams"
|
||||||
|
},
|
||||||
|
"number": 42
|
||||||
|
}
|
||||||
|
}
|
13
tests/inputs/recursivemessage/recursivemessage.proto
Normal file
13
tests/inputs/recursivemessage/recursivemessage.proto
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
string name = 1;
|
||||||
|
Test child = 2;
|
||||||
|
Intermediate intermediate = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
message Intermediate {
|
||||||
|
int32 number = 1;
|
||||||
|
Test child = 2;
|
||||||
|
}
|
@ -317,3 +317,51 @@ def test_oneof_default_value_set_causes_writes_wire():
|
|||||||
== betterproto.which_one_of(_round_trip_serialization(foo3), "group1")
|
== betterproto.which_one_of(_round_trip_serialization(foo3), "group1")
|
||||||
== ("", None)
|
== ("", None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_recursive_message():
|
||||||
|
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
|
||||||
|
|
||||||
|
msg = RecursiveMessage()
|
||||||
|
|
||||||
|
assert msg.child == RecursiveMessage()
|
||||||
|
|
||||||
|
# Lazily-created zero-value children must not affect equality.
|
||||||
|
assert msg == RecursiveMessage()
|
||||||
|
|
||||||
|
# Lazily-created zero-value children must not affect serialization.
|
||||||
|
assert bytes(msg) == b""
|
||||||
|
|
||||||
|
|
||||||
|
def test_recursive_message_defaults():
|
||||||
|
from tests.output_betterproto.recursivemessage import (
|
||||||
|
Test as RecursiveMessage,
|
||||||
|
Intermediate,
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
|
||||||
|
|
||||||
|
# set values are as expected
|
||||||
|
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))
|
||||||
|
|
||||||
|
# lazy initialized works modifies the message
|
||||||
|
assert msg != RecursiveMessage(
|
||||||
|
name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
|
||||||
|
)
|
||||||
|
msg.child.child.name = "jude"
|
||||||
|
assert msg == RecursiveMessage(
|
||||||
|
name="bob",
|
||||||
|
intermediate=Intermediate(42),
|
||||||
|
child=RecursiveMessage(child=RecursiveMessage(name="jude")),
|
||||||
|
)
|
||||||
|
|
||||||
|
# lazily initialization recurses as needed
|
||||||
|
assert msg.child.child.child.child.child.child.child == RecursiveMessage()
|
||||||
|
assert msg.intermediate.child.intermediate == Intermediate()
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_repr():
|
||||||
|
from tests.output_betterproto.recursivemessage import Test
|
||||||
|
|
||||||
|
assert repr(Test(name="Loki")) == "Test(name='Loki')"
|
||||||
|
assert repr(Test(child=Test(), name="Loki")) == "Test(name='Loki', child=Test())"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user