Fix bugs and remove footgun feature in AsyncChannel
This commit is contained in:
		@@ -81,17 +81,10 @@ class AsyncChannel(AsyncIterable[T]):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        source: Union[Iterable[T], AsyncIterable[T]] = tuple(),
 | 
			
		||||
        *,
 | 
			
		||||
        buffer_limit: int = 0,
 | 
			
		||||
        close: bool = False,
 | 
			
		||||
        self, *, buffer_limit: int = 0, close: bool = False,
 | 
			
		||||
    ):
 | 
			
		||||
        self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
 | 
			
		||||
        self._closed = False
 | 
			
		||||
        self._sending_task = (
 | 
			
		||||
            asyncio.ensure_future(self.send_from(source, close)) if source else None
 | 
			
		||||
        )
 | 
			
		||||
        self._waiting_recievers: int = 0
 | 
			
		||||
        # Track whether flush has been invoked so it can only happen once
 | 
			
		||||
        self._flushed = False
 | 
			
		||||
@@ -100,13 +93,14 @@ class AsyncChannel(AsyncIterable[T]):
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    async def __anext__(self) -> T:
 | 
			
		||||
        if self.done:
 | 
			
		||||
        if self.done():
 | 
			
		||||
            raise StopAsyncIteration
 | 
			
		||||
        self._waiting_recievers += 1
 | 
			
		||||
        try:
 | 
			
		||||
            result = await self._queue.get()
 | 
			
		||||
            if result is self.__flush:
 | 
			
		||||
                raise StopAsyncIteration
 | 
			
		||||
            return result
 | 
			
		||||
        finally:
 | 
			
		||||
            self._waiting_recievers -= 1
 | 
			
		||||
            self._queue.task_done()
 | 
			
		||||
@@ -131,7 +125,7 @@ class AsyncChannel(AsyncIterable[T]):
 | 
			
		||||
 | 
			
		||||
    async def send_from(
 | 
			
		||||
        self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
 | 
			
		||||
    ):
 | 
			
		||||
    ) -> "AsyncChannel[T]":
 | 
			
		||||
        """
 | 
			
		||||
        Iterates the given [Async]Iterable and sends all the resulting items.
 | 
			
		||||
        If close is set to True then subsequent send calls will be rejected with a
 | 
			
		||||
@@ -151,9 +145,10 @@ class AsyncChannel(AsyncIterable[T]):
 | 
			
		||||
                await self._queue.put(item)
 | 
			
		||||
        if close:
 | 
			
		||||
            # Complete the closing process
 | 
			
		||||
            await self.close()
 | 
			
		||||
            self.close()
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    async def send(self, item: T):
 | 
			
		||||
    async def send(self, item: T) -> "AsyncChannel[T]":
 | 
			
		||||
        """
 | 
			
		||||
        Send a single item over this channel.
 | 
			
		||||
        :param item: The item to send
 | 
			
		||||
@@ -161,6 +156,7 @@ class AsyncChannel(AsyncIterable[T]):
 | 
			
		||||
        if self._closed:
 | 
			
		||||
            raise ChannelClosed("Cannot send through a closed channel")
 | 
			
		||||
        await self._queue.put(item)
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    async def recieve(self) -> Optional[T]:
 | 
			
		||||
        """
 | 
			
		||||
@@ -168,7 +164,7 @@ class AsyncChannel(AsyncIterable[T]):
 | 
			
		||||
        or None if the channel is closed before another item is sent.
 | 
			
		||||
        :return: An item from the channel
 | 
			
		||||
        """
 | 
			
		||||
        if self.done:
 | 
			
		||||
        if self.done():
 | 
			
		||||
            raise ChannelDone("Cannot recieve from a closed channel")
 | 
			
		||||
        self._waiting_recievers += 1
 | 
			
		||||
        try:
 | 
			
		||||
@@ -184,8 +180,6 @@ class AsyncChannel(AsyncIterable[T]):
 | 
			
		||||
        """
 | 
			
		||||
        Close this channel to new items
 | 
			
		||||
        """
 | 
			
		||||
        if self._sending_task is not None:
 | 
			
		||||
            self._sending_task.cancel()
 | 
			
		||||
        self._closed = True
 | 
			
		||||
        asyncio.ensure_future(self._flush_queue())
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,3 +1,4 @@
 | 
			
		||||
import asyncio
 | 
			
		||||
from betterproto.tests.output_betterproto.service.service import (
 | 
			
		||||
    DoThingResponse,
 | 
			
		||||
    DoThingRequest,
 | 
			
		||||
@@ -129,7 +130,10 @@ async def test_async_gen_for_stream_stream_request():
 | 
			
		||||
        # Use an AsyncChannel to decouple sending and recieving, it'll send some_things
 | 
			
		||||
        # immediately and we'll use it to send more_things later, after recieving some
 | 
			
		||||
        # results
 | 
			
		||||
        request_chan = AsyncChannel(GetThingRequest(name) for name in some_things)
 | 
			
		||||
        request_chan = AsyncChannel()
 | 
			
		||||
        send_initial_requests = asyncio.ensure_future(
 | 
			
		||||
            request_chan.send_from(GetThingRequest(name) for name in some_things)
 | 
			
		||||
        )
 | 
			
		||||
        response_index = 0
 | 
			
		||||
        async for response in client.get_different_things(request_chan):
 | 
			
		||||
            assert response.name == expected_things[response_index]
 | 
			
		||||
@@ -138,13 +142,13 @@ async def test_async_gen_for_stream_stream_request():
 | 
			
		||||
            if more_things:
 | 
			
		||||
                # Send some more requests as we recieve reponses to be sure coordination of
 | 
			
		||||
                # send/recieve events doesn't matter
 | 
			
		||||
                another_response = await request_chan.send(
 | 
			
		||||
                    GetThingRequest(more_things.pop(0))
 | 
			
		||||
                )
 | 
			
		||||
                if another_response is not None:
 | 
			
		||||
                    assert another_response.name == expected_things[response_index]
 | 
			
		||||
                    assert another_response.version == response_index
 | 
			
		||||
                    response_index += 1
 | 
			
		||||
                await request_chan.send(GetThingRequest(more_things.pop(0)))
 | 
			
		||||
            elif not send_initial_requests.done():
 | 
			
		||||
                # Make sure the sending task it completed
 | 
			
		||||
                await send_initial_requests
 | 
			
		||||
            else:
 | 
			
		||||
                # No more things to send make sure channel is closed
 | 
			
		||||
                await request_chan.close()
 | 
			
		||||
                request_chan.close()
 | 
			
		||||
        assert response_index == len(
 | 
			
		||||
            expected_things
 | 
			
		||||
        ), "Didn't recieve all exptected responses"
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user