Add basic support for maps
This commit is contained in:
@@ -12,6 +12,7 @@ from google.protobuf.descriptor_pb2 import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FileDescriptorProto,
|
||||
FieldDescriptorProto,
|
||||
)
|
||||
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
@@ -20,7 +21,9 @@ from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
from jinja2 import Environment, PackageLoader
|
||||
|
||||
|
||||
def py_type(descriptor: DescriptorProto) -> Tuple[str, str]:
|
||||
def py_type(
|
||||
message: DescriptorProto, descriptor: FieldDescriptorProto
|
||||
) -> Tuple[str, str]:
|
||||
if descriptor.type in [1, 2, 6, 7, 15, 16]:
|
||||
return "float", descriptor.default_value
|
||||
elif descriptor.type in [3, 4, 5, 13, 17, 18]:
|
||||
@@ -115,6 +118,10 @@ def generate_code(request, response):
|
||||
|
||||
if isinstance(item, DescriptorProto):
|
||||
# print(item, file=sys.stderr)
|
||||
if item.options.map_entry:
|
||||
# Skip generated map entry messages since we just use dicts
|
||||
continue
|
||||
|
||||
data.update(
|
||||
{
|
||||
"type": "Message",
|
||||
@@ -124,11 +131,33 @@ def generate_code(request, response):
|
||||
)
|
||||
|
||||
for i, f in enumerate(item.field):
|
||||
t, zero = py_type(f)
|
||||
t, zero = py_type(item, f)
|
||||
repeated = False
|
||||
packed = False
|
||||
|
||||
if f.label == 3:
|
||||
field_type = f.Type.Name(f.type).lower()[5:]
|
||||
map_types = None
|
||||
if f.type == 11:
|
||||
# This might be a map...
|
||||
message_type = f.type_name.split(".").pop()
|
||||
map_entry = f"{f.name.capitalize()}Entry"
|
||||
|
||||
if message_type == map_entry:
|
||||
for nested in item.nested_type:
|
||||
if nested.name == map_entry:
|
||||
if nested.options.map_entry:
|
||||
print("Found a map!", file=sys.stderr)
|
||||
k, _ = py_type(item, nested.field[0])
|
||||
v, _ = py_type(item, nested.field[1])
|
||||
t = f"Dict[{k}, {v}]"
|
||||
zero = "dict"
|
||||
field_type = "map"
|
||||
map_types = (
|
||||
f.Type.Name(nested.field[0].type),
|
||||
f.Type.Name(nested.field[1].type),
|
||||
)
|
||||
|
||||
if f.label == 3 and field_type != "map":
|
||||
# Repeated field
|
||||
repeated = True
|
||||
t = f"List[{t}]"
|
||||
@@ -143,7 +172,8 @@ def generate_code(request, response):
|
||||
"number": f.number,
|
||||
"comment": get_comment(proto_file, path + [2, i]),
|
||||
"proto_type": int(f.type),
|
||||
"field_type": f.Type.Name(f.type).lower()[5:],
|
||||
"field_type": field_type,
|
||||
"map_types": map_types,
|
||||
"type": t,
|
||||
"zero": zero,
|
||||
"repeated": repeated,
|
||||
|
||||
Reference in New Issue
Block a user