From 706bd5a475e73a297bc5217742346d9eeb597ccf Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Mon, 28 Oct 2019 20:58:33 -0700 Subject: [PATCH] Slightly simplify gRPC helper functions --- betterproto/__init__.py | 18 ++++++++++++------ betterproto/templates/template.py | 6 ++---- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/betterproto/__init__.py b/betterproto/__init__.py index a9bc4b6..cb9cd29 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -21,6 +21,7 @@ from typing import ( TypeVar, Union, get_type_hints, + TYPE_CHECKING, ) import grpclib.client @@ -29,6 +30,9 @@ import stringcase from .casing import safe_snake_case +if TYPE_CHECKING: + from grpclib._protocols import IProtoMessage + # Proto 3 data types TYPE_ENUM = "enum" TYPE_BOOL = "bool" @@ -420,6 +424,7 @@ class Message(ABC): register the message fields which get used by the serializers and parsers to go between Python, binary and JSON protobuf message representations. """ + _serialized_on_wire: bool _unknown_fields: bytes _group_map: Dict[str, dict] @@ -705,7 +710,7 @@ class Message(ABC): for field in dataclasses.fields(self): meta = FieldMetadata.get(field) v = getattr(self, field.name) - cased_name = casing(field.name).rstrip("_") # type: ignore + cased_name = casing(field.name).rstrip("_") # type: ignore if meta.proto_type == "message": if isinstance(v, datetime): if v != DATETIME_ZERO: @@ -741,7 +746,7 @@ class Message(ABC): else: output[cased_name] = b64encode(v).decode("utf8") elif meta.proto_type == TYPE_ENUM: - enum_values = list(self._cls_for(field)) # type: ignore + enum_values = list(self._cls_for(field)) # type: ignore if isinstance(v, list): output[cased_name] = [enum_values[e].name for e in v] else: @@ -902,6 +907,7 @@ class _WrappedMessage(Message): Google protobuf wrapper types base class. JSON representation is just the value itself. """ + value: Any def to_dict(self, casing: Casing = Casing.CAMEL) -> Any: @@ -982,11 +988,11 @@ class ServiceStub(ABC): self.channel = channel async def _unary_unary( - self, route: str, request_type: Type, response_type: Type[T], request: Any + self, route: str, request: "IProtoMessage", response_type: Type[T] ) -> T: """Make a unary request and return the response.""" async with self.channel.request( - route, grpclib.const.Cardinality.UNARY_UNARY, request_type, response_type + route, grpclib.const.Cardinality.UNARY_UNARY, type(request), response_type ) as stream: await stream.send_message(request, end=True) response = await stream.recv_message() @@ -994,11 +1000,11 @@ class ServiceStub(ABC): return response async def _unary_stream( - self, route: str, request_type: Type, response_type: Type[T], request: Any + self, route: str, request: "IProtoMessage", response_type: Type[T] ) -> AsyncGenerator[T, None]: """Make a unary request and return the stream response iterator.""" async with self.channel.request( - route, grpclib.const.Cardinality.UNARY_STREAM, request_type, response_type + route, grpclib.const.Cardinality.UNARY_STREAM, type(request), response_type ) as stream: await stream.send_message(request, end=True) async for message in stream: diff --git a/betterproto/templates/template.py b/betterproto/templates/template.py index eac4595..d0f14e8 100644 --- a/betterproto/templates/template.py +++ b/betterproto/templates/template.py @@ -81,17 +81,15 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% if method.server_streaming %} async for response in self._unary_stream( "{{ method.route }}", - {{ method.input }}, - {{ method.output }}, request, + {{ method.output }}, ): yield response {% else %} return await self._unary_unary( "{{ method.route }}", - {{ method.input }}, - {{ method.output }}, request, + {{ method.output }}, ) {% endif %}