Merge remote-tracking branch 'daniel/master' into fix/service-input-message

# Conflicts:
#	betterproto/plugin.py
This commit is contained in:
boukeversteegh 2020-07-08 23:00:32 +02:00
commit 72d72b4603
12 changed files with 59 additions and 60 deletions

1
.gitignore vendored
View File

@ -14,3 +14,4 @@ output
.idea .idea
.DS_Store .DS_Store
.tox .tox
.venv

View File

@ -22,7 +22,7 @@ from typing import (
) )
from ._types import T from ._types import T
from .casing import camel_case, safe_snake_case, safe_snake_case, snake_case from .casing import camel_case, safe_snake_case, snake_case
from .grpc.grpclib_client import ServiceStub from .grpc.grpclib_client import ServiceStub
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7): if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
@ -378,7 +378,7 @@ def decode_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, i
result |= (b & 0x7F) << shift result |= (b & 0x7F) << shift
pos += 1 pos += 1
if not (b & 0x80): if not (b & 0x80):
return (result, pos) return result, pos
shift += 7 shift += 7
if shift >= 64: if shift >= 64:
raise ValueError("Too many bytes when decoding varint.") raise ValueError("Too many bytes when decoding varint.")
@ -479,7 +479,7 @@ class ProtoClassMetadata:
assert meta.map_types assert meta.map_types
kt = cls._cls_for(field, index=0) kt = cls._cls_for(field, index=0)
vt = cls._cls_for(field, index=1) vt = cls._cls_for(field, index=1)
Entry = dataclasses.make_dataclass( field_cls[field.name] = dataclasses.make_dataclass(
"Entry", "Entry",
[ [
("key", kt, dataclass_field(1, meta.map_types[0])), ("key", kt, dataclass_field(1, meta.map_types[0])),
@ -487,7 +487,6 @@ class ProtoClassMetadata:
], ],
bases=(Message,), bases=(Message,),
) )
field_cls[field.name] = Entry
field_cls[field.name + ".value"] = vt field_cls[field.name + ".value"] = vt
else: else:
field_cls[field.name] = cls._cls_for(field) field_cls[field.name] = cls._cls_for(field)
@ -588,7 +587,7 @@ class Message(ABC):
serialize_empty = False serialize_empty = False
if isinstance(value, Message) and value._serialized_on_wire: if isinstance(value, Message) and value._serialized_on_wire:
# Empty messages can still be sent on the wire if they were # Empty messages can still be sent on the wire if they were
# set (or recieved empty). # set (or received empty).
serialize_empty = True serialize_empty = True
if value == self._get_field_default(field_name) and not ( if value == self._get_field_default(field_name) and not (
@ -926,7 +925,7 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
# Circular import workaround: google.protobuf depends on base classes defined above. # Circular import workaround: google.protobuf depends on base classes defined above.
from .lib.google.protobuf import ( from .lib.google.protobuf import ( # noqa
Duration, Duration,
Timestamp, Timestamp,
BoolValue, BoolValue,

View File

@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from . import Message from . import Message
from grpclib._protocols import IProtoMessage from grpclib._typing import IProtoMessage
# Bound type variable to allow methods to return `self` of subclasses # Bound type variable to allow methods to return `self` of subclasses
T = TypeVar("T", bound="Message") T = TypeVar("T", bound="Message")

View File

@ -20,6 +20,8 @@ def safe_snake_case(value: str) -> str:
"and", "and",
"as", "as",
"assert", "assert",
"async",
"await",
"break", "break",
"class", "class",
"continue", "continue",

View File

@ -2,7 +2,6 @@ from abc import ABC
import asyncio import asyncio
import grpclib.const import grpclib.const
from typing import ( from typing import (
Any,
AsyncIterable, AsyncIterable,
AsyncIterator, AsyncIterator,
Collection, Collection,
@ -17,8 +16,8 @@ from typing import (
from .._types import ST, T from .._types import ST, T
if TYPE_CHECKING: if TYPE_CHECKING:
from grpclib._protocols import IProtoMessage from grpclib._typing import IProtoMessage
from grpclib.client import Channel, Stream from grpclib.client import Channel
from grpclib.metadata import Deadline from grpclib.metadata import Deadline

View File

@ -21,7 +21,7 @@ class ChannelClosed(Exception):
class ChannelDone(Exception): class ChannelDone(Exception):
""" """
An exception raised on an attempt to send recieve from a channel that is both closed An exception raised on an attempt to send receive from a channel that is both closed
and empty. and empty.
""" """
@ -32,41 +32,41 @@ class AsyncChannel(AsyncIterable[T]):
""" """
A buffered async channel for sending items between coroutines with FIFO ordering. A buffered async channel for sending items between coroutines with FIFO ordering.
This makes decoupled bidirection steaming gRPC requests easy if used like: This makes decoupled bidirectional steaming gRPC requests easy if used like:
.. code-block:: python .. code-block:: python
client = GeneratedStub(grpclib_chan) client = GeneratedStub(grpclib_chan)
request_chan = await AsyncChannel() request_channel = await AsyncChannel()
# We can start be sending all the requests we already have # We can start be sending all the requests we already have
await request_chan.send_from([ReqestObject(...), ReqestObject(...)]) await request_channel.send_from([RequestObject(...), RequestObject(...)])
async for response in client.rpc_call(request_chan): async for response in client.rpc_call(request_channel):
# The response iterator will remain active until the connection is closed # The response iterator will remain active until the connection is closed
... ...
# More items can be sent at any time # More items can be sent at any time
await request_chan.send(ReqestObject(...)) await request_channel.send(RequestObject(...))
... ...
# The channel must be closed to complete the gRPC connection # The channel must be closed to complete the gRPC connection
request_chan.close() request_channel.close()
Items can be sent through the channel by either: Items can be sent through the channel by either:
- providing an iterable to the send_from method - providing an iterable to the send_from method
- passing them to the send method one at a time - passing them to the send method one at a time
Items can be recieved from the channel by either: Items can be received from the channel by either:
- iterating over the channel with a for loop to get all items - iterating over the channel with a for loop to get all items
- calling the recieve method to get one item at a time - calling the receive method to get one item at a time
If the channel is empty then recievers will wait until either an item appears or the If the channel is empty then receivers will wait until either an item appears or the
channel is closed. channel is closed.
Once the channel is closed then subsequent attempt to send through the channel will Once the channel is closed then subsequent attempt to send through the channel will
fail with a ChannelClosed exception. fail with a ChannelClosed exception.
When th channel is closed and empty then it is done, and further attempts to recieve When th channel is closed and empty then it is done, and further attempts to receive
from it will fail with a ChannelDone exception from it will fail with a ChannelDone exception
If multiple coroutines recieve from the channel concurrently, each item sent will be If multiple coroutines receive from the channel concurrently, each item sent will be
recieved by only one of the recievers. received by only one of the receivers.
:param source: :param source:
An optional iterable will items that should be sent through the channel An optional iterable will items that should be sent through the channel
@ -74,7 +74,7 @@ class AsyncChannel(AsyncIterable[T]):
:param buffer_limit: :param buffer_limit:
Limit the number of items that can be buffered in the channel, A value less than Limit the number of items that can be buffered in the channel, A value less than
1 implies no limit. If the channel is full then attempts to send more items will 1 implies no limit. If the channel is full then attempts to send more items will
result in the sender waiting until an item is recieved from the channel. result in the sender waiting until an item is received from the channel.
:param close: :param close:
If set to True then the channel will automatically close after exhausting source If set to True then the channel will automatically close after exhausting source
or immediately if no source is provided. or immediately if no source is provided.
@ -85,7 +85,7 @@ class AsyncChannel(AsyncIterable[T]):
): ):
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit) self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
self._closed = False self._closed = False
self._waiting_recievers: int = 0 self._waiting_receivers: int = 0
# Track whether flush has been invoked so it can only happen once # Track whether flush has been invoked so it can only happen once
self._flushed = False self._flushed = False
@ -95,14 +95,14 @@ class AsyncChannel(AsyncIterable[T]):
async def __anext__(self) -> T: async def __anext__(self) -> T:
if self.done(): if self.done():
raise StopAsyncIteration raise StopAsyncIteration
self._waiting_recievers += 1 self._waiting_receivers += 1
try: try:
result = await self._queue.get() result = await self._queue.get()
if result is self.__flush: if result is self.__flush:
raise StopAsyncIteration raise StopAsyncIteration
return result return result
finally: finally:
self._waiting_recievers -= 1 self._waiting_receivers -= 1
self._queue.task_done() self._queue.task_done()
def closed(self) -> bool: def closed(self) -> bool:
@ -116,12 +116,12 @@ class AsyncChannel(AsyncIterable[T]):
Check if this channel is done. Check if this channel is done.
:return: True if this channel is closed and and has been drained of items in :return: True if this channel is closed and and has been drained of items in
which case any further attempts to recieve an item from this channel will raise which case any further attempts to receive an item from this channel will raise
a ChannelDone exception. a ChannelDone exception.
""" """
# After close the channel is not yet done until there is at least one waiting # After close the channel is not yet done until there is at least one waiting
# reciever per enqueued item. # receiver per enqueued item.
return self._closed and self._queue.qsize() <= self._waiting_recievers return self._closed and self._queue.qsize() <= self._waiting_receivers
async def send_from( async def send_from(
self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False
@ -158,22 +158,22 @@ class AsyncChannel(AsyncIterable[T]):
await self._queue.put(item) await self._queue.put(item)
return self return self
async def recieve(self) -> Optional[T]: async def receive(self) -> Optional[T]:
""" """
Returns the next item from this channel when it becomes available, Returns the next item from this channel when it becomes available,
or None if the channel is closed before another item is sent. or None if the channel is closed before another item is sent.
:return: An item from the channel :return: An item from the channel
""" """
if self.done(): if self.done():
raise ChannelDone("Cannot recieve from a closed channel") raise ChannelDone("Cannot receive from a closed channel")
self._waiting_recievers += 1 self._waiting_receivers += 1
try: try:
result = await self._queue.get() result = await self._queue.get()
if result is self.__flush: if result is self.__flush:
return None return None
return result return result
finally: finally:
self._waiting_recievers -= 1 self._waiting_receivers -= 1
self._queue.task_done() self._queue.task_done()
def close(self): def close(self):
@ -190,8 +190,8 @@ class AsyncChannel(AsyncIterable[T]):
""" """
if not self._flushed: if not self._flushed:
self._flushed = True self._flushed = True
deadlocked_recievers = max(0, self._waiting_recievers - self._queue.qsize()) deadlocked_receivers = max(0, self._waiting_receivers - self._queue.qsize())
for _ in range(deadlocked_recievers): for _ in range(deadlocked_receivers):
await self._queue.put(self.__flush) await self._queue.put(self.__flush)
# A special signal object for flushing the queue when the channel is closed # A special signal object for flushing the queue when the channel is closed

View File

@ -199,7 +199,7 @@ def generate_code(request, response):
# Render and then format the output file. # Render and then format the output file.
f.content = black.format_str( f.content = black.format_str(
template.render(description=template_data), template.render(description=template_data),
mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])), mode=black.FileMode(target_versions={black.TargetVersion.PY37}),
) )
# Make each output directory a package with __init__ file # Make each output directory a package with __init__ file

View File

@ -3,7 +3,6 @@ import asyncio
import os import os
from pathlib import Path from pathlib import Path
import shutil import shutil
import subprocess
import sys import sys
from typing import Set from typing import Set

View File

@ -3,7 +3,6 @@ from betterproto.tests.output_betterproto.service.service import (
DoThingResponse, DoThingResponse,
DoThingRequest, DoThingRequest,
GetThingRequest, GetThingRequest,
GetThingResponse,
TestStub as ThingServiceClient, TestStub as ThingServiceClient,
) )
import grpclib import grpclib
@ -18,14 +17,14 @@ async def _test_client(client, name="clean room", **kwargs):
assert response.names == [name] assert response.names == [name]
def _assert_request_meta_recieved(deadline, metadata): def _assert_request_meta_received(deadline, metadata):
def server_side_test(stream): def server_side_test(stream):
assert stream.deadline._timestamp == pytest.approx( assert stream.deadline._timestamp == pytest.approx(
deadline._timestamp, 1 deadline._timestamp, 1
), "The provided deadline should be recieved serverside" ), "The provided deadline should be received serverside"
assert ( assert (
stream.metadata["authorization"] == metadata["authorization"] stream.metadata["authorization"] == metadata["authorization"]
), "The provided authorization metadata should be recieved serverside" ), "The provided authorization metadata should be received serverside"
return server_side_test return server_side_test
@ -42,7 +41,7 @@ async def test_service_call_with_upfront_request_params():
deadline = grpclib.metadata.Deadline.from_timeout(22) deadline = grpclib.metadata.Deadline.from_timeout(22)
metadata = {"authorization": "12345"} metadata = {"authorization": "12345"}
async with ChannelFor( async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)] [ThingService(test_hook=_assert_request_meta_received(deadline, metadata),)]
) as channel: ) as channel:
await _test_client( await _test_client(
ThingServiceClient(channel, deadline=deadline, metadata=metadata) ThingServiceClient(channel, deadline=deadline, metadata=metadata)
@ -53,7 +52,7 @@ async def test_service_call_with_upfront_request_params():
deadline = grpclib.metadata.Deadline.from_timeout(timeout) deadline = grpclib.metadata.Deadline.from_timeout(timeout)
metadata = {"authorization": "12345"} metadata = {"authorization": "12345"}
async with ChannelFor( async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)] [ThingService(test_hook=_assert_request_meta_received(deadline, metadata),)]
) as channel: ) as channel:
await _test_client( await _test_client(
ThingServiceClient(channel, timeout=timeout, metadata=metadata) ThingServiceClient(channel, timeout=timeout, metadata=metadata)
@ -70,7 +69,7 @@ async def test_service_call_lower_level_with_overrides():
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28) kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
kwarg_metadata = {"authorization": "12345"} kwarg_metadata = {"authorization": "12345"}
async with ChannelFor( async with ChannelFor(
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)] [ThingService(test_hook=_assert_request_meta_received(deadline, metadata),)]
) as channel: ) as channel:
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata) client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
response = await client._unary_unary( response = await client._unary_unary(
@ -92,7 +91,7 @@ async def test_service_call_lower_level_with_overrides():
async with ChannelFor( async with ChannelFor(
[ [
ThingService( ThingService(
test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata), test_hook=_assert_request_meta_received(kwarg_deadline, kwarg_metadata),
) )
] ]
) as channel: ) as channel:
@ -140,8 +139,8 @@ async def test_async_gen_for_stream_stream_request():
assert response.version == response_index + 1 assert response.version == response_index + 1
response_index += 1 response_index += 1
if more_things: if more_things:
# Send some more requests as we recieve reponses to be sure coordination of # Send some more requests as we receive responses to be sure coordination of
# send/recieve events doesn't matter # send/receive events doesn't matter
await request_chan.send(GetThingRequest(more_things.pop(0))) await request_chan.send(GetThingRequest(more_things.pop(0)))
elif not send_initial_requests.done(): elif not send_initial_requests.done():
# Make sure the sending task it completed # Make sure the sending task it completed
@ -151,4 +150,4 @@ async def test_async_gen_for_stream_stream_request():
request_chan.close() request_chan.close()
assert response_index == len( assert response_index == len(
expected_things expected_things
), "Didn't recieve all exptected responses" ), "Didn't receive all expected responses"

View File

@ -3,10 +3,10 @@ from betterproto.tests.output_betterproto.service.service import (
DoThingRequest, DoThingRequest,
GetThingRequest, GetThingRequest,
GetThingResponse, GetThingResponse,
TestStub as ThingServiceClient,
) )
import grpclib import grpclib
from typing import Any, Dict import grpclib.server
from typing import Dict
class ThingService: class ThingService:

View File

@ -21,7 +21,7 @@ test_cases = [
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_channel_recieves_wrapped_type( async def test_channel_receives_wrapped_type(
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
): ):
wrapped_value = wrapper_class() wrapped_value = wrapper_class()

View File

@ -14,23 +14,23 @@ def test_has_field():
# Unset by default # Unset by default
foo = Foo() foo = Foo()
assert betterproto.serialized_on_wire(foo.bar) == False assert betterproto.serialized_on_wire(foo.bar) is False
# Serialized after setting something # Serialized after setting something
foo.bar.baz = 1 foo.bar.baz = 1
assert betterproto.serialized_on_wire(foo.bar) == True assert betterproto.serialized_on_wire(foo.bar) is True
# Still has it after setting the default value # Still has it after setting the default value
foo.bar.baz = 0 foo.bar.baz = 0
assert betterproto.serialized_on_wire(foo.bar) == True assert betterproto.serialized_on_wire(foo.bar) is True
# Manual override (don't do this) # Manual override (don't do this)
foo.bar._serialized_on_wire = False foo.bar._serialized_on_wire = False
assert betterproto.serialized_on_wire(foo.bar) == False assert betterproto.serialized_on_wire(foo.bar) is False
# Can manually set it but defaults to false # Can manually set it but defaults to false
foo.bar = Bar() foo.bar = Bar()
assert betterproto.serialized_on_wire(foo.bar) == False assert betterproto.serialized_on_wire(foo.bar) is False
def test_class_init(): def test_class_init():
@ -118,7 +118,7 @@ def test_oneof_support():
# Group 1 shouldn't be touched, group 2 should have reset # Group 1 shouldn't be touched, group 2 should have reset
assert foo.sub.val == 0 assert foo.sub.val == 0
assert betterproto.serialized_on_wire(foo.sub) == False assert betterproto.serialized_on_wire(foo.sub) is False
assert betterproto.which_one_of(foo, "group2")[0] == "abc" assert betterproto.which_one_of(foo, "group2")[0] == "abc"
# Zero value should always serialize for one-of # Zero value should always serialize for one-of
@ -175,8 +175,8 @@ def test_optional_flag():
assert bytes(Request(flag=False)) == b"\n\x00" assert bytes(Request(flag=False)) == b"\n\x00"
# Differentiate between not passed and the zero-value. # Differentiate between not passed and the zero-value.
assert Request().parse(b"").flag == None assert Request().parse(b"").flag is None
assert Request().parse(b"\n\x00").flag == False assert Request().parse(b"\n\x00").flag is False
def test_to_dict_default_values(): def test_to_dict_default_values():