Expose timeout, deadline and metadata parameters from grpclib (#352)

This commit is contained in:
Arun Babu Neelicattu 2022-03-13 23:34:11 +01:00 committed by GitHub
parent 62da35b3ea
commit 18a518efa7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 101 additions and 18 deletions

View File

@ -22,10 +22,10 @@ if TYPE_CHECKING:
from grpclib.metadata import Deadline from grpclib.metadata import Deadline
_Value = Union[str, bytes] Value = Union[str, bytes]
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]]
_MessageLike = Union[T, ST] MessageLike = Union[T, ST]
_MessageSource = Union[Iterable[ST], AsyncIterable[ST]] MessageSource = Union[Iterable[ST], AsyncIterable[ST]]
class ServiceStub(ABC): class ServiceStub(ABC):
@ -39,7 +39,7 @@ class ServiceStub(ABC):
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None, deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None, metadata: Optional[MetadataLike] = None,
) -> None: ) -> None:
self.channel = channel self.channel = channel
self.timeout = timeout self.timeout = timeout
@ -50,7 +50,7 @@ class ServiceStub(ABC):
self, self,
timeout: Optional[float], timeout: Optional[float],
deadline: Optional["Deadline"], deadline: Optional["Deadline"],
metadata: Optional[_MetadataLike], metadata: Optional[MetadataLike],
): ):
return { return {
"timeout": self.timeout if timeout is None else timeout, "timeout": self.timeout if timeout is None else timeout,
@ -61,12 +61,12 @@ class ServiceStub(ABC):
async def _unary_unary( async def _unary_unary(
self, self,
route: str, route: str,
request: _MessageLike, request: MessageLike,
response_type: Type[T], response_type: Type[T],
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None, deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None, metadata: Optional[MetadataLike] = None,
) -> T: ) -> T:
"""Make a unary request and return the response.""" """Make a unary request and return the response."""
async with self.channel.request( async with self.channel.request(
@ -84,12 +84,12 @@ class ServiceStub(ABC):
async def _unary_stream( async def _unary_stream(
self, self,
route: str, route: str,
request: _MessageLike, request: MessageLike,
response_type: Type[T], response_type: Type[T],
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None, deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None, metadata: Optional[MetadataLike] = None,
) -> AsyncIterator[T]: ) -> AsyncIterator[T]:
"""Make a unary request and return the stream response iterator.""" """Make a unary request and return the stream response iterator."""
async with self.channel.request( async with self.channel.request(
@ -106,13 +106,13 @@ class ServiceStub(ABC):
async def _stream_unary( async def _stream_unary(
self, self,
route: str, route: str,
request_iterator: _MessageSource, request_iterator: MessageSource,
request_type: Type[ST], request_type: Type[ST],
response_type: Type[T], response_type: Type[T],
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None, deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None, metadata: Optional[MetadataLike] = None,
) -> T: ) -> T:
"""Make a stream request and return the response.""" """Make a stream request and return the response."""
async with self.channel.request( async with self.channel.request(
@ -130,13 +130,13 @@ class ServiceStub(ABC):
async def _stream_stream( async def _stream_stream(
self, self,
route: str, route: str,
request_iterator: _MessageSource, request_iterator: MessageSource,
request_type: Type[ST], request_type: Type[ST],
response_type: Type[T], response_type: Type[T],
*, *,
timeout: Optional[float] = None, timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None, deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None, metadata: Optional[MetadataLike] = None,
) -> AsyncIterator[T]: ) -> AsyncIterator[T]:
""" """
Make a stream request and return an AsyncIterator to iterate over response Make a stream request and return an AsyncIterator to iterate over response
@ -161,7 +161,7 @@ class ServiceStub(ABC):
raise raise
@staticmethod @staticmethod
async def _send_messages(stream, messages: _MessageSource): async def _send_messages(stream, messages: MessageSource):
if isinstance(messages, AsyncIterable): if isinstance(messages, AsyncIterable):
async for message in messages: async for message in messages:
await stream.send_message(message) await stream.send_message(message)

View File

@ -232,6 +232,7 @@ class OutputTemplate:
messages: List["MessageCompiler"] = field(default_factory=list) messages: List["MessageCompiler"] = field(default_factory=list)
enums: List["EnumDefinitionCompiler"] = field(default_factory=list) enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
services: List["ServiceCompiler"] = field(default_factory=list) services: List["ServiceCompiler"] = field(default_factory=list)
imports_type_checking_only: Set[str] = field(default_factory=set)
@property @property
def package(self) -> str: def package(self) -> str:
@ -679,6 +680,15 @@ class ServiceMethodCompiler(ProtoContentBase):
if self.client_streaming or self.server_streaming: if self.client_streaming or self.server_streaming:
self.output_file.typing_imports.add("AsyncIterator") self.output_file.typing_imports.add("AsyncIterator")
# add imports required for request arguments timeout, deadline and metadata
self.output_file.typing_imports.add("Optional")
self.output_file.imports_type_checking_only.add(
"from betterproto.grpc.grpclib_client import MetadataLike"
)
self.output_file.imports_type_checking_only.add(
"from grpclib.metadata import Deadline"
)
super().__post_init__() # check for unset fields super().__post_init__() # check for unset fields
@property @property

View File

@ -20,6 +20,13 @@ from betterproto.grpc.grpclib_server import ServiceBase
import grpclib import grpclib
{% endif %} {% endif %}
{% if output_file.imports_type_checking_only %}
from typing import TYPE_CHECKING
if TYPE_CHECKING:
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
{% endfor %}
{% endif %}
{% if output_file.enums %}{% for enum in output_file.enums %} {% if output_file.enums %}{% for enum in output_file.enums %}
class {{ enum.py_name }}(betterproto.Enum): class {{ enum.py_name }}(betterproto.Enum):
@ -86,6 +93,9 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{# Client streaming: need a request iterator instead #} {# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]] , {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
{%- endif -%} {%- endif -%}
, timeout: Optional[float] = None
, deadline: Optional["Deadline"] = None
, metadata: Optional["_MetadataLike"] = None
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %} {% if method.comment %}
{{ method.comment }} {{ method.comment }}
@ -98,6 +108,9 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{{ method.py_input_message_param }}_iterator, {{ method.py_input_message_param }}_iterator,
{{ method.py_input_message_type }}, {{ method.py_input_message_type }},
{{ method.py_output_message_type.strip('"') }}, {{ method.py_output_message_type.strip('"') }},
timeout=timeout,
deadline=deadline,
metadata=metadata,
): ):
yield response yield response
{% else %}{# i.e. not client streaming #} {% else %}{# i.e. not client streaming #}
@ -105,6 +118,9 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
"{{ method.route }}", "{{ method.route }}",
{{ method.py_input_message_param }}, {{ method.py_input_message_param }},
{{ method.py_output_message_type.strip('"') }}, {{ method.py_output_message_type.strip('"') }},
timeout=timeout,
deadline=deadline,
metadata=metadata,
): ):
yield response yield response
@ -115,13 +131,19 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
"{{ method.route }}", "{{ method.route }}",
{{ method.py_input_message_param }}_iterator, {{ method.py_input_message_param }}_iterator,
{{ method.py_input_message_type }}, {{ method.py_input_message_type }},
{{ method.py_output_message_type.strip('"') }} {{ method.py_output_message_type.strip('"') }},
timeout=timeout,
deadline=deadline,
metadata=metadata,
) )
{% else %}{# i.e. not client streaming #} {% else %}{# i.e. not client streaming #}
return await self._unary_unary( return await self._unary_unary(
"{{ method.route }}", "{{ method.route }}",
{{ method.py_input_message_param }}, {{ method.py_input_message_param }},
{{ method.py_output_message_type.strip('"') }} {{ method.py_output_message_type.strip('"') }},
timeout=timeout,
deadline=deadline,
metadata=metadata,
) )
{% endif %}{# client streaming #} {% endif %}{# client streaming #}
{% endif %} {% endif %}

View File

@ -1,9 +1,11 @@
import asyncio import asyncio
import sys import sys
import uuid
import grpclib import grpclib
import grpclib.metadata import grpclib.metadata
import grpclib.server import grpclib.server
import grpclib.client
import pytest import pytest
from betterproto.grpc.util.async_channel import AsyncChannel from betterproto.grpc.util.async_channel import AsyncChannel
from grpclib.testing import ChannelFor from grpclib.testing import ChannelFor
@ -18,7 +20,7 @@ from .thing_service import ThingService
async def _test_client(client: ThingServiceClient, name="clean room", **kwargs): async def _test_client(client: ThingServiceClient, name="clean room", **kwargs):
response = await client.do_thing(DoThingRequest(name=name)) response = await client.do_thing(DoThingRequest(name=name), **kwargs)
assert response.names == [name] assert response.names == [name]
@ -172,6 +174,55 @@ async def test_service_call_lower_level_with_overrides():
assert response.names == [THING_TO_DO] assert response.names == [THING_TO_DO]
@pytest.mark.asyncio
@pytest.mark.parametrize(
("overrides",),
[
(dict(timeout=10),),
(dict(deadline=grpclib.metadata.Deadline.from_timeout(10)),),
(dict(metadata={"authorization": str(uuid.uuid4())}),),
(dict(timeout=20, metadata={"authorization": str(uuid.uuid4())}),),
],
)
async def test_service_call_high_level_with_overrides(mocker, overrides):
request_spy = mocker.spy(grpclib.client.Channel, "request")
name = str(uuid.uuid4())
defaults = dict(
timeout=99,
deadline=grpclib.metadata.Deadline.from_timeout(99),
metadata={"authorization": name},
)
async with ChannelFor(
[
ThingService(
test_hook=_assert_request_meta_received(
deadline=grpclib.metadata.Deadline.from_timeout(
overrides.get("timeout", 99)
),
metadata=overrides.get("metadata", defaults.get("metadata")),
)
)
]
) as channel:
client = ThingServiceClient(channel, **defaults)
await _test_client(client, name=name, **overrides)
assert request_spy.call_count == 1
# for python <3.8 request_spy.call_args.kwargs do not work
_, request_spy_call_kwargs = request_spy.call_args_list[0]
# ensure all overrides were successful
for key, value in overrides.items():
assert key in request_spy_call_kwargs
assert request_spy_call_kwargs[key] == value
# ensure default values were retained
for key in set(defaults.keys()) - set(overrides.keys()):
assert key in request_spy_call_kwargs
assert request_spy_call_kwargs[key] == defaults[key]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_gen_for_unary_stream_request(): async def test_async_gen_for_unary_stream_request():
thing_name = "my milkshakes" thing_name = "my milkshakes"