commit
9532844929
1
.gitignore
vendored
1
.gitignore
vendored
@ -12,4 +12,5 @@ dist
|
|||||||
**/*.egg-info
|
**/*.egg-info
|
||||||
output
|
output
|
||||||
.idea
|
.idea
|
||||||
|
.DS_Store
|
||||||
.tox
|
.tox
|
||||||
|
@ -5,8 +5,9 @@ import json
|
|||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from base64 import b64encode, b64decode
|
from base64 import b64decode, b64encode
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
import stringcase
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
@ -14,28 +15,20 @@ from typing import (
|
|||||||
Collection,
|
Collection,
|
||||||
Dict,
|
Dict,
|
||||||
Generator,
|
Generator,
|
||||||
|
Iterator,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
|
SupportsBytes,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
|
||||||
Union,
|
Union,
|
||||||
get_type_hints,
|
get_type_hints,
|
||||||
TYPE_CHECKING,
|
|
||||||
)
|
)
|
||||||
|
from ._types import ST, T
|
||||||
|
|
||||||
import grpclib.const
|
|
||||||
import stringcase
|
|
||||||
|
|
||||||
from .casing import safe_snake_case
|
from .casing import safe_snake_case
|
||||||
|
from .grpc.grpclib_client import ServiceStub
|
||||||
if TYPE_CHECKING:
|
|
||||||
from grpclib._protocols import IProtoMessage
|
|
||||||
from grpclib.client import Channel
|
|
||||||
from grpclib.metadata import Deadline
|
|
||||||
|
|
||||||
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
|
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
|
||||||
# Apply backport of datetime.fromisoformat from 3.7
|
# Apply backport of datetime.fromisoformat from 3.7
|
||||||
@ -429,10 +422,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Bound type variable to allow methods to return `self` of subclasses
|
|
||||||
T = TypeVar("T", bound="Message")
|
|
||||||
|
|
||||||
|
|
||||||
class ProtoClassMetadata:
|
class ProtoClassMetadata:
|
||||||
oneof_group_by_field: Dict[str, str]
|
oneof_group_by_field: Dict[str, str]
|
||||||
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
|
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
|
||||||
@ -451,7 +440,7 @@ class ProtoClassMetadata:
|
|||||||
|
|
||||||
def __init__(self, cls: Type["Message"]):
|
def __init__(self, cls: Type["Message"]):
|
||||||
by_field = {}
|
by_field = {}
|
||||||
by_group = {}
|
by_group: Dict[str, Set] = {}
|
||||||
by_field_name = {}
|
by_field_name = {}
|
||||||
by_field_number = {}
|
by_field_number = {}
|
||||||
|
|
||||||
@ -604,7 +593,7 @@ class Message(ABC):
|
|||||||
serialize_empty = False
|
serialize_empty = False
|
||||||
if isinstance(value, Message) and value._serialized_on_wire:
|
if isinstance(value, Message) and value._serialized_on_wire:
|
||||||
# Empty messages can still be sent on the wire if they were
|
# Empty messages can still be sent on the wire if they were
|
||||||
# set (or received empty).
|
# set (or recieved empty).
|
||||||
serialize_empty = True
|
serialize_empty = True
|
||||||
|
|
||||||
if value == self._get_field_default(field_name) and not (
|
if value == self._get_field_default(field_name) and not (
|
||||||
@ -791,7 +780,7 @@ class Message(ABC):
|
|||||||
|
|
||||||
def to_dict(
|
def to_dict(
|
||||||
self, casing: Casing = Casing.CAMEL, include_default_values: bool = False
|
self, casing: Casing = Casing.CAMEL, include_default_values: bool = False
|
||||||
) -> dict:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Returns a dict representation of this message instance which can be
|
Returns a dict representation of this message instance which can be
|
||||||
used to serialize to e.g. JSON. Defaults to camel casing for
|
used to serialize to e.g. JSON. Defaults to camel casing for
|
||||||
@ -1024,83 +1013,3 @@ def _get_wrapper(proto_type: str) -> Type:
|
|||||||
TYPE_STRING: StringValue,
|
TYPE_STRING: StringValue,
|
||||||
TYPE_BYTES: BytesValue,
|
TYPE_BYTES: BytesValue,
|
||||||
}[proto_type]
|
}[proto_type]
|
||||||
|
|
||||||
|
|
||||||
_Value = Union[str, bytes]
|
|
||||||
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
|
|
||||||
|
|
||||||
|
|
||||||
class ServiceStub(ABC):
|
|
||||||
"""
|
|
||||||
Base class for async gRPC service stubs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
channel: "Channel",
|
|
||||||
*,
|
|
||||||
timeout: Optional[float] = None,
|
|
||||||
deadline: Optional["Deadline"] = None,
|
|
||||||
metadata: Optional[_MetadataLike] = None,
|
|
||||||
) -> None:
|
|
||||||
self.channel = channel
|
|
||||||
self.timeout = timeout
|
|
||||||
self.deadline = deadline
|
|
||||||
self.metadata = metadata
|
|
||||||
|
|
||||||
def __resolve_request_kwargs(
|
|
||||||
self,
|
|
||||||
timeout: Optional[float],
|
|
||||||
deadline: Optional["Deadline"],
|
|
||||||
metadata: Optional[_MetadataLike],
|
|
||||||
):
|
|
||||||
return {
|
|
||||||
"timeout": self.timeout if timeout is None else timeout,
|
|
||||||
"deadline": self.deadline if deadline is None else deadline,
|
|
||||||
"metadata": self.metadata if metadata is None else metadata,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _unary_unary(
|
|
||||||
self,
|
|
||||||
route: str,
|
|
||||||
request: "IProtoMessage",
|
|
||||||
response_type: Type[T],
|
|
||||||
*,
|
|
||||||
timeout: Optional[float] = None,
|
|
||||||
deadline: Optional["Deadline"] = None,
|
|
||||||
metadata: Optional[_MetadataLike] = None,
|
|
||||||
) -> T:
|
|
||||||
"""Make a unary request and return the response."""
|
|
||||||
async with self.channel.request(
|
|
||||||
route,
|
|
||||||
grpclib.const.Cardinality.UNARY_UNARY,
|
|
||||||
type(request),
|
|
||||||
response_type,
|
|
||||||
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
|
||||||
) as stream:
|
|
||||||
await stream.send_message(request, end=True)
|
|
||||||
response = await stream.recv_message()
|
|
||||||
assert response is not None
|
|
||||||
return response
|
|
||||||
|
|
||||||
async def _unary_stream(
|
|
||||||
self,
|
|
||||||
route: str,
|
|
||||||
request: "IProtoMessage",
|
|
||||||
response_type: Type[T],
|
|
||||||
*,
|
|
||||||
timeout: Optional[float] = None,
|
|
||||||
deadline: Optional["Deadline"] = None,
|
|
||||||
metadata: Optional[_MetadataLike] = None,
|
|
||||||
) -> AsyncGenerator[T, None]:
|
|
||||||
"""Make a unary request and return the stream response iterator."""
|
|
||||||
async with self.channel.request(
|
|
||||||
route,
|
|
||||||
grpclib.const.Cardinality.UNARY_STREAM,
|
|
||||||
type(request),
|
|
||||||
response_type,
|
|
||||||
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
|
||||||
) as stream:
|
|
||||||
await stream.send_message(request, end=True)
|
|
||||||
async for message in stream:
|
|
||||||
yield message
|
|
||||||
|
9
betterproto/_types.py
Normal file
9
betterproto/_types.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
from typing import TYPE_CHECKING, TypeVar
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from . import Message
|
||||||
|
from grpclib._protocols import IProtoMessage
|
||||||
|
|
||||||
|
# Bound type variable to allow methods to return `self` of subclasses
|
||||||
|
T = TypeVar("T", bound="Message")
|
||||||
|
ST = TypeVar("ST", bound="IProtoMessage")
|
0
betterproto/grpc/__init__.py
Normal file
0
betterproto/grpc/__init__.py
Normal file
170
betterproto/grpc/grpclib_client.py
Normal file
170
betterproto/grpc/grpclib_client.py
Normal file
@ -0,0 +1,170 @@
|
|||||||
|
from abc import ABC
|
||||||
|
import asyncio
|
||||||
|
import grpclib.const
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterable,
|
||||||
|
AsyncIterator,
|
||||||
|
Collection,
|
||||||
|
Iterable,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
from .._types import ST, T
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from grpclib._protocols import IProtoMessage
|
||||||
|
from grpclib.client import Channel, Stream
|
||||||
|
from grpclib.metadata import Deadline
|
||||||
|
|
||||||
|
|
||||||
|
_Value = Union[str, bytes]
|
||||||
|
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
|
||||||
|
_MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceStub(ABC):
|
||||||
|
"""
|
||||||
|
Base class for async gRPC clients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
channel: "Channel",
|
||||||
|
*,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
deadline: Optional["Deadline"] = None,
|
||||||
|
metadata: Optional[_MetadataLike] = None,
|
||||||
|
) -> None:
|
||||||
|
self.channel = channel
|
||||||
|
self.timeout = timeout
|
||||||
|
self.deadline = deadline
|
||||||
|
self.metadata = metadata
|
||||||
|
|
||||||
|
def __resolve_request_kwargs(
|
||||||
|
self,
|
||||||
|
timeout: Optional[float],
|
||||||
|
deadline: Optional["Deadline"],
|
||||||
|
metadata: Optional[_MetadataLike],
|
||||||
|
):
|
||||||
|
return {
|
||||||
|
"timeout": self.timeout if timeout is None else timeout,
|
||||||
|
"deadline": self.deadline if deadline is None else deadline,
|
||||||
|
"metadata": self.metadata if metadata is None else metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _unary_unary(
|
||||||
|
self,
|
||||||
|
route: str,
|
||||||
|
request: "IProtoMessage",
|
||||||
|
response_type: Type[T],
|
||||||
|
*,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
deadline: Optional["Deadline"] = None,
|
||||||
|
metadata: Optional[_MetadataLike] = None,
|
||||||
|
) -> T:
|
||||||
|
"""Make a unary request and return the response."""
|
||||||
|
async with self.channel.request(
|
||||||
|
route,
|
||||||
|
grpclib.const.Cardinality.UNARY_UNARY,
|
||||||
|
type(request),
|
||||||
|
response_type,
|
||||||
|
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
||||||
|
) as stream:
|
||||||
|
await stream.send_message(request, end=True)
|
||||||
|
response = await stream.recv_message()
|
||||||
|
assert response is not None
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def _unary_stream(
|
||||||
|
self,
|
||||||
|
route: str,
|
||||||
|
request: "IProtoMessage",
|
||||||
|
response_type: Type[T],
|
||||||
|
*,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
deadline: Optional["Deadline"] = None,
|
||||||
|
metadata: Optional[_MetadataLike] = None,
|
||||||
|
) -> AsyncIterator[T]:
|
||||||
|
"""Make a unary request and return the stream response iterator."""
|
||||||
|
async with self.channel.request(
|
||||||
|
route,
|
||||||
|
grpclib.const.Cardinality.UNARY_STREAM,
|
||||||
|
type(request),
|
||||||
|
response_type,
|
||||||
|
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
||||||
|
) as stream:
|
||||||
|
await stream.send_message(request, end=True)
|
||||||
|
async for message in stream:
|
||||||
|
yield message
|
||||||
|
|
||||||
|
async def _stream_unary(
|
||||||
|
self,
|
||||||
|
route: str,
|
||||||
|
request_iterator: _MessageSource,
|
||||||
|
request_type: Type[ST],
|
||||||
|
response_type: Type[T],
|
||||||
|
*,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
deadline: Optional["Deadline"] = None,
|
||||||
|
metadata: Optional[_MetadataLike] = None,
|
||||||
|
) -> T:
|
||||||
|
"""Make a stream request and return the response."""
|
||||||
|
async with self.channel.request(
|
||||||
|
route,
|
||||||
|
grpclib.const.Cardinality.STREAM_UNARY,
|
||||||
|
request_type,
|
||||||
|
response_type,
|
||||||
|
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
||||||
|
) as stream:
|
||||||
|
await self._send_messages(stream, request_iterator)
|
||||||
|
response = await stream.recv_message()
|
||||||
|
assert response is not None
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def _stream_stream(
|
||||||
|
self,
|
||||||
|
route: str,
|
||||||
|
request_iterator: _MessageSource,
|
||||||
|
request_type: Type[ST],
|
||||||
|
response_type: Type[T],
|
||||||
|
*,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
|
deadline: Optional["Deadline"] = None,
|
||||||
|
metadata: Optional[_MetadataLike] = None,
|
||||||
|
) -> AsyncIterator[T]:
|
||||||
|
"""
|
||||||
|
Make a stream request and return an AsyncIterator to iterate over response
|
||||||
|
messages.
|
||||||
|
"""
|
||||||
|
async with self.channel.request(
|
||||||
|
route,
|
||||||
|
grpclib.const.Cardinality.STREAM_STREAM,
|
||||||
|
request_type,
|
||||||
|
response_type,
|
||||||
|
**self.__resolve_request_kwargs(timeout, deadline, metadata),
|
||||||
|
) as stream:
|
||||||
|
await stream.send_request()
|
||||||
|
sending_task = asyncio.ensure_future(
|
||||||
|
self._send_messages(stream, request_iterator)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
async for response in stream:
|
||||||
|
yield response
|
||||||
|
except:
|
||||||
|
sending_task.cancel()
|
||||||
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _send_messages(stream, messages: _MessageSource):
|
||||||
|
if isinstance(messages, AsyncIterable):
|
||||||
|
async for message in messages:
|
||||||
|
await stream.send_message(message)
|
||||||
|
else:
|
||||||
|
for message in messages:
|
||||||
|
await stream.send_message(message)
|
||||||
|
await stream.end()
|
0
betterproto/grpc/util/__init__.py
Normal file
0
betterproto/grpc/util/__init__.py
Normal file
198
betterproto/grpc/util/async_channel.py
Normal file
198
betterproto/grpc/util/async_channel.py
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
import asyncio
|
||||||
|
from typing import (
|
||||||
|
AsyncIterable,
|
||||||
|
AsyncIterator,
|
||||||
|
Iterable,
|
||||||
|
Optional,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelClosed(Exception):
|
||||||
|
"""
|
||||||
|
An exception raised on an attempt to send through a closed channel
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelDone(Exception):
|
||||||
|
"""
|
||||||
|
An exception raised on an attempt to send recieve from a channel that is both closed
|
||||||
|
and empty.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncChannel(AsyncIterable[T]):
|
||||||
|
"""
|
||||||
|
A buffered async channel for sending items between coroutines with FIFO ordering.
|
||||||
|
|
||||||
|
This makes decoupled bidirection steaming gRPC requests easy if used like:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
client = GeneratedStub(grpclib_chan)
|
||||||
|
request_chan = await AsyncChannel()
|
||||||
|
# We can start be sending all the requests we already have
|
||||||
|
await request_chan.send_from([ReqestObject(...), ReqestObject(...)])
|
||||||
|
async for response in client.rpc_call(request_chan):
|
||||||
|
# The response iterator will remain active until the connection is closed
|
||||||
|
...
|
||||||
|
# More items can be sent at any time
|
||||||
|
await request_chan.send(ReqestObject(...))
|
||||||
|
...
|
||||||
|
# The channel must be closed to complete the gRPC connection
|
||||||
|
request_chan.close()
|
||||||
|
|
||||||
|
Items can be sent through the channel by either:
|
||||||
|
- providing an iterable to the send_from method
|
||||||
|
- passing them to the send method one at a time
|
||||||
|
|
||||||
|
Items can be recieved from the channel by either:
|
||||||
|
- iterating over the channel with a for loop to get all items
|
||||||
|
- calling the recieve method to get one item at a time
|
||||||
|
|
||||||
|
If the channel is empty then recievers will wait until either an item appears or the
|
||||||
|
channel is closed.
|
||||||
|
|
||||||
|
Once the channel is closed then subsequent attempt to send through the channel will
|
||||||
|
fail with a ChannelClosed exception.
|
||||||
|
|
||||||
|
When th channel is closed and empty then it is done, and further attempts to recieve
|
||||||
|
from it will fail with a ChannelDone exception
|
||||||
|
|
||||||
|
If multiple coroutines recieve from the channel concurrently, each item sent will be
|
||||||
|
recieved by only one of the recievers.
|
||||||
|
|
||||||
|
:param source:
|
||||||
|
An optional iterable will items that should be sent through the channel
|
||||||
|
immediately.
|
||||||
|
:param buffer_limit:
|
||||||
|
Limit the number of items that can be buffered in the channel, A value less than
|
||||||
|
1 implies no limit. If the channel is full then attempts to send more items will
|
||||||
|
result in the sender waiting until an item is recieved from the channel.
|
||||||
|
:param close:
|
||||||
|
If set to True then the channel will automatically close after exhausting source
|
||||||
|
or immediately if no source is provided.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, *, buffer_limit: int = 0, close: bool = False,
|
||||||
|
):
|
||||||
|
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
|
||||||
|
self._closed = False
|
||||||
|
self._waiting_recievers: int = 0
|
||||||
|
# Track whether flush has been invoked so it can only happen once
|
||||||
|
self._flushed = False
|
||||||
|
|
||||||
|
def __aiter__(self) -> AsyncIterator[T]:
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self) -> T:
|
||||||
|
if self.done():
|
||||||
|
raise StopAsyncIteration
|
||||||
|
self._waiting_recievers += 1
|
||||||
|
try:
|
||||||
|
result = await self._queue.get()
|
||||||
|
if result is self.__flush:
|
||||||
|
raise StopAsyncIteration
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
self._waiting_recievers -= 1
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
def closed(self) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if this channel is closed and no-longer accepting new items
|
||||||
|
"""
|
||||||
|
return self._closed
|
||||||
|
|
||||||
|
def done(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if this channel is done.
|
||||||
|
|
||||||
|
:return: True if this channel is closed and and has been drained of items in
|
||||||
|
which case any further attempts to recieve an item from this channel will raise
|
||||||
|
a ChannelDone exception.
|
||||||
|
"""
|
||||||
|
# After close the channel is not yet done until there is at least one waiting
|
||||||
|
# reciever per enqueued item.
|
||||||
|
return self._closed and self._queue.qsize() <= self._waiting_recievers
|
||||||
|
|
||||||
|
async def send_from(
|
||||||
|
self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
|
||||||
|
) -> "AsyncChannel[T]":
|
||||||
|
"""
|
||||||
|
Iterates the given [Async]Iterable and sends all the resulting items.
|
||||||
|
If close is set to True then subsequent send calls will be rejected with a
|
||||||
|
ChannelClosed exception.
|
||||||
|
:param source: an iterable of items to send
|
||||||
|
:param close:
|
||||||
|
if True then the channel will be closed after the source has been exhausted
|
||||||
|
|
||||||
|
"""
|
||||||
|
if self._closed:
|
||||||
|
raise ChannelClosed("Cannot send through a closed channel")
|
||||||
|
if isinstance(source, AsyncIterable):
|
||||||
|
async for item in source:
|
||||||
|
await self._queue.put(item)
|
||||||
|
else:
|
||||||
|
for item in source:
|
||||||
|
await self._queue.put(item)
|
||||||
|
if close:
|
||||||
|
# Complete the closing process
|
||||||
|
self.close()
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def send(self, item: T) -> "AsyncChannel[T]":
|
||||||
|
"""
|
||||||
|
Send a single item over this channel.
|
||||||
|
:param item: The item to send
|
||||||
|
"""
|
||||||
|
if self._closed:
|
||||||
|
raise ChannelClosed("Cannot send through a closed channel")
|
||||||
|
await self._queue.put(item)
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def recieve(self) -> Optional[T]:
|
||||||
|
"""
|
||||||
|
Returns the next item from this channel when it becomes available,
|
||||||
|
or None if the channel is closed before another item is sent.
|
||||||
|
:return: An item from the channel
|
||||||
|
"""
|
||||||
|
if self.done():
|
||||||
|
raise ChannelDone("Cannot recieve from a closed channel")
|
||||||
|
self._waiting_recievers += 1
|
||||||
|
try:
|
||||||
|
result = await self._queue.get()
|
||||||
|
if result is self.__flush:
|
||||||
|
return None
|
||||||
|
return result
|
||||||
|
finally:
|
||||||
|
self._waiting_recievers -= 1
|
||||||
|
self._queue.task_done()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
Close this channel to new items
|
||||||
|
"""
|
||||||
|
self._closed = True
|
||||||
|
asyncio.ensure_future(self._flush_queue())
|
||||||
|
|
||||||
|
async def _flush_queue(self):
|
||||||
|
"""
|
||||||
|
To be called after the channel is closed. Pushes a number of self.__flush
|
||||||
|
objects to the queue to ensure no waiting consumers get deadlocked.
|
||||||
|
"""
|
||||||
|
if not self._flushed:
|
||||||
|
self._flushed = True
|
||||||
|
deadlocked_recievers = max(0, self._waiting_recievers - self._queue.qsize())
|
||||||
|
for _ in range(deadlocked_recievers):
|
||||||
|
await self._queue.put(self.__flush)
|
||||||
|
|
||||||
|
# A special signal object for flushing the queue when the channel is closed
|
||||||
|
__flush = object()
|
@ -6,10 +6,10 @@ import re
|
|||||||
import stringcase
|
import stringcase
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import List
|
from typing import List, Union
|
||||||
|
import betterproto
|
||||||
from betterproto.casing import safe_snake_case
|
from betterproto.casing import safe_snake_case
|
||||||
from betterproto.compile.importing import get_ref_type
|
from betterproto.compile.importing import get_ref_type
|
||||||
import betterproto
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# betterproto[compiler] specific dependencies
|
# betterproto[compiler] specific dependencies
|
||||||
@ -58,8 +58,8 @@ def py_type(
|
|||||||
raise NotImplementedError(f"Unknown type {descriptor.type}")
|
raise NotImplementedError(f"Unknown type {descriptor.type}")
|
||||||
|
|
||||||
|
|
||||||
def get_py_zero(type_num: int) -> str:
|
def get_py_zero(type_num: int) -> Union[str, float]:
|
||||||
zero = 0
|
zero: Union[str, float] = 0
|
||||||
if type_num in []:
|
if type_num in []:
|
||||||
zero = 0.0
|
zero = 0.0
|
||||||
elif type_num == 8:
|
elif type_num == 8:
|
||||||
@ -311,9 +311,6 @@ def generate_code(request, response):
|
|||||||
}
|
}
|
||||||
|
|
||||||
for j, method in enumerate(service.method):
|
for j, method in enumerate(service.method):
|
||||||
if method.client_streaming:
|
|
||||||
raise NotImplementedError("Client streaming not yet supported")
|
|
||||||
|
|
||||||
input_message = None
|
input_message = None
|
||||||
input_type = get_ref_type(
|
input_type = get_ref_type(
|
||||||
package, output["imports"], method.input_type
|
package, output["imports"], method.input_type
|
||||||
@ -347,8 +344,12 @@ def generate_code(request, response):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if method.client_streaming:
|
||||||
|
output["typing_imports"].add("AsyncIterable")
|
||||||
|
output["typing_imports"].add("Iterable")
|
||||||
|
output["typing_imports"].add("Union")
|
||||||
if method.server_streaming:
|
if method.server_streaming:
|
||||||
output["typing_imports"].add("AsyncGenerator")
|
output["typing_imports"].add("AsyncIterator")
|
||||||
|
|
||||||
output["services"].append(data)
|
output["services"].append(data)
|
||||||
|
|
||||||
|
@ -63,11 +63,28 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
|||||||
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% for method in service.methods %}
|
{% for method in service.methods %}
|
||||||
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
|
async def {{ method.py_name }}(self
|
||||||
|
{%- if not method.client_streaming -%}
|
||||||
|
{%- if method.input_message and method.input_message.properties -%}, *,
|
||||||
|
{%- for field in method.input_message.properties -%}
|
||||||
|
{{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") -%}
|
||||||
|
Optional[{{ field.type }}]
|
||||||
|
{%- else -%}
|
||||||
|
{{ field.type }}
|
||||||
|
{%- endif -%} = {{ field.zero }}
|
||||||
|
{%- if not loop.last %}, {% endif -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{%- endif -%}
|
||||||
|
{%- else -%}
|
||||||
|
{# Client streaming: need a request iterator instead #}
|
||||||
|
, request_iterator: Union[AsyncIterable["{{ method.input }}"], Iterable["{{ method.input }}"]]
|
||||||
|
{%- endif -%}
|
||||||
|
) -> {% if method.server_streaming %}AsyncIterator[{{ method.output }}]{% else %}{{ method.output }}{% endif %}:
|
||||||
{% if method.comment %}
|
{% if method.comment %}
|
||||||
{{ method.comment }}
|
{{ method.comment }}
|
||||||
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
{% if not method.client_streaming %}
|
||||||
request = {{ method.input }}()
|
request = {{ method.input }}()
|
||||||
{% for field in method.input_message.properties %}
|
{% for field in method.input_message.properties %}
|
||||||
{% if field.field_type == 'message' %}
|
{% if field.field_type == 'message' %}
|
||||||
@ -77,20 +94,41 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
|||||||
request.{{ field.py_name }} = {{ field.py_name }}
|
request.{{ field.py_name }} = {{ field.py_name }}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
{% if method.server_streaming %}
|
{% if method.server_streaming %}
|
||||||
|
{% if method.client_streaming %}
|
||||||
|
async for response in self._stream_stream(
|
||||||
|
"{{ method.route }}",
|
||||||
|
request_iterator,
|
||||||
|
{{ method.input }},
|
||||||
|
{{ method.output }},
|
||||||
|
):
|
||||||
|
yield response
|
||||||
|
{% else %}{# i.e. not client streaming #}
|
||||||
async for response in self._unary_stream(
|
async for response in self._unary_stream(
|
||||||
"{{ method.route }}",
|
"{{ method.route }}",
|
||||||
request,
|
request,
|
||||||
{{ method.output }},
|
{{ method.output }},
|
||||||
):
|
):
|
||||||
yield response
|
yield response
|
||||||
{% else %}
|
|
||||||
|
{% endif %}{# if client streaming #}
|
||||||
|
{% else %}{# i.e. not server streaming #}
|
||||||
|
{% if method.client_streaming %}
|
||||||
|
return await self._stream_unary(
|
||||||
|
"{{ method.route }}",
|
||||||
|
request_iterator,
|
||||||
|
{{ method.input }},
|
||||||
|
{{ method.output }}
|
||||||
|
)
|
||||||
|
{% else %}{# i.e. not client streaming #}
|
||||||
return await self._unary_unary(
|
return await self._unary_unary(
|
||||||
"{{ method.route }}",
|
"{{ method.route }}",
|
||||||
request,
|
request,
|
||||||
{{ method.output }},
|
{{ method.output }}
|
||||||
)
|
)
|
||||||
|
{% endif %}{# client streaming #}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
import glob
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
from pathlib import Path
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
@ -20,58 +21,63 @@ from betterproto.tests.util import (
|
|||||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||||
|
|
||||||
|
|
||||||
def clear_directory(path: str):
|
def clear_directory(dir_path: Path):
|
||||||
for file_or_directory in glob.glob(os.path.join(path, "*")):
|
for file_or_directory in dir_path.glob("*"):
|
||||||
if os.path.isdir(file_or_directory):
|
if file_or_directory.is_dir():
|
||||||
shutil.rmtree(file_or_directory)
|
shutil.rmtree(file_or_directory)
|
||||||
else:
|
else:
|
||||||
os.remove(file_or_directory)
|
file_or_directory.unlink()
|
||||||
|
|
||||||
|
|
||||||
def generate(whitelist: Set[str]):
|
async def generate(whitelist: Set[str], verbose: bool):
|
||||||
path_whitelist = {os.path.realpath(e) for e in whitelist if os.path.exists(e)}
|
test_case_names = set(get_directories(inputs_path)) - {"__pycache__"}
|
||||||
name_whitelist = {e for e in whitelist if not os.path.exists(e)}
|
|
||||||
|
|
||||||
test_case_names = set(get_directories(inputs_path))
|
path_whitelist = set()
|
||||||
|
name_whitelist = set()
|
||||||
failed_test_cases = []
|
for item in whitelist:
|
||||||
|
if item in test_case_names:
|
||||||
|
name_whitelist.add(item)
|
||||||
|
continue
|
||||||
|
path_whitelist.add(item)
|
||||||
|
|
||||||
|
generation_tasks = []
|
||||||
for test_case_name in sorted(test_case_names):
|
for test_case_name in sorted(test_case_names):
|
||||||
test_case_input_path = os.path.realpath(
|
test_case_input_path = inputs_path.joinpath(test_case_name).resolve()
|
||||||
os.path.join(inputs_path, test_case_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
whitelist
|
whitelist
|
||||||
and test_case_input_path not in path_whitelist
|
and str(test_case_input_path) not in path_whitelist
|
||||||
and test_case_name not in name_whitelist
|
and test_case_name not in name_whitelist
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
|
generation_tasks.append(
|
||||||
|
generate_test_case_output(test_case_input_path, test_case_name, verbose)
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Generating output for {test_case_name}")
|
failed_test_cases = []
|
||||||
try:
|
# Wait for all subprocs and match any failures to names to report
|
||||||
generate_test_case_output(test_case_name, test_case_input_path)
|
for test_case_name, result in zip(
|
||||||
except subprocess.CalledProcessError as e:
|
sorted(test_case_names), await asyncio.gather(*generation_tasks)
|
||||||
|
):
|
||||||
|
if result != 0:
|
||||||
failed_test_cases.append(test_case_name)
|
failed_test_cases.append(test_case_name)
|
||||||
|
|
||||||
if failed_test_cases:
|
if failed_test_cases:
|
||||||
sys.stderr.write("\nFailed to generate the following test cases:\n")
|
sys.stderr.write(
|
||||||
|
"\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n"
|
||||||
|
)
|
||||||
for failed_test_case in failed_test_cases:
|
for failed_test_case in failed_test_cases:
|
||||||
sys.stderr.write(f"- {failed_test_case}\n")
|
sys.stderr.write(f"- {failed_test_case}\n")
|
||||||
|
|
||||||
|
|
||||||
def generate_test_case_output(test_case_name, test_case_input_path=None):
|
async def generate_test_case_output(
|
||||||
if not test_case_input_path:
|
test_case_input_path: Path, test_case_name: str, verbose: bool
|
||||||
test_case_input_path = os.path.realpath(
|
) -> int:
|
||||||
os.path.join(inputs_path, test_case_name)
|
"""
|
||||||
)
|
Returns the max of the subprocess return values
|
||||||
|
"""
|
||||||
|
|
||||||
test_case_output_path_reference = os.path.join(
|
test_case_output_path_reference = output_path_reference.joinpath(test_case_name)
|
||||||
output_path_reference, test_case_name
|
test_case_output_path_betterproto = output_path_betterproto.joinpath(test_case_name)
|
||||||
)
|
|
||||||
test_case_output_path_betterproto = os.path.join(
|
|
||||||
output_path_betterproto, test_case_name
|
|
||||||
)
|
|
||||||
|
|
||||||
os.makedirs(test_case_output_path_reference, exist_ok=True)
|
os.makedirs(test_case_output_path_reference, exist_ok=True)
|
||||||
os.makedirs(test_case_output_path_betterproto, exist_ok=True)
|
os.makedirs(test_case_output_path_betterproto, exist_ok=True)
|
||||||
@ -79,14 +85,36 @@ def generate_test_case_output(test_case_name, test_case_input_path=None):
|
|||||||
clear_directory(test_case_output_path_reference)
|
clear_directory(test_case_output_path_reference)
|
||||||
clear_directory(test_case_output_path_betterproto)
|
clear_directory(test_case_output_path_betterproto)
|
||||||
|
|
||||||
protoc_reference(test_case_input_path, test_case_output_path_reference)
|
(
|
||||||
protoc_plugin(test_case_input_path, test_case_output_path_betterproto)
|
(ref_out, ref_err, ref_code),
|
||||||
|
(plg_out, plg_err, plg_code),
|
||||||
|
) = await asyncio.gather(
|
||||||
|
protoc_reference(test_case_input_path, test_case_output_path_reference),
|
||||||
|
protoc_plugin(test_case_input_path, test_case_output_path_betterproto),
|
||||||
|
)
|
||||||
|
|
||||||
|
message = f"Generated output for {test_case_name!r}"
|
||||||
|
if verbose:
|
||||||
|
print(f"\033[31;1;4m{message}\033[0m")
|
||||||
|
if ref_out:
|
||||||
|
sys.stdout.buffer.write(ref_out)
|
||||||
|
if ref_err:
|
||||||
|
sys.stderr.buffer.write(ref_err)
|
||||||
|
if plg_out:
|
||||||
|
sys.stdout.buffer.write(plg_out)
|
||||||
|
if plg_err:
|
||||||
|
sys.stderr.buffer.write(plg_err)
|
||||||
|
sys.stdout.buffer.flush()
|
||||||
|
sys.stderr.buffer.flush()
|
||||||
|
else:
|
||||||
|
print(message)
|
||||||
|
|
||||||
|
return max(ref_code, plg_code)
|
||||||
|
|
||||||
|
|
||||||
HELP = "\n".join(
|
HELP = "\n".join(
|
||||||
[
|
(
|
||||||
"Usage: python generate.py",
|
"Usage: python generate.py [-h] [-v] [DIRECTORIES or NAMES]",
|
||||||
" python generate.py [DIRECTORIES or NAMES]",
|
|
||||||
"Generate python classes for standard tests.",
|
"Generate python classes for standard tests.",
|
||||||
"",
|
"",
|
||||||
"DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.",
|
"DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.",
|
||||||
@ -94,7 +122,7 @@ HELP = "\n".join(
|
|||||||
"",
|
"",
|
||||||
"NAMES One or more test-case names to generate classes for.",
|
"NAMES One or more test-case names to generate classes for.",
|
||||||
" python generate.py bool double enums",
|
" python generate.py bool double enums",
|
||||||
]
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -102,9 +130,13 @@ def main():
|
|||||||
if set(sys.argv).intersection({"-h", "--help"}):
|
if set(sys.argv).intersection({"-h", "--help"}):
|
||||||
print(HELP)
|
print(HELP)
|
||||||
return
|
return
|
||||||
|
if sys.argv[1:2] == ["-v"]:
|
||||||
|
verbose = True
|
||||||
|
whitelist = set(sys.argv[2:])
|
||||||
|
else:
|
||||||
|
verbose = False
|
||||||
whitelist = set(sys.argv[1:])
|
whitelist = set(sys.argv[1:])
|
||||||
|
asyncio.get_event_loop().run_until_complete(generate(whitelist, verbose))
|
||||||
generate(whitelist)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
0
betterproto/tests/grpc/__init__.py
Normal file
0
betterproto/tests/grpc/__init__.py
Normal file
154
betterproto/tests/grpc/test_grpclib_client.py
Normal file
154
betterproto/tests/grpc/test_grpclib_client.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
import asyncio
|
||||||
|
from betterproto.tests.output_betterproto.service.service import (
|
||||||
|
DoThingResponse,
|
||||||
|
DoThingRequest,
|
||||||
|
GetThingRequest,
|
||||||
|
GetThingResponse,
|
||||||
|
TestStub as ThingServiceClient,
|
||||||
|
)
|
||||||
|
import grpclib
|
||||||
|
from grpclib.testing import ChannelFor
|
||||||
|
import pytest
|
||||||
|
from betterproto.grpc.util.async_channel import AsyncChannel
|
||||||
|
from .thing_service import ThingService
|
||||||
|
|
||||||
|
|
||||||
|
async def _test_client(client, name="clean room", **kwargs):
|
||||||
|
response = await client.do_thing(name=name)
|
||||||
|
assert response.names == [name]
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_request_meta_recieved(deadline, metadata):
|
||||||
|
def server_side_test(stream):
|
||||||
|
assert stream.deadline._timestamp == pytest.approx(
|
||||||
|
deadline._timestamp, 1
|
||||||
|
), "The provided deadline should be recieved serverside"
|
||||||
|
assert (
|
||||||
|
stream.metadata["authorization"] == metadata["authorization"]
|
||||||
|
), "The provided authorization metadata should be recieved serverside"
|
||||||
|
|
||||||
|
return server_side_test
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_simple_service_call():
|
||||||
|
async with ChannelFor([ThingService()]) as channel:
|
||||||
|
await _test_client(ThingServiceClient(channel))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_service_call_with_upfront_request_params():
|
||||||
|
# Setting deadline
|
||||||
|
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||||
|
metadata = {"authorization": "12345"}
|
||||||
|
async with ChannelFor(
|
||||||
|
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
||||||
|
) as channel:
|
||||||
|
await _test_client(
|
||||||
|
ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Setting timeout
|
||||||
|
timeout = 99
|
||||||
|
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||||
|
metadata = {"authorization": "12345"}
|
||||||
|
async with ChannelFor(
|
||||||
|
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
||||||
|
) as channel:
|
||||||
|
await _test_client(
|
||||||
|
ThingServiceClient(channel, timeout=timeout, metadata=metadata)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_service_call_lower_level_with_overrides():
|
||||||
|
THING_TO_DO = "get milk"
|
||||||
|
|
||||||
|
# Setting deadline
|
||||||
|
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||||
|
metadata = {"authorization": "12345"}
|
||||||
|
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
|
||||||
|
kwarg_metadata = {"authorization": "12345"}
|
||||||
|
async with ChannelFor(
|
||||||
|
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
||||||
|
) as channel:
|
||||||
|
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||||
|
response = await client._unary_unary(
|
||||||
|
"/service.Test/DoThing",
|
||||||
|
DoThingRequest(THING_TO_DO),
|
||||||
|
DoThingResponse,
|
||||||
|
deadline=kwarg_deadline,
|
||||||
|
metadata=kwarg_metadata,
|
||||||
|
)
|
||||||
|
assert response.names == [THING_TO_DO]
|
||||||
|
|
||||||
|
# Setting timeout
|
||||||
|
timeout = 99
|
||||||
|
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||||
|
metadata = {"authorization": "12345"}
|
||||||
|
kwarg_timeout = 9000
|
||||||
|
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
|
||||||
|
kwarg_metadata = {"authorization": "09876"}
|
||||||
|
async with ChannelFor(
|
||||||
|
[
|
||||||
|
ThingService(
|
||||||
|
test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
) as channel:
|
||||||
|
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||||
|
response = await client._unary_unary(
|
||||||
|
"/service.Test/DoThing",
|
||||||
|
DoThingRequest(THING_TO_DO),
|
||||||
|
DoThingResponse,
|
||||||
|
timeout=kwarg_timeout,
|
||||||
|
metadata=kwarg_metadata,
|
||||||
|
)
|
||||||
|
assert response.names == [THING_TO_DO]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_gen_for_unary_stream_request():
|
||||||
|
thing_name = "my milkshakes"
|
||||||
|
|
||||||
|
async with ChannelFor([ThingService()]) as channel:
|
||||||
|
client = ThingServiceClient(channel)
|
||||||
|
expected_versions = [5, 4, 3, 2, 1]
|
||||||
|
async for response in client.get_thing_versions(name=thing_name):
|
||||||
|
assert response.name == thing_name
|
||||||
|
assert response.version == expected_versions.pop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_gen_for_stream_stream_request():
|
||||||
|
some_things = ["cake", "cricket", "coral reef"]
|
||||||
|
more_things = ["ball", "that", "56kmodem", "liberal humanism", "cheesesticks"]
|
||||||
|
expected_things = (*some_things, *more_things)
|
||||||
|
|
||||||
|
async with ChannelFor([ThingService()]) as channel:
|
||||||
|
client = ThingServiceClient(channel)
|
||||||
|
# Use an AsyncChannel to decouple sending and recieving, it'll send some_things
|
||||||
|
# immediately and we'll use it to send more_things later, after recieving some
|
||||||
|
# results
|
||||||
|
request_chan = AsyncChannel()
|
||||||
|
send_initial_requests = asyncio.ensure_future(
|
||||||
|
request_chan.send_from(GetThingRequest(name) for name in some_things)
|
||||||
|
)
|
||||||
|
response_index = 0
|
||||||
|
async for response in client.get_different_things(request_chan):
|
||||||
|
assert response.name == expected_things[response_index]
|
||||||
|
assert response.version == response_index + 1
|
||||||
|
response_index += 1
|
||||||
|
if more_things:
|
||||||
|
# Send some more requests as we recieve reponses to be sure coordination of
|
||||||
|
# send/recieve events doesn't matter
|
||||||
|
await request_chan.send(GetThingRequest(more_things.pop(0)))
|
||||||
|
elif not send_initial_requests.done():
|
||||||
|
# Make sure the sending task it completed
|
||||||
|
await send_initial_requests
|
||||||
|
else:
|
||||||
|
# No more things to send make sure channel is closed
|
||||||
|
request_chan.close()
|
||||||
|
assert response_index == len(
|
||||||
|
expected_things
|
||||||
|
), "Didn't recieve all exptected responses"
|
100
betterproto/tests/grpc/test_stream_stream.py
Normal file
100
betterproto/tests/grpc/test_stream_stream.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
import asyncio
|
||||||
|
import betterproto
|
||||||
|
from betterproto.grpc.util.async_channel import AsyncChannel
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import pytest
|
||||||
|
from typing import AsyncIterator
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Message(betterproto.Message):
|
||||||
|
body: str = betterproto.string_field(1)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def expected_responses():
|
||||||
|
return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")]
|
||||||
|
|
||||||
|
|
||||||
|
class ClientStub:
|
||||||
|
async def connect(self, requests: AsyncIterator):
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
async for request in requests:
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
yield request
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
yield Message("Done")
|
||||||
|
|
||||||
|
|
||||||
|
async def to_list(generator: AsyncIterator):
|
||||||
|
result = []
|
||||||
|
async for value in generator:
|
||||||
|
result.append(value)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
# channel = Channel(host='127.0.0.1', port=50051)
|
||||||
|
# return ClientStub(channel)
|
||||||
|
return ClientStub()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_from_before_connect_and_close_automatically(
|
||||||
|
client, expected_responses
|
||||||
|
):
|
||||||
|
requests = AsyncChannel()
|
||||||
|
await requests.send_from(
|
||||||
|
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
|
||||||
|
)
|
||||||
|
responses = client.connect(requests)
|
||||||
|
|
||||||
|
assert await to_list(responses) == expected_responses
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_from_after_connect_and_close_automatically(
|
||||||
|
client, expected_responses
|
||||||
|
):
|
||||||
|
requests = AsyncChannel()
|
||||||
|
responses = client.connect(requests)
|
||||||
|
await requests.send_from(
|
||||||
|
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await to_list(responses) == expected_responses
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_from_close_manually_immediately(client, expected_responses):
|
||||||
|
requests = AsyncChannel()
|
||||||
|
responses = client.connect(requests)
|
||||||
|
await requests.send_from(
|
||||||
|
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=False
|
||||||
|
)
|
||||||
|
requests.close()
|
||||||
|
|
||||||
|
assert await to_list(responses) == expected_responses
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_individually_and_close_before_connect(client, expected_responses):
|
||||||
|
requests = AsyncChannel()
|
||||||
|
await requests.send(Message(body="Hello world 1"))
|
||||||
|
await requests.send(Message(body="Hello world 2"))
|
||||||
|
requests.close()
|
||||||
|
responses = client.connect(requests)
|
||||||
|
|
||||||
|
assert await to_list(responses) == expected_responses
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_individually_and_close_after_connect(client, expected_responses):
|
||||||
|
requests = AsyncChannel()
|
||||||
|
await requests.send(Message(body="Hello world 1"))
|
||||||
|
await requests.send(Message(body="Hello world 2"))
|
||||||
|
responses = client.connect(requests)
|
||||||
|
requests.close()
|
||||||
|
|
||||||
|
assert await to_list(responses) == expected_responses
|
83
betterproto/tests/grpc/thing_service.py
Normal file
83
betterproto/tests/grpc/thing_service.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
from betterproto.tests.output_betterproto.service.service import (
|
||||||
|
DoThingResponse,
|
||||||
|
DoThingRequest,
|
||||||
|
GetThingRequest,
|
||||||
|
GetThingResponse,
|
||||||
|
TestStub as ThingServiceClient,
|
||||||
|
)
|
||||||
|
import grpclib
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
class ThingService:
|
||||||
|
def __init__(self, test_hook=None):
|
||||||
|
# This lets us pass assertions to the servicer ;)
|
||||||
|
self.test_hook = test_hook
|
||||||
|
|
||||||
|
async def do_thing(
|
||||||
|
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||||
|
):
|
||||||
|
request = await stream.recv_message()
|
||||||
|
if self.test_hook is not None:
|
||||||
|
self.test_hook(stream)
|
||||||
|
await stream.send_message(DoThingResponse([request.name]))
|
||||||
|
|
||||||
|
async def do_many_things(
|
||||||
|
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||||
|
):
|
||||||
|
thing_names = [request.name for request in stream]
|
||||||
|
if self.test_hook is not None:
|
||||||
|
self.test_hook(stream)
|
||||||
|
await stream.send_message(DoThingResponse(thing_names))
|
||||||
|
|
||||||
|
async def get_thing_versions(
|
||||||
|
self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
|
||||||
|
):
|
||||||
|
request = await stream.recv_message()
|
||||||
|
if self.test_hook is not None:
|
||||||
|
self.test_hook(stream)
|
||||||
|
for version_num in range(1, 6):
|
||||||
|
await stream.send_message(
|
||||||
|
GetThingResponse(name=request.name, version=version_num)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_different_things(
|
||||||
|
self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
|
||||||
|
):
|
||||||
|
if self.test_hook is not None:
|
||||||
|
self.test_hook(stream)
|
||||||
|
# Respond to each input item immediately
|
||||||
|
response_num = 0
|
||||||
|
async for request in stream:
|
||||||
|
response_num += 1
|
||||||
|
await stream.send_message(
|
||||||
|
GetThingResponse(name=request.name, version=response_num)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __mapping__(self) -> Dict[str, "grpclib.const.Handler"]:
|
||||||
|
return {
|
||||||
|
"/service.Test/DoThing": grpclib.const.Handler(
|
||||||
|
self.do_thing,
|
||||||
|
grpclib.const.Cardinality.UNARY_UNARY,
|
||||||
|
DoThingRequest,
|
||||||
|
DoThingResponse,
|
||||||
|
),
|
||||||
|
"/service.Test/DoManyThings": grpclib.const.Handler(
|
||||||
|
self.do_many_things,
|
||||||
|
grpclib.const.Cardinality.STREAM_UNARY,
|
||||||
|
DoThingRequest,
|
||||||
|
DoThingResponse,
|
||||||
|
),
|
||||||
|
"/service.Test/GetThingVersions": grpclib.const.Handler(
|
||||||
|
self.get_thing_versions,
|
||||||
|
grpclib.const.Cardinality.UNARY_STREAM,
|
||||||
|
GetThingRequest,
|
||||||
|
GetThingResponse,
|
||||||
|
),
|
||||||
|
"/service.Test/GetDifferentThings": grpclib.const.Handler(
|
||||||
|
self.get_different_things,
|
||||||
|
grpclib.const.Cardinality.STREAM_STREAM,
|
||||||
|
GetThingRequest,
|
||||||
|
GetThingResponse,
|
||||||
|
),
|
||||||
|
}
|
@ -23,7 +23,7 @@ test_cases = [
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||||
async def test_channel_receives_wrapped_type(
|
async def test_channel_recieves_wrapped_type(
|
||||||
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
||||||
):
|
):
|
||||||
wrapped_value = wrapper_class()
|
wrapped_value = wrapper_class()
|
||||||
|
@ -3,13 +3,25 @@ syntax = "proto3";
|
|||||||
package service;
|
package service;
|
||||||
|
|
||||||
message DoThingRequest {
|
message DoThingRequest {
|
||||||
int32 iterations = 1;
|
string name = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message DoThingResponse {
|
message DoThingResponse {
|
||||||
int32 successfulIterations = 1;
|
repeated string names = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetThingRequest {
|
||||||
|
string name = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetThingResponse {
|
||||||
|
string name = 1;
|
||||||
|
int32 version = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
service Test {
|
service Test {
|
||||||
rpc DoThing (DoThingRequest) returns (DoThingResponse);
|
rpc DoThing (DoThingRequest) returns (DoThingResponse);
|
||||||
|
rpc DoManyThings (stream DoThingRequest) returns (DoThingResponse);
|
||||||
|
rpc GetThingVersions (GetThingRequest) returns (stream GetThingResponse);
|
||||||
|
rpc GetDifferentThings (stream GetThingRequest) returns (stream GetThingResponse);
|
||||||
}
|
}
|
||||||
|
@ -1,132 +0,0 @@
|
|||||||
import betterproto
|
|
||||||
import grpclib
|
|
||||||
from grpclib.testing import ChannelFor
|
|
||||||
import pytest
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from betterproto.tests.output_betterproto.service.service import (
|
|
||||||
DoThingResponse,
|
|
||||||
DoThingRequest,
|
|
||||||
TestStub as ExampleServiceStub,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ExampleService:
|
|
||||||
def __init__(self, test_hook=None):
|
|
||||||
# This lets us pass assertions to the servicer ;)
|
|
||||||
self.test_hook = test_hook
|
|
||||||
|
|
||||||
async def DoThing(
|
|
||||||
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
|
||||||
):
|
|
||||||
request = await stream.recv_message()
|
|
||||||
print("self.test_hook", self.test_hook)
|
|
||||||
if self.test_hook is not None:
|
|
||||||
self.test_hook(stream)
|
|
||||||
for iteration in range(request.iterations):
|
|
||||||
pass
|
|
||||||
await stream.send_message(DoThingResponse(request.iterations))
|
|
||||||
|
|
||||||
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
|
|
||||||
return {
|
|
||||||
"/service.Test/DoThing": grpclib.const.Handler(
|
|
||||||
self.DoThing,
|
|
||||||
grpclib.const.Cardinality.UNARY_UNARY,
|
|
||||||
DoThingRequest,
|
|
||||||
DoThingResponse,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def _test_stub(stub, iterations=42, **kwargs):
|
|
||||||
response = await stub.do_thing(iterations=iterations)
|
|
||||||
assert response.successful_iterations == iterations
|
|
||||||
|
|
||||||
|
|
||||||
def _get_server_side_test(deadline, metadata):
|
|
||||||
def server_side_test(stream):
|
|
||||||
assert stream.deadline._timestamp == pytest.approx(
|
|
||||||
deadline._timestamp, 1
|
|
||||||
), "The provided deadline should be recieved serverside"
|
|
||||||
assert (
|
|
||||||
stream.metadata["authorization"] == metadata["authorization"]
|
|
||||||
), "The provided authorization metadata should be recieved serverside"
|
|
||||||
|
|
||||||
return server_side_test
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_simple_service_call():
|
|
||||||
async with ChannelFor([ExampleService()]) as channel:
|
|
||||||
await _test_stub(ExampleServiceStub(channel))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_service_call_with_upfront_request_params():
|
|
||||||
# Setting deadline
|
|
||||||
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
|
||||||
metadata = {"authorization": "12345"}
|
|
||||||
async with ChannelFor(
|
|
||||||
[ExampleService(test_hook=_get_server_side_test(deadline, metadata))]
|
|
||||||
) as channel:
|
|
||||||
await _test_stub(
|
|
||||||
ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Setting timeout
|
|
||||||
timeout = 99
|
|
||||||
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
|
||||||
metadata = {"authorization": "12345"}
|
|
||||||
async with ChannelFor(
|
|
||||||
[ExampleService(test_hook=_get_server_side_test(deadline, metadata))]
|
|
||||||
) as channel:
|
|
||||||
await _test_stub(
|
|
||||||
ExampleServiceStub(channel, timeout=timeout, metadata=metadata)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_service_call_lower_level_with_overrides():
|
|
||||||
ITERATIONS = 99
|
|
||||||
|
|
||||||
# Setting deadline
|
|
||||||
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
|
||||||
metadata = {"authorization": "12345"}
|
|
||||||
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
|
|
||||||
kwarg_metadata = {"authorization": "12345"}
|
|
||||||
async with ChannelFor(
|
|
||||||
[ExampleService(test_hook=_get_server_side_test(deadline, metadata))]
|
|
||||||
) as channel:
|
|
||||||
stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
|
|
||||||
response = await stub._unary_unary(
|
|
||||||
"/service.Test/DoThing",
|
|
||||||
DoThingRequest(ITERATIONS),
|
|
||||||
DoThingResponse,
|
|
||||||
deadline=kwarg_deadline,
|
|
||||||
metadata=kwarg_metadata,
|
|
||||||
)
|
|
||||||
assert response.successful_iterations == ITERATIONS
|
|
||||||
|
|
||||||
# Setting timeout
|
|
||||||
timeout = 99
|
|
||||||
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
|
||||||
metadata = {"authorization": "12345"}
|
|
||||||
kwarg_timeout = 9000
|
|
||||||
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
|
|
||||||
kwarg_metadata = {"authorization": "09876"}
|
|
||||||
async with ChannelFor(
|
|
||||||
[
|
|
||||||
ExampleService(
|
|
||||||
test_hook=_get_server_side_test(kwarg_deadline, kwarg_metadata)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
) as channel:
|
|
||||||
stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
|
|
||||||
response = await stub._unary_unary(
|
|
||||||
"/service.Test/DoThing",
|
|
||||||
DoThingRequest(ITERATIONS),
|
|
||||||
DoThingResponse,
|
|
||||||
timeout=kwarg_timeout,
|
|
||||||
metadata=kwarg_metadata,
|
|
||||||
)
|
|
||||||
assert response.successful_iterations == ITERATIONS
|
|
@ -23,7 +23,7 @@ from google.protobuf.json_format import Parse
|
|||||||
|
|
||||||
class TestCases:
|
class TestCases:
|
||||||
def __init__(self, path, services: Set[str], xfail: Set[str]):
|
def __init__(self, path, services: Set[str], xfail: Set[str]):
|
||||||
_all = set(get_directories(path))
|
_all = set(get_directories(path)) - {"__pycache__"}
|
||||||
_services = services
|
_services = services
|
||||||
_messages = (_all - services) - {"__pycache__"}
|
_messages = (_all - services) - {"__pycache__"}
|
||||||
_messages_with_json = {
|
_messages_with_json = {
|
||||||
|
@ -1,23 +1,24 @@
|
|||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import subprocess
|
from pathlib import Path
|
||||||
from typing import Generator
|
from typing import Generator, IO, Optional
|
||||||
|
|
||||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||||
|
|
||||||
root_path = os.path.dirname(os.path.realpath(__file__))
|
root_path = Path(__file__).resolve().parent
|
||||||
inputs_path = os.path.join(root_path, "inputs")
|
inputs_path = root_path.joinpath("inputs")
|
||||||
output_path_reference = os.path.join(root_path, "output_reference")
|
output_path_reference = root_path.joinpath("output_reference")
|
||||||
output_path_betterproto = os.path.join(root_path, "output_betterproto")
|
output_path_betterproto = root_path.joinpath("output_betterproto")
|
||||||
|
|
||||||
if os.name == "nt":
|
if os.name == "nt":
|
||||||
plugin_path = os.path.join(root_path, "..", "plugin.bat")
|
plugin_path = root_path.joinpath("..", "plugin.bat").resolve()
|
||||||
else:
|
else:
|
||||||
plugin_path = os.path.join(root_path, "..", "plugin.py")
|
plugin_path = root_path.joinpath("..", "plugin.py").resolve()
|
||||||
|
|
||||||
|
|
||||||
def get_files(path, end: str) -> Generator[str, None, None]:
|
def get_files(path, suffix: str) -> Generator[str, None, None]:
|
||||||
for r, dirs, files in os.walk(path):
|
for r, dirs, files in os.walk(path):
|
||||||
for filename in [f for f in files if f.endswith(end)]:
|
for filename in [f for f in files if f.endswith(suffix)]:
|
||||||
yield os.path.join(r, filename)
|
yield os.path.join(r, filename)
|
||||||
|
|
||||||
|
|
||||||
@ -27,36 +28,30 @@ def get_directories(path):
|
|||||||
yield directory
|
yield directory
|
||||||
|
|
||||||
|
|
||||||
def relative(file: str, path: str):
|
async def protoc_plugin(path: str, output_dir: str):
|
||||||
return os.path.join(os.path.dirname(file), path)
|
proc = await asyncio.create_subprocess_shell(
|
||||||
|
|
||||||
|
|
||||||
def read_relative(file: str, path: str):
|
|
||||||
with open(relative(file, path)) as fh:
|
|
||||||
return fh.read()
|
|
||||||
|
|
||||||
|
|
||||||
def protoc_plugin(path: str, output_dir: str) -> subprocess.CompletedProcess:
|
|
||||||
return subprocess.run(
|
|
||||||
f"protoc --plugin=protoc-gen-custom={plugin_path} --custom_out={output_dir} --proto_path={path} {path}/*.proto",
|
f"protoc --plugin=protoc-gen-custom={plugin_path} --custom_out={output_dir} --proto_path={path} {path}/*.proto",
|
||||||
shell=True,
|
stdout=asyncio.subprocess.PIPE,
|
||||||
check=True,
|
stderr=asyncio.subprocess.PIPE,
|
||||||
)
|
)
|
||||||
|
return (*(await proc.communicate()), proc.returncode)
|
||||||
|
|
||||||
|
|
||||||
def protoc_reference(path: str, output_dir: str):
|
async def protoc_reference(path: str, output_dir: str):
|
||||||
subprocess.run(
|
proc = await asyncio.create_subprocess_shell(
|
||||||
f"protoc --python_out={output_dir} --proto_path={path} {path}/*.proto",
|
f"protoc --python_out={output_dir} --proto_path={path} {path}/*.proto",
|
||||||
shell=True,
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
)
|
)
|
||||||
|
return (*(await proc.communicate()), proc.returncode)
|
||||||
|
|
||||||
|
|
||||||
def get_test_case_json_data(test_case_name, json_file_name=None):
|
def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = None):
|
||||||
test_data_file_name = json_file_name if json_file_name else f"{test_case_name}.json"
|
test_data_file_name = json_file_name if json_file_name else f"{test_case_name}.json"
|
||||||
test_data_file_path = os.path.join(inputs_path, test_case_name, test_data_file_name)
|
test_data_file_path = inputs_path.joinpath(test_case_name, test_data_file_name)
|
||||||
|
|
||||||
if not os.path.exists(test_data_file_path):
|
if not test_data_file_path.exists():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
with open(test_data_file_path) as fh:
|
with test_data_file_path.open("r") as fh:
|
||||||
return fh.read()
|
return fh.read()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user