Add support for recursive messages (#130)
Changes message initialization (`__post_init__`) so that default values are no longer eagerly created to prevent infinite recursion when initializing recursive messages. As a result, `PLACEHOLDER` will be present in the message for any uninitialized fields. So, an implementation of `__get_attribute__` is added that checks for `PLACEHOLDER` and lazily creates and stores default field values. And, because `PLACEHOLDER` values don't compare equal with zero values, a custom implementation of `__eq__` is provided, and the code generation template is updated so that messages generate with `@dataclass(eq=False)`. Also add new Message __repr__ implementation that skips PLACEHOLDER values and orders keys by number from the proto. Co-authored-by: Christopher Chambers <chris@peanutcode.com> Co-authored-by: nat <n@natn.me> Co-authored-by: James <50501825+Gobot1234@users.noreply.github.com>
This commit is contained in:
		| @@ -428,6 +428,7 @@ class ProtoClassMetadata: | |||||||
|         "cls_by_field", |         "cls_by_field", | ||||||
|         "field_name_by_number", |         "field_name_by_number", | ||||||
|         "meta_by_field_name", |         "meta_by_field_name", | ||||||
|  |         "sorted_field_names", | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     def __init__(self, cls: Type["Message"]): |     def __init__(self, cls: Type["Message"]): | ||||||
| @@ -453,6 +454,9 @@ class ProtoClassMetadata: | |||||||
|         self.oneof_field_by_group = by_group |         self.oneof_field_by_group = by_group | ||||||
|         self.field_name_by_number = by_field_number |         self.field_name_by_number = by_field_number | ||||||
|         self.meta_by_field_name = by_field_name |         self.meta_by_field_name = by_field_name | ||||||
|  |         self.sorted_field_names = tuple( | ||||||
|  |             by_field_number[number] for number in sorted(by_field_number.keys()) | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         self.default_gen = self._get_default_gen(cls, fields) |         self.default_gen = self._get_default_gen(cls, fields) | ||||||
|         self.cls_by_field = self._get_cls_by_field(cls, fields) |         self.cls_by_field = self._get_cls_by_field(cls, fields) | ||||||
| @@ -513,23 +517,63 @@ class Message(ABC): | |||||||
|             if meta.group: |             if meta.group: | ||||||
|                 group_current.setdefault(meta.group) |                 group_current.setdefault(meta.group) | ||||||
|  |  | ||||||
|             if getattr(self, field_name) != PLACEHOLDER: |             if self.__raw_get(field_name) != PLACEHOLDER: | ||||||
|                 # Skip anything not set to the sentinel value |                 # Found a non-sentinel value | ||||||
|                 all_sentinel = False |                 all_sentinel = False | ||||||
|  |  | ||||||
|                 if meta.group: |                 if meta.group: | ||||||
|                     # This was set, so make it the selected value of the one-of. |                     # This was set, so make it the selected value of the one-of. | ||||||
|                     group_current[meta.group] = field_name |                     group_current[meta.group] = field_name | ||||||
|  |  | ||||||
|                 continue |  | ||||||
|  |  | ||||||
|             setattr(self, field_name, self._get_field_default(field_name)) |  | ||||||
|  |  | ||||||
|         # Now that all the defaults are set, reset it! |         # Now that all the defaults are set, reset it! | ||||||
|         self.__dict__["_serialized_on_wire"] = not all_sentinel |         self.__dict__["_serialized_on_wire"] = not all_sentinel | ||||||
|         self.__dict__["_unknown_fields"] = b"" |         self.__dict__["_unknown_fields"] = b"" | ||||||
|         self.__dict__["_group_current"] = group_current |         self.__dict__["_group_current"] = group_current | ||||||
|  |  | ||||||
|  |     def __raw_get(self, name: str) -> Any: | ||||||
|  |         return super().__getattribute__(name) | ||||||
|  |  | ||||||
|  |     def __eq__(self, other) -> bool: | ||||||
|  |         if type(self) is not type(other): | ||||||
|  |             return False | ||||||
|  |  | ||||||
|  |         for field_name in self._betterproto.meta_by_field_name: | ||||||
|  |             self_val = self.__raw_get(field_name) | ||||||
|  |             other_val = other.__raw_get(field_name) | ||||||
|  |             if self_val is PLACEHOLDER: | ||||||
|  |                 if other_val is PLACEHOLDER: | ||||||
|  |                     continue | ||||||
|  |                 self_val = self._get_field_default(field_name) | ||||||
|  |             elif other_val is PLACEHOLDER: | ||||||
|  |                 other_val = other._get_field_default(field_name) | ||||||
|  |  | ||||||
|  |             if self_val != other_val: | ||||||
|  |                 return False | ||||||
|  |  | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     def __repr__(self) -> str: | ||||||
|  |         parts = [ | ||||||
|  |             f"{field_name}={value!r}" | ||||||
|  |             for field_name in self._betterproto.sorted_field_names | ||||||
|  |             for value in (self.__raw_get(field_name),) | ||||||
|  |             if value is not PLACEHOLDER | ||||||
|  |         ] | ||||||
|  |         return f"{self.__class__.__name__}({', '.join(parts)})" | ||||||
|  |  | ||||||
|  |     def __getattribute__(self, name: str) -> Any: | ||||||
|  |         """ | ||||||
|  |         Lazily initialize default values to avoid infinite recursion for recursive | ||||||
|  |         message types | ||||||
|  |         """ | ||||||
|  |         value = super().__getattribute__(name) | ||||||
|  |         if value is not PLACEHOLDER: | ||||||
|  |             return value | ||||||
|  |  | ||||||
|  |         value = self._get_field_default(name) | ||||||
|  |         super().__setattr__(name, value) | ||||||
|  |         return value | ||||||
|  |  | ||||||
|     def __setattr__(self, attr: str, value: Any) -> None: |     def __setattr__(self, attr: str, value: Any) -> None: | ||||||
|         if attr != "_serialized_on_wire": |         if attr != "_serialized_on_wire": | ||||||
|             # Track when a field has been set. |             # Track when a field has been set. | ||||||
| @@ -542,9 +586,7 @@ class Message(ABC): | |||||||
|                     if field.name == attr: |                     if field.name == attr: | ||||||
|                         self._group_current[group] = field.name |                         self._group_current[group] = field.name | ||||||
|                     else: |                     else: | ||||||
|                         super().__setattr__( |                         super().__setattr__(field.name, PLACEHOLDER) | ||||||
|                             field.name, self._get_field_default(field.name) |  | ||||||
|                         ) |  | ||||||
|  |  | ||||||
|         super().__setattr__(attr, value) |         super().__setattr__(attr, value) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -37,7 +37,7 @@ class {{ enum.py_name }}(betterproto.Enum): | |||||||
| {% endfor %} | {% endfor %} | ||||||
| {% endif %} | {% endif %} | ||||||
| {% for message in output_file.messages %} | {% for message in output_file.messages %} | ||||||
| @dataclass | @dataclass(eq=False, repr=False) | ||||||
| class {{ message.py_name }}(betterproto.Message): | class {{ message.py_name }}(betterproto.Message): | ||||||
|     {% if message.comment %} |     {% if message.comment %} | ||||||
| {{ message.comment }} | {{ message.comment }} | ||||||
|   | |||||||
							
								
								
									
										12
									
								
								tests/inputs/recursivemessage/recursivemessage.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								tests/inputs/recursivemessage/recursivemessage.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,12 @@ | |||||||
|  | { | ||||||
|  |   "name": "Zues", | ||||||
|  |   "child": { | ||||||
|  |     "name": "Hercules" | ||||||
|  |   }, | ||||||
|  |   "intermediate": { | ||||||
|  |     "child": { | ||||||
|  |       "name": "Douglas Adams" | ||||||
|  |     }, | ||||||
|  |     "number": 42 | ||||||
|  |   } | ||||||
|  | } | ||||||
							
								
								
									
										13
									
								
								tests/inputs/recursivemessage/recursivemessage.proto
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								tests/inputs/recursivemessage/recursivemessage.proto
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | |||||||
|  | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | message Test { | ||||||
|  |     string name = 1; | ||||||
|  |     Test child = 2; | ||||||
|  |     Intermediate intermediate = 3; | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | message Intermediate { | ||||||
|  |     int32 number = 1; | ||||||
|  |     Test child = 2; | ||||||
|  | } | ||||||
| @@ -317,3 +317,51 @@ def test_oneof_default_value_set_causes_writes_wire(): | |||||||
|         == betterproto.which_one_of(_round_trip_serialization(foo3), "group1") |         == betterproto.which_one_of(_round_trip_serialization(foo3), "group1") | ||||||
|         == ("", None) |         == ("", None) | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_recursive_message(): | ||||||
|  |     from tests.output_betterproto.recursivemessage import Test as RecursiveMessage | ||||||
|  |  | ||||||
|  |     msg = RecursiveMessage() | ||||||
|  |  | ||||||
|  |     assert msg.child == RecursiveMessage() | ||||||
|  |  | ||||||
|  |     # Lazily-created zero-value children must not affect equality. | ||||||
|  |     assert msg == RecursiveMessage() | ||||||
|  |  | ||||||
|  |     # Lazily-created zero-value children must not affect serialization. | ||||||
|  |     assert bytes(msg) == b"" | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_recursive_message_defaults(): | ||||||
|  |     from tests.output_betterproto.recursivemessage import ( | ||||||
|  |         Test as RecursiveMessage, | ||||||
|  |         Intermediate, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     msg = RecursiveMessage(name="bob", intermediate=Intermediate(42)) | ||||||
|  |  | ||||||
|  |     # set values are as expected | ||||||
|  |     assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42)) | ||||||
|  |  | ||||||
|  |     # lazy initialized works modifies the message | ||||||
|  |     assert msg != RecursiveMessage( | ||||||
|  |         name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude") | ||||||
|  |     ) | ||||||
|  |     msg.child.child.name = "jude" | ||||||
|  |     assert msg == RecursiveMessage( | ||||||
|  |         name="bob", | ||||||
|  |         intermediate=Intermediate(42), | ||||||
|  |         child=RecursiveMessage(child=RecursiveMessage(name="jude")), | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     # lazily initialization recurses as needed | ||||||
|  |     assert msg.child.child.child.child.child.child.child == RecursiveMessage() | ||||||
|  |     assert msg.intermediate.child.intermediate == Intermediate() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_message_repr(): | ||||||
|  |     from tests.output_betterproto.recursivemessage import Test | ||||||
|  |  | ||||||
|  |     assert repr(Test(name="Loki")) == "Test(name='Loki')" | ||||||
|  |     assert repr(Test(child=Test(), name="Loki")) == "Test(name='Loki', child=Test())" | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user