More features, refactoring
This commit is contained in:
@@ -17,16 +17,49 @@ import dataclasses
|
||||
|
||||
from . import parse, serialize
|
||||
|
||||
TYPE_ENUM = "enum"
|
||||
TYPE_BOOL = "bool"
|
||||
TYPE_INT32 = "int32"
|
||||
TYPE_INT64 = "int64"
|
||||
TYPE_UINT32 = "uint32"
|
||||
TYPE_UINT64 = "uint64"
|
||||
TYPE_SINT32 = "sint32"
|
||||
TYPE_SINT64 = "sint64"
|
||||
TYPE_FLOAT = "float"
|
||||
TYPE_DOUBLE = "double"
|
||||
TYPE_FIXED32 = "fixed32"
|
||||
TYPE_SFIXED32 = "sfixed32"
|
||||
TYPE_FIXED64 = "fixed64"
|
||||
TYPE_SFIXED64 = "sfixed64"
|
||||
TYPE_STRING = "string"
|
||||
TYPE_BYTES = "bytes"
|
||||
TYPE_MESSAGE = "message"
|
||||
|
||||
|
||||
FIXED_TYPES = [
|
||||
TYPE_FLOAT,
|
||||
TYPE_DOUBLE,
|
||||
TYPE_FIXED32,
|
||||
TYPE_SFIXED32,
|
||||
TYPE_FIXED64,
|
||||
TYPE_SFIXED64,
|
||||
]
|
||||
|
||||
PACKED_TYPES = [
|
||||
"bool",
|
||||
"int32",
|
||||
"int64",
|
||||
"uint32",
|
||||
"uint64",
|
||||
"sint32",
|
||||
"sint64",
|
||||
"float",
|
||||
"double",
|
||||
TYPE_ENUM,
|
||||
TYPE_BOOL,
|
||||
TYPE_INT32,
|
||||
TYPE_INT64,
|
||||
TYPE_UINT32,
|
||||
TYPE_UINT64,
|
||||
TYPE_SINT32,
|
||||
TYPE_SINT64,
|
||||
TYPE_FLOAT,
|
||||
TYPE_DOUBLE,
|
||||
TYPE_FIXED32,
|
||||
TYPE_SFIXED32,
|
||||
TYPE_FIXED64,
|
||||
TYPE_SFIXED64,
|
||||
]
|
||||
|
||||
# Wire types
|
||||
@@ -36,6 +69,21 @@ WIRE_FIXED_64 = 1
|
||||
WIRE_LEN_DELIM = 2
|
||||
WIRE_FIXED_32 = 5
|
||||
|
||||
WIRE_VARINT_TYPES = [
|
||||
TYPE_ENUM,
|
||||
TYPE_BOOL,
|
||||
TYPE_INT32,
|
||||
TYPE_INT64,
|
||||
TYPE_UINT32,
|
||||
TYPE_UINT64,
|
||||
TYPE_SINT32,
|
||||
TYPE_SINT64,
|
||||
]
|
||||
|
||||
WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]
|
||||
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
|
||||
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _Meta:
|
||||
@@ -59,74 +107,135 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
|
||||
)
|
||||
|
||||
|
||||
def int32_field(
|
||||
number: int, default: Union[int, Type[Iterable]] = 0
|
||||
) -> dataclasses.Field:
|
||||
# Note: the fields below return `Any` to prevent type errors in the generated
|
||||
# data classes since the types won't match with `Field` and they get swapped
|
||||
# out at runtime. The generated dataclass variables are still typed correctly.
|
||||
|
||||
|
||||
def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
|
||||
return field(number, "enum", default=default)
|
||||
|
||||
|
||||
def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
|
||||
return field(number, "int32", default=default)
|
||||
|
||||
|
||||
def int64_field(number: int, default: int = 0) -> dataclasses.Field:
|
||||
def int64_field(number: int, default: int = 0) -> Any:
|
||||
return field(number, "int64", default=default)
|
||||
|
||||
|
||||
def uint32_field(number: int, default: int = 0) -> dataclasses.Field:
|
||||
def uint32_field(number: int, default: int = 0) -> Any:
|
||||
return field(number, "uint32", default=default)
|
||||
|
||||
|
||||
def uint64_field(number: int, default: int = 0) -> dataclasses.Field:
|
||||
def uint64_field(number: int, default: int = 0) -> Any:
|
||||
return field(number, "uint64", default=default)
|
||||
|
||||
|
||||
def sint32_field(number: int, default: int = 0) -> dataclasses.Field:
|
||||
def sint32_field(number: int, default: int = 0) -> Any:
|
||||
return field(number, "sint32", default=default)
|
||||
|
||||
|
||||
def sint64_field(number: int, default: int = 0) -> dataclasses.Field:
|
||||
def sint64_field(number: int, default: int = 0) -> Any:
|
||||
return field(number, "sint64", default=default)
|
||||
|
||||
|
||||
def float_field(number: int, default: float = 0.0) -> dataclasses.Field:
|
||||
def float_field(number: int, default: float = 0.0) -> Any:
|
||||
return field(number, "float", default=default)
|
||||
|
||||
|
||||
def double_field(number: int, default: float = 0.0) -> dataclasses.Field:
|
||||
def double_field(number: int, default: float = 0.0) -> Any:
|
||||
return field(number, "double", default=default)
|
||||
|
||||
|
||||
def string_field(number: int, default: str = "") -> dataclasses.Field:
|
||||
def fixed32_field(number: int, default: float = 0.0) -> Any:
|
||||
return field(number, "fixed32", default=default)
|
||||
|
||||
|
||||
def fixed64_field(number: int, default: float = 0.0) -> Any:
|
||||
return field(number, "fixed64", default=default)
|
||||
|
||||
|
||||
def sfixed32_field(number: int, default: float = 0.0) -> Any:
|
||||
return field(number, "sfixed32", default=default)
|
||||
|
||||
|
||||
def sfixed64_field(number: int, default: float = 0.0) -> Any:
|
||||
return field(number, "sfixed64", default=default)
|
||||
|
||||
|
||||
def string_field(number: int, default: str = "") -> Any:
|
||||
return field(number, "string", default=default)
|
||||
|
||||
|
||||
def message_field(number: int, default: Type["ProtoMessage"]) -> dataclasses.Field:
|
||||
def message_field(number: int, default: Type["Message"]) -> Any:
|
||||
return field(number, "message", default=default)
|
||||
|
||||
|
||||
def _serialize_single(meta: _Meta, value: Any) -> bytes:
|
||||
def _pack_fmt(proto_type: str) -> str:
|
||||
return {
|
||||
"double": "<d",
|
||||
"float": "<f",
|
||||
"fixed32": "<I",
|
||||
"fixed64": "<Q",
|
||||
"sfixed32": "<i",
|
||||
"sfixed64": "<q",
|
||||
}[proto_type]
|
||||
|
||||
|
||||
def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
|
||||
value = _preprocess_single(proto_type, value)
|
||||
|
||||
output = b""
|
||||
if meta.proto_type in ["int32", "int64", "uint32", "uint64"]:
|
||||
if value < 0:
|
||||
# Handle negative numbers.
|
||||
value += 1 << 64
|
||||
output = serialize.varint(meta.number, value)
|
||||
elif meta.proto_type in ["sint32", "sint64"]:
|
||||
if value >= 0:
|
||||
value = value << 1
|
||||
else:
|
||||
value = (value << 1) ^ (~0)
|
||||
output = serialize.varint(meta.number, value)
|
||||
elif meta.proto_type == "string":
|
||||
output = serialize.len_delim(meta.number, value.encode("utf-8"))
|
||||
elif meta.proto_type == "message":
|
||||
b = bytes(value)
|
||||
if len(b):
|
||||
output = serialize.len_delim(meta.number, b)
|
||||
if proto_type in WIRE_VARINT_TYPES:
|
||||
key = serialize._varint(field_number << 3)
|
||||
output += key + value
|
||||
elif proto_type in WIRE_FIXED_32_TYPES:
|
||||
key = serialize._varint((field_number << 3) | 5)
|
||||
output += key + value
|
||||
elif proto_type in WIRE_FIXED_64_TYPES:
|
||||
key = serialize._varint((field_number << 3) | 1)
|
||||
output += key + value
|
||||
elif proto_type in WIRE_LEN_DELIM_TYPES:
|
||||
if len(value):
|
||||
key = serialize._varint((field_number << 3) | 2)
|
||||
output += key + serialize._varint(len(value)) + value
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError(proto_type)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _parse_single(wire_type: int, meta: _Meta, field: Any, value: Any) -> Any:
|
||||
def _preprocess_single(proto_type: str, value: Any) -> bytes:
|
||||
"""Adjusts values before serialization."""
|
||||
if proto_type in [
|
||||
TYPE_ENUM,
|
||||
TYPE_BOOL,
|
||||
TYPE_INT32,
|
||||
TYPE_INT64,
|
||||
TYPE_UINT32,
|
||||
TYPE_UINT64,
|
||||
]:
|
||||
return serialize._varint(value)
|
||||
elif proto_type in [TYPE_SINT32, TYPE_SINT64]:
|
||||
# Handle zig-zag encoding.
|
||||
if value >= 0:
|
||||
value = value << 1
|
||||
else:
|
||||
value = (value << 1) ^ (~0)
|
||||
return serialize._varint(value)
|
||||
elif proto_type in FIXED_TYPES:
|
||||
return struct.pack(_pack_fmt(proto_type), value)
|
||||
elif proto_type == TYPE_STRING:
|
||||
return value.encode("utf-8")
|
||||
elif proto_type == TYPE_MESSAGE:
|
||||
return bytes(value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def _postprocess_single(wire_type: int, meta: _Meta, field: Any, value: Any) -> Any:
|
||||
"""Adjusts values after parsing."""
|
||||
if wire_type == WIRE_VARINT:
|
||||
if meta.proto_type in ["int32", "int64"]:
|
||||
bits = int(meta.proto_type[3:])
|
||||
@@ -136,6 +245,9 @@ def _parse_single(wire_type: int, meta: _Meta, field: Any, value: Any) -> Any:
|
||||
elif meta.proto_type in ["sint32", "sint64"]:
|
||||
# Undo zig-zag encoding
|
||||
value = (value >> 1) ^ (-(value & 1))
|
||||
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
|
||||
fmt = _pack_fmt(meta.proto_type)
|
||||
value = struct.unpack(fmt, value)[0]
|
||||
elif wire_type == WIRE_LEN_DELIM:
|
||||
if meta.proto_type in ["string"]:
|
||||
value = value.decode("utf-8")
|
||||
@@ -170,15 +282,21 @@ class Message(ABC):
|
||||
continue
|
||||
|
||||
if meta.proto_type in PACKED_TYPES:
|
||||
output += serialize.packed(meta.number, value)
|
||||
# Packed lists look like a length-delimited field. First,
|
||||
# preprocess/encode each value into a buffer and then
|
||||
# treat it like a field of raw bytes.
|
||||
buf = b""
|
||||
for item in value:
|
||||
buf += _preprocess_single(meta.proto_type, item)
|
||||
output += _serialize_single(meta.number, TYPE_BYTES, buf)
|
||||
else:
|
||||
for item in value:
|
||||
output += _serialize_single(meta, item)
|
||||
output += _serialize_single(meta.number, meta.proto_type, item)
|
||||
else:
|
||||
if value == field.default:
|
||||
continue
|
||||
|
||||
output += _serialize_single(meta, value)
|
||||
output += _serialize_single(meta.number, meta.proto_type, value)
|
||||
|
||||
return output
|
||||
|
||||
@@ -201,11 +319,21 @@ class Message(ABC):
|
||||
pos = 0
|
||||
value = []
|
||||
while pos < len(parsed.value):
|
||||
decoded, pos = parse._decode_varint(parsed.value, pos)
|
||||
decoded = _parse_single(WIRE_VARINT, meta, field, decoded)
|
||||
if meta.proto_type in ["float", "fixed32", "sfixed32"]:
|
||||
decoded, pos = parsed.value[pos : pos + 4], pos + 4
|
||||
wire_type = WIRE_FIXED_32
|
||||
elif meta.proto_type in ["double", "fixed64", "sfixed64"]:
|
||||
decoded, pos = parsed.value[pos : pos + 8], pos + 8
|
||||
wire_type = WIRE_FIXED_64
|
||||
else:
|
||||
decoded, pos = parse.parse_varint(parsed.value, pos)
|
||||
wire_type = WIRE_VARINT
|
||||
decoded = _postprocess_single(wire_type, meta, field, decoded)
|
||||
value.append(decoded)
|
||||
else:
|
||||
value = _parse_single(parsed.wire_type, meta, field, parsed.value)
|
||||
value = _postprocess_single(
|
||||
parsed.wire_type, meta, field, parsed.value
|
||||
)
|
||||
|
||||
if isinstance(getattr(self, field.name), list) and not isinstance(
|
||||
value, list
|
||||
|
||||
Reference in New Issue
Block a user