Finish implementation and testing of client
Including stream_unary and stream_stream call methods. Also - improve organisation of relevant tests - fix some generated type annotations - Add AsyncChannel utility cos it's useful
This commit is contained in:
parent
09f821921f
commit
4b6f55dce5
@ -593,7 +593,7 @@ class Message(ABC):
|
||||
serialize_empty = False
|
||||
if isinstance(value, Message) and value._serialized_on_wire:
|
||||
# Empty messages can still be sent on the wire if they were
|
||||
# set (or received empty).
|
||||
# set (or recieved empty).
|
||||
serialize_empty = True
|
||||
|
||||
if value == self._get_field_default(field_name) and not (
|
||||
|
@ -1,7 +1,8 @@
|
||||
from abc import ABC
|
||||
import asyncio
|
||||
import grpclib.const
|
||||
from typing import (
|
||||
AsyncGenerator,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Collection,
|
||||
Iterator,
|
||||
@ -16,17 +17,18 @@ from .._types import ST, T
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from grpclib._protocols import IProtoMessage
|
||||
from grpclib.client import Channel
|
||||
from grpclib.client import Channel, Stream
|
||||
from grpclib.metadata import Deadline
|
||||
|
||||
|
||||
_Value = Union[str, bytes]
|
||||
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
|
||||
_MessageSource = Union[Iterator["IProtoMessage"], AsyncIterator["IProtoMessage"]]
|
||||
|
||||
|
||||
class ServiceStub(ABC):
|
||||
"""
|
||||
Base class for async gRPC service stubs.
|
||||
Base class for async gRPC clients.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -86,7 +88,7 @@ class ServiceStub(ABC):
|
||||
timeout: Optional[float] = None,
|
||||
deadline: Optional["Deadline"] = None,
|
||||
metadata: Optional[_MetadataLike] = None,
|
||||
) -> AsyncGenerator[T, None]:
|
||||
) -> AsyncIterator[T]:
|
||||
"""Make a unary request and return the stream response iterator."""
|
||||
async with self.channel.request(
|
||||
route,
|
||||
@ -102,17 +104,23 @@ class ServiceStub(ABC):
|
||||
async def _stream_unary(
|
||||
self,
|
||||
route: str,
|
||||
request_iterator: Iterator["IProtoMessage"],
|
||||
request_iterator: _MessageSource,
|
||||
request_type: Type[ST],
|
||||
response_type: Type[T],
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
deadline: Optional["Deadline"] = None,
|
||||
metadata: Optional[_MetadataLike] = None,
|
||||
) -> T:
|
||||
"""Make a stream request and return the response."""
|
||||
async with self.channel.request(
|
||||
route, grpclib.const.Cardinality.STREAM_UNARY, request_type, response_type
|
||||
route,
|
||||
grpclib.const.Cardinality.STREAM_UNARY,
|
||||
request_type,
|
||||
response_type,
|
||||
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
||||
) as stream:
|
||||
for message in request_iterator:
|
||||
await stream.send_message(message)
|
||||
await stream.send_request(end=True)
|
||||
await self._send_messages(stream, request_iterator)
|
||||
response = await stream.recv_message()
|
||||
assert response is not None
|
||||
return response
|
||||
@ -120,16 +128,42 @@ class ServiceStub(ABC):
|
||||
async def _stream_stream(
|
||||
self,
|
||||
route: str,
|
||||
request_iterator: Iterator["IProtoMessage"],
|
||||
request_iterator: _MessageSource,
|
||||
request_type: Type[ST],
|
||||
response_type: Type[T],
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""Make a stream request and return the stream response iterator."""
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
deadline: Optional["Deadline"] = None,
|
||||
metadata: Optional[_MetadataLike] = None,
|
||||
) -> AsyncIterator[T]:
|
||||
"""
|
||||
Make a stream request and return an AsyncIterator to iterate over response
|
||||
messages.
|
||||
"""
|
||||
async with self.channel.request(
|
||||
route, grpclib.const.Cardinality.STREAM_STREAM, request_type, response_type
|
||||
route,
|
||||
grpclib.const.Cardinality.STREAM_STREAM,
|
||||
request_type,
|
||||
response_type,
|
||||
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
||||
) as stream:
|
||||
for message in request_iterator:
|
||||
await stream.send_request()
|
||||
sending_task = asyncio.ensure_future(
|
||||
self._send_messages(stream, request_iterator)
|
||||
)
|
||||
try:
|
||||
async for response in stream:
|
||||
yield response
|
||||
except:
|
||||
sending_task.cancel()
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def _send_messages(stream, messages: _MessageSource):
|
||||
if hasattr(messages, "__aiter__"):
|
||||
async for message in messages:
|
||||
await stream.send_message(message)
|
||||
await stream.send_request(end=True)
|
||||
async for message in stream:
|
||||
yield message
|
||||
else:
|
||||
for message in messages:
|
||||
await stream.send_message(message)
|
||||
await stream.end()
|
||||
|
0
betterproto/grpc/util/__init__.py
Normal file
0
betterproto/grpc/util/__init__.py
Normal file
204
betterproto/grpc/util/async_channel.py
Normal file
204
betterproto/grpc/util/async_channel.py
Normal file
@ -0,0 +1,204 @@
|
||||
import asyncio
|
||||
from typing import (
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Iterable,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ChannelClosed(Exception):
|
||||
"""
|
||||
An exception raised on an attempt to send through a closed channel
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ChannelDone(Exception):
|
||||
"""
|
||||
An exception raised on an attempt to send recieve from a channel that is both closed
|
||||
and empty.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AsyncChannel(AsyncIterable[T]):
|
||||
"""
|
||||
A buffered async channel for sending items between coroutines with FIFO semantics.
|
||||
|
||||
This makes decoupled bidirection steaming gRPC requests easy if used like:
|
||||
|
||||
.. code-block:: python
|
||||
client = GeneratedStub(grpclib_chan)
|
||||
# The channel can be initialised with items to send immediately
|
||||
request_chan = AsyncChannel([ReqestObject(...), ReqestObject(...)])
|
||||
async for response in client.rpc_call(request_chan):
|
||||
# The response iterator will remain active until the connection is closed
|
||||
...
|
||||
# More items can be sent at any time
|
||||
await request_chan.send(ReqestObject(...))
|
||||
...
|
||||
# The channel must be closed to complete the gRPC connection
|
||||
request_chan.close()
|
||||
|
||||
Items can be sent through the channel by either:
|
||||
- providing an iterable to the constructor
|
||||
- providing an iterable to the send_from method
|
||||
- passing them to the send method one at a time
|
||||
|
||||
Items can be recieved from the channel by either:
|
||||
- iterating over the channel with a for loop to get all items
|
||||
- calling the recieve method to get one item at a time
|
||||
|
||||
If the channel is empty then recievers will wait until either an item appears or the
|
||||
channel is closed.
|
||||
|
||||
Once the channel is closed then subsequent attempt to send through the channel will
|
||||
fail with a ChannelClosed exception.
|
||||
|
||||
When th channel is closed and empty then it is done, and further attempts to recieve
|
||||
from it will fail with a ChannelDone exception
|
||||
|
||||
If multiple coroutines recieve from the channel concurrently, each item sent will be
|
||||
recieved by only one of the recievers.
|
||||
|
||||
:param source:
|
||||
An optional iterable will items that should be sent through the channel
|
||||
immediately.
|
||||
:param buffer_limit:
|
||||
Limit the number of items that can be buffered in the channel, A value less than
|
||||
1 implies no limit. If the channel is full then attempts to send more items will
|
||||
result in the sender waiting until an item is recieved from the channel.
|
||||
:param close:
|
||||
If set to True then the channel will automatically close after exhausting source
|
||||
or immediately if no source is provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
source: Union[Iterable[T], AsyncIterable[T]] = tuple(),
|
||||
*,
|
||||
buffer_limit: int = 0,
|
||||
close: bool = False,
|
||||
):
|
||||
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
|
||||
self._closed = False
|
||||
self._sending_task = (
|
||||
asyncio.ensure_future(self.send_from(source, close)) if source else None
|
||||
)
|
||||
self._waiting_recievers: int = 0
|
||||
# Track whether flush has been invoked so it can only happen once
|
||||
self._flushed = False
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[T]:
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> T:
|
||||
if self.done:
|
||||
raise StopAsyncIteration
|
||||
self._waiting_recievers += 1
|
||||
try:
|
||||
result = await self._queue.get()
|
||||
if result is self.__flush:
|
||||
raise StopAsyncIteration
|
||||
finally:
|
||||
self._waiting_recievers -= 1
|
||||
self._queue.task_done()
|
||||
|
||||
def closed(self) -> bool:
|
||||
"""
|
||||
Returns True if this channel is closed and no-longer accepting new items
|
||||
"""
|
||||
return self._closed
|
||||
|
||||
def done(self) -> bool:
|
||||
"""
|
||||
Check if this channel is done.
|
||||
|
||||
:return: True if this channel is closed and and has been drained of items in
|
||||
which case any further attempts to recieve an item from this channel will raise
|
||||
a ChannelDone exception.
|
||||
"""
|
||||
# After close the channel is not yet done until there is at least one waiting
|
||||
# reciever per enqueued item.
|
||||
return self._closed and self._queue.qsize() <= self._waiting_recievers
|
||||
|
||||
async def send_from(
|
||||
self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
|
||||
):
|
||||
"""
|
||||
Iterates the given [Async]Iterable and sends all the resulting items.
|
||||
If close is set to True then subsequent send calls will be rejected with a
|
||||
ChannelClosed exception.
|
||||
:param source: an iterable of items to send
|
||||
:param close:
|
||||
if True then the channel will be closed after the source has been exhausted
|
||||
|
||||
"""
|
||||
if self._closed:
|
||||
raise ChannelClosed("Cannot send through a closed channel")
|
||||
if isinstance(source, AsyncIterable):
|
||||
async for item in source:
|
||||
await self._queue.put(item)
|
||||
else:
|
||||
for item in source:
|
||||
await self._queue.put(item)
|
||||
if close:
|
||||
# Complete the closing process
|
||||
await self.close()
|
||||
|
||||
async def send(self, item: T):
|
||||
"""
|
||||
Send a single item over this channel.
|
||||
:param item: The item to send
|
||||
"""
|
||||
if self._closed:
|
||||
raise ChannelClosed("Cannot send through a closed channel")
|
||||
await self._queue.put(item)
|
||||
|
||||
async def recieve(self) -> Optional[T]:
|
||||
"""
|
||||
Returns the next item from this channel when it becomes available,
|
||||
or None if the channel is closed before another item is sent.
|
||||
:return: An item from the channel
|
||||
"""
|
||||
if self.done:
|
||||
raise ChannelDone("Cannot recieve from a closed channel")
|
||||
self._waiting_recievers += 1
|
||||
try:
|
||||
result = await self._queue.get()
|
||||
if result is self.__flush:
|
||||
return None
|
||||
return result
|
||||
finally:
|
||||
self._waiting_recievers -= 1
|
||||
self._queue.task_done()
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Close this channel to new items
|
||||
"""
|
||||
if self._sending_task is not None:
|
||||
self._sending_task.cancel()
|
||||
self._closed = True
|
||||
asyncio.ensure_future(self._flush_queue())
|
||||
|
||||
async def _flush_queue(self):
|
||||
"""
|
||||
To be called after the channel is closed. Pushes a number of self.__flush
|
||||
objects to the queue to ensure no waiting consumers get deadlocked.
|
||||
"""
|
||||
if not self._flushed:
|
||||
self._flushed = True
|
||||
deadlocked_recievers = max(0, self._waiting_recievers - self._queue.qsize())
|
||||
for _ in range(deadlocked_recievers):
|
||||
await self._queue.put(self.__flush)
|
||||
|
||||
# A special signal object for flushing the queue when the channel is closed
|
||||
__flush = object()
|
@ -344,11 +344,12 @@ def generate_code(request, response):
|
||||
}
|
||||
)
|
||||
|
||||
if method.server_streaming:
|
||||
output["typing_imports"].add("AsyncGenerator")
|
||||
|
||||
if method.client_streaming:
|
||||
output["typing_imports"].add("Iterator")
|
||||
output["typing_imports"].add("AsyncIterable")
|
||||
output["typing_imports"].add("Iterable")
|
||||
output["typing_imports"].add("Union")
|
||||
if method.server_streaming:
|
||||
output["typing_imports"].add("AsyncIterator")
|
||||
|
||||
output["services"].append(data)
|
||||
|
||||
|
@ -77,9 +77,9 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
{%- endif -%}
|
||||
{%- else -%}
|
||||
{# Client streaming: need a request iterator instead #}
|
||||
, request_iterator: Iterator["{{ method.input }}"]
|
||||
, request_iterator: Union[AsyncIterable["{{ method.input }}"], Iterable["{{ method.input }}"]]
|
||||
{%- endif -%}
|
||||
) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
|
||||
) -> {% if method.server_streaming %}AsyncIterator[{{ method.output }}]{% else %}{{ method.output }}{% endif %}:
|
||||
{% if method.comment %}
|
||||
{{ method.comment }}
|
||||
|
||||
@ -97,7 +97,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
{% endif %}
|
||||
|
||||
{% if method.server_streaming %}
|
||||
{% if method.client_streaming %}
|
||||
{% if method.client_streaming %}
|
||||
async for response in self._stream_stream(
|
||||
"{{ method.route }}",
|
||||
request_iterator,
|
||||
@ -105,7 +105,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
{{ method.output }},
|
||||
):
|
||||
yield response
|
||||
{% else %}{# i.e. not client streaming #}
|
||||
{% else %}{# i.e. not client streaming #}
|
||||
async for response in self._unary_stream(
|
||||
"{{ method.route }}",
|
||||
request,
|
||||
@ -113,22 +113,22 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
):
|
||||
yield response
|
||||
|
||||
{% endif %}{# if client streaming #}
|
||||
{% endif %}{# if client streaming #}
|
||||
{% else %}{# i.e. not server streaming #}
|
||||
{% if method.client_streaming %}
|
||||
{% if method.client_streaming %}
|
||||
return await self._stream_unary(
|
||||
"{{ method.route }}",
|
||||
request_iterator,
|
||||
{{ method.input }},
|
||||
{{ method.output }}
|
||||
)
|
||||
{% else %}{# i.e. not client streaming #}
|
||||
{% else %}{# i.e. not client streaming #}
|
||||
return await self._unary_unary(
|
||||
"{{ method.route }}",
|
||||
request,
|
||||
{{ method.output }}
|
||||
)
|
||||
{% endif %}{# client streaming #}
|
||||
{% endif %}{# client streaming #}
|
||||
{% endif %}
|
||||
|
||||
{% endfor %}
|
||||
|
0
betterproto/tests/grpc/__init__.py
Normal file
0
betterproto/tests/grpc/__init__.py
Normal file
150
betterproto/tests/grpc/test_grpclib_client.py
Normal file
150
betterproto/tests/grpc/test_grpclib_client.py
Normal file
@ -0,0 +1,150 @@
|
||||
from betterproto.tests.output_betterproto.service.service import (
|
||||
DoThingResponse,
|
||||
DoThingRequest,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
TestStub as ThingServiceClient,
|
||||
)
|
||||
import grpclib
|
||||
from grpclib.testing import ChannelFor
|
||||
import pytest
|
||||
from betterproto.grpc.util.async_channel import AsyncChannel
|
||||
from .thing_service import ThingService
|
||||
|
||||
|
||||
async def _test_client(client, name="clean room", **kwargs):
|
||||
response = await client.do_thing(name=name)
|
||||
assert response.names == [name]
|
||||
|
||||
|
||||
def _assert_request_meta_recieved(deadline, metadata):
|
||||
def server_side_test(stream):
|
||||
assert stream.deadline._timestamp == pytest.approx(
|
||||
deadline._timestamp, 1
|
||||
), "The provided deadline should be recieved serverside"
|
||||
assert (
|
||||
stream.metadata["authorization"] == metadata["authorization"]
|
||||
), "The provided authorization metadata should be recieved serverside"
|
||||
|
||||
return server_side_test
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_service_call():
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
await _test_client(ThingServiceClient(channel))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_call_with_upfront_request_params():
|
||||
# Setting deadline
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||
metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
||||
) as channel:
|
||||
await _test_client(
|
||||
ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||
)
|
||||
|
||||
# Setting timeout
|
||||
timeout = 99
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||
metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
||||
) as channel:
|
||||
await _test_client(
|
||||
ThingServiceClient(channel, timeout=timeout, metadata=metadata)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_call_lower_level_with_overrides():
|
||||
THING_TO_DO = "get milk"
|
||||
|
||||
# Setting deadline
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||
metadata = {"authorization": "12345"}
|
||||
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
|
||||
kwarg_metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
||||
) as channel:
|
||||
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||
response = await client._unary_unary(
|
||||
"/service.Test/DoThing",
|
||||
DoThingRequest(THING_TO_DO),
|
||||
DoThingResponse,
|
||||
deadline=kwarg_deadline,
|
||||
metadata=kwarg_metadata,
|
||||
)
|
||||
assert response.names == [THING_TO_DO]
|
||||
|
||||
# Setting timeout
|
||||
timeout = 99
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||
metadata = {"authorization": "12345"}
|
||||
kwarg_timeout = 9000
|
||||
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
|
||||
kwarg_metadata = {"authorization": "09876"}
|
||||
async with ChannelFor(
|
||||
[
|
||||
ThingService(
|
||||
test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata),
|
||||
)
|
||||
]
|
||||
) as channel:
|
||||
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||
response = await client._unary_unary(
|
||||
"/service.Test/DoThing",
|
||||
DoThingRequest(THING_TO_DO),
|
||||
DoThingResponse,
|
||||
timeout=kwarg_timeout,
|
||||
metadata=kwarg_metadata,
|
||||
)
|
||||
assert response.names == [THING_TO_DO]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_gen_for_unary_stream_request():
|
||||
thing_name = "my milkshakes"
|
||||
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
client = ThingServiceClient(channel)
|
||||
expected_versions = [5, 4, 3, 2, 1]
|
||||
async for response in client.get_thing_versions(name=thing_name):
|
||||
assert response.name == thing_name
|
||||
assert response.version == expected_versions.pop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_gen_for_stream_stream_request():
|
||||
some_things = ["cake", "cricket", "coral reef"]
|
||||
more_things = ["ball", "that", "56kmodem", "liberal humanism", "cheesesticks"]
|
||||
expected_things = (*some_things, *more_things)
|
||||
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
client = ThingServiceClient(channel)
|
||||
# Use an AsyncChannel to decouple sending and recieving, it'll send some_things
|
||||
# immediately and we'll use it to send more_things later, after recieving some
|
||||
# results
|
||||
request_chan = AsyncChannel(GetThingRequest(name) for name in some_things)
|
||||
response_index = 0
|
||||
async for response in client.get_different_things(request_chan):
|
||||
assert response.name == expected_things[response_index]
|
||||
assert response.version == response_index + 1
|
||||
response_index += 1
|
||||
if more_things:
|
||||
# Send some more requests as we recieve reponses to be sure coordination of
|
||||
# send/recieve events doesn't matter
|
||||
another_response = await request_chan.send(
|
||||
GetThingRequest(more_things.pop(0))
|
||||
)
|
||||
if another_response is not None:
|
||||
assert another_response.name == expected_things[response_index]
|
||||
assert another_response.version == response_index
|
||||
response_index += 1
|
||||
else:
|
||||
# No more things to send make sure channel is closed
|
||||
await request_chan.close()
|
83
betterproto/tests/grpc/thing_service.py
Normal file
83
betterproto/tests/grpc/thing_service.py
Normal file
@ -0,0 +1,83 @@
|
||||
from betterproto.tests.output_betterproto.service.service import (
|
||||
DoThingResponse,
|
||||
DoThingRequest,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
TestStub as ThingServiceClient,
|
||||
)
|
||||
import grpclib
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class ThingService:
|
||||
def __init__(self, test_hook=None):
|
||||
# This lets us pass assertions to the servicer ;)
|
||||
self.test_hook = test_hook
|
||||
|
||||
async def do_thing(
|
||||
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||
):
|
||||
request = await stream.recv_message()
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
await stream.send_message(DoThingResponse([request.name]))
|
||||
|
||||
async def do_many_things(
|
||||
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||
):
|
||||
thing_names = [request.name for request in stream]
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
await stream.send_message(DoThingResponse(thing_names))
|
||||
|
||||
async def get_thing_versions(
|
||||
self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
|
||||
):
|
||||
request = await stream.recv_message()
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
for version_num in range(1, 6):
|
||||
await stream.send_message(
|
||||
GetThingResponse(name=request.name, version=version_num)
|
||||
)
|
||||
|
||||
async def get_different_things(
|
||||
self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
|
||||
):
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
# Respond to each input item immediately
|
||||
response_num = 0
|
||||
async for request in stream:
|
||||
response_num += 1
|
||||
await stream.send_message(
|
||||
GetThingResponse(name=request.name, version=response_num)
|
||||
)
|
||||
|
||||
def __mapping__(self) -> Dict[str, "grpclib.const.Handler"]:
|
||||
return {
|
||||
"/service.Test/DoThing": grpclib.const.Handler(
|
||||
self.do_thing,
|
||||
grpclib.const.Cardinality.UNARY_UNARY,
|
||||
DoThingRequest,
|
||||
DoThingResponse,
|
||||
),
|
||||
"/service.Test/DoManyThings": grpclib.const.Handler(
|
||||
self.do_many_things,
|
||||
grpclib.const.Cardinality.STREAM_UNARY,
|
||||
DoThingRequest,
|
||||
DoThingResponse,
|
||||
),
|
||||
"/service.Test/GetThingVersions": grpclib.const.Handler(
|
||||
self.get_thing_versions,
|
||||
grpclib.const.Cardinality.UNARY_STREAM,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
),
|
||||
"/service.Test/GetDifferentThings": grpclib.const.Handler(
|
||||
self.get_different_things,
|
||||
grpclib.const.Cardinality.STREAM_STREAM,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
),
|
||||
}
|
@ -23,7 +23,7 @@ test_cases = [
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||
async def test_channel_receives_wrapped_type(
|
||||
async def test_channel_recieves_wrapped_type(
|
||||
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
||||
):
|
||||
wrapped_value = wrapper_class()
|
||||
|
@ -1,132 +0,0 @@
|
||||
import betterproto
|
||||
import grpclib
|
||||
from grpclib.testing import ChannelFor
|
||||
import pytest
|
||||
from typing import Dict
|
||||
|
||||
from betterproto.tests.output_betterproto.service.service import (
|
||||
DoThingResponse,
|
||||
DoThingRequest,
|
||||
TestStub as ExampleServiceStub,
|
||||
)
|
||||
|
||||
|
||||
class ExampleService:
|
||||
def __init__(self, test_hook=None):
|
||||
# This lets us pass assertions to the servicer ;)
|
||||
self.test_hook = test_hook
|
||||
|
||||
async def DoThing(
|
||||
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||
):
|
||||
request = await stream.recv_message()
|
||||
print("self.test_hook", self.test_hook)
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
for iteration in range(request.iterations):
|
||||
pass
|
||||
await stream.send_message(DoThingResponse(request.iterations))
|
||||
|
||||
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
|
||||
return {
|
||||
"/service.Test/DoThing": grpclib.const.Handler(
|
||||
self.DoThing,
|
||||
grpclib.const.Cardinality.UNARY_UNARY,
|
||||
DoThingRequest,
|
||||
DoThingResponse,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
async def _test_stub(stub, iterations=42, **kwargs):
|
||||
response = await stub.do_thing(iterations=iterations)
|
||||
assert response.successful_iterations == iterations
|
||||
|
||||
|
||||
def _get_server_side_test(deadline, metadata):
|
||||
def server_side_test(stream):
|
||||
assert stream.deadline._timestamp == pytest.approx(
|
||||
deadline._timestamp, 1
|
||||
), "The provided deadline should be recieved serverside"
|
||||
assert (
|
||||
stream.metadata["authorization"] == metadata["authorization"]
|
||||
), "The provided authorization metadata should be recieved serverside"
|
||||
|
||||
return server_side_test
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_service_call():
|
||||
async with ChannelFor([ExampleService()]) as channel:
|
||||
await _test_stub(ExampleServiceStub(channel))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_call_with_upfront_request_params():
|
||||
# Setting deadline
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||
metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ExampleService(test_hook=_get_server_side_test(deadline, metadata))]
|
||||
) as channel:
|
||||
await _test_stub(
|
||||
ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
|
||||
)
|
||||
|
||||
# Setting timeout
|
||||
timeout = 99
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||
metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ExampleService(test_hook=_get_server_side_test(deadline, metadata))]
|
||||
) as channel:
|
||||
await _test_stub(
|
||||
ExampleServiceStub(channel, timeout=timeout, metadata=metadata)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_call_lower_level_with_overrides():
|
||||
ITERATIONS = 99
|
||||
|
||||
# Setting deadline
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||
metadata = {"authorization": "12345"}
|
||||
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
|
||||
kwarg_metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ExampleService(test_hook=_get_server_side_test(deadline, metadata))]
|
||||
) as channel:
|
||||
stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
|
||||
response = await stub._unary_unary(
|
||||
"/service.Test/DoThing",
|
||||
DoThingRequest(ITERATIONS),
|
||||
DoThingResponse,
|
||||
deadline=kwarg_deadline,
|
||||
metadata=kwarg_metadata,
|
||||
)
|
||||
assert response.successful_iterations == ITERATIONS
|
||||
|
||||
# Setting timeout
|
||||
timeout = 99
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||
metadata = {"authorization": "12345"}
|
||||
kwarg_timeout = 9000
|
||||
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
|
||||
kwarg_metadata = {"authorization": "09876"}
|
||||
async with ChannelFor(
|
||||
[
|
||||
ExampleService(
|
||||
test_hook=_get_server_side_test(kwarg_deadline, kwarg_metadata)
|
||||
)
|
||||
]
|
||||
) as channel:
|
||||
stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
|
||||
response = await stub._unary_unary(
|
||||
"/service.Test/DoThing",
|
||||
DoThingRequest(ITERATIONS),
|
||||
DoThingResponse,
|
||||
timeout=kwarg_timeout,
|
||||
metadata=kwarg_metadata,
|
||||
)
|
||||
assert response.successful_iterations == ITERATIONS
|
@ -1,176 +0,0 @@
|
||||
import betterproto
|
||||
import grpclib
|
||||
from grpclib.testing import ChannelFor
|
||||
import pytest
|
||||
from typing import Dict
|
||||
from betterproto.tests.output_betterproto.service.service import (
|
||||
DoThingResponse,
|
||||
DoThingRequest,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
TestStub as ThingServiceClient,
|
||||
)
|
||||
|
||||
|
||||
class ThingService:
|
||||
def __init__(self, test_hook=None):
|
||||
# This lets us pass assertions to the servicer ;)
|
||||
self.test_hook = test_hook
|
||||
|
||||
async def DoThing(
|
||||
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||
):
|
||||
request = await stream.recv_message()
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
await stream.send_message(DoThingResponse([request.name]))
|
||||
|
||||
async def DoManyThings(
|
||||
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||
):
|
||||
thing_names = [request.name for request in stream]
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
await stream.send_message(DoThingResponse(thing_names))
|
||||
|
||||
async def GetThingVersions(
|
||||
self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
|
||||
):
|
||||
request = await stream.recv_message()
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
for version_num in range(1, 6):
|
||||
await stream.send_message(
|
||||
GetThingResponse(name=request, version=version_num)
|
||||
)
|
||||
|
||||
async def GetDifferentThings(
|
||||
self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
|
||||
):
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
# Response to each input item immediately
|
||||
for request in stream:
|
||||
await stream.send_message(GetThingResponse(name=request.name, version=1))
|
||||
|
||||
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
|
||||
return {
|
||||
"/service.Test/DoThing": grpclib.const.Handler(
|
||||
self.DoThing,
|
||||
grpclib.const.Cardinality.UNARY_UNARY,
|
||||
DoThingRequest,
|
||||
DoThingResponse,
|
||||
),
|
||||
"/service.Test/DoManyThings": grpclib.const.Handler(
|
||||
self.DoManyThings,
|
||||
grpclib.const.Cardinality.STREAM_UNARY,
|
||||
DoThingRequest,
|
||||
DoThingResponse,
|
||||
),
|
||||
"/service.Test/GetThingVersions": grpclib.const.Handler(
|
||||
self.GetThingVersions,
|
||||
grpclib.const.Cardinality.UNARY_STREAM,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
),
|
||||
"/service.Test/GetDifferentThings": grpclib.const.Handler(
|
||||
self.GetDifferentThings,
|
||||
grpclib.const.Cardinality.STREAM_STREAM,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def _test_stub(stub, name="clean room", **kwargs):
|
||||
response = await stub.do_thing(name=name)
|
||||
assert response.names == [name]
|
||||
|
||||
|
||||
def _assert_request_meta_recieved(deadline, metadata):
|
||||
def server_side_test(stream):
|
||||
assert stream.deadline._timestamp == pytest.approx(
|
||||
deadline._timestamp, 1
|
||||
), "The provided deadline should be recieved serverside"
|
||||
assert (
|
||||
stream.metadata["authorization"] == metadata["authorization"]
|
||||
), "The provided authorization metadata should be recieved serverside"
|
||||
|
||||
return server_side_test
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_service_call():
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
await _test_stub(ThingServiceClient(channel))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_call_with_upfront_request_params():
|
||||
# Setting deadline
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||
metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata))]
|
||||
) as channel:
|
||||
await _test_stub(
|
||||
ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||
)
|
||||
|
||||
# Setting timeout
|
||||
timeout = 99
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||
metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata))]
|
||||
) as channel:
|
||||
await _test_stub(
|
||||
ThingServiceClient(channel, timeout=timeout, metadata=metadata)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_call_lower_level_with_overrides():
|
||||
THING_TO_DO = "get milk"
|
||||
|
||||
# Setting deadline
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||
metadata = {"authorization": "12345"}
|
||||
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
|
||||
kwarg_metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata))]
|
||||
) as channel:
|
||||
stub = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||
response = await stub._unary_unary(
|
||||
"/service.Test/DoThing",
|
||||
DoThingRequest(THING_TO_DO),
|
||||
DoThingResponse,
|
||||
deadline=kwarg_deadline,
|
||||
metadata=kwarg_metadata,
|
||||
)
|
||||
assert response.names == [THING_TO_DO]
|
||||
|
||||
# Setting timeout
|
||||
timeout = 99
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||
metadata = {"authorization": "12345"}
|
||||
kwarg_timeout = 9000
|
||||
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
|
||||
kwarg_metadata = {"authorization": "09876"}
|
||||
async with ChannelFor(
|
||||
[
|
||||
ThingService(
|
||||
test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata)
|
||||
)
|
||||
]
|
||||
) as channel:
|
||||
stub = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||
response = await stub._unary_unary(
|
||||
"/service.Test/DoThing",
|
||||
DoThingRequest(THING_TO_DO),
|
||||
DoThingResponse,
|
||||
timeout=kwarg_timeout,
|
||||
metadata=kwarg_metadata,
|
||||
)
|
||||
assert response.names == [THING_TO_DO]
|
Loading…
x
Reference in New Issue
Block a user