186 lines
5.8 KiB
Python
Executable File
186 lines
5.8 KiB
Python
Executable File
#!/usr/bin/env python
|
|
|
|
import sys
|
|
|
|
import itertools
|
|
import json
|
|
import os.path
|
|
from typing import Tuple, Any, List
|
|
import textwrap
|
|
|
|
from google.protobuf.descriptor_pb2 import (
|
|
DescriptorProto,
|
|
EnumDescriptorProto,
|
|
FileDescriptorProto,
|
|
)
|
|
|
|
from google.protobuf.compiler import plugin_pb2 as plugin
|
|
|
|
|
|
from jinja2 import Environment, PackageLoader
|
|
|
|
|
|
def py_type(descriptor: DescriptorProto) -> Tuple[str, str]:
|
|
if descriptor.type in [1, 2, 6, 7, 15, 16]:
|
|
return "float", descriptor.default_value
|
|
elif descriptor.type in [3, 4, 5, 13, 17, 18]:
|
|
return "int", descriptor.default_value
|
|
elif descriptor.type == 8:
|
|
return "bool", descriptor.default_value.capitalize()
|
|
elif descriptor.type == 9:
|
|
default = ""
|
|
if descriptor.default_value:
|
|
default = f'"{descriptor.default_value}"'
|
|
return "str", default
|
|
elif descriptor.type == 11:
|
|
# Type referencing another defined Message
|
|
# print(descriptor.type_name, file=sys.stderr)
|
|
# message_type = descriptor.type_name.replace(".", "")
|
|
message_type = descriptor.type_name.split(".").pop()
|
|
return f'"{message_type}"', f"lambda: {message_type}()"
|
|
elif descriptor.type == 12:
|
|
default = ""
|
|
if descriptor.default_value:
|
|
default = f'b"{descriptor.default_value}"'
|
|
return "bytes", default
|
|
else:
|
|
raise NotImplementedError()
|
|
|
|
|
|
def traverse(proto_file):
|
|
def _traverse(path, items):
|
|
for i, item in enumerate(items):
|
|
yield item, path + [i]
|
|
|
|
if isinstance(item, DescriptorProto):
|
|
for enum in item.enum_type:
|
|
yield enum, path + [i, 4]
|
|
|
|
if item.nested_type:
|
|
for n, p in _traverse(path + [i, 3], item.nested_type):
|
|
yield n, p
|
|
|
|
return itertools.chain(
|
|
_traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type)
|
|
)
|
|
|
|
|
|
def get_comment(proto_file, path: List[int]) -> str:
|
|
for sci in proto_file.source_code_info.location:
|
|
# print(list(sci.path), path, file=sys.stderr)
|
|
if list(sci.path) == path and sci.leading_comments:
|
|
lines = textwrap.wrap(
|
|
sci.leading_comments.strip().replace("\n", ""), width=75
|
|
)
|
|
|
|
if path[-2] == 2:
|
|
# This is a field
|
|
return " # " + " # ".join(lines)
|
|
else:
|
|
# This is a class
|
|
if len(lines) == 1 and len(lines[0]) < 70:
|
|
return f' """{lines[0]}"""'
|
|
else:
|
|
return f' """\n{" ".join(lines)}\n """'
|
|
|
|
return ""
|
|
|
|
|
|
def generate_code(request, response):
|
|
env = Environment(
|
|
trim_blocks=True,
|
|
lstrip_blocks=True,
|
|
loader=PackageLoader("betterproto", "templates"),
|
|
)
|
|
template = env.get_template("main.py")
|
|
|
|
for proto_file in request.proto_file:
|
|
# print(proto_file.message_type, file=sys.stderr)
|
|
# print(proto_file.source_code_info, file=sys.stderr)
|
|
output = {
|
|
"package": proto_file.package,
|
|
"filename": proto_file.name,
|
|
"messages": [],
|
|
}
|
|
|
|
# Parse request
|
|
for item, path in traverse(proto_file):
|
|
# print(item, file=sys.stderr)
|
|
# print(path, file=sys.stderr)
|
|
data = {"name": item.name}
|
|
|
|
if isinstance(item, DescriptorProto):
|
|
# print(item, file=sys.stderr)
|
|
data.update(
|
|
{
|
|
"type": "Message",
|
|
"comment": get_comment(proto_file, path),
|
|
"properties": [],
|
|
}
|
|
)
|
|
|
|
for i, f in enumerate(item.field):
|
|
t, zero = py_type(f)
|
|
repeated = False
|
|
packed = False
|
|
|
|
if f.label == 3:
|
|
# Repeated field
|
|
repeated = True
|
|
t = f"List[{t}]"
|
|
zero = "list"
|
|
|
|
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
|
|
packed = True
|
|
|
|
data["properties"].append(
|
|
{
|
|
"name": f.name,
|
|
"number": f.number,
|
|
"comment": get_comment(proto_file, path + [2, i]),
|
|
"proto_type": int(f.type),
|
|
"field_type": f.Type.Name(f.type).lower()[5:],
|
|
"type": t,
|
|
"zero": zero,
|
|
"repeated": repeated,
|
|
"packed": packed,
|
|
}
|
|
)
|
|
# print(f, file=sys.stderr)
|
|
|
|
# elif isinstance(item, EnumDescriptorProto):
|
|
# data.update({
|
|
# 'type': 'Enum',
|
|
# 'values': [{'name': v.name, 'value': v.number}
|
|
# for v in item.value]
|
|
# })
|
|
|
|
output["messages"].append(data)
|
|
|
|
# Fill response
|
|
f = response.file.add()
|
|
f.name = os.path.splitext(proto_file.name)[0] + ".py"
|
|
# f.content = json.dumps(output, indent=2)
|
|
f.content = template.render(description=output).rstrip("\n") + "\n"
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Read request message from stdin
|
|
data = sys.stdin.buffer.read()
|
|
|
|
# Parse request
|
|
request = plugin.CodeGeneratorRequest()
|
|
request.ParseFromString(data)
|
|
|
|
# Create response
|
|
response = plugin.CodeGeneratorResponse()
|
|
|
|
# Generate code
|
|
generate_code(request, response)
|
|
|
|
# Serialise response message
|
|
output = response.SerializeToString()
|
|
|
|
# Write to stdout
|
|
sys.stdout.buffer.write(output)
|