Implement basic async gRPC support
This commit is contained in:
@@ -3,6 +3,7 @@ import json
|
||||
import struct
|
||||
from typing import (
|
||||
get_type_hints,
|
||||
AsyncGenerator,
|
||||
Union,
|
||||
Generator,
|
||||
Any,
|
||||
@@ -17,6 +18,9 @@ from typing import (
|
||||
)
|
||||
import dataclasses
|
||||
|
||||
import grpclib.client
|
||||
import grpclib.const
|
||||
|
||||
import inspect
|
||||
|
||||
# Proto 3 data types
|
||||
@@ -92,7 +96,14 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
|
||||
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
||||
|
||||
|
||||
def get_default(proto_type: int) -> Any:
|
||||
class _PLACEHOLDER:
|
||||
pass
|
||||
|
||||
|
||||
PLACEHOLDER: Any = _PLACEHOLDER()
|
||||
|
||||
|
||||
def get_default(proto_type: str) -> Any:
|
||||
"""Get the default (zero value) for a given type."""
|
||||
return {
|
||||
TYPE_BOOL: False,
|
||||
@@ -114,8 +125,6 @@ class FieldMetadata:
|
||||
proto_type: str
|
||||
# Map information if the proto_type is a map
|
||||
map_types: Optional[Tuple[str, str]]
|
||||
# Default value if given
|
||||
default: Any
|
||||
|
||||
@staticmethod
|
||||
def get(field: dataclasses.Field) -> "FieldMetadata":
|
||||
@@ -124,23 +133,12 @@ class FieldMetadata:
|
||||
|
||||
|
||||
def dataclass_field(
|
||||
number: int,
|
||||
proto_type: str,
|
||||
default: Any = None,
|
||||
map_types: Optional[Tuple[str, str]] = None,
|
||||
**kwargs: dict,
|
||||
number: int, proto_type: str, map_types: Optional[Tuple[str, str]] = None
|
||||
) -> dataclasses.Field:
|
||||
"""Creates a dataclass field with attached protobuf metadata."""
|
||||
if callable(default):
|
||||
kwargs["default_factory"] = default
|
||||
elif isinstance(default, dict) or isinstance(default, list):
|
||||
kwargs["default_factory"] = lambda: default
|
||||
else:
|
||||
kwargs["default"] = default
|
||||
|
||||
return dataclasses.field(
|
||||
**kwargs,
|
||||
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, default)},
|
||||
default=PLACEHOLDER,
|
||||
metadata={"betterproto": FieldMetadata(number, proto_type, map_types)},
|
||||
)
|
||||
|
||||
|
||||
@@ -149,68 +147,68 @@ def dataclass_field(
|
||||
# out at runtime. The generated dataclass variables are still typed correctly.
|
||||
|
||||
|
||||
def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
|
||||
return dataclass_field(number, TYPE_ENUM, default=default)
|
||||
def enum_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_ENUM)
|
||||
|
||||
|
||||
def bool_field(number: int, default: Union[bool, Type[Iterable]] = 0) -> Any:
|
||||
return dataclass_field(number, TYPE_BOOL, default=default)
|
||||
def bool_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_BOOL)
|
||||
|
||||
|
||||
def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
|
||||
return dataclass_field(number, TYPE_INT32, default=default)
|
||||
def int32_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_INT32)
|
||||
|
||||
|
||||
def int64_field(number: int, default: int = 0) -> Any:
|
||||
return dataclass_field(number, TYPE_INT64, default=default)
|
||||
def int64_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_INT64)
|
||||
|
||||
|
||||
def uint32_field(number: int, default: int = 0) -> Any:
|
||||
return dataclass_field(number, TYPE_UINT32, default=default)
|
||||
def uint32_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_UINT32)
|
||||
|
||||
|
||||
def uint64_field(number: int, default: int = 0) -> Any:
|
||||
return dataclass_field(number, TYPE_UINT64, default=default)
|
||||
def uint64_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_UINT64)
|
||||
|
||||
|
||||
def sint32_field(number: int, default: int = 0) -> Any:
|
||||
return dataclass_field(number, TYPE_SINT32, default=default)
|
||||
def sint32_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_SINT32)
|
||||
|
||||
|
||||
def sint64_field(number: int, default: int = 0) -> Any:
|
||||
return dataclass_field(number, TYPE_SINT64, default=default)
|
||||
def sint64_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_SINT64)
|
||||
|
||||
|
||||
def float_field(number: int, default: float = 0.0) -> Any:
|
||||
return dataclass_field(number, TYPE_FLOAT, default=default)
|
||||
def float_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_FLOAT)
|
||||
|
||||
|
||||
def double_field(number: int, default: float = 0.0) -> Any:
|
||||
return dataclass_field(number, TYPE_DOUBLE, default=default)
|
||||
def double_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_DOUBLE)
|
||||
|
||||
|
||||
def fixed32_field(number: int, default: float = 0.0) -> Any:
|
||||
return dataclass_field(number, TYPE_FIXED32, default=default)
|
||||
def fixed32_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_FIXED32)
|
||||
|
||||
|
||||
def fixed64_field(number: int, default: float = 0.0) -> Any:
|
||||
return dataclass_field(number, TYPE_FIXED64, default=default)
|
||||
def fixed64_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_FIXED64)
|
||||
|
||||
|
||||
def sfixed32_field(number: int, default: float = 0.0) -> Any:
|
||||
return dataclass_field(number, TYPE_SFIXED32, default=default)
|
||||
def sfixed32_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_SFIXED32)
|
||||
|
||||
|
||||
def sfixed64_field(number: int, default: float = 0.0) -> Any:
|
||||
return dataclass_field(number, TYPE_SFIXED64, default=default)
|
||||
def sfixed64_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_SFIXED64)
|
||||
|
||||
|
||||
def string_field(number: int, default: str = "") -> Any:
|
||||
return dataclass_field(number, TYPE_STRING, default=default)
|
||||
def string_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_STRING)
|
||||
|
||||
|
||||
def bytes_field(number: int, default: bytes = b"") -> Any:
|
||||
return dataclass_field(number, TYPE_BYTES, default=default)
|
||||
def bytes_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_BYTES)
|
||||
|
||||
|
||||
def message_field(number: int) -> Any:
|
||||
@@ -218,9 +216,7 @@ def message_field(number: int) -> Any:
|
||||
|
||||
|
||||
def map_field(number: int, key_type: str, value_type: str) -> Any:
|
||||
return dataclass_field(
|
||||
number, TYPE_MAP, default=dict, map_types=(key_type, value_type)
|
||||
)
|
||||
return dataclass_field(number, TYPE_MAP, map_types=(key_type, value_type))
|
||||
|
||||
|
||||
def _pack_fmt(proto_type: str) -> str:
|
||||
@@ -336,6 +332,7 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
||||
number = num_wire >> 3
|
||||
wire_type = num_wire & 0x7
|
||||
|
||||
decoded: Any
|
||||
if wire_type == 0:
|
||||
decoded, i = decode_varint(value, i)
|
||||
elif wire_type == 1:
|
||||
@@ -369,11 +366,15 @@ class Message(ABC):
|
||||
# Set a default value for each field in the class after `__init__` has
|
||||
# already been run.
|
||||
for field in dataclasses.fields(self):
|
||||
if getattr(self, field.name) != PLACEHOLDER:
|
||||
# Skip anything not set (aka set to the sentinel value)
|
||||
continue
|
||||
|
||||
meta = FieldMetadata.get(field)
|
||||
|
||||
t = self._cls_for(field, index=-1)
|
||||
|
||||
value = 0
|
||||
value: Any = 0
|
||||
if meta.proto_type == TYPE_MAP:
|
||||
# Maps cannot be repeated, so we check these first.
|
||||
value = {}
|
||||
@@ -419,6 +420,7 @@ class Message(ABC):
|
||||
continue
|
||||
|
||||
for k, v in value.items():
|
||||
assert meta.map_types
|
||||
sk = _serialize_single(1, meta.map_types[0], k)
|
||||
sv = _serialize_single(2, meta.map_types[1], v)
|
||||
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
|
||||
@@ -431,10 +433,13 @@ class Message(ABC):
|
||||
|
||||
return output
|
||||
|
||||
# For compatibility with other libraries
|
||||
SerializeToString = __bytes__
|
||||
|
||||
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
|
||||
"""Get the message class for a field from the type hints."""
|
||||
module = inspect.getmodule(self)
|
||||
type_hints = get_type_hints(self, vars(module))
|
||||
module = inspect.getmodule(self.__class__)
|
||||
type_hints = get_type_hints(self.__class__, vars(module))
|
||||
cls = type_hints[field.name]
|
||||
if hasattr(cls, "__args__") and index >= 0:
|
||||
cls = type_hints[field.name].__args__[index]
|
||||
@@ -465,6 +470,7 @@ class Message(ABC):
|
||||
elif meta.proto_type in [TYPE_MAP]:
|
||||
# TODO: This is slow, use a cache to make it faster since each
|
||||
# key/value pair will recreate the class.
|
||||
assert meta.map_types
|
||||
kt = self._cls_for(field, index=0)
|
||||
vt = self._cls_for(field, index=1)
|
||||
Entry = dataclasses.make_dataclass(
|
||||
@@ -479,7 +485,7 @@ class Message(ABC):
|
||||
|
||||
return value
|
||||
|
||||
def parse(self, data: bytes) -> T:
|
||||
def parse(self: T, data: bytes) -> T:
|
||||
"""
|
||||
Parse the binary encoded Protobuf into this message instance. This
|
||||
returns the instance itself and is therefore assignable and chainable.
|
||||
@@ -490,6 +496,7 @@ class Message(ABC):
|
||||
field = fields[parsed.number]
|
||||
meta = FieldMetadata.get(field)
|
||||
|
||||
value: Any
|
||||
if (
|
||||
parsed.wire_type == WIRE_LEN_DELIM
|
||||
and meta.proto_type in PACKED_TYPES
|
||||
@@ -528,8 +535,15 @@ class Message(ABC):
|
||||
# TODO: handle unknown fields
|
||||
pass
|
||||
|
||||
from typing import cast
|
||||
|
||||
return self
|
||||
|
||||
# For compatibility with other libraries.
|
||||
@classmethod
|
||||
def FromString(cls: Type[T], data: bytes) -> T:
|
||||
return cls().parse(data)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
Returns a dict representation of this message instance which can be
|
||||
@@ -557,11 +571,11 @@ class Message(ABC):
|
||||
|
||||
if v:
|
||||
output[field.name] = v
|
||||
elif v != field.default:
|
||||
elif v != get_default(meta.proto_type):
|
||||
output[field.name] = v
|
||||
return output
|
||||
|
||||
def from_dict(self, value: dict) -> T:
|
||||
def from_dict(self: T, value: dict) -> T:
|
||||
"""
|
||||
Parse the key/value pairs in `value` into this message instance. This
|
||||
returns the instance itself and is therefore assignable and chainable.
|
||||
@@ -578,7 +592,7 @@ class Message(ABC):
|
||||
v.append(cls().from_dict(value[field.name][i]))
|
||||
else:
|
||||
v.from_dict(value[field.name])
|
||||
elif meta.proto_type == "map" and meta.map_types[1] == TYPE_MESSAGE:
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
v = getattr(self, field.name)
|
||||
cls = self._cls_for(field, index=1)
|
||||
for k in value[field.name]:
|
||||
@@ -587,13 +601,48 @@ class Message(ABC):
|
||||
setattr(self, field.name, value[field.name])
|
||||
return self
|
||||
|
||||
def to_json(self) -> bytes:
|
||||
def to_json(self) -> str:
|
||||
"""Returns the encoded JSON representation of this message instance."""
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
def from_json(self, value: bytes) -> T:
|
||||
def from_json(self: T, value: Union[str, bytes]) -> T:
|
||||
"""
|
||||
Parse the key/value pairs in `value` into this message instance. This
|
||||
returns the instance itself and is therefore assignable and chainable.
|
||||
"""
|
||||
return self.from_dict(json.loads(value))
|
||||
|
||||
|
||||
ResponseType = TypeVar("ResponseType", bound="Message")
|
||||
|
||||
|
||||
class ServiceStub(ABC):
|
||||
"""
|
||||
Base class for async gRPC service stubs.
|
||||
"""
|
||||
|
||||
def __init__(self, channel: grpclib.client.Channel) -> None:
|
||||
self.channel = channel
|
||||
|
||||
async def _unary_unary(
|
||||
self, route: str, request_type: Type, response_type: Type[T], request: Any
|
||||
) -> T:
|
||||
"""Make a unary request and return the response."""
|
||||
async with self.channel.request(
|
||||
route, grpclib.const.Cardinality.UNARY_UNARY, request_type, response_type
|
||||
) as stream:
|
||||
await stream.send_message(request, end=True)
|
||||
response = await stream.recv_message()
|
||||
assert response is not None
|
||||
return response
|
||||
|
||||
async def _unary_stream(
|
||||
self, route: str, request_type: Type, response_type: Type[T], request: Any
|
||||
) -> AsyncGenerator[T, None]:
|
||||
"""Make a unary request and return the stream response iterator."""
|
||||
async with self.channel.request(
|
||||
route, grpclib.const.Cardinality.UNARY_STREAM, request_type, response_type
|
||||
) as stream:
|
||||
await stream.send_message(request, end=True)
|
||||
async for message in stream:
|
||||
yield message
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: {{ description.filename }}
|
||||
# sources: {{ ', '.join(description.files) }}
|
||||
# plugin: python-betterproto
|
||||
{% if description.enums %}import enum
|
||||
{% endif %}
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
from typing import AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import betterproto
|
||||
{% if description.services %}
|
||||
import grpclib
|
||||
{% endif %}
|
||||
{% for i in description.imports %}
|
||||
|
||||
{{ i }}
|
||||
@@ -48,3 +51,36 @@ class {{ message.name }}(betterproto.Message):
|
||||
|
||||
|
||||
{% endfor %}
|
||||
{% for service in description.services %}
|
||||
class {{ service.name }}Stub(betterproto.ServiceStub):
|
||||
{% for method in service.methods %}
|
||||
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
|
||||
request = {{ method.input }}()
|
||||
{% for field in method.input_message.properties %}
|
||||
{% if field.field_type == 'message' %}
|
||||
if {{ field.name }} is not None:
|
||||
request.{{ field.name }} = {{ field.name }}
|
||||
{% else %}
|
||||
request.{{ field.name }} = {{ field.name }}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
{% if method.server_streaming %}
|
||||
async for response in self._unary_stream(
|
||||
"{{ method.route }}",
|
||||
{{ method.input }},
|
||||
{{ method.output }},
|
||||
request,
|
||||
):
|
||||
yield response
|
||||
{% else %}
|
||||
return await self._unary_unary(
|
||||
"{{ method.route }}",
|
||||
{{ method.input }},
|
||||
{{ method.output }},
|
||||
request,
|
||||
)
|
||||
{% endif %}
|
||||
|
||||
{% endfor %}
|
||||
{% endfor %}
|
||||
|
||||
Reference in New Issue
Block a user