Add support for streaming delimited messages (#529)
* Add support for streaming delimited messages This allows developers to easily dump and load multiple messages from a stream in a way that is compatible with official protobuf implementations (such as Java's `MessageLite#writeDelimitedTo(...)`). * Add Java compatibility tests for streaming These tests stream data such as messages to output files, have a Java binary read them and then write them back using the `protobuf-java` functions, and then read them back in on the Python side to check that the returned data is as expected. This checks that the official Java implementation (and so any other matching implementations) can properly parse outputs from Betterproto, and vice-versa, ensuring compatibility in these functions between the two. * Replace `xxxxableBuffer` with `SupportsXxxx`
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from shutil import which
|
||||
from subprocess import run
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@@ -40,6 +42,8 @@ map_example = map.Test().from_dict({"counts": {"blah": 1, "Blah2": 2}})
|
||||
|
||||
streams_path = Path("tests/streams/")
|
||||
|
||||
java = which("java")
|
||||
|
||||
|
||||
def test_load_varint_too_long():
|
||||
with BytesIO(
|
||||
@@ -127,6 +131,18 @@ def test_message_dump_file_multiple(tmp_path):
|
||||
assert test_stream.read() == exp_stream.read()
|
||||
|
||||
|
||||
def test_message_dump_delimited(tmp_path):
|
||||
with open(tmp_path / "message_dump_delimited.out", "wb") as stream:
|
||||
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||
nested_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||
|
||||
with open(tmp_path / "message_dump_delimited.out", "rb") as test_stream, open(
|
||||
streams_path / "delimited_messages.in", "rb"
|
||||
) as exp_stream:
|
||||
assert test_stream.read() == exp_stream.read()
|
||||
|
||||
|
||||
def test_message_len():
|
||||
assert len_oneof == len(bytes(oneof_example))
|
||||
assert len(nested_example) == len(bytes(nested_example))
|
||||
@@ -155,7 +171,15 @@ def test_message_load_too_small():
|
||||
oneof.Test().load(stream, len_oneof - 1)
|
||||
|
||||
|
||||
def test_message_too_large():
|
||||
def test_message_load_delimited():
|
||||
with open(streams_path / "delimited_messages.in", "rb") as stream:
|
||||
assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example
|
||||
assert oneof.Test().load(stream, betterproto.SIZE_DELIMITED) == oneof_example
|
||||
assert nested.Test().load(stream, betterproto.SIZE_DELIMITED) == nested_example
|
||||
assert stream.read(1) == b""
|
||||
|
||||
|
||||
def test_message_load_too_large():
|
||||
with open(
|
||||
streams_path / "message_dump_file_single.expected", "rb"
|
||||
) as stream, pytest.raises(ValueError):
|
||||
@@ -266,3 +290,145 @@ def test_dump_varint_positive(tmp_path):
|
||||
streams_path / "dump_varint_positive.expected", "rb"
|
||||
) as exp_stream:
|
||||
assert test_stream.read() == exp_stream.read()
|
||||
|
||||
|
||||
# Java compatibility tests
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def compile_jar():
|
||||
# Skip if not all required tools are present
|
||||
if java is None:
|
||||
pytest.skip("`java` command is absent and is required")
|
||||
mvn = which("mvn")
|
||||
if mvn is None:
|
||||
pytest.skip("Maven is absent and is required")
|
||||
|
||||
# Compile the JAR
|
||||
proc_maven = run([mvn, "clean", "install", "-f", "tests/streams/java/pom.xml"])
|
||||
if proc_maven.returncode != 0:
|
||||
pytest.skip(
|
||||
"Maven compatibility-test.jar build failed (maybe Java version <11?)"
|
||||
)
|
||||
|
||||
|
||||
jar = "tests/streams/java/target/compatibility-test.jar"
|
||||
|
||||
|
||||
def run_jar(command: str, tmp_path):
|
||||
return run([java, "-jar", jar, command, tmp_path], check=True)
|
||||
|
||||
|
||||
def run_java_single_varint(value: int, tmp_path) -> int:
|
||||
# Write single varint to file
|
||||
with open(tmp_path / "py_single_varint.out", "wb") as stream:
|
||||
betterproto.dump_varint(value, stream)
|
||||
|
||||
# Have Java read this varint and write it back
|
||||
run_jar("single_varint", tmp_path)
|
||||
|
||||
# Read single varint from Java output file
|
||||
with open(tmp_path / "java_single_varint.out", "rb") as stream:
|
||||
returned = betterproto.load_varint(stream)
|
||||
with pytest.raises(EOFError):
|
||||
betterproto.load_varint(stream)
|
||||
|
||||
return returned
|
||||
|
||||
|
||||
def test_single_varint(compile_jar, tmp_path):
|
||||
single_byte = (1, b"\x01")
|
||||
multi_byte = (123456789, b"\x95\x9A\xEF\x3A")
|
||||
|
||||
# Write a single-byte varint to a file and have Java read it back
|
||||
returned = run_java_single_varint(single_byte[0], tmp_path)
|
||||
assert returned == single_byte
|
||||
|
||||
# Same for a multi-byte varint
|
||||
returned = run_java_single_varint(multi_byte[0], tmp_path)
|
||||
assert returned == multi_byte
|
||||
|
||||
|
||||
def test_multiple_varints(compile_jar, tmp_path):
|
||||
single_byte = (1, b"\x01")
|
||||
multi_byte = (123456789, b"\x95\x9A\xEF\x3A")
|
||||
over32 = (3000000000, b"\x80\xBC\xC1\x96\x0B")
|
||||
|
||||
# Write two varints to the same file
|
||||
with open(tmp_path / "py_multiple_varints.out", "wb") as stream:
|
||||
betterproto.dump_varint(single_byte[0], stream)
|
||||
betterproto.dump_varint(multi_byte[0], stream)
|
||||
betterproto.dump_varint(over32[0], stream)
|
||||
|
||||
# Have Java read these varints and write them back
|
||||
run_jar("multiple_varints", tmp_path)
|
||||
|
||||
# Read varints from Java output file
|
||||
with open(tmp_path / "java_multiple_varints.out", "rb") as stream:
|
||||
returned_single = betterproto.load_varint(stream)
|
||||
returned_multi = betterproto.load_varint(stream)
|
||||
returned_over32 = betterproto.load_varint(stream)
|
||||
with pytest.raises(EOFError):
|
||||
betterproto.load_varint(stream)
|
||||
|
||||
assert returned_single == single_byte
|
||||
assert returned_multi == multi_byte
|
||||
assert returned_over32 == over32
|
||||
|
||||
|
||||
def test_single_message(compile_jar, tmp_path):
|
||||
# Write message to file
|
||||
with open(tmp_path / "py_single_message.out", "wb") as stream:
|
||||
oneof_example.dump(stream)
|
||||
|
||||
# Have Java read and return the message
|
||||
run_jar("single_message", tmp_path)
|
||||
|
||||
# Read and check the returned message
|
||||
with open(tmp_path / "java_single_message.out", "rb") as stream:
|
||||
returned = oneof.Test().load(stream, len(bytes(oneof_example)))
|
||||
assert stream.read() == b""
|
||||
|
||||
assert returned == oneof_example
|
||||
|
||||
|
||||
def test_multiple_messages(compile_jar, tmp_path):
|
||||
# Write delimited messages to file
|
||||
with open(tmp_path / "py_multiple_messages.out", "wb") as stream:
|
||||
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||
nested_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||
|
||||
# Have Java read and return the messages
|
||||
run_jar("multiple_messages", tmp_path)
|
||||
|
||||
# Read and check the returned messages
|
||||
with open(tmp_path / "java_multiple_messages.out", "rb") as stream:
|
||||
returned_oneof = oneof.Test().load(stream, betterproto.SIZE_DELIMITED)
|
||||
returned_nested = nested.Test().load(stream, betterproto.SIZE_DELIMITED)
|
||||
assert stream.read() == b""
|
||||
|
||||
assert returned_oneof == oneof_example
|
||||
assert returned_nested == nested_example
|
||||
|
||||
|
||||
def test_infinite_messages(compile_jar, tmp_path):
|
||||
num_messages = 5
|
||||
|
||||
# Write delimited messages to file
|
||||
with open(tmp_path / "py_infinite_messages.out", "wb") as stream:
|
||||
for x in range(num_messages):
|
||||
oneof_example.dump(stream, betterproto.SIZE_DELIMITED)
|
||||
|
||||
# Have Java read and return the messages
|
||||
run_jar("infinite_messages", tmp_path)
|
||||
|
||||
# Read and check the returned messages
|
||||
messages = []
|
||||
with open(tmp_path / "java_infinite_messages.out", "rb") as stream:
|
||||
while True:
|
||||
try:
|
||||
messages.append(oneof.Test().load(stream, betterproto.SIZE_DELIMITED))
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
assert len(messages) == num_messages
|
||||
|
||||
Reference in New Issue
Block a user