Add basic support for maps

This commit is contained in:
Daniel G. Taylor 2019-10-10 22:20:27 -07:00
parent ad7162a3ec
commit e0d1611797
No known key found for this signature in database
GPG Key ID: 7BD6DC99C9A87E22
7 changed files with 116 additions and 36 deletions

View File

@ -6,7 +6,8 @@
- [x] Don't encode zero values for nested types
- [x] Enums
- [x] Repeated message fields
- [ ] Maps
- [x] Maps
- [ ] Maps of message fields
- [ ] Support passthrough of unknown fields
- [ ] Refs to nested types
- [ ] Imports in proto files

View File

@ -13,6 +13,7 @@ from typing import (
Type,
Iterable,
TypeVar,
Optional,
)
import dataclasses
@ -36,6 +37,7 @@ TYPE_SFIXED64 = "sfixed64"
TYPE_STRING = "string"
TYPE_BYTES = "bytes"
TYPE_MESSAGE = "message"
TYPE_MAP = "map"
# Fields that use a fixed amount of space (4 or 8 bytes)
@ -87,7 +89,7 @@ WIRE_VARINT_TYPES = [
WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE]
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
@dataclasses.dataclass(frozen=True)
@ -98,6 +100,8 @@ class FieldMetadata:
number: int
# Protobuf type name
proto_type: str
# Map information if the proto_type is a map
map_types: Optional[Tuple[str, str]]
# Default value if given
default: Any
@ -107,10 +111,14 @@ class FieldMetadata:
return field.metadata["betterproto"]
def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
def dataclass_field(
number: int,
proto_type: str,
default: Any,
map_types: Optional[Tuple[str, str]] = None,
**kwargs: dict,
) -> dataclasses.Field:
"""Creates a dataclass field with attached protobuf metadata."""
kwargs = {}
if callable(default):
kwargs["default_factory"] = default
elif isinstance(default, dict) or isinstance(default, list):
@ -119,7 +127,8 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
kwargs["default"] = default
return dataclasses.field(
**kwargs, metadata={"betterproto": FieldMetadata(number, proto_type, default)}
**kwargs,
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, default)},
)
@ -129,63 +138,69 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
return field(number, TYPE_ENUM, default=default)
return dataclass_field(number, TYPE_ENUM, default=default)
def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
return field(number, TYPE_INT32, default=default)
return dataclass_field(number, TYPE_INT32, default=default)
def int64_field(number: int, default: int = 0) -> Any:
return field(number, TYPE_INT64, default=default)
return dataclass_field(number, TYPE_INT64, default=default)
def uint32_field(number: int, default: int = 0) -> Any:
return field(number, TYPE_UINT32, default=default)
return dataclass_field(number, TYPE_UINT32, default=default)
def uint64_field(number: int, default: int = 0) -> Any:
return field(number, TYPE_UINT64, default=default)
return dataclass_field(number, TYPE_UINT64, default=default)
def sint32_field(number: int, default: int = 0) -> Any:
return field(number, TYPE_SINT32, default=default)
return dataclass_field(number, TYPE_SINT32, default=default)
def sint64_field(number: int, default: int = 0) -> Any:
return field(number, TYPE_SINT64, default=default)
return dataclass_field(number, TYPE_SINT64, default=default)
def float_field(number: int, default: float = 0.0) -> Any:
return field(number, TYPE_FLOAT, default=default)
return dataclass_field(number, TYPE_FLOAT, default=default)
def double_field(number: int, default: float = 0.0) -> Any:
return field(number, TYPE_DOUBLE, default=default)
return dataclass_field(number, TYPE_DOUBLE, default=default)
def fixed32_field(number: int, default: float = 0.0) -> Any:
return field(number, TYPE_FIXED32, default=default)
return dataclass_field(number, TYPE_FIXED32, default=default)
def fixed64_field(number: int, default: float = 0.0) -> Any:
return field(number, TYPE_FIXED64, default=default)
return dataclass_field(number, TYPE_FIXED64, default=default)
def sfixed32_field(number: int, default: float = 0.0) -> Any:
return field(number, TYPE_SFIXED32, default=default)
return dataclass_field(number, TYPE_SFIXED32, default=default)
def sfixed64_field(number: int, default: float = 0.0) -> Any:
return field(number, TYPE_SFIXED64, default=default)
return dataclass_field(number, TYPE_SFIXED64, default=default)
def string_field(number: int, default: str = "") -> Any:
return field(number, TYPE_STRING, default=default)
return dataclass_field(number, TYPE_STRING, default=default)
def message_field(number: int, default: Type["Message"]) -> Any:
return field(number, TYPE_MESSAGE, default=default)
return dataclass_field(number, TYPE_MESSAGE, default=default)
def map_field(number: int, key_type: str, value_type: str) -> Any:
return dataclass_field(
number, TYPE_MAP, default=dict, map_types=(key_type, value_type)
)
def _pack_fmt(proto_type: str) -> str:
@ -354,6 +369,14 @@ class Message(ABC):
else:
for item in value:
output += _serialize_single(meta.number, meta.proto_type, item)
elif isinstance(value, dict):
if not len(value):
continue
for k, v in value.items():
sk = _serialize_single(1, meta.map_types[0], k)
sv = _serialize_single(2, meta.map_types[1], v)
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
else:
if value == field.default:
continue
@ -377,23 +400,35 @@ class Message(ABC):
) -> Any:
"""Adjusts values after parsing."""
if wire_type == WIRE_VARINT:
if meta.proto_type in ["int32", "int64"]:
if meta.proto_type in [TYPE_INT32, TYPE_INT64]:
bits = int(meta.proto_type[3:])
value = value & ((1 << bits) - 1)
signbit = 1 << (bits - 1)
value = int((value ^ signbit) - signbit)
elif meta.proto_type in ["sint32", "sint64"]:
elif meta.proto_type in [TYPE_SINT32, TYPE_SINT64]:
# Undo zig-zag encoding
value = (value >> 1) ^ (-(value & 1))
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
fmt = _pack_fmt(meta.proto_type)
value = struct.unpack(fmt, value)[0]
elif wire_type == WIRE_LEN_DELIM:
if meta.proto_type in ["string"]:
if meta.proto_type in [TYPE_STRING]:
value = value.decode("utf-8")
elif meta.proto_type in ["message"]:
elif meta.proto_type in [TYPE_MESSAGE]:
cls = self._cls_for(field)
value = cls().parse(value)
elif meta.proto_type in [TYPE_MAP]:
# TODO: This is slow, use a cache to make it faster since each
# key/value pair will recreate the class.
Entry = dataclasses.make_dataclass(
"Entry",
[
("key", Any, dataclass_field(1, meta.map_types[0], None)),
("value", Any, dataclass_field(2, meta.map_types[1], None)),
],
bases=(Message,),
)
value = Entry().parse(value)
return value
@ -434,10 +469,12 @@ class Message(ABC):
parsed.wire_type, meta, field, parsed.value
)
if isinstance(getattr(self, field.name), list) and not isinstance(
value, list
):
getattr(self, field.name).append(value)
current = getattr(self, field.name)
if meta.proto_type == TYPE_MAP:
# Value represents a single key/value pair entry in the map.
current[value.key] = value.value
elif isinstance(current, list) and not isinstance(value, list):
current.append(value)
else:
setattr(self, field.name, value)
else:

View File

@ -4,7 +4,7 @@
{% if description.enums %}import enum
{% endif %}
from dataclasses import dataclass
from typing import List
from typing import Dict, List
import betterproto
@ -36,7 +36,7 @@ class {{ message.name }}(betterproto.Message):
{% if field.comment %}
{{ field.comment }}
{% endif %}
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.zero %}, default={{ field.zero }}{% endif %})
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.zero and field.field_type != 'map' %}, default={{ field.zero }}{% endif %}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %})
{% endfor %}

View File

@ -48,7 +48,7 @@ if __name__ == "__main__":
json_files = get_files(".json")
for filename in proto_files:
print(f"Generatinng code for {os.path.basename(filename)}")
print(f"Generating code for {os.path.basename(filename)}")
subprocess.run(
f"protoc --python_out=. {os.path.basename(filename)}", shell=True
)

View File

@ -0,0 +1,7 @@
{
"counts": {
"item1": 1,
"item2": 2,
"item3": 3
}
}

View File

@ -0,0 +1,5 @@
syntax = "proto3";
message Test {
map<string, int32> counts = 1;
}

View File

@ -12,6 +12,7 @@ from google.protobuf.descriptor_pb2 import (
DescriptorProto,
EnumDescriptorProto,
FileDescriptorProto,
FieldDescriptorProto,
)
from google.protobuf.compiler import plugin_pb2 as plugin
@ -20,7 +21,9 @@ from google.protobuf.compiler import plugin_pb2 as plugin
from jinja2 import Environment, PackageLoader
def py_type(descriptor: DescriptorProto) -> Tuple[str, str]:
def py_type(
message: DescriptorProto, descriptor: FieldDescriptorProto
) -> Tuple[str, str]:
if descriptor.type in [1, 2, 6, 7, 15, 16]:
return "float", descriptor.default_value
elif descriptor.type in [3, 4, 5, 13, 17, 18]:
@ -115,6 +118,10 @@ def generate_code(request, response):
if isinstance(item, DescriptorProto):
# print(item, file=sys.stderr)
if item.options.map_entry:
# Skip generated map entry messages since we just use dicts
continue
data.update(
{
"type": "Message",
@ -124,11 +131,33 @@ def generate_code(request, response):
)
for i, f in enumerate(item.field):
t, zero = py_type(f)
t, zero = py_type(item, f)
repeated = False
packed = False
if f.label == 3:
field_type = f.Type.Name(f.type).lower()[5:]
map_types = None
if f.type == 11:
# This might be a map...
message_type = f.type_name.split(".").pop()
map_entry = f"{f.name.capitalize()}Entry"
if message_type == map_entry:
for nested in item.nested_type:
if nested.name == map_entry:
if nested.options.map_entry:
print("Found a map!", file=sys.stderr)
k, _ = py_type(item, nested.field[0])
v, _ = py_type(item, nested.field[1])
t = f"Dict[{k}, {v}]"
zero = "dict"
field_type = "map"
map_types = (
f.Type.Name(nested.field[0].type),
f.Type.Name(nested.field[1].type),
)
if f.label == 3 and field_type != "map":
# Repeated field
repeated = True
t = f"List[{t}]"
@ -143,7 +172,8 @@ def generate_code(request, response):
"number": f.number,
"comment": get_comment(proto_file, path + [2, i]),
"proto_type": int(f.type),
"field_type": f.Type.Name(f.type).lower()[5:],
"field_type": field_type,
"map_types": map_types,
"type": t,
"zero": zero,
"repeated": repeated,