From 0814729c5af0f96ad933ae0e72c695fd0dce8d41 Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Mon, 15 Jun 2020 18:14:13 +0200 Subject: [PATCH] Add cases for send() --- betterproto/tests/grpc/test_stream_stream.py | 43 ++++++++++++++++---- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/betterproto/tests/grpc/test_stream_stream.py b/betterproto/tests/grpc/test_stream_stream.py index 5768189..3c2c7e2 100644 --- a/betterproto/tests/grpc/test_stream_stream.py +++ b/betterproto/tests/grpc/test_stream_stream.py @@ -13,20 +13,13 @@ class Message(betterproto.Message): body: str = betterproto.string_field(1) -async def to_list(generator: AsyncIterator): - lis = [] - async for value in generator: - lis.append(value) - return lis - - @pytest.fixture def expected_responses(): return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")] class ClientStub: - async def connect(self, requests): + async def connect(self, requests: AsyncIterator): await asyncio.sleep(0.1) async for request in requests: await asyncio.sleep(0.1) @@ -35,6 +28,13 @@ class ClientStub: yield Message("Done") +async def to_list(generator: AsyncIterator): + lis = [] + async for value in generator: + lis.append(value) + return lis + + @pytest.fixture def client(): # channel = Channel(host='127.0.0.1', port=50051) @@ -122,3 +122,30 @@ async def test_send_from_close_manually_immediately(client, expected_responses): requests.close() assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_individually_and_close_before_connect(client, expected_responses): + requests = AsyncChannel() + + await requests.send(Message(body="Hello world 1")) + await requests.send(Message(body="Hello world 2")) + requests.close() + + responses = client.connect(requests) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_individually_and_close_after_connect(client, expected_responses): + requests = AsyncChannel() + + await requests.send(Message(body="Hello world 1")) + await requests.send(Message(body="Hello world 2")) + + responses = client.connect(requests) + + requests.close() + + assert await to_list(responses) == expected_responses