Implement imports, simplified default value handling
This commit is contained in:
@@ -22,33 +22,45 @@ from jinja2 import Environment, PackageLoader
|
||||
|
||||
|
||||
def py_type(
|
||||
message: DescriptorProto, descriptor: FieldDescriptorProto
|
||||
) -> Tuple[str, str]:
|
||||
package: str,
|
||||
imports: set,
|
||||
message: DescriptorProto,
|
||||
descriptor: FieldDescriptorProto,
|
||||
) -> str:
|
||||
if descriptor.type in [1, 2, 6, 7, 15, 16]:
|
||||
return "float", descriptor.default_value
|
||||
return "float"
|
||||
elif descriptor.type in [3, 4, 5, 13, 17, 18]:
|
||||
return "int", descriptor.default_value
|
||||
return "int"
|
||||
elif descriptor.type == 8:
|
||||
return "bool", descriptor.default_value.capitalize()
|
||||
return "bool"
|
||||
elif descriptor.type == 9:
|
||||
default = ""
|
||||
if descriptor.default_value:
|
||||
default = f'"{descriptor.default_value}"'
|
||||
return "str", default
|
||||
elif descriptor.type == 11:
|
||||
# Type referencing another defined Message
|
||||
# print(descriptor.type_name, file=sys.stderr)
|
||||
# message_type = descriptor.type_name.replace(".", "")
|
||||
message_type = descriptor.type_name.split(".").pop()
|
||||
return f'"{message_type}"', f"lambda: {message_type}()"
|
||||
return "str"
|
||||
elif descriptor.type in [11, 14]:
|
||||
# Type referencing another defined Message or a named enum
|
||||
message_type = descriptor.type_name.lstrip(".")
|
||||
if message_type.startswith(package):
|
||||
# This is the current package, which has nested types flattened.
|
||||
message_type = message_type.lstrip(package).lstrip(".").replace(".", "")
|
||||
|
||||
if "." in message_type:
|
||||
# This is imported from another package. No need
|
||||
# to use a forward ref and we need to add the import.
|
||||
message_type = message_type.strip('"')
|
||||
parts = message_type.split(".")
|
||||
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
|
||||
message_type = f"{parts[-2]}.{parts[-1]}"
|
||||
|
||||
# print(
|
||||
# descriptor.name,
|
||||
# package,
|
||||
# descriptor.type_name,
|
||||
# message_type,
|
||||
# file=sys.stderr,
|
||||
# )
|
||||
|
||||
return f'"{message_type}"'
|
||||
elif descriptor.type == 12:
|
||||
default = ""
|
||||
if descriptor.default_value:
|
||||
default = f'b"{descriptor.default_value}"'
|
||||
return "bytes", default
|
||||
elif descriptor.type == 14:
|
||||
# print(descriptor.type_name, file=sys.stderr)
|
||||
return descriptor.type_name.split(".").pop(), 0
|
||||
return "bytes"
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown type {descriptor.type}")
|
||||
|
||||
@@ -64,6 +76,8 @@ def traverse(proto_file):
|
||||
|
||||
if item.nested_type:
|
||||
for n, p in _traverse(path + [i, 3], item.nested_type):
|
||||
# Adjust the name since we flatten the heirarchy.
|
||||
n.name = item.name + n.name
|
||||
yield n, p
|
||||
|
||||
return itertools.chain(
|
||||
@@ -85,6 +99,7 @@ def get_comment(proto_file, path: List[int]) -> str:
|
||||
else:
|
||||
# This is a class
|
||||
if len(lines) == 1 and len(lines[0]) < 70:
|
||||
lines[0] = lines[0].strip('"')
|
||||
return f' """{lines[0]}"""'
|
||||
else:
|
||||
return f' """\n{" ".join(lines)}\n """'
|
||||
@@ -100,112 +115,139 @@ def generate_code(request, response):
|
||||
)
|
||||
template = env.get_template("main.py")
|
||||
|
||||
# TODO: Refactor below to generate a single file per package if packages
|
||||
# are being used, otherwise one output for each input. Figure out how to
|
||||
# set up relative imports when needed and change the Message type refs to
|
||||
# use the import names when not in the current module.
|
||||
output_map = {}
|
||||
for proto_file in request.proto_file:
|
||||
# print(proto_file.message_type, file=sys.stderr)
|
||||
# print(proto_file.source_code_info, file=sys.stderr)
|
||||
output = {
|
||||
"package": proto_file.package,
|
||||
"filename": proto_file.name,
|
||||
"messages": [],
|
||||
"enums": [],
|
||||
}
|
||||
out = proto_file.package
|
||||
if not out:
|
||||
out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".")
|
||||
|
||||
# Parse request
|
||||
for item, path in traverse(proto_file):
|
||||
# print(item, file=sys.stderr)
|
||||
# print(path, file=sys.stderr)
|
||||
data = {"name": item.name}
|
||||
if out not in output_map:
|
||||
output_map[out] = {"package": proto_file.package, "files": []}
|
||||
output_map[out]["files"].append(proto_file)
|
||||
|
||||
if isinstance(item, DescriptorProto):
|
||||
# TODO: Figure out how to handle gRPC request/response messages and add
|
||||
# processing below for Service.
|
||||
|
||||
for filename, options in output_map.items():
|
||||
package = options["package"]
|
||||
# print(package, filename, file=sys.stderr)
|
||||
output = {"package": package, "imports": set(), "messages": [], "enums": []}
|
||||
|
||||
for proto_file in options["files"]:
|
||||
# print(proto_file.message_type, file=sys.stderr)
|
||||
# print(proto_file.service, file=sys.stderr)
|
||||
# print(proto_file.source_code_info, file=sys.stderr)
|
||||
|
||||
for item, path in traverse(proto_file):
|
||||
# print(item, file=sys.stderr)
|
||||
if item.options.map_entry:
|
||||
# Skip generated map entry messages since we just use dicts
|
||||
continue
|
||||
# print(path, file=sys.stderr)
|
||||
data = {"name": item.name}
|
||||
|
||||
data.update(
|
||||
{
|
||||
"type": "Message",
|
||||
"comment": get_comment(proto_file, path),
|
||||
"properties": [],
|
||||
}
|
||||
)
|
||||
if isinstance(item, DescriptorProto):
|
||||
# print(item, file=sys.stderr)
|
||||
if item.options.map_entry:
|
||||
# Skip generated map entry messages since we just use dicts
|
||||
continue
|
||||
|
||||
for i, f in enumerate(item.field):
|
||||
t, zero = py_type(item, f)
|
||||
repeated = False
|
||||
packed = False
|
||||
|
||||
field_type = f.Type.Name(f.type).lower()[5:]
|
||||
map_types = None
|
||||
if f.type == 11:
|
||||
# This might be a map...
|
||||
message_type = f.type_name.split(".").pop()
|
||||
map_entry = f"{f.name.capitalize()}Entry"
|
||||
|
||||
if message_type == map_entry:
|
||||
for nested in item.nested_type:
|
||||
if nested.name == map_entry:
|
||||
if nested.options.map_entry:
|
||||
print("Found a map!", file=sys.stderr)
|
||||
k, _ = py_type(item, nested.field[0])
|
||||
v, _ = py_type(item, nested.field[1])
|
||||
t = f"Dict[{k}, {v}]"
|
||||
zero = "dict"
|
||||
field_type = "map"
|
||||
map_types = (
|
||||
f.Type.Name(nested.field[0].type),
|
||||
f.Type.Name(nested.field[1].type),
|
||||
)
|
||||
|
||||
if f.label == 3 and field_type != "map":
|
||||
# Repeated field
|
||||
repeated = True
|
||||
t = f"List[{t}]"
|
||||
zero = "list"
|
||||
|
||||
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
|
||||
packed = True
|
||||
|
||||
data["properties"].append(
|
||||
data.update(
|
||||
{
|
||||
"name": f.name,
|
||||
"number": f.number,
|
||||
"comment": get_comment(proto_file, path + [2, i]),
|
||||
"proto_type": int(f.type),
|
||||
"field_type": field_type,
|
||||
"map_types": map_types,
|
||||
"type": t,
|
||||
"zero": zero,
|
||||
"repeated": repeated,
|
||||
"packed": packed,
|
||||
"type": "Message",
|
||||
"comment": get_comment(proto_file, path),
|
||||
"properties": [],
|
||||
}
|
||||
)
|
||||
# print(f, file=sys.stderr)
|
||||
|
||||
output["messages"].append(data)
|
||||
for i, f in enumerate(item.field):
|
||||
t = py_type(package, output["imports"], item, f)
|
||||
|
||||
elif isinstance(item, EnumDescriptorProto):
|
||||
# print(item.name, path, file=sys.stderr)
|
||||
data.update(
|
||||
{
|
||||
"type": "Enum",
|
||||
"comment": get_comment(proto_file, path),
|
||||
"entries": [
|
||||
repeated = False
|
||||
packed = False
|
||||
|
||||
field_type = f.Type.Name(f.type).lower()[5:]
|
||||
map_types = None
|
||||
if f.type == 11:
|
||||
# This might be a map...
|
||||
message_type = f.type_name.split(".").pop()
|
||||
map_entry = f"{f.name.capitalize()}Entry"
|
||||
|
||||
if message_type == map_entry:
|
||||
for nested in item.nested_type:
|
||||
if nested.name == map_entry:
|
||||
if nested.options.map_entry:
|
||||
# print("Found a map!", file=sys.stderr)
|
||||
k = py_type(
|
||||
package,
|
||||
output["imports"],
|
||||
item,
|
||||
nested.field[0],
|
||||
)
|
||||
v = py_type(
|
||||
package,
|
||||
output["imports"],
|
||||
item,
|
||||
nested.field[1],
|
||||
)
|
||||
t = f"Dict[{k}, {v}]"
|
||||
field_type = "map"
|
||||
map_types = (
|
||||
f.Type.Name(nested.field[0].type),
|
||||
f.Type.Name(nested.field[1].type),
|
||||
)
|
||||
|
||||
if f.label == 3 and field_type != "map":
|
||||
# Repeated field
|
||||
repeated = True
|
||||
t = f"List[{t}]"
|
||||
|
||||
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
|
||||
packed = True
|
||||
|
||||
data["properties"].append(
|
||||
{
|
||||
"name": v.name,
|
||||
"value": v.number,
|
||||
"name": f.name,
|
||||
"number": f.number,
|
||||
"comment": get_comment(proto_file, path + [2, i]),
|
||||
"proto_type": int(f.type),
|
||||
"field_type": field_type,
|
||||
"map_types": map_types,
|
||||
"type": t,
|
||||
"repeated": repeated,
|
||||
"packed": packed,
|
||||
}
|
||||
for i, v in enumerate(item.value)
|
||||
],
|
||||
}
|
||||
)
|
||||
)
|
||||
# print(f, file=sys.stderr)
|
||||
|
||||
output["enums"].append(data)
|
||||
output["messages"].append(data)
|
||||
|
||||
elif isinstance(item, EnumDescriptorProto):
|
||||
# print(item.name, path, file=sys.stderr)
|
||||
data.update(
|
||||
{
|
||||
"type": "Enum",
|
||||
"comment": get_comment(proto_file, path),
|
||||
"entries": [
|
||||
{
|
||||
"name": v.name,
|
||||
"value": v.number,
|
||||
"comment": get_comment(proto_file, path + [2, i]),
|
||||
}
|
||||
for i, v in enumerate(item.value)
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
output["enums"].append(data)
|
||||
|
||||
output["imports"] = sorted(output["imports"])
|
||||
|
||||
# Fill response
|
||||
f = response.file.add()
|
||||
f.name = os.path.splitext(proto_file.name)[0] + ".py"
|
||||
# print(filename, file=sys.stderr)
|
||||
f.name = filename + ".py"
|
||||
# f.content = json.dumps(output, indent=2)
|
||||
f.content = template.render(description=output).rstrip("\n") + "\n"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user