Expose timeout, deadline and metadata parameters from grpclib (#352)
This commit is contained in:
parent
62da35b3ea
commit
18a518efa7
@ -22,10 +22,10 @@ if TYPE_CHECKING:
|
|||||||
from grpclib.metadata import Deadline
|
from grpclib.metadata import Deadline
|
||||||
|
|
||||||
|
|
||||||
_Value = Union[str, bytes]
|
Value = Union[str, bytes]
|
||||||
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
|
MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]]
|
||||||
_MessageLike = Union[T, ST]
|
MessageLike = Union[T, ST]
|
||||||
_MessageSource = Union[Iterable[ST], AsyncIterable[ST]]
|
MessageSource = Union[Iterable[ST], AsyncIterable[ST]]
|
||||||
|
|
||||||
|
|
||||||
class ServiceStub(ABC):
|
class ServiceStub(ABC):
|
||||||
@ -39,7 +39,7 @@ class ServiceStub(ABC):
|
|||||||
*,
|
*,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
deadline: Optional["Deadline"] = None,
|
deadline: Optional["Deadline"] = None,
|
||||||
metadata: Optional[_MetadataLike] = None,
|
metadata: Optional[MetadataLike] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.channel = channel
|
self.channel = channel
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
@ -50,7 +50,7 @@ class ServiceStub(ABC):
|
|||||||
self,
|
self,
|
||||||
timeout: Optional[float],
|
timeout: Optional[float],
|
||||||
deadline: Optional["Deadline"],
|
deadline: Optional["Deadline"],
|
||||||
metadata: Optional[_MetadataLike],
|
metadata: Optional[MetadataLike],
|
||||||
):
|
):
|
||||||
return {
|
return {
|
||||||
"timeout": self.timeout if timeout is None else timeout,
|
"timeout": self.timeout if timeout is None else timeout,
|
||||||
@ -61,12 +61,12 @@ class ServiceStub(ABC):
|
|||||||
async def _unary_unary(
|
async def _unary_unary(
|
||||||
self,
|
self,
|
||||||
route: str,
|
route: str,
|
||||||
request: _MessageLike,
|
request: MessageLike,
|
||||||
response_type: Type[T],
|
response_type: Type[T],
|
||||||
*,
|
*,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
deadline: Optional["Deadline"] = None,
|
deadline: Optional["Deadline"] = None,
|
||||||
metadata: Optional[_MetadataLike] = None,
|
metadata: Optional[MetadataLike] = None,
|
||||||
) -> T:
|
) -> T:
|
||||||
"""Make a unary request and return the response."""
|
"""Make a unary request and return the response."""
|
||||||
async with self.channel.request(
|
async with self.channel.request(
|
||||||
@ -84,12 +84,12 @@ class ServiceStub(ABC):
|
|||||||
async def _unary_stream(
|
async def _unary_stream(
|
||||||
self,
|
self,
|
||||||
route: str,
|
route: str,
|
||||||
request: _MessageLike,
|
request: MessageLike,
|
||||||
response_type: Type[T],
|
response_type: Type[T],
|
||||||
*,
|
*,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
deadline: Optional["Deadline"] = None,
|
deadline: Optional["Deadline"] = None,
|
||||||
metadata: Optional[_MetadataLike] = None,
|
metadata: Optional[MetadataLike] = None,
|
||||||
) -> AsyncIterator[T]:
|
) -> AsyncIterator[T]:
|
||||||
"""Make a unary request and return the stream response iterator."""
|
"""Make a unary request and return the stream response iterator."""
|
||||||
async with self.channel.request(
|
async with self.channel.request(
|
||||||
@ -106,13 +106,13 @@ class ServiceStub(ABC):
|
|||||||
async def _stream_unary(
|
async def _stream_unary(
|
||||||
self,
|
self,
|
||||||
route: str,
|
route: str,
|
||||||
request_iterator: _MessageSource,
|
request_iterator: MessageSource,
|
||||||
request_type: Type[ST],
|
request_type: Type[ST],
|
||||||
response_type: Type[T],
|
response_type: Type[T],
|
||||||
*,
|
*,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
deadline: Optional["Deadline"] = None,
|
deadline: Optional["Deadline"] = None,
|
||||||
metadata: Optional[_MetadataLike] = None,
|
metadata: Optional[MetadataLike] = None,
|
||||||
) -> T:
|
) -> T:
|
||||||
"""Make a stream request and return the response."""
|
"""Make a stream request and return the response."""
|
||||||
async with self.channel.request(
|
async with self.channel.request(
|
||||||
@ -130,13 +130,13 @@ class ServiceStub(ABC):
|
|||||||
async def _stream_stream(
|
async def _stream_stream(
|
||||||
self,
|
self,
|
||||||
route: str,
|
route: str,
|
||||||
request_iterator: _MessageSource,
|
request_iterator: MessageSource,
|
||||||
request_type: Type[ST],
|
request_type: Type[ST],
|
||||||
response_type: Type[T],
|
response_type: Type[T],
|
||||||
*,
|
*,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
deadline: Optional["Deadline"] = None,
|
deadline: Optional["Deadline"] = None,
|
||||||
metadata: Optional[_MetadataLike] = None,
|
metadata: Optional[MetadataLike] = None,
|
||||||
) -> AsyncIterator[T]:
|
) -> AsyncIterator[T]:
|
||||||
"""
|
"""
|
||||||
Make a stream request and return an AsyncIterator to iterate over response
|
Make a stream request and return an AsyncIterator to iterate over response
|
||||||
@ -161,7 +161,7 @@ class ServiceStub(ABC):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _send_messages(stream, messages: _MessageSource):
|
async def _send_messages(stream, messages: MessageSource):
|
||||||
if isinstance(messages, AsyncIterable):
|
if isinstance(messages, AsyncIterable):
|
||||||
async for message in messages:
|
async for message in messages:
|
||||||
await stream.send_message(message)
|
await stream.send_message(message)
|
||||||
|
@ -232,6 +232,7 @@ class OutputTemplate:
|
|||||||
messages: List["MessageCompiler"] = field(default_factory=list)
|
messages: List["MessageCompiler"] = field(default_factory=list)
|
||||||
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
|
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
|
||||||
services: List["ServiceCompiler"] = field(default_factory=list)
|
services: List["ServiceCompiler"] = field(default_factory=list)
|
||||||
|
imports_type_checking_only: Set[str] = field(default_factory=set)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def package(self) -> str:
|
def package(self) -> str:
|
||||||
@ -679,6 +680,15 @@ class ServiceMethodCompiler(ProtoContentBase):
|
|||||||
if self.client_streaming or self.server_streaming:
|
if self.client_streaming or self.server_streaming:
|
||||||
self.output_file.typing_imports.add("AsyncIterator")
|
self.output_file.typing_imports.add("AsyncIterator")
|
||||||
|
|
||||||
|
# add imports required for request arguments timeout, deadline and metadata
|
||||||
|
self.output_file.typing_imports.add("Optional")
|
||||||
|
self.output_file.imports_type_checking_only.add(
|
||||||
|
"from betterproto.grpc.grpclib_client import MetadataLike"
|
||||||
|
)
|
||||||
|
self.output_file.imports_type_checking_only.add(
|
||||||
|
"from grpclib.metadata import Deadline"
|
||||||
|
)
|
||||||
|
|
||||||
super().__post_init__() # check for unset fields
|
super().__post_init__() # check for unset fields
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -20,6 +20,13 @@ from betterproto.grpc.grpclib_server import ServiceBase
|
|||||||
import grpclib
|
import grpclib
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
|
{% if output_file.imports_type_checking_only %}
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
{% if output_file.enums %}{% for enum in output_file.enums %}
|
{% if output_file.enums %}{% for enum in output_file.enums %}
|
||||||
class {{ enum.py_name }}(betterproto.Enum):
|
class {{ enum.py_name }}(betterproto.Enum):
|
||||||
@ -86,6 +93,9 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
|||||||
{# Client streaming: need a request iterator instead #}
|
{# Client streaming: need a request iterator instead #}
|
||||||
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
|
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
|
||||||
{%- endif -%}
|
{%- endif -%}
|
||||||
|
, timeout: Optional[float] = None
|
||||||
|
, deadline: Optional["Deadline"] = None
|
||||||
|
, metadata: Optional["_MetadataLike"] = None
|
||||||
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||||
{% if method.comment %}
|
{% if method.comment %}
|
||||||
{{ method.comment }}
|
{{ method.comment }}
|
||||||
@ -98,6 +108,9 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
|||||||
{{ method.py_input_message_param }}_iterator,
|
{{ method.py_input_message_param }}_iterator,
|
||||||
{{ method.py_input_message_type }},
|
{{ method.py_input_message_type }},
|
||||||
{{ method.py_output_message_type.strip('"') }},
|
{{ method.py_output_message_type.strip('"') }},
|
||||||
|
timeout=timeout,
|
||||||
|
deadline=deadline,
|
||||||
|
metadata=metadata,
|
||||||
):
|
):
|
||||||
yield response
|
yield response
|
||||||
{% else %}{# i.e. not client streaming #}
|
{% else %}{# i.e. not client streaming #}
|
||||||
@ -105,6 +118,9 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
|||||||
"{{ method.route }}",
|
"{{ method.route }}",
|
||||||
{{ method.py_input_message_param }},
|
{{ method.py_input_message_param }},
|
||||||
{{ method.py_output_message_type.strip('"') }},
|
{{ method.py_output_message_type.strip('"') }},
|
||||||
|
timeout=timeout,
|
||||||
|
deadline=deadline,
|
||||||
|
metadata=metadata,
|
||||||
):
|
):
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
@ -115,13 +131,19 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
|||||||
"{{ method.route }}",
|
"{{ method.route }}",
|
||||||
{{ method.py_input_message_param }}_iterator,
|
{{ method.py_input_message_param }}_iterator,
|
||||||
{{ method.py_input_message_type }},
|
{{ method.py_input_message_type }},
|
||||||
{{ method.py_output_message_type.strip('"') }}
|
{{ method.py_output_message_type.strip('"') }},
|
||||||
|
timeout=timeout,
|
||||||
|
deadline=deadline,
|
||||||
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
{% else %}{# i.e. not client streaming #}
|
{% else %}{# i.e. not client streaming #}
|
||||||
return await self._unary_unary(
|
return await self._unary_unary(
|
||||||
"{{ method.route }}",
|
"{{ method.route }}",
|
||||||
{{ method.py_input_message_param }},
|
{{ method.py_input_message_param }},
|
||||||
{{ method.py_output_message_type.strip('"') }}
|
{{ method.py_output_message_type.strip('"') }},
|
||||||
|
timeout=timeout,
|
||||||
|
deadline=deadline,
|
||||||
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
{% endif %}{# client streaming #}
|
{% endif %}{# client streaming #}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
|
import uuid
|
||||||
|
|
||||||
import grpclib
|
import grpclib
|
||||||
import grpclib.metadata
|
import grpclib.metadata
|
||||||
import grpclib.server
|
import grpclib.server
|
||||||
|
import grpclib.client
|
||||||
import pytest
|
import pytest
|
||||||
from betterproto.grpc.util.async_channel import AsyncChannel
|
from betterproto.grpc.util.async_channel import AsyncChannel
|
||||||
from grpclib.testing import ChannelFor
|
from grpclib.testing import ChannelFor
|
||||||
@ -18,7 +20,7 @@ from .thing_service import ThingService
|
|||||||
|
|
||||||
|
|
||||||
async def _test_client(client: ThingServiceClient, name="clean room", **kwargs):
|
async def _test_client(client: ThingServiceClient, name="clean room", **kwargs):
|
||||||
response = await client.do_thing(DoThingRequest(name=name))
|
response = await client.do_thing(DoThingRequest(name=name), **kwargs)
|
||||||
assert response.names == [name]
|
assert response.names == [name]
|
||||||
|
|
||||||
|
|
||||||
@ -172,6 +174,55 @@ async def test_service_call_lower_level_with_overrides():
|
|||||||
assert response.names == [THING_TO_DO]
|
assert response.names == [THING_TO_DO]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("overrides",),
|
||||||
|
[
|
||||||
|
(dict(timeout=10),),
|
||||||
|
(dict(deadline=grpclib.metadata.Deadline.from_timeout(10)),),
|
||||||
|
(dict(metadata={"authorization": str(uuid.uuid4())}),),
|
||||||
|
(dict(timeout=20, metadata={"authorization": str(uuid.uuid4())}),),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_service_call_high_level_with_overrides(mocker, overrides):
|
||||||
|
request_spy = mocker.spy(grpclib.client.Channel, "request")
|
||||||
|
name = str(uuid.uuid4())
|
||||||
|
defaults = dict(
|
||||||
|
timeout=99,
|
||||||
|
deadline=grpclib.metadata.Deadline.from_timeout(99),
|
||||||
|
metadata={"authorization": name},
|
||||||
|
)
|
||||||
|
|
||||||
|
async with ChannelFor(
|
||||||
|
[
|
||||||
|
ThingService(
|
||||||
|
test_hook=_assert_request_meta_received(
|
||||||
|
deadline=grpclib.metadata.Deadline.from_timeout(
|
||||||
|
overrides.get("timeout", 99)
|
||||||
|
),
|
||||||
|
metadata=overrides.get("metadata", defaults.get("metadata")),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
) as channel:
|
||||||
|
client = ThingServiceClient(channel, **defaults)
|
||||||
|
await _test_client(client, name=name, **overrides)
|
||||||
|
assert request_spy.call_count == 1
|
||||||
|
|
||||||
|
# for python <3.8 request_spy.call_args.kwargs do not work
|
||||||
|
_, request_spy_call_kwargs = request_spy.call_args_list[0]
|
||||||
|
|
||||||
|
# ensure all overrides were successful
|
||||||
|
for key, value in overrides.items():
|
||||||
|
assert key in request_spy_call_kwargs
|
||||||
|
assert request_spy_call_kwargs[key] == value
|
||||||
|
|
||||||
|
# ensure default values were retained
|
||||||
|
for key in set(defaults.keys()) - set(overrides.keys()):
|
||||||
|
assert key in request_spy_call_kwargs
|
||||||
|
assert request_spy_call_kwargs[key] == defaults[key]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_gen_for_unary_stream_request():
|
async def test_async_gen_for_unary_stream_request():
|
||||||
thing_name = "my milkshakes"
|
thing_name = "my milkshakes"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user