Implement basic async gRPC support

This commit is contained in:
Daniel G. Taylor
2019-10-16 22:52:38 -07:00
parent 41a96f65ee
commit d93214eccd
6 changed files with 322 additions and 98 deletions

View File

@@ -5,6 +5,7 @@ import sys
import itertools
import json
import os.path
import re
from typing import Tuple, Any, List
import textwrap
@@ -13,6 +14,7 @@ from google.protobuf.descriptor_pb2 import (
EnumDescriptorProto,
FileDescriptorProto,
FieldDescriptorProto,
ServiceDescriptorProto,
)
from google.protobuf.compiler import plugin_pb2 as plugin
@@ -21,6 +23,32 @@ from google.protobuf.compiler import plugin_pb2 as plugin
from jinja2 import Environment, PackageLoader
def snake_case(value: str) -> str:
return (
re.sub(r"(?<=[a-z])[A-Z]|[A-Z](?=[^A-Z])", r"_\g<0>", value).lower().strip("_")
)
def get_ref_type(package: str, imports: set, type_name: str) -> str:
"""
Return a Python type name for a proto type reference. Adds the import if
necessary.
"""
type_name = type_name.lstrip(".")
if type_name.startswith(package):
# This is the current package, which has nested types flattened.
type_name = f'"{type_name.lstrip(package).lstrip(".").replace(".", "")}"'
if "." in type_name:
# This is imported from another package. No need
# to use a forward ref and we need to add the import.
parts = type_name.split(".")
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
type_name = f"{parts[-2]}.{parts[-1]}"
return type_name
def py_type(
package: str,
imports: set,
@@ -37,35 +65,29 @@ def py_type(
return "str"
elif descriptor.type in [11, 14]:
# Type referencing another defined Message or a named enum
message_type = descriptor.type_name.lstrip(".")
if message_type.startswith(package):
# This is the current package, which has nested types flattened.
message_type = (
f'"{message_type.lstrip(package).lstrip(".").replace(".", "")}"'
)
if "." in message_type:
# This is imported from another package. No need
# to use a forward ref and we need to add the import.
parts = message_type.split(".")
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
message_type = f"{parts[-2]}.{parts[-1]}"
# print(
# descriptor.name,
# package,
# descriptor.type_name,
# message_type,
# file=sys.stderr,
# )
return message_type
return get_ref_type(package, imports, descriptor.type_name)
elif descriptor.type == 12:
return "bytes"
else:
raise NotImplementedError(f"Unknown type {descriptor.type}")
def get_py_zero(type_num: int) -> str:
zero = 0
if type_num in []:
zero = 0.0
elif type_num == 8:
zero = "False"
elif type_num == 9:
zero = '""'
elif type_num == 11:
zero = "None"
elif type_num == 12:
zero = 'b""'
return zero
def traverse(proto_file):
def _traverse(path, items):
for i, item in enumerate(items):
@@ -73,6 +95,7 @@ def traverse(proto_file):
if isinstance(item, DescriptorProto):
for enum in item.enum_type:
enum.name = item.name + enum.name
yield enum, path + [i, 4]
if item.nested_type:
@@ -103,7 +126,8 @@ def get_comment(proto_file, path: List[int]) -> str:
lines[0] = lines[0].strip('"')
return f' """{lines[0]}"""'
else:
return f' """\n{" ".join(lines)}\n """'
joined = "\n ".join(lines)
return f' """\n {joined}\n """'
return ""
@@ -116,10 +140,6 @@ def generate_code(request, response):
)
template = env.get_template("main.py")
# TODO: Refactor below to generate a single file per package if packages
# are being used, otherwise one output for each input. Figure out how to
# set up relative imports when needed and change the Message type refs to
# use the import names when not in the current module.
output_map = {}
for proto_file in request.proto_file:
out = proto_file.package
@@ -136,7 +156,16 @@ def generate_code(request, response):
for filename, options in output_map.items():
package = options["package"]
# print(package, filename, file=sys.stderr)
output = {"package": package, "imports": set(), "messages": [], "enums": []}
output = {
"package": package,
"files": [f.name for f in options["files"]],
"imports": set(),
"messages": [],
"enums": [],
"services": [],
}
type_mapping = {}
for proto_file in options["files"]:
# print(proto_file.message_type, file=sys.stderr)
@@ -164,6 +193,7 @@ def generate_code(request, response):
for i, f in enumerate(item.field):
t = py_type(package, output["imports"], item, f)
zero = get_py_zero(f.type)
repeated = False
packed = False
@@ -172,12 +202,16 @@ def generate_code(request, response):
map_types = None
if f.type == 11:
# This might be a map...
message_type = f.type_name.split(".").pop()
map_entry = f"{f.name.capitalize()}Entry"
message_type = f.type_name.split(".").pop().lower()
# message_type = py_type(package)
map_entry = f"{f.name.replace('_', '').lower()}entry"
if message_type == map_entry:
for nested in item.nested_type:
if nested.name == map_entry:
if (
nested.name.replace("_", "").lower()
== map_entry
):
if nested.options.map_entry:
# print("Found a map!", file=sys.stderr)
k = py_type(
@@ -203,6 +237,7 @@ def generate_code(request, response):
# Repeated field
repeated = True
t = f"List[{t}]"
zero = "[]"
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
packed = True
@@ -216,6 +251,7 @@ def generate_code(request, response):
"field_type": field_type,
"map_types": map_types,
"type": t,
"zero": zero,
"repeated": repeated,
"packed": packed,
}
@@ -223,7 +259,6 @@ def generate_code(request, response):
# print(f, file=sys.stderr)
output["messages"].append(data)
elif isinstance(item, EnumDescriptorProto):
# print(item.name, path, file=sys.stderr)
data.update(
@@ -243,6 +278,44 @@ def generate_code(request, response):
output["enums"].append(data)
for service in proto_file.service:
# print(service, file=sys.stderr)
# TODO: comments
data = {"name": service.name, "methods": []}
for method in service.method:
if method.client_streaming:
raise NotImplementedError("Client streaming not yet supported")
input_message = None
input_type = get_ref_type(
package, output["imports"], method.input_type
).strip('"')
for msg in output["messages"]:
if msg["name"] == input_type:
input_message = msg
break
data["methods"].append(
{
"name": method.name,
"py_name": snake_case(method.name),
"route": f"/{package}.{service.name}/{method.name}",
"input": get_ref_type(
package, output["imports"], method.input_type
).strip('"'),
"input_message": input_message,
"output": get_ref_type(
package, output["imports"], method.output_type
).strip('"'),
"client_streaming": method.client_streaming,
"server_streaming": method.server_streaming,
}
)
output["services"].append(data)
output["imports"] = sorted(output["imports"])
# Fill response
@@ -256,7 +329,7 @@ def generate_code(request, response):
inits = set([""])
for f in response.file:
# Ensure output paths exist
print(f.name, file=sys.stderr)
# print(f.name, file=sys.stderr)
dirnames = os.path.dirname(f.name)
if dirnames:
os.makedirs(dirnames, exist_ok=True)