Move ServiceStub to a seperate module and add more rpcs to service test

This commit is contained in:
Nat Noordanus 2020-05-23 23:35:28 +02:00
parent a757da1b29
commit 09f821921f
8 changed files with 336 additions and 137 deletions

3
.gitignore vendored
View File

@ -9,4 +9,5 @@ betterproto/tests/output_*
dist
**/*.egg-info
output
.idea
.idea
.DS_Store

View File

@ -5,8 +5,9 @@ import json
import struct
import sys
from abc import ABC
from base64 import b64encode, b64decode
from base64 import b64decode, b64encode
from datetime import datetime, timedelta, timezone
import stringcase
from typing import (
Any,
AsyncGenerator,
@ -22,22 +23,12 @@ from typing import (
SupportsBytes,
Tuple,
Type,
TypeVar,
Union,
get_type_hints,
TYPE_CHECKING,
)
import grpclib.const
import stringcase
from ._types import ST, T
from .casing import safe_snake_case
if TYPE_CHECKING:
from grpclib._protocols import IProtoMessage
from grpclib.client import Channel
from grpclib.metadata import Deadline
from .grpc.grpclib_client import ServiceStub
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
# Apply backport of datetime.fromisoformat from 3.7
@ -431,11 +422,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
)
# Bound type variable to allow methods to return `self` of subclasses
T = TypeVar("T", bound="Message")
ST = TypeVar("ST", bound="IProtoMessage")
class ProtoClassMetadata:
oneof_group_by_field: Dict[str, str]
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
@ -1027,118 +1013,3 @@ def _get_wrapper(proto_type: str) -> Type:
TYPE_STRING: StringValue,
TYPE_BYTES: BytesValue,
}[proto_type]
_Value = Union[str, bytes]
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
class ServiceStub(ABC):
"""
Base class for async gRPC service stubs.
"""
def __init__(
self,
channel: "Channel",
*,
timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None,
) -> None:
self.channel = channel
self.timeout = timeout
self.deadline = deadline
self.metadata = metadata
def __resolve_request_kwargs(
self,
timeout: Optional[float],
deadline: Optional["Deadline"],
metadata: Optional[_MetadataLike],
):
return {
"timeout": self.timeout if timeout is None else timeout,
"deadline": self.deadline if deadline is None else deadline,
"metadata": self.metadata if metadata is None else metadata,
}
async def _unary_unary(
self,
route: str,
request: "IProtoMessage",
response_type: Type[T],
*,
timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None,
) -> T:
"""Make a unary request and return the response."""
async with self.channel.request(
route,
grpclib.const.Cardinality.UNARY_UNARY,
type(request),
response_type,
**self.__resolve_request_kwargs(timeout, deadline, metadata),
) as stream:
await stream.send_message(request, end=True)
response = await stream.recv_message()
assert response is not None
return response
async def _unary_stream(
self,
route: str,
request: "IProtoMessage",
response_type: Type[T],
*,
timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None,
) -> AsyncGenerator[T, None]:
"""Make a unary request and return the stream response iterator."""
async with self.channel.request(
route,
grpclib.const.Cardinality.UNARY_STREAM,
type(request),
response_type,
**self.__resolve_request_kwargs(timeout, deadline, metadata),
) as stream:
await stream.send_message(request, end=True)
async for message in stream:
yield message
async def _stream_unary(
self,
route: str,
request_iterator: Iterator["IProtoMessage"],
request_type: Type[ST],
response_type: Type[T],
) -> T:
"""Make a stream request and return the response."""
async with self.channel.request(
route, grpclib.const.Cardinality.STREAM_UNARY, request_type, response_type
) as stream:
for message in request_iterator:
await stream.send_message(message)
await stream.send_request(end=True)
response = await stream.recv_message()
assert response is not None
return response
async def _stream_stream(
self,
route: str,
request_iterator: Iterator["IProtoMessage"],
request_type: Type[ST],
response_type: Type[T],
) -> AsyncGenerator[T, None]:
"""Make a stream request and return the stream response iterator."""
async with self.channel.request(
route, grpclib.const.Cardinality.STREAM_STREAM, request_type, response_type
) as stream:
for message in request_iterator:
await stream.send_message(message)
await stream.send_request(end=True)
async for message in stream:
yield message

5
betterproto/_types.py Normal file
View File

@ -0,0 +1,5 @@
from typing import TypeVar
# Bound type variable to allow methods to return `self` of subclasses
T = TypeVar("T", bound="Message")
ST = TypeVar("ST", bound="IProtoMessage")

View File

View File

@ -0,0 +1,135 @@
from abc import ABC
import grpclib.const
from typing import (
AsyncGenerator,
AsyncIterator,
Collection,
Iterator,
Mapping,
Optional,
Tuple,
TYPE_CHECKING,
Type,
Union,
)
from .._types import ST, T
if TYPE_CHECKING:
from grpclib._protocols import IProtoMessage
from grpclib.client import Channel
from grpclib.metadata import Deadline
_Value = Union[str, bytes]
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
class ServiceStub(ABC):
"""
Base class for async gRPC service stubs.
"""
def __init__(
self,
channel: "Channel",
*,
timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None,
) -> None:
self.channel = channel
self.timeout = timeout
self.deadline = deadline
self.metadata = metadata
def __resolve_request_kwargs(
self,
timeout: Optional[float],
deadline: Optional["Deadline"],
metadata: Optional[_MetadataLike],
):
return {
"timeout": self.timeout if timeout is None else timeout,
"deadline": self.deadline if deadline is None else deadline,
"metadata": self.metadata if metadata is None else metadata,
}
async def _unary_unary(
self,
route: str,
request: "IProtoMessage",
response_type: Type[T],
*,
timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None,
) -> T:
"""Make a unary request and return the response."""
async with self.channel.request(
route,
grpclib.const.Cardinality.UNARY_UNARY,
type(request),
response_type,
**self.__resolve_request_kwargs(timeout, deadline, metadata),
) as stream:
await stream.send_message(request, end=True)
response = await stream.recv_message()
assert response is not None
return response
async def _unary_stream(
self,
route: str,
request: "IProtoMessage",
response_type: Type[T],
*,
timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None,
metadata: Optional[_MetadataLike] = None,
) -> AsyncGenerator[T, None]:
"""Make a unary request and return the stream response iterator."""
async with self.channel.request(
route,
grpclib.const.Cardinality.UNARY_STREAM,
type(request),
response_type,
**self.__resolve_request_kwargs(timeout, deadline, metadata),
) as stream:
await stream.send_message(request, end=True)
async for message in stream:
yield message
async def _stream_unary(
self,
route: str,
request_iterator: Iterator["IProtoMessage"],
request_type: Type[ST],
response_type: Type[T],
) -> T:
"""Make a stream request and return the response."""
async with self.channel.request(
route, grpclib.const.Cardinality.STREAM_UNARY, request_type, response_type
) as stream:
for message in request_iterator:
await stream.send_message(message)
await stream.send_request(end=True)
response = await stream.recv_message()
assert response is not None
return response
async def _stream_stream(
self,
route: str,
request_iterator: Iterator["IProtoMessage"],
request_type: Type[ST],
response_type: Type[T],
) -> AsyncGenerator[T, None]:
"""Make a stream request and return the stream response iterator."""
async with self.channel.request(
route, grpclib.const.Cardinality.STREAM_STREAM, request_type, response_type
) as stream:
for message in request_iterator:
await stream.send_message(message)
await stream.send_request(end=True)
async for message in stream:
yield message

View File

@ -311,7 +311,6 @@ def generate_code(request, response):
}
for j, method in enumerate(service.method):
input_message = None
input_type = get_ref_type(
package, output["imports"], method.input_type

View File

@ -3,13 +3,25 @@ syntax = "proto3";
package service;
message DoThingRequest {
int32 iterations = 1;
string name = 1;
}
message DoThingResponse {
int32 successfulIterations = 1;
repeated string names = 1;
}
message GetThingRequest {
string name = 1;
}
message GetThingResponse {
string name = 1;
int32 version = 2;
}
service Test {
rpc DoThing (DoThingRequest) returns (DoThingResponse);
rpc DoManyThings (stream DoThingRequest) returns (DoThingResponse);
rpc GetThingVersions (GetThingRequest) returns (stream GetThingResponse);
rpc GetDifferentThings (stream GetThingRequest) returns (stream GetThingResponse);
}

View File

@ -0,0 +1,176 @@
import betterproto
import grpclib
from grpclib.testing import ChannelFor
import pytest
from typing import Dict
from betterproto.tests.output_betterproto.service.service import (
DoThingResponse,
DoThingRequest,
GetThingRequest,
GetThingResponse,
TestStub as ThingServiceClient,
)
class ThingService:
def __init__(self, test_hook=None):
# This lets us pass assertions to the servicer ;)
self.test_hook = test_hook
async def DoThing(
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
):
request = await stream.recv_message()
if self.test_hook is not None:
self.test_hook(stream)
await stream.send_message(DoThingResponse([request.name]))
async def DoManyThings(
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
):
thing_names = [request.name for request in stream]
if self.test_hook is not None:
self.test_hook(stream)
await stream.send_message(DoThingResponse(thing_names))
async def GetThingVersions(
self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
):
request = await stream.recv_message()
if self.test_hook is not None:
self.test_hook(stream)
for version_num in range(1, 6):
await stream.send_message(
GetThingResponse(name=request, version=version_num)
)
async def GetDifferentThings(
self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
):
if self.test_hook is not None:
self.test_hook(stream)
# Response to each input item immediately
for request in stream:
await stream.send_message(GetThingResponse(name=request.name, version=1))
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
return {
"/service.Test/DoThing": grpclib.const.Handler(
self.DoThing,
grpclib.const.Cardinality.UNARY_UNARY,
DoThingRequest,
DoThingResponse,
),
"/service.Test/DoManyThings": grpclib.const.Handler(
self.DoManyThings,
grpclib.const.Cardinality.STREAM_UNARY,
DoThingRequest,
DoThingResponse,
),
"/service.Test/GetThingVersions": grpclib.const.Handler(
self.GetThingVersions,
grpclib.const.Cardinality.UNARY_STREAM,
GetThingRequest,
GetThingResponse,
),
"/service.Test/GetDifferentThings": grpclib.const.Handler(
self.GetDifferentThings,
grpclib.const.Cardinality.STREAM_STREAM,
GetThingRequest,
GetThingResponse,
),
}
async def _test_stub(stub, name="clean room", **kwargs):
response = await stub.do_thing(name=name)
assert response.names == [name]
def _assert_request_meta_recieved(deadline, metadata):
def server_side_test(stream):
assert stream.deadline._timestamp == pytest.approx(
deadline._timestamp, 1
), "The provided deadline should be recieved serverside"
assert (
stream.metadata["authorization"] == metadata["authorization"]
), "The provided authorization metadata should be recieved serverside"
return server_side_test
@pytest.mark.asyncio
async def test_simple_service_call():
async with ChannelFor([ThingService()]) as channel:
await _test_stub(ThingServiceClient(channel))
@pytest.mark.asyncio
async def test_service_call_with_upfront_request_params():
# Setting deadline
deadline = grpclib.metadata.Deadline.from_timeout(22)
metadata = {"authorization": "12345"}
async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata))]
) as channel:
await _test_stub(
ThingServiceClient(channel, deadline=deadline, metadata=metadata)
)
# Setting timeout
timeout = 99
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
metadata = {"authorization": "12345"}
async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata))]
) as channel:
await _test_stub(
ThingServiceClient(channel, timeout=timeout, metadata=metadata)
)
@pytest.mark.asyncio
async def test_service_call_lower_level_with_overrides():
THING_TO_DO = "get milk"
# Setting deadline
deadline = grpclib.metadata.Deadline.from_timeout(22)
metadata = {"authorization": "12345"}
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
kwarg_metadata = {"authorization": "12345"}
async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata))]
) as channel:
stub = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
response = await stub._unary_unary(
"/service.Test/DoThing",
DoThingRequest(THING_TO_DO),
DoThingResponse,
deadline=kwarg_deadline,
metadata=kwarg_metadata,
)
assert response.names == [THING_TO_DO]
# Setting timeout
timeout = 99
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
metadata = {"authorization": "12345"}
kwarg_timeout = 9000
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
kwarg_metadata = {"authorization": "09876"}
async with ChannelFor(
[
ThingService(
test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata)
)
]
) as channel:
stub = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
response = await stub._unary_unary(
"/service.Test/DoThing",
DoThingRequest(THING_TO_DO),
DoThingResponse,
timeout=kwarg_timeout,
metadata=kwarg_metadata,
)
assert response.names == [THING_TO_DO]