diff --git a/betterproto/tests/grpc/test_stream_stream.py b/betterproto/tests/grpc/test_stream_stream.py index 5768189..3c2c7e2 100644 --- a/betterproto/tests/grpc/test_stream_stream.py +++ b/betterproto/tests/grpc/test_stream_stream.py @@ -13,20 +13,13 @@ class Message(betterproto.Message): body: str = betterproto.string_field(1) -async def to_list(generator: AsyncIterator): - lis = [] - async for value in generator: - lis.append(value) - return lis - - @pytest.fixture def expected_responses(): return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")] class ClientStub: - async def connect(self, requests): + async def connect(self, requests: AsyncIterator): await asyncio.sleep(0.1) async for request in requests: await asyncio.sleep(0.1) @@ -35,6 +28,13 @@ class ClientStub: yield Message("Done") +async def to_list(generator: AsyncIterator): + lis = [] + async for value in generator: + lis.append(value) + return lis + + @pytest.fixture def client(): # channel = Channel(host='127.0.0.1', port=50051) @@ -122,3 +122,30 @@ async def test_send_from_close_manually_immediately(client, expected_responses): requests.close() assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_individually_and_close_before_connect(client, expected_responses): + requests = AsyncChannel() + + await requests.send(Message(body="Hello world 1")) + await requests.send(Message(body="Hello world 2")) + requests.close() + + responses = client.connect(requests) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_individually_and_close_after_connect(client, expected_responses): + requests = AsyncChannel() + + await requests.send(Message(body="Hello world 1")) + await requests.send(Message(body="Hello world 2")) + + responses = client.connect(requests) + + requests.close() + + assert await to_list(responses) == expected_responses