From 159c30ddd8917a14a56c65b96c400fd4b9c90f5d Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Mon, 15 Jun 2020 18:02:05 +0200 Subject: [PATCH 1/4] Fix close not awaitable, fix done is callable, fix return async next value --- betterproto/grpc/util/async_channel.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/betterproto/grpc/util/async_channel.py b/betterproto/grpc/util/async_channel.py index 7e83c94..3a104ca 100644 --- a/betterproto/grpc/util/async_channel.py +++ b/betterproto/grpc/util/async_channel.py @@ -100,13 +100,14 @@ class AsyncChannel(AsyncIterable[T]): return self async def __anext__(self) -> T: - if self.done: + if self.done(): raise StopAsyncIteration self._waiting_recievers += 1 try: result = await self._queue.get() if result is self.__flush: raise StopAsyncIteration + return result finally: self._waiting_recievers -= 1 self._queue.task_done() @@ -151,7 +152,7 @@ class AsyncChannel(AsyncIterable[T]): await self._queue.put(item) if close: # Complete the closing process - await self.close() + self.close() async def send(self, item: T): """ @@ -168,7 +169,7 @@ class AsyncChannel(AsyncIterable[T]): or None if the channel is closed before another item is sent. :return: An item from the channel """ - if self.done: + if self.done(): raise ChannelDone("Cannot recieve from a closed channel") self._waiting_recievers += 1 try: From f7aa6150e25368d2c6be75e661fa2afe52eb05e0 Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Mon, 15 Jun 2020 18:02:37 +0200 Subject: [PATCH 2/4] Add test-cases for client stream-stream --- betterproto/tests/grpc/test_stream_stream.py | 124 +++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 betterproto/tests/grpc/test_stream_stream.py diff --git a/betterproto/tests/grpc/test_stream_stream.py b/betterproto/tests/grpc/test_stream_stream.py new file mode 100644 index 0000000..5768189 --- /dev/null +++ b/betterproto/tests/grpc/test_stream_stream.py @@ -0,0 +1,124 @@ +import asyncio +from dataclasses import dataclass +from typing import AsyncIterator + +import pytest + +import betterproto +from betterproto.grpc.util.async_channel import AsyncChannel + + +@dataclass +class Message(betterproto.Message): + body: str = betterproto.string_field(1) + + +async def to_list(generator: AsyncIterator): + lis = [] + async for value in generator: + lis.append(value) + return lis + + +@pytest.fixture +def expected_responses(): + return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")] + + +class ClientStub: + async def connect(self, requests): + await asyncio.sleep(0.1) + async for request in requests: + await asyncio.sleep(0.1) + yield request + await asyncio.sleep(0.1) + yield Message("Done") + + +@pytest.fixture +def client(): + # channel = Channel(host='127.0.0.1', port=50051) + # return ClientStub(channel) + return ClientStub() + + +@pytest.mark.asyncio +async def test_from_list_close_automatically(client, expected_responses): + requests = AsyncChannel( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True + ) + + responses = client.connect(requests) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_from_list_close_manually_immediately(client, expected_responses): + requests = AsyncChannel( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=False + ) + + requests.close() + + responses = client.connect(requests) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_from_list_close_manually_after_connect(client, expected_responses): + requests = AsyncChannel( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=False + ) + + responses = client.connect(requests) + + requests.close() + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_from_before_connect_and_close_automatically( + client, expected_responses +): + requests = AsyncChannel() + + await requests.send_from( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True + ) + + responses = client.connect(requests) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_from_after_connect_and_close_automatically( + client, expected_responses +): + requests = AsyncChannel() + + responses = client.connect(requests) + + await requests.send_from( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True + ) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_from_close_manually_immediately(client, expected_responses): + requests = AsyncChannel() + + responses = client.connect(requests) + + await requests.send_from( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=False + ) + + requests.close() + + assert await to_list(responses) == expected_responses From 0814729c5af0f96ad933ae0e72c695fd0dce8d41 Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Mon, 15 Jun 2020 18:14:13 +0200 Subject: [PATCH 3/4] Add cases for send() --- betterproto/tests/grpc/test_stream_stream.py | 43 ++++++++++++++++---- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/betterproto/tests/grpc/test_stream_stream.py b/betterproto/tests/grpc/test_stream_stream.py index 5768189..3c2c7e2 100644 --- a/betterproto/tests/grpc/test_stream_stream.py +++ b/betterproto/tests/grpc/test_stream_stream.py @@ -13,20 +13,13 @@ class Message(betterproto.Message): body: str = betterproto.string_field(1) -async def to_list(generator: AsyncIterator): - lis = [] - async for value in generator: - lis.append(value) - return lis - - @pytest.fixture def expected_responses(): return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")] class ClientStub: - async def connect(self, requests): + async def connect(self, requests: AsyncIterator): await asyncio.sleep(0.1) async for request in requests: await asyncio.sleep(0.1) @@ -35,6 +28,13 @@ class ClientStub: yield Message("Done") +async def to_list(generator: AsyncIterator): + lis = [] + async for value in generator: + lis.append(value) + return lis + + @pytest.fixture def client(): # channel = Channel(host='127.0.0.1', port=50051) @@ -122,3 +122,30 @@ async def test_send_from_close_manually_immediately(client, expected_responses): requests.close() assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_individually_and_close_before_connect(client, expected_responses): + requests = AsyncChannel() + + await requests.send(Message(body="Hello world 1")) + await requests.send(Message(body="Hello world 2")) + requests.close() + + responses = client.connect(requests) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_individually_and_close_after_connect(client, expected_responses): + requests = AsyncChannel() + + await requests.send(Message(body="Hello world 1")) + await requests.send(Message(body="Hello world 2")) + + responses = client.connect(requests) + + requests.close() + + assert await to_list(responses) == expected_responses From 50bb67bf5dca04ded331adbcdcedab3aed7d7de1 Mon Sep 17 00:00:00 2001 From: Nat Noordanus Date: Mon, 15 Jun 2020 23:35:56 +0200 Subject: [PATCH 4/4] Fix bugs and remove footgun feature in AsyncChannel --- betterproto/grpc/util/async_channel.py | 24 +++++++------------ betterproto/tests/grpc/test_grpclib_client.py | 22 ++++++++++------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/betterproto/grpc/util/async_channel.py b/betterproto/grpc/util/async_channel.py index 7e83c94..fd0ecc2 100644 --- a/betterproto/grpc/util/async_channel.py +++ b/betterproto/grpc/util/async_channel.py @@ -81,17 +81,10 @@ class AsyncChannel(AsyncIterable[T]): """ def __init__( - self, - source: Union[Iterable[T], AsyncIterable[T]] = tuple(), - *, - buffer_limit: int = 0, - close: bool = False, + self, *, buffer_limit: int = 0, close: bool = False, ): self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit) self._closed = False - self._sending_task = ( - asyncio.ensure_future(self.send_from(source, close)) if source else None - ) self._waiting_recievers: int = 0 # Track whether flush has been invoked so it can only happen once self._flushed = False @@ -100,13 +93,14 @@ class AsyncChannel(AsyncIterable[T]): return self async def __anext__(self) -> T: - if self.done: + if self.done(): raise StopAsyncIteration self._waiting_recievers += 1 try: result = await self._queue.get() if result is self.__flush: raise StopAsyncIteration + return result finally: self._waiting_recievers -= 1 self._queue.task_done() @@ -131,7 +125,7 @@ class AsyncChannel(AsyncIterable[T]): async def send_from( self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False - ): + ) -> "AsyncChannel[T]": """ Iterates the given [Async]Iterable and sends all the resulting items. If close is set to True then subsequent send calls will be rejected with a @@ -151,9 +145,10 @@ class AsyncChannel(AsyncIterable[T]): await self._queue.put(item) if close: # Complete the closing process - await self.close() + self.close() + return self - async def send(self, item: T): + async def send(self, item: T) -> "AsyncChannel[T]": """ Send a single item over this channel. :param item: The item to send @@ -161,6 +156,7 @@ class AsyncChannel(AsyncIterable[T]): if self._closed: raise ChannelClosed("Cannot send through a closed channel") await self._queue.put(item) + return self async def recieve(self) -> Optional[T]: """ @@ -168,7 +164,7 @@ class AsyncChannel(AsyncIterable[T]): or None if the channel is closed before another item is sent. :return: An item from the channel """ - if self.done: + if self.done(): raise ChannelDone("Cannot recieve from a closed channel") self._waiting_recievers += 1 try: @@ -184,8 +180,6 @@ class AsyncChannel(AsyncIterable[T]): """ Close this channel to new items """ - if self._sending_task is not None: - self._sending_task.cancel() self._closed = True asyncio.ensure_future(self._flush_queue()) diff --git a/betterproto/tests/grpc/test_grpclib_client.py b/betterproto/tests/grpc/test_grpclib_client.py index dc57fe4..6c34ece 100644 --- a/betterproto/tests/grpc/test_grpclib_client.py +++ b/betterproto/tests/grpc/test_grpclib_client.py @@ -1,3 +1,4 @@ +import asyncio from betterproto.tests.output_betterproto.service.service import ( DoThingResponse, DoThingRequest, @@ -129,7 +130,10 @@ async def test_async_gen_for_stream_stream_request(): # Use an AsyncChannel to decouple sending and recieving, it'll send some_things # immediately and we'll use it to send more_things later, after recieving some # results - request_chan = AsyncChannel(GetThingRequest(name) for name in some_things) + request_chan = AsyncChannel() + send_initial_requests = asyncio.ensure_future( + request_chan.send_from(GetThingRequest(name) for name in some_things) + ) response_index = 0 async for response in client.get_different_things(request_chan): assert response.name == expected_things[response_index] @@ -138,13 +142,13 @@ async def test_async_gen_for_stream_stream_request(): if more_things: # Send some more requests as we recieve reponses to be sure coordination of # send/recieve events doesn't matter - another_response = await request_chan.send( - GetThingRequest(more_things.pop(0)) - ) - if another_response is not None: - assert another_response.name == expected_things[response_index] - assert another_response.version == response_index - response_index += 1 + await request_chan.send(GetThingRequest(more_things.pop(0))) + elif not send_initial_requests.done(): + # Make sure the sending task it completed + await send_initial_requests else: # No more things to send make sure channel is closed - await request_chan.close() + request_chan.close() + assert response_index == len( + expected_things + ), "Didn't recieve all exptected responses"