Add OneOf support, rework field detection
This commit is contained in:
@@ -130,6 +130,8 @@ class FieldMetadata:
|
||||
proto_type: str
|
||||
# Map information if the proto_type is a map
|
||||
map_types: Optional[Tuple[str, str]]
|
||||
# Groups several "one-of" fields together
|
||||
group: Optional[str]
|
||||
|
||||
@staticmethod
|
||||
def get(field: dataclasses.Field) -> "FieldMetadata":
|
||||
@@ -138,12 +140,16 @@ class FieldMetadata:
|
||||
|
||||
|
||||
def dataclass_field(
|
||||
number: int, proto_type: str, map_types: Optional[Tuple[str, str]] = None
|
||||
number: int,
|
||||
proto_type: str,
|
||||
*,
|
||||
map_types: Optional[Tuple[str, str]] = None,
|
||||
group: Optional[str] = None,
|
||||
) -> dataclasses.Field:
|
||||
"""Creates a dataclass field with attached protobuf metadata."""
|
||||
return dataclasses.field(
|
||||
default=PLACEHOLDER,
|
||||
metadata={"betterproto": FieldMetadata(number, proto_type, map_types)},
|
||||
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, group)},
|
||||
)
|
||||
|
||||
|
||||
@@ -152,76 +158,80 @@ def dataclass_field(
|
||||
# out at runtime. The generated dataclass variables are still typed correctly.
|
||||
|
||||
|
||||
def enum_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_ENUM)
|
||||
def enum_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_ENUM, group=group)
|
||||
|
||||
|
||||
def bool_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_BOOL)
|
||||
def bool_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_BOOL, group=group)
|
||||
|
||||
|
||||
def int32_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_INT32)
|
||||
def int32_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_INT32, group=group)
|
||||
|
||||
|
||||
def int64_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_INT64)
|
||||
def int64_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_INT64, group=group)
|
||||
|
||||
|
||||
def uint32_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_UINT32)
|
||||
def uint32_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_UINT32, group=group)
|
||||
|
||||
|
||||
def uint64_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_UINT64)
|
||||
def uint64_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_UINT64, group=group)
|
||||
|
||||
|
||||
def sint32_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_SINT32)
|
||||
def sint32_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_SINT32, group=group)
|
||||
|
||||
|
||||
def sint64_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_SINT64)
|
||||
def sint64_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_SINT64, group=group)
|
||||
|
||||
|
||||
def float_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_FLOAT)
|
||||
def float_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_FLOAT, group=group)
|
||||
|
||||
|
||||
def double_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_DOUBLE)
|
||||
def double_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_DOUBLE, group=group)
|
||||
|
||||
|
||||
def fixed32_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_FIXED32)
|
||||
def fixed32_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_FIXED32, group=group)
|
||||
|
||||
|
||||
def fixed64_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_FIXED64)
|
||||
def fixed64_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_FIXED64, group=group)
|
||||
|
||||
|
||||
def sfixed32_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_SFIXED32)
|
||||
def sfixed32_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_SFIXED32, group=group)
|
||||
|
||||
|
||||
def sfixed64_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_SFIXED64)
|
||||
def sfixed64_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_SFIXED64, group=group)
|
||||
|
||||
|
||||
def string_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_STRING)
|
||||
def string_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_STRING, group=group)
|
||||
|
||||
|
||||
def bytes_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_BYTES)
|
||||
def bytes_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_BYTES, group=group)
|
||||
|
||||
|
||||
def message_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_MESSAGE)
|
||||
def message_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_MESSAGE, group=group)
|
||||
|
||||
|
||||
def map_field(number: int, key_type: str, value_type: str) -> Any:
|
||||
return dataclass_field(number, TYPE_MAP, map_types=(key_type, value_type))
|
||||
def map_field(
|
||||
number: int, key_type: str, value_type: str, group: Optional[str] = None
|
||||
) -> Any:
|
||||
return dataclass_field(
|
||||
number, TYPE_MAP, map_types=(key_type, value_type), group=group
|
||||
)
|
||||
|
||||
|
||||
class Enum(int, enum.Enum):
|
||||
@@ -383,46 +393,52 @@ class Message(ABC):
|
||||
to go between Python, binary and JSON protobuf message representations.
|
||||
"""
|
||||
|
||||
# True if this message was or should be serialized on the wire. This can
|
||||
# be used to detect presence (e.g. optional wrapper message) and is used
|
||||
# internally during parsing/serialization.
|
||||
serialized_on_wire: bool
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Set a default value for each field in the class after `__init__` has
|
||||
# already been run.
|
||||
group_map = {"fields": {}, "groups": {}}
|
||||
for field in dataclasses.fields(self):
|
||||
if getattr(self, field.name) != PLACEHOLDER:
|
||||
# Skip anything not set (aka set to the sentinel value)
|
||||
continue
|
||||
|
||||
meta = FieldMetadata.get(field)
|
||||
|
||||
t = self._cls_for(field, index=-1)
|
||||
if meta.group:
|
||||
group_map["fields"][field.name] = meta.group
|
||||
|
||||
value: Any = 0
|
||||
if meta.proto_type == TYPE_MAP:
|
||||
# Maps cannot be repeated, so we check these first.
|
||||
value = {}
|
||||
elif hasattr(t, "__args__") and len(t.__args__) == 1:
|
||||
# Anything else with type args is a list.
|
||||
value = []
|
||||
elif meta.proto_type == TYPE_MESSAGE:
|
||||
# Message means creating an instance of the right type.
|
||||
value = t()
|
||||
else:
|
||||
value = get_default(meta.proto_type)
|
||||
if meta.group not in group_map["groups"]:
|
||||
group_map["groups"][meta.group] = {"current": None, "fields": set()}
|
||||
group_map["groups"][meta.group]["fields"].add(field)
|
||||
|
||||
setattr(self, field.name, value)
|
||||
if getattr(self, field.name) != PLACEHOLDER:
|
||||
# Skip anything not set to the sentinel value
|
||||
|
||||
if meta.group:
|
||||
# This was set, so make it the selected value of the one-of.
|
||||
group_map["groups"][meta.group]["current"] = field
|
||||
|
||||
continue
|
||||
|
||||
setattr(self, field.name, self._get_field_default(field, meta))
|
||||
|
||||
# Now that all the defaults are set, reset it!
|
||||
self.__dict__["serialized_on_wire"] = False
|
||||
self.__dict__["_serialized_on_wire"] = False
|
||||
self.__dict__["_unknown_fields"] = b""
|
||||
self.__dict__["_group_map"] = group_map
|
||||
|
||||
def __setattr__(self, attr: str, value: Any) -> None:
|
||||
if attr != "serialized_on_wire":
|
||||
if attr != "_serialized_on_wire":
|
||||
# Track when a field has been set.
|
||||
self.__dict__["serialized_on_wire"] = True
|
||||
self.__dict__["_serialized_on_wire"] = True
|
||||
|
||||
if attr in getattr(self, "_group_map", {}).get("fields", {}):
|
||||
group = self._group_map["fields"][attr]
|
||||
for field in self._group_map["groups"][group]["fields"]:
|
||||
if field.name == attr:
|
||||
self._group_map["groups"][group]["current"] = field
|
||||
else:
|
||||
super().__setattr__(
|
||||
field.name,
|
||||
self._get_field_default(field, FieldMetadata.get(field)),
|
||||
)
|
||||
|
||||
super().__setattr__(attr, value)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
@@ -434,8 +450,15 @@ class Message(ABC):
|
||||
meta = FieldMetadata.get(field)
|
||||
value = getattr(self, field.name)
|
||||
|
||||
# Being selected in a a group means this field is the one that is
|
||||
# currently set in a `oneof` group, so it must be serialized even
|
||||
# if the value is the default zero value.
|
||||
selected_in_group = False
|
||||
if meta.group and self._group_map["groups"][meta.group]["current"] == field:
|
||||
selected_in_group = True
|
||||
|
||||
if isinstance(value, list):
|
||||
if not len(value):
|
||||
if not len(value) and not selected_in_group:
|
||||
# Empty values are not serialized
|
||||
continue
|
||||
|
||||
@@ -451,7 +474,7 @@ class Message(ABC):
|
||||
for item in value:
|
||||
output += _serialize_single(meta.number, meta.proto_type, item)
|
||||
elif isinstance(value, dict):
|
||||
if not len(value):
|
||||
if not len(value) and not selected_in_group:
|
||||
# Empty values are not serialized
|
||||
continue
|
||||
|
||||
@@ -461,12 +484,12 @@ class Message(ABC):
|
||||
sv = _serialize_single(2, meta.map_types[1], v)
|
||||
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
|
||||
else:
|
||||
if value == get_default(meta.proto_type):
|
||||
if value == get_default(meta.proto_type) and not selected_in_group:
|
||||
# Default (zero) values are not serialized
|
||||
continue
|
||||
|
||||
serialize_empty = False
|
||||
if isinstance(value, Message) and value.serialized_on_wire:
|
||||
if isinstance(value, Message) and value._serialized_on_wire:
|
||||
serialize_empty = True
|
||||
output += _serialize_single(
|
||||
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
|
||||
@@ -486,6 +509,24 @@ class Message(ABC):
|
||||
cls = type_hints[field.name].__args__[index]
|
||||
return cls
|
||||
|
||||
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
|
||||
t = self._cls_for(field, index=-1)
|
||||
|
||||
value: Any = 0
|
||||
if meta.proto_type == TYPE_MAP:
|
||||
# Maps cannot be repeated, so we check these first.
|
||||
value = {}
|
||||
elif hasattr(t, "__args__") and len(t.__args__) == 1:
|
||||
# Anything else with type args is a list.
|
||||
value = []
|
||||
elif meta.proto_type == TYPE_MESSAGE:
|
||||
# Message means creating an instance of the right type.
|
||||
value = t()
|
||||
else:
|
||||
value = get_default(meta.proto_type)
|
||||
|
||||
return value
|
||||
|
||||
def _postprocess_single(
|
||||
self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any
|
||||
) -> Any:
|
||||
@@ -508,7 +549,7 @@ class Message(ABC):
|
||||
elif meta.proto_type == TYPE_MESSAGE:
|
||||
cls = self._cls_for(field)
|
||||
value = cls().parse(value)
|
||||
value.serialized_on_wire = True
|
||||
value._serialized_on_wire = True
|
||||
elif meta.proto_type == TYPE_MAP:
|
||||
# TODO: This is slow, use a cache to make it faster since each
|
||||
# key/value pair will recreate the class.
|
||||
@@ -518,8 +559,8 @@ class Message(ABC):
|
||||
Entry = dataclasses.make_dataclass(
|
||||
"Entry",
|
||||
[
|
||||
("key", kt, dataclass_field(1, meta.map_types[0], None)),
|
||||
("value", vt, dataclass_field(2, meta.map_types[1], None)),
|
||||
("key", kt, dataclass_field(1, meta.map_types[0])),
|
||||
("value", vt, dataclass_field(2, meta.map_types[1])),
|
||||
],
|
||||
bases=(Message,),
|
||||
)
|
||||
@@ -597,7 +638,7 @@ class Message(ABC):
|
||||
# Convert each item.
|
||||
v = [i.to_dict() for i in v]
|
||||
output[field.name] = v
|
||||
elif v.serialized_on_wire:
|
||||
elif v._serialized_on_wire:
|
||||
output[field.name] = v.to_dict()
|
||||
elif meta.proto_type == "map":
|
||||
for k in v:
|
||||
@@ -632,7 +673,7 @@ class Message(ABC):
|
||||
Parse the key/value pairs in `value` into this message instance. This
|
||||
returns the instance itself and is therefore assignable and chainable.
|
||||
"""
|
||||
self.serialized_on_wire = True
|
||||
self._serialized_on_wire = True
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
if field.name in value and value[field.name] is not None:
|
||||
@@ -685,6 +726,23 @@ class Message(ABC):
|
||||
return self.from_dict(json.loads(value))
|
||||
|
||||
|
||||
def serialized_on_wire(message: Message) -> bool:
|
||||
"""
|
||||
True if this message was or should be serialized on the wire. This can
|
||||
be used to detect presence (e.g. optional wrapper message) and is used
|
||||
internally during parsing/serialization.
|
||||
"""
|
||||
return message._serialized_on_wire
|
||||
|
||||
|
||||
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
||||
"""Return the name and value of a message's one-of field group."""
|
||||
field = message._group_map["groups"].get(group_name, {}).get("current")
|
||||
if not field:
|
||||
return ("", None)
|
||||
return (field.name, getattr(message, field.name))
|
||||
|
||||
|
||||
class ServiceStub(ABC):
|
||||
"""
|
||||
Base class for async gRPC service stubs.
|
||||
|
||||
Reference in New Issue
Block a user