Slightly simplify gRPC helper functions

This commit is contained in:
Daniel G. Taylor 2019-10-28 20:58:33 -07:00
parent 52beeb0d73
commit 706bd5a475
No known key found for this signature in database
GPG Key ID: 7BD6DC99C9A87E22
2 changed files with 14 additions and 10 deletions

View File

@ -21,6 +21,7 @@ from typing import (
TypeVar,
Union,
get_type_hints,
TYPE_CHECKING,
)
import grpclib.client
@ -29,6 +30,9 @@ import stringcase
from .casing import safe_snake_case
if TYPE_CHECKING:
from grpclib._protocols import IProtoMessage
# Proto 3 data types
TYPE_ENUM = "enum"
TYPE_BOOL = "bool"
@ -420,6 +424,7 @@ class Message(ABC):
register the message fields which get used by the serializers and parsers
to go between Python, binary and JSON protobuf message representations.
"""
_serialized_on_wire: bool
_unknown_fields: bytes
_group_map: Dict[str, dict]
@ -705,7 +710,7 @@ class Message(ABC):
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
v = getattr(self, field.name)
cased_name = casing(field.name).rstrip("_") # type: ignore
cased_name = casing(field.name).rstrip("_") # type: ignore
if meta.proto_type == "message":
if isinstance(v, datetime):
if v != DATETIME_ZERO:
@ -741,7 +746,7 @@ class Message(ABC):
else:
output[cased_name] = b64encode(v).decode("utf8")
elif meta.proto_type == TYPE_ENUM:
enum_values = list(self._cls_for(field)) # type: ignore
enum_values = list(self._cls_for(field)) # type: ignore
if isinstance(v, list):
output[cased_name] = [enum_values[e].name for e in v]
else:
@ -902,6 +907,7 @@ class _WrappedMessage(Message):
Google protobuf wrapper types base class. JSON representation is just the
value itself.
"""
value: Any
def to_dict(self, casing: Casing = Casing.CAMEL) -> Any:
@ -982,11 +988,11 @@ class ServiceStub(ABC):
self.channel = channel
async def _unary_unary(
self, route: str, request_type: Type, response_type: Type[T], request: Any
self, route: str, request: "IProtoMessage", response_type: Type[T]
) -> T:
"""Make a unary request and return the response."""
async with self.channel.request(
route, grpclib.const.Cardinality.UNARY_UNARY, request_type, response_type
route, grpclib.const.Cardinality.UNARY_UNARY, type(request), response_type
) as stream:
await stream.send_message(request, end=True)
response = await stream.recv_message()
@ -994,11 +1000,11 @@ class ServiceStub(ABC):
return response
async def _unary_stream(
self, route: str, request_type: Type, response_type: Type[T], request: Any
self, route: str, request: "IProtoMessage", response_type: Type[T]
) -> AsyncGenerator[T, None]:
"""Make a unary request and return the stream response iterator."""
async with self.channel.request(
route, grpclib.const.Cardinality.UNARY_STREAM, request_type, response_type
route, grpclib.const.Cardinality.UNARY_STREAM, type(request), response_type
) as stream:
await stream.send_message(request, end=True)
async for message in stream:

View File

@ -81,17 +81,15 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% if method.server_streaming %}
async for response in self._unary_stream(
"{{ method.route }}",
{{ method.input }},
{{ method.output }},
request,
{{ method.output }},
):
yield response
{% else %}
return await self._unary_unary(
"{{ method.route }}",
{{ method.input }},
{{ method.output }},
request,
{{ method.output }},
)
{% endif %}