Slightly simplify gRPC helper functions
This commit is contained in:
parent
52beeb0d73
commit
706bd5a475
@ -21,6 +21,7 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
get_type_hints,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
import grpclib.client
|
||||
@ -29,6 +30,9 @@ import stringcase
|
||||
|
||||
from .casing import safe_snake_case
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from grpclib._protocols import IProtoMessage
|
||||
|
||||
# Proto 3 data types
|
||||
TYPE_ENUM = "enum"
|
||||
TYPE_BOOL = "bool"
|
||||
@ -420,6 +424,7 @@ class Message(ABC):
|
||||
register the message fields which get used by the serializers and parsers
|
||||
to go between Python, binary and JSON protobuf message representations.
|
||||
"""
|
||||
|
||||
_serialized_on_wire: bool
|
||||
_unknown_fields: bytes
|
||||
_group_map: Dict[str, dict]
|
||||
@ -705,7 +710,7 @@ class Message(ABC):
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
v = getattr(self, field.name)
|
||||
cased_name = casing(field.name).rstrip("_") # type: ignore
|
||||
cased_name = casing(field.name).rstrip("_") # type: ignore
|
||||
if meta.proto_type == "message":
|
||||
if isinstance(v, datetime):
|
||||
if v != DATETIME_ZERO:
|
||||
@ -741,7 +746,7 @@ class Message(ABC):
|
||||
else:
|
||||
output[cased_name] = b64encode(v).decode("utf8")
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
enum_values = list(self._cls_for(field)) # type: ignore
|
||||
enum_values = list(self._cls_for(field)) # type: ignore
|
||||
if isinstance(v, list):
|
||||
output[cased_name] = [enum_values[e].name for e in v]
|
||||
else:
|
||||
@ -902,6 +907,7 @@ class _WrappedMessage(Message):
|
||||
Google protobuf wrapper types base class. JSON representation is just the
|
||||
value itself.
|
||||
"""
|
||||
|
||||
value: Any
|
||||
|
||||
def to_dict(self, casing: Casing = Casing.CAMEL) -> Any:
|
||||
@ -982,11 +988,11 @@ class ServiceStub(ABC):
|
||||
self.channel = channel
|
||||
|
||||
async def _unary_unary(
|
||||
self, route: str, request_type: Type, response_type: Type[T], request: Any
|
||||
self, route: str, request: "IProtoMessage", response_type: Type[T]
|
||||
) -> T:
|
||||
"""Make a unary request and return the response."""
|
||||
async with self.channel.request(
|
||||
route, grpclib.const.Cardinality.UNARY_UNARY, request_type, response_type
|
||||
route, grpclib.const.Cardinality.UNARY_UNARY, type(request), response_type
|
||||
) as stream:
|
||||
await stream.send_message(request, end=True)
|
||||
response = await stream.recv_message()
|
||||
@ -994,11 +1000,11 @@ class ServiceStub(ABC):
|
||||
return response
|
||||
|
||||
async def _unary_stream(
|
||||
self, route: str, request_type: Type, response_type: Type[T], request: Any
|
||||
self, route: str, request: "IProtoMessage", response_type: Type[T]
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""Make a unary request and return the stream response iterator."""
|
||||
async with self.channel.request(
|
||||
route, grpclib.const.Cardinality.UNARY_STREAM, request_type, response_type
|
||||
route, grpclib.const.Cardinality.UNARY_STREAM, type(request), response_type
|
||||
) as stream:
|
||||
await stream.send_message(request, end=True)
|
||||
async for message in stream:
|
||||
|
@ -81,17 +81,15 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
{% if method.server_streaming %}
|
||||
async for response in self._unary_stream(
|
||||
"{{ method.route }}",
|
||||
{{ method.input }},
|
||||
{{ method.output }},
|
||||
request,
|
||||
{{ method.output }},
|
||||
):
|
||||
yield response
|
||||
{% else %}
|
||||
return await self._unary_unary(
|
||||
"{{ method.route }}",
|
||||
{{ method.input }},
|
||||
{{ method.output }},
|
||||
request,
|
||||
{{ method.output }},
|
||||
)
|
||||
{% endif %}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user