diff --git a/.gitignore b/.gitignore index 6b9e7f0..93227fc 100644 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,5 @@ dist **/*.egg-info output .idea +.DS_Store .tox diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 5d901be..c1e60ea 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -5,8 +5,9 @@ import json import struct import sys from abc import ABC -from base64 import b64encode, b64decode +from base64 import b64decode, b64encode from datetime import datetime, timedelta, timezone +import stringcase from typing import ( Any, AsyncGenerator, @@ -14,28 +15,20 @@ from typing import ( Collection, Dict, Generator, + Iterator, List, Mapping, Optional, Set, + SupportsBytes, Tuple, Type, - TypeVar, Union, get_type_hints, - TYPE_CHECKING, ) - - -import grpclib.const -import stringcase - +from ._types import ST, T from .casing import safe_snake_case - -if TYPE_CHECKING: - from grpclib._protocols import IProtoMessage - from grpclib.client import Channel - from grpclib.metadata import Deadline +from .grpc.grpclib_client import ServiceStub if not (sys.version_info.major == 3 and sys.version_info.minor >= 7): # Apply backport of datetime.fromisoformat from 3.7 @@ -429,10 +422,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: ) -# Bound type variable to allow methods to return `self` of subclasses -T = TypeVar("T", bound="Message") - - class ProtoClassMetadata: oneof_group_by_field: Dict[str, str] oneof_field_by_group: Dict[str, Set[dataclasses.Field]] @@ -451,7 +440,7 @@ class ProtoClassMetadata: def __init__(self, cls: Type["Message"]): by_field = {} - by_group = {} + by_group: Dict[str, Set] = {} by_field_name = {} by_field_number = {} @@ -604,7 +593,7 @@ class Message(ABC): serialize_empty = False if isinstance(value, Message) and value._serialized_on_wire: # Empty messages can still be sent on the wire if they were - # set (or received empty). + # set (or recieved empty). serialize_empty = True if value == self._get_field_default(field_name) and not ( @@ -791,7 +780,7 @@ class Message(ABC): def to_dict( self, casing: Casing = Casing.CAMEL, include_default_values: bool = False - ) -> dict: + ) -> Dict[str, Any]: """ Returns a dict representation of this message instance which can be used to serialize to e.g. JSON. Defaults to camel casing for @@ -1024,83 +1013,3 @@ def _get_wrapper(proto_type: str) -> Type: TYPE_STRING: StringValue, TYPE_BYTES: BytesValue, }[proto_type] - - -_Value = Union[str, bytes] -_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] - - -class ServiceStub(ABC): - """ - Base class for async gRPC service stubs. - """ - - def __init__( - self, - channel: "Channel", - *, - timeout: Optional[float] = None, - deadline: Optional["Deadline"] = None, - metadata: Optional[_MetadataLike] = None, - ) -> None: - self.channel = channel - self.timeout = timeout - self.deadline = deadline - self.metadata = metadata - - def __resolve_request_kwargs( - self, - timeout: Optional[float], - deadline: Optional["Deadline"], - metadata: Optional[_MetadataLike], - ): - return { - "timeout": self.timeout if timeout is None else timeout, - "deadline": self.deadline if deadline is None else deadline, - "metadata": self.metadata if metadata is None else metadata, - } - - async def _unary_unary( - self, - route: str, - request: "IProtoMessage", - response_type: Type[T], - *, - timeout: Optional[float] = None, - deadline: Optional["Deadline"] = None, - metadata: Optional[_MetadataLike] = None, - ) -> T: - """Make a unary request and return the response.""" - async with self.channel.request( - route, - grpclib.const.Cardinality.UNARY_UNARY, - type(request), - response_type, - **self.__resolve_request_kwargs(timeout, deadline, metadata), - ) as stream: - await stream.send_message(request, end=True) - response = await stream.recv_message() - assert response is not None - return response - - async def _unary_stream( - self, - route: str, - request: "IProtoMessage", - response_type: Type[T], - *, - timeout: Optional[float] = None, - deadline: Optional["Deadline"] = None, - metadata: Optional[_MetadataLike] = None, - ) -> AsyncGenerator[T, None]: - """Make a unary request and return the stream response iterator.""" - async with self.channel.request( - route, - grpclib.const.Cardinality.UNARY_STREAM, - type(request), - response_type, - **self.__resolve_request_kwargs(timeout, deadline, metadata), - ) as stream: - await stream.send_message(request, end=True) - async for message in stream: - yield message diff --git a/betterproto/_types.py b/betterproto/_types.py new file mode 100644 index 0000000..d03432c --- /dev/null +++ b/betterproto/_types.py @@ -0,0 +1,9 @@ +from typing import TYPE_CHECKING, TypeVar + +if TYPE_CHECKING: + from . import Message + from grpclib._protocols import IProtoMessage + +# Bound type variable to allow methods to return `self` of subclasses +T = TypeVar("T", bound="Message") +ST = TypeVar("ST", bound="IProtoMessage") diff --git a/betterproto/grpc/__init__.py b/betterproto/grpc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/betterproto/grpc/grpclib_client.py b/betterproto/grpc/grpclib_client.py new file mode 100644 index 0000000..7f48fb9 --- /dev/null +++ b/betterproto/grpc/grpclib_client.py @@ -0,0 +1,170 @@ +from abc import ABC +import asyncio +import grpclib.const +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Collection, + Iterable, + Mapping, + Optional, + Tuple, + TYPE_CHECKING, + Type, + Union, +) +from .._types import ST, T + +if TYPE_CHECKING: + from grpclib._protocols import IProtoMessage + from grpclib.client import Channel, Stream + from grpclib.metadata import Deadline + + +_Value = Union[str, bytes] +_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]] +_MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]] + + +class ServiceStub(ABC): + """ + Base class for async gRPC clients. + """ + + def __init__( + self, + channel: "Channel", + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[_MetadataLike] = None, + ) -> None: + self.channel = channel + self.timeout = timeout + self.deadline = deadline + self.metadata = metadata + + def __resolve_request_kwargs( + self, + timeout: Optional[float], + deadline: Optional["Deadline"], + metadata: Optional[_MetadataLike], + ): + return { + "timeout": self.timeout if timeout is None else timeout, + "deadline": self.deadline if deadline is None else deadline, + "metadata": self.metadata if metadata is None else metadata, + } + + async def _unary_unary( + self, + route: str, + request: "IProtoMessage", + response_type: Type[T], + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[_MetadataLike] = None, + ) -> T: + """Make a unary request and return the response.""" + async with self.channel.request( + route, + grpclib.const.Cardinality.UNARY_UNARY, + type(request), + response_type, + **self.__resolve_request_kwargs(timeout, deadline, metadata), + ) as stream: + await stream.send_message(request, end=True) + response = await stream.recv_message() + assert response is not None + return response + + async def _unary_stream( + self, + route: str, + request: "IProtoMessage", + response_type: Type[T], + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[_MetadataLike] = None, + ) -> AsyncIterator[T]: + """Make a unary request and return the stream response iterator.""" + async with self.channel.request( + route, + grpclib.const.Cardinality.UNARY_STREAM, + type(request), + response_type, + **self.__resolve_request_kwargs(timeout, deadline, metadata), + ) as stream: + await stream.send_message(request, end=True) + async for message in stream: + yield message + + async def _stream_unary( + self, + route: str, + request_iterator: _MessageSource, + request_type: Type[ST], + response_type: Type[T], + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[_MetadataLike] = None, + ) -> T: + """Make a stream request and return the response.""" + async with self.channel.request( + route, + grpclib.const.Cardinality.STREAM_UNARY, + request_type, + response_type, + **self.__resolve_request_kwargs(timeout, deadline, metadata), + ) as stream: + await self._send_messages(stream, request_iterator) + response = await stream.recv_message() + assert response is not None + return response + + async def _stream_stream( + self, + route: str, + request_iterator: _MessageSource, + request_type: Type[ST], + response_type: Type[T], + *, + timeout: Optional[float] = None, + deadline: Optional["Deadline"] = None, + metadata: Optional[_MetadataLike] = None, + ) -> AsyncIterator[T]: + """ + Make a stream request and return an AsyncIterator to iterate over response + messages. + """ + async with self.channel.request( + route, + grpclib.const.Cardinality.STREAM_STREAM, + request_type, + response_type, + **self.__resolve_request_kwargs(timeout, deadline, metadata), + ) as stream: + await stream.send_request() + sending_task = asyncio.ensure_future( + self._send_messages(stream, request_iterator) + ) + try: + async for response in stream: + yield response + except: + sending_task.cancel() + raise + + @staticmethod + async def _send_messages(stream, messages: _MessageSource): + if isinstance(messages, AsyncIterable): + async for message in messages: + await stream.send_message(message) + else: + for message in messages: + await stream.send_message(message) + await stream.end() diff --git a/betterproto/grpc/util/__init__.py b/betterproto/grpc/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/betterproto/grpc/util/async_channel.py b/betterproto/grpc/util/async_channel.py new file mode 100644 index 0000000..de020a6 --- /dev/null +++ b/betterproto/grpc/util/async_channel.py @@ -0,0 +1,198 @@ +import asyncio +from typing import ( + AsyncIterable, + AsyncIterator, + Iterable, + Optional, + TypeVar, + Union, +) + +T = TypeVar("T") + + +class ChannelClosed(Exception): + """ + An exception raised on an attempt to send through a closed channel + """ + + pass + + +class ChannelDone(Exception): + """ + An exception raised on an attempt to send recieve from a channel that is both closed + and empty. + """ + + pass + + +class AsyncChannel(AsyncIterable[T]): + """ + A buffered async channel for sending items between coroutines with FIFO ordering. + + This makes decoupled bidirection steaming gRPC requests easy if used like: + + .. code-block:: python + client = GeneratedStub(grpclib_chan) + request_chan = await AsyncChannel() + # We can start be sending all the requests we already have + await request_chan.send_from([ReqestObject(...), ReqestObject(...)]) + async for response in client.rpc_call(request_chan): + # The response iterator will remain active until the connection is closed + ... + # More items can be sent at any time + await request_chan.send(ReqestObject(...)) + ... + # The channel must be closed to complete the gRPC connection + request_chan.close() + + Items can be sent through the channel by either: + - providing an iterable to the send_from method + - passing them to the send method one at a time + + Items can be recieved from the channel by either: + - iterating over the channel with a for loop to get all items + - calling the recieve method to get one item at a time + + If the channel is empty then recievers will wait until either an item appears or the + channel is closed. + + Once the channel is closed then subsequent attempt to send through the channel will + fail with a ChannelClosed exception. + + When th channel is closed and empty then it is done, and further attempts to recieve + from it will fail with a ChannelDone exception + + If multiple coroutines recieve from the channel concurrently, each item sent will be + recieved by only one of the recievers. + + :param source: + An optional iterable will items that should be sent through the channel + immediately. + :param buffer_limit: + Limit the number of items that can be buffered in the channel, A value less than + 1 implies no limit. If the channel is full then attempts to send more items will + result in the sender waiting until an item is recieved from the channel. + :param close: + If set to True then the channel will automatically close after exhausting source + or immediately if no source is provided. + """ + + def __init__( + self, *, buffer_limit: int = 0, close: bool = False, + ): + self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit) + self._closed = False + self._waiting_recievers: int = 0 + # Track whether flush has been invoked so it can only happen once + self._flushed = False + + def __aiter__(self) -> AsyncIterator[T]: + return self + + async def __anext__(self) -> T: + if self.done(): + raise StopAsyncIteration + self._waiting_recievers += 1 + try: + result = await self._queue.get() + if result is self.__flush: + raise StopAsyncIteration + return result + finally: + self._waiting_recievers -= 1 + self._queue.task_done() + + def closed(self) -> bool: + """ + Returns True if this channel is closed and no-longer accepting new items + """ + return self._closed + + def done(self) -> bool: + """ + Check if this channel is done. + + :return: True if this channel is closed and and has been drained of items in + which case any further attempts to recieve an item from this channel will raise + a ChannelDone exception. + """ + # After close the channel is not yet done until there is at least one waiting + # reciever per enqueued item. + return self._closed and self._queue.qsize() <= self._waiting_recievers + + async def send_from( + self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False + ) -> "AsyncChannel[T]": + """ + Iterates the given [Async]Iterable and sends all the resulting items. + If close is set to True then subsequent send calls will be rejected with a + ChannelClosed exception. + :param source: an iterable of items to send + :param close: + if True then the channel will be closed after the source has been exhausted + + """ + if self._closed: + raise ChannelClosed("Cannot send through a closed channel") + if isinstance(source, AsyncIterable): + async for item in source: + await self._queue.put(item) + else: + for item in source: + await self._queue.put(item) + if close: + # Complete the closing process + self.close() + return self + + async def send(self, item: T) -> "AsyncChannel[T]": + """ + Send a single item over this channel. + :param item: The item to send + """ + if self._closed: + raise ChannelClosed("Cannot send through a closed channel") + await self._queue.put(item) + return self + + async def recieve(self) -> Optional[T]: + """ + Returns the next item from this channel when it becomes available, + or None if the channel is closed before another item is sent. + :return: An item from the channel + """ + if self.done(): + raise ChannelDone("Cannot recieve from a closed channel") + self._waiting_recievers += 1 + try: + result = await self._queue.get() + if result is self.__flush: + return None + return result + finally: + self._waiting_recievers -= 1 + self._queue.task_done() + + def close(self): + """ + Close this channel to new items + """ + self._closed = True + asyncio.ensure_future(self._flush_queue()) + + async def _flush_queue(self): + """ + To be called after the channel is closed. Pushes a number of self.__flush + objects to the queue to ensure no waiting consumers get deadlocked. + """ + if not self._flushed: + self._flushed = True + deadlocked_recievers = max(0, self._waiting_recievers - self._queue.qsize()) + for _ in range(deadlocked_recievers): + await self._queue.put(self.__flush) + + # A special signal object for flushing the queue when the channel is closed + __flush = object() diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 5524afd..928f026 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -6,10 +6,10 @@ import re import stringcase import sys import textwrap -from typing import List +from typing import List, Union +import betterproto from betterproto.casing import safe_snake_case from betterproto.compile.importing import get_ref_type -import betterproto try: # betterproto[compiler] specific dependencies @@ -58,8 +58,8 @@ def py_type( raise NotImplementedError(f"Unknown type {descriptor.type}") -def get_py_zero(type_num: int) -> str: - zero = 0 +def get_py_zero(type_num: int) -> Union[str, float]: + zero: Union[str, float] = 0 if type_num in []: zero = 0.0 elif type_num == 8: @@ -311,9 +311,6 @@ def generate_code(request, response): } for j, method in enumerate(service.method): - if method.client_streaming: - raise NotImplementedError("Client streaming not yet supported") - input_message = None input_type = get_ref_type( package, output["imports"], method.input_type @@ -347,8 +344,12 @@ def generate_code(request, response): } ) + if method.client_streaming: + output["typing_imports"].add("AsyncIterable") + output["typing_imports"].add("Iterable") + output["typing_imports"].add("Union") if method.server_streaming: - output["typing_imports"].add("AsyncGenerator") + output["typing_imports"].add("AsyncIterator") output["services"].append(data) diff --git a/betterproto/templates/template.py.j2 b/betterproto/templates/template.py.j2 index 3a19422..3894619 100644 --- a/betterproto/templates/template.py.j2 +++ b/betterproto/templates/template.py.j2 @@ -63,11 +63,28 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% endif %} {% for method in service.methods %} - async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}: + async def {{ method.py_name }}(self + {%- if not method.client_streaming -%} + {%- if method.input_message and method.input_message.properties -%}, *, + {%- for field in method.input_message.properties -%} + {{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") -%} + Optional[{{ field.type }}] + {%- else -%} + {{ field.type }} + {%- endif -%} = {{ field.zero }} + {%- if not loop.last %}, {% endif -%} + {%- endfor -%} + {%- endif -%} + {%- else -%} + {# Client streaming: need a request iterator instead #} + , request_iterator: Union[AsyncIterable["{{ method.input }}"], Iterable["{{ method.input }}"]] + {%- endif -%} + ) -> {% if method.server_streaming %}AsyncIterator[{{ method.output }}]{% else %}{{ method.output }}{% endif %}: {% if method.comment %} {{ method.comment }} {% endif %} + {% if not method.client_streaming %} request = {{ method.input }}() {% for field in method.input_message.properties %} {% if field.field_type == 'message' %} @@ -77,20 +94,41 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): request.{{ field.py_name }} = {{ field.py_name }} {% endif %} {% endfor %} + {% endif %} {% if method.server_streaming %} + {% if method.client_streaming %} + async for response in self._stream_stream( + "{{ method.route }}", + request_iterator, + {{ method.input }}, + {{ method.output }}, + ): + yield response + {% else %}{# i.e. not client streaming #} async for response in self._unary_stream( "{{ method.route }}", request, {{ method.output }}, ): yield response - {% else %} + + {% endif %}{# if client streaming #} + {% else %}{# i.e. not server streaming #} + {% if method.client_streaming %} + return await self._stream_unary( + "{{ method.route }}", + request_iterator, + {{ method.input }}, + {{ method.output }} + ) + {% else %}{# i.e. not client streaming #} return await self._unary_unary( "{{ method.route }}", request, - {{ method.output }}, + {{ method.output }} ) + {% endif %}{# client streaming #} {% endif %} {% endfor %} diff --git a/betterproto/tests/generate.py b/betterproto/tests/generate.py index fc85b7f..5c555ff 100755 --- a/betterproto/tests/generate.py +++ b/betterproto/tests/generate.py @@ -1,6 +1,7 @@ #!/usr/bin/env python -import glob +import asyncio import os +from pathlib import Path import shutil import subprocess import sys @@ -20,58 +21,63 @@ from betterproto.tests.util import ( os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" -def clear_directory(path: str): - for file_or_directory in glob.glob(os.path.join(path, "*")): - if os.path.isdir(file_or_directory): +def clear_directory(dir_path: Path): + for file_or_directory in dir_path.glob("*"): + if file_or_directory.is_dir(): shutil.rmtree(file_or_directory) else: - os.remove(file_or_directory) + file_or_directory.unlink() -def generate(whitelist: Set[str]): - path_whitelist = {os.path.realpath(e) for e in whitelist if os.path.exists(e)} - name_whitelist = {e for e in whitelist if not os.path.exists(e)} +async def generate(whitelist: Set[str], verbose: bool): + test_case_names = set(get_directories(inputs_path)) - {"__pycache__"} - test_case_names = set(get_directories(inputs_path)) - - failed_test_cases = [] + path_whitelist = set() + name_whitelist = set() + for item in whitelist: + if item in test_case_names: + name_whitelist.add(item) + continue + path_whitelist.add(item) + generation_tasks = [] for test_case_name in sorted(test_case_names): - test_case_input_path = os.path.realpath( - os.path.join(inputs_path, test_case_name) - ) - + test_case_input_path = inputs_path.joinpath(test_case_name).resolve() if ( whitelist - and test_case_input_path not in path_whitelist + and str(test_case_input_path) not in path_whitelist and test_case_name not in name_whitelist ): continue + generation_tasks.append( + generate_test_case_output(test_case_input_path, test_case_name, verbose) + ) - print(f"Generating output for {test_case_name}") - try: - generate_test_case_output(test_case_name, test_case_input_path) - except subprocess.CalledProcessError as e: + failed_test_cases = [] + # Wait for all subprocs and match any failures to names to report + for test_case_name, result in zip( + sorted(test_case_names), await asyncio.gather(*generation_tasks) + ): + if result != 0: failed_test_cases.append(test_case_name) if failed_test_cases: - sys.stderr.write("\nFailed to generate the following test cases:\n") + sys.stderr.write( + "\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n" + ) for failed_test_case in failed_test_cases: sys.stderr.write(f"- {failed_test_case}\n") -def generate_test_case_output(test_case_name, test_case_input_path=None): - if not test_case_input_path: - test_case_input_path = os.path.realpath( - os.path.join(inputs_path, test_case_name) - ) +async def generate_test_case_output( + test_case_input_path: Path, test_case_name: str, verbose: bool +) -> int: + """ + Returns the max of the subprocess return values + """ - test_case_output_path_reference = os.path.join( - output_path_reference, test_case_name - ) - test_case_output_path_betterproto = os.path.join( - output_path_betterproto, test_case_name - ) + test_case_output_path_reference = output_path_reference.joinpath(test_case_name) + test_case_output_path_betterproto = output_path_betterproto.joinpath(test_case_name) os.makedirs(test_case_output_path_reference, exist_ok=True) os.makedirs(test_case_output_path_betterproto, exist_ok=True) @@ -79,14 +85,36 @@ def generate_test_case_output(test_case_name, test_case_input_path=None): clear_directory(test_case_output_path_reference) clear_directory(test_case_output_path_betterproto) - protoc_reference(test_case_input_path, test_case_output_path_reference) - protoc_plugin(test_case_input_path, test_case_output_path_betterproto) + ( + (ref_out, ref_err, ref_code), + (plg_out, plg_err, plg_code), + ) = await asyncio.gather( + protoc_reference(test_case_input_path, test_case_output_path_reference), + protoc_plugin(test_case_input_path, test_case_output_path_betterproto), + ) + + message = f"Generated output for {test_case_name!r}" + if verbose: + print(f"\033[31;1;4m{message}\033[0m") + if ref_out: + sys.stdout.buffer.write(ref_out) + if ref_err: + sys.stderr.buffer.write(ref_err) + if plg_out: + sys.stdout.buffer.write(plg_out) + if plg_err: + sys.stderr.buffer.write(plg_err) + sys.stdout.buffer.flush() + sys.stderr.buffer.flush() + else: + print(message) + + return max(ref_code, plg_code) HELP = "\n".join( - [ - "Usage: python generate.py", - " python generate.py [DIRECTORIES or NAMES]", + ( + "Usage: python generate.py [-h] [-v] [DIRECTORIES or NAMES]", "Generate python classes for standard tests.", "", "DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.", @@ -94,7 +122,7 @@ HELP = "\n".join( "", "NAMES One or more test-case names to generate classes for.", " python generate.py bool double enums", - ] + ) ) @@ -102,9 +130,13 @@ def main(): if set(sys.argv).intersection({"-h", "--help"}): print(HELP) return - whitelist = set(sys.argv[1:]) - - generate(whitelist) + if sys.argv[1:2] == ["-v"]: + verbose = True + whitelist = set(sys.argv[2:]) + else: + verbose = False + whitelist = set(sys.argv[1:]) + asyncio.get_event_loop().run_until_complete(generate(whitelist, verbose)) if __name__ == "__main__": diff --git a/betterproto/tests/grpc/__init__.py b/betterproto/tests/grpc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/betterproto/tests/grpc/test_grpclib_client.py b/betterproto/tests/grpc/test_grpclib_client.py new file mode 100644 index 0000000..6c34ece --- /dev/null +++ b/betterproto/tests/grpc/test_grpclib_client.py @@ -0,0 +1,154 @@ +import asyncio +from betterproto.tests.output_betterproto.service.service import ( + DoThingResponse, + DoThingRequest, + GetThingRequest, + GetThingResponse, + TestStub as ThingServiceClient, +) +import grpclib +from grpclib.testing import ChannelFor +import pytest +from betterproto.grpc.util.async_channel import AsyncChannel +from .thing_service import ThingService + + +async def _test_client(client, name="clean room", **kwargs): + response = await client.do_thing(name=name) + assert response.names == [name] + + +def _assert_request_meta_recieved(deadline, metadata): + def server_side_test(stream): + assert stream.deadline._timestamp == pytest.approx( + deadline._timestamp, 1 + ), "The provided deadline should be recieved serverside" + assert ( + stream.metadata["authorization"] == metadata["authorization"] + ), "The provided authorization metadata should be recieved serverside" + + return server_side_test + + +@pytest.mark.asyncio +async def test_simple_service_call(): + async with ChannelFor([ThingService()]) as channel: + await _test_client(ThingServiceClient(channel)) + + +@pytest.mark.asyncio +async def test_service_call_with_upfront_request_params(): + # Setting deadline + deadline = grpclib.metadata.Deadline.from_timeout(22) + metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)] + ) as channel: + await _test_client( + ThingServiceClient(channel, deadline=deadline, metadata=metadata) + ) + + # Setting timeout + timeout = 99 + deadline = grpclib.metadata.Deadline.from_timeout(timeout) + metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)] + ) as channel: + await _test_client( + ThingServiceClient(channel, timeout=timeout, metadata=metadata) + ) + + +@pytest.mark.asyncio +async def test_service_call_lower_level_with_overrides(): + THING_TO_DO = "get milk" + + # Setting deadline + deadline = grpclib.metadata.Deadline.from_timeout(22) + metadata = {"authorization": "12345"} + kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28) + kwarg_metadata = {"authorization": "12345"} + async with ChannelFor( + [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)] + ) as channel: + client = ThingServiceClient(channel, deadline=deadline, metadata=metadata) + response = await client._unary_unary( + "/service.Test/DoThing", + DoThingRequest(THING_TO_DO), + DoThingResponse, + deadline=kwarg_deadline, + metadata=kwarg_metadata, + ) + assert response.names == [THING_TO_DO] + + # Setting timeout + timeout = 99 + deadline = grpclib.metadata.Deadline.from_timeout(timeout) + metadata = {"authorization": "12345"} + kwarg_timeout = 9000 + kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout) + kwarg_metadata = {"authorization": "09876"} + async with ChannelFor( + [ + ThingService( + test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata), + ) + ] + ) as channel: + client = ThingServiceClient(channel, deadline=deadline, metadata=metadata) + response = await client._unary_unary( + "/service.Test/DoThing", + DoThingRequest(THING_TO_DO), + DoThingResponse, + timeout=kwarg_timeout, + metadata=kwarg_metadata, + ) + assert response.names == [THING_TO_DO] + + +@pytest.mark.asyncio +async def test_async_gen_for_unary_stream_request(): + thing_name = "my milkshakes" + + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + expected_versions = [5, 4, 3, 2, 1] + async for response in client.get_thing_versions(name=thing_name): + assert response.name == thing_name + assert response.version == expected_versions.pop() + + +@pytest.mark.asyncio +async def test_async_gen_for_stream_stream_request(): + some_things = ["cake", "cricket", "coral reef"] + more_things = ["ball", "that", "56kmodem", "liberal humanism", "cheesesticks"] + expected_things = (*some_things, *more_things) + + async with ChannelFor([ThingService()]) as channel: + client = ThingServiceClient(channel) + # Use an AsyncChannel to decouple sending and recieving, it'll send some_things + # immediately and we'll use it to send more_things later, after recieving some + # results + request_chan = AsyncChannel() + send_initial_requests = asyncio.ensure_future( + request_chan.send_from(GetThingRequest(name) for name in some_things) + ) + response_index = 0 + async for response in client.get_different_things(request_chan): + assert response.name == expected_things[response_index] + assert response.version == response_index + 1 + response_index += 1 + if more_things: + # Send some more requests as we recieve reponses to be sure coordination of + # send/recieve events doesn't matter + await request_chan.send(GetThingRequest(more_things.pop(0))) + elif not send_initial_requests.done(): + # Make sure the sending task it completed + await send_initial_requests + else: + # No more things to send make sure channel is closed + request_chan.close() + assert response_index == len( + expected_things + ), "Didn't recieve all exptected responses" diff --git a/betterproto/tests/grpc/test_stream_stream.py b/betterproto/tests/grpc/test_stream_stream.py new file mode 100644 index 0000000..2fc9237 --- /dev/null +++ b/betterproto/tests/grpc/test_stream_stream.py @@ -0,0 +1,100 @@ +import asyncio +import betterproto +from betterproto.grpc.util.async_channel import AsyncChannel +from dataclasses import dataclass +import pytest +from typing import AsyncIterator + + +@dataclass +class Message(betterproto.Message): + body: str = betterproto.string_field(1) + + +@pytest.fixture +def expected_responses(): + return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")] + + +class ClientStub: + async def connect(self, requests: AsyncIterator): + await asyncio.sleep(0.1) + async for request in requests: + await asyncio.sleep(0.1) + yield request + await asyncio.sleep(0.1) + yield Message("Done") + + +async def to_list(generator: AsyncIterator): + result = [] + async for value in generator: + result.append(value) + return result + + +@pytest.fixture +def client(): + # channel = Channel(host='127.0.0.1', port=50051) + # return ClientStub(channel) + return ClientStub() + + +@pytest.mark.asyncio +async def test_send_from_before_connect_and_close_automatically( + client, expected_responses +): + requests = AsyncChannel() + await requests.send_from( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True + ) + responses = client.connect(requests) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_from_after_connect_and_close_automatically( + client, expected_responses +): + requests = AsyncChannel() + responses = client.connect(requests) + await requests.send_from( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=True + ) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_from_close_manually_immediately(client, expected_responses): + requests = AsyncChannel() + responses = client.connect(requests) + await requests.send_from( + [Message(body="Hello world 1"), Message(body="Hello world 2")], close=False + ) + requests.close() + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_individually_and_close_before_connect(client, expected_responses): + requests = AsyncChannel() + await requests.send(Message(body="Hello world 1")) + await requests.send(Message(body="Hello world 2")) + requests.close() + responses = client.connect(requests) + + assert await to_list(responses) == expected_responses + + +@pytest.mark.asyncio +async def test_send_individually_and_close_after_connect(client, expected_responses): + requests = AsyncChannel() + await requests.send(Message(body="Hello world 1")) + await requests.send(Message(body="Hello world 2")) + responses = client.connect(requests) + requests.close() + + assert await to_list(responses) == expected_responses diff --git a/betterproto/tests/grpc/thing_service.py b/betterproto/tests/grpc/thing_service.py new file mode 100644 index 0000000..bc9fff8 --- /dev/null +++ b/betterproto/tests/grpc/thing_service.py @@ -0,0 +1,83 @@ +from betterproto.tests.output_betterproto.service.service import ( + DoThingResponse, + DoThingRequest, + GetThingRequest, + GetThingResponse, + TestStub as ThingServiceClient, +) +import grpclib +from typing import Any, Dict + + +class ThingService: + def __init__(self, test_hook=None): + # This lets us pass assertions to the servicer ;) + self.test_hook = test_hook + + async def do_thing( + self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" + ): + request = await stream.recv_message() + if self.test_hook is not None: + self.test_hook(stream) + await stream.send_message(DoThingResponse([request.name])) + + async def do_many_things( + self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" + ): + thing_names = [request.name for request in stream] + if self.test_hook is not None: + self.test_hook(stream) + await stream.send_message(DoThingResponse(thing_names)) + + async def get_thing_versions( + self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" + ): + request = await stream.recv_message() + if self.test_hook is not None: + self.test_hook(stream) + for version_num in range(1, 6): + await stream.send_message( + GetThingResponse(name=request.name, version=version_num) + ) + + async def get_different_things( + self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]" + ): + if self.test_hook is not None: + self.test_hook(stream) + # Respond to each input item immediately + response_num = 0 + async for request in stream: + response_num += 1 + await stream.send_message( + GetThingResponse(name=request.name, version=response_num) + ) + + def __mapping__(self) -> Dict[str, "grpclib.const.Handler"]: + return { + "/service.Test/DoThing": grpclib.const.Handler( + self.do_thing, + grpclib.const.Cardinality.UNARY_UNARY, + DoThingRequest, + DoThingResponse, + ), + "/service.Test/DoManyThings": grpclib.const.Handler( + self.do_many_things, + grpclib.const.Cardinality.STREAM_UNARY, + DoThingRequest, + DoThingResponse, + ), + "/service.Test/GetThingVersions": grpclib.const.Handler( + self.get_thing_versions, + grpclib.const.Cardinality.UNARY_STREAM, + GetThingRequest, + GetThingResponse, + ), + "/service.Test/GetDifferentThings": grpclib.const.Handler( + self.get_different_things, + grpclib.const.Cardinality.STREAM_STREAM, + GetThingRequest, + GetThingResponse, + ), + } diff --git a/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py b/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py index 02fa193..bd5f602 100644 --- a/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py +++ b/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py @@ -23,7 +23,7 @@ test_cases = [ @pytest.mark.asyncio @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) -async def test_channel_receives_wrapped_type( +async def test_channel_recieves_wrapped_type( service_method: Callable[[TestStub], Any], wrapper_class: Callable, value ): wrapped_value = wrapper_class() diff --git a/betterproto/tests/inputs/service/service.proto b/betterproto/tests/inputs/service/service.proto index 7c931ed..acfbcdd 100644 --- a/betterproto/tests/inputs/service/service.proto +++ b/betterproto/tests/inputs/service/service.proto @@ -3,13 +3,25 @@ syntax = "proto3"; package service; message DoThingRequest { - int32 iterations = 1; + string name = 1; } message DoThingResponse { - int32 successfulIterations = 1; + repeated string names = 1; +} + +message GetThingRequest { + string name = 1; +} + +message GetThingResponse { + string name = 1; + int32 version = 2; } service Test { rpc DoThing (DoThingRequest) returns (DoThingResponse); + rpc DoManyThings (stream DoThingRequest) returns (DoThingResponse); + rpc GetThingVersions (GetThingRequest) returns (stream GetThingResponse); + rpc GetDifferentThings (stream GetThingRequest) returns (stream GetThingResponse); } diff --git a/betterproto/tests/inputs/service/test_service.py b/betterproto/tests/inputs/service/test_service.py deleted file mode 100644 index 2a6ca59..0000000 --- a/betterproto/tests/inputs/service/test_service.py +++ /dev/null @@ -1,132 +0,0 @@ -import betterproto -import grpclib -from grpclib.testing import ChannelFor -import pytest -from typing import Dict - -from betterproto.tests.output_betterproto.service.service import ( - DoThingResponse, - DoThingRequest, - TestStub as ExampleServiceStub, -) - - -class ExampleService: - def __init__(self, test_hook=None): - # This lets us pass assertions to the servicer ;) - self.test_hook = test_hook - - async def DoThing( - self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]" - ): - request = await stream.recv_message() - print("self.test_hook", self.test_hook) - if self.test_hook is not None: - self.test_hook(stream) - for iteration in range(request.iterations): - pass - await stream.send_message(DoThingResponse(request.iterations)) - - def __mapping__(self) -> Dict[str, grpclib.const.Handler]: - return { - "/service.Test/DoThing": grpclib.const.Handler( - self.DoThing, - grpclib.const.Cardinality.UNARY_UNARY, - DoThingRequest, - DoThingResponse, - ) - } - - -async def _test_stub(stub, iterations=42, **kwargs): - response = await stub.do_thing(iterations=iterations) - assert response.successful_iterations == iterations - - -def _get_server_side_test(deadline, metadata): - def server_side_test(stream): - assert stream.deadline._timestamp == pytest.approx( - deadline._timestamp, 1 - ), "The provided deadline should be recieved serverside" - assert ( - stream.metadata["authorization"] == metadata["authorization"] - ), "The provided authorization metadata should be recieved serverside" - - return server_side_test - - -@pytest.mark.asyncio -async def test_simple_service_call(): - async with ChannelFor([ExampleService()]) as channel: - await _test_stub(ExampleServiceStub(channel)) - - -@pytest.mark.asyncio -async def test_service_call_with_upfront_request_params(): - # Setting deadline - deadline = grpclib.metadata.Deadline.from_timeout(22) - metadata = {"authorization": "12345"} - async with ChannelFor( - [ExampleService(test_hook=_get_server_side_test(deadline, metadata))] - ) as channel: - await _test_stub( - ExampleServiceStub(channel, deadline=deadline, metadata=metadata) - ) - - # Setting timeout - timeout = 99 - deadline = grpclib.metadata.Deadline.from_timeout(timeout) - metadata = {"authorization": "12345"} - async with ChannelFor( - [ExampleService(test_hook=_get_server_side_test(deadline, metadata))] - ) as channel: - await _test_stub( - ExampleServiceStub(channel, timeout=timeout, metadata=metadata) - ) - - -@pytest.mark.asyncio -async def test_service_call_lower_level_with_overrides(): - ITERATIONS = 99 - - # Setting deadline - deadline = grpclib.metadata.Deadline.from_timeout(22) - metadata = {"authorization": "12345"} - kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28) - kwarg_metadata = {"authorization": "12345"} - async with ChannelFor( - [ExampleService(test_hook=_get_server_side_test(deadline, metadata))] - ) as channel: - stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata) - response = await stub._unary_unary( - "/service.Test/DoThing", - DoThingRequest(ITERATIONS), - DoThingResponse, - deadline=kwarg_deadline, - metadata=kwarg_metadata, - ) - assert response.successful_iterations == ITERATIONS - - # Setting timeout - timeout = 99 - deadline = grpclib.metadata.Deadline.from_timeout(timeout) - metadata = {"authorization": "12345"} - kwarg_timeout = 9000 - kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout) - kwarg_metadata = {"authorization": "09876"} - async with ChannelFor( - [ - ExampleService( - test_hook=_get_server_side_test(kwarg_deadline, kwarg_metadata) - ) - ] - ) as channel: - stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata) - response = await stub._unary_unary( - "/service.Test/DoThing", - DoThingRequest(ITERATIONS), - DoThingResponse, - timeout=kwarg_timeout, - metadata=kwarg_metadata, - ) - assert response.successful_iterations == ITERATIONS diff --git a/betterproto/tests/test_inputs.py b/betterproto/tests/test_inputs.py index 7de1c69..cb5974d 100644 --- a/betterproto/tests/test_inputs.py +++ b/betterproto/tests/test_inputs.py @@ -23,7 +23,7 @@ from google.protobuf.json_format import Parse class TestCases: def __init__(self, path, services: Set[str], xfail: Set[str]): - _all = set(get_directories(path)) + _all = set(get_directories(path)) - {"__pycache__"} _services = services _messages = (_all - services) - {"__pycache__"} _messages_with_json = { diff --git a/betterproto/tests/util.py b/betterproto/tests/util.py index a7cff7a..61ba53e 100644 --- a/betterproto/tests/util.py +++ b/betterproto/tests/util.py @@ -1,23 +1,24 @@ +import asyncio import os -import subprocess -from typing import Generator +from pathlib import Path +from typing import Generator, IO, Optional os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" -root_path = os.path.dirname(os.path.realpath(__file__)) -inputs_path = os.path.join(root_path, "inputs") -output_path_reference = os.path.join(root_path, "output_reference") -output_path_betterproto = os.path.join(root_path, "output_betterproto") +root_path = Path(__file__).resolve().parent +inputs_path = root_path.joinpath("inputs") +output_path_reference = root_path.joinpath("output_reference") +output_path_betterproto = root_path.joinpath("output_betterproto") if os.name == "nt": - plugin_path = os.path.join(root_path, "..", "plugin.bat") + plugin_path = root_path.joinpath("..", "plugin.bat").resolve() else: - plugin_path = os.path.join(root_path, "..", "plugin.py") + plugin_path = root_path.joinpath("..", "plugin.py").resolve() -def get_files(path, end: str) -> Generator[str, None, None]: +def get_files(path, suffix: str) -> Generator[str, None, None]: for r, dirs, files in os.walk(path): - for filename in [f for f in files if f.endswith(end)]: + for filename in [f for f in files if f.endswith(suffix)]: yield os.path.join(r, filename) @@ -27,36 +28,30 @@ def get_directories(path): yield directory -def relative(file: str, path: str): - return os.path.join(os.path.dirname(file), path) - - -def read_relative(file: str, path: str): - with open(relative(file, path)) as fh: - return fh.read() - - -def protoc_plugin(path: str, output_dir: str) -> subprocess.CompletedProcess: - return subprocess.run( +async def protoc_plugin(path: str, output_dir: str): + proc = await asyncio.create_subprocess_shell( f"protoc --plugin=protoc-gen-custom={plugin_path} --custom_out={output_dir} --proto_path={path} {path}/*.proto", - shell=True, - check=True, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, ) + return (*(await proc.communicate()), proc.returncode) -def protoc_reference(path: str, output_dir: str): - subprocess.run( +async def protoc_reference(path: str, output_dir: str): + proc = await asyncio.create_subprocess_shell( f"protoc --python_out={output_dir} --proto_path={path} {path}/*.proto", - shell=True, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, ) + return (*(await proc.communicate()), proc.returncode) -def get_test_case_json_data(test_case_name, json_file_name=None): +def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = None): test_data_file_name = json_file_name if json_file_name else f"{test_case_name}.json" - test_data_file_path = os.path.join(inputs_path, test_case_name, test_data_file_name) + test_data_file_path = inputs_path.joinpath(test_case_name, test_data_file_name) - if not os.path.exists(test_data_file_path): + if not test_data_file_path.exists(): return None - with open(test_data_file_path) as fh: + with test_data_file_path.open("r") as fh: return fh.read()