Merge remote-tracking branch 'daniel/master' into fix/service-input-message
# Conflicts: # betterproto/plugin.py
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -14,3 +14,4 @@ output | ||||
| .idea | ||||
| .DS_Store | ||||
| .tox | ||||
| .venv | ||||
|   | ||||
| @@ -22,7 +22,7 @@ from typing import ( | ||||
| ) | ||||
|  | ||||
| from ._types import T | ||||
| from .casing import camel_case, safe_snake_case, safe_snake_case, snake_case | ||||
| from .casing import camel_case, safe_snake_case, snake_case | ||||
| from .grpc.grpclib_client import ServiceStub | ||||
|  | ||||
| if not (sys.version_info.major == 3 and sys.version_info.minor >= 7): | ||||
| @@ -378,7 +378,7 @@ def decode_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, i | ||||
|         result |= (b & 0x7F) << shift | ||||
|         pos += 1 | ||||
|         if not (b & 0x80): | ||||
|             return (result, pos) | ||||
|             return result, pos | ||||
|         shift += 7 | ||||
|         if shift >= 64: | ||||
|             raise ValueError("Too many bytes when decoding varint.") | ||||
| @@ -479,7 +479,7 @@ class ProtoClassMetadata: | ||||
|                 assert meta.map_types | ||||
|                 kt = cls._cls_for(field, index=0) | ||||
|                 vt = cls._cls_for(field, index=1) | ||||
|                 Entry = dataclasses.make_dataclass( | ||||
|                 field_cls[field.name] = dataclasses.make_dataclass( | ||||
|                     "Entry", | ||||
|                     [ | ||||
|                         ("key", kt, dataclass_field(1, meta.map_types[0])), | ||||
| @@ -487,7 +487,6 @@ class ProtoClassMetadata: | ||||
|                     ], | ||||
|                     bases=(Message,), | ||||
|                 ) | ||||
|                 field_cls[field.name] = Entry | ||||
|                 field_cls[field.name + ".value"] = vt | ||||
|             else: | ||||
|                 field_cls[field.name] = cls._cls_for(field) | ||||
| @@ -588,7 +587,7 @@ class Message(ABC): | ||||
|             serialize_empty = False | ||||
|             if isinstance(value, Message) and value._serialized_on_wire: | ||||
|                 # Empty messages can still be sent on the wire if they were | ||||
|                 # set (or recieved empty). | ||||
|                 # set (or received empty). | ||||
|                 serialize_empty = True | ||||
|  | ||||
|             if value == self._get_field_default(field_name) and not ( | ||||
| @@ -926,7 +925,7 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]: | ||||
|  | ||||
|  | ||||
| # Circular import workaround: google.protobuf depends on base classes defined above. | ||||
| from .lib.google.protobuf import ( | ||||
| from .lib.google.protobuf import (  # noqa | ||||
|     Duration, | ||||
|     Timestamp, | ||||
|     BoolValue, | ||||
|   | ||||
| @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, TypeVar | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from . import Message | ||||
|     from grpclib._protocols import IProtoMessage | ||||
|     from grpclib._typing import IProtoMessage | ||||
|  | ||||
| # Bound type variable to allow methods to return `self` of subclasses | ||||
| T = TypeVar("T", bound="Message") | ||||
|   | ||||
| @@ -20,6 +20,8 @@ def safe_snake_case(value: str) -> str: | ||||
|         "and", | ||||
|         "as", | ||||
|         "assert", | ||||
|         "async", | ||||
|         "await", | ||||
|         "break", | ||||
|         "class", | ||||
|         "continue", | ||||
|   | ||||
| @@ -2,7 +2,6 @@ from abc import ABC | ||||
| import asyncio | ||||
| import grpclib.const | ||||
| from typing import ( | ||||
|     Any, | ||||
|     AsyncIterable, | ||||
|     AsyncIterator, | ||||
|     Collection, | ||||
| @@ -17,8 +16,8 @@ from typing import ( | ||||
| from .._types import ST, T | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from grpclib._protocols import IProtoMessage | ||||
|     from grpclib.client import Channel, Stream | ||||
|     from grpclib._typing import IProtoMessage | ||||
|     from grpclib.client import Channel | ||||
|     from grpclib.metadata import Deadline | ||||
|  | ||||
|  | ||||
|   | ||||
| @@ -21,7 +21,7 @@ class ChannelClosed(Exception): | ||||
|  | ||||
| class ChannelDone(Exception): | ||||
|     """ | ||||
|     An exception raised on an attempt to send recieve from a channel that is both closed | ||||
|     An exception raised on an attempt to send receive from a channel that is both closed | ||||
|     and empty. | ||||
|     """ | ||||
|  | ||||
| @@ -32,41 +32,41 @@ class AsyncChannel(AsyncIterable[T]): | ||||
|     """ | ||||
|     A buffered async channel for sending items between coroutines with FIFO ordering. | ||||
|  | ||||
|     This makes decoupled bidirection steaming gRPC requests easy if used like: | ||||
|     This makes decoupled bidirectional steaming gRPC requests easy if used like: | ||||
|  | ||||
|     .. code-block:: python | ||||
|         client = GeneratedStub(grpclib_chan) | ||||
|         request_chan = await AsyncChannel() | ||||
|         request_channel = await AsyncChannel() | ||||
|         # We can start be sending all the requests we already have | ||||
|         await request_chan.send_from([ReqestObject(...), ReqestObject(...)]) | ||||
|         async for response in client.rpc_call(request_chan): | ||||
|         await request_channel.send_from([RequestObject(...), RequestObject(...)]) | ||||
|         async for response in client.rpc_call(request_channel): | ||||
|             # The response iterator will remain active until the connection is closed | ||||
|             ... | ||||
|             # More items can be sent at any time | ||||
|             await request_chan.send(ReqestObject(...)) | ||||
|             await request_channel.send(RequestObject(...)) | ||||
|             ... | ||||
|             # The channel must be closed to complete the gRPC connection | ||||
|             request_chan.close() | ||||
|             request_channel.close() | ||||
|  | ||||
|     Items can be sent through the channel by either: | ||||
|     - providing an iterable to the send_from method | ||||
|     - passing them to the send method one at a time | ||||
|  | ||||
|     Items can be recieved from the channel by either: | ||||
|     Items can be received from the channel by either: | ||||
|     - iterating over the channel with a for loop to get all items | ||||
|     - calling the recieve method to get one item at a time | ||||
|     - calling the receive method to get one item at a time | ||||
|  | ||||
|     If the channel is empty then recievers will wait until either an item appears or the | ||||
|     If the channel is empty then receivers will wait until either an item appears or the | ||||
|     channel is closed. | ||||
|  | ||||
|     Once the channel is closed then subsequent attempt to send through the channel will | ||||
|     fail with a ChannelClosed exception. | ||||
|  | ||||
|     When th channel is closed and empty then it is done, and further attempts to recieve | ||||
|     When th channel is closed and empty then it is done, and further attempts to receive | ||||
|     from it will fail with a ChannelDone exception | ||||
|  | ||||
|     If multiple coroutines recieve from the channel concurrently, each item sent will be | ||||
|     recieved by only one of the recievers. | ||||
|     If multiple coroutines receive from the channel concurrently, each item sent will be | ||||
|     received by only one of the receivers. | ||||
|  | ||||
|     :param source: | ||||
|         An optional iterable will items that should be sent through the channel | ||||
| @@ -74,7 +74,7 @@ class AsyncChannel(AsyncIterable[T]): | ||||
|     :param buffer_limit: | ||||
|         Limit the number of items that can be buffered in the channel, A value less than | ||||
|         1 implies no limit. If the channel is full then attempts to send more items will | ||||
|         result in the sender waiting until an item is recieved from the channel. | ||||
|         result in the sender waiting until an item is received from the channel. | ||||
|     :param close: | ||||
|         If set to True then the channel will automatically close after exhausting source | ||||
|         or immediately if no source is provided. | ||||
| @@ -85,7 +85,7 @@ class AsyncChannel(AsyncIterable[T]): | ||||
|     ): | ||||
|         self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit) | ||||
|         self._closed = False | ||||
|         self._waiting_recievers: int = 0 | ||||
|         self._waiting_receivers: int = 0 | ||||
|         # Track whether flush has been invoked so it can only happen once | ||||
|         self._flushed = False | ||||
|  | ||||
| @@ -95,14 +95,14 @@ class AsyncChannel(AsyncIterable[T]): | ||||
|     async def __anext__(self) -> T: | ||||
|         if self.done(): | ||||
|             raise StopAsyncIteration | ||||
|         self._waiting_recievers += 1 | ||||
|         self._waiting_receivers += 1 | ||||
|         try: | ||||
|             result = await self._queue.get() | ||||
|             if result is self.__flush: | ||||
|                 raise StopAsyncIteration | ||||
|             return result | ||||
|         finally: | ||||
|             self._waiting_recievers -= 1 | ||||
|             self._waiting_receivers -= 1 | ||||
|             self._queue.task_done() | ||||
|  | ||||
|     def closed(self) -> bool: | ||||
| @@ -116,12 +116,12 @@ class AsyncChannel(AsyncIterable[T]): | ||||
|         Check if this channel is done. | ||||
|  | ||||
|         :return: True if this channel is closed and and has been drained of items in | ||||
|         which case any further attempts to recieve an item from this channel will raise | ||||
|         which case any further attempts to receive an item from this channel will raise | ||||
|         a ChannelDone exception. | ||||
|         """ | ||||
|         # After close the channel is not yet done until there is at least one waiting | ||||
|         # reciever per enqueued item. | ||||
|         return self._closed and self._queue.qsize() <= self._waiting_recievers | ||||
|         # receiver per enqueued item. | ||||
|         return self._closed and self._queue.qsize() <= self._waiting_receivers | ||||
|  | ||||
|     async def send_from( | ||||
|         self, source: Union[Iterable[T], AsyncIterable[T]], close: bool = False | ||||
| @@ -158,22 +158,22 @@ class AsyncChannel(AsyncIterable[T]): | ||||
|         await self._queue.put(item) | ||||
|         return self | ||||
|  | ||||
|     async def recieve(self) -> Optional[T]: | ||||
|     async def receive(self) -> Optional[T]: | ||||
|         """ | ||||
|         Returns the next item from this channel when it becomes available, | ||||
|         or None if the channel is closed before another item is sent. | ||||
|         :return: An item from the channel | ||||
|         """ | ||||
|         if self.done(): | ||||
|             raise ChannelDone("Cannot recieve from a closed channel") | ||||
|         self._waiting_recievers += 1 | ||||
|             raise ChannelDone("Cannot receive from a closed channel") | ||||
|         self._waiting_receivers += 1 | ||||
|         try: | ||||
|             result = await self._queue.get() | ||||
|             if result is self.__flush: | ||||
|                 return None | ||||
|             return result | ||||
|         finally: | ||||
|             self._waiting_recievers -= 1 | ||||
|             self._waiting_receivers -= 1 | ||||
|             self._queue.task_done() | ||||
|  | ||||
|     def close(self): | ||||
| @@ -190,8 +190,8 @@ class AsyncChannel(AsyncIterable[T]): | ||||
|         """ | ||||
|         if not self._flushed: | ||||
|             self._flushed = True | ||||
|             deadlocked_recievers = max(0, self._waiting_recievers - self._queue.qsize()) | ||||
|             for _ in range(deadlocked_recievers): | ||||
|             deadlocked_receivers = max(0, self._waiting_receivers - self._queue.qsize()) | ||||
|             for _ in range(deadlocked_receivers): | ||||
|                 await self._queue.put(self.__flush) | ||||
|  | ||||
|     # A special signal object for flushing the queue when the channel is closed | ||||
|   | ||||
| @@ -199,7 +199,7 @@ def generate_code(request, response): | ||||
|         # Render and then format the output file. | ||||
|         f.content = black.format_str( | ||||
|             template.render(description=template_data), | ||||
|             mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])), | ||||
|             mode=black.FileMode(target_versions={black.TargetVersion.PY37}), | ||||
|         ) | ||||
|  | ||||
|     # Make each output directory a package with __init__ file | ||||
|   | ||||
| @@ -3,7 +3,6 @@ import asyncio | ||||
| import os | ||||
| from pathlib import Path | ||||
| import shutil | ||||
| import subprocess | ||||
| import sys | ||||
| from typing import Set | ||||
|  | ||||
|   | ||||
| @@ -3,7 +3,6 @@ from betterproto.tests.output_betterproto.service.service import ( | ||||
|     DoThingResponse, | ||||
|     DoThingRequest, | ||||
|     GetThingRequest, | ||||
|     GetThingResponse, | ||||
|     TestStub as ThingServiceClient, | ||||
| ) | ||||
| import grpclib | ||||
| @@ -18,14 +17,14 @@ async def _test_client(client, name="clean room", **kwargs): | ||||
|     assert response.names == [name] | ||||
|  | ||||
|  | ||||
| def _assert_request_meta_recieved(deadline, metadata): | ||||
| def _assert_request_meta_received(deadline, metadata): | ||||
|     def server_side_test(stream): | ||||
|         assert stream.deadline._timestamp == pytest.approx( | ||||
|             deadline._timestamp, 1 | ||||
|         ), "The provided deadline should be recieved serverside" | ||||
|         ), "The provided deadline should be received serverside" | ||||
|         assert ( | ||||
|             stream.metadata["authorization"] == metadata["authorization"] | ||||
|         ), "The provided authorization metadata should be recieved serverside" | ||||
|         ), "The provided authorization metadata should be received serverside" | ||||
|  | ||||
|     return server_side_test | ||||
|  | ||||
| @@ -42,7 +41,7 @@ async def test_service_call_with_upfront_request_params(): | ||||
|     deadline = grpclib.metadata.Deadline.from_timeout(22) | ||||
|     metadata = {"authorization": "12345"} | ||||
|     async with ChannelFor( | ||||
|         [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)] | ||||
|         [ThingService(test_hook=_assert_request_meta_received(deadline, metadata),)] | ||||
|     ) as channel: | ||||
|         await _test_client( | ||||
|             ThingServiceClient(channel, deadline=deadline, metadata=metadata) | ||||
| @@ -53,7 +52,7 @@ async def test_service_call_with_upfront_request_params(): | ||||
|     deadline = grpclib.metadata.Deadline.from_timeout(timeout) | ||||
|     metadata = {"authorization": "12345"} | ||||
|     async with ChannelFor( | ||||
|         [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)] | ||||
|         [ThingService(test_hook=_assert_request_meta_received(deadline, metadata),)] | ||||
|     ) as channel: | ||||
|         await _test_client( | ||||
|             ThingServiceClient(channel, timeout=timeout, metadata=metadata) | ||||
| @@ -70,7 +69,7 @@ async def test_service_call_lower_level_with_overrides(): | ||||
|     kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28) | ||||
|     kwarg_metadata = {"authorization": "12345"} | ||||
|     async with ChannelFor( | ||||
|         [ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)] | ||||
|         [ThingService(test_hook=_assert_request_meta_received(deadline, metadata),)] | ||||
|     ) as channel: | ||||
|         client = ThingServiceClient(channel, deadline=deadline, metadata=metadata) | ||||
|         response = await client._unary_unary( | ||||
| @@ -92,7 +91,7 @@ async def test_service_call_lower_level_with_overrides(): | ||||
|     async with ChannelFor( | ||||
|         [ | ||||
|             ThingService( | ||||
|                 test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata), | ||||
|                 test_hook=_assert_request_meta_received(kwarg_deadline, kwarg_metadata), | ||||
|             ) | ||||
|         ] | ||||
|     ) as channel: | ||||
| @@ -140,8 +139,8 @@ async def test_async_gen_for_stream_stream_request(): | ||||
|             assert response.version == response_index + 1 | ||||
|             response_index += 1 | ||||
|             if more_things: | ||||
|                 # Send some more requests as we recieve reponses to be sure coordination of | ||||
|                 # send/recieve events doesn't matter | ||||
|                 # Send some more requests as we receive responses to be sure coordination of | ||||
|                 # send/receive events doesn't matter | ||||
|                 await request_chan.send(GetThingRequest(more_things.pop(0))) | ||||
|             elif not send_initial_requests.done(): | ||||
|                 # Make sure the sending task it completed | ||||
| @@ -151,4 +150,4 @@ async def test_async_gen_for_stream_stream_request(): | ||||
|                 request_chan.close() | ||||
|         assert response_index == len( | ||||
|             expected_things | ||||
|         ), "Didn't recieve all exptected responses" | ||||
|         ), "Didn't receive all expected responses" | ||||
|   | ||||
| @@ -3,10 +3,10 @@ from betterproto.tests.output_betterproto.service.service import ( | ||||
|     DoThingRequest, | ||||
|     GetThingRequest, | ||||
|     GetThingResponse, | ||||
|     TestStub as ThingServiceClient, | ||||
| ) | ||||
| import grpclib | ||||
| from typing import Any, Dict | ||||
| import grpclib.server | ||||
| from typing import Dict | ||||
|  | ||||
|  | ||||
| class ThingService: | ||||
|   | ||||
| @@ -21,7 +21,7 @@ test_cases = [ | ||||
|  | ||||
| @pytest.mark.asyncio | ||||
| @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) | ||||
| async def test_channel_recieves_wrapped_type( | ||||
| async def test_channel_receives_wrapped_type( | ||||
|     service_method: Callable[[TestStub], Any], wrapper_class: Callable, value | ||||
| ): | ||||
|     wrapped_value = wrapper_class() | ||||
|   | ||||
| @@ -14,23 +14,23 @@ def test_has_field(): | ||||
|  | ||||
|     # Unset by default | ||||
|     foo = Foo() | ||||
|     assert betterproto.serialized_on_wire(foo.bar) == False | ||||
|     assert betterproto.serialized_on_wire(foo.bar) is False | ||||
|  | ||||
|     # Serialized after setting something | ||||
|     foo.bar.baz = 1 | ||||
|     assert betterproto.serialized_on_wire(foo.bar) == True | ||||
|     assert betterproto.serialized_on_wire(foo.bar) is True | ||||
|  | ||||
|     # Still has it after setting the default value | ||||
|     foo.bar.baz = 0 | ||||
|     assert betterproto.serialized_on_wire(foo.bar) == True | ||||
|     assert betterproto.serialized_on_wire(foo.bar) is True | ||||
|  | ||||
|     # Manual override (don't do this) | ||||
|     foo.bar._serialized_on_wire = False | ||||
|     assert betterproto.serialized_on_wire(foo.bar) == False | ||||
|     assert betterproto.serialized_on_wire(foo.bar) is False | ||||
|  | ||||
|     # Can manually set it but defaults to false | ||||
|     foo.bar = Bar() | ||||
|     assert betterproto.serialized_on_wire(foo.bar) == False | ||||
|     assert betterproto.serialized_on_wire(foo.bar) is False | ||||
|  | ||||
|  | ||||
| def test_class_init(): | ||||
| @@ -118,7 +118,7 @@ def test_oneof_support(): | ||||
|  | ||||
|     # Group 1 shouldn't be touched, group 2 should have reset | ||||
|     assert foo.sub.val == 0 | ||||
|     assert betterproto.serialized_on_wire(foo.sub) == False | ||||
|     assert betterproto.serialized_on_wire(foo.sub) is False | ||||
|     assert betterproto.which_one_of(foo, "group2")[0] == "abc" | ||||
|  | ||||
|     # Zero value should always serialize for one-of | ||||
| @@ -175,8 +175,8 @@ def test_optional_flag(): | ||||
|     assert bytes(Request(flag=False)) == b"\n\x00" | ||||
|  | ||||
|     # Differentiate between not passed and the zero-value. | ||||
|     assert Request().parse(b"").flag == None | ||||
|     assert Request().parse(b"\n\x00").flag == False | ||||
|     assert Request().parse(b"").flag is None | ||||
|     assert Request().parse(b"\n\x00").flag is False | ||||
|  | ||||
|  | ||||
| def test_to_dict_default_values(): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user