Add message streaming support (#518)
This commit is contained in:
parent
4cdf1bb9e0
commit
8659c51123
@ -17,14 +17,16 @@ from datetime import (
|
|||||||
timedelta,
|
timedelta,
|
||||||
timezone,
|
timezone,
|
||||||
)
|
)
|
||||||
|
from io import BytesIO
|
||||||
|
from itertools import count
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
BinaryIO,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
@ -46,6 +48,10 @@ from .casing import (
|
|||||||
from .grpc.grpclib_client import ServiceStub
|
from .grpc.grpclib_client import ServiceStub
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from _typeshed import ReadableBuffer
|
||||||
|
|
||||||
|
|
||||||
# Proto 3 data types
|
# Proto 3 data types
|
||||||
TYPE_ENUM = "enum"
|
TYPE_ENUM = "enum"
|
||||||
TYPE_BOOL = "bool"
|
TYPE_BOOL = "bool"
|
||||||
@ -66,7 +72,6 @@ TYPE_BYTES = "bytes"
|
|||||||
TYPE_MESSAGE = "message"
|
TYPE_MESSAGE = "message"
|
||||||
TYPE_MAP = "map"
|
TYPE_MAP = "map"
|
||||||
|
|
||||||
|
|
||||||
# Fields that use a fixed amount of space (4 or 8 bytes)
|
# Fields that use a fixed amount of space (4 or 8 bytes)
|
||||||
FIXED_TYPES = [
|
FIXED_TYPES = [
|
||||||
TYPE_FLOAT,
|
TYPE_FLOAT,
|
||||||
@ -129,7 +134,6 @@ def datetime_default_gen() -> datetime:
|
|||||||
|
|
||||||
DATETIME_ZERO = datetime_default_gen()
|
DATETIME_ZERO = datetime_default_gen()
|
||||||
|
|
||||||
|
|
||||||
# Special protobuf json doubles
|
# Special protobuf json doubles
|
||||||
INFINITY = "Infinity"
|
INFINITY = "Infinity"
|
||||||
NEG_INFINITY = "-Infinity"
|
NEG_INFINITY = "-Infinity"
|
||||||
@ -343,20 +347,43 @@ def _pack_fmt(proto_type: str) -> str:
|
|||||||
}[proto_type]
|
}[proto_type]
|
||||||
|
|
||||||
|
|
||||||
def encode_varint(value: int) -> bytes:
|
def dump_varint(value: int, stream: BinaryIO) -> None:
|
||||||
"""Encodes a single varint value for serialization."""
|
"""Encodes a single varint and dumps it into the provided stream."""
|
||||||
b: List[int] = []
|
if value < -(1 << 63):
|
||||||
|
raise ValueError(
|
||||||
if value < 0:
|
"Negative value is not representable as a 64-bit integer - unable to encode a varint within 10 bytes."
|
||||||
|
)
|
||||||
|
elif value < 0:
|
||||||
value += 1 << 64
|
value += 1 << 64
|
||||||
|
|
||||||
bits = value & 0x7F
|
bits = value & 0x7F
|
||||||
value >>= 7
|
value >>= 7
|
||||||
while value:
|
while value:
|
||||||
b.append(0x80 | bits)
|
stream.write((0x80 | bits).to_bytes(1, "little"))
|
||||||
bits = value & 0x7F
|
bits = value & 0x7F
|
||||||
value >>= 7
|
value >>= 7
|
||||||
return bytes(b + [bits])
|
stream.write(bits.to_bytes(1, "little"))
|
||||||
|
|
||||||
|
|
||||||
|
def encode_varint(value: int) -> bytes:
|
||||||
|
"""Encodes a single varint value for serialization."""
|
||||||
|
with BytesIO() as stream:
|
||||||
|
dump_varint(value, stream)
|
||||||
|
return stream.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def size_varint(value: int) -> int:
|
||||||
|
"""Calculates the size in bytes that a value would take as a varint."""
|
||||||
|
if value < -(1 << 63):
|
||||||
|
raise ValueError(
|
||||||
|
"Negative value is not representable as a 64-bit integer - unable to encode a varint within 10 bytes."
|
||||||
|
)
|
||||||
|
elif value < 0:
|
||||||
|
return 10
|
||||||
|
elif value == 0:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return math.ceil(value.bit_length() / 7)
|
||||||
|
|
||||||
|
|
||||||
def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
|
def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
|
||||||
@ -394,6 +421,41 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _len_preprocessed_single(proto_type: str, wraps: str, value: Any) -> int:
|
||||||
|
"""Calculate the size of adjusted values for serialization without fully serializing them."""
|
||||||
|
if proto_type in (
|
||||||
|
TYPE_ENUM,
|
||||||
|
TYPE_BOOL,
|
||||||
|
TYPE_INT32,
|
||||||
|
TYPE_INT64,
|
||||||
|
TYPE_UINT32,
|
||||||
|
TYPE_UINT64,
|
||||||
|
):
|
||||||
|
return size_varint(value)
|
||||||
|
elif proto_type in (TYPE_SINT32, TYPE_SINT64):
|
||||||
|
# Handle zig-zag encoding.
|
||||||
|
return size_varint(value << 1 if value >= 0 else (value << 1) ^ (~0))
|
||||||
|
elif proto_type in FIXED_TYPES:
|
||||||
|
return len(struct.pack(_pack_fmt(proto_type), value))
|
||||||
|
elif proto_type == TYPE_STRING:
|
||||||
|
return len(value.encode("utf-8"))
|
||||||
|
elif proto_type == TYPE_MESSAGE:
|
||||||
|
if isinstance(value, datetime):
|
||||||
|
# Convert the `datetime` to a timestamp message.
|
||||||
|
value = _Timestamp.from_datetime(value)
|
||||||
|
elif isinstance(value, timedelta):
|
||||||
|
# Convert the `timedelta` to a duration message.
|
||||||
|
value = _Duration.from_timedelta(value)
|
||||||
|
elif wraps:
|
||||||
|
if value is None:
|
||||||
|
return 0
|
||||||
|
value = _get_wrapper(wraps)(value=value)
|
||||||
|
|
||||||
|
return len(bytes(value))
|
||||||
|
|
||||||
|
return len(value)
|
||||||
|
|
||||||
|
|
||||||
def _serialize_single(
|
def _serialize_single(
|
||||||
field_number: int,
|
field_number: int,
|
||||||
proto_type: str,
|
proto_type: str,
|
||||||
@ -425,6 +487,31 @@ def _serialize_single(
|
|||||||
return bytes(output)
|
return bytes(output)
|
||||||
|
|
||||||
|
|
||||||
|
def _len_single(
|
||||||
|
field_number: int,
|
||||||
|
proto_type: str,
|
||||||
|
value: Any,
|
||||||
|
*,
|
||||||
|
serialize_empty: bool = False,
|
||||||
|
wraps: str = "",
|
||||||
|
) -> int:
|
||||||
|
"""Calculates the size of a serialized single field and value."""
|
||||||
|
size = _len_preprocessed_single(proto_type, wraps, value)
|
||||||
|
if proto_type in WIRE_VARINT_TYPES:
|
||||||
|
size += size_varint(field_number << 3)
|
||||||
|
elif proto_type in WIRE_FIXED_32_TYPES:
|
||||||
|
size += size_varint((field_number << 3) | 5)
|
||||||
|
elif proto_type in WIRE_FIXED_64_TYPES:
|
||||||
|
size += size_varint((field_number << 3) | 1)
|
||||||
|
elif proto_type in WIRE_LEN_DELIM_TYPES:
|
||||||
|
if size or serialize_empty or wraps:
|
||||||
|
size += size_varint((field_number << 3) | 2) + size_varint(size)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(proto_type)
|
||||||
|
|
||||||
|
return size
|
||||||
|
|
||||||
|
|
||||||
def _parse_float(value: Any) -> float:
|
def _parse_float(value: Any) -> float:
|
||||||
"""Parse the given value to a float
|
"""Parse the given value to a float
|
||||||
|
|
||||||
@ -469,22 +556,34 @@ def _dump_float(value: float) -> Union[float, str]:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def load_varint(stream: BinaryIO) -> Tuple[int, bytes]:
|
||||||
|
"""
|
||||||
|
Load a single varint value from a stream. Returns the value and the raw bytes read.
|
||||||
|
"""
|
||||||
|
result = 0
|
||||||
|
raw = b""
|
||||||
|
for shift in count(0, 7):
|
||||||
|
if shift >= 64:
|
||||||
|
raise ValueError("Too many bytes when decoding varint.")
|
||||||
|
b = stream.read(1)
|
||||||
|
if not b:
|
||||||
|
raise EOFError("Stream ended unexpectedly while attempting to load varint.")
|
||||||
|
raw += b
|
||||||
|
b_int = int.from_bytes(b, byteorder="little")
|
||||||
|
result |= (b_int & 0x7F) << shift
|
||||||
|
if not (b_int & 0x80):
|
||||||
|
return result, raw
|
||||||
|
|
||||||
|
|
||||||
def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]:
|
def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
Decode a single varint value from a byte buffer. Returns the value and the
|
Decode a single varint value from a byte buffer. Returns the value and the
|
||||||
new position in the buffer.
|
new position in the buffer.
|
||||||
"""
|
"""
|
||||||
result = 0
|
with BytesIO(buffer) as stream:
|
||||||
shift = 0
|
stream.seek(pos)
|
||||||
while 1:
|
value, raw = load_varint(stream)
|
||||||
b = buffer[pos]
|
return value, pos + len(raw)
|
||||||
result |= (b & 0x7F) << shift
|
|
||||||
pos += 1
|
|
||||||
if not (b & 0x80):
|
|
||||||
return result, pos
|
|
||||||
shift += 7
|
|
||||||
if shift >= 64:
|
|
||||||
raise ValueError("Too many bytes when decoding varint.")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
@ -495,6 +594,34 @@ class ParsedField:
|
|||||||
raw: bytes
|
raw: bytes
|
||||||
|
|
||||||
|
|
||||||
|
def load_fields(stream: BinaryIO) -> Generator[ParsedField, None, None]:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
num_wire, raw = load_varint(stream)
|
||||||
|
except EOFError:
|
||||||
|
return
|
||||||
|
number = num_wire >> 3
|
||||||
|
wire_type = num_wire & 0x7
|
||||||
|
|
||||||
|
decoded: Any = None
|
||||||
|
if wire_type == WIRE_VARINT:
|
||||||
|
decoded, r = load_varint(stream)
|
||||||
|
raw += r
|
||||||
|
elif wire_type == WIRE_FIXED_64:
|
||||||
|
decoded = stream.read(8)
|
||||||
|
raw += decoded
|
||||||
|
elif wire_type == WIRE_LEN_DELIM:
|
||||||
|
length, r = load_varint(stream)
|
||||||
|
decoded = stream.read(length)
|
||||||
|
raw += r
|
||||||
|
raw += decoded
|
||||||
|
elif wire_type == WIRE_FIXED_32:
|
||||||
|
decoded = stream.read(4)
|
||||||
|
raw += decoded
|
||||||
|
|
||||||
|
yield ParsedField(number=number, wire_type=wire_type, value=decoded, raw=raw)
|
||||||
|
|
||||||
|
|
||||||
def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(value):
|
while i < len(value):
|
||||||
@ -775,11 +902,16 @@ class Message(ABC):
|
|||||||
self.__class__._betterproto_meta = meta # type: ignore
|
self.__class__._betterproto_meta = meta # type: ignore
|
||||||
return meta
|
return meta
|
||||||
|
|
||||||
def __bytes__(self) -> bytes:
|
def dump(self, stream: BinaryIO) -> None:
|
||||||
"""
|
"""
|
||||||
Get the binary encoded Protobuf representation of this message instance.
|
Dumps the binary encoded Protobuf message to the stream.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
stream: :class:`BinaryIO`
|
||||||
|
The stream to dump the message to.
|
||||||
"""
|
"""
|
||||||
output = bytearray()
|
|
||||||
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
||||||
try:
|
try:
|
||||||
value = getattr(self, field_name)
|
value = getattr(self, field_name)
|
||||||
@ -825,10 +957,10 @@ class Message(ABC):
|
|||||||
buf = bytearray()
|
buf = bytearray()
|
||||||
for item in value:
|
for item in value:
|
||||||
buf += _preprocess_single(meta.proto_type, "", item)
|
buf += _preprocess_single(meta.proto_type, "", item)
|
||||||
output += _serialize_single(meta.number, TYPE_BYTES, buf)
|
stream.write(_serialize_single(meta.number, TYPE_BYTES, buf))
|
||||||
else:
|
else:
|
||||||
for item in value:
|
for item in value:
|
||||||
output += (
|
stream.write(
|
||||||
_serialize_single(
|
_serialize_single(
|
||||||
meta.number,
|
meta.number,
|
||||||
meta.proto_type,
|
meta.proto_type,
|
||||||
@ -846,7 +978,9 @@ class Message(ABC):
|
|||||||
assert meta.map_types
|
assert meta.map_types
|
||||||
sk = _serialize_single(1, meta.map_types[0], k)
|
sk = _serialize_single(1, meta.map_types[0], k)
|
||||||
sv = _serialize_single(2, meta.map_types[1], v)
|
sv = _serialize_single(2, meta.map_types[1], v)
|
||||||
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
|
stream.write(
|
||||||
|
_serialize_single(meta.number, meta.proto_type, sk + sv)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# If we have an empty string and we're including the default value for
|
# If we have an empty string and we're including the default value for
|
||||||
# a oneof, make sure we serialize it. This ensures that the byte string
|
# a oneof, make sure we serialize it. This ensures that the byte string
|
||||||
@ -859,7 +993,111 @@ class Message(ABC):
|
|||||||
):
|
):
|
||||||
serialize_empty = True
|
serialize_empty = True
|
||||||
|
|
||||||
output += _serialize_single(
|
stream.write(
|
||||||
|
_serialize_single(
|
||||||
|
meta.number,
|
||||||
|
meta.proto_type,
|
||||||
|
value,
|
||||||
|
serialize_empty=serialize_empty or bool(selected_in_group),
|
||||||
|
wraps=meta.wraps or "",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
stream.write(self._unknown_fields)
|
||||||
|
|
||||||
|
def __bytes__(self) -> bytes:
|
||||||
|
"""
|
||||||
|
Get the binary encoded Protobuf representation of this message instance.
|
||||||
|
"""
|
||||||
|
with BytesIO() as stream:
|
||||||
|
self.dump(stream)
|
||||||
|
return stream.getvalue()
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Get the size of the encoded Protobuf representation of this message instance.
|
||||||
|
"""
|
||||||
|
size = 0
|
||||||
|
for field_name, meta in self._betterproto.meta_by_field_name.items():
|
||||||
|
try:
|
||||||
|
value = getattr(self, field_name)
|
||||||
|
except AttributeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
# Optional items should be skipped. This is used for the Google
|
||||||
|
# wrapper types and proto3 field presence/optional fields.
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Being selected in a group means this field is the one that is
|
||||||
|
# currently set in a `oneof` group, so it must be serialized even
|
||||||
|
# if the value is the default zero value.
|
||||||
|
#
|
||||||
|
# Note that proto3 field presence/optional fields are put in a
|
||||||
|
# synthetic single-item oneof by protoc, which helps us ensure we
|
||||||
|
# send the value even if the value is the default zero value.
|
||||||
|
selected_in_group = bool(meta.group)
|
||||||
|
|
||||||
|
# Empty messages can still be sent on the wire if they were
|
||||||
|
# set (or received empty).
|
||||||
|
serialize_empty = isinstance(value, Message) and value._serialized_on_wire
|
||||||
|
|
||||||
|
include_default_value_for_oneof = self._include_default_value_for_oneof(
|
||||||
|
field_name=field_name, meta=meta
|
||||||
|
)
|
||||||
|
|
||||||
|
if value == self._get_field_default(field_name) and not (
|
||||||
|
selected_in_group or serialize_empty or include_default_value_for_oneof
|
||||||
|
):
|
||||||
|
# Default (zero) values are not serialized. Two exceptions are
|
||||||
|
# if this is the selected oneof item or if we know we have to
|
||||||
|
# serialize an empty message (i.e. zero value was explicitly
|
||||||
|
# set by the user).
|
||||||
|
continue
|
||||||
|
|
||||||
|
if isinstance(value, list):
|
||||||
|
if meta.proto_type in PACKED_TYPES:
|
||||||
|
# Packed lists look like a length-delimited field. First,
|
||||||
|
# preprocess/encode each value into a buffer and then
|
||||||
|
# treat it like a field of raw bytes.
|
||||||
|
buf = bytearray()
|
||||||
|
for item in value:
|
||||||
|
buf += _preprocess_single(meta.proto_type, "", item)
|
||||||
|
size += _len_single(meta.number, TYPE_BYTES, buf)
|
||||||
|
else:
|
||||||
|
for item in value:
|
||||||
|
size += (
|
||||||
|
_len_single(
|
||||||
|
meta.number,
|
||||||
|
meta.proto_type,
|
||||||
|
item,
|
||||||
|
wraps=meta.wraps or "",
|
||||||
|
serialize_empty=True,
|
||||||
|
)
|
||||||
|
# if it's an empty message it still needs to be represented
|
||||||
|
# as an item in the repeated list
|
||||||
|
or 2
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
for k, v in value.items():
|
||||||
|
assert meta.map_types
|
||||||
|
sk = _serialize_single(1, meta.map_types[0], k)
|
||||||
|
sv = _serialize_single(2, meta.map_types[1], v)
|
||||||
|
size += _len_single(meta.number, meta.proto_type, sk + sv)
|
||||||
|
else:
|
||||||
|
# If we have an empty string and we're including the default value for
|
||||||
|
# a oneof, make sure we serialize it. This ensures that the byte string
|
||||||
|
# output isn't simply an empty string. This also ensures that round trip
|
||||||
|
# serialization will keep `which_one_of` calls consistent.
|
||||||
|
if (
|
||||||
|
isinstance(value, str)
|
||||||
|
and value == ""
|
||||||
|
and include_default_value_for_oneof
|
||||||
|
):
|
||||||
|
serialize_empty = True
|
||||||
|
|
||||||
|
size += _len_single(
|
||||||
meta.number,
|
meta.number,
|
||||||
meta.proto_type,
|
meta.proto_type,
|
||||||
value,
|
value,
|
||||||
@ -867,8 +1105,8 @@ class Message(ABC):
|
|||||||
wraps=meta.wraps or "",
|
wraps=meta.wraps or "",
|
||||||
)
|
)
|
||||||
|
|
||||||
output += self._unknown_fields
|
size += len(self._unknown_fields)
|
||||||
return bytes(output)
|
return size
|
||||||
|
|
||||||
# For compatibility with other libraries
|
# For compatibility with other libraries
|
||||||
def SerializeToString(self: T) -> bytes:
|
def SerializeToString(self: T) -> bytes:
|
||||||
@ -987,15 +1225,18 @@ class Message(ABC):
|
|||||||
meta.group is not None and self._group_current.get(meta.group) == field_name
|
meta.group is not None and self._group_current.get(meta.group) == field_name
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse(self: T, data: bytes) -> T:
|
def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T:
|
||||||
"""
|
"""
|
||||||
Parse the binary encoded Protobuf into this message instance. This
|
Load the binary encoded Protobuf from a stream into this message instance. This
|
||||||
returns the instance itself and is therefore assignable and chainable.
|
returns the instance itself and is therefore assignable and chainable.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
-----------
|
-----------
|
||||||
data: :class:`bytes`
|
stream: :class:`bytes`
|
||||||
The data to parse the protobuf from.
|
The stream to load the message from.
|
||||||
|
size: :class:`Optional[int]`
|
||||||
|
The size of the message in the stream.
|
||||||
|
Reads stream until EOF if ``None`` is given.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
--------
|
--------
|
||||||
@ -1005,7 +1246,8 @@ class Message(ABC):
|
|||||||
# Got some data over the wire
|
# Got some data over the wire
|
||||||
self._serialized_on_wire = True
|
self._serialized_on_wire = True
|
||||||
proto_meta = self._betterproto
|
proto_meta = self._betterproto
|
||||||
for parsed in parse_fields(data):
|
read = 0
|
||||||
|
for parsed in load_fields(stream):
|
||||||
field_name = proto_meta.field_name_by_number.get(parsed.number)
|
field_name = proto_meta.field_name_by_number.get(parsed.number)
|
||||||
if not field_name:
|
if not field_name:
|
||||||
self._unknown_fields += parsed.raw
|
self._unknown_fields += parsed.raw
|
||||||
@ -1051,8 +1293,46 @@ class Message(ABC):
|
|||||||
else:
|
else:
|
||||||
setattr(self, field_name, value)
|
setattr(self, field_name, value)
|
||||||
|
|
||||||
|
# If we have now loaded the expected length of the message, stop
|
||||||
|
if size is not None:
|
||||||
|
prev = read
|
||||||
|
read += len(parsed.raw)
|
||||||
|
if read == size:
|
||||||
|
break
|
||||||
|
elif read > size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected message of size {size}, can only read "
|
||||||
|
f"either {prev} or {read} bytes - there is no "
|
||||||
|
"message of the expected size in the stream."
|
||||||
|
)
|
||||||
|
|
||||||
|
if size is not None and read < size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected message of size {size}, but was only able to "
|
||||||
|
f"read {read} bytes - the stream may have ended too soon,"
|
||||||
|
" or the expected size may have been incorrect."
|
||||||
|
)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def parse(self: T, data: "ReadableBuffer") -> T:
|
||||||
|
"""
|
||||||
|
Parse the binary encoded Protobuf into this message instance. This
|
||||||
|
returns the instance itself and is therefore assignable and chainable.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
data: :class:`bytes`
|
||||||
|
The data to parse the message from.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
--------
|
||||||
|
:class:`Message`
|
||||||
|
The initialized message.
|
||||||
|
"""
|
||||||
|
with BytesIO(data) as stream:
|
||||||
|
return self.load(stream)
|
||||||
|
|
||||||
# For compatibility with other libraries.
|
# For compatibility with other libraries.
|
||||||
@classmethod
|
@classmethod
|
||||||
def FromString(cls: Type[T], data: bytes) -> T:
|
def FromString(cls: Type[T], data: bytes) -> T:
|
||||||
|
1
tests/streams/dump_varint_negative.expected
Normal file
1
tests/streams/dump_varint_negative.expected
Normal file
@ -0,0 +1 @@
|
|||||||
|
<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><01>ӝ<EFBFBD><D39D><EFBFBD><EFBFBD><EFBFBD><EFBFBD><01><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><01><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
|
1
tests/streams/dump_varint_positive.expected
Normal file
1
tests/streams/dump_varint_positive.expected
Normal file
@ -0,0 +1 @@
|
|||||||
|
Ђв
|
1
tests/streams/load_varint_cutoff.in
Normal file
1
tests/streams/load_varint_cutoff.in
Normal file
@ -0,0 +1 @@
|
|||||||
|
ȁ
|
2
tests/streams/message_dump_file_multiple.expected
Normal file
2
tests/streams/message_dump_file_multiple.expected
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
•šï:bTesting•šï:bTesting
|
||||||
|
|
1
tests/streams/message_dump_file_single.expected
Normal file
1
tests/streams/message_dump_file_single.expected
Normal file
@ -0,0 +1 @@
|
|||||||
|
•šï:bTesting
|
268
tests/test_streams.py
Normal file
268
tests/test_streams.py
Normal file
@ -0,0 +1,268 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import betterproto
|
||||||
|
from tests.output_betterproto import (
|
||||||
|
map,
|
||||||
|
nested,
|
||||||
|
oneof,
|
||||||
|
repeated,
|
||||||
|
repeatedpacked,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
oneof_example = oneof.Test().from_dict(
|
||||||
|
{"pitied": 1, "just_a_regular_field": 123456789, "bar_name": "Testing"}
|
||||||
|
)
|
||||||
|
|
||||||
|
len_oneof = len(oneof_example)
|
||||||
|
|
||||||
|
nested_example = nested.Test().from_dict(
|
||||||
|
{
|
||||||
|
"nested": {"count": 1},
|
||||||
|
"sibling": {"foo": 2},
|
||||||
|
"sibling2": {"foo": 3},
|
||||||
|
"msg": nested.TestMsg.THIS,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
repeated_example = repeated.Test().from_dict({"names": ["blah", "Blah2"]})
|
||||||
|
|
||||||
|
packed_example = repeatedpacked.Test().from_dict(
|
||||||
|
{"counts": [1, 2, 3], "signed": [-1, 2, -3], "fixed": [1.2, -2.3, 3.4]}
|
||||||
|
)
|
||||||
|
|
||||||
|
map_example = map.Test().from_dict({"counts": {"blah": 1, "Blah2": 2}})
|
||||||
|
|
||||||
|
streams_path = Path("tests/streams/")
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_varint_too_long():
|
||||||
|
with BytesIO(
|
||||||
|
b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01"
|
||||||
|
) as stream, pytest.raises(ValueError):
|
||||||
|
betterproto.load_varint(stream)
|
||||||
|
|
||||||
|
with BytesIO(b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01") as stream:
|
||||||
|
# This should not raise a ValueError, as it is within 64 bits
|
||||||
|
betterproto.load_varint(stream)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_varint_file():
|
||||||
|
with open(streams_path / "message_dump_file_single.expected", "rb") as stream:
|
||||||
|
assert betterproto.load_varint(stream) == (8, b"\x08") # Single-byte varint
|
||||||
|
stream.read(2) # Skip until first multi-byte
|
||||||
|
assert betterproto.load_varint(stream) == (
|
||||||
|
123456789,
|
||||||
|
b"\x95\x9A\xEF\x3A",
|
||||||
|
) # Multi-byte varint
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_varint_cutoff():
|
||||||
|
with open(streams_path / "load_varint_cutoff.in", "rb") as stream:
|
||||||
|
with pytest.raises(EOFError):
|
||||||
|
betterproto.load_varint(stream)
|
||||||
|
|
||||||
|
stream.seek(1)
|
||||||
|
with pytest.raises(EOFError):
|
||||||
|
betterproto.load_varint(stream)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dump_varint_file(tmp_path):
|
||||||
|
# Dump test varints to file
|
||||||
|
with open(tmp_path / "dump_varint_file.out", "wb") as stream:
|
||||||
|
betterproto.dump_varint(8, stream) # Single-byte varint
|
||||||
|
betterproto.dump_varint(123456789, stream) # Multi-byte varint
|
||||||
|
|
||||||
|
# Check that file contents are as expected
|
||||||
|
with open(tmp_path / "dump_varint_file.out", "rb") as test_stream, open(
|
||||||
|
streams_path / "message_dump_file_single.expected", "rb"
|
||||||
|
) as exp_stream:
|
||||||
|
assert betterproto.load_varint(test_stream) == betterproto.load_varint(
|
||||||
|
exp_stream
|
||||||
|
)
|
||||||
|
exp_stream.read(2)
|
||||||
|
assert betterproto.load_varint(test_stream) == betterproto.load_varint(
|
||||||
|
exp_stream
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_fields():
|
||||||
|
with open(streams_path / "message_dump_file_single.expected", "rb") as stream:
|
||||||
|
parsed_bytes = betterproto.parse_fields(stream.read())
|
||||||
|
|
||||||
|
with open(streams_path / "message_dump_file_single.expected", "rb") as stream:
|
||||||
|
parsed_stream = betterproto.load_fields(stream)
|
||||||
|
for field in parsed_bytes:
|
||||||
|
assert field == next(parsed_stream)
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_dump_file_single(tmp_path):
|
||||||
|
# Write the message to the stream
|
||||||
|
with open(tmp_path / "message_dump_file_single.out", "wb") as stream:
|
||||||
|
oneof_example.dump(stream)
|
||||||
|
|
||||||
|
# Check that the outputted file is exactly as expected
|
||||||
|
with open(tmp_path / "message_dump_file_single.out", "rb") as test_stream, open(
|
||||||
|
streams_path / "message_dump_file_single.expected", "rb"
|
||||||
|
) as exp_stream:
|
||||||
|
assert test_stream.read() == exp_stream.read()
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_dump_file_multiple(tmp_path):
|
||||||
|
# Write the same Message twice and another, different message
|
||||||
|
with open(tmp_path / "message_dump_file_multiple.out", "wb") as stream:
|
||||||
|
oneof_example.dump(stream)
|
||||||
|
oneof_example.dump(stream)
|
||||||
|
nested_example.dump(stream)
|
||||||
|
|
||||||
|
# Check that all three Messages were outputted to the file correctly
|
||||||
|
with open(tmp_path / "message_dump_file_multiple.out", "rb") as test_stream, open(
|
||||||
|
streams_path / "message_dump_file_multiple.expected", "rb"
|
||||||
|
) as exp_stream:
|
||||||
|
assert test_stream.read() == exp_stream.read()
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_len():
|
||||||
|
assert len_oneof == len(bytes(oneof_example))
|
||||||
|
assert len(nested_example) == len(bytes(nested_example))
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_load_file_single():
|
||||||
|
with open(streams_path / "message_dump_file_single.expected", "rb") as stream:
|
||||||
|
assert oneof.Test().load(stream) == oneof_example
|
||||||
|
stream.seek(0)
|
||||||
|
assert oneof.Test().load(stream, len_oneof) == oneof_example
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_load_file_multiple():
|
||||||
|
with open(streams_path / "message_dump_file_multiple.expected", "rb") as stream:
|
||||||
|
oneof_size = len_oneof
|
||||||
|
assert oneof.Test().load(stream, oneof_size) == oneof_example
|
||||||
|
assert oneof.Test().load(stream, oneof_size) == oneof_example
|
||||||
|
assert nested.Test().load(stream) == nested_example
|
||||||
|
assert stream.read(1) == b""
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_load_too_small():
|
||||||
|
with open(
|
||||||
|
streams_path / "message_dump_file_single.expected", "rb"
|
||||||
|
) as stream, pytest.raises(ValueError):
|
||||||
|
oneof.Test().load(stream, len_oneof - 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_too_large():
|
||||||
|
with open(
|
||||||
|
streams_path / "message_dump_file_single.expected", "rb"
|
||||||
|
) as stream, pytest.raises(ValueError):
|
||||||
|
oneof.Test().load(stream, len_oneof + 1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_len_optional_field():
|
||||||
|
@dataclass
|
||||||
|
class Request(betterproto.Message):
|
||||||
|
flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL)
|
||||||
|
|
||||||
|
assert len(Request()) == len(b"")
|
||||||
|
assert len(Request(flag=True)) == len(b"\n\x02\x08\x01")
|
||||||
|
assert len(Request(flag=False)) == len(b"\n\x00")
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_len_repeated_field():
|
||||||
|
assert len(repeated_example) == len(bytes(repeated_example))
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_len_packed_field():
|
||||||
|
assert len(packed_example) == len(bytes(packed_example))
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_len_map_field():
|
||||||
|
assert len(map_example) == len(bytes(map_example))
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_len_empty_string():
|
||||||
|
@dataclass
|
||||||
|
class Empty(betterproto.Message):
|
||||||
|
string: str = betterproto.string_field(1, "group")
|
||||||
|
integer: int = betterproto.int32_field(2, "group")
|
||||||
|
|
||||||
|
empty = Empty().from_dict({"string": ""})
|
||||||
|
assert len(empty) == len(bytes(empty))
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate_varint_size_negative():
|
||||||
|
single_byte = -1
|
||||||
|
multi_byte = -10000000
|
||||||
|
edge = -(1 << 63)
|
||||||
|
beyond = -(1 << 63) - 1
|
||||||
|
before = -(1 << 63) + 1
|
||||||
|
|
||||||
|
assert (
|
||||||
|
betterproto.size_varint(single_byte)
|
||||||
|
== len(betterproto.encode_varint(single_byte))
|
||||||
|
== 10
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
betterproto.size_varint(multi_byte)
|
||||||
|
== len(betterproto.encode_varint(multi_byte))
|
||||||
|
== 10
|
||||||
|
)
|
||||||
|
assert betterproto.size_varint(edge) == len(betterproto.encode_varint(edge)) == 10
|
||||||
|
assert (
|
||||||
|
betterproto.size_varint(before) == len(betterproto.encode_varint(before)) == 10
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
betterproto.size_varint(beyond)
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate_varint_size_positive():
|
||||||
|
single_byte = 1
|
||||||
|
multi_byte = 10000000
|
||||||
|
|
||||||
|
assert betterproto.size_varint(single_byte) == len(
|
||||||
|
betterproto.encode_varint(single_byte)
|
||||||
|
)
|
||||||
|
assert betterproto.size_varint(multi_byte) == len(
|
||||||
|
betterproto.encode_varint(multi_byte)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dump_varint_negative(tmp_path):
|
||||||
|
single_byte = -1
|
||||||
|
multi_byte = -10000000
|
||||||
|
edge = -(1 << 63)
|
||||||
|
beyond = -(1 << 63) - 1
|
||||||
|
before = -(1 << 63) + 1
|
||||||
|
|
||||||
|
with open(tmp_path / "dump_varint_negative.out", "wb") as stream:
|
||||||
|
betterproto.dump_varint(single_byte, stream)
|
||||||
|
betterproto.dump_varint(multi_byte, stream)
|
||||||
|
betterproto.dump_varint(edge, stream)
|
||||||
|
betterproto.dump_varint(before, stream)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
betterproto.dump_varint(beyond, stream)
|
||||||
|
|
||||||
|
with open(streams_path / "dump_varint_negative.expected", "rb") as exp_stream, open(
|
||||||
|
tmp_path / "dump_varint_negative.out", "rb"
|
||||||
|
) as test_stream:
|
||||||
|
assert test_stream.read() == exp_stream.read()
|
||||||
|
|
||||||
|
|
||||||
|
def test_dump_varint_positive(tmp_path):
|
||||||
|
single_byte = 1
|
||||||
|
multi_byte = 10000000
|
||||||
|
|
||||||
|
with open(tmp_path / "dump_varint_positive.out", "wb") as stream:
|
||||||
|
betterproto.dump_varint(single_byte, stream)
|
||||||
|
betterproto.dump_varint(multi_byte, stream)
|
||||||
|
|
||||||
|
with open(tmp_path / "dump_varint_positive.out", "rb") as test_stream, open(
|
||||||
|
streams_path / "dump_varint_positive.expected", "rb"
|
||||||
|
) as exp_stream:
|
||||||
|
assert test_stream.read() == exp_stream.read()
|
Loading…
x
Reference in New Issue
Block a user