Add support for recursive messages (#130)
Changes message initialization (`__post_init__`) so that default values are no longer eagerly created to prevent infinite recursion when initializing recursive messages. As a result, `PLACEHOLDER` will be present in the message for any uninitialized fields. So, an implementation of `__get_attribute__` is added that checks for `PLACEHOLDER` and lazily creates and stores default field values. And, because `PLACEHOLDER` values don't compare equal with zero values, a custom implementation of `__eq__` is provided, and the code generation template is updated so that messages generate with `@dataclass(eq=False)`. Also add new Message __repr__ implementation that skips PLACEHOLDER values and orders keys by number from the proto. Co-authored-by: Christopher Chambers <chris@peanutcode.com> Co-authored-by: nat <n@natn.me> Co-authored-by: James <50501825+Gobot1234@users.noreply.github.com>
This commit is contained in:
		| @@ -428,6 +428,7 @@ class ProtoClassMetadata: | ||||
|         "cls_by_field", | ||||
|         "field_name_by_number", | ||||
|         "meta_by_field_name", | ||||
|         "sorted_field_names", | ||||
|     ) | ||||
|  | ||||
|     def __init__(self, cls: Type["Message"]): | ||||
| @@ -453,6 +454,9 @@ class ProtoClassMetadata: | ||||
|         self.oneof_field_by_group = by_group | ||||
|         self.field_name_by_number = by_field_number | ||||
|         self.meta_by_field_name = by_field_name | ||||
|         self.sorted_field_names = tuple( | ||||
|             by_field_number[number] for number in sorted(by_field_number.keys()) | ||||
|         ) | ||||
|  | ||||
|         self.default_gen = self._get_default_gen(cls, fields) | ||||
|         self.cls_by_field = self._get_cls_by_field(cls, fields) | ||||
| @@ -513,23 +517,63 @@ class Message(ABC): | ||||
|             if meta.group: | ||||
|                 group_current.setdefault(meta.group) | ||||
|  | ||||
|             if getattr(self, field_name) != PLACEHOLDER: | ||||
|                 # Skip anything not set to the sentinel value | ||||
|             if self.__raw_get(field_name) != PLACEHOLDER: | ||||
|                 # Found a non-sentinel value | ||||
|                 all_sentinel = False | ||||
|  | ||||
|                 if meta.group: | ||||
|                     # This was set, so make it the selected value of the one-of. | ||||
|                     group_current[meta.group] = field_name | ||||
|  | ||||
|                 continue | ||||
|  | ||||
|             setattr(self, field_name, self._get_field_default(field_name)) | ||||
|  | ||||
|         # Now that all the defaults are set, reset it! | ||||
|         self.__dict__["_serialized_on_wire"] = not all_sentinel | ||||
|         self.__dict__["_unknown_fields"] = b"" | ||||
|         self.__dict__["_group_current"] = group_current | ||||
|  | ||||
|     def __raw_get(self, name: str) -> Any: | ||||
|         return super().__getattribute__(name) | ||||
|  | ||||
|     def __eq__(self, other) -> bool: | ||||
|         if type(self) is not type(other): | ||||
|             return False | ||||
|  | ||||
|         for field_name in self._betterproto.meta_by_field_name: | ||||
|             self_val = self.__raw_get(field_name) | ||||
|             other_val = other.__raw_get(field_name) | ||||
|             if self_val is PLACEHOLDER: | ||||
|                 if other_val is PLACEHOLDER: | ||||
|                     continue | ||||
|                 self_val = self._get_field_default(field_name) | ||||
|             elif other_val is PLACEHOLDER: | ||||
|                 other_val = other._get_field_default(field_name) | ||||
|  | ||||
|             if self_val != other_val: | ||||
|                 return False | ||||
|  | ||||
|         return True | ||||
|  | ||||
|     def __repr__(self) -> str: | ||||
|         parts = [ | ||||
|             f"{field_name}={value!r}" | ||||
|             for field_name in self._betterproto.sorted_field_names | ||||
|             for value in (self.__raw_get(field_name),) | ||||
|             if value is not PLACEHOLDER | ||||
|         ] | ||||
|         return f"{self.__class__.__name__}({', '.join(parts)})" | ||||
|  | ||||
|     def __getattribute__(self, name: str) -> Any: | ||||
|         """ | ||||
|         Lazily initialize default values to avoid infinite recursion for recursive | ||||
|         message types | ||||
|         """ | ||||
|         value = super().__getattribute__(name) | ||||
|         if value is not PLACEHOLDER: | ||||
|             return value | ||||
|  | ||||
|         value = self._get_field_default(name) | ||||
|         super().__setattr__(name, value) | ||||
|         return value | ||||
|  | ||||
|     def __setattr__(self, attr: str, value: Any) -> None: | ||||
|         if attr != "_serialized_on_wire": | ||||
|             # Track when a field has been set. | ||||
| @@ -542,9 +586,7 @@ class Message(ABC): | ||||
|                     if field.name == attr: | ||||
|                         self._group_current[group] = field.name | ||||
|                     else: | ||||
|                         super().__setattr__( | ||||
|                             field.name, self._get_field_default(field.name) | ||||
|                         ) | ||||
|                         super().__setattr__(field.name, PLACEHOLDER) | ||||
|  | ||||
|         super().__setattr__(attr, value) | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user