Add message streaming support (#518)

This commit is contained in:
Joshua Leivers
2023-08-29 14:26:25 +01:00
committed by GitHub
parent 4cdf1bb9e0
commit 8659c51123
7 changed files with 589 additions and 35 deletions

View File

@@ -17,14 +17,16 @@ from datetime import (
timedelta,
timezone,
)
from io import BytesIO
from itertools import count
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Callable,
Dict,
Generator,
Iterable,
List,
Mapping,
Optional,
Set,
@@ -46,6 +48,10 @@ from .casing import (
from .grpc.grpclib_client import ServiceStub
if TYPE_CHECKING:
from _typeshed import ReadableBuffer
# Proto 3 data types
TYPE_ENUM = "enum"
TYPE_BOOL = "bool"
@@ -66,7 +72,6 @@ TYPE_BYTES = "bytes"
TYPE_MESSAGE = "message"
TYPE_MAP = "map"
# Fields that use a fixed amount of space (4 or 8 bytes)
FIXED_TYPES = [
TYPE_FLOAT,
@@ -129,7 +134,6 @@ def datetime_default_gen() -> datetime:
DATETIME_ZERO = datetime_default_gen()
# Special protobuf json doubles
INFINITY = "Infinity"
NEG_INFINITY = "-Infinity"
@@ -343,20 +347,43 @@ def _pack_fmt(proto_type: str) -> str:
}[proto_type]
def encode_varint(value: int) -> bytes:
"""Encodes a single varint value for serialization."""
b: List[int] = []
if value < 0:
def dump_varint(value: int, stream: BinaryIO) -> None:
"""Encodes a single varint and dumps it into the provided stream."""
if value < -(1 << 63):
raise ValueError(
"Negative value is not representable as a 64-bit integer - unable to encode a varint within 10 bytes."
)
elif value < 0:
value += 1 << 64
bits = value & 0x7F
value >>= 7
while value:
b.append(0x80 | bits)
stream.write((0x80 | bits).to_bytes(1, "little"))
bits = value & 0x7F
value >>= 7
return bytes(b + [bits])
stream.write(bits.to_bytes(1, "little"))
def encode_varint(value: int) -> bytes:
"""Encodes a single varint value for serialization."""
with BytesIO() as stream:
dump_varint(value, stream)
return stream.getvalue()
def size_varint(value: int) -> int:
"""Calculates the size in bytes that a value would take as a varint."""
if value < -(1 << 63):
raise ValueError(
"Negative value is not representable as a 64-bit integer - unable to encode a varint within 10 bytes."
)
elif value < 0:
return 10
elif value == 0:
return 1
else:
return math.ceil(value.bit_length() / 7)
def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
@@ -394,6 +421,41 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
return value
def _len_preprocessed_single(proto_type: str, wraps: str, value: Any) -> int:
"""Calculate the size of adjusted values for serialization without fully serializing them."""
if proto_type in (
TYPE_ENUM,
TYPE_BOOL,
TYPE_INT32,
TYPE_INT64,
TYPE_UINT32,
TYPE_UINT64,
):
return size_varint(value)
elif proto_type in (TYPE_SINT32, TYPE_SINT64):
# Handle zig-zag encoding.
return size_varint(value << 1 if value >= 0 else (value << 1) ^ (~0))
elif proto_type in FIXED_TYPES:
return len(struct.pack(_pack_fmt(proto_type), value))
elif proto_type == TYPE_STRING:
return len(value.encode("utf-8"))
elif proto_type == TYPE_MESSAGE:
if isinstance(value, datetime):
# Convert the `datetime` to a timestamp message.
value = _Timestamp.from_datetime(value)
elif isinstance(value, timedelta):
# Convert the `timedelta` to a duration message.
value = _Duration.from_timedelta(value)
elif wraps:
if value is None:
return 0
value = _get_wrapper(wraps)(value=value)
return len(bytes(value))
return len(value)
def _serialize_single(
field_number: int,
proto_type: str,
@@ -425,6 +487,31 @@ def _serialize_single(
return bytes(output)
def _len_single(
field_number: int,
proto_type: str,
value: Any,
*,
serialize_empty: bool = False,
wraps: str = "",
) -> int:
"""Calculates the size of a serialized single field and value."""
size = _len_preprocessed_single(proto_type, wraps, value)
if proto_type in WIRE_VARINT_TYPES:
size += size_varint(field_number << 3)
elif proto_type in WIRE_FIXED_32_TYPES:
size += size_varint((field_number << 3) | 5)
elif proto_type in WIRE_FIXED_64_TYPES:
size += size_varint((field_number << 3) | 1)
elif proto_type in WIRE_LEN_DELIM_TYPES:
if size or serialize_empty or wraps:
size += size_varint((field_number << 3) | 2) + size_varint(size)
else:
raise NotImplementedError(proto_type)
return size
def _parse_float(value: Any) -> float:
"""Parse the given value to a float
@@ -469,22 +556,34 @@ def _dump_float(value: float) -> Union[float, str]:
return value
def load_varint(stream: BinaryIO) -> Tuple[int, bytes]:
"""
Load a single varint value from a stream. Returns the value and the raw bytes read.
"""
result = 0
raw = b""
for shift in count(0, 7):
if shift >= 64:
raise ValueError("Too many bytes when decoding varint.")
b = stream.read(1)
if not b:
raise EOFError("Stream ended unexpectedly while attempting to load varint.")
raw += b
b_int = int.from_bytes(b, byteorder="little")
result |= (b_int & 0x7F) << shift
if not (b_int & 0x80):
return result, raw
def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]:
"""
Decode a single varint value from a byte buffer. Returns the value and the
new position in the buffer.
"""
result = 0
shift = 0
while 1:
b = buffer[pos]
result |= (b & 0x7F) << shift
pos += 1
if not (b & 0x80):
return result, pos
shift += 7
if shift >= 64:
raise ValueError("Too many bytes when decoding varint.")
with BytesIO(buffer) as stream:
stream.seek(pos)
value, raw = load_varint(stream)
return value, pos + len(raw)
@dataclasses.dataclass(frozen=True)
@@ -495,6 +594,34 @@ class ParsedField:
raw: bytes
def load_fields(stream: BinaryIO) -> Generator[ParsedField, None, None]:
while True:
try:
num_wire, raw = load_varint(stream)
except EOFError:
return
number = num_wire >> 3
wire_type = num_wire & 0x7
decoded: Any = None
if wire_type == WIRE_VARINT:
decoded, r = load_varint(stream)
raw += r
elif wire_type == WIRE_FIXED_64:
decoded = stream.read(8)
raw += decoded
elif wire_type == WIRE_LEN_DELIM:
length, r = load_varint(stream)
decoded = stream.read(length)
raw += r
raw += decoded
elif wire_type == WIRE_FIXED_32:
decoded = stream.read(4)
raw += decoded
yield ParsedField(number=number, wire_type=wire_type, value=decoded, raw=raw)
def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
i = 0
while i < len(value):
@@ -775,11 +902,16 @@ class Message(ABC):
self.__class__._betterproto_meta = meta # type: ignore
return meta
def __bytes__(self) -> bytes:
def dump(self, stream: BinaryIO) -> None:
"""
Get the binary encoded Protobuf representation of this message instance.
Dumps the binary encoded Protobuf message to the stream.
Parameters
-----------
stream: :class:`BinaryIO`
The stream to dump the message to.
"""
output = bytearray()
for field_name, meta in self._betterproto.meta_by_field_name.items():
try:
value = getattr(self, field_name)
@@ -825,10 +957,10 @@ class Message(ABC):
buf = bytearray()
for item in value:
buf += _preprocess_single(meta.proto_type, "", item)
output += _serialize_single(meta.number, TYPE_BYTES, buf)
stream.write(_serialize_single(meta.number, TYPE_BYTES, buf))
else:
for item in value:
output += (
stream.write(
_serialize_single(
meta.number,
meta.proto_type,
@@ -846,7 +978,9 @@ class Message(ABC):
assert meta.map_types
sk = _serialize_single(1, meta.map_types[0], k)
sv = _serialize_single(2, meta.map_types[1], v)
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
stream.write(
_serialize_single(meta.number, meta.proto_type, sk + sv)
)
else:
# If we have an empty string and we're including the default value for
# a oneof, make sure we serialize it. This ensures that the byte string
@@ -859,7 +993,111 @@ class Message(ABC):
):
serialize_empty = True
output += _serialize_single(
stream.write(
_serialize_single(
meta.number,
meta.proto_type,
value,
serialize_empty=serialize_empty or bool(selected_in_group),
wraps=meta.wraps or "",
)
)
stream.write(self._unknown_fields)
def __bytes__(self) -> bytes:
"""
Get the binary encoded Protobuf representation of this message instance.
"""
with BytesIO() as stream:
self.dump(stream)
return stream.getvalue()
def __len__(self) -> int:
"""
Get the size of the encoded Protobuf representation of this message instance.
"""
size = 0
for field_name, meta in self._betterproto.meta_by_field_name.items():
try:
value = getattr(self, field_name)
except AttributeError:
continue
if value is None:
# Optional items should be skipped. This is used for the Google
# wrapper types and proto3 field presence/optional fields.
continue
# Being selected in a group means this field is the one that is
# currently set in a `oneof` group, so it must be serialized even
# if the value is the default zero value.
#
# Note that proto3 field presence/optional fields are put in a
# synthetic single-item oneof by protoc, which helps us ensure we
# send the value even if the value is the default zero value.
selected_in_group = bool(meta.group)
# Empty messages can still be sent on the wire if they were
# set (or received empty).
serialize_empty = isinstance(value, Message) and value._serialized_on_wire
include_default_value_for_oneof = self._include_default_value_for_oneof(
field_name=field_name, meta=meta
)
if value == self._get_field_default(field_name) and not (
selected_in_group or serialize_empty or include_default_value_for_oneof
):
# Default (zero) values are not serialized. Two exceptions are
# if this is the selected oneof item or if we know we have to
# serialize an empty message (i.e. zero value was explicitly
# set by the user).
continue
if isinstance(value, list):
if meta.proto_type in PACKED_TYPES:
# Packed lists look like a length-delimited field. First,
# preprocess/encode each value into a buffer and then
# treat it like a field of raw bytes.
buf = bytearray()
for item in value:
buf += _preprocess_single(meta.proto_type, "", item)
size += _len_single(meta.number, TYPE_BYTES, buf)
else:
for item in value:
size += (
_len_single(
meta.number,
meta.proto_type,
item,
wraps=meta.wraps or "",
serialize_empty=True,
)
# if it's an empty message it still needs to be represented
# as an item in the repeated list
or 2
)
elif isinstance(value, dict):
for k, v in value.items():
assert meta.map_types
sk = _serialize_single(1, meta.map_types[0], k)
sv = _serialize_single(2, meta.map_types[1], v)
size += _len_single(meta.number, meta.proto_type, sk + sv)
else:
# If we have an empty string and we're including the default value for
# a oneof, make sure we serialize it. This ensures that the byte string
# output isn't simply an empty string. This also ensures that round trip
# serialization will keep `which_one_of` calls consistent.
if (
isinstance(value, str)
and value == ""
and include_default_value_for_oneof
):
serialize_empty = True
size += _len_single(
meta.number,
meta.proto_type,
value,
@@ -867,8 +1105,8 @@ class Message(ABC):
wraps=meta.wraps or "",
)
output += self._unknown_fields
return bytes(output)
size += len(self._unknown_fields)
return size
# For compatibility with other libraries
def SerializeToString(self: T) -> bytes:
@@ -987,15 +1225,18 @@ class Message(ABC):
meta.group is not None and self._group_current.get(meta.group) == field_name
)
def parse(self: T, data: bytes) -> T:
def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T:
"""
Parse the binary encoded Protobuf into this message instance. This
Load the binary encoded Protobuf from a stream into this message instance. This
returns the instance itself and is therefore assignable and chainable.
Parameters
-----------
data: :class:`bytes`
The data to parse the protobuf from.
stream: :class:`bytes`
The stream to load the message from.
size: :class:`Optional[int]`
The size of the message in the stream.
Reads stream until EOF if ``None`` is given.
Returns
--------
@@ -1005,7 +1246,8 @@ class Message(ABC):
# Got some data over the wire
self._serialized_on_wire = True
proto_meta = self._betterproto
for parsed in parse_fields(data):
read = 0
for parsed in load_fields(stream):
field_name = proto_meta.field_name_by_number.get(parsed.number)
if not field_name:
self._unknown_fields += parsed.raw
@@ -1051,8 +1293,46 @@ class Message(ABC):
else:
setattr(self, field_name, value)
# If we have now loaded the expected length of the message, stop
if size is not None:
prev = read
read += len(parsed.raw)
if read == size:
break
elif read > size:
raise ValueError(
f"Expected message of size {size}, can only read "
f"either {prev} or {read} bytes - there is no "
"message of the expected size in the stream."
)
if size is not None and read < size:
raise ValueError(
f"Expected message of size {size}, but was only able to "
f"read {read} bytes - the stream may have ended too soon,"
" or the expected size may have been incorrect."
)
return self
def parse(self: T, data: "ReadableBuffer") -> T:
"""
Parse the binary encoded Protobuf into this message instance. This
returns the instance itself and is therefore assignable and chainable.
Parameters
-----------
data: :class:`bytes`
The data to parse the message from.
Returns
--------
:class:`Message`
The initialized message.
"""
with BytesIO(data) as stream:
return self.load(stream)
# For compatibility with other libraries.
@classmethod
def FromString(cls: Type[T], data: bytes) -> T: