From 50bb67bf5dca04ded331adbcdcedab3aed7d7de1 Mon Sep 17 00:00:00 2001 From: Nat Noordanus Date: Mon, 15 Jun 2020 23:35:56 +0200 Subject: [PATCH] Fix bugs and remove footgun feature in AsyncChannel --- betterproto/grpc/util/async_channel.py | 24 +++++++------------ betterproto/tests/grpc/test_grpclib_client.py | 22 ++++++++++------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/betterproto/grpc/util/async_channel.py b/betterproto/grpc/util/async_channel.py index 7e83c94..fd0ecc2 100644 --- a/betterproto/grpc/util/async_channel.py +++ b/betterproto/grpc/util/async_channel.py @@ -81,17 +81,10 @@ class AsyncChannel(AsyncIterable[T]): """ def __init__( - self, - source: Union[Iterable[T], AsyncIterable[T]] = tuple(), - *, - buffer_limit: int = 0, - close: bool = False, + self, *, buffer_limit: int = 0, close: bool = False, ): self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit) self._closed = False - self._sending_task = ( - asyncio.ensure_future(self.send_from(source, close)) if source else None - ) self._waiting_recievers: int = 0 # Track whether flush has been invoked so it can only happen once self._flushed = False @@ -100,13 +93,14 @@ class AsyncChannel(AsyncIterable[T]): return self async def __anext__(self) -> T: - if self.done: + if self.done(): raise StopAsyncIteration self._waiting_recievers += 1 try: result = await self._queue.get() if result is self.__flush: raise StopAsyncIteration + return result finally: self._waiting_recievers -= 1 self._queue.task_done() @@ -131,7 +125,7 @@ class AsyncChannel(AsyncIterable[T]): async def send_from( self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False - ): + ) -> "AsyncChannel[T]": """ Iterates the given [Async]Iterable and sends all the resulting items. If close is set to True then subsequent send calls will be rejected with a @@ -151,9 +145,10 @@ class AsyncChannel(AsyncIterable[T]): await self._queue.put(item) if close: # Complete the closing process - await self.close() + self.close() + return self - async def send(self, item: T): + async def send(self, item: T) -> "AsyncChannel[T]": """ Send a single item over this channel. :param item: The item to send @@ -161,6 +156,7 @@ class AsyncChannel(AsyncIterable[T]): if self._closed: raise ChannelClosed("Cannot send through a closed channel") await self._queue.put(item) + return self async def recieve(self) -> Optional[T]: """ @@ -168,7 +164,7 @@ class AsyncChannel(AsyncIterable[T]): or None if the channel is closed before another item is sent. :return: An item from the channel """ - if self.done: + if self.done(): raise ChannelDone("Cannot recieve from a closed channel") self._waiting_recievers += 1 try: @@ -184,8 +180,6 @@ class AsyncChannel(AsyncIterable[T]): """ Close this channel to new items """ - if self._sending_task is not None: - self._sending_task.cancel() self._closed = True asyncio.ensure_future(self._flush_queue()) diff --git a/betterproto/tests/grpc/test_grpclib_client.py b/betterproto/tests/grpc/test_grpclib_client.py index dc57fe4..6c34ece 100644 --- a/betterproto/tests/grpc/test_grpclib_client.py +++ b/betterproto/tests/grpc/test_grpclib_client.py @@ -1,3 +1,4 @@ +import asyncio from betterproto.tests.output_betterproto.service.service import ( DoThingResponse, DoThingRequest, @@ -129,7 +130,10 @@ async def test_async_gen_for_stream_stream_request(): # Use an AsyncChannel to decouple sending and recieving, it'll send some_things # immediately and we'll use it to send more_things later, after recieving some # results - request_chan = AsyncChannel(GetThingRequest(name) for name in some_things) + request_chan = AsyncChannel() + send_initial_requests = asyncio.ensure_future( + request_chan.send_from(GetThingRequest(name) for name in some_things) + ) response_index = 0 async for response in client.get_different_things(request_chan): assert response.name == expected_things[response_index] @@ -138,13 +142,13 @@ async def test_async_gen_for_stream_stream_request(): if more_things: # Send some more requests as we recieve reponses to be sure coordination of # send/recieve events doesn't matter - another_response = await request_chan.send( - GetThingRequest(more_things.pop(0)) - ) - if another_response is not None: - assert another_response.name == expected_things[response_index] - assert another_response.version == response_index - response_index += 1 + await request_chan.send(GetThingRequest(more_things.pop(0))) + elif not send_initial_requests.done(): + # Make sure the sending task it completed + await send_initial_requests else: # No more things to send make sure channel is closed - await request_chan.close() + request_chan.close() + assert response_index == len( + expected_things + ), "Didn't recieve all exptected responses"