Add OneOf support, rework field detection
This commit is contained in:
parent
477e9cdae8
commit
5dae20970b
71
README.md
71
README.md
@ -168,19 +168,75 @@ Both serializing and parsing are supported to/from JSON and Python dictionaries
|
|||||||
|
|
||||||
Sometimes it is useful to be able to determine whether a message has been sent on the wire. This is how the Google wrapper types work to let you know whether a value is unset, set as the default (zero value), or set as something else, for example.
|
Sometimes it is useful to be able to determine whether a message has been sent on the wire. This is how the Google wrapper types work to let you know whether a value is unset, set as the default (zero value), or set as something else, for example.
|
||||||
|
|
||||||
Use `Message().serialized_on_wire` to determine if it was sent. This is a little bit different from the official Google generated Python code:
|
Use `betterproto.serialized_on_wire(message)` to determine if it was sent. This is a little bit different from the official Google generated Python code, and it lives outside the generated `Message` class to prevent name clashes. Note that it **only** supports Proto 3 and thus can **only** be used to check if `Message` fields are set. You cannot check if a scalar was sent on the wire.
|
||||||
|
|
||||||
```py
|
```py
|
||||||
# Old way (official Google Protobuf package)
|
# Old way (official Google Protobuf package)
|
||||||
>>> mymessage.HasField('myfield')
|
>>> mymessage.HasField('myfield')
|
||||||
|
|
||||||
# New way (this project)
|
# New way (this project)
|
||||||
>>> mymessage.myfield.serialized_on_wire
|
>>> betterproto.serialized_on_wire(mymessage.myfield)
|
||||||
|
```
|
||||||
|
|
||||||
|
### One-of Support
|
||||||
|
|
||||||
|
Protobuf supports grouping fields in a `oneof` clause. Only one of the fields in the group may be set at a given time. For example, given the proto:
|
||||||
|
|
||||||
|
```protobuf
|
||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
oneof foo {
|
||||||
|
bool on = 1;
|
||||||
|
int32 count = 2;
|
||||||
|
string name = 3;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
You can use `betterproto.which_one_of(message, group_name)` to determine which of the fields was set. It returns a tuple of the field name and value, or a blank string and `None` if unset.
|
||||||
|
|
||||||
|
```py
|
||||||
|
>>> test = Test()
|
||||||
|
>>> betterproto.which_one_of(test, "foo")
|
||||||
|
["", None]
|
||||||
|
|
||||||
|
>>> test.on = True
|
||||||
|
>>> betterproto.which_one_of(test, "foo")
|
||||||
|
["on", True]
|
||||||
|
|
||||||
|
# Setting one member of the group resets the others.
|
||||||
|
>>> test.count = 57
|
||||||
|
>>> betterproto.which_one_of(test, "foo")
|
||||||
|
["count", 57]
|
||||||
|
>>> test.on
|
||||||
|
False
|
||||||
|
|
||||||
|
# Default (zero) values also work.
|
||||||
|
>>> test.name = ""
|
||||||
|
>>> betterproto.which_one_of(test, "foo")
|
||||||
|
["name", ""]
|
||||||
|
>>> test.count
|
||||||
|
0
|
||||||
|
>>> test.on
|
||||||
|
False
|
||||||
|
```
|
||||||
|
|
||||||
|
Again this is a little different than the official Google code generator:
|
||||||
|
|
||||||
|
```py
|
||||||
|
# Old way (official Google protobuf package)
|
||||||
|
>>> message.WhichOneof("group")
|
||||||
|
"foo"
|
||||||
|
|
||||||
|
# New way (this project)
|
||||||
|
>>> betterproto.which_one_of(message, "group")
|
||||||
|
["foo", "foo's value"]
|
||||||
```
|
```
|
||||||
|
|
||||||
## Development
|
## Development
|
||||||
|
|
||||||
First, make sure you have Python 3.7+ and `pipenv` installed:
|
First, make sure you have Python 3.7+ and `pipenv` installed, along with the official [Protobuf Compiler](https://github.com/protocolbuffers/protobuf/releases) for your platform. Then:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
# Get set up with the virtual env & dependencies
|
# Get set up with the virtual env & dependencies
|
||||||
@ -224,10 +280,10 @@ $ pipenv run tests
|
|||||||
- [x] Refs to nested types
|
- [x] Refs to nested types
|
||||||
- [x] Imports in proto files
|
- [x] Imports in proto files
|
||||||
- [x] Well-known Google types
|
- [x] Well-known Google types
|
||||||
- [ ] OneOf support
|
- [x] OneOf support
|
||||||
- [x] Basic support on the wire
|
- [x] Basic support on the wire
|
||||||
- [ ] Check which was set from the group
|
- [x] Check which was set from the group
|
||||||
- [ ] Setting one unsets the others
|
- [x] Setting one unsets the others
|
||||||
- [ ] JSON that isn't completely naive.
|
- [ ] JSON that isn't completely naive.
|
||||||
- [x] 64-bit ints as strings
|
- [x] 64-bit ints as strings
|
||||||
- [x] Maps
|
- [x] Maps
|
||||||
@ -236,6 +292,7 @@ $ pipenv run tests
|
|||||||
- [ ] Any support
|
- [ ] Any support
|
||||||
- [x] Enum strings
|
- [x] Enum strings
|
||||||
- [ ] Well known types support (timestamp, duration, wrappers)
|
- [ ] Well known types support (timestamp, duration, wrappers)
|
||||||
|
- [ ] Support different casing (orig vs. camel vs. others?)
|
||||||
- [ ] Async service stubs
|
- [ ] Async service stubs
|
||||||
- [x] Unary-unary
|
- [x] Unary-unary
|
||||||
- [x] Server streaming response
|
- [x] Server streaming response
|
||||||
@ -243,7 +300,7 @@ $ pipenv run tests
|
|||||||
- [ ] Renaming messages and fields to conform to Python name standards
|
- [ ] Renaming messages and fields to conform to Python name standards
|
||||||
- [ ] Renaming clashes with language keywords and standard library top-level packages
|
- [ ] Renaming clashes with language keywords and standard library top-level packages
|
||||||
- [x] Python package
|
- [x] Python package
|
||||||
- [ ] Automate running tests
|
- [x] Automate running tests
|
||||||
- [ ] Cleanup!
|
- [ ] Cleanup!
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
@ -130,6 +130,8 @@ class FieldMetadata:
|
|||||||
proto_type: str
|
proto_type: str
|
||||||
# Map information if the proto_type is a map
|
# Map information if the proto_type is a map
|
||||||
map_types: Optional[Tuple[str, str]]
|
map_types: Optional[Tuple[str, str]]
|
||||||
|
# Groups several "one-of" fields together
|
||||||
|
group: Optional[str]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get(field: dataclasses.Field) -> "FieldMetadata":
|
def get(field: dataclasses.Field) -> "FieldMetadata":
|
||||||
@ -138,12 +140,16 @@ class FieldMetadata:
|
|||||||
|
|
||||||
|
|
||||||
def dataclass_field(
|
def dataclass_field(
|
||||||
number: int, proto_type: str, map_types: Optional[Tuple[str, str]] = None
|
number: int,
|
||||||
|
proto_type: str,
|
||||||
|
*,
|
||||||
|
map_types: Optional[Tuple[str, str]] = None,
|
||||||
|
group: Optional[str] = None,
|
||||||
) -> dataclasses.Field:
|
) -> dataclasses.Field:
|
||||||
"""Creates a dataclass field with attached protobuf metadata."""
|
"""Creates a dataclass field with attached protobuf metadata."""
|
||||||
return dataclasses.field(
|
return dataclasses.field(
|
||||||
default=PLACEHOLDER,
|
default=PLACEHOLDER,
|
||||||
metadata={"betterproto": FieldMetadata(number, proto_type, map_types)},
|
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, group)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -152,76 +158,80 @@ def dataclass_field(
|
|||||||
# out at runtime. The generated dataclass variables are still typed correctly.
|
# out at runtime. The generated dataclass variables are still typed correctly.
|
||||||
|
|
||||||
|
|
||||||
def enum_field(number: int) -> Any:
|
def enum_field(number: int, group: Optional[str] = None) -> Any:
|
||||||
return dataclass_field(number, TYPE_ENUM)
|
return dataclass_field(number, TYPE_ENUM, group=group)
|
||||||
|
|
||||||
|
|
||||||
def bool_field(number: int) -> Any:
|
def bool_field(number: int, group: Optional[str] = None) -> Any:
|
||||||
return dataclass_field(number, TYPE_BOOL)
|
return dataclass_field(number, TYPE_BOOL, group=group)
|
||||||
|
|
||||||
|
|
||||||
def int32_field(number: int) -> Any:
|
def int32_field(number: int, group: Optional[str] = None) -> Any:
|
||||||
return dataclass_field(number, TYPE_INT32)
|
return dataclass_field(number, TYPE_INT32, group=group)
|
||||||
|
|
||||||
|
|
||||||
def int64_field(number: int) -> Any:
|
def int64_field(number: int, group: Optional[str] = None) -> Any:
|
||||||
return dataclass_field(number, TYPE_INT64)
|
return dataclass_field(number, TYPE_INT64, group=group)
|
||||||
|
|
||||||
|
|
||||||
def uint32_field(number: int) -> Any:
|
def uint32_field(number: int, group: Optional[str] = None) -> Any:
|
||||||
return dataclass_field(number, TYPE_UINT32)
|
return dataclass_field(number, TYPE_UINT32, group=group)
|
||||||
|
|
||||||
|
|
||||||
def uint64_field(number: int) -> Any:
|
def uint64_field(number: int, group: Optional[str] = None) -> Any:
|
||||||
return dataclass_field(number, TYPE_UINT64)
|
return dataclass_field(number, TYPE_UINT64, group=group)
|
||||||
|
|
||||||
|
|
||||||
def sint32_field(number: int) -> Any:
|
def sint32_field(number: int, group: Optional[str] = None) -> Any:
|
||||||
return dataclass_field(number, TYPE_SINT32)
|
return dataclass_field(number, TYPE_SINT32, group=group)
|
||||||
|
|
||||||
|
|
||||||
def sint64_field(number: int) -> Any:
|
def sint64_field(number: int, group: Optional[str] = None) -> Any:
|
||||||
return dataclass_field(number, TYPE_SINT64)
|
return dataclass_field(number, TYPE_SINT64, group=group)
|
||||||
|
|
||||||
|
|
||||||
def float_field(number: int) -> Any:
|
def float_field(number: int, group: Optional[str] = None) -> Any:
|
||||||
return dataclass_field(number, TYPE_FLOAT)
|
return dataclass_field(number, TYPE_FLOAT, group=group)
|
||||||
|
|
||||||
|
|
||||||
def double_field(number: int) -> Any:
|
def double_field(number: int, group: Optional[str] = None) -> Any:
|
||||||
return dataclass_field(number, TYPE_DOUBLE)
|
return dataclass_field(number, TYPE_DOUBLE, group=group)
|
||||||
|
|
||||||
|
|
||||||
def fixed32_field(number: int) -> Any:
|
def fixed32_field(number: int, group: Optional[str] = None) -> Any:
|
||||||
return dataclass_field(number, TYPE_FIXED32)
|
return dataclass_field(number, TYPE_FIXED32, group=group)
|
||||||
|
|
||||||
|
|
||||||
def fixed64_field(number: int) -> Any:
|
def fixed64_field(number: int, group: Optional[str] = None) -> Any:
|
||||||
return dataclass_field(number, TYPE_FIXED64)
|
return dataclass_field(number, TYPE_FIXED64, group=group)
|
||||||
|
|
||||||
|
|
||||||
def sfixed32_field(number: int) -> Any:
|
def sfixed32_field(number: int, group: Optional[str] = None) -> Any:
|
||||||
return dataclass_field(number, TYPE_SFIXED32)
|
return dataclass_field(number, TYPE_SFIXED32, group=group)
|
||||||
|
|
||||||
|
|
||||||
def sfixed64_field(number: int) -> Any:
|
def sfixed64_field(number: int, group: Optional[str] = None) -> Any:
|
||||||
return dataclass_field(number, TYPE_SFIXED64)
|
return dataclass_field(number, TYPE_SFIXED64, group=group)
|
||||||
|
|
||||||
|
|
||||||
def string_field(number: int) -> Any:
|
def string_field(number: int, group: Optional[str] = None) -> Any:
|
||||||
return dataclass_field(number, TYPE_STRING)
|
return dataclass_field(number, TYPE_STRING, group=group)
|
||||||
|
|
||||||
|
|
||||||
def bytes_field(number: int) -> Any:
|
def bytes_field(number: int, group: Optional[str] = None) -> Any:
|
||||||
return dataclass_field(number, TYPE_BYTES)
|
return dataclass_field(number, TYPE_BYTES, group=group)
|
||||||
|
|
||||||
|
|
||||||
def message_field(number: int) -> Any:
|
def message_field(number: int, group: Optional[str] = None) -> Any:
|
||||||
return dataclass_field(number, TYPE_MESSAGE)
|
return dataclass_field(number, TYPE_MESSAGE, group=group)
|
||||||
|
|
||||||
|
|
||||||
def map_field(number: int, key_type: str, value_type: str) -> Any:
|
def map_field(
|
||||||
return dataclass_field(number, TYPE_MAP, map_types=(key_type, value_type))
|
number: int, key_type: str, value_type: str, group: Optional[str] = None
|
||||||
|
) -> Any:
|
||||||
|
return dataclass_field(
|
||||||
|
number, TYPE_MAP, map_types=(key_type, value_type), group=group
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Enum(int, enum.Enum):
|
class Enum(int, enum.Enum):
|
||||||
@ -383,46 +393,52 @@ class Message(ABC):
|
|||||||
to go between Python, binary and JSON protobuf message representations.
|
to go between Python, binary and JSON protobuf message representations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# True if this message was or should be serialized on the wire. This can
|
|
||||||
# be used to detect presence (e.g. optional wrapper message) and is used
|
|
||||||
# internally during parsing/serialization.
|
|
||||||
serialized_on_wire: bool
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
# Set a default value for each field in the class after `__init__` has
|
# Set a default value for each field in the class after `__init__` has
|
||||||
# already been run.
|
# already been run.
|
||||||
|
group_map = {"fields": {}, "groups": {}}
|
||||||
for field in dataclasses.fields(self):
|
for field in dataclasses.fields(self):
|
||||||
if getattr(self, field.name) != PLACEHOLDER:
|
|
||||||
# Skip anything not set (aka set to the sentinel value)
|
|
||||||
continue
|
|
||||||
|
|
||||||
meta = FieldMetadata.get(field)
|
meta = FieldMetadata.get(field)
|
||||||
|
|
||||||
t = self._cls_for(field, index=-1)
|
if meta.group:
|
||||||
|
group_map["fields"][field.name] = meta.group
|
||||||
|
|
||||||
value: Any = 0
|
if meta.group not in group_map["groups"]:
|
||||||
if meta.proto_type == TYPE_MAP:
|
group_map["groups"][meta.group] = {"current": None, "fields": set()}
|
||||||
# Maps cannot be repeated, so we check these first.
|
group_map["groups"][meta.group]["fields"].add(field)
|
||||||
value = {}
|
|
||||||
elif hasattr(t, "__args__") and len(t.__args__) == 1:
|
|
||||||
# Anything else with type args is a list.
|
|
||||||
value = []
|
|
||||||
elif meta.proto_type == TYPE_MESSAGE:
|
|
||||||
# Message means creating an instance of the right type.
|
|
||||||
value = t()
|
|
||||||
else:
|
|
||||||
value = get_default(meta.proto_type)
|
|
||||||
|
|
||||||
setattr(self, field.name, value)
|
if getattr(self, field.name) != PLACEHOLDER:
|
||||||
|
# Skip anything not set to the sentinel value
|
||||||
|
|
||||||
|
if meta.group:
|
||||||
|
# This was set, so make it the selected value of the one-of.
|
||||||
|
group_map["groups"][meta.group]["current"] = field
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
setattr(self, field.name, self._get_field_default(field, meta))
|
||||||
|
|
||||||
# Now that all the defaults are set, reset it!
|
# Now that all the defaults are set, reset it!
|
||||||
self.__dict__["serialized_on_wire"] = False
|
self.__dict__["_serialized_on_wire"] = False
|
||||||
self.__dict__["_unknown_fields"] = b""
|
self.__dict__["_unknown_fields"] = b""
|
||||||
|
self.__dict__["_group_map"] = group_map
|
||||||
|
|
||||||
def __setattr__(self, attr: str, value: Any) -> None:
|
def __setattr__(self, attr: str, value: Any) -> None:
|
||||||
if attr != "serialized_on_wire":
|
if attr != "_serialized_on_wire":
|
||||||
# Track when a field has been set.
|
# Track when a field has been set.
|
||||||
self.__dict__["serialized_on_wire"] = True
|
self.__dict__["_serialized_on_wire"] = True
|
||||||
|
|
||||||
|
if attr in getattr(self, "_group_map", {}).get("fields", {}):
|
||||||
|
group = self._group_map["fields"][attr]
|
||||||
|
for field in self._group_map["groups"][group]["fields"]:
|
||||||
|
if field.name == attr:
|
||||||
|
self._group_map["groups"][group]["current"] = field
|
||||||
|
else:
|
||||||
|
super().__setattr__(
|
||||||
|
field.name,
|
||||||
|
self._get_field_default(field, FieldMetadata.get(field)),
|
||||||
|
)
|
||||||
|
|
||||||
super().__setattr__(attr, value)
|
super().__setattr__(attr, value)
|
||||||
|
|
||||||
def __bytes__(self) -> bytes:
|
def __bytes__(self) -> bytes:
|
||||||
@ -434,8 +450,15 @@ class Message(ABC):
|
|||||||
meta = FieldMetadata.get(field)
|
meta = FieldMetadata.get(field)
|
||||||
value = getattr(self, field.name)
|
value = getattr(self, field.name)
|
||||||
|
|
||||||
|
# Being selected in a a group means this field is the one that is
|
||||||
|
# currently set in a `oneof` group, so it must be serialized even
|
||||||
|
# if the value is the default zero value.
|
||||||
|
selected_in_group = False
|
||||||
|
if meta.group and self._group_map["groups"][meta.group]["current"] == field:
|
||||||
|
selected_in_group = True
|
||||||
|
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
if not len(value):
|
if not len(value) and not selected_in_group:
|
||||||
# Empty values are not serialized
|
# Empty values are not serialized
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -451,7 +474,7 @@ class Message(ABC):
|
|||||||
for item in value:
|
for item in value:
|
||||||
output += _serialize_single(meta.number, meta.proto_type, item)
|
output += _serialize_single(meta.number, meta.proto_type, item)
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
if not len(value):
|
if not len(value) and not selected_in_group:
|
||||||
# Empty values are not serialized
|
# Empty values are not serialized
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -461,12 +484,12 @@ class Message(ABC):
|
|||||||
sv = _serialize_single(2, meta.map_types[1], v)
|
sv = _serialize_single(2, meta.map_types[1], v)
|
||||||
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
|
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
|
||||||
else:
|
else:
|
||||||
if value == get_default(meta.proto_type):
|
if value == get_default(meta.proto_type) and not selected_in_group:
|
||||||
# Default (zero) values are not serialized
|
# Default (zero) values are not serialized
|
||||||
continue
|
continue
|
||||||
|
|
||||||
serialize_empty = False
|
serialize_empty = False
|
||||||
if isinstance(value, Message) and value.serialized_on_wire:
|
if isinstance(value, Message) and value._serialized_on_wire:
|
||||||
serialize_empty = True
|
serialize_empty = True
|
||||||
output += _serialize_single(
|
output += _serialize_single(
|
||||||
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
|
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
|
||||||
@ -486,6 +509,24 @@ class Message(ABC):
|
|||||||
cls = type_hints[field.name].__args__[index]
|
cls = type_hints[field.name].__args__[index]
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
|
||||||
|
t = self._cls_for(field, index=-1)
|
||||||
|
|
||||||
|
value: Any = 0
|
||||||
|
if meta.proto_type == TYPE_MAP:
|
||||||
|
# Maps cannot be repeated, so we check these first.
|
||||||
|
value = {}
|
||||||
|
elif hasattr(t, "__args__") and len(t.__args__) == 1:
|
||||||
|
# Anything else with type args is a list.
|
||||||
|
value = []
|
||||||
|
elif meta.proto_type == TYPE_MESSAGE:
|
||||||
|
# Message means creating an instance of the right type.
|
||||||
|
value = t()
|
||||||
|
else:
|
||||||
|
value = get_default(meta.proto_type)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
def _postprocess_single(
|
def _postprocess_single(
|
||||||
self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any
|
self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
@ -508,7 +549,7 @@ class Message(ABC):
|
|||||||
elif meta.proto_type == TYPE_MESSAGE:
|
elif meta.proto_type == TYPE_MESSAGE:
|
||||||
cls = self._cls_for(field)
|
cls = self._cls_for(field)
|
||||||
value = cls().parse(value)
|
value = cls().parse(value)
|
||||||
value.serialized_on_wire = True
|
value._serialized_on_wire = True
|
||||||
elif meta.proto_type == TYPE_MAP:
|
elif meta.proto_type == TYPE_MAP:
|
||||||
# TODO: This is slow, use a cache to make it faster since each
|
# TODO: This is slow, use a cache to make it faster since each
|
||||||
# key/value pair will recreate the class.
|
# key/value pair will recreate the class.
|
||||||
@ -518,8 +559,8 @@ class Message(ABC):
|
|||||||
Entry = dataclasses.make_dataclass(
|
Entry = dataclasses.make_dataclass(
|
||||||
"Entry",
|
"Entry",
|
||||||
[
|
[
|
||||||
("key", kt, dataclass_field(1, meta.map_types[0], None)),
|
("key", kt, dataclass_field(1, meta.map_types[0])),
|
||||||
("value", vt, dataclass_field(2, meta.map_types[1], None)),
|
("value", vt, dataclass_field(2, meta.map_types[1])),
|
||||||
],
|
],
|
||||||
bases=(Message,),
|
bases=(Message,),
|
||||||
)
|
)
|
||||||
@ -597,7 +638,7 @@ class Message(ABC):
|
|||||||
# Convert each item.
|
# Convert each item.
|
||||||
v = [i.to_dict() for i in v]
|
v = [i.to_dict() for i in v]
|
||||||
output[field.name] = v
|
output[field.name] = v
|
||||||
elif v.serialized_on_wire:
|
elif v._serialized_on_wire:
|
||||||
output[field.name] = v.to_dict()
|
output[field.name] = v.to_dict()
|
||||||
elif meta.proto_type == "map":
|
elif meta.proto_type == "map":
|
||||||
for k in v:
|
for k in v:
|
||||||
@ -632,7 +673,7 @@ class Message(ABC):
|
|||||||
Parse the key/value pairs in `value` into this message instance. This
|
Parse the key/value pairs in `value` into this message instance. This
|
||||||
returns the instance itself and is therefore assignable and chainable.
|
returns the instance itself and is therefore assignable and chainable.
|
||||||
"""
|
"""
|
||||||
self.serialized_on_wire = True
|
self._serialized_on_wire = True
|
||||||
for field in dataclasses.fields(self):
|
for field in dataclasses.fields(self):
|
||||||
meta = FieldMetadata.get(field)
|
meta = FieldMetadata.get(field)
|
||||||
if field.name in value and value[field.name] is not None:
|
if field.name in value and value[field.name] is not None:
|
||||||
@ -685,6 +726,23 @@ class Message(ABC):
|
|||||||
return self.from_dict(json.loads(value))
|
return self.from_dict(json.loads(value))
|
||||||
|
|
||||||
|
|
||||||
|
def serialized_on_wire(message: Message) -> bool:
|
||||||
|
"""
|
||||||
|
True if this message was or should be serialized on the wire. This can
|
||||||
|
be used to detect presence (e.g. optional wrapper message) and is used
|
||||||
|
internally during parsing/serialization.
|
||||||
|
"""
|
||||||
|
return message._serialized_on_wire
|
||||||
|
|
||||||
|
|
||||||
|
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
||||||
|
"""Return the name and value of a message's one-of field group."""
|
||||||
|
field = message._group_map["groups"].get(group_name, {}).get("current")
|
||||||
|
if not field:
|
||||||
|
return ("", None)
|
||||||
|
return (field.name, getattr(message, field.name))
|
||||||
|
|
||||||
|
|
||||||
class ServiceStub(ABC):
|
class ServiceStub(ABC):
|
||||||
"""
|
"""
|
||||||
Base class for async gRPC service stubs.
|
Base class for async gRPC service stubs.
|
||||||
|
@ -242,6 +242,10 @@ def generate_code(request, response):
|
|||||||
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
|
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
|
||||||
packed = True
|
packed = True
|
||||||
|
|
||||||
|
one_of = ""
|
||||||
|
if f.HasField("oneof_index"):
|
||||||
|
one_of = item.oneof_decl[f.oneof_index].name
|
||||||
|
|
||||||
data["properties"].append(
|
data["properties"].append(
|
||||||
{
|
{
|
||||||
"name": f.name,
|
"name": f.name,
|
||||||
@ -254,6 +258,7 @@ def generate_code(request, response):
|
|||||||
"zero": zero,
|
"zero": zero,
|
||||||
"repeated": repeated,
|
"repeated": repeated,
|
||||||
"packed": packed,
|
"packed": packed,
|
||||||
|
"one_of": one_of,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
# print(f, file=sys.stderr)
|
# print(f, file=sys.stderr)
|
||||||
|
@ -44,7 +44,7 @@ class {{ message.name }}(betterproto.Message):
|
|||||||
{% if field.comment %}
|
{% if field.comment %}
|
||||||
{{ field.comment }}
|
{{ field.comment }}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %})
|
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %})
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
{% if not message.properties %}
|
{% if not message.properties %}
|
||||||
pass
|
pass
|
||||||
|
@ -13,23 +13,23 @@ def test_has_field():
|
|||||||
|
|
||||||
# Unset by default
|
# Unset by default
|
||||||
foo = Foo()
|
foo = Foo()
|
||||||
assert foo.bar.serialized_on_wire == False
|
assert betterproto.serialized_on_wire(foo.bar) == False
|
||||||
|
|
||||||
# Serialized after setting something
|
# Serialized after setting something
|
||||||
foo.bar.baz = 1
|
foo.bar.baz = 1
|
||||||
assert foo.bar.serialized_on_wire == True
|
assert betterproto.serialized_on_wire(foo.bar) == True
|
||||||
|
|
||||||
# Still has it after setting the default value
|
# Still has it after setting the default value
|
||||||
foo.bar.baz = 0
|
foo.bar.baz = 0
|
||||||
assert foo.bar.serialized_on_wire == True
|
assert betterproto.serialized_on_wire(foo.bar) == True
|
||||||
|
|
||||||
# Manual override
|
# Manual override (don't do this)
|
||||||
foo.bar.serialized_on_wire = False
|
foo.bar._serialized_on_wire = False
|
||||||
assert foo.bar.serialized_on_wire == False
|
assert betterproto.serialized_on_wire(foo.bar) == False
|
||||||
|
|
||||||
# Can manually set it but defaults to false
|
# Can manually set it but defaults to false
|
||||||
foo.bar = Bar()
|
foo.bar = Bar()
|
||||||
assert foo.bar.serialized_on_wire == False
|
assert betterproto.serialized_on_wire(foo.bar) == False
|
||||||
|
|
||||||
|
|
||||||
def test_enum_as_int_json():
|
def test_enum_as_int_json():
|
||||||
@ -70,3 +70,48 @@ def test_unknown_fields():
|
|||||||
|
|
||||||
new_again = Newer().parse(round_trip)
|
new_again = Newer().parse(round_trip)
|
||||||
assert newer == new_again
|
assert newer == new_again
|
||||||
|
|
||||||
|
|
||||||
|
def test_oneof_support():
|
||||||
|
@dataclass
|
||||||
|
class Sub(betterproto.Message):
|
||||||
|
val: int = betterproto.int32_field(1)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Foo(betterproto.Message):
|
||||||
|
bar: int = betterproto.int32_field(1, group="group1")
|
||||||
|
baz: str = betterproto.string_field(2, group="group1")
|
||||||
|
sub: Sub = betterproto.message_field(3, group="group2")
|
||||||
|
abc: str = betterproto.string_field(4, group="group2")
|
||||||
|
|
||||||
|
foo = Foo()
|
||||||
|
|
||||||
|
assert betterproto.which_one_of(foo, "group1")[0] == ""
|
||||||
|
|
||||||
|
foo.bar = 1
|
||||||
|
foo.baz = "test"
|
||||||
|
|
||||||
|
# Other oneof fields should now be unset
|
||||||
|
assert foo.bar == 0
|
||||||
|
assert betterproto.which_one_of(foo, "group1")[0] == "baz"
|
||||||
|
|
||||||
|
foo.sub.val = 1
|
||||||
|
assert betterproto.serialized_on_wire(foo.sub)
|
||||||
|
|
||||||
|
foo.abc = "test"
|
||||||
|
|
||||||
|
# Group 1 shouldn't be touched, group 2 should have reset
|
||||||
|
assert foo.sub.val == 0
|
||||||
|
assert betterproto.serialized_on_wire(foo.sub) == False
|
||||||
|
assert betterproto.which_one_of(foo, "group2")[0] == "abc"
|
||||||
|
|
||||||
|
# Zero value should always serialize for one-of
|
||||||
|
foo = Foo(bar=0)
|
||||||
|
assert betterproto.which_one_of(foo, "group1")[0] == "bar"
|
||||||
|
assert bytes(foo) == b"\x08\x00"
|
||||||
|
|
||||||
|
# Round trip should also work
|
||||||
|
foo2 = Foo().parse(bytes(foo))
|
||||||
|
assert betterproto.which_one_of(foo2, "group1")[0] == "bar"
|
||||||
|
assert foo.bar == 0
|
||||||
|
assert betterproto.which_one_of(foo2, "group2")[0] == ""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user