diff --git a/betterproto/grpc/util/async_channel.py b/betterproto/grpc/util/async_channel.py index 3a104ca..fd0ecc2 100644 --- a/betterproto/grpc/util/async_channel.py +++ b/betterproto/grpc/util/async_channel.py @@ -81,17 +81,10 @@ class AsyncChannel(AsyncIterable[T]): """ def __init__( - self, - source: Union[Iterable[T], AsyncIterable[T]] = tuple(), - *, - buffer_limit: int = 0, - close: bool = False, + self, *, buffer_limit: int = 0, close: bool = False, ): self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit) self._closed = False - self._sending_task = ( - asyncio.ensure_future(self.send_from(source, close)) if source else None - ) self._waiting_recievers: int = 0 # Track whether flush has been invoked so it can only happen once self._flushed = False @@ -132,7 +125,7 @@ class AsyncChannel(AsyncIterable[T]): async def send_from( self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False - ): + ) -> "AsyncChannel[T]": """ Iterates the given [Async]Iterable and sends all the resulting items. If close is set to True then subsequent send calls will be rejected with a @@ -153,8 +146,9 @@ class AsyncChannel(AsyncIterable[T]): if close: # Complete the closing process self.close() + return self - async def send(self, item: T): + async def send(self, item: T) -> "AsyncChannel[T]": """ Send a single item over this channel. :param item: The item to send @@ -162,6 +156,7 @@ class AsyncChannel(AsyncIterable[T]): if self._closed: raise ChannelClosed("Cannot send through a closed channel") await self._queue.put(item) + return self async def recieve(self) -> Optional[T]: """ @@ -185,8 +180,6 @@ class AsyncChannel(AsyncIterable[T]): """ Close this channel to new items """ - if self._sending_task is not None: - self._sending_task.cancel() self._closed = True asyncio.ensure_future(self._flush_queue()) diff --git a/betterproto/tests/grpc/test_grpclib_client.py b/betterproto/tests/grpc/test_grpclib_client.py index dc57fe4..6c34ece 100644 --- a/betterproto/tests/grpc/test_grpclib_client.py +++ b/betterproto/tests/grpc/test_grpclib_client.py @@ -1,3 +1,4 @@ +import asyncio from betterproto.tests.output_betterproto.service.service import ( DoThingResponse, DoThingRequest, @@ -129,7 +130,10 @@ async def test_async_gen_for_stream_stream_request(): # Use an AsyncChannel to decouple sending and recieving, it'll send some_things # immediately and we'll use it to send more_things later, after recieving some # results - request_chan = AsyncChannel(GetThingRequest(name) for name in some_things) + request_chan = AsyncChannel() + send_initial_requests = asyncio.ensure_future( + request_chan.send_from(GetThingRequest(name) for name in some_things) + ) response_index = 0 async for response in client.get_different_things(request_chan): assert response.name == expected_things[response_index] @@ -138,13 +142,13 @@ async def test_async_gen_for_stream_stream_request(): if more_things: # Send some more requests as we recieve reponses to be sure coordination of # send/recieve events doesn't matter - another_response = await request_chan.send( - GetThingRequest(more_things.pop(0)) - ) - if another_response is not None: - assert another_response.name == expected_things[response_index] - assert another_response.version == response_index - response_index += 1 + await request_chan.send(GetThingRequest(more_things.pop(0))) + elif not send_initial_requests.done(): + # Make sure the sending task it completed + await send_initial_requests else: # No more things to send make sure channel is closed - await request_chan.close() + request_chan.close() + assert response_index == len( + expected_things + ), "Didn't recieve all exptected responses"