Merge remote-tracking branch 'daniel/master' into fix/service-input-message
# Conflicts: # betterproto/plugin.py
This commit is contained in:
commit
72d72b4603
1
.gitignore
vendored
1
.gitignore
vendored
@ -14,3 +14,4 @@ output
|
||||
.idea
|
||||
.DS_Store
|
||||
.tox
|
||||
.venv
|
||||
|
@ -22,7 +22,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from ._types import T
|
||||
from .casing import camel_case, safe_snake_case, safe_snake_case, snake_case
|
||||
from .casing import camel_case, safe_snake_case, snake_case
|
||||
from .grpc.grpclib_client import ServiceStub
|
||||
|
||||
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
|
||||
@ -378,7 +378,7 @@ def decode_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, i
|
||||
result |= (b & 0x7F) << shift
|
||||
pos += 1
|
||||
if not (b & 0x80):
|
||||
return (result, pos)
|
||||
return result, pos
|
||||
shift += 7
|
||||
if shift >= 64:
|
||||
raise ValueError("Too many bytes when decoding varint.")
|
||||
@ -479,7 +479,7 @@ class ProtoClassMetadata:
|
||||
assert meta.map_types
|
||||
kt = cls._cls_for(field, index=0)
|
||||
vt = cls._cls_for(field, index=1)
|
||||
Entry = dataclasses.make_dataclass(
|
||||
field_cls[field.name] = dataclasses.make_dataclass(
|
||||
"Entry",
|
||||
[
|
||||
("key", kt, dataclass_field(1, meta.map_types[0])),
|
||||
@ -487,7 +487,6 @@ class ProtoClassMetadata:
|
||||
],
|
||||
bases=(Message,),
|
||||
)
|
||||
field_cls[field.name] = Entry
|
||||
field_cls[field.name + ".value"] = vt
|
||||
else:
|
||||
field_cls[field.name] = cls._cls_for(field)
|
||||
@ -588,7 +587,7 @@ class Message(ABC):
|
||||
serialize_empty = False
|
||||
if isinstance(value, Message) and value._serialized_on_wire:
|
||||
# Empty messages can still be sent on the wire if they were
|
||||
# set (or recieved empty).
|
||||
# set (or received empty).
|
||||
serialize_empty = True
|
||||
|
||||
if value == self._get_field_default(field_name) and not (
|
||||
@ -926,7 +925,7 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
||||
|
||||
|
||||
# Circular import workaround: google.protobuf depends on base classes defined above.
|
||||
from .lib.google.protobuf import (
|
||||
from .lib.google.protobuf import ( # noqa
|
||||
Duration,
|
||||
Timestamp,
|
||||
BoolValue,
|
||||
|
@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import Message
|
||||
from grpclib._protocols import IProtoMessage
|
||||
from grpclib._typing import IProtoMessage
|
||||
|
||||
# Bound type variable to allow methods to return `self` of subclasses
|
||||
T = TypeVar("T", bound="Message")
|
||||
|
@ -20,6 +20,8 @@ def safe_snake_case(value: str) -> str:
|
||||
"and",
|
||||
"as",
|
||||
"assert",
|
||||
"async",
|
||||
"await",
|
||||
"break",
|
||||
"class",
|
||||
"continue",
|
||||
|
@ -2,7 +2,6 @@ from abc import ABC
|
||||
import asyncio
|
||||
import grpclib.const
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Collection,
|
||||
@ -17,8 +16,8 @@ from typing import (
|
||||
from .._types import ST, T
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from grpclib._protocols import IProtoMessage
|
||||
from grpclib.client import Channel, Stream
|
||||
from grpclib._typing import IProtoMessage
|
||||
from grpclib.client import Channel
|
||||
from grpclib.metadata import Deadline
|
||||
|
||||
|
||||
|
@ -21,7 +21,7 @@ class ChannelClosed(Exception):
|
||||
|
||||
class ChannelDone(Exception):
|
||||
"""
|
||||
An exception raised on an attempt to send recieve from a channel that is both closed
|
||||
An exception raised on an attempt to send receive from a channel that is both closed
|
||||
and empty.
|
||||
"""
|
||||
|
||||
@ -32,41 +32,41 @@ class AsyncChannel(AsyncIterable[T]):
|
||||
"""
|
||||
A buffered async channel for sending items between coroutines with FIFO ordering.
|
||||
|
||||
This makes decoupled bidirection steaming gRPC requests easy if used like:
|
||||
This makes decoupled bidirectional steaming gRPC requests easy if used like:
|
||||
|
||||
.. code-block:: python
|
||||
client = GeneratedStub(grpclib_chan)
|
||||
request_chan = await AsyncChannel()
|
||||
request_channel = await AsyncChannel()
|
||||
# We can start be sending all the requests we already have
|
||||
await request_chan.send_from([ReqestObject(...), ReqestObject(...)])
|
||||
async for response in client.rpc_call(request_chan):
|
||||
await request_channel.send_from([RequestObject(...), RequestObject(...)])
|
||||
async for response in client.rpc_call(request_channel):
|
||||
# The response iterator will remain active until the connection is closed
|
||||
...
|
||||
# More items can be sent at any time
|
||||
await request_chan.send(ReqestObject(...))
|
||||
await request_channel.send(RequestObject(...))
|
||||
...
|
||||
# The channel must be closed to complete the gRPC connection
|
||||
request_chan.close()
|
||||
request_channel.close()
|
||||
|
||||
Items can be sent through the channel by either:
|
||||
- providing an iterable to the send_from method
|
||||
- passing them to the send method one at a time
|
||||
|
||||
Items can be recieved from the channel by either:
|
||||
Items can be received from the channel by either:
|
||||
- iterating over the channel with a for loop to get all items
|
||||
- calling the recieve method to get one item at a time
|
||||
- calling the receive method to get one item at a time
|
||||
|
||||
If the channel is empty then recievers will wait until either an item appears or the
|
||||
If the channel is empty then receivers will wait until either an item appears or the
|
||||
channel is closed.
|
||||
|
||||
Once the channel is closed then subsequent attempt to send through the channel will
|
||||
fail with a ChannelClosed exception.
|
||||
|
||||
When th channel is closed and empty then it is done, and further attempts to recieve
|
||||
When th channel is closed and empty then it is done, and further attempts to receive
|
||||
from it will fail with a ChannelDone exception
|
||||
|
||||
If multiple coroutines recieve from the channel concurrently, each item sent will be
|
||||
recieved by only one of the recievers.
|
||||
If multiple coroutines receive from the channel concurrently, each item sent will be
|
||||
received by only one of the receivers.
|
||||
|
||||
:param source:
|
||||
An optional iterable will items that should be sent through the channel
|
||||
@ -74,7 +74,7 @@ class AsyncChannel(AsyncIterable[T]):
|
||||
:param buffer_limit:
|
||||
Limit the number of items that can be buffered in the channel, A value less than
|
||||
1 implies no limit. If the channel is full then attempts to send more items will
|
||||
result in the sender waiting until an item is recieved from the channel.
|
||||
result in the sender waiting until an item is received from the channel.
|
||||
:param close:
|
||||
If set to True then the channel will automatically close after exhausting source
|
||||
or immediately if no source is provided.
|
||||
@ -85,7 +85,7 @@ class AsyncChannel(AsyncIterable[T]):
|
||||
):
|
||||
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
|
||||
self._closed = False
|
||||
self._waiting_recievers: int = 0
|
||||
self._waiting_receivers: int = 0
|
||||
# Track whether flush has been invoked so it can only happen once
|
||||
self._flushed = False
|
||||
|
||||
@ -95,14 +95,14 @@ class AsyncChannel(AsyncIterable[T]):
|
||||
async def __anext__(self) -> T:
|
||||
if self.done():
|
||||
raise StopAsyncIteration
|
||||
self._waiting_recievers += 1
|
||||
self._waiting_receivers += 1
|
||||
try:
|
||||
result = await self._queue.get()
|
||||
if result is self.__flush:
|
||||
raise StopAsyncIteration
|
||||
return result
|
||||
finally:
|
||||
self._waiting_recievers -= 1
|
||||
self._waiting_receivers -= 1
|
||||
self._queue.task_done()
|
||||
|
||||
def closed(self) -> bool:
|
||||
@ -116,12 +116,12 @@ class AsyncChannel(AsyncIterable[T]):
|
||||
Check if this channel is done.
|
||||
|
||||
:return: True if this channel is closed and and has been drained of items in
|
||||
which case any further attempts to recieve an item from this channel will raise
|
||||
which case any further attempts to receive an item from this channel will raise
|
||||
a ChannelDone exception.
|
||||
"""
|
||||
# After close the channel is not yet done until there is at least one waiting
|
||||
# reciever per enqueued item.
|
||||
return self._closed and self._queue.qsize() <= self._waiting_recievers
|
||||
# receiver per enqueued item.
|
||||
return self._closed and self._queue.qsize() <= self._waiting_receivers
|
||||
|
||||
async def send_from(
|
||||
self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
|
||||
@ -158,22 +158,22 @@ class AsyncChannel(AsyncIterable[T]):
|
||||
await self._queue.put(item)
|
||||
return self
|
||||
|
||||
async def recieve(self) -> Optional[T]:
|
||||
async def receive(self) -> Optional[T]:
|
||||
"""
|
||||
Returns the next item from this channel when it becomes available,
|
||||
or None if the channel is closed before another item is sent.
|
||||
:return: An item from the channel
|
||||
"""
|
||||
if self.done():
|
||||
raise ChannelDone("Cannot recieve from a closed channel")
|
||||
self._waiting_recievers += 1
|
||||
raise ChannelDone("Cannot receive from a closed channel")
|
||||
self._waiting_receivers += 1
|
||||
try:
|
||||
result = await self._queue.get()
|
||||
if result is self.__flush:
|
||||
return None
|
||||
return result
|
||||
finally:
|
||||
self._waiting_recievers -= 1
|
||||
self._waiting_receivers -= 1
|
||||
self._queue.task_done()
|
||||
|
||||
def close(self):
|
||||
@ -190,8 +190,8 @@ class AsyncChannel(AsyncIterable[T]):
|
||||
"""
|
||||
if not self._flushed:
|
||||
self._flushed = True
|
||||
deadlocked_recievers = max(0, self._waiting_recievers - self._queue.qsize())
|
||||
for _ in range(deadlocked_recievers):
|
||||
deadlocked_receivers = max(0, self._waiting_receivers - self._queue.qsize())
|
||||
for _ in range(deadlocked_receivers):
|
||||
await self._queue.put(self.__flush)
|
||||
|
||||
# A special signal object for flushing the queue when the channel is closed
|
||||
|
@ -199,7 +199,7 @@ def generate_code(request, response):
|
||||
# Render and then format the output file.
|
||||
f.content = black.format_str(
|
||||
template.render(description=template_data),
|
||||
mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])),
|
||||
mode=black.FileMode(target_versions={black.TargetVersion.PY37}),
|
||||
)
|
||||
|
||||
# Make each output directory a package with __init__ file
|
||||
|
@ -3,7 +3,6 @@ import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Set
|
||||
|
||||
|
@ -3,7 +3,6 @@ from betterproto.tests.output_betterproto.service.service import (
|
||||
DoThingResponse,
|
||||
DoThingRequest,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
TestStub as ThingServiceClient,
|
||||
)
|
||||
import grpclib
|
||||
@ -18,14 +17,14 @@ async def _test_client(client, name="clean room", **kwargs):
|
||||
assert response.names == [name]
|
||||
|
||||
|
||||
def _assert_request_meta_recieved(deadline, metadata):
|
||||
def _assert_request_meta_received(deadline, metadata):
|
||||
def server_side_test(stream):
|
||||
assert stream.deadline._timestamp == pytest.approx(
|
||||
deadline._timestamp, 1
|
||||
), "The provided deadline should be recieved serverside"
|
||||
), "The provided deadline should be received serverside"
|
||||
assert (
|
||||
stream.metadata["authorization"] == metadata["authorization"]
|
||||
), "The provided authorization metadata should be recieved serverside"
|
||||
), "The provided authorization metadata should be received serverside"
|
||||
|
||||
return server_side_test
|
||||
|
||||
@ -42,7 +41,7 @@ async def test_service_call_with_upfront_request_params():
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||
metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
||||
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata),)]
|
||||
) as channel:
|
||||
await _test_client(
|
||||
ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||
@ -53,7 +52,7 @@ async def test_service_call_with_upfront_request_params():
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||
metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
||||
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata),)]
|
||||
) as channel:
|
||||
await _test_client(
|
||||
ThingServiceClient(channel, timeout=timeout, metadata=metadata)
|
||||
@ -70,7 +69,7 @@ async def test_service_call_lower_level_with_overrides():
|
||||
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
|
||||
kwarg_metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
||||
[ThingService(test_hook=_assert_request_meta_received(deadline, metadata),)]
|
||||
) as channel:
|
||||
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||
response = await client._unary_unary(
|
||||
@ -92,7 +91,7 @@ async def test_service_call_lower_level_with_overrides():
|
||||
async with ChannelFor(
|
||||
[
|
||||
ThingService(
|
||||
test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata),
|
||||
test_hook=_assert_request_meta_received(kwarg_deadline, kwarg_metadata),
|
||||
)
|
||||
]
|
||||
) as channel:
|
||||
@ -140,8 +139,8 @@ async def test_async_gen_for_stream_stream_request():
|
||||
assert response.version == response_index + 1
|
||||
response_index += 1
|
||||
if more_things:
|
||||
# Send some more requests as we recieve reponses to be sure coordination of
|
||||
# send/recieve events doesn't matter
|
||||
# Send some more requests as we receive responses to be sure coordination of
|
||||
# send/receive events doesn't matter
|
||||
await request_chan.send(GetThingRequest(more_things.pop(0)))
|
||||
elif not send_initial_requests.done():
|
||||
# Make sure the sending task it completed
|
||||
@ -151,4 +150,4 @@ async def test_async_gen_for_stream_stream_request():
|
||||
request_chan.close()
|
||||
assert response_index == len(
|
||||
expected_things
|
||||
), "Didn't recieve all exptected responses"
|
||||
), "Didn't receive all expected responses"
|
||||
|
@ -3,10 +3,10 @@ from betterproto.tests.output_betterproto.service.service import (
|
||||
DoThingRequest,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
TestStub as ThingServiceClient,
|
||||
)
|
||||
import grpclib
|
||||
from typing import Any, Dict
|
||||
import grpclib.server
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class ThingService:
|
||||
|
@ -21,7 +21,7 @@ test_cases = [
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||
async def test_channel_recieves_wrapped_type(
|
||||
async def test_channel_receives_wrapped_type(
|
||||
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
||||
):
|
||||
wrapped_value = wrapper_class()
|
||||
|
@ -14,23 +14,23 @@ def test_has_field():
|
||||
|
||||
# Unset by default
|
||||
foo = Foo()
|
||||
assert betterproto.serialized_on_wire(foo.bar) == False
|
||||
assert betterproto.serialized_on_wire(foo.bar) is False
|
||||
|
||||
# Serialized after setting something
|
||||
foo.bar.baz = 1
|
||||
assert betterproto.serialized_on_wire(foo.bar) == True
|
||||
assert betterproto.serialized_on_wire(foo.bar) is True
|
||||
|
||||
# Still has it after setting the default value
|
||||
foo.bar.baz = 0
|
||||
assert betterproto.serialized_on_wire(foo.bar) == True
|
||||
assert betterproto.serialized_on_wire(foo.bar) is True
|
||||
|
||||
# Manual override (don't do this)
|
||||
foo.bar._serialized_on_wire = False
|
||||
assert betterproto.serialized_on_wire(foo.bar) == False
|
||||
assert betterproto.serialized_on_wire(foo.bar) is False
|
||||
|
||||
# Can manually set it but defaults to false
|
||||
foo.bar = Bar()
|
||||
assert betterproto.serialized_on_wire(foo.bar) == False
|
||||
assert betterproto.serialized_on_wire(foo.bar) is False
|
||||
|
||||
|
||||
def test_class_init():
|
||||
@ -118,7 +118,7 @@ def test_oneof_support():
|
||||
|
||||
# Group 1 shouldn't be touched, group 2 should have reset
|
||||
assert foo.sub.val == 0
|
||||
assert betterproto.serialized_on_wire(foo.sub) == False
|
||||
assert betterproto.serialized_on_wire(foo.sub) is False
|
||||
assert betterproto.which_one_of(foo, "group2")[0] == "abc"
|
||||
|
||||
# Zero value should always serialize for one-of
|
||||
@ -175,8 +175,8 @@ def test_optional_flag():
|
||||
assert bytes(Request(flag=False)) == b"\n\x00"
|
||||
|
||||
# Differentiate between not passed and the zero-value.
|
||||
assert Request().parse(b"").flag == None
|
||||
assert Request().parse(b"\n\x00").flag == False
|
||||
assert Request().parse(b"").flag is None
|
||||
assert Request().parse(b"\n\x00").flag is False
|
||||
|
||||
|
||||
def test_to_dict_default_values():
|
||||
|
Loading…
x
Reference in New Issue
Block a user