diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index 4835516..5985798 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -428,6 +428,7 @@ class ProtoClassMetadata: "cls_by_field", "field_name_by_number", "meta_by_field_name", + "sorted_field_names", ) def __init__(self, cls: Type["Message"]): @@ -453,6 +454,9 @@ class ProtoClassMetadata: self.oneof_field_by_group = by_group self.field_name_by_number = by_field_number self.meta_by_field_name = by_field_name + self.sorted_field_names = tuple( + by_field_number[number] for number in sorted(by_field_number.keys()) + ) self.default_gen = self._get_default_gen(cls, fields) self.cls_by_field = self._get_cls_by_field(cls, fields) @@ -513,23 +517,63 @@ class Message(ABC): if meta.group: group_current.setdefault(meta.group) - if getattr(self, field_name) != PLACEHOLDER: - # Skip anything not set to the sentinel value + if self.__raw_get(field_name) != PLACEHOLDER: + # Found a non-sentinel value all_sentinel = False if meta.group: # This was set, so make it the selected value of the one-of. group_current[meta.group] = field_name - continue - - setattr(self, field_name, self._get_field_default(field_name)) - # Now that all the defaults are set, reset it! self.__dict__["_serialized_on_wire"] = not all_sentinel self.__dict__["_unknown_fields"] = b"" self.__dict__["_group_current"] = group_current + def __raw_get(self, name: str) -> Any: + return super().__getattribute__(name) + + def __eq__(self, other) -> bool: + if type(self) is not type(other): + return False + + for field_name in self._betterproto.meta_by_field_name: + self_val = self.__raw_get(field_name) + other_val = other.__raw_get(field_name) + if self_val is PLACEHOLDER: + if other_val is PLACEHOLDER: + continue + self_val = self._get_field_default(field_name) + elif other_val is PLACEHOLDER: + other_val = other._get_field_default(field_name) + + if self_val != other_val: + return False + + return True + + def __repr__(self) -> str: + parts = [ + f"{field_name}={value!r}" + for field_name in self._betterproto.sorted_field_names + for value in (self.__raw_get(field_name),) + if value is not PLACEHOLDER + ] + return f"{self.__class__.__name__}({', '.join(parts)})" + + def __getattribute__(self, name: str) -> Any: + """ + Lazily initialize default values to avoid infinite recursion for recursive + message types + """ + value = super().__getattribute__(name) + if value is not PLACEHOLDER: + return value + + value = self._get_field_default(name) + super().__setattr__(name, value) + return value + def __setattr__(self, attr: str, value: Any) -> None: if attr != "_serialized_on_wire": # Track when a field has been set. @@ -542,9 +586,7 @@ class Message(ABC): if field.name == attr: self._group_current[group] = field.name else: - super().__setattr__( - field.name, self._get_field_default(field.name) - ) + super().__setattr__(field.name, PLACEHOLDER) super().__setattr__(attr, value) diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 7fd0463..753d340 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -37,7 +37,7 @@ class {{ enum.py_name }}(betterproto.Enum): {% endfor %} {% endif %} {% for message in output_file.messages %} -@dataclass +@dataclass(eq=False, repr=False) class {{ message.py_name }}(betterproto.Message): {% if message.comment %} {{ message.comment }} diff --git a/tests/inputs/recursivemessage/recursivemessage.json b/tests/inputs/recursivemessage/recursivemessage.json new file mode 100644 index 0000000..e92c3fb --- /dev/null +++ b/tests/inputs/recursivemessage/recursivemessage.json @@ -0,0 +1,12 @@ +{ + "name": "Zues", + "child": { + "name": "Hercules" + }, + "intermediate": { + "child": { + "name": "Douglas Adams" + }, + "number": 42 + } +} diff --git a/tests/inputs/recursivemessage/recursivemessage.proto b/tests/inputs/recursivemessage/recursivemessage.proto new file mode 100644 index 0000000..f988316 --- /dev/null +++ b/tests/inputs/recursivemessage/recursivemessage.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +message Test { + string name = 1; + Test child = 2; + Intermediate intermediate = 3; +} + + +message Intermediate { + int32 number = 1; + Test child = 2; +} diff --git a/tests/test_features.py b/tests/test_features.py index b5b3811..f548264 100644 --- a/tests/test_features.py +++ b/tests/test_features.py @@ -317,3 +317,51 @@ def test_oneof_default_value_set_causes_writes_wire(): == betterproto.which_one_of(_round_trip_serialization(foo3), "group1") == ("", None) ) + + +def test_recursive_message(): + from tests.output_betterproto.recursivemessage import Test as RecursiveMessage + + msg = RecursiveMessage() + + assert msg.child == RecursiveMessage() + + # Lazily-created zero-value children must not affect equality. + assert msg == RecursiveMessage() + + # Lazily-created zero-value children must not affect serialization. + assert bytes(msg) == b"" + + +def test_recursive_message_defaults(): + from tests.output_betterproto.recursivemessage import ( + Test as RecursiveMessage, + Intermediate, + ) + + msg = RecursiveMessage(name="bob", intermediate=Intermediate(42)) + + # set values are as expected + assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42)) + + # lazy initialized works modifies the message + assert msg != RecursiveMessage( + name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude") + ) + msg.child.child.name = "jude" + assert msg == RecursiveMessage( + name="bob", + intermediate=Intermediate(42), + child=RecursiveMessage(child=RecursiveMessage(name="jude")), + ) + + # lazily initialization recurses as needed + assert msg.child.child.child.child.child.child.child == RecursiveMessage() + assert msg.intermediate.child.intermediate == Intermediate() + + +def test_message_repr(): + from tests.output_betterproto.recursivemessage import Test + + assert repr(Test(name="Loki")) == "Test(name='Loki')" + assert repr(Test(child=Test(), name="Loki")) == "Test(name='Loki', child=Test())"