From 09f821921f9b680c0e8b70020393097643fc466f Mon Sep 17 00:00:00 2001 From: Nat Noordanus Date: Sat, 23 May 2020 23:35:28 +0200 Subject: [PATCH] Move ServiceStub to a seperate module and add more rpcs to service test --- .gitignore | 3 +- betterproto/__init__.py | 137 +------------- betterproto/_types.py | 5 + betterproto/grpc/__init__.py | 0 betterproto/grpc/grpclib_client.py | 135 ++++++++++++++ betterproto/plugin.py | 1 - .../tests/inputs/service/service.proto | 16 +- betterproto/tests/test_service_client.py | 176 ++++++++++++++++++ 8 files changed, 336 insertions(+), 137 deletions(-) create mode 100644 betterproto/_types.py create mode 100644 betterproto/grpc/__init__.py create mode 100644 betterproto/grpc/grpclib_client.py create mode 100644 betterproto/tests/test_service_client.py diff --git a/.gitignore b/.gitignore index dd22728..4ae66e7 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ betterproto/tests/output_* dist **/*.egg-info output -.idea \ No newline at end of file +.idea +.DS_Store diff --git a/betterproto/__init__.py b/betterproto/__init__.py index a2e7a18..8288aaf 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -5,8 +5,9 @@ import json import struct import sys from abc import ABC -from base64 import b64encode, b64decode +from base64 import b64decode, b64encode from datetime import datetime, timedelta, timezone +import stringcase from typing import ( Any, AsyncGenerator, @@ -22,22 +23,12 @@ from typing import ( SupportsBytes, Tuple, Type, - TypeVar, Union, get_type_hints, - TYPE_CHECKING, ) - - -import grpclib.const -import stringcase - +from ._types import ST, T from .casing import safe_snake_case - -if TYPE_CHECKING: - from grpclib._protocols import IProtoMessage - from grpclib.client import Channel - from grpclib.metadata import Deadline +from .grpc.grpclib_client import ServiceStub if not (sys.version_info.major == 3 and sys.version_info.minor >= 7): # Apply backport of datetime.fromisoformat from 3.7 @@ -431,11 +422,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: ) -# Bound type variable to allow methods to return `self` of subclasses -T = TypeVar("T", bound="Message") -ST = TypeVar("ST", bound="IProtoMessage") - - class ProtoClassMetadata: oneof_group_by_field: Dict[str, str] oneof_field_by_group: Dict[str, Set[dataclasses.Field]] @@ -1027,118 +1013,3 @@ def _get_wrapper(proto_type: str) -> Type: TYPE_STRING: StringValue, TYPE_BYTES: BytesValue, }[proto_type] - - -_Value = Union[str, bytes] -_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] - - -class ServiceStub(ABC): - """ - Base class for async gRPC service stubs. - """ - - def __init__( - self, - channel: "Channel", - *, - timeout: Optional[float] = None, - deadline: Optional["Deadline"] = None, - metadata: Optional[_MetadataLike] = None, - ) -> None: - self.channel = channel - self.timeout = timeout - self.deadline = deadline - self.metadata = metadata - - def __resolve_request_kwargs( - self, - timeout: Optional[float], - deadline: Optional["Deadline"], - metadata: Optional[_MetadataLike], - ): - return { - "timeout": self.timeout if timeout is None else timeout, - "deadline": self.deadline if deadline is None else deadline, - "metadata": self.metadata if metadata is None else metadata, - } - - async def _unary_unary( - self, - route: str, - request: "IProtoMessage", - response_type: Type[T], - *, - timeout: Optional[float] = None, - deadline: Optional["Deadline"] = None, - metadata: Optional[_MetadataLike] = None, - ) -> T: - """Make a unary request and return the response.""" - async with self.channel.request( - route, - grpclib.const.Cardinality.UNARY_UNARY, - type(request), - response_type, - **self.__resolve_request_kwargs(timeout, deadline, metadata), - ) as stream: - await stream.send_message(request, end=True) - response = await stream.recv_message() - assert response is not None - return response - - async def _unary_stream( - self, - route: str, - request: "IProtoMessage", - response_type: Type[T], - *, - timeout: Optional[float] = None, - deadline: Optional["Deadline"] = None, - metadata: Optional[_MetadataLike] = None, - ) -> AsyncGenerator[T, None]: - """Make a unary request and return the stream response iterator.""" - async with self.channel.request( - route, - grpclib.const.Cardinality.UNARY_STREAM, - type(request), - response_type, - **self.__resolve_request_kwargs(timeout, deadline, metadata), - ) as stream: - await stream.send_message(request, end=True) - async for message in stream: - yield message - - async def _stream_unary( - self, - route: str, - request_iterator: Iterator["IProtoMessage"], - request_type: Type[ST], - response_type: Type[T], - ) -> T: - """Make a stream request and return the response.""" - async with self.channel.request( - route, grpclib.const.Cardinality.STREAM_UNARY, request_type, response_type - ) as stream: - for message in request_iterator: - await stream.send_message(message) - await stream.send_request(end=True) - response = await stream.recv_message() - assert response is not None - return response - - async def _stream_stream( - self, - route: str, - request_iterator: Iterator["IProtoMessage"], - request_type: Type[ST], - response_type: Type[T], - ) -> AsyncGenerator[T, None]: - """Make a stream request and return the stream response iterator.""" - async with self.channel.request( - route, grpclib.const.Cardinality.STREAM_STREAM, request_type, response_type - ) as stream: - for message in request_iterator: - await stream.send_message(message) - await stream.send_request(end=True) - async for message in stream: - yield message diff --git a/betterproto/_types.py b/betterproto/_types.py new file mode 100644 index 0000000..0ff23e4 --- /dev/null +++ b/betterproto/_types.py @@ -0,0 +1,5 @@ +from typing import TypeVar + +# Bound type variable to allow methods to return `self` of subclasses +T = TypeVar("T", bound="Message") +ST = TypeVar("ST", bound="IProtoMessage") diff --git a/betterproto/grpc/__init__.py b/betterproto/grpc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/betterproto/grpc/grpclib_client.py b/betterproto/grpc/grpclib_client.py new file mode 100644 index 0000000..757982e --- /dev/null +++ b/betterproto/grpc/grpclib_client.py @@ -0,0 +1,135 @@ +from abc import ABC +import grpclib.const +from typing import ( + AsyncGenerator, + AsyncIterator, + Collection, + Iterator, + Mapping, + Optional, + Tuple, + TYPE_CHECKING, + Type, + Union, +) +from .._types import ST, T + +if TYPE_CHECKING: + from grpclib._protocols import IProtoMessage + from grpclib.client import Channel + from grpclib.metadata import Deadline + + +_Value = Union[str, bytes] +_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] + + +class ServiceStub(ABC): + """ + Base class for async gRPC service stubs. + """ + + def __init__( + self, + channel: "Channel", + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[_MetadataLike] = None, + ) -> None: + self.channel = channel + self.timeout = timeout + self.deadline = deadline + self.metadata = metadata + + def __resolve_request_kwargs( + self, + timeout: Optional[float], + deadline: Optional["Deadline"], + metadata: Optional[_MetadataLike], + ): + return { + "timeout": self.timeout if timeout is None else timeout, + "deadline": self.deadline if deadline is None else deadline, + "metadata": self.metadata if metadata is None else metadata, + } + + async def _unary_unary( + self, + route: str, + request: "IProtoMessage", + response_type: Type[T], + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[_MetadataLike] = None, + ) -> T: + """Make a unary request and return the response.""" + async with self.channel.request( + route, + grpclib.const.Cardinality.UNARY_UNARY, + type(request), + response_type, + **self.__resolve_request_kwargs(timeout, deadline, metadata), + ) as stream: + await stream.send_message(request, end=True) + response = await stream.recv_message() + assert response is not None + return response + + async def _unary_stream( + self, + route: str, + request: "IProtoMessage", + response_type: Type[T], + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[_MetadataLike] = None, + ) -> AsyncGenerator[T, None]: + """Make a unary request and return the stream response iterator.""" + async with self.channel.request( + route, + grpclib.const.Cardinality.UNARY_STREAM, + type(request), + response_type, + **self.__resolve_request_kwargs(timeout, deadline, metadata), + ) as stream: + await stream.send_message(request, end=True) + async for message in stream: + yield message + + async def _stream_unary( + self, + route: str, + request_iterator: Iterator["IProtoMessage"], + request_type: Type[ST], + response_type: Type[T], + ) -> T: + """Make a stream request and return the response.""" + async with self.channel.request( + route, grpclib.const.Cardinality.STREAM_UNARY, request_type, response_type + ) as stream: + for message in request_iterator: + await stream.send_message(message) + await stream.send_request(end=True) + response = await stream.recv_message() + assert response is not None + return response + + async def _stream_stream( + self, + route: str, + request_iterator: Iterator["IProtoMessage"], + request_type: Type[ST], + response_type: Type[T], + ) -> AsyncGenerator[T, None]: + """Make a stream request and return the stream response iterator.""" + async with self.channel.request( + route, grpclib.const.Cardinality.STREAM_STREAM, request_type, response_type + ) as stream: + for message in request_iterator: + await stream.send_message(message) + await stream.send_request(end=True) + async for message in stream: + yield message diff --git a/betterproto/plugin.py b/betterproto/plugin.py index b877ce6..44515d5 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -311,7 +311,6 @@ def generate_code(request, response): } for j, method in enumerate(service.method): - input_message = None input_type = get_ref_type( package, output["imports"], method.input_type diff --git a/betterproto/tests/inputs/service/service.proto b/betterproto/tests/inputs/service/service.proto index 7c931ed..acfbcdd 100644 --- a/betterproto/tests/inputs/service/service.proto +++ b/betterproto/tests/inputs/service/service.proto @@ -3,13 +3,25 @@ syntax = "proto3"; package service; message DoThingRequest { - int32 iterations = 1; + string name = 1; } message DoThingResponse { - int32 successfulIterations = 1; + repeated string names = 1; +} + +message GetThingRequest { + string name = 1; +} + +message GetThingResponse { + string name = 1; + int32 version = 2; } service Test { rpc DoThing (DoThingRequest) returns (DoThingResponse); + rpc DoManyThings (stream DoThingRequest) returns (DoThingResponse); + rpc GetThingVersions (GetThingRequest) returns (stream GetThingResponse); + rpc GetDifferentThings (stream GetThingRequest) returns (stream GetThingResponse); } diff --git a/betterproto/tests/test_service_client.py b/betterproto/tests/test_service_client.py new file mode 100644 index 0000000..586095d --- /dev/null +++ b/betterproto/tests/test_service_client.py @@ -0,0 +1,176 @@ +import betterproto +import grpclib +from grpclib.testing import ChannelFor +import pytest +from typing import Dict +from betterproto.tests.output_betterproto.service.service import ( + DoThingResponse, + DoThingRequest, + GetThingRequest, + GetThingResponse, + TestStub as ThingServiceClient, +) + + +class ThingService: + def __init__(self, test_hook=None): + # This lets us pass assertions to the servicer ;) + self.test_hook = test_hook + + async def DoThing( + self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" + ): + request = await stream.recv_message() + if self.test_hook is not None: + self.test_hook(stream) + await stream.send_message(DoThingResponse([request.name])) + + async def DoManyThings( + self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" + ): + thing_names = [request.name for request in stream] + if self.test_hook is not None: + self.test_hook(stream) + await stream.send_message(DoThingResponse(thing_names)) + + async def GetThingVersions( + self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" + ): + request = await stream.recv_message() + if self.test_hook is not None: + self.test_hook(stream) + for version_num in range(1, 6): + await stream.send_message( + GetThingResponse(name=request, version=version_num) + ) + + async def GetDifferentThings( + self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" + ): + if self.test_hook is not None: + self.test_hook(stream) + # Response to each input item immediately + for request in stream: + await stream.send_message(GetThingResponse(name=request.name, version=1)) + + def __mapping__(self) -> Dict[str, grpclib.const.Handler]: + return { + "/service.Test/DoThing": grpclib.const.Handler( + self.DoThing, + grpclib.const.Cardinality.UNARY_UNARY, + DoThingRequest, + DoThingResponse, + ), + "/service.Test/DoManyThings": grpclib.const.Handler( + self.DoManyThings, + grpclib.const.Cardinality.STREAM_UNARY, + DoThingRequest, + DoThingResponse, + ), + "/service.Test/GetThingVersions": grpclib.const.Handler( + self.GetThingVersions, + grpclib.const.Cardinality.UNARY_STREAM, + GetThingRequest, + GetThingResponse, + ), + "/service.Test/GetDifferentThings": grpclib.const.Handler( + self.GetDifferentThings, + grpclib.const.Cardinality.STREAM_STREAM, + GetThingRequest, + GetThingResponse, + ), + } + + +async def _test_stub(stub, name="clean room", **kwargs): + response = await stub.do_thing(name=name) + assert response.names == [name] + + +def _assert_request_meta_recieved(deadline, metadata): + def server_side_test(stream): + assert stream.deadline._timestamp == pytest.approx( + deadline._timestamp, 1 + ), "The provided deadline should be recieved serverside" + assert ( + stream.metadata["authorization"] == metadata["authorization"] + ), "The provided authorization metadata should be recieved serverside" + + return server_side_test + + +@pytest.mark.asyncio +async def test_simple_service_call(): + async with ChannelFor([ThingService()]) as channel: + await _test_stub(ThingServiceClient(channel)) + + +@pytest.mark.asyncio +async def test_service_call_with_upfront_request_params(): + # Setting deadline + deadline = grpclib.metadata.Deadline.from_timeout(22) + metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata))] + ) as channel: + await _test_stub( + ThingServiceClient(channel, deadline=deadline, metadata=metadata) + ) + + # Setting timeout + timeout = 99 + deadline = grpclib.metadata.Deadline.from_timeout(timeout) + metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata))] + ) as channel: + await _test_stub( + ThingServiceClient(channel, timeout=timeout, metadata=metadata) + ) + + +@pytest.mark.asyncio +async def test_service_call_lower_level_with_overrides(): + THING_TO_DO = "get milk" + + # Setting deadline + deadline = grpclib.metadata.Deadline.from_timeout(22) + metadata = {"authorization": "12345"} + kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28) + kwarg_metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata))] + ) as channel: + stub = ThingServiceClient(channel, deadline=deadline, metadata=metadata) + response = await stub._unary_unary( + "/service.Test/DoThing", + DoThingRequest(THING_TO_DO), + DoThingResponse, + deadline=kwarg_deadline, + metadata=kwarg_metadata, + ) + assert response.names == [THING_TO_DO] + + # Setting timeout + timeout = 99 + deadline = grpclib.metadata.Deadline.from_timeout(timeout) + metadata = {"authorization": "12345"} + kwarg_timeout = 9000 + kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout) + kwarg_metadata = {"authorization": "09876"} + async with ChannelFor( + [ + ThingService( + test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata) + ) + ] + ) as channel: + stub = ThingServiceClient(channel, deadline=deadline, metadata=metadata) + response = await stub._unary_unary( + "/service.Test/DoThing", + DoThingRequest(THING_TO_DO), + DoThingResponse, + timeout=kwarg_timeout, + metadata=kwarg_metadata, + ) + assert response.names == [THING_TO_DO]