Add basic support for maps

This commit is contained in:
Daniel G. Taylor
2019-10-10 22:20:27 -07:00
parent ad7162a3ec
commit e0d1611797
7 changed files with 116 additions and 36 deletions

View File

@@ -12,6 +12,7 @@ from google.protobuf.descriptor_pb2 import (
DescriptorProto,
EnumDescriptorProto,
FileDescriptorProto,
FieldDescriptorProto,
)
from google.protobuf.compiler import plugin_pb2 as plugin
@@ -20,7 +21,9 @@ from google.protobuf.compiler import plugin_pb2 as plugin
from jinja2 import Environment, PackageLoader
def py_type(descriptor: DescriptorProto) -> Tuple[str, str]:
def py_type(
message: DescriptorProto, descriptor: FieldDescriptorProto
) -> Tuple[str, str]:
if descriptor.type in [1, 2, 6, 7, 15, 16]:
return "float", descriptor.default_value
elif descriptor.type in [3, 4, 5, 13, 17, 18]:
@@ -115,6 +118,10 @@ def generate_code(request, response):
if isinstance(item, DescriptorProto):
# print(item, file=sys.stderr)
if item.options.map_entry:
# Skip generated map entry messages since we just use dicts
continue
data.update(
{
"type": "Message",
@@ -124,11 +131,33 @@ def generate_code(request, response):
)
for i, f in enumerate(item.field):
t, zero = py_type(f)
t, zero = py_type(item, f)
repeated = False
packed = False
if f.label == 3:
field_type = f.Type.Name(f.type).lower()[5:]
map_types = None
if f.type == 11:
# This might be a map...
message_type = f.type_name.split(".").pop()
map_entry = f"{f.name.capitalize()}Entry"
if message_type == map_entry:
for nested in item.nested_type:
if nested.name == map_entry:
if nested.options.map_entry:
print("Found a map!", file=sys.stderr)
k, _ = py_type(item, nested.field[0])
v, _ = py_type(item, nested.field[1])
t = f"Dict[{k}, {v}]"
zero = "dict"
field_type = "map"
map_types = (
f.Type.Name(nested.field[0].type),
f.Type.Name(nested.field[1].type),
)
if f.label == 3 and field_type != "map":
# Repeated field
repeated = True
t = f"List[{t}]"
@@ -143,7 +172,8 @@ def generate_code(request, response):
"number": f.number,
"comment": get_comment(proto_file, path + [2, i]),
"proto_type": int(f.type),
"field_type": f.Type.Name(f.type).lower()[5:],
"field_type": field_type,
"map_types": map_types,
"type": t,
"zero": zero,
"repeated": repeated,