Support nested messages, fix casing. Support test-cases in packages.

This commit is contained in:
boukeversteegh
2020-06-07 16:57:57 +02:00
parent d8abb850f8
commit f7c2fd1194
19 changed files with 333 additions and 163 deletions

View File

@@ -10,6 +10,11 @@ from typing import List
from betterproto.casing import safe_snake_case
from betterproto.compile.importing import get_ref_type
import betterproto
from betterproto.compile.naming import (
pythonize_class_name,
pythonize_field_name,
pythonize_method_name,
)
try:
# betterproto[compiler] specific dependencies
@@ -35,27 +40,22 @@ except ImportError as err:
raise SystemExit(1)
def py_type(
package: str,
imports: set,
message: DescriptorProto,
descriptor: FieldDescriptorProto,
) -> str:
if descriptor.type in [1, 2, 6, 7, 15, 16]:
def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str:
if field.type in [1, 2, 6, 7, 15, 16]:
return "float"
elif descriptor.type in [3, 4, 5, 13, 17, 18]:
elif field.type in [3, 4, 5, 13, 17, 18]:
return "int"
elif descriptor.type == 8:
elif field.type == 8:
return "bool"
elif descriptor.type == 9:
elif field.type == 9:
return "str"
elif descriptor.type in [11, 14]:
elif field.type in [11, 14]:
# Type referencing another defined Message or a named enum
return get_ref_type(package, imports, descriptor.type_name)
elif descriptor.type == 12:
return get_ref_type(package, imports, field.type_name)
elif field.type == 12:
return "bytes"
else:
raise NotImplementedError(f"Unknown type {descriptor.type}")
raise NotImplementedError(f"Unknown type {field.type}")
def get_py_zero(type_num: int) -> str:
@@ -160,17 +160,10 @@ def generate_code(request, response):
"services": [],
}
type_mapping = {}
for proto_file in options["files"]:
# print(proto_file.message_type, file=sys.stderr)
# print(proto_file.service, file=sys.stderr)
# print(proto_file.source_code_info, file=sys.stderr)
item: DescriptorProto
for item, path in traverse(proto_file):
# print(item, file=sys.stderr)
# print(path, file=sys.stderr)
data = {"name": item.name, "py_name": stringcase.pascalcase(item.name)}
data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
if isinstance(item, DescriptorProto):
# print(item, file=sys.stderr)
@@ -187,7 +180,7 @@ def generate_code(request, response):
)
for i, f in enumerate(item.field):
t = py_type(package, output["imports"], item, f)
t = py_type(package, output["imports"], f)
zero = get_py_zero(f.type)
repeated = False
@@ -222,13 +215,11 @@ def generate_code(request, response):
k = py_type(
package,
output["imports"],
item,
nested.field[0],
)
v = py_type(
package,
output["imports"],
item,
nested.field[1],
)
t = f"Dict[{k}, {v}]"
@@ -264,7 +255,7 @@ def generate_code(request, response):
data["properties"].append(
{
"name": f.name,
"py_name": safe_snake_case(f.name),
"py_name": pythonize_field_name(f.name),
"number": f.number,
"comment": get_comment(proto_file, path + [2, i]),
"proto_type": int(f.type),
@@ -305,7 +296,7 @@ def generate_code(request, response):
data = {
"name": service.name,
"py_name": stringcase.pascalcase(service.name),
"py_name": pythonize_class_name(service.name),
"comment": get_comment(proto_file, [6, i]),
"methods": [],
}
@@ -329,7 +320,7 @@ def generate_code(request, response):
data["methods"].append(
{
"name": method.name,
"py_name": stringcase.snakecase(method.name),
"py_name": pythonize_method_name(method.name),
"comment": get_comment(proto_file, [6, i, 2, j], indent=8),
"route": f"/{package}.{service.name}/{method.name}",
"input": get_ref_type(