From a757da1b293969f148ba512bb7b5392dfb817cb6 Mon Sep 17 00:00:00 2001 From: Hans Lellelid Date: Mon, 11 May 2020 15:30:29 -0400 Subject: [PATCH 01/10] Adding basic support (untested) for client streaming --- betterproto/__init__.py | 38 ++++++++++++++++++++++++ betterproto/plugin.py | 5 ++-- betterproto/templates/template.py.j2 | 44 ++++++++++++++++++++++++++-- 3 files changed, 82 insertions(+), 5 deletions(-) diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 5d901be..a2e7a18 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -14,10 +14,12 @@ from typing import ( Collection, Dict, Generator, + Iterator, List, Mapping, Optional, Set, + SupportsBytes, Tuple, Type, TypeVar, @@ -431,6 +433,7 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: # Bound type variable to allow methods to return `self` of subclasses T = TypeVar("T", bound="Message") +ST = TypeVar("ST", bound="IProtoMessage") class ProtoClassMetadata: @@ -1104,3 +1107,38 @@ class ServiceStub(ABC): await stream.send_message(request, end=True) async for message in stream: yield message + + async def _stream_unary( + self, + route: str, + request_iterator: Iterator["IProtoMessage"], + request_type: Type[ST], + response_type: Type[T], + ) -> T: + """Make a stream request and return the response.""" + async with self.channel.request( + route, grpclib.const.Cardinality.STREAM_UNARY, request_type, response_type + ) as stream: + for message in request_iterator: + await stream.send_message(message) + await stream.send_request(end=True) + response = await stream.recv_message() + assert response is not None + return response + + async def _stream_stream( + self, + route: str, + request_iterator: Iterator["IProtoMessage"], + request_type: Type[ST], + response_type: Type[T], + ) -> AsyncGenerator[T, None]: + """Make a stream request and return the stream response iterator.""" + async with self.channel.request( + route, grpclib.const.Cardinality.STREAM_STREAM, request_type, response_type + ) as stream: + for message in request_iterator: + await stream.send_message(message) + await stream.send_request(end=True) + async for message in stream: + yield message diff --git a/betterproto/plugin.py b/betterproto/plugin.py index e300318..b877ce6 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -311,8 +311,6 @@ def generate_code(request, response): } for j, method in enumerate(service.method): - if method.client_streaming: - raise NotImplementedError("Client streaming not yet supported") input_message = None input_type = get_ref_type( @@ -350,6 +348,9 @@ def generate_code(request, response): if method.server_streaming: output["typing_imports"].add("AsyncGenerator") + if method.client_streaming: + output["typing_imports"].add("Iterator") + output["services"].append(data) output["imports"] = sorted(output["imports"]) diff --git a/betterproto/templates/template.py.j2 b/betterproto/templates/template.py.j2 index 3a19422..c4c3029 100644 --- a/betterproto/templates/template.py.j2 +++ b/betterproto/templates/template.py.j2 @@ -63,11 +63,28 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% endif %} {% for method in service.methods %} - async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}: + async def {{ method.py_name }}(self + {%- if not method.client_streaming -%} + {%- if method.input_message and method.input_message.properties -%}, *, + {%- for field in method.input_message.properties -%} + {{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") -%} + Optional[{{ field.type }}] + {%- else -%} + {{ field.type }} + {%- endif -%} = {{ field.zero }} + {%- if not loop.last %}, {% endif -%} + {%- endfor -%} + {%- endif -%} + {%- else -%} + {# Client streaming: need a request iterator instead #} + , request_iterator: Iterator["{{ method.input }}"] + {%- endif -%} + ) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}: {% if method.comment %} {{ method.comment }} {% endif %} + {% if not method.client_streaming %} request = {{ method.input }}() {% for field in method.input_message.properties %} {% if field.field_type == 'message' %} @@ -77,20 +94,41 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): request.{{ field.py_name }} = {{ field.py_name }} {% endif %} {% endfor %} + {% endif %} {% if method.server_streaming %} + {% if method.client_streaming %} + async for response in self._stream_stream( + "{{ method.route }}", + request_iterator, + {{ method.input }}, + {{ method.output }}, + ): + yield response + {% else %}{# i.e. not client streaming #} async for response in self._unary_stream( "{{ method.route }}", request, {{ method.output }}, ): yield response - {% else %} + + {% endif %}{# if client streaming #} + {% else %}{# i.e. not server streaming #} + {% if method.client_streaming %} + return await self._stream_unary( + "{{ method.route }}", + request_iterator, + {{ method.input }}, + {{ method.output }} + ) + {% else %}{# i.e. not client streaming #} return await self._unary_unary( "{{ method.route }}", request, - {{ method.output }}, + {{ method.output }} ) + {% endif %}{# client streaming #} {% endif %} {% endfor %} From 09f821921f9b680c0e8b70020393097643fc466f Mon Sep 17 00:00:00 2001 From: Nat Noordanus Date: Sat, 23 May 2020 23:35:28 +0200 Subject: [PATCH 02/10] Move ServiceStub to a seperate module and add more rpcs to service test --- .gitignore | 3 +- betterproto/__init__.py | 137 +------------- betterproto/_types.py | 5 + betterproto/grpc/__init__.py | 0 betterproto/grpc/grpclib_client.py | 135 ++++++++++++++ betterproto/plugin.py | 1 - .../tests/inputs/service/service.proto | 16 +- betterproto/tests/test_service_client.py | 176 ++++++++++++++++++ 8 files changed, 336 insertions(+), 137 deletions(-) create mode 100644 betterproto/_types.py create mode 100644 betterproto/grpc/__init__.py create mode 100644 betterproto/grpc/grpclib_client.py create mode 100644 betterproto/tests/test_service_client.py diff --git a/.gitignore b/.gitignore index dd22728..4ae66e7 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ betterproto/tests/output_* dist **/*.egg-info output -.idea \ No newline at end of file +.idea +.DS_Store diff --git a/betterproto/__init__.py b/betterproto/__init__.py index a2e7a18..8288aaf 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -5,8 +5,9 @@ import json import struct import sys from abc import ABC -from base64 import b64encode, b64decode +from base64 import b64decode, b64encode from datetime import datetime, timedelta, timezone +import stringcase from typing import ( Any, AsyncGenerator, @@ -22,22 +23,12 @@ from typing import ( SupportsBytes, Tuple, Type, - TypeVar, Union, get_type_hints, - TYPE_CHECKING, ) - - -import grpclib.const -import stringcase - +from ._types import ST, T from .casing import safe_snake_case - -if TYPE_CHECKING: - from grpclib._protocols import IProtoMessage - from grpclib.client import Channel - from grpclib.metadata import Deadline +from .grpc.grpclib_client import ServiceStub if not (sys.version_info.major == 3 and sys.version_info.minor >= 7): # Apply backport of datetime.fromisoformat from 3.7 @@ -431,11 +422,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: ) -# Bound type variable to allow methods to return `self` of subclasses -T = TypeVar("T", bound="Message") -ST = TypeVar("ST", bound="IProtoMessage") - - class ProtoClassMetadata: oneof_group_by_field: Dict[str, str] oneof_field_by_group: Dict[str, Set[dataclasses.Field]] @@ -1027,118 +1013,3 @@ def _get_wrapper(proto_type: str) -> Type: TYPE_STRING: StringValue, TYPE_BYTES: BytesValue, }[proto_type] - - -_Value = Union[str, bytes] -_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] - - -class ServiceStub(ABC): - """ - Base class for async gRPC service stubs. - """ - - def __init__( - self, - channel: "Channel", - *, - timeout: Optional[float] = None, - deadline: Optional["Deadline"] = None, - metadata: Optional[_MetadataLike] = None, - ) -> None: - self.channel = channel - self.timeout = timeout - self.deadline = deadline - self.metadata = metadata - - def __resolve_request_kwargs( - self, - timeout: Optional[float], - deadline: Optional["Deadline"], - metadata: Optional[_MetadataLike], - ): - return { - "timeout": self.timeout if timeout is None else timeout, - "deadline": self.deadline if deadline is None else deadline, - "metadata": self.metadata if metadata is None else metadata, - } - - async def _unary_unary( - self, - route: str, - request: "IProtoMessage", - response_type: Type[T], - *, - timeout: Optional[float] = None, - deadline: Optional["Deadline"] = None, - metadata: Optional[_MetadataLike] = None, - ) -> T: - """Make a unary request and return the response.""" - async with self.channel.request( - route, - grpclib.const.Cardinality.UNARY_UNARY, - type(request), - response_type, - **self.__resolve_request_kwargs(timeout, deadline, metadata), - ) as stream: - await stream.send_message(request, end=True) - response = await stream.recv_message() - assert response is not None - return response - - async def _unary_stream( - self, - route: str, - request: "IProtoMessage", - response_type: Type[T], - *, - timeout: Optional[float] = None, - deadline: Optional["Deadline"] = None, - metadata: Optional[_MetadataLike] = None, - ) -> AsyncGenerator[T, None]: - """Make a unary request and return the stream response iterator.""" - async with self.channel.request( - route, - grpclib.const.Cardinality.UNARY_STREAM, - type(request), - response_type, - **self.__resolve_request_kwargs(timeout, deadline, metadata), - ) as stream: - await stream.send_message(request, end=True) - async for message in stream: - yield message - - async def _stream_unary( - self, - route: str, - request_iterator: Iterator["IProtoMessage"], - request_type: Type[ST], - response_type: Type[T], - ) -> T: - """Make a stream request and return the response.""" - async with self.channel.request( - route, grpclib.const.Cardinality.STREAM_UNARY, request_type, response_type - ) as stream: - for message in request_iterator: - await stream.send_message(message) - await stream.send_request(end=True) - response = await stream.recv_message() - assert response is not None - return response - - async def _stream_stream( - self, - route: str, - request_iterator: Iterator["IProtoMessage"], - request_type: Type[ST], - response_type: Type[T], - ) -> AsyncGenerator[T, None]: - """Make a stream request and return the stream response iterator.""" - async with self.channel.request( - route, grpclib.const.Cardinality.STREAM_STREAM, request_type, response_type - ) as stream: - for message in request_iterator: - await stream.send_message(message) - await stream.send_request(end=True) - async for message in stream: - yield message diff --git a/betterproto/_types.py b/betterproto/_types.py new file mode 100644 index 0000000..0ff23e4 --- /dev/null +++ b/betterproto/_types.py @@ -0,0 +1,5 @@ +from typing import TypeVar + +# Bound type variable to allow methods to return `self` of subclasses +T = TypeVar("T", bound="Message") +ST = TypeVar("ST", bound="IProtoMessage") diff --git a/betterproto/grpc/__init__.py b/betterproto/grpc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/betterproto/grpc/grpclib_client.py b/betterproto/grpc/grpclib_client.py new file mode 100644 index 0000000..757982e --- /dev/null +++ b/betterproto/grpc/grpclib_client.py @@ -0,0 +1,135 @@ +from abc import ABC +import grpclib.const +from typing import ( + AsyncGenerator, + AsyncIterator, + Collection, + Iterator, + Mapping, + Optional, + Tuple, + TYPE_CHECKING, + Type, + Union, +) +from .._types import ST, T + +if TYPE_CHECKING: + from grpclib._protocols import IProtoMessage + from grpclib.client import Channel + from grpclib.metadata import Deadline + + +_Value = Union[str, bytes] +_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] + + +class ServiceStub(ABC): + """ + Base class for async gRPC service stubs. + """ + + def __init__( + self, + channel: "Channel", + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[_MetadataLike] = None, + ) -> None: + self.channel = channel + self.timeout = timeout + self.deadline = deadline + self.metadata = metadata + + def __resolve_request_kwargs( + self, + timeout: Optional[float], + deadline: Optional["Deadline"], + metadata: Optional[_MetadataLike], + ): + return { + "timeout": self.timeout if timeout is None else timeout, + "deadline": self.deadline if deadline is None else deadline, + "metadata": self.metadata if metadata is None else metadata, + } + + async def _unary_unary( + self, + route: str, + request: "IProtoMessage", + response_type: Type[T], + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[_MetadataLike] = None, + ) -> T: + """Make a unary request and return the response.""" + async with self.channel.request( + route, + grpclib.const.Cardinality.UNARY_UNARY, + type(request), + response_type, + **self.__resolve_request_kwargs(timeout, deadline, metadata), + ) as stream: + await stream.send_message(request, end=True) + response = await stream.recv_message() + assert response is not None + return response + + async def _unary_stream( + self, + route: str, + request: "IProtoMessage", + response_type: Type[T], + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[_MetadataLike] = None, + ) -> AsyncGenerator[T, None]: + """Make a unary request and return the stream response iterator.""" + async with self.channel.request( + route, + grpclib.const.Cardinality.UNARY_STREAM, + type(request), + response_type, + **self.__resolve_request_kwargs(timeout, deadline, metadata), + ) as stream: + await stream.send_message(request, end=True) + async for message in stream: + yield message + + async def _stream_unary( + self, + route: str, + request_iterator: Iterator["IProtoMessage"], + request_type: Type[ST], + response_type: Type[T], + ) -> T: + """Make a stream request and return the response.""" + async with self.channel.request( + route, grpclib.const.Cardinality.STREAM_UNARY, request_type, response_type + ) as stream: + for message in request_iterator: + await stream.send_message(message) + await stream.send_request(end=True) + response = await stream.recv_message() + assert response is not None + return response + + async def _stream_stream( + self, + route: str, + request_iterator: Iterator["IProtoMessage"], + request_type: Type[ST], + response_type: Type[T], + ) -> AsyncGenerator[T, None]: + """Make a stream request and return the stream response iterator.""" + async with self.channel.request( + route, grpclib.const.Cardinality.STREAM_STREAM, request_type, response_type + ) as stream: + for message in request_iterator: + await stream.send_message(message) + await stream.send_request(end=True) + async for message in stream: + yield message diff --git a/betterproto/plugin.py b/betterproto/plugin.py index b877ce6..44515d5 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -311,7 +311,6 @@ def generate_code(request, response): } for j, method in enumerate(service.method): - input_message = None input_type = get_ref_type( package, output["imports"], method.input_type diff --git a/betterproto/tests/inputs/service/service.proto b/betterproto/tests/inputs/service/service.proto index 7c931ed..acfbcdd 100644 --- a/betterproto/tests/inputs/service/service.proto +++ b/betterproto/tests/inputs/service/service.proto @@ -3,13 +3,25 @@ syntax = "proto3"; package service; message DoThingRequest { - int32 iterations = 1; + string name = 1; } message DoThingResponse { - int32 successfulIterations = 1; + repeated string names = 1; +} + +message GetThingRequest { + string name = 1; +} + +message GetThingResponse { + string name = 1; + int32 version = 2; } service Test { rpc DoThing (DoThingRequest) returns (DoThingResponse); + rpc DoManyThings (stream DoThingRequest) returns (DoThingResponse); + rpc GetThingVersions (GetThingRequest) returns (stream GetThingResponse); + rpc GetDifferentThings (stream GetThingRequest) returns (stream GetThingResponse); } diff --git a/betterproto/tests/test_service_client.py b/betterproto/tests/test_service_client.py new file mode 100644 index 0000000..586095d --- /dev/null +++ b/betterproto/tests/test_service_client.py @@ -0,0 +1,176 @@ +import betterproto +import grpclib +from grpclib.testing import ChannelFor +import pytest +from typing import Dict +from betterproto.tests.output_betterproto.service.service import ( + DoThingResponse, + DoThingRequest, + GetThingRequest, + GetThingResponse, + TestStub as ThingServiceClient, +) + + +class ThingService: + def __init__(self, test_hook=None): + # This lets us pass assertions to the servicer ;) + self.test_hook = test_hook + + async def DoThing( + self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" + ): + request = await stream.recv_message() + if self.test_hook is not None: + self.test_hook(stream) + await stream.send_message(DoThingResponse([request.name])) + + async def DoManyThings( + self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" + ): + thing_names = [request.name for request in stream] + if self.test_hook is not None: + self.test_hook(stream) + await stream.send_message(DoThingResponse(thing_names)) + + async def GetThingVersions( + self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" + ): + request = await stream.recv_message() + if self.test_hook is not None: + self.test_hook(stream) + for version_num in range(1, 6): + await stream.send_message( + GetThingResponse(name=request, version=version_num) + ) + + async def GetDifferentThings( + self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" + ): + if self.test_hook is not None: + self.test_hook(stream) + # Response to each input item immediately + for request in stream: + await stream.send_message(GetThingResponse(name=request.name, version=1)) + + def __mapping__(self) -> Dict[str, grpclib.const.Handler]: + return { + "/service.Test/DoThing": grpclib.const.Handler( + self.DoThing, + grpclib.const.Cardinality.UNARY_UNARY, + DoThingRequest, + DoThingResponse, + ), + "/service.Test/DoManyThings": grpclib.const.Handler( + self.DoManyThings, + grpclib.const.Cardinality.STREAM_UNARY, + DoThingRequest, + DoThingResponse, + ), + "/service.Test/GetThingVersions": grpclib.const.Handler( + self.GetThingVersions, + grpclib.const.Cardinality.UNARY_STREAM, + GetThingRequest, + GetThingResponse, + ), + "/service.Test/GetDifferentThings": grpclib.const.Handler( + self.GetDifferentThings, + grpclib.const.Cardinality.STREAM_STREAM, + GetThingRequest, + GetThingResponse, + ), + } + + +async def _test_stub(stub, name="clean room", **kwargs): + response = await stub.do_thing(name=name) + assert response.names == [name] + + +def _assert_request_meta_recieved(deadline, metadata): + def server_side_test(stream): + assert stream.deadline._timestamp == pytest.approx( + deadline._timestamp, 1 + ), "The provided deadline should be recieved serverside" + assert ( + stream.metadata["authorization"] == metadata["authorization"] + ), "The provided authorization metadata should be recieved serverside" + + return server_side_test + + +@pytest.mark.asyncio +async def test_simple_service_call(): + async with ChannelFor([ThingService()]) as channel: + await _test_stub(ThingServiceClient(channel)) + + +@pytest.mark.asyncio +async def test_service_call_with_upfront_request_params(): + # Setting deadline + deadline = grpclib.metadata.Deadline.from_timeout(22) + metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata))] + ) as channel: + await _test_stub( + ThingServiceClient(channel, deadline=deadline, metadata=metadata) + ) + + # Setting timeout + timeout = 99 + deadline = grpclib.metadata.Deadline.from_timeout(timeout) + metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata))] + ) as channel: + await _test_stub( + ThingServiceClient(channel, timeout=timeout, metadata=metadata) + ) + + +@pytest.mark.asyncio +async def test_service_call_lower_level_with_overrides(): + THING_TO_DO = "get milk" + + # Setting deadline + deadline = grpclib.metadata.Deadline.from_timeout(22) + metadata = {"authorization": "12345"} + kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28) + kwarg_metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata))] + ) as channel: + stub = ThingServiceClient(channel, deadline=deadline, metadata=metadata) + response = await stub._unary_unary( + "/service.Test/DoThing", + DoThingRequest(THING_TO_DO), + DoThingResponse, + deadline=kwarg_deadline, + metadata=kwarg_metadata, + ) + assert response.names == [THING_TO_DO] + + # Setting timeout + timeout = 99 + deadline = grpclib.metadata.Deadline.from_timeout(timeout) + metadata = {"authorization": "12345"} + kwarg_timeout = 9000 + kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout) + kwarg_metadata = {"authorization": "09876"} + async with ChannelFor( + [ + ThingService( + test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata) + ) + ] + ) as channel: + stub = ThingServiceClient(channel, deadline=deadline, metadata=metadata) + response = await stub._unary_unary( + "/service.Test/DoThing", + DoThingRequest(THING_TO_DO), + DoThingResponse, + timeout=kwarg_timeout, + metadata=kwarg_metadata, + ) + assert response.names == [THING_TO_DO] From 4b6f55dce58d82f8db14ada7e08063c14eab9b94 Mon Sep 17 00:00:00 2001 From: Nat Noordanus Date: Sun, 7 Jun 2020 17:51:26 +0200 Subject: [PATCH 03/10] Finish implementation and testing of client Including stream_unary and stream_stream call methods. Also - improve organisation of relevant tests - fix some generated type annotations - Add AsyncChannel utility cos it's useful --- betterproto/__init__.py | 2 +- betterproto/grpc/grpclib_client.py | 68 ++++-- betterproto/grpc/util/__init__.py | 0 betterproto/grpc/util/async_channel.py | 204 ++++++++++++++++++ betterproto/plugin.py | 9 +- betterproto/templates/template.py.j2 | 16 +- betterproto/tests/grpc/__init__.py | 0 betterproto/tests/grpc/test_grpclib_client.py | 150 +++++++++++++ betterproto/tests/grpc/thing_service.py | 83 +++++++ .../test_googletypes_response.py | 2 +- .../tests/inputs/service/test_service.py | 132 ------------ betterproto/tests/test_service_client.py | 176 --------------- 12 files changed, 503 insertions(+), 339 deletions(-) create mode 100644 betterproto/grpc/util/__init__.py create mode 100644 betterproto/grpc/util/async_channel.py create mode 100644 betterproto/tests/grpc/__init__.py create mode 100644 betterproto/tests/grpc/test_grpclib_client.py create mode 100644 betterproto/tests/grpc/thing_service.py delete mode 100644 betterproto/tests/inputs/service/test_service.py delete mode 100644 betterproto/tests/test_service_client.py diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 8288aaf..6a53d65 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -593,7 +593,7 @@ class Message(ABC): serialize_empty = False if isinstance(value, Message) and value._serialized_on_wire: # Empty messages can still be sent on the wire if they were - # set (or received empty). + # set (or recieved empty). serialize_empty = True if value == self._get_field_default(field_name) and not ( diff --git a/betterproto/grpc/grpclib_client.py b/betterproto/grpc/grpclib_client.py index 757982e..7218574 100644 --- a/betterproto/grpc/grpclib_client.py +++ b/betterproto/grpc/grpclib_client.py @@ -1,7 +1,8 @@ from abc import ABC +import asyncio import grpclib.const from typing import ( - AsyncGenerator, + Any, AsyncIterator, Collection, Iterator, @@ -16,17 +17,18 @@ from .._types import ST, T if TYPE_CHECKING: from grpclib._protocols import IProtoMessage - from grpclib.client import Channel + from grpclib.client import Channel, Stream from grpclib.metadata import Deadline _Value = Union[str, bytes] _MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] +_MessageSource = Union[Iterator["IProtoMessage"], AsyncIterator["IProtoMessage"]] class ServiceStub(ABC): """ - Base class for async gRPC service stubs. + Base class for async gRPC clients. """ def __init__( @@ -86,7 +88,7 @@ class ServiceStub(ABC): timeout: Optional[float] = None, deadline: Optional["Deadline"] = None, metadata: Optional[_MetadataLike] = None, - ) -> AsyncGenerator[T, None]: + ) -> AsyncIterator[T]: """Make a unary request and return the stream response iterator.""" async with self.channel.request( route, @@ -102,17 +104,23 @@ class ServiceStub(ABC): async def _stream_unary( self, route: str, - request_iterator: Iterator["IProtoMessage"], + request_iterator: _MessageSource, request_type: Type[ST], response_type: Type[T], + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[_MetadataLike] = None, ) -> T: """Make a stream request and return the response.""" async with self.channel.request( - route, grpclib.const.Cardinality.STREAM_UNARY, request_type, response_type + route, + grpclib.const.Cardinality.STREAM_UNARY, + request_type, + response_type, + **self.__resolve_request_kwargs(timeout, deadline, metadata), ) as stream: - for message in request_iterator: - await stream.send_message(message) - await stream.send_request(end=True) + await self._send_messages(stream, request_iterator) response = await stream.recv_message() assert response is not None return response @@ -120,16 +128,42 @@ class ServiceStub(ABC): async def _stream_stream( self, route: str, - request_iterator: Iterator["IProtoMessage"], + request_iterator: _MessageSource, request_type: Type[ST], response_type: Type[T], - ) -> AsyncGenerator[T, None]: - """Make a stream request and return the stream response iterator.""" + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[_MetadataLike] = None, + ) -> AsyncIterator[T]: + """ + Make a stream request and return an AsyncIterator to iterate over response + messages. + """ async with self.channel.request( - route, grpclib.const.Cardinality.STREAM_STREAM, request_type, response_type + route, + grpclib.const.Cardinality.STREAM_STREAM, + request_type, + response_type, + **self.__resolve_request_kwargs(timeout, deadline, metadata), ) as stream: - for message in request_iterator: + await stream.send_request() + sending_task = asyncio.ensure_future( + self._send_messages(stream, request_iterator) + ) + try: + async for response in stream: + yield response + except: + sending_task.cancel() + raise + + @staticmethod + async def _send_messages(stream, messages: _MessageSource): + if hasattr(messages, "__aiter__"): + async for message in messages: await stream.send_message(message) - await stream.send_request(end=True) - async for message in stream: - yield message + else: + for message in messages: + await stream.send_message(message) + await stream.end() diff --git a/betterproto/grpc/util/__init__.py b/betterproto/grpc/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/betterproto/grpc/util/async_channel.py b/betterproto/grpc/util/async_channel.py new file mode 100644 index 0000000..7e83c94 --- /dev/null +++ b/betterproto/grpc/util/async_channel.py @@ -0,0 +1,204 @@ +import asyncio +from typing import ( + AsyncIterable, + AsyncIterator, + Iterable, + Optional, + TypeVar, + Union, +) + +T = TypeVar("T") + + +class ChannelClosed(Exception): + """ + An exception raised on an attempt to send through a closed channel + """ + + pass + + +class ChannelDone(Exception): + """ + An exception raised on an attempt to send recieve from a channel that is both closed + and empty. + """ + + pass + + +class AsyncChannel(AsyncIterable[T]): + """ + A buffered async channel for sending items between coroutines with FIFO semantics. + + This makes decoupled bidirection steaming gRPC requests easy if used like: + + .. code-block:: python + client = GeneratedStub(grpclib_chan) + # The channel can be initialised with items to send immediately + request_chan = AsyncChannel([ReqestObject(...), ReqestObject(...)]) + async for response in client.rpc_call(request_chan): + # The response iterator will remain active until the connection is closed + ... + # More items can be sent at any time + await request_chan.send(ReqestObject(...)) + ... + # The channel must be closed to complete the gRPC connection + request_chan.close() + + Items can be sent through the channel by either: + - providing an iterable to the constructor + - providing an iterable to the send_from method + - passing them to the send method one at a time + + Items can be recieved from the channel by either: + - iterating over the channel with a for loop to get all items + - calling the recieve method to get one item at a time + + If the channel is empty then recievers will wait until either an item appears or the + channel is closed. + + Once the channel is closed then subsequent attempt to send through the channel will + fail with a ChannelClosed exception. + + When th channel is closed and empty then it is done, and further attempts to recieve + from it will fail with a ChannelDone exception + + If multiple coroutines recieve from the channel concurrently, each item sent will be + recieved by only one of the recievers. + + :param source: + An optional iterable will items that should be sent through the channel + immediately. + :param buffer_limit: + Limit the number of items that can be buffered in the channel, A value less than + 1 implies no limit. If the channel is full then attempts to send more items will + result in the sender waiting until an item is recieved from the channel. + :param close: + If set to True then the channel will automatically close after exhausting source + or immediately if no source is provided. + """ + + def __init__( + self, + source: Union[Iterable[T], AsyncIterable[T]] = tuple(), + *, + buffer_limit: int = 0, + close: bool = False, + ): + self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit) + self._closed = False + self._sending_task = ( + asyncio.ensure_future(self.send_from(source, close)) if source else None + ) + self._waiting_recievers: int = 0 + # Track whether flush has been invoked so it can only happen once + self._flushed = False + + def __aiter__(self) -> AsyncIterator[T]: + return self + + async def __anext__(self) -> T: + if self.done: + raise StopAsyncIteration + self._waiting_recievers += 1 + try: + result = await self._queue.get() + if result is self.__flush: + raise StopAsyncIteration + finally: + self._waiting_recievers -= 1 + self._queue.task_done() + + def closed(self) -> bool: + """ + Returns True if this channel is closed and no-longer accepting new items + """ + return self._closed + + def done(self) -> bool: + """ + Check if this channel is done. + + :return: True if this channel is closed and and has been drained of items in + which case any further attempts to recieve an item from this channel will raise + a ChannelDone exception. + """ + # After close the channel is not yet done until there is at least one waiting + # reciever per enqueued item. + return self._closed and self._queue.qsize() <= self._waiting_recievers + + async def send_from( + self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False + ): + """ + Iterates the given [Async]Iterable and sends all the resulting items. + If close is set to True then subsequent send calls will be rejected with a + ChannelClosed exception. + :param source: an iterable of items to send + :param close: + if True then the channel will be closed after the source has been exhausted + + """ + if self._closed: + raise ChannelClosed("Cannot send through a closed channel") + if isinstance(source, AsyncIterable): + async for item in source: + await self._queue.put(item) + else: + for item in source: + await self._queue.put(item) + if close: + # Complete the closing process + await self.close() + + async def send(self, item: T): + """ + Send a single item over this channel. + :param item: The item to send + """ + if self._closed: + raise ChannelClosed("Cannot send through a closed channel") + await self._queue.put(item) + + async def recieve(self) -> Optional[T]: + """ + Returns the next item from this channel when it becomes available, + or None if the channel is closed before another item is sent. + :return: An item from the channel + """ + if self.done: + raise ChannelDone("Cannot recieve from a closed channel") + self._waiting_recievers += 1 + try: + result = await self._queue.get() + if result is self.__flush: + return None + return result + finally: + self._waiting_recievers -= 1 + self._queue.task_done() + + def close(self): + """ + Close this channel to new items + """ + if self._sending_task is not None: + self._sending_task.cancel() + self._closed = True + asyncio.ensure_future(self._flush_queue()) + + async def _flush_queue(self): + """ + To be called after the channel is closed. Pushes a number of self.__flush + objects to the queue to ensure no waiting consumers get deadlocked. + """ + if not self._flushed: + self._flushed = True + deadlocked_recievers = max(0, self._waiting_recievers - self._queue.qsize()) + for _ in range(deadlocked_recievers): + await self._queue.put(self.__flush) + + # A special signal object for flushing the queue when the channel is closed + __flush = object() diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 44515d5..85fd905 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -344,11 +344,12 @@ def generate_code(request, response): } ) - if method.server_streaming: - output["typing_imports"].add("AsyncGenerator") - if method.client_streaming: - output["typing_imports"].add("Iterator") + output["typing_imports"].add("AsyncIterable") + output["typing_imports"].add("Iterable") + output["typing_imports"].add("Union") + if method.server_streaming: + output["typing_imports"].add("AsyncIterator") output["services"].append(data) diff --git a/betterproto/templates/template.py.j2 b/betterproto/templates/template.py.j2 index c4c3029..3894619 100644 --- a/betterproto/templates/template.py.j2 +++ b/betterproto/templates/template.py.j2 @@ -77,9 +77,9 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {%- endif -%} {%- else -%} {# Client streaming: need a request iterator instead #} - , request_iterator: Iterator["{{ method.input }}"] + , request_iterator: Union[AsyncIterable["{{ method.input }}"], Iterable["{{ method.input }}"]] {%- endif -%} - ) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}: + ) -> {% if method.server_streaming %}AsyncIterator[{{ method.output }}]{% else %}{{ method.output }}{% endif %}: {% if method.comment %} {{ method.comment }} @@ -97,7 +97,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% endif %} {% if method.server_streaming %} - {% if method.client_streaming %} + {% if method.client_streaming %} async for response in self._stream_stream( "{{ method.route }}", request_iterator, @@ -105,7 +105,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {{ method.output }}, ): yield response - {% else %}{# i.e. not client streaming #} + {% else %}{# i.e. not client streaming #} async for response in self._unary_stream( "{{ method.route }}", request, @@ -113,22 +113,22 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): ): yield response - {% endif %}{# if client streaming #} + {% endif %}{# if client streaming #} {% else %}{# i.e. not server streaming #} - {% if method.client_streaming %} + {% if method.client_streaming %} return await self._stream_unary( "{{ method.route }}", request_iterator, {{ method.input }}, {{ method.output }} ) - {% else %}{# i.e. not client streaming #} + {% else %}{# i.e. not client streaming #} return await self._unary_unary( "{{ method.route }}", request, {{ method.output }} ) - {% endif %}{# client streaming #} + {% endif %}{# client streaming #} {% endif %} {% endfor %} diff --git a/betterproto/tests/grpc/__init__.py b/betterproto/tests/grpc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/betterproto/tests/grpc/test_grpclib_client.py b/betterproto/tests/grpc/test_grpclib_client.py new file mode 100644 index 0000000..dc57fe4 --- /dev/null +++ b/betterproto/tests/grpc/test_grpclib_client.py @@ -0,0 +1,150 @@ +from betterproto.tests.output_betterproto.service.service import ( + DoThingResponse, + DoThingRequest, + GetThingRequest, + GetThingResponse, + TestStub as ThingServiceClient, +) +import grpclib +from grpclib.testing import ChannelFor +import pytest +from betterproto.grpc.util.async_channel import AsyncChannel +from .thing_service import ThingService + + +async def _test_client(client, name="clean room", **kwargs): + response = await client.do_thing(name=name) + assert response.names == [name] + + +def _assert_request_meta_recieved(deadline, metadata): + def server_side_test(stream): + assert stream.deadline._timestamp == pytest.approx( + deadline._timestamp, 1 + ), "The provided deadline should be recieved serverside" + assert ( + stream.metadata["authorization"] == metadata["authorization"] + ), "The provided authorization metadata should be recieved serverside" + + return server_side_test + + +@pytest.mark.asyncio +async def test_simple_service_call(): + async with ChannelFor([ThingService()]) as channel: + await _test_client(ThingServiceClient(channel)) + + +@pytest.mark.asyncio +async def test_service_call_with_upfront_request_params(): + # Setting deadline + deadline = grpclib.metadata.Deadline.from_timeout(22) + metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)] + ) as channel: + await _test_client( + ThingServiceClient(channel, deadline=deadline, metadata=metadata) + ) + + # Setting timeout + timeout = 99 + deadline = grpclib.metadata.Deadline.from_timeout(timeout) + metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)] + ) as channel: + await _test_client( + ThingServiceClient(channel, timeout=timeout, metadata=metadata) + ) + + +@pytest.mark.asyncio +async def test_service_call_lower_level_with_overrides(): + THING_TO_DO = "get milk" + + # Setting deadline + deadline = grpclib.metadata.Deadline.from_timeout(22) + metadata = {"authorization": "12345"} + kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28) + kwarg_metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)] + ) as channel: + client = ThingServiceClient(channel, deadline=deadline, metadata=metadata) + response = await client._unary_unary( + "/service.Test/DoThing", + DoThingRequest(THING_TO_DO), + DoThingResponse, + deadline=kwarg_deadline, + metadata=kwarg_metadata, + ) + assert response.names == [THING_TO_DO] + + # Setting timeout + timeout = 99 + deadline = grpclib.metadata.Deadline.from_timeout(timeout) + metadata = {"authorization": "12345"} + kwarg_timeout = 9000 + kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout) + kwarg_metadata = {"authorization": "09876"} + async with ChannelFor( + [ + ThingService( + test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata), + ) + ] + ) as channel: + client = ThingServiceClient(channel, deadline=deadline, metadata=metadata) + response = await client._unary_unary( + "/service.Test/DoThing", + DoThingRequest(THING_TO_DO), + DoThingResponse, + timeout=kwarg_timeout, + metadata=kwarg_metadata, + ) + assert response.names == [THING_TO_DO] + + +@pytest.mark.asyncio +async def test_async_gen_for_unary_stream_request(): + thing_name = "my milkshakes" + + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + expected_versions = [5, 4, 3, 2, 1] + async for response in client.get_thing_versions(name=thing_name): + assert response.name == thing_name + assert response.version == expected_versions.pop() + + +@pytest.mark.asyncio +async def test_async_gen_for_stream_stream_request(): + some_things = ["cake", "cricket", "coral reef"] + more_things = ["ball", "that", "56kmodem", "liberal humanism", "cheesesticks"] + expected_things = (*some_things, *more_things) + + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + # Use an AsyncChannel to decouple sending and recieving, it'll send some_things + # immediately and we'll use it to send more_things later, after recieving some + # results + request_chan = AsyncChannel(GetThingRequest(name) for name in some_things) + response_index = 0 + async for response in client.get_different_things(request_chan): + assert response.name == expected_things[response_index] + assert response.version == response_index + 1 + response_index += 1 + if more_things: + # Send some more requests as we recieve reponses to be sure coordination of + # send/recieve events doesn't matter + another_response = await request_chan.send( + GetThingRequest(more_things.pop(0)) + ) + if another_response is not None: + assert another_response.name == expected_things[response_index] + assert another_response.version == response_index + response_index += 1 + else: + # No more things to send make sure channel is closed + await request_chan.close() diff --git a/betterproto/tests/grpc/thing_service.py b/betterproto/tests/grpc/thing_service.py new file mode 100644 index 0000000..bc9fff8 --- /dev/null +++ b/betterproto/tests/grpc/thing_service.py @@ -0,0 +1,83 @@ +from betterproto.tests.output_betterproto.service.service import ( + DoThingResponse, + DoThingRequest, + GetThingRequest, + GetThingResponse, + TestStub as ThingServiceClient, +) +import grpclib +from typing import Any, Dict + + +class ThingService: + def __init__(self, test_hook=None): + # This lets us pass assertions to the servicer ;) + self.test_hook = test_hook + + async def do_thing( + self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" + ): + request = await stream.recv_message() + if self.test_hook is not None: + self.test_hook(stream) + await stream.send_message(DoThingResponse([request.name])) + + async def do_many_things( + self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" + ): + thing_names = [request.name for request in stream] + if self.test_hook is not None: + self.test_hook(stream) + await stream.send_message(DoThingResponse(thing_names)) + + async def get_thing_versions( + self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" + ): + request = await stream.recv_message() + if self.test_hook is not None: + self.test_hook(stream) + for version_num in range(1, 6): + await stream.send_message( + GetThingResponse(name=request.name, version=version_num) + ) + + async def get_different_things( + self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" + ): + if self.test_hook is not None: + self.test_hook(stream) + # Respond to each input item immediately + response_num = 0 + async for request in stream: + response_num += 1 + await stream.send_message( + GetThingResponse(name=request.name, version=response_num) + ) + + def __mapping__(self) -> Dict[str, "grpclib.const.Handler"]: + return { + "/service.Test/DoThing": grpclib.const.Handler( + self.do_thing, + grpclib.const.Cardinality.UNARY_UNARY, + DoThingRequest, + DoThingResponse, + ), + "/service.Test/DoManyThings": grpclib.const.Handler( + self.do_many_things, + grpclib.const.Cardinality.STREAM_UNARY, + DoThingRequest, + DoThingResponse, + ), + "/service.Test/GetThingVersions": grpclib.const.Handler( + self.get_thing_versions, + grpclib.const.Cardinality.UNARY_STREAM, + GetThingRequest, + GetThingResponse, + ), + "/service.Test/GetDifferentThings": grpclib.const.Handler( + self.get_different_things, + grpclib.const.Cardinality.STREAM_STREAM, + GetThingRequest, + GetThingResponse, + ), + } diff --git a/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py b/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py index 02fa193..bd5f602 100644 --- a/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py +++ b/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py @@ -23,7 +23,7 @@ test_cases = [ @pytest.mark.asyncio @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) -async def test_channel_receives_wrapped_type( +async def test_channel_recieves_wrapped_type( service_method: Callable[[TestStub], Any], wrapper_class: Callable, value ): wrapped_value = wrapper_class() diff --git a/betterproto/tests/inputs/service/test_service.py b/betterproto/tests/inputs/service/test_service.py deleted file mode 100644 index 2a6ca59..0000000 --- a/betterproto/tests/inputs/service/test_service.py +++ /dev/null @@ -1,132 +0,0 @@ -import betterproto -import grpclib -from grpclib.testing import ChannelFor -import pytest -from typing import Dict - -from betterproto.tests.output_betterproto.service.service import ( - DoThingResponse, - DoThingRequest, - TestStub as ExampleServiceStub, -) - - -class ExampleService: - def __init__(self, test_hook=None): - # This lets us pass assertions to the servicer ;) - self.test_hook = test_hook - - async def DoThing( - self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" - ): - request = await stream.recv_message() - print("self.test_hook", self.test_hook) - if self.test_hook is not None: - self.test_hook(stream) - for iteration in range(request.iterations): - pass - await stream.send_message(DoThingResponse(request.iterations)) - - def __mapping__(self) -> Dict[str, grpclib.const.Handler]: - return { - "/service.Test/DoThing": grpclib.const.Handler( - self.DoThing, - grpclib.const.Cardinality.UNARY_UNARY, - DoThingRequest, - DoThingResponse, - ) - } - - -async def _test_stub(stub, iterations=42, **kwargs): - response = await stub.do_thing(iterations=iterations) - assert response.successful_iterations == iterations - - -def _get_server_side_test(deadline, metadata): - def server_side_test(stream): - assert stream.deadline._timestamp == pytest.approx( - deadline._timestamp, 1 - ), "The provided deadline should be recieved serverside" - assert ( - stream.metadata["authorization"] == metadata["authorization"] - ), "The provided authorization metadata should be recieved serverside" - - return server_side_test - - -@pytest.mark.asyncio -async def test_simple_service_call(): - async with ChannelFor([ExampleService()]) as channel: - await _test_stub(ExampleServiceStub(channel)) - - -@pytest.mark.asyncio -async def test_service_call_with_upfront_request_params(): - # Setting deadline - deadline = grpclib.metadata.Deadline.from_timeout(22) - metadata = {"authorization": "12345"} - async with ChannelFor( - [ExampleService(test_hook=_get_server_side_test(deadline, metadata))] - ) as channel: - await _test_stub( - ExampleServiceStub(channel, deadline=deadline, metadata=metadata) - ) - - # Setting timeout - timeout = 99 - deadline = grpclib.metadata.Deadline.from_timeout(timeout) - metadata = {"authorization": "12345"} - async with ChannelFor( - [ExampleService(test_hook=_get_server_side_test(deadline, metadata))] - ) as channel: - await _test_stub( - ExampleServiceStub(channel, timeout=timeout, metadata=metadata) - ) - - -@pytest.mark.asyncio -async def test_service_call_lower_level_with_overrides(): - ITERATIONS = 99 - - # Setting deadline - deadline = grpclib.metadata.Deadline.from_timeout(22) - metadata = {"authorization": "12345"} - kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28) - kwarg_metadata = {"authorization": "12345"} - async with ChannelFor( - [ExampleService(test_hook=_get_server_side_test(deadline, metadata))] - ) as channel: - stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata) - response = await stub._unary_unary( - "/service.Test/DoThing", - DoThingRequest(ITERATIONS), - DoThingResponse, - deadline=kwarg_deadline, - metadata=kwarg_metadata, - ) - assert response.successful_iterations == ITERATIONS - - # Setting timeout - timeout = 99 - deadline = grpclib.metadata.Deadline.from_timeout(timeout) - metadata = {"authorization": "12345"} - kwarg_timeout = 9000 - kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout) - kwarg_metadata = {"authorization": "09876"} - async with ChannelFor( - [ - ExampleService( - test_hook=_get_server_side_test(kwarg_deadline, kwarg_metadata) - ) - ] - ) as channel: - stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata) - response = await stub._unary_unary( - "/service.Test/DoThing", - DoThingRequest(ITERATIONS), - DoThingResponse, - timeout=kwarg_timeout, - metadata=kwarg_metadata, - ) - assert response.successful_iterations == ITERATIONS diff --git a/betterproto/tests/test_service_client.py b/betterproto/tests/test_service_client.py deleted file mode 100644 index 586095d..0000000 --- a/betterproto/tests/test_service_client.py +++ /dev/null @@ -1,176 +0,0 @@ -import betterproto -import grpclib -from grpclib.testing import ChannelFor -import pytest -from typing import Dict -from betterproto.tests.output_betterproto.service.service import ( - DoThingResponse, - DoThingRequest, - GetThingRequest, - GetThingResponse, - TestStub as ThingServiceClient, -) - - -class ThingService: - def __init__(self, test_hook=None): - # This lets us pass assertions to the servicer ;) - self.test_hook = test_hook - - async def DoThing( - self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" - ): - request = await stream.recv_message() - if self.test_hook is not None: - self.test_hook(stream) - await stream.send_message(DoThingResponse([request.name])) - - async def DoManyThings( - self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" - ): - thing_names = [request.name for request in stream] - if self.test_hook is not None: - self.test_hook(stream) - await stream.send_message(DoThingResponse(thing_names)) - - async def GetThingVersions( - self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" - ): - request = await stream.recv_message() - if self.test_hook is not None: - self.test_hook(stream) - for version_num in range(1, 6): - await stream.send_message( - GetThingResponse(name=request, version=version_num) - ) - - async def GetDifferentThings( - self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" - ): - if self.test_hook is not None: - self.test_hook(stream) - # Response to each input item immediately - for request in stream: - await stream.send_message(GetThingResponse(name=request.name, version=1)) - - def __mapping__(self) -> Dict[str, grpclib.const.Handler]: - return { - "/service.Test/DoThing": grpclib.const.Handler( - self.DoThing, - grpclib.const.Cardinality.UNARY_UNARY, - DoThingRequest, - DoThingResponse, - ), - "/service.Test/DoManyThings": grpclib.const.Handler( - self.DoManyThings, - grpclib.const.Cardinality.STREAM_UNARY, - DoThingRequest, - DoThingResponse, - ), - "/service.Test/GetThingVersions": grpclib.const.Handler( - self.GetThingVersions, - grpclib.const.Cardinality.UNARY_STREAM, - GetThingRequest, - GetThingResponse, - ), - "/service.Test/GetDifferentThings": grpclib.const.Handler( - self.GetDifferentThings, - grpclib.const.Cardinality.STREAM_STREAM, - GetThingRequest, - GetThingResponse, - ), - } - - -async def _test_stub(stub, name="clean room", **kwargs): - response = await stub.do_thing(name=name) - assert response.names == [name] - - -def _assert_request_meta_recieved(deadline, metadata): - def server_side_test(stream): - assert stream.deadline._timestamp == pytest.approx( - deadline._timestamp, 1 - ), "The provided deadline should be recieved serverside" - assert ( - stream.metadata["authorization"] == metadata["authorization"] - ), "The provided authorization metadata should be recieved serverside" - - return server_side_test - - -@pytest.mark.asyncio -async def test_simple_service_call(): - async with ChannelFor([ThingService()]) as channel: - await _test_stub(ThingServiceClient(channel)) - - -@pytest.mark.asyncio -async def test_service_call_with_upfront_request_params(): - # Setting deadline - deadline = grpclib.metadata.Deadline.from_timeout(22) - metadata = {"authorization": "12345"} - async with ChannelFor( - [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata))] - ) as channel: - await _test_stub( - ThingServiceClient(channel, deadline=deadline, metadata=metadata) - ) - - # Setting timeout - timeout = 99 - deadline = grpclib.metadata.Deadline.from_timeout(timeout) - metadata = {"authorization": "12345"} - async with ChannelFor( - [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata))] - ) as channel: - await _test_stub( - ThingServiceClient(channel, timeout=timeout, metadata=metadata) - ) - - -@pytest.mark.asyncio -async def test_service_call_lower_level_with_overrides(): - THING_TO_DO = "get milk" - - # Setting deadline - deadline = grpclib.metadata.Deadline.from_timeout(22) - metadata = {"authorization": "12345"} - kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28) - kwarg_metadata = {"authorization": "12345"} - async with ChannelFor( - [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata))] - ) as channel: - stub = ThingServiceClient(channel, deadline=deadline, metadata=metadata) - response = await stub._unary_unary( - "/service.Test/DoThing", - DoThingRequest(THING_TO_DO), - DoThingResponse, - deadline=kwarg_deadline, - metadata=kwarg_metadata, - ) - assert response.names == [THING_TO_DO] - - # Setting timeout - timeout = 99 - deadline = grpclib.metadata.Deadline.from_timeout(timeout) - metadata = {"authorization": "12345"} - kwarg_timeout = 9000 - kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout) - kwarg_metadata = {"authorization": "09876"} - async with ChannelFor( - [ - ThingService( - test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata) - ) - ] - ) as channel: - stub = ThingServiceClient(channel, deadline=deadline, metadata=metadata) - response = await stub._unary_unary( - "/service.Test/DoThing", - DoThingRequest(THING_TO_DO), - DoThingResponse, - timeout=kwarg_timeout, - metadata=kwarg_metadata, - ) - assert response.names == [THING_TO_DO] From 3185c670981de023e400c5e6854231cd29195e12 Mon Sep 17 00:00:00 2001 From: Nat Noordanus Date: Sun, 7 Jun 2020 17:53:06 +0200 Subject: [PATCH 04/10] Improve generate script - Fix issue with __pycache__ dirs getting picked up - parallelise code generation with asyncio for 3x speedup - silence protoc output unless -v option is supplied - Use pathlib ;) --- betterproto/tests/generate.py | 114 ++++++++++++++++++++----------- betterproto/tests/test_inputs.py | 2 +- betterproto/tests/util.py | 55 +++++++-------- 3 files changed, 99 insertions(+), 72 deletions(-) diff --git a/betterproto/tests/generate.py b/betterproto/tests/generate.py index fc85b7f..5c555ff 100644 --- a/betterproto/tests/generate.py +++ b/betterproto/tests/generate.py @@ -1,6 +1,7 @@ #!/usr/bin/env python -import glob +import asyncio import os +from pathlib import Path import shutil import subprocess import sys @@ -20,58 +21,63 @@ from betterproto.tests.util import ( os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" -def clear_directory(path: str): - for file_or_directory in glob.glob(os.path.join(path, "*")): - if os.path.isdir(file_or_directory): +def clear_directory(dir_path: Path): + for file_or_directory in dir_path.glob("*"): + if file_or_directory.is_dir(): shutil.rmtree(file_or_directory) else: - os.remove(file_or_directory) + file_or_directory.unlink() -def generate(whitelist: Set[str]): - path_whitelist = {os.path.realpath(e) for e in whitelist if os.path.exists(e)} - name_whitelist = {e for e in whitelist if not os.path.exists(e)} +async def generate(whitelist: Set[str], verbose: bool): + test_case_names = set(get_directories(inputs_path)) - {"__pycache__"} - test_case_names = set(get_directories(inputs_path)) - - failed_test_cases = [] + path_whitelist = set() + name_whitelist = set() + for item in whitelist: + if item in test_case_names: + name_whitelist.add(item) + continue + path_whitelist.add(item) + generation_tasks = [] for test_case_name in sorted(test_case_names): - test_case_input_path = os.path.realpath( - os.path.join(inputs_path, test_case_name) - ) - + test_case_input_path = inputs_path.joinpath(test_case_name).resolve() if ( whitelist - and test_case_input_path not in path_whitelist + and str(test_case_input_path) not in path_whitelist and test_case_name not in name_whitelist ): continue + generation_tasks.append( + generate_test_case_output(test_case_input_path, test_case_name, verbose) + ) - print(f"Generating output for {test_case_name}") - try: - generate_test_case_output(test_case_name, test_case_input_path) - except subprocess.CalledProcessError as e: + failed_test_cases = [] + # Wait for all subprocs and match any failures to names to report + for test_case_name, result in zip( + sorted(test_case_names), await asyncio.gather(*generation_tasks) + ): + if result != 0: failed_test_cases.append(test_case_name) if failed_test_cases: - sys.stderr.write("\nFailed to generate the following test cases:\n") + sys.stderr.write( + "\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n" + ) for failed_test_case in failed_test_cases: sys.stderr.write(f"- {failed_test_case}\n") -def generate_test_case_output(test_case_name, test_case_input_path=None): - if not test_case_input_path: - test_case_input_path = os.path.realpath( - os.path.join(inputs_path, test_case_name) - ) +async def generate_test_case_output( + test_case_input_path: Path, test_case_name: str, verbose: bool +) -> int: + """ + Returns the max of the subprocess return values + """ - test_case_output_path_reference = os.path.join( - output_path_reference, test_case_name - ) - test_case_output_path_betterproto = os.path.join( - output_path_betterproto, test_case_name - ) + test_case_output_path_reference = output_path_reference.joinpath(test_case_name) + test_case_output_path_betterproto = output_path_betterproto.joinpath(test_case_name) os.makedirs(test_case_output_path_reference, exist_ok=True) os.makedirs(test_case_output_path_betterproto, exist_ok=True) @@ -79,14 +85,36 @@ def generate_test_case_output(test_case_name, test_case_input_path=None): clear_directory(test_case_output_path_reference) clear_directory(test_case_output_path_betterproto) - protoc_reference(test_case_input_path, test_case_output_path_reference) - protoc_plugin(test_case_input_path, test_case_output_path_betterproto) + ( + (ref_out, ref_err, ref_code), + (plg_out, plg_err, plg_code), + ) = await asyncio.gather( + protoc_reference(test_case_input_path, test_case_output_path_reference), + protoc_plugin(test_case_input_path, test_case_output_path_betterproto), + ) + + message = f"Generated output for {test_case_name!r}" + if verbose: + print(f"\033[31;1;4m{message}\033[0m") + if ref_out: + sys.stdout.buffer.write(ref_out) + if ref_err: + sys.stderr.buffer.write(ref_err) + if plg_out: + sys.stdout.buffer.write(plg_out) + if plg_err: + sys.stderr.buffer.write(plg_err) + sys.stdout.buffer.flush() + sys.stderr.buffer.flush() + else: + print(message) + + return max(ref_code, plg_code) HELP = "\n".join( - [ - "Usage: python generate.py", - " python generate.py [DIRECTORIES or NAMES]", + ( + "Usage: python generate.py [-h] [-v] [DIRECTORIES or NAMES]", "Generate python classes for standard tests.", "", "DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.", @@ -94,7 +122,7 @@ HELP = "\n".join( "", "NAMES One or more test-case names to generate classes for.", " python generate.py bool double enums", - ] + ) ) @@ -102,9 +130,13 @@ def main(): if set(sys.argv).intersection({"-h", "--help"}): print(HELP) return - whitelist = set(sys.argv[1:]) - - generate(whitelist) + if sys.argv[1:2] == ["-v"]: + verbose = True + whitelist = set(sys.argv[2:]) + else: + verbose = False + whitelist = set(sys.argv[1:]) + asyncio.get_event_loop().run_until_complete(generate(whitelist, verbose)) if __name__ == "__main__": diff --git a/betterproto/tests/test_inputs.py b/betterproto/tests/test_inputs.py index cac8327..5fd3ccc 100644 --- a/betterproto/tests/test_inputs.py +++ b/betterproto/tests/test_inputs.py @@ -23,7 +23,7 @@ from google.protobuf.json_format import Parse class TestCases: def __init__(self, path, services: Set[str], xfail: Set[str]): - _all = set(get_directories(path)) + _all = set(get_directories(path)) - {"__pycache__"} _services = services _messages = _all - services _messages_with_json = { diff --git a/betterproto/tests/util.py b/betterproto/tests/util.py index a7cff7a..61ba53e 100644 --- a/betterproto/tests/util.py +++ b/betterproto/tests/util.py @@ -1,23 +1,24 @@ +import asyncio import os -import subprocess -from typing import Generator +from pathlib import Path +from typing import Generator, IO, Optional os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" -root_path = os.path.dirname(os.path.realpath(__file__)) -inputs_path = os.path.join(root_path, "inputs") -output_path_reference = os.path.join(root_path, "output_reference") -output_path_betterproto = os.path.join(root_path, "output_betterproto") +root_path = Path(__file__).resolve().parent +inputs_path = root_path.joinpath("inputs") +output_path_reference = root_path.joinpath("output_reference") +output_path_betterproto = root_path.joinpath("output_betterproto") if os.name == "nt": - plugin_path = os.path.join(root_path, "..", "plugin.bat") + plugin_path = root_path.joinpath("..", "plugin.bat").resolve() else: - plugin_path = os.path.join(root_path, "..", "plugin.py") + plugin_path = root_path.joinpath("..", "plugin.py").resolve() -def get_files(path, end: str) -> Generator[str, None, None]: +def get_files(path, suffix: str) -> Generator[str, None, None]: for r, dirs, files in os.walk(path): - for filename in [f for f in files if f.endswith(end)]: + for filename in [f for f in files if f.endswith(suffix)]: yield os.path.join(r, filename) @@ -27,36 +28,30 @@ def get_directories(path): yield directory -def relative(file: str, path: str): - return os.path.join(os.path.dirname(file), path) - - -def read_relative(file: str, path: str): - with open(relative(file, path)) as fh: - return fh.read() - - -def protoc_plugin(path: str, output_dir: str) -> subprocess.CompletedProcess: - return subprocess.run( +async def protoc_plugin(path: str, output_dir: str): + proc = await asyncio.create_subprocess_shell( f"protoc --plugin=protoc-gen-custom={plugin_path} --custom_out={output_dir} --proto_path={path} {path}/*.proto", - shell=True, - check=True, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, ) + return (*(await proc.communicate()), proc.returncode) -def protoc_reference(path: str, output_dir: str): - subprocess.run( +async def protoc_reference(path: str, output_dir: str): + proc = await asyncio.create_subprocess_shell( f"protoc --python_out={output_dir} --proto_path={path} {path}/*.proto", - shell=True, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, ) + return (*(await proc.communicate()), proc.returncode) -def get_test_case_json_data(test_case_name, json_file_name=None): +def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = None): test_data_file_name = json_file_name if json_file_name else f"{test_case_name}.json" - test_data_file_path = os.path.join(inputs_path, test_case_name, test_data_file_name) + test_data_file_path = inputs_path.joinpath(test_case_name, test_data_file_name) - if not os.path.exists(test_data_file_path): + if not test_data_file_path.exists(): return None - with open(test_data_file_path) as fh: + with test_data_file_path.open("r") as fh: return fh.read() From c8229e53a7b73f60c4d658e512cebb1cf8080a7d Mon Sep 17 00:00:00 2001 From: Nat Noordanus Date: Sun, 7 Jun 2020 19:10:41 +0200 Subject: [PATCH 05/10] Fix most mypy warnings --- betterproto/__init__.py | 4 ++-- betterproto/_types.py | 6 +++++- betterproto/grpc/grpclib_client.py | 7 ++++--- betterproto/plugin.py | 8 ++++---- 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 6a53d65..c1e60ea 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -440,7 +440,7 @@ class ProtoClassMetadata: def __init__(self, cls: Type["Message"]): by_field = {} - by_group = {} + by_group: Dict[str, Set] = {} by_field_name = {} by_field_number = {} @@ -780,7 +780,7 @@ class Message(ABC): def to_dict( self, casing: Casing = Casing.CAMEL, include_default_values: bool = False - ) -> dict: + ) -> Dict[str, Any]: """ Returns a dict representation of this message instance which can be used to serialize to e.g. JSON. Defaults to camel casing for diff --git a/betterproto/_types.py b/betterproto/_types.py index 0ff23e4..d03432c 100644 --- a/betterproto/_types.py +++ b/betterproto/_types.py @@ -1,4 +1,8 @@ -from typing import TypeVar +from typing import TYPE_CHECKING, TypeVar + +if TYPE_CHECKING: + from . import Message + from grpclib._protocols import IProtoMessage # Bound type variable to allow methods to return `self` of subclasses T = TypeVar("T", bound="Message") diff --git a/betterproto/grpc/grpclib_client.py b/betterproto/grpc/grpclib_client.py index 7218574..7f48fb9 100644 --- a/betterproto/grpc/grpclib_client.py +++ b/betterproto/grpc/grpclib_client.py @@ -3,9 +3,10 @@ import asyncio import grpclib.const from typing import ( Any, + AsyncIterable, AsyncIterator, Collection, - Iterator, + Iterable, Mapping, Optional, Tuple, @@ -23,7 +24,7 @@ if TYPE_CHECKING: _Value = Union[str, bytes] _MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] -_MessageSource = Union[Iterator["IProtoMessage"], AsyncIterator["IProtoMessage"]] +_MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]] class ServiceStub(ABC): @@ -160,7 +161,7 @@ class ServiceStub(ABC): @staticmethod async def _send_messages(stream, messages: _MessageSource): - if hasattr(messages, "__aiter__"): + if isinstance(messages, AsyncIterable): async for message in messages: await stream.send_message(message) else: diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 85fd905..ed14e00 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -6,10 +6,10 @@ import re import stringcase import sys import textwrap -from typing import List +from typing import List, Union +import betterproto from betterproto.casing import safe_snake_case from betterproto.compile.importing import get_ref_type -import betterproto try: # betterproto[compiler] specific dependencies @@ -58,8 +58,8 @@ def py_type( raise NotImplementedError(f"Unknown type {descriptor.type}") -def get_py_zero(type_num: int) -> str: - zero = 0 +def get_py_zero(type_num: int) -> Union[str, float]: + zero: Union[str, float] = 0 if type_num in []: zero = 0.0 elif type_num == 8: From 159c30ddd8917a14a56c65b96c400fd4b9c90f5d Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Mon, 15 Jun 2020 18:02:05 +0200 Subject: [PATCH 06/10] Fix close not awaitable, fix done is callable, fix return async next value --- betterproto/grpc/util/async_channel.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/betterproto/grpc/util/async_channel.py b/betterproto/grpc/util/async_channel.py index 7e83c94..3a104ca 100644 --- a/betterproto/grpc/util/async_channel.py +++ b/betterproto/grpc/util/async_channel.py @@ -100,13 +100,14 @@ class AsyncChannel(AsyncIterable[T]): return self async def __anext__(self) -> T: - if self.done: + if self.done(): raise StopAsyncIteration self._waiting_recievers += 1 try: result = await self._queue.get() if result is self.__flush: raise StopAsyncIteration + return result finally: self._waiting_recievers -= 1 self._queue.task_done() @@ -151,7 +152,7 @@ class AsyncChannel(AsyncIterable[T]): await self._queue.put(item) if close: # Complete the closing process - await self.close() + self.close() async def send(self, item: T): """ @@ -168,7 +169,7 @@ class AsyncChannel(AsyncIterable[T]): or None if the channel is closed before another item is sent. :return: An item from the channel """ - if self.done: + if self.done(): raise ChannelDone("Cannot recieve from a closed channel") self._waiting_recievers += 1 try: From f7aa6150e25368d2c6be75e661fa2afe52eb05e0 Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Mon, 15 Jun 2020 18:02:37 +0200 Subject: [PATCH 07/10] Add test-cases for client stream-stream --- betterproto/tests/grpc/test_stream_stream.py | 124 +++++++++++++++++++ 1 file changed, 124 insertions(+) create mode 100644 betterproto/tests/grpc/test_stream_stream.py diff --git a/betterproto/tests/grpc/test_stream_stream.py b/betterproto/tests/grpc/test_stream_stream.py new file mode 100644 index 0000000..5768189 --- /dev/null +++ b/betterproto/tests/grpc/test_stream_stream.py @@ -0,0 +1,124 @@ +import asyncio +from dataclasses import dataclass +from typing import AsyncIterator + +import pytest + +import betterproto +from betterproto.grpc.util.async_channel import AsyncChannel + + +@dataclass +class Message(betterproto.Message): + body: str = betterproto.string_field(1) + + +async def to_list(generator: AsyncIterator): + lis = [] + async for value in generator: + lis.append(value) + return lis + + +@pytest.fixture +def expected_responses(): + return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")] + + +class ClientStub: + async def connect(self, requests): + await asyncio.sleep(0.1) + async for request in requests: + await asyncio.sleep(0.1) + yield request + await asyncio.sleep(0.1) + yield Message("Done") + + +@pytest.fixture +def client(): + # channel = Channel(host='127.0.0.1', port=50051) + # return ClientStub(channel) + return ClientStub() + + +@pytest.mark.asyncio +async def test_from_list_close_automatically(client, expected_responses): + requests = AsyncChannel( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True + ) + + responses = client.connect(requests) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_from_list_close_manually_immediately(client, expected_responses): + requests = AsyncChannel( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=False + ) + + requests.close() + + responses = client.connect(requests) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_from_list_close_manually_after_connect(client, expected_responses): + requests = AsyncChannel( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=False + ) + + responses = client.connect(requests) + + requests.close() + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_from_before_connect_and_close_automatically( + client, expected_responses +): + requests = AsyncChannel() + + await requests.send_from( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True + ) + + responses = client.connect(requests) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_from_after_connect_and_close_automatically( + client, expected_responses +): + requests = AsyncChannel() + + responses = client.connect(requests) + + await requests.send_from( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True + ) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_from_close_manually_immediately(client, expected_responses): + requests = AsyncChannel() + + responses = client.connect(requests) + + await requests.send_from( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=False + ) + + requests.close() + + assert await to_list(responses) == expected_responses From 0814729c5af0f96ad933ae0e72c695fd0dce8d41 Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Mon, 15 Jun 2020 18:14:13 +0200 Subject: [PATCH 08/10] Add cases for send() --- betterproto/tests/grpc/test_stream_stream.py | 43 ++++++++++++++++---- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/betterproto/tests/grpc/test_stream_stream.py b/betterproto/tests/grpc/test_stream_stream.py index 5768189..3c2c7e2 100644 --- a/betterproto/tests/grpc/test_stream_stream.py +++ b/betterproto/tests/grpc/test_stream_stream.py @@ -13,20 +13,13 @@ class Message(betterproto.Message): body: str = betterproto.string_field(1) -async def to_list(generator: AsyncIterator): - lis = [] - async for value in generator: - lis.append(value) - return lis - - @pytest.fixture def expected_responses(): return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")] class ClientStub: - async def connect(self, requests): + async def connect(self, requests: AsyncIterator): await asyncio.sleep(0.1) async for request in requests: await asyncio.sleep(0.1) @@ -35,6 +28,13 @@ class ClientStub: yield Message("Done") +async def to_list(generator: AsyncIterator): + lis = [] + async for value in generator: + lis.append(value) + return lis + + @pytest.fixture def client(): # channel = Channel(host='127.0.0.1', port=50051) @@ -122,3 +122,30 @@ async def test_send_from_close_manually_immediately(client, expected_responses): requests.close() assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_individually_and_close_before_connect(client, expected_responses): + requests = AsyncChannel() + + await requests.send(Message(body="Hello world 1")) + await requests.send(Message(body="Hello world 2")) + requests.close() + + responses = client.connect(requests) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_individually_and_close_after_connect(client, expected_responses): + requests = AsyncChannel() + + await requests.send(Message(body="Hello world 1")) + await requests.send(Message(body="Hello world 2")) + + responses = client.connect(requests) + + requests.close() + + assert await to_list(responses) == expected_responses From 50bb67bf5dca04ded331adbcdcedab3aed7d7de1 Mon Sep 17 00:00:00 2001 From: Nat Noordanus Date: Mon, 15 Jun 2020 23:35:56 +0200 Subject: [PATCH 09/10] Fix bugs and remove footgun feature in AsyncChannel --- betterproto/grpc/util/async_channel.py | 24 +++++++------------ betterproto/tests/grpc/test_grpclib_client.py | 22 ++++++++++------- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/betterproto/grpc/util/async_channel.py b/betterproto/grpc/util/async_channel.py index 7e83c94..fd0ecc2 100644 --- a/betterproto/grpc/util/async_channel.py +++ b/betterproto/grpc/util/async_channel.py @@ -81,17 +81,10 @@ class AsyncChannel(AsyncIterable[T]): """ def __init__( - self, - source: Union[Iterable[T], AsyncIterable[T]] = tuple(), - *, - buffer_limit: int = 0, - close: bool = False, + self, *, buffer_limit: int = 0, close: bool = False, ): self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit) self._closed = False - self._sending_task = ( - asyncio.ensure_future(self.send_from(source, close)) if source else None - ) self._waiting_recievers: int = 0 # Track whether flush has been invoked so it can only happen once self._flushed = False @@ -100,13 +93,14 @@ class AsyncChannel(AsyncIterable[T]): return self async def __anext__(self) -> T: - if self.done: + if self.done(): raise StopAsyncIteration self._waiting_recievers += 1 try: result = await self._queue.get() if result is self.__flush: raise StopAsyncIteration + return result finally: self._waiting_recievers -= 1 self._queue.task_done() @@ -131,7 +125,7 @@ class AsyncChannel(AsyncIterable[T]): async def send_from( self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False - ): + ) -> "AsyncChannel[T]": """ Iterates the given [Async]Iterable and sends all the resulting items. If close is set to True then subsequent send calls will be rejected with a @@ -151,9 +145,10 @@ class AsyncChannel(AsyncIterable[T]): await self._queue.put(item) if close: # Complete the closing process - await self.close() + self.close() + return self - async def send(self, item: T): + async def send(self, item: T) -> "AsyncChannel[T]": """ Send a single item over this channel. :param item: The item to send @@ -161,6 +156,7 @@ class AsyncChannel(AsyncIterable[T]): if self._closed: raise ChannelClosed("Cannot send through a closed channel") await self._queue.put(item) + return self async def recieve(self) -> Optional[T]: """ @@ -168,7 +164,7 @@ class AsyncChannel(AsyncIterable[T]): or None if the channel is closed before another item is sent. :return: An item from the channel """ - if self.done: + if self.done(): raise ChannelDone("Cannot recieve from a closed channel") self._waiting_recievers += 1 try: @@ -184,8 +180,6 @@ class AsyncChannel(AsyncIterable[T]): """ Close this channel to new items """ - if self._sending_task is not None: - self._sending_task.cancel() self._closed = True asyncio.ensure_future(self._flush_queue()) diff --git a/betterproto/tests/grpc/test_grpclib_client.py b/betterproto/tests/grpc/test_grpclib_client.py index dc57fe4..6c34ece 100644 --- a/betterproto/tests/grpc/test_grpclib_client.py +++ b/betterproto/tests/grpc/test_grpclib_client.py @@ -1,3 +1,4 @@ +import asyncio from betterproto.tests.output_betterproto.service.service import ( DoThingResponse, DoThingRequest, @@ -129,7 +130,10 @@ async def test_async_gen_for_stream_stream_request(): # Use an AsyncChannel to decouple sending and recieving, it'll send some_things # immediately and we'll use it to send more_things later, after recieving some # results - request_chan = AsyncChannel(GetThingRequest(name) for name in some_things) + request_chan = AsyncChannel() + send_initial_requests = asyncio.ensure_future( + request_chan.send_from(GetThingRequest(name) for name in some_things) + ) response_index = 0 async for response in client.get_different_things(request_chan): assert response.name == expected_things[response_index] @@ -138,13 +142,13 @@ async def test_async_gen_for_stream_stream_request(): if more_things: # Send some more requests as we recieve reponses to be sure coordination of # send/recieve events doesn't matter - another_response = await request_chan.send( - GetThingRequest(more_things.pop(0)) - ) - if another_response is not None: - assert another_response.name == expected_things[response_index] - assert another_response.version == response_index - response_index += 1 + await request_chan.send(GetThingRequest(more_things.pop(0))) + elif not send_initial_requests.done(): + # Make sure the sending task it completed + await send_initial_requests else: # No more things to send make sure channel is closed - await request_chan.close() + request_chan.close() + assert response_index == len( + expected_things + ), "Didn't recieve all exptected responses" From e1ccd540a9e00ff60e519e9ed2048366628c2a02 Mon Sep 17 00:00:00 2001 From: Nat Noordanus Date: Mon, 15 Jun 2020 23:35:56 +0200 Subject: [PATCH 10/10] Fix bugs and remove footgun feature in AsyncChannel --- betterproto/grpc/util/async_channel.py | 32 ++++++++----------- betterproto/tests/grpc/test_grpclib_client.py | 22 +++++++------ 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/betterproto/grpc/util/async_channel.py b/betterproto/grpc/util/async_channel.py index 7e83c94..de020a6 100644 --- a/betterproto/grpc/util/async_channel.py +++ b/betterproto/grpc/util/async_channel.py @@ -30,14 +30,15 @@ class ChannelDone(Exception): class AsyncChannel(AsyncIterable[T]): """ - A buffered async channel for sending items between coroutines with FIFO semantics. + A buffered async channel for sending items between coroutines with FIFO ordering. This makes decoupled bidirection steaming gRPC requests easy if used like: .. code-block:: python client = GeneratedStub(grpclib_chan) - # The channel can be initialised with items to send immediately - request_chan = AsyncChannel([ReqestObject(...), ReqestObject(...)]) + request_chan = await AsyncChannel() + # We can start be sending all the requests we already have + await request_chan.send_from([ReqestObject(...), ReqestObject(...)]) async for response in client.rpc_call(request_chan): # The response iterator will remain active until the connection is closed ... @@ -48,7 +49,6 @@ class AsyncChannel(AsyncIterable[T]): request_chan.close() Items can be sent through the channel by either: - - providing an iterable to the constructor - providing an iterable to the send_from method - passing them to the send method one at a time @@ -81,17 +81,10 @@ class AsyncChannel(AsyncIterable[T]): """ def __init__( - self, - source: Union[Iterable[T], AsyncIterable[T]] = tuple(), - *, - buffer_limit: int = 0, - close: bool = False, + self, *, buffer_limit: int = 0, close: bool = False, ): self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit) self._closed = False - self._sending_task = ( - asyncio.ensure_future(self.send_from(source, close)) if source else None - ) self._waiting_recievers: int = 0 # Track whether flush has been invoked so it can only happen once self._flushed = False @@ -100,13 +93,14 @@ class AsyncChannel(AsyncIterable[T]): return self async def __anext__(self) -> T: - if self.done: + if self.done(): raise StopAsyncIteration self._waiting_recievers += 1 try: result = await self._queue.get() if result is self.__flush: raise StopAsyncIteration + return result finally: self._waiting_recievers -= 1 self._queue.task_done() @@ -131,7 +125,7 @@ class AsyncChannel(AsyncIterable[T]): async def send_from( self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False - ): + ) -> "AsyncChannel[T]": """ Iterates the given [Async]Iterable and sends all the resulting items. If close is set to True then subsequent send calls will be rejected with a @@ -151,9 +145,10 @@ class AsyncChannel(AsyncIterable[T]): await self._queue.put(item) if close: # Complete the closing process - await self.close() + self.close() + return self - async def send(self, item: T): + async def send(self, item: T) -> "AsyncChannel[T]": """ Send a single item over this channel. :param item: The item to send @@ -161,6 +156,7 @@ class AsyncChannel(AsyncIterable[T]): if self._closed: raise ChannelClosed("Cannot send through a closed channel") await self._queue.put(item) + return self async def recieve(self) -> Optional[T]: """ @@ -168,7 +164,7 @@ class AsyncChannel(AsyncIterable[T]): or None if the channel is closed before another item is sent. :return: An item from the channel """ - if self.done: + if self.done(): raise ChannelDone("Cannot recieve from a closed channel") self._waiting_recievers += 1 try: @@ -184,8 +180,6 @@ class AsyncChannel(AsyncIterable[T]): """ Close this channel to new items """ - if self._sending_task is not None: - self._sending_task.cancel() self._closed = True asyncio.ensure_future(self._flush_queue()) diff --git a/betterproto/tests/grpc/test_grpclib_client.py b/betterproto/tests/grpc/test_grpclib_client.py index dc57fe4..6c34ece 100644 --- a/betterproto/tests/grpc/test_grpclib_client.py +++ b/betterproto/tests/grpc/test_grpclib_client.py @@ -1,3 +1,4 @@ +import asyncio from betterproto.tests.output_betterproto.service.service import ( DoThingResponse, DoThingRequest, @@ -129,7 +130,10 @@ async def test_async_gen_for_stream_stream_request(): # Use an AsyncChannel to decouple sending and recieving, it'll send some_things # immediately and we'll use it to send more_things later, after recieving some # results - request_chan = AsyncChannel(GetThingRequest(name) for name in some_things) + request_chan = AsyncChannel() + send_initial_requests = asyncio.ensure_future( + request_chan.send_from(GetThingRequest(name) for name in some_things) + ) response_index = 0 async for response in client.get_different_things(request_chan): assert response.name == expected_things[response_index] @@ -138,13 +142,13 @@ async def test_async_gen_for_stream_stream_request(): if more_things: # Send some more requests as we recieve reponses to be sure coordination of # send/recieve events doesn't matter - another_response = await request_chan.send( - GetThingRequest(more_things.pop(0)) - ) - if another_response is not None: - assert another_response.name == expected_things[response_index] - assert another_response.version == response_index - response_index += 1 + await request_chan.send(GetThingRequest(more_things.pop(0))) + elif not send_initial_requests.done(): + # Make sure the sending task it completed + await send_initial_requests else: # No more things to send make sure channel is closed - await request_chan.close() + request_chan.close() + assert response_index == len( + expected_things + ), "Didn't recieve all exptected responses"