Implement basic async gRPC support
This commit is contained in:
parent
41a96f65ee
commit
d93214eccd
1
Pipfile
1
Pipfile
@ -13,6 +13,7 @@ rope = "*"
|
|||||||
[packages]
|
[packages]
|
||||||
protobuf = "*"
|
protobuf = "*"
|
||||||
jinja2 = "*"
|
jinja2 = "*"
|
||||||
|
grpclib = "*"
|
||||||
|
|
||||||
[requires]
|
[requires]
|
||||||
python_version = "3.7"
|
python_version = "3.7"
|
||||||
|
64
Pipfile.lock
generated
64
Pipfile.lock
generated
@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"_meta": {
|
"_meta": {
|
||||||
"hash": {
|
"hash": {
|
||||||
"sha256": "6c1797fb4eb73be97ca566206527c9d648b90f38c5bf2caf4b69537cd325ced9"
|
"sha256": "f698150037f2a8ac554e4d37ecd4619ba35d1aa570f5b641d048ec9c6b23eb40"
|
||||||
},
|
},
|
||||||
"pipfile-spec": 6,
|
"pipfile-spec": 6,
|
||||||
"requires": {
|
"requires": {
|
||||||
@ -16,6 +16,34 @@
|
|||||||
]
|
]
|
||||||
},
|
},
|
||||||
"default": {
|
"default": {
|
||||||
|
"grpclib": {
|
||||||
|
"hashes": [
|
||||||
|
"sha256:d19e2ea87cb073e5b0825dfee15336fd2b1c09278d271816e04c90faddc107ea"
|
||||||
|
],
|
||||||
|
"index": "pypi",
|
||||||
|
"version": "==0.3.0"
|
||||||
|
},
|
||||||
|
"h2": {
|
||||||
|
"hashes": [
|
||||||
|
"sha256:ac377fcf586314ef3177bfd90c12c7826ab0840edeb03f0f24f511858326049e",
|
||||||
|
"sha256:b8a32bd282594424c0ac55845377eea13fa54fe4a8db012f3a198ed923dc3ab4"
|
||||||
|
],
|
||||||
|
"version": "==3.1.1"
|
||||||
|
},
|
||||||
|
"hpack": {
|
||||||
|
"hashes": [
|
||||||
|
"sha256:0edd79eda27a53ba5be2dfabf3b15780928a0dff6eb0c60a3d6767720e970c89",
|
||||||
|
"sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2"
|
||||||
|
],
|
||||||
|
"version": "==3.0.0"
|
||||||
|
},
|
||||||
|
"hyperframe": {
|
||||||
|
"hashes": [
|
||||||
|
"sha256:5187962cb16dcc078f23cb5a4b110098d546c3f41ff2d4038a9896893bbd0b40",
|
||||||
|
"sha256:a9f5c17f2cc3c719b917c4f33ed1c61bd1f8dfac4b1bd23b7c80b3400971b41f"
|
||||||
|
],
|
||||||
|
"version": "==5.2.0"
|
||||||
|
},
|
||||||
"jinja2": {
|
"jinja2": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:74320bb91f31270f9551d46522e33af46a80c3d619f4a4bf42b3164d30b5911f",
|
"sha256:74320bb91f31270f9551d46522e33af46a80c3d619f4a4bf42b3164d30b5911f",
|
||||||
@ -57,6 +85,40 @@
|
|||||||
],
|
],
|
||||||
"version": "==1.1.1"
|
"version": "==1.1.1"
|
||||||
},
|
},
|
||||||
|
"multidict": {
|
||||||
|
"hashes": [
|
||||||
|
"sha256:024b8129695a952ebd93373e45b5d341dbb87c17ce49637b34000093f243dd4f",
|
||||||
|
"sha256:041e9442b11409be5e4fc8b6a97e4bcead758ab1e11768d1e69160bdde18acc3",
|
||||||
|
"sha256:045b4dd0e5f6121e6f314d81759abd2c257db4634260abcfe0d3f7083c4908ef",
|
||||||
|
"sha256:047c0a04e382ef8bd74b0de01407e8d8632d7d1b4db6f2561106af812a68741b",
|
||||||
|
"sha256:068167c2d7bbeebd359665ac4fff756be5ffac9cda02375b5c5a7c4777038e73",
|
||||||
|
"sha256:148ff60e0fffa2f5fad2eb25aae7bef23d8f3b8bdaf947a65cdbe84a978092bc",
|
||||||
|
"sha256:1d1c77013a259971a72ddaa83b9f42c80a93ff12df6a4723be99d858fa30bee3",
|
||||||
|
"sha256:1d48bc124a6b7a55006d97917f695effa9725d05abe8ee78fd60d6588b8344cd",
|
||||||
|
"sha256:31dfa2fc323097f8ad7acd41aa38d7c614dd1960ac6681745b6da124093dc351",
|
||||||
|
"sha256:34f82db7f80c49f38b032c5abb605c458bac997a6c3142e0d6c130be6fb2b941",
|
||||||
|
"sha256:3d5dd8e5998fb4ace04789d1d008e2bb532de501218519d70bb672c4c5a2fc5d",
|
||||||
|
"sha256:4a6ae52bd3ee41ee0f3acf4c60ceb3f44e0e3bc52ab7da1c2b2aa6703363a3d1",
|
||||||
|
"sha256:4b02a3b2a2f01d0490dd39321c74273fed0568568ea0e7ea23e02bd1fb10a10b",
|
||||||
|
"sha256:4b843f8e1dd6a3195679d9838eb4670222e8b8d01bc36c9894d6c3538316fa0a",
|
||||||
|
"sha256:5de53a28f40ef3c4fd57aeab6b590c2c663de87a5af76136ced519923d3efbb3",
|
||||||
|
"sha256:61b2b33ede821b94fa99ce0b09c9ece049c7067a33b279f343adfe35108a4ea7",
|
||||||
|
"sha256:6a3a9b0f45fd75dc05d8e93dc21b18fc1670135ec9544d1ad4acbcf6b86781d0",
|
||||||
|
"sha256:76ad8e4c69dadbb31bad17c16baee61c0d1a4a73bed2590b741b2e1a46d3edd0",
|
||||||
|
"sha256:7ba19b777dc00194d1b473180d4ca89a054dd18de27d0ee2e42a103ec9b7d014",
|
||||||
|
"sha256:7c1b7eab7a49aa96f3db1f716f0113a8a2e93c7375dd3d5d21c4941f1405c9c5",
|
||||||
|
"sha256:7fc0eee3046041387cbace9314926aa48b681202f8897f8bff3809967a049036",
|
||||||
|
"sha256:8ccd1c5fff1aa1427100ce188557fc31f1e0a383ad8ec42c559aabd4ff08802d",
|
||||||
|
"sha256:8e08dd76de80539d613654915a2f5196dbccc67448df291e69a88712ea21e24a",
|
||||||
|
"sha256:c18498c50c59263841862ea0501da9f2b3659c00db54abfbf823a80787fde8ce",
|
||||||
|
"sha256:c49db89d602c24928e68c0d510f4fcf8989d77defd01c973d6cbe27e684833b1",
|
||||||
|
"sha256:ce20044d0317649ddbb4e54dab3c1bcc7483c78c27d3f58ab3d0c7e6bc60d26a",
|
||||||
|
"sha256:d1071414dd06ca2eafa90c85a079169bfeb0e5f57fd0b45d44c092546fcd6fd9",
|
||||||
|
"sha256:d3be11ac43ab1a3e979dac80843b42226d5d3cccd3986f2e03152720a4297cd7",
|
||||||
|
"sha256:db603a1c235d110c860d5f39988ebc8218ee028f07a7cbc056ba6424372ca31b"
|
||||||
|
],
|
||||||
|
"version": "==4.5.2"
|
||||||
|
},
|
||||||
"protobuf": {
|
"protobuf": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:125713564d8cfed7610e52444c9769b8dcb0b55e25cc7841f2290ee7bc86636f",
|
"sha256:125713564d8cfed7610e52444c9769b8dcb0b55e25cc7841f2290ee7bc86636f",
|
||||||
|
@ -33,5 +33,8 @@ This project is heavily inspired by, and borrows functionality from:
|
|||||||
- [ ] Well-known Google types
|
- [ ] Well-known Google types
|
||||||
- [ ] JSON that isn't completely naive.
|
- [ ] JSON that isn't completely naive.
|
||||||
- [ ] Async service stubs
|
- [ ] Async service stubs
|
||||||
|
- [x] Unary-unary
|
||||||
|
- [x] Server streaming response
|
||||||
|
- [ ] Client streaming request
|
||||||
- [ ] Python package
|
- [ ] Python package
|
||||||
- [ ] Cleanup!
|
- [ ] Cleanup!
|
||||||
|
@ -3,6 +3,7 @@ import json
|
|||||||
import struct
|
import struct
|
||||||
from typing import (
|
from typing import (
|
||||||
get_type_hints,
|
get_type_hints,
|
||||||
|
AsyncGenerator,
|
||||||
Union,
|
Union,
|
||||||
Generator,
|
Generator,
|
||||||
Any,
|
Any,
|
||||||
@ -17,6 +18,9 @@ from typing import (
|
|||||||
)
|
)
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
|
||||||
|
import grpclib.client
|
||||||
|
import grpclib.const
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
# Proto 3 data types
|
# Proto 3 data types
|
||||||
@ -92,7 +96,14 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
|
|||||||
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
||||||
|
|
||||||
|
|
||||||
def get_default(proto_type: int) -> Any:
|
class _PLACEHOLDER:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
PLACEHOLDER: Any = _PLACEHOLDER()
|
||||||
|
|
||||||
|
|
||||||
|
def get_default(proto_type: str) -> Any:
|
||||||
"""Get the default (zero value) for a given type."""
|
"""Get the default (zero value) for a given type."""
|
||||||
return {
|
return {
|
||||||
TYPE_BOOL: False,
|
TYPE_BOOL: False,
|
||||||
@ -114,8 +125,6 @@ class FieldMetadata:
|
|||||||
proto_type: str
|
proto_type: str
|
||||||
# Map information if the proto_type is a map
|
# Map information if the proto_type is a map
|
||||||
map_types: Optional[Tuple[str, str]]
|
map_types: Optional[Tuple[str, str]]
|
||||||
# Default value if given
|
|
||||||
default: Any
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get(field: dataclasses.Field) -> "FieldMetadata":
|
def get(field: dataclasses.Field) -> "FieldMetadata":
|
||||||
@ -124,23 +133,12 @@ class FieldMetadata:
|
|||||||
|
|
||||||
|
|
||||||
def dataclass_field(
|
def dataclass_field(
|
||||||
number: int,
|
number: int, proto_type: str, map_types: Optional[Tuple[str, str]] = None
|
||||||
proto_type: str,
|
|
||||||
default: Any = None,
|
|
||||||
map_types: Optional[Tuple[str, str]] = None,
|
|
||||||
**kwargs: dict,
|
|
||||||
) -> dataclasses.Field:
|
) -> dataclasses.Field:
|
||||||
"""Creates a dataclass field with attached protobuf metadata."""
|
"""Creates a dataclass field with attached protobuf metadata."""
|
||||||
if callable(default):
|
|
||||||
kwargs["default_factory"] = default
|
|
||||||
elif isinstance(default, dict) or isinstance(default, list):
|
|
||||||
kwargs["default_factory"] = lambda: default
|
|
||||||
else:
|
|
||||||
kwargs["default"] = default
|
|
||||||
|
|
||||||
return dataclasses.field(
|
return dataclasses.field(
|
||||||
**kwargs,
|
default=PLACEHOLDER,
|
||||||
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, default)},
|
metadata={"betterproto": FieldMetadata(number, proto_type, map_types)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -149,68 +147,68 @@ def dataclass_field(
|
|||||||
# out at runtime. The generated dataclass variables are still typed correctly.
|
# out at runtime. The generated dataclass variables are still typed correctly.
|
||||||
|
|
||||||
|
|
||||||
def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
|
def enum_field(number: int) -> Any:
|
||||||
return dataclass_field(number, TYPE_ENUM, default=default)
|
return dataclass_field(number, TYPE_ENUM)
|
||||||
|
|
||||||
|
|
||||||
def bool_field(number: int, default: Union[bool, Type[Iterable]] = 0) -> Any:
|
def bool_field(number: int) -> Any:
|
||||||
return dataclass_field(number, TYPE_BOOL, default=default)
|
return dataclass_field(number, TYPE_BOOL)
|
||||||
|
|
||||||
|
|
||||||
def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
|
def int32_field(number: int) -> Any:
|
||||||
return dataclass_field(number, TYPE_INT32, default=default)
|
return dataclass_field(number, TYPE_INT32)
|
||||||
|
|
||||||
|
|
||||||
def int64_field(number: int, default: int = 0) -> Any:
|
def int64_field(number: int) -> Any:
|
||||||
return dataclass_field(number, TYPE_INT64, default=default)
|
return dataclass_field(number, TYPE_INT64)
|
||||||
|
|
||||||
|
|
||||||
def uint32_field(number: int, default: int = 0) -> Any:
|
def uint32_field(number: int) -> Any:
|
||||||
return dataclass_field(number, TYPE_UINT32, default=default)
|
return dataclass_field(number, TYPE_UINT32)
|
||||||
|
|
||||||
|
|
||||||
def uint64_field(number: int, default: int = 0) -> Any:
|
def uint64_field(number: int) -> Any:
|
||||||
return dataclass_field(number, TYPE_UINT64, default=default)
|
return dataclass_field(number, TYPE_UINT64)
|
||||||
|
|
||||||
|
|
||||||
def sint32_field(number: int, default: int = 0) -> Any:
|
def sint32_field(number: int) -> Any:
|
||||||
return dataclass_field(number, TYPE_SINT32, default=default)
|
return dataclass_field(number, TYPE_SINT32)
|
||||||
|
|
||||||
|
|
||||||
def sint64_field(number: int, default: int = 0) -> Any:
|
def sint64_field(number: int) -> Any:
|
||||||
return dataclass_field(number, TYPE_SINT64, default=default)
|
return dataclass_field(number, TYPE_SINT64)
|
||||||
|
|
||||||
|
|
||||||
def float_field(number: int, default: float = 0.0) -> Any:
|
def float_field(number: int) -> Any:
|
||||||
return dataclass_field(number, TYPE_FLOAT, default=default)
|
return dataclass_field(number, TYPE_FLOAT)
|
||||||
|
|
||||||
|
|
||||||
def double_field(number: int, default: float = 0.0) -> Any:
|
def double_field(number: int) -> Any:
|
||||||
return dataclass_field(number, TYPE_DOUBLE, default=default)
|
return dataclass_field(number, TYPE_DOUBLE)
|
||||||
|
|
||||||
|
|
||||||
def fixed32_field(number: int, default: float = 0.0) -> Any:
|
def fixed32_field(number: int) -> Any:
|
||||||
return dataclass_field(number, TYPE_FIXED32, default=default)
|
return dataclass_field(number, TYPE_FIXED32)
|
||||||
|
|
||||||
|
|
||||||
def fixed64_field(number: int, default: float = 0.0) -> Any:
|
def fixed64_field(number: int) -> Any:
|
||||||
return dataclass_field(number, TYPE_FIXED64, default=default)
|
return dataclass_field(number, TYPE_FIXED64)
|
||||||
|
|
||||||
|
|
||||||
def sfixed32_field(number: int, default: float = 0.0) -> Any:
|
def sfixed32_field(number: int) -> Any:
|
||||||
return dataclass_field(number, TYPE_SFIXED32, default=default)
|
return dataclass_field(number, TYPE_SFIXED32)
|
||||||
|
|
||||||
|
|
||||||
def sfixed64_field(number: int, default: float = 0.0) -> Any:
|
def sfixed64_field(number: int) -> Any:
|
||||||
return dataclass_field(number, TYPE_SFIXED64, default=default)
|
return dataclass_field(number, TYPE_SFIXED64)
|
||||||
|
|
||||||
|
|
||||||
def string_field(number: int, default: str = "") -> Any:
|
def string_field(number: int) -> Any:
|
||||||
return dataclass_field(number, TYPE_STRING, default=default)
|
return dataclass_field(number, TYPE_STRING)
|
||||||
|
|
||||||
|
|
||||||
def bytes_field(number: int, default: bytes = b"") -> Any:
|
def bytes_field(number: int) -> Any:
|
||||||
return dataclass_field(number, TYPE_BYTES, default=default)
|
return dataclass_field(number, TYPE_BYTES)
|
||||||
|
|
||||||
|
|
||||||
def message_field(number: int) -> Any:
|
def message_field(number: int) -> Any:
|
||||||
@ -218,9 +216,7 @@ def message_field(number: int) -> Any:
|
|||||||
|
|
||||||
|
|
||||||
def map_field(number: int, key_type: str, value_type: str) -> Any:
|
def map_field(number: int, key_type: str, value_type: str) -> Any:
|
||||||
return dataclass_field(
|
return dataclass_field(number, TYPE_MAP, map_types=(key_type, value_type))
|
||||||
number, TYPE_MAP, default=dict, map_types=(key_type, value_type)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _pack_fmt(proto_type: str) -> str:
|
def _pack_fmt(proto_type: str) -> str:
|
||||||
@ -336,6 +332,7 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
|||||||
number = num_wire >> 3
|
number = num_wire >> 3
|
||||||
wire_type = num_wire & 0x7
|
wire_type = num_wire & 0x7
|
||||||
|
|
||||||
|
decoded: Any
|
||||||
if wire_type == 0:
|
if wire_type == 0:
|
||||||
decoded, i = decode_varint(value, i)
|
decoded, i = decode_varint(value, i)
|
||||||
elif wire_type == 1:
|
elif wire_type == 1:
|
||||||
@ -369,11 +366,15 @@ class Message(ABC):
|
|||||||
# Set a default value for each field in the class after `__init__` has
|
# Set a default value for each field in the class after `__init__` has
|
||||||
# already been run.
|
# already been run.
|
||||||
for field in dataclasses.fields(self):
|
for field in dataclasses.fields(self):
|
||||||
|
if getattr(self, field.name) != PLACEHOLDER:
|
||||||
|
# Skip anything not set (aka set to the sentinel value)
|
||||||
|
continue
|
||||||
|
|
||||||
meta = FieldMetadata.get(field)
|
meta = FieldMetadata.get(field)
|
||||||
|
|
||||||
t = self._cls_for(field, index=-1)
|
t = self._cls_for(field, index=-1)
|
||||||
|
|
||||||
value = 0
|
value: Any = 0
|
||||||
if meta.proto_type == TYPE_MAP:
|
if meta.proto_type == TYPE_MAP:
|
||||||
# Maps cannot be repeated, so we check these first.
|
# Maps cannot be repeated, so we check these first.
|
||||||
value = {}
|
value = {}
|
||||||
@ -419,6 +420,7 @@ class Message(ABC):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
for k, v in value.items():
|
for k, v in value.items():
|
||||||
|
assert meta.map_types
|
||||||
sk = _serialize_single(1, meta.map_types[0], k)
|
sk = _serialize_single(1, meta.map_types[0], k)
|
||||||
sv = _serialize_single(2, meta.map_types[1], v)
|
sv = _serialize_single(2, meta.map_types[1], v)
|
||||||
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
|
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
|
||||||
@ -431,10 +433,13 @@ class Message(ABC):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
# For compatibility with other libraries
|
||||||
|
SerializeToString = __bytes__
|
||||||
|
|
||||||
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
|
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
|
||||||
"""Get the message class for a field from the type hints."""
|
"""Get the message class for a field from the type hints."""
|
||||||
module = inspect.getmodule(self)
|
module = inspect.getmodule(self.__class__)
|
||||||
type_hints = get_type_hints(self, vars(module))
|
type_hints = get_type_hints(self.__class__, vars(module))
|
||||||
cls = type_hints[field.name]
|
cls = type_hints[field.name]
|
||||||
if hasattr(cls, "__args__") and index >= 0:
|
if hasattr(cls, "__args__") and index >= 0:
|
||||||
cls = type_hints[field.name].__args__[index]
|
cls = type_hints[field.name].__args__[index]
|
||||||
@ -465,6 +470,7 @@ class Message(ABC):
|
|||||||
elif meta.proto_type in [TYPE_MAP]:
|
elif meta.proto_type in [TYPE_MAP]:
|
||||||
# TODO: This is slow, use a cache to make it faster since each
|
# TODO: This is slow, use a cache to make it faster since each
|
||||||
# key/value pair will recreate the class.
|
# key/value pair will recreate the class.
|
||||||
|
assert meta.map_types
|
||||||
kt = self._cls_for(field, index=0)
|
kt = self._cls_for(field, index=0)
|
||||||
vt = self._cls_for(field, index=1)
|
vt = self._cls_for(field, index=1)
|
||||||
Entry = dataclasses.make_dataclass(
|
Entry = dataclasses.make_dataclass(
|
||||||
@ -479,7 +485,7 @@ class Message(ABC):
|
|||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def parse(self, data: bytes) -> T:
|
def parse(self: T, data: bytes) -> T:
|
||||||
"""
|
"""
|
||||||
Parse the binary encoded Protobuf into this message instance. This
|
Parse the binary encoded Protobuf into this message instance. This
|
||||||
returns the instance itself and is therefore assignable and chainable.
|
returns the instance itself and is therefore assignable and chainable.
|
||||||
@ -490,6 +496,7 @@ class Message(ABC):
|
|||||||
field = fields[parsed.number]
|
field = fields[parsed.number]
|
||||||
meta = FieldMetadata.get(field)
|
meta = FieldMetadata.get(field)
|
||||||
|
|
||||||
|
value: Any
|
||||||
if (
|
if (
|
||||||
parsed.wire_type == WIRE_LEN_DELIM
|
parsed.wire_type == WIRE_LEN_DELIM
|
||||||
and meta.proto_type in PACKED_TYPES
|
and meta.proto_type in PACKED_TYPES
|
||||||
@ -528,8 +535,15 @@ class Message(ABC):
|
|||||||
# TODO: handle unknown fields
|
# TODO: handle unknown fields
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
# For compatibility with other libraries.
|
||||||
|
@classmethod
|
||||||
|
def FromString(cls: Type[T], data: bytes) -> T:
|
||||||
|
return cls().parse(data)
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
"""
|
"""
|
||||||
Returns a dict representation of this message instance which can be
|
Returns a dict representation of this message instance which can be
|
||||||
@ -557,11 +571,11 @@ class Message(ABC):
|
|||||||
|
|
||||||
if v:
|
if v:
|
||||||
output[field.name] = v
|
output[field.name] = v
|
||||||
elif v != field.default:
|
elif v != get_default(meta.proto_type):
|
||||||
output[field.name] = v
|
output[field.name] = v
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def from_dict(self, value: dict) -> T:
|
def from_dict(self: T, value: dict) -> T:
|
||||||
"""
|
"""
|
||||||
Parse the key/value pairs in `value` into this message instance. This
|
Parse the key/value pairs in `value` into this message instance. This
|
||||||
returns the instance itself and is therefore assignable and chainable.
|
returns the instance itself and is therefore assignable and chainable.
|
||||||
@ -578,7 +592,7 @@ class Message(ABC):
|
|||||||
v.append(cls().from_dict(value[field.name][i]))
|
v.append(cls().from_dict(value[field.name][i]))
|
||||||
else:
|
else:
|
||||||
v.from_dict(value[field.name])
|
v.from_dict(value[field.name])
|
||||||
elif meta.proto_type == "map" and meta.map_types[1] == TYPE_MESSAGE:
|
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||||
v = getattr(self, field.name)
|
v = getattr(self, field.name)
|
||||||
cls = self._cls_for(field, index=1)
|
cls = self._cls_for(field, index=1)
|
||||||
for k in value[field.name]:
|
for k in value[field.name]:
|
||||||
@ -587,13 +601,48 @@ class Message(ABC):
|
|||||||
setattr(self, field.name, value[field.name])
|
setattr(self, field.name, value[field.name])
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_json(self) -> bytes:
|
def to_json(self) -> str:
|
||||||
"""Returns the encoded JSON representation of this message instance."""
|
"""Returns the encoded JSON representation of this message instance."""
|
||||||
return json.dumps(self.to_dict())
|
return json.dumps(self.to_dict())
|
||||||
|
|
||||||
def from_json(self, value: bytes) -> T:
|
def from_json(self: T, value: Union[str, bytes]) -> T:
|
||||||
"""
|
"""
|
||||||
Parse the key/value pairs in `value` into this message instance. This
|
Parse the key/value pairs in `value` into this message instance. This
|
||||||
returns the instance itself and is therefore assignable and chainable.
|
returns the instance itself and is therefore assignable and chainable.
|
||||||
"""
|
"""
|
||||||
return self.from_dict(json.loads(value))
|
return self.from_dict(json.loads(value))
|
||||||
|
|
||||||
|
|
||||||
|
ResponseType = TypeVar("ResponseType", bound="Message")
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceStub(ABC):
|
||||||
|
"""
|
||||||
|
Base class for async gRPC service stubs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channel: grpclib.client.Channel) -> None:
|
||||||
|
self.channel = channel
|
||||||
|
|
||||||
|
async def _unary_unary(
|
||||||
|
self, route: str, request_type: Type, response_type: Type[T], request: Any
|
||||||
|
) -> T:
|
||||||
|
"""Make a unary request and return the response."""
|
||||||
|
async with self.channel.request(
|
||||||
|
route, grpclib.const.Cardinality.UNARY_UNARY, request_type, response_type
|
||||||
|
) as stream:
|
||||||
|
await stream.send_message(request, end=True)
|
||||||
|
response = await stream.recv_message()
|
||||||
|
assert response is not None
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def _unary_stream(
|
||||||
|
self, route: str, request_type: Type, response_type: Type[T], request: Any
|
||||||
|
) -> AsyncGenerator[T, None]:
|
||||||
|
"""Make a unary request and return the stream response iterator."""
|
||||||
|
async with self.channel.request(
|
||||||
|
route, grpclib.const.Cardinality.UNARY_STREAM, request_type, response_type
|
||||||
|
) as stream:
|
||||||
|
await stream.send_message(request, end=True)
|
||||||
|
async for message in stream:
|
||||||
|
yield message
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
# source: {{ description.filename }}
|
# sources: {{ ', '.join(description.files) }}
|
||||||
# plugin: python-betterproto
|
# plugin: python-betterproto
|
||||||
{% if description.enums %}import enum
|
{% if description.enums %}import enum
|
||||||
{% endif %}
|
{% endif %}
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List
|
from typing import AsyncGenerator, Dict, List, Optional
|
||||||
|
|
||||||
import betterproto
|
import betterproto
|
||||||
|
{% if description.services %}
|
||||||
|
import grpclib
|
||||||
|
{% endif %}
|
||||||
{% for i in description.imports %}
|
{% for i in description.imports %}
|
||||||
|
|
||||||
{{ i }}
|
{{ i }}
|
||||||
@ -48,3 +51,36 @@ class {{ message.name }}(betterproto.Message):
|
|||||||
|
|
||||||
|
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
{% for service in description.services %}
|
||||||
|
class {{ service.name }}Stub(betterproto.ServiceStub):
|
||||||
|
{% for method in service.methods %}
|
||||||
|
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
|
||||||
|
request = {{ method.input }}()
|
||||||
|
{% for field in method.input_message.properties %}
|
||||||
|
{% if field.field_type == 'message' %}
|
||||||
|
if {{ field.name }} is not None:
|
||||||
|
request.{{ field.name }} = {{ field.name }}
|
||||||
|
{% else %}
|
||||||
|
request.{{ field.name }} = {{ field.name }}
|
||||||
|
{% endif %}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
{% if method.server_streaming %}
|
||||||
|
async for response in self._unary_stream(
|
||||||
|
"{{ method.route }}",
|
||||||
|
{{ method.input }},
|
||||||
|
{{ method.output }},
|
||||||
|
request,
|
||||||
|
):
|
||||||
|
yield response
|
||||||
|
{% else %}
|
||||||
|
return await self._unary_unary(
|
||||||
|
"{{ method.route }}",
|
||||||
|
{{ method.input }},
|
||||||
|
{{ method.output }},
|
||||||
|
request,
|
||||||
|
)
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% endfor %}
|
||||||
|
{% endfor %}
|
||||||
|
@ -5,6 +5,7 @@ import sys
|
|||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import os.path
|
import os.path
|
||||||
|
import re
|
||||||
from typing import Tuple, Any, List
|
from typing import Tuple, Any, List
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
@ -13,6 +14,7 @@ from google.protobuf.descriptor_pb2 import (
|
|||||||
EnumDescriptorProto,
|
EnumDescriptorProto,
|
||||||
FileDescriptorProto,
|
FileDescriptorProto,
|
||||||
FieldDescriptorProto,
|
FieldDescriptorProto,
|
||||||
|
ServiceDescriptorProto,
|
||||||
)
|
)
|
||||||
|
|
||||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||||
@ -21,6 +23,32 @@ from google.protobuf.compiler import plugin_pb2 as plugin
|
|||||||
from jinja2 import Environment, PackageLoader
|
from jinja2 import Environment, PackageLoader
|
||||||
|
|
||||||
|
|
||||||
|
def snake_case(value: str) -> str:
|
||||||
|
return (
|
||||||
|
re.sub(r"(?<=[a-z])[A-Z]|[A-Z](?=[^A-Z])", r"_\g<0>", value).lower().strip("_")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Return a Python type name for a proto type reference. Adds the import if
|
||||||
|
necessary.
|
||||||
|
"""
|
||||||
|
type_name = type_name.lstrip(".")
|
||||||
|
if type_name.startswith(package):
|
||||||
|
# This is the current package, which has nested types flattened.
|
||||||
|
type_name = f'"{type_name.lstrip(package).lstrip(".").replace(".", "")}"'
|
||||||
|
|
||||||
|
if "." in type_name:
|
||||||
|
# This is imported from another package. No need
|
||||||
|
# to use a forward ref and we need to add the import.
|
||||||
|
parts = type_name.split(".")
|
||||||
|
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
|
||||||
|
type_name = f"{parts[-2]}.{parts[-1]}"
|
||||||
|
|
||||||
|
return type_name
|
||||||
|
|
||||||
|
|
||||||
def py_type(
|
def py_type(
|
||||||
package: str,
|
package: str,
|
||||||
imports: set,
|
imports: set,
|
||||||
@ -37,35 +65,29 @@ def py_type(
|
|||||||
return "str"
|
return "str"
|
||||||
elif descriptor.type in [11, 14]:
|
elif descriptor.type in [11, 14]:
|
||||||
# Type referencing another defined Message or a named enum
|
# Type referencing another defined Message or a named enum
|
||||||
message_type = descriptor.type_name.lstrip(".")
|
return get_ref_type(package, imports, descriptor.type_name)
|
||||||
if message_type.startswith(package):
|
|
||||||
# This is the current package, which has nested types flattened.
|
|
||||||
message_type = (
|
|
||||||
f'"{message_type.lstrip(package).lstrip(".").replace(".", "")}"'
|
|
||||||
)
|
|
||||||
|
|
||||||
if "." in message_type:
|
|
||||||
# This is imported from another package. No need
|
|
||||||
# to use a forward ref and we need to add the import.
|
|
||||||
parts = message_type.split(".")
|
|
||||||
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
|
|
||||||
message_type = f"{parts[-2]}.{parts[-1]}"
|
|
||||||
|
|
||||||
# print(
|
|
||||||
# descriptor.name,
|
|
||||||
# package,
|
|
||||||
# descriptor.type_name,
|
|
||||||
# message_type,
|
|
||||||
# file=sys.stderr,
|
|
||||||
# )
|
|
||||||
|
|
||||||
return message_type
|
|
||||||
elif descriptor.type == 12:
|
elif descriptor.type == 12:
|
||||||
return "bytes"
|
return "bytes"
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unknown type {descriptor.type}")
|
raise NotImplementedError(f"Unknown type {descriptor.type}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_py_zero(type_num: int) -> str:
|
||||||
|
zero = 0
|
||||||
|
if type_num in []:
|
||||||
|
zero = 0.0
|
||||||
|
elif type_num == 8:
|
||||||
|
zero = "False"
|
||||||
|
elif type_num == 9:
|
||||||
|
zero = '""'
|
||||||
|
elif type_num == 11:
|
||||||
|
zero = "None"
|
||||||
|
elif type_num == 12:
|
||||||
|
zero = 'b""'
|
||||||
|
|
||||||
|
return zero
|
||||||
|
|
||||||
|
|
||||||
def traverse(proto_file):
|
def traverse(proto_file):
|
||||||
def _traverse(path, items):
|
def _traverse(path, items):
|
||||||
for i, item in enumerate(items):
|
for i, item in enumerate(items):
|
||||||
@ -73,6 +95,7 @@ def traverse(proto_file):
|
|||||||
|
|
||||||
if isinstance(item, DescriptorProto):
|
if isinstance(item, DescriptorProto):
|
||||||
for enum in item.enum_type:
|
for enum in item.enum_type:
|
||||||
|
enum.name = item.name + enum.name
|
||||||
yield enum, path + [i, 4]
|
yield enum, path + [i, 4]
|
||||||
|
|
||||||
if item.nested_type:
|
if item.nested_type:
|
||||||
@ -103,7 +126,8 @@ def get_comment(proto_file, path: List[int]) -> str:
|
|||||||
lines[0] = lines[0].strip('"')
|
lines[0] = lines[0].strip('"')
|
||||||
return f' """{lines[0]}"""'
|
return f' """{lines[0]}"""'
|
||||||
else:
|
else:
|
||||||
return f' """\n{" ".join(lines)}\n """'
|
joined = "\n ".join(lines)
|
||||||
|
return f' """\n {joined}\n """'
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@ -116,10 +140,6 @@ def generate_code(request, response):
|
|||||||
)
|
)
|
||||||
template = env.get_template("main.py")
|
template = env.get_template("main.py")
|
||||||
|
|
||||||
# TODO: Refactor below to generate a single file per package if packages
|
|
||||||
# are being used, otherwise one output for each input. Figure out how to
|
|
||||||
# set up relative imports when needed and change the Message type refs to
|
|
||||||
# use the import names when not in the current module.
|
|
||||||
output_map = {}
|
output_map = {}
|
||||||
for proto_file in request.proto_file:
|
for proto_file in request.proto_file:
|
||||||
out = proto_file.package
|
out = proto_file.package
|
||||||
@ -136,7 +156,16 @@ def generate_code(request, response):
|
|||||||
for filename, options in output_map.items():
|
for filename, options in output_map.items():
|
||||||
package = options["package"]
|
package = options["package"]
|
||||||
# print(package, filename, file=sys.stderr)
|
# print(package, filename, file=sys.stderr)
|
||||||
output = {"package": package, "imports": set(), "messages": [], "enums": []}
|
output = {
|
||||||
|
"package": package,
|
||||||
|
"files": [f.name for f in options["files"]],
|
||||||
|
"imports": set(),
|
||||||
|
"messages": [],
|
||||||
|
"enums": [],
|
||||||
|
"services": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
type_mapping = {}
|
||||||
|
|
||||||
for proto_file in options["files"]:
|
for proto_file in options["files"]:
|
||||||
# print(proto_file.message_type, file=sys.stderr)
|
# print(proto_file.message_type, file=sys.stderr)
|
||||||
@ -164,6 +193,7 @@ def generate_code(request, response):
|
|||||||
|
|
||||||
for i, f in enumerate(item.field):
|
for i, f in enumerate(item.field):
|
||||||
t = py_type(package, output["imports"], item, f)
|
t = py_type(package, output["imports"], item, f)
|
||||||
|
zero = get_py_zero(f.type)
|
||||||
|
|
||||||
repeated = False
|
repeated = False
|
||||||
packed = False
|
packed = False
|
||||||
@ -172,12 +202,16 @@ def generate_code(request, response):
|
|||||||
map_types = None
|
map_types = None
|
||||||
if f.type == 11:
|
if f.type == 11:
|
||||||
# This might be a map...
|
# This might be a map...
|
||||||
message_type = f.type_name.split(".").pop()
|
message_type = f.type_name.split(".").pop().lower()
|
||||||
map_entry = f"{f.name.capitalize()}Entry"
|
# message_type = py_type(package)
|
||||||
|
map_entry = f"{f.name.replace('_', '').lower()}entry"
|
||||||
|
|
||||||
if message_type == map_entry:
|
if message_type == map_entry:
|
||||||
for nested in item.nested_type:
|
for nested in item.nested_type:
|
||||||
if nested.name == map_entry:
|
if (
|
||||||
|
nested.name.replace("_", "").lower()
|
||||||
|
== map_entry
|
||||||
|
):
|
||||||
if nested.options.map_entry:
|
if nested.options.map_entry:
|
||||||
# print("Found a map!", file=sys.stderr)
|
# print("Found a map!", file=sys.stderr)
|
||||||
k = py_type(
|
k = py_type(
|
||||||
@ -203,6 +237,7 @@ def generate_code(request, response):
|
|||||||
# Repeated field
|
# Repeated field
|
||||||
repeated = True
|
repeated = True
|
||||||
t = f"List[{t}]"
|
t = f"List[{t}]"
|
||||||
|
zero = "[]"
|
||||||
|
|
||||||
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
|
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
|
||||||
packed = True
|
packed = True
|
||||||
@ -216,6 +251,7 @@ def generate_code(request, response):
|
|||||||
"field_type": field_type,
|
"field_type": field_type,
|
||||||
"map_types": map_types,
|
"map_types": map_types,
|
||||||
"type": t,
|
"type": t,
|
||||||
|
"zero": zero,
|
||||||
"repeated": repeated,
|
"repeated": repeated,
|
||||||
"packed": packed,
|
"packed": packed,
|
||||||
}
|
}
|
||||||
@ -223,7 +259,6 @@ def generate_code(request, response):
|
|||||||
# print(f, file=sys.stderr)
|
# print(f, file=sys.stderr)
|
||||||
|
|
||||||
output["messages"].append(data)
|
output["messages"].append(data)
|
||||||
|
|
||||||
elif isinstance(item, EnumDescriptorProto):
|
elif isinstance(item, EnumDescriptorProto):
|
||||||
# print(item.name, path, file=sys.stderr)
|
# print(item.name, path, file=sys.stderr)
|
||||||
data.update(
|
data.update(
|
||||||
@ -243,6 +278,44 @@ def generate_code(request, response):
|
|||||||
|
|
||||||
output["enums"].append(data)
|
output["enums"].append(data)
|
||||||
|
|
||||||
|
for service in proto_file.service:
|
||||||
|
# print(service, file=sys.stderr)
|
||||||
|
|
||||||
|
# TODO: comments
|
||||||
|
data = {"name": service.name, "methods": []}
|
||||||
|
|
||||||
|
for method in service.method:
|
||||||
|
if method.client_streaming:
|
||||||
|
raise NotImplementedError("Client streaming not yet supported")
|
||||||
|
|
||||||
|
input_message = None
|
||||||
|
input_type = get_ref_type(
|
||||||
|
package, output["imports"], method.input_type
|
||||||
|
).strip('"')
|
||||||
|
for msg in output["messages"]:
|
||||||
|
if msg["name"] == input_type:
|
||||||
|
input_message = msg
|
||||||
|
break
|
||||||
|
|
||||||
|
data["methods"].append(
|
||||||
|
{
|
||||||
|
"name": method.name,
|
||||||
|
"py_name": snake_case(method.name),
|
||||||
|
"route": f"/{package}.{service.name}/{method.name}",
|
||||||
|
"input": get_ref_type(
|
||||||
|
package, output["imports"], method.input_type
|
||||||
|
).strip('"'),
|
||||||
|
"input_message": input_message,
|
||||||
|
"output": get_ref_type(
|
||||||
|
package, output["imports"], method.output_type
|
||||||
|
).strip('"'),
|
||||||
|
"client_streaming": method.client_streaming,
|
||||||
|
"server_streaming": method.server_streaming,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
output["services"].append(data)
|
||||||
|
|
||||||
output["imports"] = sorted(output["imports"])
|
output["imports"] = sorted(output["imports"])
|
||||||
|
|
||||||
# Fill response
|
# Fill response
|
||||||
@ -256,7 +329,7 @@ def generate_code(request, response):
|
|||||||
inits = set([""])
|
inits = set([""])
|
||||||
for f in response.file:
|
for f in response.file:
|
||||||
# Ensure output paths exist
|
# Ensure output paths exist
|
||||||
print(f.name, file=sys.stderr)
|
# print(f.name, file=sys.stderr)
|
||||||
dirnames = os.path.dirname(f.name)
|
dirnames = os.path.dirname(f.name)
|
||||||
if dirnames:
|
if dirnames:
|
||||||
os.makedirs(dirnames, exist_ok=True)
|
os.makedirs(dirnames, exist_ok=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user