Fix bugs and remove footgun feature in AsyncChannel
This commit is contained in:
parent
c8229e53a7
commit
e1ccd540a9
@ -30,14 +30,15 @@ class ChannelDone(Exception):
|
|||||||
|
|
||||||
class AsyncChannel(AsyncIterable[T]):
|
class AsyncChannel(AsyncIterable[T]):
|
||||||
"""
|
"""
|
||||||
A buffered async channel for sending items between coroutines with FIFO semantics.
|
A buffered async channel for sending items between coroutines with FIFO ordering.
|
||||||
|
|
||||||
This makes decoupled bidirection steaming gRPC requests easy if used like:
|
This makes decoupled bidirection steaming gRPC requests easy if used like:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
client = GeneratedStub(grpclib_chan)
|
client = GeneratedStub(grpclib_chan)
|
||||||
# The channel can be initialised with items to send immediately
|
request_chan = await AsyncChannel()
|
||||||
request_chan = AsyncChannel([ReqestObject(...), ReqestObject(...)])
|
# We can start be sending all the requests we already have
|
||||||
|
await request_chan.send_from([ReqestObject(...), ReqestObject(...)])
|
||||||
async for response in client.rpc_call(request_chan):
|
async for response in client.rpc_call(request_chan):
|
||||||
# The response iterator will remain active until the connection is closed
|
# The response iterator will remain active until the connection is closed
|
||||||
...
|
...
|
||||||
@ -48,7 +49,6 @@ class AsyncChannel(AsyncIterable[T]):
|
|||||||
request_chan.close()
|
request_chan.close()
|
||||||
|
|
||||||
Items can be sent through the channel by either:
|
Items can be sent through the channel by either:
|
||||||
- providing an iterable to the constructor
|
|
||||||
- providing an iterable to the send_from method
|
- providing an iterable to the send_from method
|
||||||
- passing them to the send method one at a time
|
- passing them to the send method one at a time
|
||||||
|
|
||||||
@ -81,17 +81,10 @@ class AsyncChannel(AsyncIterable[T]):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, *, buffer_limit: int = 0, close: bool = False,
|
||||||
source: Union[Iterable[T], AsyncIterable[T]] = tuple(),
|
|
||||||
*,
|
|
||||||
buffer_limit: int = 0,
|
|
||||||
close: bool = False,
|
|
||||||
):
|
):
|
||||||
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
|
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
|
||||||
self._closed = False
|
self._closed = False
|
||||||
self._sending_task = (
|
|
||||||
asyncio.ensure_future(self.send_from(source, close)) if source else None
|
|
||||||
)
|
|
||||||
self._waiting_recievers: int = 0
|
self._waiting_recievers: int = 0
|
||||||
# Track whether flush has been invoked so it can only happen once
|
# Track whether flush has been invoked so it can only happen once
|
||||||
self._flushed = False
|
self._flushed = False
|
||||||
@ -100,13 +93,14 @@ class AsyncChannel(AsyncIterable[T]):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
async def __anext__(self) -> T:
|
async def __anext__(self) -> T:
|
||||||
if self.done:
|
if self.done():
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
self._waiting_recievers += 1
|
self._waiting_recievers += 1
|
||||||
try:
|
try:
|
||||||
result = await self._queue.get()
|
result = await self._queue.get()
|
||||||
if result is self.__flush:
|
if result is self.__flush:
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
|
return result
|
||||||
finally:
|
finally:
|
||||||
self._waiting_recievers -= 1
|
self._waiting_recievers -= 1
|
||||||
self._queue.task_done()
|
self._queue.task_done()
|
||||||
@ -131,7 +125,7 @@ class AsyncChannel(AsyncIterable[T]):
|
|||||||
|
|
||||||
async def send_from(
|
async def send_from(
|
||||||
self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
|
self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
|
||||||
):
|
) -> "AsyncChannel[T]":
|
||||||
"""
|
"""
|
||||||
Iterates the given [Async]Iterable and sends all the resulting items.
|
Iterates the given [Async]Iterable and sends all the resulting items.
|
||||||
If close is set to True then subsequent send calls will be rejected with a
|
If close is set to True then subsequent send calls will be rejected with a
|
||||||
@ -151,9 +145,10 @@ class AsyncChannel(AsyncIterable[T]):
|
|||||||
await self._queue.put(item)
|
await self._queue.put(item)
|
||||||
if close:
|
if close:
|
||||||
# Complete the closing process
|
# Complete the closing process
|
||||||
await self.close()
|
self.close()
|
||||||
|
return self
|
||||||
|
|
||||||
async def send(self, item: T):
|
async def send(self, item: T) -> "AsyncChannel[T]":
|
||||||
"""
|
"""
|
||||||
Send a single item over this channel.
|
Send a single item over this channel.
|
||||||
:param item: The item to send
|
:param item: The item to send
|
||||||
@ -161,6 +156,7 @@ class AsyncChannel(AsyncIterable[T]):
|
|||||||
if self._closed:
|
if self._closed:
|
||||||
raise ChannelClosed("Cannot send through a closed channel")
|
raise ChannelClosed("Cannot send through a closed channel")
|
||||||
await self._queue.put(item)
|
await self._queue.put(item)
|
||||||
|
return self
|
||||||
|
|
||||||
async def recieve(self) -> Optional[T]:
|
async def recieve(self) -> Optional[T]:
|
||||||
"""
|
"""
|
||||||
@ -168,7 +164,7 @@ class AsyncChannel(AsyncIterable[T]):
|
|||||||
or None if the channel is closed before another item is sent.
|
or None if the channel is closed before another item is sent.
|
||||||
:return: An item from the channel
|
:return: An item from the channel
|
||||||
"""
|
"""
|
||||||
if self.done:
|
if self.done():
|
||||||
raise ChannelDone("Cannot recieve from a closed channel")
|
raise ChannelDone("Cannot recieve from a closed channel")
|
||||||
self._waiting_recievers += 1
|
self._waiting_recievers += 1
|
||||||
try:
|
try:
|
||||||
@ -184,8 +180,6 @@ class AsyncChannel(AsyncIterable[T]):
|
|||||||
"""
|
"""
|
||||||
Close this channel to new items
|
Close this channel to new items
|
||||||
"""
|
"""
|
||||||
if self._sending_task is not None:
|
|
||||||
self._sending_task.cancel()
|
|
||||||
self._closed = True
|
self._closed = True
|
||||||
asyncio.ensure_future(self._flush_queue())
|
asyncio.ensure_future(self._flush_queue())
|
||||||
|
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
from betterproto.tests.output_betterproto.service.service import (
|
from betterproto.tests.output_betterproto.service.service import (
|
||||||
DoThingResponse,
|
DoThingResponse,
|
||||||
DoThingRequest,
|
DoThingRequest,
|
||||||
@ -129,7 +130,10 @@ async def test_async_gen_for_stream_stream_request():
|
|||||||
# Use an AsyncChannel to decouple sending and recieving, it'll send some_things
|
# Use an AsyncChannel to decouple sending and recieving, it'll send some_things
|
||||||
# immediately and we'll use it to send more_things later, after recieving some
|
# immediately and we'll use it to send more_things later, after recieving some
|
||||||
# results
|
# results
|
||||||
request_chan = AsyncChannel(GetThingRequest(name) for name in some_things)
|
request_chan = AsyncChannel()
|
||||||
|
send_initial_requests = asyncio.ensure_future(
|
||||||
|
request_chan.send_from(GetThingRequest(name) for name in some_things)
|
||||||
|
)
|
||||||
response_index = 0
|
response_index = 0
|
||||||
async for response in client.get_different_things(request_chan):
|
async for response in client.get_different_things(request_chan):
|
||||||
assert response.name == expected_things[response_index]
|
assert response.name == expected_things[response_index]
|
||||||
@ -138,13 +142,13 @@ async def test_async_gen_for_stream_stream_request():
|
|||||||
if more_things:
|
if more_things:
|
||||||
# Send some more requests as we recieve reponses to be sure coordination of
|
# Send some more requests as we recieve reponses to be sure coordination of
|
||||||
# send/recieve events doesn't matter
|
# send/recieve events doesn't matter
|
||||||
another_response = await request_chan.send(
|
await request_chan.send(GetThingRequest(more_things.pop(0)))
|
||||||
GetThingRequest(more_things.pop(0))
|
elif not send_initial_requests.done():
|
||||||
)
|
# Make sure the sending task it completed
|
||||||
if another_response is not None:
|
await send_initial_requests
|
||||||
assert another_response.name == expected_things[response_index]
|
|
||||||
assert another_response.version == response_index
|
|
||||||
response_index += 1
|
|
||||||
else:
|
else:
|
||||||
# No more things to send make sure channel is closed
|
# No more things to send make sure channel is closed
|
||||||
await request_chan.close()
|
request_chan.close()
|
||||||
|
assert response_index == len(
|
||||||
|
expected_things
|
||||||
|
), "Didn't recieve all exptected responses"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user