Fix bugs and remove footgun feature in AsyncChannel

This commit is contained in:
Nat Noordanus 2020-06-15 23:35:56 +02:00
parent c8229e53a7
commit e1ccd540a9
2 changed files with 26 additions and 28 deletions

View File

@ -30,14 +30,15 @@ class ChannelDone(Exception):
class AsyncChannel(AsyncIterable[T]): class AsyncChannel(AsyncIterable[T]):
""" """
A buffered async channel for sending items between coroutines with FIFO semantics. A buffered async channel for sending items between coroutines with FIFO ordering.
This makes decoupled bidirection steaming gRPC requests easy if used like: This makes decoupled bidirection steaming gRPC requests easy if used like:
.. code-block:: python .. code-block:: python
client = GeneratedStub(grpclib_chan) client = GeneratedStub(grpclib_chan)
# The channel can be initialised with items to send immediately request_chan = await AsyncChannel()
request_chan = AsyncChannel([ReqestObject(...), ReqestObject(...)]) # We can start be sending all the requests we already have
await request_chan.send_from([ReqestObject(...), ReqestObject(...)])
async for response in client.rpc_call(request_chan): async for response in client.rpc_call(request_chan):
# The response iterator will remain active until the connection is closed # The response iterator will remain active until the connection is closed
... ...
@ -48,7 +49,6 @@ class AsyncChannel(AsyncIterable[T]):
request_chan.close() request_chan.close()
Items can be sent through the channel by either: Items can be sent through the channel by either:
- providing an iterable to the constructor
- providing an iterable to the send_from method - providing an iterable to the send_from method
- passing them to the send method one at a time - passing them to the send method one at a time
@ -81,17 +81,10 @@ class AsyncChannel(AsyncIterable[T]):
""" """
def __init__( def __init__(
self, self, *, buffer_limit: int = 0, close: bool = False,
source: Union[Iterable[T], AsyncIterable[T]] = tuple(),
*,
buffer_limit: int = 0,
close: bool = False,
): ):
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit) self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
self._closed = False self._closed = False
self._sending_task = (
asyncio.ensure_future(self.send_from(source, close)) if source else None
)
self._waiting_recievers: int = 0 self._waiting_recievers: int = 0
# Track whether flush has been invoked so it can only happen once # Track whether flush has been invoked so it can only happen once
self._flushed = False self._flushed = False
@ -100,13 +93,14 @@ class AsyncChannel(AsyncIterable[T]):
return self return self
async def __anext__(self) -> T: async def __anext__(self) -> T:
if self.done: if self.done():
raise StopAsyncIteration raise StopAsyncIteration
self._waiting_recievers += 1 self._waiting_recievers += 1
try: try:
result = await self._queue.get() result = await self._queue.get()
if result is self.__flush: if result is self.__flush:
raise StopAsyncIteration raise StopAsyncIteration
return result
finally: finally:
self._waiting_recievers -= 1 self._waiting_recievers -= 1
self._queue.task_done() self._queue.task_done()
@ -131,7 +125,7 @@ class AsyncChannel(AsyncIterable[T]):
async def send_from( async def send_from(
self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
): ) -> "AsyncChannel[T]":
""" """
Iterates the given [Async]Iterable and sends all the resulting items. Iterates the given [Async]Iterable and sends all the resulting items.
If close is set to True then subsequent send calls will be rejected with a If close is set to True then subsequent send calls will be rejected with a
@ -151,9 +145,10 @@ class AsyncChannel(AsyncIterable[T]):
await self._queue.put(item) await self._queue.put(item)
if close: if close:
# Complete the closing process # Complete the closing process
await self.close() self.close()
return self
async def send(self, item: T): async def send(self, item: T) -> "AsyncChannel[T]":
""" """
Send a single item over this channel. Send a single item over this channel.
:param item: The item to send :param item: The item to send
@ -161,6 +156,7 @@ class AsyncChannel(AsyncIterable[T]):
if self._closed: if self._closed:
raise ChannelClosed("Cannot send through a closed channel") raise ChannelClosed("Cannot send through a closed channel")
await self._queue.put(item) await self._queue.put(item)
return self
async def recieve(self) -> Optional[T]: async def recieve(self) -> Optional[T]:
""" """
@ -168,7 +164,7 @@ class AsyncChannel(AsyncIterable[T]):
or None if the channel is closed before another item is sent. or None if the channel is closed before another item is sent.
:return: An item from the channel :return: An item from the channel
""" """
if self.done: if self.done():
raise ChannelDone("Cannot recieve from a closed channel") raise ChannelDone("Cannot recieve from a closed channel")
self._waiting_recievers += 1 self._waiting_recievers += 1
try: try:
@ -184,8 +180,6 @@ class AsyncChannel(AsyncIterable[T]):
""" """
Close this channel to new items Close this channel to new items
""" """
if self._sending_task is not None:
self._sending_task.cancel()
self._closed = True self._closed = True
asyncio.ensure_future(self._flush_queue()) asyncio.ensure_future(self._flush_queue())

View File

@ -1,3 +1,4 @@
import asyncio
from betterproto.tests.output_betterproto.service.service import ( from betterproto.tests.output_betterproto.service.service import (
DoThingResponse, DoThingResponse,
DoThingRequest, DoThingRequest,
@ -129,7 +130,10 @@ async def test_async_gen_for_stream_stream_request():
# Use an AsyncChannel to decouple sending and recieving, it'll send some_things # Use an AsyncChannel to decouple sending and recieving, it'll send some_things
# immediately and we'll use it to send more_things later, after recieving some # immediately and we'll use it to send more_things later, after recieving some
# results # results
request_chan = AsyncChannel(GetThingRequest(name) for name in some_things) request_chan = AsyncChannel()
send_initial_requests = asyncio.ensure_future(
request_chan.send_from(GetThingRequest(name) for name in some_things)
)
response_index = 0 response_index = 0
async for response in client.get_different_things(request_chan): async for response in client.get_different_things(request_chan):
assert response.name == expected_things[response_index] assert response.name == expected_things[response_index]
@ -138,13 +142,13 @@ async def test_async_gen_for_stream_stream_request():
if more_things: if more_things:
# Send some more requests as we recieve reponses to be sure coordination of # Send some more requests as we recieve reponses to be sure coordination of
# send/recieve events doesn't matter # send/recieve events doesn't matter
another_response = await request_chan.send( await request_chan.send(GetThingRequest(more_things.pop(0)))
GetThingRequest(more_things.pop(0)) elif not send_initial_requests.done():
) # Make sure the sending task it completed
if another_response is not None: await send_initial_requests
assert another_response.name == expected_things[response_index]
assert another_response.version == response_index
response_index += 1
else: else:
# No more things to send make sure channel is closed # No more things to send make sure channel is closed
await request_chan.close() request_chan.close()
assert response_index == len(
expected_things
), "Didn't recieve all exptected responses"