From d93214eccdf21294b7285f5b40649d8db31b2e72 Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Wed, 16 Oct 2019 22:52:38 -0700 Subject: [PATCH] Implement basic async gRPC support --- Pipfile | 1 + Pipfile.lock | 64 ++++++++++++- README.md | 3 + betterproto/__init__.py | 171 ++++++++++++++++++++++------------ betterproto/templates/main.py | 40 +++++++- protoc-gen-betterpy.py | 141 +++++++++++++++++++++------- 6 files changed, 322 insertions(+), 98 deletions(-) diff --git a/Pipfile b/Pipfile index 63f99a3..aae4234 100644 --- a/Pipfile +++ b/Pipfile @@ -13,6 +13,7 @@ rope = "*" [packages] protobuf = "*" jinja2 = "*" +grpclib = "*" [requires] python_version = "3.7" diff --git a/Pipfile.lock b/Pipfile.lock index 03ab532..2b39de5 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "6c1797fb4eb73be97ca566206527c9d648b90f38c5bf2caf4b69537cd325ced9" + "sha256": "f698150037f2a8ac554e4d37ecd4619ba35d1aa570f5b641d048ec9c6b23eb40" }, "pipfile-spec": 6, "requires": { @@ -16,6 +16,34 @@ ] }, "default": { + "grpclib": { + "hashes": [ + "sha256:d19e2ea87cb073e5b0825dfee15336fd2b1c09278d271816e04c90faddc107ea" + ], + "index": "pypi", + "version": "==0.3.0" + }, + "h2": { + "hashes": [ + "sha256:ac377fcf586314ef3177bfd90c12c7826ab0840edeb03f0f24f511858326049e", + "sha256:b8a32bd282594424c0ac55845377eea13fa54fe4a8db012f3a198ed923dc3ab4" + ], + "version": "==3.1.1" + }, + "hpack": { + "hashes": [ + "sha256:0edd79eda27a53ba5be2dfabf3b15780928a0dff6eb0c60a3d6767720e970c89", + "sha256:8eec9c1f4bfae3408a3f30500261f7e6a65912dc138526ea054f9ad98892e9d2" + ], + "version": "==3.0.0" + }, + "hyperframe": { + "hashes": [ + "sha256:5187962cb16dcc078f23cb5a4b110098d546c3f41ff2d4038a9896893bbd0b40", + "sha256:a9f5c17f2cc3c719b917c4f33ed1c61bd1f8dfac4b1bd23b7c80b3400971b41f" + ], + "version": "==5.2.0" + }, "jinja2": { "hashes": [ "sha256:74320bb91f31270f9551d46522e33af46a80c3d619f4a4bf42b3164d30b5911f", @@ -57,6 +85,40 @@ ], "version": "==1.1.1" }, + "multidict": { + "hashes": [ + "sha256:024b8129695a952ebd93373e45b5d341dbb87c17ce49637b34000093f243dd4f", + "sha256:041e9442b11409be5e4fc8b6a97e4bcead758ab1e11768d1e69160bdde18acc3", + "sha256:045b4dd0e5f6121e6f314d81759abd2c257db4634260abcfe0d3f7083c4908ef", + "sha256:047c0a04e382ef8bd74b0de01407e8d8632d7d1b4db6f2561106af812a68741b", + "sha256:068167c2d7bbeebd359665ac4fff756be5ffac9cda02375b5c5a7c4777038e73", + "sha256:148ff60e0fffa2f5fad2eb25aae7bef23d8f3b8bdaf947a65cdbe84a978092bc", + "sha256:1d1c77013a259971a72ddaa83b9f42c80a93ff12df6a4723be99d858fa30bee3", + "sha256:1d48bc124a6b7a55006d97917f695effa9725d05abe8ee78fd60d6588b8344cd", + "sha256:31dfa2fc323097f8ad7acd41aa38d7c614dd1960ac6681745b6da124093dc351", + "sha256:34f82db7f80c49f38b032c5abb605c458bac997a6c3142e0d6c130be6fb2b941", + "sha256:3d5dd8e5998fb4ace04789d1d008e2bb532de501218519d70bb672c4c5a2fc5d", + "sha256:4a6ae52bd3ee41ee0f3acf4c60ceb3f44e0e3bc52ab7da1c2b2aa6703363a3d1", + "sha256:4b02a3b2a2f01d0490dd39321c74273fed0568568ea0e7ea23e02bd1fb10a10b", + "sha256:4b843f8e1dd6a3195679d9838eb4670222e8b8d01bc36c9894d6c3538316fa0a", + "sha256:5de53a28f40ef3c4fd57aeab6b590c2c663de87a5af76136ced519923d3efbb3", + "sha256:61b2b33ede821b94fa99ce0b09c9ece049c7067a33b279f343adfe35108a4ea7", + "sha256:6a3a9b0f45fd75dc05d8e93dc21b18fc1670135ec9544d1ad4acbcf6b86781d0", + "sha256:76ad8e4c69dadbb31bad17c16baee61c0d1a4a73bed2590b741b2e1a46d3edd0", + "sha256:7ba19b777dc00194d1b473180d4ca89a054dd18de27d0ee2e42a103ec9b7d014", + "sha256:7c1b7eab7a49aa96f3db1f716f0113a8a2e93c7375dd3d5d21c4941f1405c9c5", + "sha256:7fc0eee3046041387cbace9314926aa48b681202f8897f8bff3809967a049036", + "sha256:8ccd1c5fff1aa1427100ce188557fc31f1e0a383ad8ec42c559aabd4ff08802d", + "sha256:8e08dd76de80539d613654915a2f5196dbccc67448df291e69a88712ea21e24a", + "sha256:c18498c50c59263841862ea0501da9f2b3659c00db54abfbf823a80787fde8ce", + "sha256:c49db89d602c24928e68c0d510f4fcf8989d77defd01c973d6cbe27e684833b1", + "sha256:ce20044d0317649ddbb4e54dab3c1bcc7483c78c27d3f58ab3d0c7e6bc60d26a", + "sha256:d1071414dd06ca2eafa90c85a079169bfeb0e5f57fd0b45d44c092546fcd6fd9", + "sha256:d3be11ac43ab1a3e979dac80843b42226d5d3cccd3986f2e03152720a4297cd7", + "sha256:db603a1c235d110c860d5f39988ebc8218ee028f07a7cbc056ba6424372ca31b" + ], + "version": "==4.5.2" + }, "protobuf": { "hashes": [ "sha256:125713564d8cfed7610e52444c9769b8dcb0b55e25cc7841f2290ee7bc86636f", diff --git a/README.md b/README.md index 9da3c77..b71f6f6 100644 --- a/README.md +++ b/README.md @@ -33,5 +33,8 @@ This project is heavily inspired by, and borrows functionality from: - [ ] Well-known Google types - [ ] JSON that isn't completely naive. - [ ] Async service stubs + - [x] Unary-unary + - [x] Server streaming response + - [ ] Client streaming request - [ ] Python package - [ ] Cleanup! diff --git a/betterproto/__init__.py b/betterproto/__init__.py index f9b6f15..6ea3891 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -3,6 +3,7 @@ import json import struct from typing import ( get_type_hints, + AsyncGenerator, Union, Generator, Any, @@ -17,6 +18,9 @@ from typing import ( ) import dataclasses +import grpclib.client +import grpclib.const + import inspect # Proto 3 data types @@ -92,7 +96,14 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64] WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] -def get_default(proto_type: int) -> Any: +class _PLACEHOLDER: + pass + + +PLACEHOLDER: Any = _PLACEHOLDER() + + +def get_default(proto_type: str) -> Any: """Get the default (zero value) for a given type.""" return { TYPE_BOOL: False, @@ -114,8 +125,6 @@ class FieldMetadata: proto_type: str # Map information if the proto_type is a map map_types: Optional[Tuple[str, str]] - # Default value if given - default: Any @staticmethod def get(field: dataclasses.Field) -> "FieldMetadata": @@ -124,23 +133,12 @@ class FieldMetadata: def dataclass_field( - number: int, - proto_type: str, - default: Any = None, - map_types: Optional[Tuple[str, str]] = None, - **kwargs: dict, + number: int, proto_type: str, map_types: Optional[Tuple[str, str]] = None ) -> dataclasses.Field: """Creates a dataclass field with attached protobuf metadata.""" - if callable(default): - kwargs["default_factory"] = default - elif isinstance(default, dict) or isinstance(default, list): - kwargs["default_factory"] = lambda: default - else: - kwargs["default"] = default - return dataclasses.field( - **kwargs, - metadata={"betterproto": FieldMetadata(number, proto_type, map_types, default)}, + default=PLACEHOLDER, + metadata={"betterproto": FieldMetadata(number, proto_type, map_types)}, ) @@ -149,68 +147,68 @@ def dataclass_field( # out at runtime. The generated dataclass variables are still typed correctly. -def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any: - return dataclass_field(number, TYPE_ENUM, default=default) +def enum_field(number: int) -> Any: + return dataclass_field(number, TYPE_ENUM) -def bool_field(number: int, default: Union[bool, Type[Iterable]] = 0) -> Any: - return dataclass_field(number, TYPE_BOOL, default=default) +def bool_field(number: int) -> Any: + return dataclass_field(number, TYPE_BOOL) -def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any: - return dataclass_field(number, TYPE_INT32, default=default) +def int32_field(number: int) -> Any: + return dataclass_field(number, TYPE_INT32) -def int64_field(number: int, default: int = 0) -> Any: - return dataclass_field(number, TYPE_INT64, default=default) +def int64_field(number: int) -> Any: + return dataclass_field(number, TYPE_INT64) -def uint32_field(number: int, default: int = 0) -> Any: - return dataclass_field(number, TYPE_UINT32, default=default) +def uint32_field(number: int) -> Any: + return dataclass_field(number, TYPE_UINT32) -def uint64_field(number: int, default: int = 0) -> Any: - return dataclass_field(number, TYPE_UINT64, default=default) +def uint64_field(number: int) -> Any: + return dataclass_field(number, TYPE_UINT64) -def sint32_field(number: int, default: int = 0) -> Any: - return dataclass_field(number, TYPE_SINT32, default=default) +def sint32_field(number: int) -> Any: + return dataclass_field(number, TYPE_SINT32) -def sint64_field(number: int, default: int = 0) -> Any: - return dataclass_field(number, TYPE_SINT64, default=default) +def sint64_field(number: int) -> Any: + return dataclass_field(number, TYPE_SINT64) -def float_field(number: int, default: float = 0.0) -> Any: - return dataclass_field(number, TYPE_FLOAT, default=default) +def float_field(number: int) -> Any: + return dataclass_field(number, TYPE_FLOAT) -def double_field(number: int, default: float = 0.0) -> Any: - return dataclass_field(number, TYPE_DOUBLE, default=default) +def double_field(number: int) -> Any: + return dataclass_field(number, TYPE_DOUBLE) -def fixed32_field(number: int, default: float = 0.0) -> Any: - return dataclass_field(number, TYPE_FIXED32, default=default) +def fixed32_field(number: int) -> Any: + return dataclass_field(number, TYPE_FIXED32) -def fixed64_field(number: int, default: float = 0.0) -> Any: - return dataclass_field(number, TYPE_FIXED64, default=default) +def fixed64_field(number: int) -> Any: + return dataclass_field(number, TYPE_FIXED64) -def sfixed32_field(number: int, default: float = 0.0) -> Any: - return dataclass_field(number, TYPE_SFIXED32, default=default) +def sfixed32_field(number: int) -> Any: + return dataclass_field(number, TYPE_SFIXED32) -def sfixed64_field(number: int, default: float = 0.0) -> Any: - return dataclass_field(number, TYPE_SFIXED64, default=default) +def sfixed64_field(number: int) -> Any: + return dataclass_field(number, TYPE_SFIXED64) -def string_field(number: int, default: str = "") -> Any: - return dataclass_field(number, TYPE_STRING, default=default) +def string_field(number: int) -> Any: + return dataclass_field(number, TYPE_STRING) -def bytes_field(number: int, default: bytes = b"") -> Any: - return dataclass_field(number, TYPE_BYTES, default=default) +def bytes_field(number: int) -> Any: + return dataclass_field(number, TYPE_BYTES) def message_field(number: int) -> Any: @@ -218,9 +216,7 @@ def message_field(number: int) -> Any: def map_field(number: int, key_type: str, value_type: str) -> Any: - return dataclass_field( - number, TYPE_MAP, default=dict, map_types=(key_type, value_type) - ) + return dataclass_field(number, TYPE_MAP, map_types=(key_type, value_type)) def _pack_fmt(proto_type: str) -> str: @@ -336,6 +332,7 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: number = num_wire >> 3 wire_type = num_wire & 0x7 + decoded: Any if wire_type == 0: decoded, i = decode_varint(value, i) elif wire_type == 1: @@ -369,11 +366,15 @@ class Message(ABC): # Set a default value for each field in the class after `__init__` has # already been run. for field in dataclasses.fields(self): + if getattr(self, field.name) != PLACEHOLDER: + # Skip anything not set (aka set to the sentinel value) + continue + meta = FieldMetadata.get(field) t = self._cls_for(field, index=-1) - value = 0 + value: Any = 0 if meta.proto_type == TYPE_MAP: # Maps cannot be repeated, so we check these first. value = {} @@ -419,6 +420,7 @@ class Message(ABC): continue for k, v in value.items(): + assert meta.map_types sk = _serialize_single(1, meta.map_types[0], k) sv = _serialize_single(2, meta.map_types[1], v) output += _serialize_single(meta.number, meta.proto_type, sk + sv) @@ -431,10 +433,13 @@ class Message(ABC): return output + # For compatibility with other libraries + SerializeToString = __bytes__ + def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type: """Get the message class for a field from the type hints.""" - module = inspect.getmodule(self) - type_hints = get_type_hints(self, vars(module)) + module = inspect.getmodule(self.__class__) + type_hints = get_type_hints(self.__class__, vars(module)) cls = type_hints[field.name] if hasattr(cls, "__args__") and index >= 0: cls = type_hints[field.name].__args__[index] @@ -465,6 +470,7 @@ class Message(ABC): elif meta.proto_type in [TYPE_MAP]: # TODO: This is slow, use a cache to make it faster since each # key/value pair will recreate the class. + assert meta.map_types kt = self._cls_for(field, index=0) vt = self._cls_for(field, index=1) Entry = dataclasses.make_dataclass( @@ -479,7 +485,7 @@ class Message(ABC): return value - def parse(self, data: bytes) -> T: + def parse(self: T, data: bytes) -> T: """ Parse the binary encoded Protobuf into this message instance. This returns the instance itself and is therefore assignable and chainable. @@ -490,6 +496,7 @@ class Message(ABC): field = fields[parsed.number] meta = FieldMetadata.get(field) + value: Any if ( parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES @@ -528,8 +535,15 @@ class Message(ABC): # TODO: handle unknown fields pass + from typing import cast + return self + # For compatibility with other libraries. + @classmethod + def FromString(cls: Type[T], data: bytes) -> T: + return cls().parse(data) + def to_dict(self) -> dict: """ Returns a dict representation of this message instance which can be @@ -557,11 +571,11 @@ class Message(ABC): if v: output[field.name] = v - elif v != field.default: + elif v != get_default(meta.proto_type): output[field.name] = v return output - def from_dict(self, value: dict) -> T: + def from_dict(self: T, value: dict) -> T: """ Parse the key/value pairs in `value` into this message instance. This returns the instance itself and is therefore assignable and chainable. @@ -578,7 +592,7 @@ class Message(ABC): v.append(cls().from_dict(value[field.name][i])) else: v.from_dict(value[field.name]) - elif meta.proto_type == "map" and meta.map_types[1] == TYPE_MESSAGE: + elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: v = getattr(self, field.name) cls = self._cls_for(field, index=1) for k in value[field.name]: @@ -587,13 +601,48 @@ class Message(ABC): setattr(self, field.name, value[field.name]) return self - def to_json(self) -> bytes: + def to_json(self) -> str: """Returns the encoded JSON representation of this message instance.""" return json.dumps(self.to_dict()) - def from_json(self, value: bytes) -> T: + def from_json(self: T, value: Union[str, bytes]) -> T: """ Parse the key/value pairs in `value` into this message instance. This returns the instance itself and is therefore assignable and chainable. """ return self.from_dict(json.loads(value)) + + +ResponseType = TypeVar("ResponseType", bound="Message") + + +class ServiceStub(ABC): + """ + Base class for async gRPC service stubs. + """ + + def __init__(self, channel: grpclib.client.Channel) -> None: + self.channel = channel + + async def _unary_unary( + self, route: str, request_type: Type, response_type: Type[T], request: Any + ) -> T: + """Make a unary request and return the response.""" + async with self.channel.request( + route, grpclib.const.Cardinality.UNARY_UNARY, request_type, response_type + ) as stream: + await stream.send_message(request, end=True) + response = await stream.recv_message() + assert response is not None + return response + + async def _unary_stream( + self, route: str, request_type: Type, response_type: Type[T], request: Any + ) -> AsyncGenerator[T, None]: + """Make a unary request and return the stream response iterator.""" + async with self.channel.request( + route, grpclib.const.Cardinality.UNARY_STREAM, request_type, response_type + ) as stream: + await stream.send_message(request, end=True) + async for message in stream: + yield message diff --git a/betterproto/templates/main.py b/betterproto/templates/main.py index 08b5dc1..7782707 100644 --- a/betterproto/templates/main.py +++ b/betterproto/templates/main.py @@ -1,12 +1,15 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! -# source: {{ description.filename }} +# sources: {{ ', '.join(description.files) }} # plugin: python-betterproto {% if description.enums %}import enum {% endif %} from dataclasses import dataclass -from typing import Dict, List +from typing import AsyncGenerator, Dict, List, Optional import betterproto +{% if description.services %} +import grpclib +{% endif %} {% for i in description.imports %} {{ i }} @@ -48,3 +51,36 @@ class {{ message.name }}(betterproto.Message): {% endfor %} +{% for service in description.services %} +class {{ service.name }}Stub(betterproto.ServiceStub): + {% for method in service.methods %} + async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}: + request = {{ method.input }}() + {% for field in method.input_message.properties %} + {% if field.field_type == 'message' %} + if {{ field.name }} is not None: + request.{{ field.name }} = {{ field.name }} + {% else %} + request.{{ field.name }} = {{ field.name }} + {% endif %} + {% endfor %} + + {% if method.server_streaming %} + async for response in self._unary_stream( + "{{ method.route }}", + {{ method.input }}, + {{ method.output }}, + request, + ): + yield response + {% else %} + return await self._unary_unary( + "{{ method.route }}", + {{ method.input }}, + {{ method.output }}, + request, + ) + {% endif %} + + {% endfor %} +{% endfor %} diff --git a/protoc-gen-betterpy.py b/protoc-gen-betterpy.py index 86f35f0..6f578ca 100755 --- a/protoc-gen-betterpy.py +++ b/protoc-gen-betterpy.py @@ -5,6 +5,7 @@ import sys import itertools import json import os.path +import re from typing import Tuple, Any, List import textwrap @@ -13,6 +14,7 @@ from google.protobuf.descriptor_pb2 import ( EnumDescriptorProto, FileDescriptorProto, FieldDescriptorProto, + ServiceDescriptorProto, ) from google.protobuf.compiler import plugin_pb2 as plugin @@ -21,6 +23,32 @@ from google.protobuf.compiler import plugin_pb2 as plugin from jinja2 import Environment, PackageLoader +def snake_case(value: str) -> str: + return ( + re.sub(r"(?<=[a-z])[A-Z]|[A-Z](?=[^A-Z])", r"_\g<0>", value).lower().strip("_") + ) + + +def get_ref_type(package: str, imports: set, type_name: str) -> str: + """ + Return a Python type name for a proto type reference. Adds the import if + necessary. + """ + type_name = type_name.lstrip(".") + if type_name.startswith(package): + # This is the current package, which has nested types flattened. + type_name = f'"{type_name.lstrip(package).lstrip(".").replace(".", "")}"' + + if "." in type_name: + # This is imported from another package. No need + # to use a forward ref and we need to add the import. + parts = type_name.split(".") + imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}") + type_name = f"{parts[-2]}.{parts[-1]}" + + return type_name + + def py_type( package: str, imports: set, @@ -37,35 +65,29 @@ def py_type( return "str" elif descriptor.type in [11, 14]: # Type referencing another defined Message or a named enum - message_type = descriptor.type_name.lstrip(".") - if message_type.startswith(package): - # This is the current package, which has nested types flattened. - message_type = ( - f'"{message_type.lstrip(package).lstrip(".").replace(".", "")}"' - ) - - if "." in message_type: - # This is imported from another package. No need - # to use a forward ref and we need to add the import. - parts = message_type.split(".") - imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}") - message_type = f"{parts[-2]}.{parts[-1]}" - - # print( - # descriptor.name, - # package, - # descriptor.type_name, - # message_type, - # file=sys.stderr, - # ) - - return message_type + return get_ref_type(package, imports, descriptor.type_name) elif descriptor.type == 12: return "bytes" else: raise NotImplementedError(f"Unknown type {descriptor.type}") +def get_py_zero(type_num: int) -> str: + zero = 0 + if type_num in []: + zero = 0.0 + elif type_num == 8: + zero = "False" + elif type_num == 9: + zero = '""' + elif type_num == 11: + zero = "None" + elif type_num == 12: + zero = 'b""' + + return zero + + def traverse(proto_file): def _traverse(path, items): for i, item in enumerate(items): @@ -73,6 +95,7 @@ def traverse(proto_file): if isinstance(item, DescriptorProto): for enum in item.enum_type: + enum.name = item.name + enum.name yield enum, path + [i, 4] if item.nested_type: @@ -103,7 +126,8 @@ def get_comment(proto_file, path: List[int]) -> str: lines[0] = lines[0].strip('"') return f' """{lines[0]}"""' else: - return f' """\n{" ".join(lines)}\n """' + joined = "\n ".join(lines) + return f' """\n {joined}\n """' return "" @@ -116,10 +140,6 @@ def generate_code(request, response): ) template = env.get_template("main.py") - # TODO: Refactor below to generate a single file per package if packages - # are being used, otherwise one output for each input. Figure out how to - # set up relative imports when needed and change the Message type refs to - # use the import names when not in the current module. output_map = {} for proto_file in request.proto_file: out = proto_file.package @@ -136,7 +156,16 @@ def generate_code(request, response): for filename, options in output_map.items(): package = options["package"] # print(package, filename, file=sys.stderr) - output = {"package": package, "imports": set(), "messages": [], "enums": []} + output = { + "package": package, + "files": [f.name for f in options["files"]], + "imports": set(), + "messages": [], + "enums": [], + "services": [], + } + + type_mapping = {} for proto_file in options["files"]: # print(proto_file.message_type, file=sys.stderr) @@ -164,6 +193,7 @@ def generate_code(request, response): for i, f in enumerate(item.field): t = py_type(package, output["imports"], item, f) + zero = get_py_zero(f.type) repeated = False packed = False @@ -172,12 +202,16 @@ def generate_code(request, response): map_types = None if f.type == 11: # This might be a map... - message_type = f.type_name.split(".").pop() - map_entry = f"{f.name.capitalize()}Entry" + message_type = f.type_name.split(".").pop().lower() + # message_type = py_type(package) + map_entry = f"{f.name.replace('_', '').lower()}entry" if message_type == map_entry: for nested in item.nested_type: - if nested.name == map_entry: + if ( + nested.name.replace("_", "").lower() + == map_entry + ): if nested.options.map_entry: # print("Found a map!", file=sys.stderr) k = py_type( @@ -203,6 +237,7 @@ def generate_code(request, response): # Repeated field repeated = True t = f"List[{t}]" + zero = "[]" if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: packed = True @@ -216,6 +251,7 @@ def generate_code(request, response): "field_type": field_type, "map_types": map_types, "type": t, + "zero": zero, "repeated": repeated, "packed": packed, } @@ -223,7 +259,6 @@ def generate_code(request, response): # print(f, file=sys.stderr) output["messages"].append(data) - elif isinstance(item, EnumDescriptorProto): # print(item.name, path, file=sys.stderr) data.update( @@ -243,6 +278,44 @@ def generate_code(request, response): output["enums"].append(data) + for service in proto_file.service: + # print(service, file=sys.stderr) + + # TODO: comments + data = {"name": service.name, "methods": []} + + for method in service.method: + if method.client_streaming: + raise NotImplementedError("Client streaming not yet supported") + + input_message = None + input_type = get_ref_type( + package, output["imports"], method.input_type + ).strip('"') + for msg in output["messages"]: + if msg["name"] == input_type: + input_message = msg + break + + data["methods"].append( + { + "name": method.name, + "py_name": snake_case(method.name), + "route": f"/{package}.{service.name}/{method.name}", + "input": get_ref_type( + package, output["imports"], method.input_type + ).strip('"'), + "input_message": input_message, + "output": get_ref_type( + package, output["imports"], method.output_type + ).strip('"'), + "client_streaming": method.client_streaming, + "server_streaming": method.server_streaming, + } + ) + + output["services"].append(data) + output["imports"] = sorted(output["imports"]) # Fill response @@ -256,7 +329,7 @@ def generate_code(request, response): inits = set([""]) for f in response.file: # Ensure output paths exist - print(f.name, file=sys.stderr) + # print(f.name, file=sys.stderr) dirnames = os.path.dirname(f.name) if dirnames: os.makedirs(dirnames, exist_ok=True)