Implement basic async gRPC support

This commit is contained in:
Daniel G. Taylor 2019-10-16 22:52:38 -07:00
parent 41a96f65ee
commit d93214eccd
No known key found for this signature in database
GPG Key ID: 7BD6DC99C9A87E22
6 changed files with 322 additions and 98 deletions

View File

@ -13,6 +13,7 @@ rope = "*"
[packages] [packages]
protobuf = "*" protobuf = "*"
jinja2 = "*" jinja2 = "*"
grpclib = "*"
[requires] [requires]
python_version = "3.7" python_version = "3.7"

64
Pipfile.lock generated
View File

@ -1,7 +1,7 @@
{ {
"_meta": { "_meta": {
"hash": { "hash": {
"sha256": "6c1797fb4eb73be97ca566206527c9d648b90f38c5bf2caf4b69537cd325ced9" "sha256": "f698150037f2a8ac554e4d37ecd4619ba35d1aa570f5b641d048ec9c6b23eb40"
}, },
"pipfile-spec": 6, "pipfile-spec": 6,
"requires": { "requires": {
@ -16,6 +16,34 @@
] ]
}, },
"default": { "default": {
"grpclib": {
"hashes": [
"sha256:d19e2ea87cb073e5b0825dfee15336fd2b1c09278d271816e04c90faddc107ea"
],
"index": "pypi",
"version": "==0.3.0"
},
"h2": {
"hashes": [
"sha256:ac377fcf586314ef3177bfd90c12c7826ab0840edeb03f0f24f511858326049e",
"sha256:b8a32bd282594424c0ac55845377eea13fa54fe4a8db012f3a198ed923dc3ab4"
],
"version": "==3.1.1"
},
"hpack": {
"hashes": [
"sha256:0edd79eda27a53ba5be2dfabf3b15780928a0dff6eb0c60a3d6767720e970c89",
"sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2"
],
"version": "==3.0.0"
},
"hyperframe": {
"hashes": [
"sha256:5187962cb16dcc078f23cb5a4b110098d546c3f41ff2d4038a9896893bbd0b40",
"sha256:a9f5c17f2cc3c719b917c4f33ed1c61bd1f8dfac4b1bd23b7c80b3400971b41f"
],
"version": "==5.2.0"
},
"jinja2": { "jinja2": {
"hashes": [ "hashes": [
"sha256:74320bb91f31270f9551d46522e33af46a80c3d619f4a4bf42b3164d30b5911f", "sha256:74320bb91f31270f9551d46522e33af46a80c3d619f4a4bf42b3164d30b5911f",
@ -57,6 +85,40 @@
], ],
"version": "==1.1.1" "version": "==1.1.1"
}, },
"multidict": {
"hashes": [
"sha256:024b8129695a952ebd93373e45b5d341dbb87c17ce49637b34000093f243dd4f",
"sha256:041e9442b11409be5e4fc8b6a97e4bcead758ab1e11768d1e69160bdde18acc3",
"sha256:045b4dd0e5f6121e6f314d81759abd2c257db4634260abcfe0d3f7083c4908ef",
"sha256:047c0a04e382ef8bd74b0de01407e8d8632d7d1b4db6f2561106af812a68741b",
"sha256:068167c2d7bbeebd359665ac4fff756be5ffac9cda02375b5c5a7c4777038e73",
"sha256:148ff60e0fffa2f5fad2eb25aae7bef23d8f3b8bdaf947a65cdbe84a978092bc",
"sha256:1d1c77013a259971a72ddaa83b9f42c80a93ff12df6a4723be99d858fa30bee3",
"sha256:1d48bc124a6b7a55006d97917f695effa9725d05abe8ee78fd60d6588b8344cd",
"sha256:31dfa2fc323097f8ad7acd41aa38d7c614dd1960ac6681745b6da124093dc351",
"sha256:34f82db7f80c49f38b032c5abb605c458bac997a6c3142e0d6c130be6fb2b941",
"sha256:3d5dd8e5998fb4ace04789d1d008e2bb532de501218519d70bb672c4c5a2fc5d",
"sha256:4a6ae52bd3ee41ee0f3acf4c60ceb3f44e0e3bc52ab7da1c2b2aa6703363a3d1",
"sha256:4b02a3b2a2f01d0490dd39321c74273fed0568568ea0e7ea23e02bd1fb10a10b",
"sha256:4b843f8e1dd6a3195679d9838eb4670222e8b8d01bc36c9894d6c3538316fa0a",
"sha256:5de53a28f40ef3c4fd57aeab6b590c2c663de87a5af76136ced519923d3efbb3",
"sha256:61b2b33ede821b94fa99ce0b09c9ece049c7067a33b279f343adfe35108a4ea7",
"sha256:6a3a9b0f45fd75dc05d8e93dc21b18fc1670135ec9544d1ad4acbcf6b86781d0",
"sha256:76ad8e4c69dadbb31bad17c16baee61c0d1a4a73bed2590b741b2e1a46d3edd0",
"sha256:7ba19b777dc00194d1b473180d4ca89a054dd18de27d0ee2e42a103ec9b7d014",
"sha256:7c1b7eab7a49aa96f3db1f716f0113a8a2e93c7375dd3d5d21c4941f1405c9c5",
"sha256:7fc0eee3046041387cbace9314926aa48b681202f8897f8bff3809967a049036",
"sha256:8ccd1c5fff1aa1427100ce188557fc31f1e0a383ad8ec42c559aabd4ff08802d",
"sha256:8e08dd76de80539d613654915a2f5196dbccc67448df291e69a88712ea21e24a",
"sha256:c18498c50c59263841862ea0501da9f2b3659c00db54abfbf823a80787fde8ce",
"sha256:c49db89d602c24928e68c0d510f4fcf8989d77defd01c973d6cbe27e684833b1",
"sha256:ce20044d0317649ddbb4e54dab3c1bcc7483c78c27d3f58ab3d0c7e6bc60d26a",
"sha256:d1071414dd06ca2eafa90c85a079169bfeb0e5f57fd0b45d44c092546fcd6fd9",
"sha256:d3be11ac43ab1a3e979dac80843b42226d5d3cccd3986f2e03152720a4297cd7",
"sha256:db603a1c235d110c860d5f39988ebc8218ee028f07a7cbc056ba6424372ca31b"
],
"version": "==4.5.2"
},
"protobuf": { "protobuf": {
"hashes": [ "hashes": [
"sha256:125713564d8cfed7610e52444c9769b8dcb0b55e25cc7841f2290ee7bc86636f", "sha256:125713564d8cfed7610e52444c9769b8dcb0b55e25cc7841f2290ee7bc86636f",

View File

@ -33,5 +33,8 @@ This project is heavily inspired by, and borrows functionality from:
- [ ] Well-known Google types - [ ] Well-known Google types
- [ ] JSON that isn't completely naive. - [ ] JSON that isn't completely naive.
- [ ] Async service stubs - [ ] Async service stubs
- [x] Unary-unary
- [x] Server streaming response
- [ ] Client streaming request
- [ ] Python package - [ ] Python package
- [ ] Cleanup! - [ ] Cleanup!

View File

@ -3,6 +3,7 @@ import json
import struct import struct
from typing import ( from typing import (
get_type_hints, get_type_hints,
AsyncGenerator,
Union, Union,
Generator, Generator,
Any, Any,
@ -17,6 +18,9 @@ from typing import (
) )
import dataclasses import dataclasses
import grpclib.client
import grpclib.const
import inspect import inspect
# Proto 3 data types # Proto 3 data types
@ -92,7 +96,14 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
def get_default(proto_type: int) -> Any: class _PLACEHOLDER:
pass
PLACEHOLDER: Any = _PLACEHOLDER()
def get_default(proto_type: str) -> Any:
"""Get the default (zero value) for a given type.""" """Get the default (zero value) for a given type."""
return { return {
TYPE_BOOL: False, TYPE_BOOL: False,
@ -114,8 +125,6 @@ class FieldMetadata:
proto_type: str proto_type: str
# Map information if the proto_type is a map # Map information if the proto_type is a map
map_types: Optional[Tuple[str, str]] map_types: Optional[Tuple[str, str]]
# Default value if given
default: Any
@staticmethod @staticmethod
def get(field: dataclasses.Field) -> "FieldMetadata": def get(field: dataclasses.Field) -> "FieldMetadata":
@ -124,23 +133,12 @@ class FieldMetadata:
def dataclass_field( def dataclass_field(
number: int, number: int, proto_type: str, map_types: Optional[Tuple[str, str]] = None
proto_type: str,
default: Any = None,
map_types: Optional[Tuple[str, str]] = None,
**kwargs: dict,
) -> dataclasses.Field: ) -> dataclasses.Field:
"""Creates a dataclass field with attached protobuf metadata.""" """Creates a dataclass field with attached protobuf metadata."""
if callable(default):
kwargs["default_factory"] = default
elif isinstance(default, dict) or isinstance(default, list):
kwargs["default_factory"] = lambda: default
else:
kwargs["default"] = default
return dataclasses.field( return dataclasses.field(
**kwargs, default=PLACEHOLDER,
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, default)}, metadata={"betterproto": FieldMetadata(number, proto_type, map_types)},
) )
@ -149,68 +147,68 @@ def dataclass_field(
# out at runtime. The generated dataclass variables are still typed correctly. # out at runtime. The generated dataclass variables are still typed correctly.
def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any: def enum_field(number: int) -> Any:
return dataclass_field(number, TYPE_ENUM, default=default) return dataclass_field(number, TYPE_ENUM)
def bool_field(number: int, default: Union[bool, Type[Iterable]] = 0) -> Any: def bool_field(number: int) -> Any:
return dataclass_field(number, TYPE_BOOL, default=default) return dataclass_field(number, TYPE_BOOL)
def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any: def int32_field(number: int) -> Any:
return dataclass_field(number, TYPE_INT32, default=default) return dataclass_field(number, TYPE_INT32)
def int64_field(number: int, default: int = 0) -> Any: def int64_field(number: int) -> Any:
return dataclass_field(number, TYPE_INT64, default=default) return dataclass_field(number, TYPE_INT64)
def uint32_field(number: int, default: int = 0) -> Any: def uint32_field(number: int) -> Any:
return dataclass_field(number, TYPE_UINT32, default=default) return dataclass_field(number, TYPE_UINT32)
def uint64_field(number: int, default: int = 0) -> Any: def uint64_field(number: int) -> Any:
return dataclass_field(number, TYPE_UINT64, default=default) return dataclass_field(number, TYPE_UINT64)
def sint32_field(number: int, default: int = 0) -> Any: def sint32_field(number: int) -> Any:
return dataclass_field(number, TYPE_SINT32, default=default) return dataclass_field(number, TYPE_SINT32)
def sint64_field(number: int, default: int = 0) -> Any: def sint64_field(number: int) -> Any:
return dataclass_field(number, TYPE_SINT64, default=default) return dataclass_field(number, TYPE_SINT64)
def float_field(number: int, default: float = 0.0) -> Any: def float_field(number: int) -> Any:
return dataclass_field(number, TYPE_FLOAT, default=default) return dataclass_field(number, TYPE_FLOAT)
def double_field(number: int, default: float = 0.0) -> Any: def double_field(number: int) -> Any:
return dataclass_field(number, TYPE_DOUBLE, default=default) return dataclass_field(number, TYPE_DOUBLE)
def fixed32_field(number: int, default: float = 0.0) -> Any: def fixed32_field(number: int) -> Any:
return dataclass_field(number, TYPE_FIXED32, default=default) return dataclass_field(number, TYPE_FIXED32)
def fixed64_field(number: int, default: float = 0.0) -> Any: def fixed64_field(number: int) -> Any:
return dataclass_field(number, TYPE_FIXED64, default=default) return dataclass_field(number, TYPE_FIXED64)
def sfixed32_field(number: int, default: float = 0.0) -> Any: def sfixed32_field(number: int) -> Any:
return dataclass_field(number, TYPE_SFIXED32, default=default) return dataclass_field(number, TYPE_SFIXED32)
def sfixed64_field(number: int, default: float = 0.0) -> Any: def sfixed64_field(number: int) -> Any:
return dataclass_field(number, TYPE_SFIXED64, default=default) return dataclass_field(number, TYPE_SFIXED64)
def string_field(number: int, default: str = "") -> Any: def string_field(number: int) -> Any:
return dataclass_field(number, TYPE_STRING, default=default) return dataclass_field(number, TYPE_STRING)
def bytes_field(number: int, default: bytes = b"") -> Any: def bytes_field(number: int) -> Any:
return dataclass_field(number, TYPE_BYTES, default=default) return dataclass_field(number, TYPE_BYTES)
def message_field(number: int) -> Any: def message_field(number: int) -> Any:
@ -218,9 +216,7 @@ def message_field(number: int) -> Any:
def map_field(number: int, key_type: str, value_type: str) -> Any: def map_field(number: int, key_type: str, value_type: str) -> Any:
return dataclass_field( return dataclass_field(number, TYPE_MAP, map_types=(key_type, value_type))
number, TYPE_MAP, default=dict, map_types=(key_type, value_type)
)
def _pack_fmt(proto_type: str) -> str: def _pack_fmt(proto_type: str) -> str:
@ -336,6 +332,7 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
number = num_wire >> 3 number = num_wire >> 3
wire_type = num_wire & 0x7 wire_type = num_wire & 0x7
decoded: Any
if wire_type == 0: if wire_type == 0:
decoded, i = decode_varint(value, i) decoded, i = decode_varint(value, i)
elif wire_type == 1: elif wire_type == 1:
@ -369,11 +366,15 @@ class Message(ABC):
# Set a default value for each field in the class after `__init__` has # Set a default value for each field in the class after `__init__` has
# already been run. # already been run.
for field in dataclasses.fields(self): for field in dataclasses.fields(self):
if getattr(self, field.name) != PLACEHOLDER:
# Skip anything not set (aka set to the sentinel value)
continue
meta = FieldMetadata.get(field) meta = FieldMetadata.get(field)
t = self._cls_for(field, index=-1) t = self._cls_for(field, index=-1)
value = 0 value: Any = 0
if meta.proto_type == TYPE_MAP: if meta.proto_type == TYPE_MAP:
# Maps cannot be repeated, so we check these first. # Maps cannot be repeated, so we check these first.
value = {} value = {}
@ -419,6 +420,7 @@ class Message(ABC):
continue continue
for k, v in value.items(): for k, v in value.items():
assert meta.map_types
sk = _serialize_single(1, meta.map_types[0], k) sk = _serialize_single(1, meta.map_types[0], k)
sv = _serialize_single(2, meta.map_types[1], v) sv = _serialize_single(2, meta.map_types[1], v)
output += _serialize_single(meta.number, meta.proto_type, sk + sv) output += _serialize_single(meta.number, meta.proto_type, sk + sv)
@ -431,10 +433,13 @@ class Message(ABC):
return output return output
# For compatibility with other libraries
SerializeToString = __bytes__
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type: def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
"""Get the message class for a field from the type hints.""" """Get the message class for a field from the type hints."""
module = inspect.getmodule(self) module = inspect.getmodule(self.__class__)
type_hints = get_type_hints(self, vars(module)) type_hints = get_type_hints(self.__class__, vars(module))
cls = type_hints[field.name] cls = type_hints[field.name]
if hasattr(cls, "__args__") and index >= 0: if hasattr(cls, "__args__") and index >= 0:
cls = type_hints[field.name].__args__[index] cls = type_hints[field.name].__args__[index]
@ -465,6 +470,7 @@ class Message(ABC):
elif meta.proto_type in [TYPE_MAP]: elif meta.proto_type in [TYPE_MAP]:
# TODO: This is slow, use a cache to make it faster since each # TODO: This is slow, use a cache to make it faster since each
# key/value pair will recreate the class. # key/value pair will recreate the class.
assert meta.map_types
kt = self._cls_for(field, index=0) kt = self._cls_for(field, index=0)
vt = self._cls_for(field, index=1) vt = self._cls_for(field, index=1)
Entry = dataclasses.make_dataclass( Entry = dataclasses.make_dataclass(
@ -479,7 +485,7 @@ class Message(ABC):
return value return value
def parse(self, data: bytes) -> T: def parse(self: T, data: bytes) -> T:
""" """
Parse the binary encoded Protobuf into this message instance. This Parse the binary encoded Protobuf into this message instance. This
returns the instance itself and is therefore assignable and chainable. returns the instance itself and is therefore assignable and chainable.
@ -490,6 +496,7 @@ class Message(ABC):
field = fields[parsed.number] field = fields[parsed.number]
meta = FieldMetadata.get(field) meta = FieldMetadata.get(field)
value: Any
if ( if (
parsed.wire_type == WIRE_LEN_DELIM parsed.wire_type == WIRE_LEN_DELIM
and meta.proto_type in PACKED_TYPES and meta.proto_type in PACKED_TYPES
@ -528,8 +535,15 @@ class Message(ABC):
# TODO: handle unknown fields # TODO: handle unknown fields
pass pass
from typing import cast
return self return self
# For compatibility with other libraries.
@classmethod
def FromString(cls: Type[T], data: bytes) -> T:
return cls().parse(data)
def to_dict(self) -> dict: def to_dict(self) -> dict:
""" """
Returns a dict representation of this message instance which can be Returns a dict representation of this message instance which can be
@ -557,11 +571,11 @@ class Message(ABC):
if v: if v:
output[field.name] = v output[field.name] = v
elif v != field.default: elif v != get_default(meta.proto_type):
output[field.name] = v output[field.name] = v
return output return output
def from_dict(self, value: dict) -> T: def from_dict(self: T, value: dict) -> T:
""" """
Parse the key/value pairs in `value` into this message instance. This Parse the key/value pairs in `value` into this message instance. This
returns the instance itself and is therefore assignable and chainable. returns the instance itself and is therefore assignable and chainable.
@ -578,7 +592,7 @@ class Message(ABC):
v.append(cls().from_dict(value[field.name][i])) v.append(cls().from_dict(value[field.name][i]))
else: else:
v.from_dict(value[field.name]) v.from_dict(value[field.name])
elif meta.proto_type == "map" and meta.map_types[1] == TYPE_MESSAGE: elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field.name) v = getattr(self, field.name)
cls = self._cls_for(field, index=1) cls = self._cls_for(field, index=1)
for k in value[field.name]: for k in value[field.name]:
@ -587,13 +601,48 @@ class Message(ABC):
setattr(self, field.name, value[field.name]) setattr(self, field.name, value[field.name])
return self return self
def to_json(self) -> bytes: def to_json(self) -> str:
"""Returns the encoded JSON representation of this message instance.""" """Returns the encoded JSON representation of this message instance."""
return json.dumps(self.to_dict()) return json.dumps(self.to_dict())
def from_json(self, value: bytes) -> T: def from_json(self: T, value: Union[str, bytes]) -> T:
""" """
Parse the key/value pairs in `value` into this message instance. This Parse the key/value pairs in `value` into this message instance. This
returns the instance itself and is therefore assignable and chainable. returns the instance itself and is therefore assignable and chainable.
""" """
return self.from_dict(json.loads(value)) return self.from_dict(json.loads(value))
ResponseType = TypeVar("ResponseType", bound="Message")
class ServiceStub(ABC):
"""
Base class for async gRPC service stubs.
"""
def __init__(self, channel: grpclib.client.Channel) -> None:
self.channel = channel
async def _unary_unary(
self, route: str, request_type: Type, response_type: Type[T], request: Any
) -> T:
"""Make a unary request and return the response."""
async with self.channel.request(
route, grpclib.const.Cardinality.UNARY_UNARY, request_type, response_type
) as stream:
await stream.send_message(request, end=True)
response = await stream.recv_message()
assert response is not None
return response
async def _unary_stream(
self, route: str, request_type: Type, response_type: Type[T], request: Any
) -> AsyncGenerator[T, None]:
"""Make a unary request and return the stream response iterator."""
async with self.channel.request(
route, grpclib.const.Cardinality.UNARY_STREAM, request_type, response_type
) as stream:
await stream.send_message(request, end=True)
async for message in stream:
yield message

View File

@ -1,12 +1,15 @@
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# source: {{ description.filename }} # sources: {{ ', '.join(description.files) }}
# plugin: python-betterproto # plugin: python-betterproto
{% if description.enums %}import enum {% if description.enums %}import enum
{% endif %} {% endif %}
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List from typing import AsyncGenerator, Dict, List, Optional
import betterproto import betterproto
{% if description.services %}
import grpclib
{% endif %}
{% for i in description.imports %} {% for i in description.imports %}
{{ i }} {{ i }}
@ -48,3 +51,36 @@ class {{ message.name }}(betterproto.Message):
{% endfor %} {% endfor %}
{% for service in description.services %}
class {{ service.name }}Stub(betterproto.ServiceStub):
{% for method in service.methods %}
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
request = {{ method.input }}()
{% for field in method.input_message.properties %}
{% if field.field_type == 'message' %}
if {{ field.name }} is not None:
request.{{ field.name }} = {{ field.name }}
{% else %}
request.{{ field.name }} = {{ field.name }}
{% endif %}
{% endfor %}
{% if method.server_streaming %}
async for response in self._unary_stream(
"{{ method.route }}",
{{ method.input }},
{{ method.output }},
request,
):
yield response
{% else %}
return await self._unary_unary(
"{{ method.route }}",
{{ method.input }},
{{ method.output }},
request,
)
{% endif %}
{% endfor %}
{% endfor %}

View File

@ -5,6 +5,7 @@ import sys
import itertools import itertools
import json import json
import os.path import os.path
import re
from typing import Tuple, Any, List from typing import Tuple, Any, List
import textwrap import textwrap
@ -13,6 +14,7 @@ from google.protobuf.descriptor_pb2 import (
EnumDescriptorProto, EnumDescriptorProto,
FileDescriptorProto, FileDescriptorProto,
FieldDescriptorProto, FieldDescriptorProto,
ServiceDescriptorProto,
) )
from google.protobuf.compiler import plugin_pb2 as plugin from google.protobuf.compiler import plugin_pb2 as plugin
@ -21,6 +23,32 @@ from google.protobuf.compiler import plugin_pb2 as plugin
from jinja2 import Environment, PackageLoader from jinja2 import Environment, PackageLoader
def snake_case(value: str) -> str:
return (
re.sub(r"(?<=[a-z])[A-Z]|[A-Z](?=[^A-Z])", r"_\g<0>", value).lower().strip("_")
)
def get_ref_type(package: str, imports: set, type_name: str) -> str:
"""
Return a Python type name for a proto type reference. Adds the import if
necessary.
"""
type_name = type_name.lstrip(".")
if type_name.startswith(package):
# This is the current package, which has nested types flattened.
type_name = f'"{type_name.lstrip(package).lstrip(".").replace(".", "")}"'
if "." in type_name:
# This is imported from another package. No need
# to use a forward ref and we need to add the import.
parts = type_name.split(".")
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
type_name = f"{parts[-2]}.{parts[-1]}"
return type_name
def py_type( def py_type(
package: str, package: str,
imports: set, imports: set,
@ -37,35 +65,29 @@ def py_type(
return "str" return "str"
elif descriptor.type in [11, 14]: elif descriptor.type in [11, 14]:
# Type referencing another defined Message or a named enum # Type referencing another defined Message or a named enum
message_type = descriptor.type_name.lstrip(".") return get_ref_type(package, imports, descriptor.type_name)
if message_type.startswith(package):
# This is the current package, which has nested types flattened.
message_type = (
f'"{message_type.lstrip(package).lstrip(".").replace(".", "")}"'
)
if "." in message_type:
# This is imported from another package. No need
# to use a forward ref and we need to add the import.
parts = message_type.split(".")
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
message_type = f"{parts[-2]}.{parts[-1]}"
# print(
# descriptor.name,
# package,
# descriptor.type_name,
# message_type,
# file=sys.stderr,
# )
return message_type
elif descriptor.type == 12: elif descriptor.type == 12:
return "bytes" return "bytes"
else: else:
raise NotImplementedError(f"Unknown type {descriptor.type}") raise NotImplementedError(f"Unknown type {descriptor.type}")
def get_py_zero(type_num: int) -> str:
zero = 0
if type_num in []:
zero = 0.0
elif type_num == 8:
zero = "False"
elif type_num == 9:
zero = '""'
elif type_num == 11:
zero = "None"
elif type_num == 12:
zero = 'b""'
return zero
def traverse(proto_file): def traverse(proto_file):
def _traverse(path, items): def _traverse(path, items):
for i, item in enumerate(items): for i, item in enumerate(items):
@ -73,6 +95,7 @@ def traverse(proto_file):
if isinstance(item, DescriptorProto): if isinstance(item, DescriptorProto):
for enum in item.enum_type: for enum in item.enum_type:
enum.name = item.name + enum.name
yield enum, path + [i, 4] yield enum, path + [i, 4]
if item.nested_type: if item.nested_type:
@ -103,7 +126,8 @@ def get_comment(proto_file, path: List[int]) -> str:
lines[0] = lines[0].strip('"') lines[0] = lines[0].strip('"')
return f' """{lines[0]}"""' return f' """{lines[0]}"""'
else: else:
return f' """\n{" ".join(lines)}\n """' joined = "\n ".join(lines)
return f' """\n {joined}\n """'
return "" return ""
@ -116,10 +140,6 @@ def generate_code(request, response):
) )
template = env.get_template("main.py") template = env.get_template("main.py")
# TODO: Refactor below to generate a single file per package if packages
# are being used, otherwise one output for each input. Figure out how to
# set up relative imports when needed and change the Message type refs to
# use the import names when not in the current module.
output_map = {} output_map = {}
for proto_file in request.proto_file: for proto_file in request.proto_file:
out = proto_file.package out = proto_file.package
@ -136,7 +156,16 @@ def generate_code(request, response):
for filename, options in output_map.items(): for filename, options in output_map.items():
package = options["package"] package = options["package"]
# print(package, filename, file=sys.stderr) # print(package, filename, file=sys.stderr)
output = {"package": package, "imports": set(), "messages": [], "enums": []} output = {
"package": package,
"files": [f.name for f in options["files"]],
"imports": set(),
"messages": [],
"enums": [],
"services": [],
}
type_mapping = {}
for proto_file in options["files"]: for proto_file in options["files"]:
# print(proto_file.message_type, file=sys.stderr) # print(proto_file.message_type, file=sys.stderr)
@ -164,6 +193,7 @@ def generate_code(request, response):
for i, f in enumerate(item.field): for i, f in enumerate(item.field):
t = py_type(package, output["imports"], item, f) t = py_type(package, output["imports"], item, f)
zero = get_py_zero(f.type)
repeated = False repeated = False
packed = False packed = False
@ -172,12 +202,16 @@ def generate_code(request, response):
map_types = None map_types = None
if f.type == 11: if f.type == 11:
# This might be a map... # This might be a map...
message_type = f.type_name.split(".").pop() message_type = f.type_name.split(".").pop().lower()
map_entry = f"{f.name.capitalize()}Entry" # message_type = py_type(package)
map_entry = f"{f.name.replace('_', '').lower()}entry"
if message_type == map_entry: if message_type == map_entry:
for nested in item.nested_type: for nested in item.nested_type:
if nested.name == map_entry: if (
nested.name.replace("_", "").lower()
== map_entry
):
if nested.options.map_entry: if nested.options.map_entry:
# print("Found a map!", file=sys.stderr) # print("Found a map!", file=sys.stderr)
k = py_type( k = py_type(
@ -203,6 +237,7 @@ def generate_code(request, response):
# Repeated field # Repeated field
repeated = True repeated = True
t = f"List[{t}]" t = f"List[{t}]"
zero = "[]"
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
packed = True packed = True
@ -216,6 +251,7 @@ def generate_code(request, response):
"field_type": field_type, "field_type": field_type,
"map_types": map_types, "map_types": map_types,
"type": t, "type": t,
"zero": zero,
"repeated": repeated, "repeated": repeated,
"packed": packed, "packed": packed,
} }
@ -223,7 +259,6 @@ def generate_code(request, response):
# print(f, file=sys.stderr) # print(f, file=sys.stderr)
output["messages"].append(data) output["messages"].append(data)
elif isinstance(item, EnumDescriptorProto): elif isinstance(item, EnumDescriptorProto):
# print(item.name, path, file=sys.stderr) # print(item.name, path, file=sys.stderr)
data.update( data.update(
@ -243,6 +278,44 @@ def generate_code(request, response):
output["enums"].append(data) output["enums"].append(data)
for service in proto_file.service:
# print(service, file=sys.stderr)
# TODO: comments
data = {"name": service.name, "methods": []}
for method in service.method:
if method.client_streaming:
raise NotImplementedError("Client streaming not yet supported")
input_message = None
input_type = get_ref_type(
package, output["imports"], method.input_type
).strip('"')
for msg in output["messages"]:
if msg["name"] == input_type:
input_message = msg
break
data["methods"].append(
{
"name": method.name,
"py_name": snake_case(method.name),
"route": f"/{package}.{service.name}/{method.name}",
"input": get_ref_type(
package, output["imports"], method.input_type
).strip('"'),
"input_message": input_message,
"output": get_ref_type(
package, output["imports"], method.output_type
).strip('"'),
"client_streaming": method.client_streaming,
"server_streaming": method.server_streaming,
}
)
output["services"].append(data)
output["imports"] = sorted(output["imports"]) output["imports"] = sorted(output["imports"])
# Fill response # Fill response
@ -256,7 +329,7 @@ def generate_code(request, response):
inits = set([""]) inits = set([""])
for f in response.file: for f in response.file:
# Ensure output paths exist # Ensure output paths exist
print(f.name, file=sys.stderr) # print(f.name, file=sys.stderr)
dirnames = os.path.dirname(f.name) dirnames = os.path.dirname(f.name)
if dirnames: if dirnames:
os.makedirs(dirnames, exist_ok=True) os.makedirs(dirnames, exist_ok=True)