Add OneOf support, rework field detection
This commit is contained in:
		
							
								
								
									
										71
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										71
									
								
								README.md
									
									
									
									
									
								
							| @@ -168,19 +168,75 @@ Both serializing and parsing are supported to/from JSON and Python dictionaries | |||||||
|  |  | ||||||
| Sometimes it is useful to be able to determine whether a message has been sent on the wire. This is how the Google wrapper types work to let you know whether a value is unset, set as the default (zero value), or set as something else, for example. | Sometimes it is useful to be able to determine whether a message has been sent on the wire. This is how the Google wrapper types work to let you know whether a value is unset, set as the default (zero value), or set as something else, for example. | ||||||
|  |  | ||||||
| Use `Message().serialized_on_wire` to determine if it was sent. This is a little bit different from the official Google generated Python code: | Use `betterproto.serialized_on_wire(message)` to determine if it was sent. This is a little bit different from the official Google generated Python code, and it lives outside the generated `Message` class to prevent name clashes. Note that it **only** supports Proto 3 and thus can **only** be used to check if `Message` fields are set. You cannot check if a scalar was sent on the wire. | ||||||
|  |  | ||||||
| ```py | ```py | ||||||
| # Old way (official Google Protobuf package) | # Old way (official Google Protobuf package) | ||||||
| >>> mymessage.HasField('myfield') | >>> mymessage.HasField('myfield') | ||||||
|  |  | ||||||
| # New way (this project) | # New way (this project) | ||||||
| >>> mymessage.myfield.serialized_on_wire | >>> betterproto.serialized_on_wire(mymessage.myfield) | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ### One-of Support | ||||||
|  |  | ||||||
|  | Protobuf supports grouping fields in a `oneof` clause. Only one of the fields in the group may be set at a given time. For example, given the proto: | ||||||
|  |  | ||||||
|  | ```protobuf | ||||||
|  | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | message Test { | ||||||
|  |   oneof foo { | ||||||
|  |     bool on = 1; | ||||||
|  |     int32 count = 2; | ||||||
|  |     string name = 3; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | You can use `betterproto.which_one_of(message, group_name)` to determine which of the fields was set. It returns a tuple of the field name and value, or a blank string and `None` if unset. | ||||||
|  |  | ||||||
|  | ```py | ||||||
|  | >>> test = Test() | ||||||
|  | >>> betterproto.which_one_of(test, "foo") | ||||||
|  | ["", None] | ||||||
|  |  | ||||||
|  | >>> test.on = True | ||||||
|  | >>> betterproto.which_one_of(test, "foo") | ||||||
|  | ["on", True] | ||||||
|  |  | ||||||
|  | # Setting one member of the group resets the others. | ||||||
|  | >>> test.count = 57 | ||||||
|  | >>> betterproto.which_one_of(test, "foo") | ||||||
|  | ["count", 57] | ||||||
|  | >>> test.on | ||||||
|  | False | ||||||
|  |  | ||||||
|  | # Default (zero) values also work. | ||||||
|  | >>> test.name = "" | ||||||
|  | >>> betterproto.which_one_of(test, "foo") | ||||||
|  | ["name", ""] | ||||||
|  | >>> test.count | ||||||
|  | 0 | ||||||
|  | >>> test.on | ||||||
|  | False | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Again this is a little different than the official Google code generator: | ||||||
|  |  | ||||||
|  | ```py | ||||||
|  | # Old way (official Google protobuf package) | ||||||
|  | >>> message.WhichOneof("group") | ||||||
|  | "foo" | ||||||
|  |  | ||||||
|  | # New way (this project) | ||||||
|  | >>> betterproto.which_one_of(message, "group") | ||||||
|  | ["foo", "foo's value"] | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
| ## Development | ## Development | ||||||
|  |  | ||||||
| First, make sure you have Python 3.7+ and `pipenv` installed: | First, make sure you have Python 3.7+ and `pipenv` installed, along with the official [Protobuf Compiler](https://github.com/protocolbuffers/protobuf/releases) for your platform. Then: | ||||||
|  |  | ||||||
| ```sh | ```sh | ||||||
| # Get set up with the virtual env & dependencies | # Get set up with the virtual env & dependencies | ||||||
| @@ -224,10 +280,10 @@ $ pipenv run tests | |||||||
| - [x] Refs to nested types | - [x] Refs to nested types | ||||||
| - [x] Imports in proto files | - [x] Imports in proto files | ||||||
| - [x] Well-known Google types | - [x] Well-known Google types | ||||||
| - [ ] OneOf support | - [x] OneOf support | ||||||
|   - [x] Basic support on the wire |   - [x] Basic support on the wire | ||||||
|   - [ ] Check which was set from the group |   - [x] Check which was set from the group | ||||||
|   - [ ] Setting one unsets the others |   - [x] Setting one unsets the others | ||||||
| - [ ] JSON that isn't completely naive. | - [ ] JSON that isn't completely naive. | ||||||
|   - [x] 64-bit ints as strings |   - [x] 64-bit ints as strings | ||||||
|   - [x] Maps |   - [x] Maps | ||||||
| @@ -236,6 +292,7 @@ $ pipenv run tests | |||||||
|   - [ ] Any support |   - [ ] Any support | ||||||
|   - [x] Enum strings |   - [x] Enum strings | ||||||
|   - [ ] Well known types support (timestamp, duration, wrappers) |   - [ ] Well known types support (timestamp, duration, wrappers) | ||||||
|  |   - [ ] Support different casing (orig vs. camel vs. others?) | ||||||
| - [ ] Async service stubs | - [ ] Async service stubs | ||||||
|   - [x] Unary-unary |   - [x] Unary-unary | ||||||
|   - [x] Server streaming response |   - [x] Server streaming response | ||||||
| @@ -243,7 +300,7 @@ $ pipenv run tests | |||||||
| - [ ] Renaming messages and fields to conform to Python name standards | - [ ] Renaming messages and fields to conform to Python name standards | ||||||
| - [ ] Renaming clashes with language keywords and standard library top-level packages | - [ ] Renaming clashes with language keywords and standard library top-level packages | ||||||
| - [x] Python package | - [x] Python package | ||||||
| - [ ] Automate running tests | - [x] Automate running tests | ||||||
| - [ ] Cleanup! | - [ ] Cleanup! | ||||||
|  |  | ||||||
| ## License | ## License | ||||||
|   | |||||||
| @@ -130,6 +130,8 @@ class FieldMetadata: | |||||||
|     proto_type: str |     proto_type: str | ||||||
|     # Map information if the proto_type is a map |     # Map information if the proto_type is a map | ||||||
|     map_types: Optional[Tuple[str, str]] |     map_types: Optional[Tuple[str, str]] | ||||||
|  |     # Groups several "one-of" fields together | ||||||
|  |     group: Optional[str] | ||||||
|  |  | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     def get(field: dataclasses.Field) -> "FieldMetadata": |     def get(field: dataclasses.Field) -> "FieldMetadata": | ||||||
| @@ -138,12 +140,16 @@ class FieldMetadata: | |||||||
|  |  | ||||||
|  |  | ||||||
| def dataclass_field( | def dataclass_field( | ||||||
|     number: int, proto_type: str, map_types: Optional[Tuple[str, str]] = None |     number: int, | ||||||
|  |     proto_type: str, | ||||||
|  |     *, | ||||||
|  |     map_types: Optional[Tuple[str, str]] = None, | ||||||
|  |     group: Optional[str] = None, | ||||||
| ) -> dataclasses.Field: | ) -> dataclasses.Field: | ||||||
|     """Creates a dataclass field with attached protobuf metadata.""" |     """Creates a dataclass field with attached protobuf metadata.""" | ||||||
|     return dataclasses.field( |     return dataclasses.field( | ||||||
|         default=PLACEHOLDER, |         default=PLACEHOLDER, | ||||||
|         metadata={"betterproto": FieldMetadata(number, proto_type, map_types)}, |         metadata={"betterproto": FieldMetadata(number, proto_type, map_types, group)}, | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| @@ -152,76 +158,80 @@ def dataclass_field( | |||||||
| # out at runtime. The generated dataclass variables are still typed correctly. | # out at runtime. The generated dataclass variables are still typed correctly. | ||||||
|  |  | ||||||
|  |  | ||||||
| def enum_field(number: int) -> Any: | def enum_field(number: int, group: Optional[str] = None) -> Any: | ||||||
|     return dataclass_field(number, TYPE_ENUM) |     return dataclass_field(number, TYPE_ENUM, group=group) | ||||||
|  |  | ||||||
|  |  | ||||||
| def bool_field(number: int) -> Any: | def bool_field(number: int, group: Optional[str] = None) -> Any: | ||||||
|     return dataclass_field(number, TYPE_BOOL) |     return dataclass_field(number, TYPE_BOOL, group=group) | ||||||
|  |  | ||||||
|  |  | ||||||
| def int32_field(number: int) -> Any: | def int32_field(number: int, group: Optional[str] = None) -> Any: | ||||||
|     return dataclass_field(number, TYPE_INT32) |     return dataclass_field(number, TYPE_INT32, group=group) | ||||||
|  |  | ||||||
|  |  | ||||||
| def int64_field(number: int) -> Any: | def int64_field(number: int, group: Optional[str] = None) -> Any: | ||||||
|     return dataclass_field(number, TYPE_INT64) |     return dataclass_field(number, TYPE_INT64, group=group) | ||||||
|  |  | ||||||
|  |  | ||||||
| def uint32_field(number: int) -> Any: | def uint32_field(number: int, group: Optional[str] = None) -> Any: | ||||||
|     return dataclass_field(number, TYPE_UINT32) |     return dataclass_field(number, TYPE_UINT32, group=group) | ||||||
|  |  | ||||||
|  |  | ||||||
| def uint64_field(number: int) -> Any: | def uint64_field(number: int, group: Optional[str] = None) -> Any: | ||||||
|     return dataclass_field(number, TYPE_UINT64) |     return dataclass_field(number, TYPE_UINT64, group=group) | ||||||
|  |  | ||||||
|  |  | ||||||
| def sint32_field(number: int) -> Any: | def sint32_field(number: int, group: Optional[str] = None) -> Any: | ||||||
|     return dataclass_field(number, TYPE_SINT32) |     return dataclass_field(number, TYPE_SINT32, group=group) | ||||||
|  |  | ||||||
|  |  | ||||||
| def sint64_field(number: int) -> Any: | def sint64_field(number: int, group: Optional[str] = None) -> Any: | ||||||
|     return dataclass_field(number, TYPE_SINT64) |     return dataclass_field(number, TYPE_SINT64, group=group) | ||||||
|  |  | ||||||
|  |  | ||||||
| def float_field(number: int) -> Any: | def float_field(number: int, group: Optional[str] = None) -> Any: | ||||||
|     return dataclass_field(number, TYPE_FLOAT) |     return dataclass_field(number, TYPE_FLOAT, group=group) | ||||||
|  |  | ||||||
|  |  | ||||||
| def double_field(number: int) -> Any: | def double_field(number: int, group: Optional[str] = None) -> Any: | ||||||
|     return dataclass_field(number, TYPE_DOUBLE) |     return dataclass_field(number, TYPE_DOUBLE, group=group) | ||||||
|  |  | ||||||
|  |  | ||||||
| def fixed32_field(number: int) -> Any: | def fixed32_field(number: int, group: Optional[str] = None) -> Any: | ||||||
|     return dataclass_field(number, TYPE_FIXED32) |     return dataclass_field(number, TYPE_FIXED32, group=group) | ||||||
|  |  | ||||||
|  |  | ||||||
| def fixed64_field(number: int) -> Any: | def fixed64_field(number: int, group: Optional[str] = None) -> Any: | ||||||
|     return dataclass_field(number, TYPE_FIXED64) |     return dataclass_field(number, TYPE_FIXED64, group=group) | ||||||
|  |  | ||||||
|  |  | ||||||
| def sfixed32_field(number: int) -> Any: | def sfixed32_field(number: int, group: Optional[str] = None) -> Any: | ||||||
|     return dataclass_field(number, TYPE_SFIXED32) |     return dataclass_field(number, TYPE_SFIXED32, group=group) | ||||||
|  |  | ||||||
|  |  | ||||||
| def sfixed64_field(number: int) -> Any: | def sfixed64_field(number: int, group: Optional[str] = None) -> Any: | ||||||
|     return dataclass_field(number, TYPE_SFIXED64) |     return dataclass_field(number, TYPE_SFIXED64, group=group) | ||||||
|  |  | ||||||
|  |  | ||||||
| def string_field(number: int) -> Any: | def string_field(number: int, group: Optional[str] = None) -> Any: | ||||||
|     return dataclass_field(number, TYPE_STRING) |     return dataclass_field(number, TYPE_STRING, group=group) | ||||||
|  |  | ||||||
|  |  | ||||||
| def bytes_field(number: int) -> Any: | def bytes_field(number: int, group: Optional[str] = None) -> Any: | ||||||
|     return dataclass_field(number, TYPE_BYTES) |     return dataclass_field(number, TYPE_BYTES, group=group) | ||||||
|  |  | ||||||
|  |  | ||||||
| def message_field(number: int) -> Any: | def message_field(number: int, group: Optional[str] = None) -> Any: | ||||||
|     return dataclass_field(number, TYPE_MESSAGE) |     return dataclass_field(number, TYPE_MESSAGE, group=group) | ||||||
|  |  | ||||||
|  |  | ||||||
| def map_field(number: int, key_type: str, value_type: str) -> Any: | def map_field( | ||||||
|     return dataclass_field(number, TYPE_MAP, map_types=(key_type, value_type)) |     number: int, key_type: str, value_type: str, group: Optional[str] = None | ||||||
|  | ) -> Any: | ||||||
|  |     return dataclass_field( | ||||||
|  |         number, TYPE_MAP, map_types=(key_type, value_type), group=group | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| class Enum(int, enum.Enum): | class Enum(int, enum.Enum): | ||||||
| @@ -383,46 +393,52 @@ class Message(ABC): | |||||||
|     to go between Python, binary and JSON protobuf message representations. |     to go between Python, binary and JSON protobuf message representations. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     # True if this message was or should be serialized on the wire. This can |  | ||||||
|     # be used to detect presence (e.g. optional wrapper message) and is used |  | ||||||
|     # internally during parsing/serialization. |  | ||||||
|     serialized_on_wire: bool |  | ||||||
|  |  | ||||||
|     def __post_init__(self) -> None: |     def __post_init__(self) -> None: | ||||||
|         # Set a default value for each field in the class after `__init__` has |         # Set a default value for each field in the class after `__init__` has | ||||||
|         # already been run. |         # already been run. | ||||||
|  |         group_map = {"fields": {}, "groups": {}} | ||||||
|         for field in dataclasses.fields(self): |         for field in dataclasses.fields(self): | ||||||
|             if getattr(self, field.name) != PLACEHOLDER: |  | ||||||
|                 # Skip anything not set (aka set to the sentinel value) |  | ||||||
|                 continue |  | ||||||
|  |  | ||||||
|             meta = FieldMetadata.get(field) |             meta = FieldMetadata.get(field) | ||||||
|  |  | ||||||
|             t = self._cls_for(field, index=-1) |             if meta.group: | ||||||
|  |                 group_map["fields"][field.name] = meta.group | ||||||
|  |  | ||||||
|             value: Any = 0 |                 if meta.group not in group_map["groups"]: | ||||||
|             if meta.proto_type == TYPE_MAP: |                     group_map["groups"][meta.group] = {"current": None, "fields": set()} | ||||||
|                 # Maps cannot be repeated, so we check these first. |                 group_map["groups"][meta.group]["fields"].add(field) | ||||||
|                 value = {} |  | ||||||
|             elif hasattr(t, "__args__") and len(t.__args__) == 1: |  | ||||||
|                 # Anything else with type args is a list. |  | ||||||
|                 value = [] |  | ||||||
|             elif meta.proto_type == TYPE_MESSAGE: |  | ||||||
|                 # Message means creating an instance of the right type. |  | ||||||
|                 value = t() |  | ||||||
|             else: |  | ||||||
|                 value = get_default(meta.proto_type) |  | ||||||
|  |  | ||||||
|             setattr(self, field.name, value) |             if getattr(self, field.name) != PLACEHOLDER: | ||||||
|  |                 # Skip anything not set to the sentinel value | ||||||
|  |  | ||||||
|  |                 if meta.group: | ||||||
|  |                     # This was set, so make it the selected value of the one-of. | ||||||
|  |                     group_map["groups"][meta.group]["current"] = field | ||||||
|  |  | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             setattr(self, field.name, self._get_field_default(field, meta)) | ||||||
|  |  | ||||||
|         # Now that all the defaults are set, reset it! |         # Now that all the defaults are set, reset it! | ||||||
|         self.__dict__["serialized_on_wire"] = False |         self.__dict__["_serialized_on_wire"] = False | ||||||
|         self.__dict__["_unknown_fields"] = b"" |         self.__dict__["_unknown_fields"] = b"" | ||||||
|  |         self.__dict__["_group_map"] = group_map | ||||||
|  |  | ||||||
|     def __setattr__(self, attr: str, value: Any) -> None: |     def __setattr__(self, attr: str, value: Any) -> None: | ||||||
|         if attr != "serialized_on_wire": |         if attr != "_serialized_on_wire": | ||||||
|             # Track when a field has been set. |             # Track when a field has been set. | ||||||
|             self.__dict__["serialized_on_wire"] = True |             self.__dict__["_serialized_on_wire"] = True | ||||||
|  |  | ||||||
|  |         if attr in getattr(self, "_group_map", {}).get("fields", {}): | ||||||
|  |             group = self._group_map["fields"][attr] | ||||||
|  |             for field in self._group_map["groups"][group]["fields"]: | ||||||
|  |                 if field.name == attr: | ||||||
|  |                     self._group_map["groups"][group]["current"] = field | ||||||
|  |                 else: | ||||||
|  |                     super().__setattr__( | ||||||
|  |                         field.name, | ||||||
|  |                         self._get_field_default(field, FieldMetadata.get(field)), | ||||||
|  |                     ) | ||||||
|  |  | ||||||
|         super().__setattr__(attr, value) |         super().__setattr__(attr, value) | ||||||
|  |  | ||||||
|     def __bytes__(self) -> bytes: |     def __bytes__(self) -> bytes: | ||||||
| @@ -434,8 +450,15 @@ class Message(ABC): | |||||||
|             meta = FieldMetadata.get(field) |             meta = FieldMetadata.get(field) | ||||||
|             value = getattr(self, field.name) |             value = getattr(self, field.name) | ||||||
|  |  | ||||||
|  |             # Being selected in a a group means this field is the one that is | ||||||
|  |             # currently set in a `oneof` group, so it must be serialized even | ||||||
|  |             # if the value is the default zero value. | ||||||
|  |             selected_in_group = False | ||||||
|  |             if meta.group and self._group_map["groups"][meta.group]["current"] == field: | ||||||
|  |                 selected_in_group = True | ||||||
|  |  | ||||||
|             if isinstance(value, list): |             if isinstance(value, list): | ||||||
|                 if not len(value): |                 if not len(value) and not selected_in_group: | ||||||
|                     # Empty values are not serialized |                     # Empty values are not serialized | ||||||
|                     continue |                     continue | ||||||
|  |  | ||||||
| @@ -451,7 +474,7 @@ class Message(ABC): | |||||||
|                     for item in value: |                     for item in value: | ||||||
|                         output += _serialize_single(meta.number, meta.proto_type, item) |                         output += _serialize_single(meta.number, meta.proto_type, item) | ||||||
|             elif isinstance(value, dict): |             elif isinstance(value, dict): | ||||||
|                 if not len(value): |                 if not len(value) and not selected_in_group: | ||||||
|                     # Empty values are not serialized |                     # Empty values are not serialized | ||||||
|                     continue |                     continue | ||||||
|  |  | ||||||
| @@ -461,12 +484,12 @@ class Message(ABC): | |||||||
|                     sv = _serialize_single(2, meta.map_types[1], v) |                     sv = _serialize_single(2, meta.map_types[1], v) | ||||||
|                     output += _serialize_single(meta.number, meta.proto_type, sk + sv) |                     output += _serialize_single(meta.number, meta.proto_type, sk + sv) | ||||||
|             else: |             else: | ||||||
|                 if value == get_default(meta.proto_type): |                 if value == get_default(meta.proto_type) and not selected_in_group: | ||||||
|                     # Default (zero) values are not serialized |                     # Default (zero) values are not serialized | ||||||
|                     continue |                     continue | ||||||
|  |  | ||||||
|                 serialize_empty = False |                 serialize_empty = False | ||||||
|                 if isinstance(value, Message) and value.serialized_on_wire: |                 if isinstance(value, Message) and value._serialized_on_wire: | ||||||
|                     serialize_empty = True |                     serialize_empty = True | ||||||
|                 output += _serialize_single( |                 output += _serialize_single( | ||||||
|                     meta.number, meta.proto_type, value, serialize_empty=serialize_empty |                     meta.number, meta.proto_type, value, serialize_empty=serialize_empty | ||||||
| @@ -486,6 +509,24 @@ class Message(ABC): | |||||||
|             cls = type_hints[field.name].__args__[index] |             cls = type_hints[field.name].__args__[index] | ||||||
|         return cls |         return cls | ||||||
|  |  | ||||||
|  |     def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any: | ||||||
|  |         t = self._cls_for(field, index=-1) | ||||||
|  |  | ||||||
|  |         value: Any = 0 | ||||||
|  |         if meta.proto_type == TYPE_MAP: | ||||||
|  |             # Maps cannot be repeated, so we check these first. | ||||||
|  |             value = {} | ||||||
|  |         elif hasattr(t, "__args__") and len(t.__args__) == 1: | ||||||
|  |             # Anything else with type args is a list. | ||||||
|  |             value = [] | ||||||
|  |         elif meta.proto_type == TYPE_MESSAGE: | ||||||
|  |             # Message means creating an instance of the right type. | ||||||
|  |             value = t() | ||||||
|  |         else: | ||||||
|  |             value = get_default(meta.proto_type) | ||||||
|  |  | ||||||
|  |         return value | ||||||
|  |  | ||||||
|     def _postprocess_single( |     def _postprocess_single( | ||||||
|         self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any |         self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any | ||||||
|     ) -> Any: |     ) -> Any: | ||||||
| @@ -508,7 +549,7 @@ class Message(ABC): | |||||||
|             elif meta.proto_type == TYPE_MESSAGE: |             elif meta.proto_type == TYPE_MESSAGE: | ||||||
|                 cls = self._cls_for(field) |                 cls = self._cls_for(field) | ||||||
|                 value = cls().parse(value) |                 value = cls().parse(value) | ||||||
|                 value.serialized_on_wire = True |                 value._serialized_on_wire = True | ||||||
|             elif meta.proto_type == TYPE_MAP: |             elif meta.proto_type == TYPE_MAP: | ||||||
|                 # TODO: This is slow, use a cache to make it faster since each |                 # TODO: This is slow, use a cache to make it faster since each | ||||||
|                 #       key/value pair will recreate the class. |                 #       key/value pair will recreate the class. | ||||||
| @@ -518,8 +559,8 @@ class Message(ABC): | |||||||
|                 Entry = dataclasses.make_dataclass( |                 Entry = dataclasses.make_dataclass( | ||||||
|                     "Entry", |                     "Entry", | ||||||
|                     [ |                     [ | ||||||
|                         ("key", kt, dataclass_field(1, meta.map_types[0], None)), |                         ("key", kt, dataclass_field(1, meta.map_types[0])), | ||||||
|                         ("value", vt, dataclass_field(2, meta.map_types[1], None)), |                         ("value", vt, dataclass_field(2, meta.map_types[1])), | ||||||
|                     ], |                     ], | ||||||
|                     bases=(Message,), |                     bases=(Message,), | ||||||
|                 ) |                 ) | ||||||
| @@ -597,7 +638,7 @@ class Message(ABC): | |||||||
|                     # Convert each item. |                     # Convert each item. | ||||||
|                     v = [i.to_dict() for i in v] |                     v = [i.to_dict() for i in v] | ||||||
|                     output[field.name] = v |                     output[field.name] = v | ||||||
|                 elif v.serialized_on_wire: |                 elif v._serialized_on_wire: | ||||||
|                     output[field.name] = v.to_dict() |                     output[field.name] = v.to_dict() | ||||||
|             elif meta.proto_type == "map": |             elif meta.proto_type == "map": | ||||||
|                 for k in v: |                 for k in v: | ||||||
| @@ -632,7 +673,7 @@ class Message(ABC): | |||||||
|         Parse the key/value pairs in `value` into this message instance. This |         Parse the key/value pairs in `value` into this message instance. This | ||||||
|         returns the instance itself and is therefore assignable and chainable. |         returns the instance itself and is therefore assignable and chainable. | ||||||
|         """ |         """ | ||||||
|         self.serialized_on_wire = True |         self._serialized_on_wire = True | ||||||
|         for field in dataclasses.fields(self): |         for field in dataclasses.fields(self): | ||||||
|             meta = FieldMetadata.get(field) |             meta = FieldMetadata.get(field) | ||||||
|             if field.name in value and value[field.name] is not None: |             if field.name in value and value[field.name] is not None: | ||||||
| @@ -685,6 +726,23 @@ class Message(ABC): | |||||||
|         return self.from_dict(json.loads(value)) |         return self.from_dict(json.loads(value)) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def serialized_on_wire(message: Message) -> bool: | ||||||
|  |     """ | ||||||
|  |     True if this message was or should be serialized on the wire. This can | ||||||
|  |     be used to detect presence (e.g. optional wrapper message) and is used | ||||||
|  |     internally during parsing/serialization. | ||||||
|  |     """ | ||||||
|  |     return message._serialized_on_wire | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]: | ||||||
|  |     """Return the name and value of a message's one-of field group.""" | ||||||
|  |     field = message._group_map["groups"].get(group_name, {}).get("current") | ||||||
|  |     if not field: | ||||||
|  |         return ("", None) | ||||||
|  |     return (field.name, getattr(message, field.name)) | ||||||
|  |  | ||||||
|  |  | ||||||
| class ServiceStub(ABC): | class ServiceStub(ABC): | ||||||
|     """ |     """ | ||||||
|     Base class for async gRPC service stubs. |     Base class for async gRPC service stubs. | ||||||
|   | |||||||
| @@ -242,6 +242,10 @@ def generate_code(request, response): | |||||||
|                             if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: |                             if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: | ||||||
|                                 packed = True |                                 packed = True | ||||||
|  |  | ||||||
|  |                         one_of = "" | ||||||
|  |                         if f.HasField("oneof_index"): | ||||||
|  |                             one_of = item.oneof_decl[f.oneof_index].name | ||||||
|  |  | ||||||
|                         data["properties"].append( |                         data["properties"].append( | ||||||
|                             { |                             { | ||||||
|                                 "name": f.name, |                                 "name": f.name, | ||||||
| @@ -254,6 +258,7 @@ def generate_code(request, response): | |||||||
|                                 "zero": zero, |                                 "zero": zero, | ||||||
|                                 "repeated": repeated, |                                 "repeated": repeated, | ||||||
|                                 "packed": packed, |                                 "packed": packed, | ||||||
|  |                                 "one_of": one_of, | ||||||
|                             } |                             } | ||||||
|                         ) |                         ) | ||||||
|                         # print(f, file=sys.stderr) |                         # print(f, file=sys.stderr) | ||||||
|   | |||||||
| @@ -44,7 +44,7 @@ class {{ message.name }}(betterproto.Message): | |||||||
|         {% if field.comment %} |         {% if field.comment %} | ||||||
| {{ field.comment }} | {{ field.comment }} | ||||||
|         {% endif %} |         {% endif %} | ||||||
|     {{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}) |     {{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}) | ||||||
|     {% endfor %} |     {% endfor %} | ||||||
|     {% if not message.properties %} |     {% if not message.properties %} | ||||||
|     pass |     pass | ||||||
|   | |||||||
| @@ -13,23 +13,23 @@ def test_has_field(): | |||||||
|  |  | ||||||
|     # Unset by default |     # Unset by default | ||||||
|     foo = Foo() |     foo = Foo() | ||||||
|     assert foo.bar.serialized_on_wire == False |     assert betterproto.serialized_on_wire(foo.bar) == False | ||||||
|  |  | ||||||
|     # Serialized after setting something |     # Serialized after setting something | ||||||
|     foo.bar.baz = 1 |     foo.bar.baz = 1 | ||||||
|     assert foo.bar.serialized_on_wire == True |     assert betterproto.serialized_on_wire(foo.bar) == True | ||||||
|  |  | ||||||
|     # Still has it after setting the default value |     # Still has it after setting the default value | ||||||
|     foo.bar.baz = 0 |     foo.bar.baz = 0 | ||||||
|     assert foo.bar.serialized_on_wire == True |     assert betterproto.serialized_on_wire(foo.bar) == True | ||||||
|  |  | ||||||
|     # Manual override |     # Manual override (don't do this) | ||||||
|     foo.bar.serialized_on_wire = False |     foo.bar._serialized_on_wire = False | ||||||
|     assert foo.bar.serialized_on_wire == False |     assert betterproto.serialized_on_wire(foo.bar) == False | ||||||
|  |  | ||||||
|     # Can manually set it but defaults to false |     # Can manually set it but defaults to false | ||||||
|     foo.bar = Bar() |     foo.bar = Bar() | ||||||
|     assert foo.bar.serialized_on_wire == False |     assert betterproto.serialized_on_wire(foo.bar) == False | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_enum_as_int_json(): | def test_enum_as_int_json(): | ||||||
| @@ -70,3 +70,48 @@ def test_unknown_fields(): | |||||||
|  |  | ||||||
|     new_again = Newer().parse(round_trip) |     new_again = Newer().parse(round_trip) | ||||||
|     assert newer == new_again |     assert newer == new_again | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def test_oneof_support(): | ||||||
|  |     @dataclass | ||||||
|  |     class Sub(betterproto.Message): | ||||||
|  |         val: int = betterproto.int32_field(1) | ||||||
|  |  | ||||||
|  |     @dataclass | ||||||
|  |     class Foo(betterproto.Message): | ||||||
|  |         bar: int = betterproto.int32_field(1, group="group1") | ||||||
|  |         baz: str = betterproto.string_field(2, group="group1") | ||||||
|  |         sub: Sub = betterproto.message_field(3, group="group2") | ||||||
|  |         abc: str = betterproto.string_field(4, group="group2") | ||||||
|  |  | ||||||
|  |     foo = Foo() | ||||||
|  |  | ||||||
|  |     assert betterproto.which_one_of(foo, "group1")[0] == "" | ||||||
|  |  | ||||||
|  |     foo.bar = 1 | ||||||
|  |     foo.baz = "test" | ||||||
|  |  | ||||||
|  |     # Other oneof fields should now be unset | ||||||
|  |     assert foo.bar == 0 | ||||||
|  |     assert betterproto.which_one_of(foo, "group1")[0] == "baz" | ||||||
|  |  | ||||||
|  |     foo.sub.val = 1 | ||||||
|  |     assert betterproto.serialized_on_wire(foo.sub) | ||||||
|  |  | ||||||
|  |     foo.abc = "test" | ||||||
|  |  | ||||||
|  |     # Group 1 shouldn't be touched, group 2 should have reset | ||||||
|  |     assert foo.sub.val == 0 | ||||||
|  |     assert betterproto.serialized_on_wire(foo.sub) == False | ||||||
|  |     assert betterproto.which_one_of(foo, "group2")[0] == "abc" | ||||||
|  |  | ||||||
|  |     # Zero value should always serialize for one-of | ||||||
|  |     foo = Foo(bar=0) | ||||||
|  |     assert betterproto.which_one_of(foo, "group1")[0] == "bar" | ||||||
|  |     assert bytes(foo) == b"\x08\x00" | ||||||
|  |  | ||||||
|  |     # Round trip should also work | ||||||
|  |     foo2 = Foo().parse(bytes(foo)) | ||||||
|  |     assert betterproto.which_one_of(foo2, "group1")[0] == "bar" | ||||||
|  |     assert foo.bar == 0 | ||||||
|  |     assert betterproto.which_one_of(foo2, "group2")[0] == "" | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user