Merge pull request #1 from danielgtaylor/maps

Add basic support for maps
This commit is contained in:
Daniel G. Taylor 2019-10-10 22:25:25 -07:00 committed by GitHub
commit 32bc8d50fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 116 additions and 36 deletions

View File

@ -6,7 +6,8 @@
- [x] Don't encode zero values for nested types - [x] Don't encode zero values for nested types
- [x] Enums - [x] Enums
- [x] Repeated message fields - [x] Repeated message fields
- [ ] Maps - [x] Maps
- [ ] Maps of message fields
- [ ] Support passthrough of unknown fields - [ ] Support passthrough of unknown fields
- [ ] Refs to nested types - [ ] Refs to nested types
- [ ] Imports in proto files - [ ] Imports in proto files

View File

@ -13,6 +13,7 @@ from typing import (
Type, Type,
Iterable, Iterable,
TypeVar, TypeVar,
Optional,
) )
import dataclasses import dataclasses
@ -36,6 +37,7 @@ TYPE_SFIXED64 = "sfixed64"
TYPE_STRING = "string" TYPE_STRING = "string"
TYPE_BYTES = "bytes" TYPE_BYTES = "bytes"
TYPE_MESSAGE = "message" TYPE_MESSAGE = "message"
TYPE_MAP = "map"
# Fields that use a fixed amount of space (4 or 8 bytes) # Fields that use a fixed amount of space (4 or 8 bytes)
@ -87,7 +89,7 @@ WIRE_VARINT_TYPES = [
WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32] WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64] WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE] WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
@ -98,6 +100,8 @@ class FieldMetadata:
number: int number: int
# Protobuf type name # Protobuf type name
proto_type: str proto_type: str
# Map information if the proto_type is a map
map_types: Optional[Tuple[str, str]]
# Default value if given # Default value if given
default: Any default: Any
@ -107,10 +111,14 @@ class FieldMetadata:
return field.metadata["betterproto"] return field.metadata["betterproto"]
def field(number: int, proto_type: str, default: Any) -> dataclasses.Field: def dataclass_field(
number: int,
proto_type: str,
default: Any,
map_types: Optional[Tuple[str, str]] = None,
**kwargs: dict,
) -> dataclasses.Field:
"""Creates a dataclass field with attached protobuf metadata.""" """Creates a dataclass field with attached protobuf metadata."""
kwargs = {}
if callable(default): if callable(default):
kwargs["default_factory"] = default kwargs["default_factory"] = default
elif isinstance(default, dict) or isinstance(default, list): elif isinstance(default, dict) or isinstance(default, list):
@ -119,7 +127,8 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
kwargs["default"] = default kwargs["default"] = default
return dataclasses.field( return dataclasses.field(
**kwargs, metadata={"betterproto": FieldMetadata(number, proto_type, default)} **kwargs,
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, default)},
) )
@ -129,63 +138,69 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any: def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
return field(number, TYPE_ENUM, default=default) return dataclass_field(number, TYPE_ENUM, default=default)
def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any: def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
return field(number, TYPE_INT32, default=default) return dataclass_field(number, TYPE_INT32, default=default)
def int64_field(number: int, default: int = 0) -> Any: def int64_field(number: int, default: int = 0) -> Any:
return field(number, TYPE_INT64, default=default) return dataclass_field(number, TYPE_INT64, default=default)
def uint32_field(number: int, default: int = 0) -> Any: def uint32_field(number: int, default: int = 0) -> Any:
return field(number, TYPE_UINT32, default=default) return dataclass_field(number, TYPE_UINT32, default=default)
def uint64_field(number: int, default: int = 0) -> Any: def uint64_field(number: int, default: int = 0) -> Any:
return field(number, TYPE_UINT64, default=default) return dataclass_field(number, TYPE_UINT64, default=default)
def sint32_field(number: int, default: int = 0) -> Any: def sint32_field(number: int, default: int = 0) -> Any:
return field(number, TYPE_SINT32, default=default) return dataclass_field(number, TYPE_SINT32, default=default)
def sint64_field(number: int, default: int = 0) -> Any: def sint64_field(number: int, default: int = 0) -> Any:
return field(number, TYPE_SINT64, default=default) return dataclass_field(number, TYPE_SINT64, default=default)
def float_field(number: int, default: float = 0.0) -> Any: def float_field(number: int, default: float = 0.0) -> Any:
return field(number, TYPE_FLOAT, default=default) return dataclass_field(number, TYPE_FLOAT, default=default)
def double_field(number: int, default: float = 0.0) -> Any: def double_field(number: int, default: float = 0.0) -> Any:
return field(number, TYPE_DOUBLE, default=default) return dataclass_field(number, TYPE_DOUBLE, default=default)
def fixed32_field(number: int, default: float = 0.0) -> Any: def fixed32_field(number: int, default: float = 0.0) -> Any:
return field(number, TYPE_FIXED32, default=default) return dataclass_field(number, TYPE_FIXED32, default=default)
def fixed64_field(number: int, default: float = 0.0) -> Any: def fixed64_field(number: int, default: float = 0.0) -> Any:
return field(number, TYPE_FIXED64, default=default) return dataclass_field(number, TYPE_FIXED64, default=default)
def sfixed32_field(number: int, default: float = 0.0) -> Any: def sfixed32_field(number: int, default: float = 0.0) -> Any:
return field(number, TYPE_SFIXED32, default=default) return dataclass_field(number, TYPE_SFIXED32, default=default)
def sfixed64_field(number: int, default: float = 0.0) -> Any: def sfixed64_field(number: int, default: float = 0.0) -> Any:
return field(number, TYPE_SFIXED64, default=default) return dataclass_field(number, TYPE_SFIXED64, default=default)
def string_field(number: int, default: str = "") -> Any: def string_field(number: int, default: str = "") -> Any:
return field(number, TYPE_STRING, default=default) return dataclass_field(number, TYPE_STRING, default=default)
def message_field(number: int, default: Type["Message"]) -> Any: def message_field(number: int, default: Type["Message"]) -> Any:
return field(number, TYPE_MESSAGE, default=default) return dataclass_field(number, TYPE_MESSAGE, default=default)
def map_field(number: int, key_type: str, value_type: str) -> Any:
return dataclass_field(
number, TYPE_MAP, default=dict, map_types=(key_type, value_type)
)
def _pack_fmt(proto_type: str) -> str: def _pack_fmt(proto_type: str) -> str:
@ -354,6 +369,14 @@ class Message(ABC):
else: else:
for item in value: for item in value:
output += _serialize_single(meta.number, meta.proto_type, item) output += _serialize_single(meta.number, meta.proto_type, item)
elif isinstance(value, dict):
if not len(value):
continue
for k, v in value.items():
sk = _serialize_single(1, meta.map_types[0], k)
sv = _serialize_single(2, meta.map_types[1], v)
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
else: else:
if value == field.default: if value == field.default:
continue continue
@ -377,23 +400,35 @@ class Message(ABC):
) -> Any: ) -> Any:
"""Adjusts values after parsing.""" """Adjusts values after parsing."""
if wire_type == WIRE_VARINT: if wire_type == WIRE_VARINT:
if meta.proto_type in ["int32", "int64"]: if meta.proto_type in [TYPE_INT32, TYPE_INT64]:
bits = int(meta.proto_type[3:]) bits = int(meta.proto_type[3:])
value = value & ((1 << bits) - 1) value = value & ((1 << bits) - 1)
signbit = 1 << (bits - 1) signbit = 1 << (bits - 1)
value = int((value ^ signbit) - signbit) value = int((value ^ signbit) - signbit)
elif meta.proto_type in ["sint32", "sint64"]: elif meta.proto_type in [TYPE_SINT32, TYPE_SINT64]:
# Undo zig-zag encoding # Undo zig-zag encoding
value = (value >> 1) ^ (-(value & 1)) value = (value >> 1) ^ (-(value & 1))
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]: elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
fmt = _pack_fmt(meta.proto_type) fmt = _pack_fmt(meta.proto_type)
value = struct.unpack(fmt, value)[0] value = struct.unpack(fmt, value)[0]
elif wire_type == WIRE_LEN_DELIM: elif wire_type == WIRE_LEN_DELIM:
if meta.proto_type in ["string"]: if meta.proto_type in [TYPE_STRING]:
value = value.decode("utf-8") value = value.decode("utf-8")
elif meta.proto_type in ["message"]: elif meta.proto_type in [TYPE_MESSAGE]:
cls = self._cls_for(field) cls = self._cls_for(field)
value = cls().parse(value) value = cls().parse(value)
elif meta.proto_type in [TYPE_MAP]:
# TODO: This is slow, use a cache to make it faster since each
# key/value pair will recreate the class.
Entry = dataclasses.make_dataclass(
"Entry",
[
("key", Any, dataclass_field(1, meta.map_types[0], None)),
("value", Any, dataclass_field(2, meta.map_types[1], None)),
],
bases=(Message,),
)
value = Entry().parse(value)
return value return value
@ -434,10 +469,12 @@ class Message(ABC):
parsed.wire_type, meta, field, parsed.value parsed.wire_type, meta, field, parsed.value
) )
if isinstance(getattr(self, field.name), list) and not isinstance( current = getattr(self, field.name)
value, list if meta.proto_type == TYPE_MAP:
): # Value represents a single key/value pair entry in the map.
getattr(self, field.name).append(value) current[value.key] = value.value
elif isinstance(current, list) and not isinstance(value, list):
current.append(value)
else: else:
setattr(self, field.name, value) setattr(self, field.name, value)
else: else:

View File

@ -4,7 +4,7 @@
{% if description.enums %}import enum {% if description.enums %}import enum
{% endif %} {% endif %}
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import Dict, List
import betterproto import betterproto
@ -36,7 +36,7 @@ class {{ message.name }}(betterproto.Message):
{% if field.comment %} {% if field.comment %}
{{ field.comment }} {{ field.comment }}
{% endif %} {% endif %}
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.zero %}, default={{ field.zero }}{% endif %}) {{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.zero and field.field_type != 'map' %}, default={{ field.zero }}{% endif %}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %})
{% endfor %} {% endfor %}

View File

@ -48,7 +48,7 @@ if __name__ == "__main__":
json_files = get_files(".json") json_files = get_files(".json")
for filename in proto_files: for filename in proto_files:
print(f"Generatinng code for {os.path.basename(filename)}") print(f"Generating code for {os.path.basename(filename)}")
subprocess.run( subprocess.run(
f"protoc --python_out=. {os.path.basename(filename)}", shell=True f"protoc --python_out=. {os.path.basename(filename)}", shell=True
) )

View File

@ -0,0 +1,7 @@
{
"counts": {
"item1": 1,
"item2": 2,
"item3": 3
}
}

View File

@ -0,0 +1,5 @@
syntax = "proto3";
message Test {
map<string, int32> counts = 1;
}

View File

@ -12,6 +12,7 @@ from google.protobuf.descriptor_pb2 import (
DescriptorProto, DescriptorProto,
EnumDescriptorProto, EnumDescriptorProto,
FileDescriptorProto, FileDescriptorProto,
FieldDescriptorProto,
) )
from google.protobuf.compiler import plugin_pb2 as plugin from google.protobuf.compiler import plugin_pb2 as plugin
@ -20,7 +21,9 @@ from google.protobuf.compiler import plugin_pb2 as plugin
from jinja2 import Environment, PackageLoader from jinja2 import Environment, PackageLoader
def py_type(descriptor: DescriptorProto) -> Tuple[str, str]: def py_type(
message: DescriptorProto, descriptor: FieldDescriptorProto
) -> Tuple[str, str]:
if descriptor.type in [1, 2, 6, 7, 15, 16]: if descriptor.type in [1, 2, 6, 7, 15, 16]:
return "float", descriptor.default_value return "float", descriptor.default_value
elif descriptor.type in [3, 4, 5, 13, 17, 18]: elif descriptor.type in [3, 4, 5, 13, 17, 18]:
@ -115,6 +118,10 @@ def generate_code(request, response):
if isinstance(item, DescriptorProto): if isinstance(item, DescriptorProto):
# print(item, file=sys.stderr) # print(item, file=sys.stderr)
if item.options.map_entry:
# Skip generated map entry messages since we just use dicts
continue
data.update( data.update(
{ {
"type": "Message", "type": "Message",
@ -124,11 +131,33 @@ def generate_code(request, response):
) )
for i, f in enumerate(item.field): for i, f in enumerate(item.field):
t, zero = py_type(f) t, zero = py_type(item, f)
repeated = False repeated = False
packed = False packed = False
if f.label == 3: field_type = f.Type.Name(f.type).lower()[5:]
map_types = None
if f.type == 11:
# This might be a map...
message_type = f.type_name.split(".").pop()
map_entry = f"{f.name.capitalize()}Entry"
if message_type == map_entry:
for nested in item.nested_type:
if nested.name == map_entry:
if nested.options.map_entry:
print("Found a map!", file=sys.stderr)
k, _ = py_type(item, nested.field[0])
v, _ = py_type(item, nested.field[1])
t = f"Dict[{k}, {v}]"
zero = "dict"
field_type = "map"
map_types = (
f.Type.Name(nested.field[0].type),
f.Type.Name(nested.field[1].type),
)
if f.label == 3 and field_type != "map":
# Repeated field # Repeated field
repeated = True repeated = True
t = f"List[{t}]" t = f"List[{t}]"
@ -143,7 +172,8 @@ def generate_code(request, response):
"number": f.number, "number": f.number,
"comment": get_comment(proto_file, path + [2, i]), "comment": get_comment(proto_file, path + [2, i]),
"proto_type": int(f.type), "proto_type": int(f.type),
"field_type": f.Type.Name(f.type).lower()[5:], "field_type": field_type,
"map_types": map_types,
"type": t, "type": t,
"zero": zero, "zero": zero,
"repeated": repeated, "repeated": repeated,