Merge pull request #1 from danielgtaylor/maps
Add basic support for maps
This commit is contained in:
commit
32bc8d50fb
@ -6,7 +6,8 @@
|
||||
- [x] Don't encode zero values for nested types
|
||||
- [x] Enums
|
||||
- [x] Repeated message fields
|
||||
- [ ] Maps
|
||||
- [x] Maps
|
||||
- [ ] Maps of message fields
|
||||
- [ ] Support passthrough of unknown fields
|
||||
- [ ] Refs to nested types
|
||||
- [ ] Imports in proto files
|
||||
|
@ -13,6 +13,7 @@ from typing import (
|
||||
Type,
|
||||
Iterable,
|
||||
TypeVar,
|
||||
Optional,
|
||||
)
|
||||
import dataclasses
|
||||
|
||||
@ -36,6 +37,7 @@ TYPE_SFIXED64 = "sfixed64"
|
||||
TYPE_STRING = "string"
|
||||
TYPE_BYTES = "bytes"
|
||||
TYPE_MESSAGE = "message"
|
||||
TYPE_MAP = "map"
|
||||
|
||||
|
||||
# Fields that use a fixed amount of space (4 or 8 bytes)
|
||||
@ -87,7 +89,7 @@ WIRE_VARINT_TYPES = [
|
||||
|
||||
WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]
|
||||
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
|
||||
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE]
|
||||
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -98,6 +100,8 @@ class FieldMetadata:
|
||||
number: int
|
||||
# Protobuf type name
|
||||
proto_type: str
|
||||
# Map information if the proto_type is a map
|
||||
map_types: Optional[Tuple[str, str]]
|
||||
# Default value if given
|
||||
default: Any
|
||||
|
||||
@ -107,10 +111,14 @@ class FieldMetadata:
|
||||
return field.metadata["betterproto"]
|
||||
|
||||
|
||||
def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
|
||||
def dataclass_field(
|
||||
number: int,
|
||||
proto_type: str,
|
||||
default: Any,
|
||||
map_types: Optional[Tuple[str, str]] = None,
|
||||
**kwargs: dict,
|
||||
) -> dataclasses.Field:
|
||||
"""Creates a dataclass field with attached protobuf metadata."""
|
||||
kwargs = {}
|
||||
|
||||
if callable(default):
|
||||
kwargs["default_factory"] = default
|
||||
elif isinstance(default, dict) or isinstance(default, list):
|
||||
@ -119,7 +127,8 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
|
||||
kwargs["default"] = default
|
||||
|
||||
return dataclasses.field(
|
||||
**kwargs, metadata={"betterproto": FieldMetadata(number, proto_type, default)}
|
||||
**kwargs,
|
||||
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, default)},
|
||||
)
|
||||
|
||||
|
||||
@ -129,63 +138,69 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
|
||||
|
||||
|
||||
def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
|
||||
return field(number, TYPE_ENUM, default=default)
|
||||
return dataclass_field(number, TYPE_ENUM, default=default)
|
||||
|
||||
|
||||
def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
|
||||
return field(number, TYPE_INT32, default=default)
|
||||
return dataclass_field(number, TYPE_INT32, default=default)
|
||||
|
||||
|
||||
def int64_field(number: int, default: int = 0) -> Any:
|
||||
return field(number, TYPE_INT64, default=default)
|
||||
return dataclass_field(number, TYPE_INT64, default=default)
|
||||
|
||||
|
||||
def uint32_field(number: int, default: int = 0) -> Any:
|
||||
return field(number, TYPE_UINT32, default=default)
|
||||
return dataclass_field(number, TYPE_UINT32, default=default)
|
||||
|
||||
|
||||
def uint64_field(number: int, default: int = 0) -> Any:
|
||||
return field(number, TYPE_UINT64, default=default)
|
||||
return dataclass_field(number, TYPE_UINT64, default=default)
|
||||
|
||||
|
||||
def sint32_field(number: int, default: int = 0) -> Any:
|
||||
return field(number, TYPE_SINT32, default=default)
|
||||
return dataclass_field(number, TYPE_SINT32, default=default)
|
||||
|
||||
|
||||
def sint64_field(number: int, default: int = 0) -> Any:
|
||||
return field(number, TYPE_SINT64, default=default)
|
||||
return dataclass_field(number, TYPE_SINT64, default=default)
|
||||
|
||||
|
||||
def float_field(number: int, default: float = 0.0) -> Any:
|
||||
return field(number, TYPE_FLOAT, default=default)
|
||||
return dataclass_field(number, TYPE_FLOAT, default=default)
|
||||
|
||||
|
||||
def double_field(number: int, default: float = 0.0) -> Any:
|
||||
return field(number, TYPE_DOUBLE, default=default)
|
||||
return dataclass_field(number, TYPE_DOUBLE, default=default)
|
||||
|
||||
|
||||
def fixed32_field(number: int, default: float = 0.0) -> Any:
|
||||
return field(number, TYPE_FIXED32, default=default)
|
||||
return dataclass_field(number, TYPE_FIXED32, default=default)
|
||||
|
||||
|
||||
def fixed64_field(number: int, default: float = 0.0) -> Any:
|
||||
return field(number, TYPE_FIXED64, default=default)
|
||||
return dataclass_field(number, TYPE_FIXED64, default=default)
|
||||
|
||||
|
||||
def sfixed32_field(number: int, default: float = 0.0) -> Any:
|
||||
return field(number, TYPE_SFIXED32, default=default)
|
||||
return dataclass_field(number, TYPE_SFIXED32, default=default)
|
||||
|
||||
|
||||
def sfixed64_field(number: int, default: float = 0.0) -> Any:
|
||||
return field(number, TYPE_SFIXED64, default=default)
|
||||
return dataclass_field(number, TYPE_SFIXED64, default=default)
|
||||
|
||||
|
||||
def string_field(number: int, default: str = "") -> Any:
|
||||
return field(number, TYPE_STRING, default=default)
|
||||
return dataclass_field(number, TYPE_STRING, default=default)
|
||||
|
||||
|
||||
def message_field(number: int, default: Type["Message"]) -> Any:
|
||||
return field(number, TYPE_MESSAGE, default=default)
|
||||
return dataclass_field(number, TYPE_MESSAGE, default=default)
|
||||
|
||||
|
||||
def map_field(number: int, key_type: str, value_type: str) -> Any:
|
||||
return dataclass_field(
|
||||
number, TYPE_MAP, default=dict, map_types=(key_type, value_type)
|
||||
)
|
||||
|
||||
|
||||
def _pack_fmt(proto_type: str) -> str:
|
||||
@ -354,6 +369,14 @@ class Message(ABC):
|
||||
else:
|
||||
for item in value:
|
||||
output += _serialize_single(meta.number, meta.proto_type, item)
|
||||
elif isinstance(value, dict):
|
||||
if not len(value):
|
||||
continue
|
||||
|
||||
for k, v in value.items():
|
||||
sk = _serialize_single(1, meta.map_types[0], k)
|
||||
sv = _serialize_single(2, meta.map_types[1], v)
|
||||
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
|
||||
else:
|
||||
if value == field.default:
|
||||
continue
|
||||
@ -377,23 +400,35 @@ class Message(ABC):
|
||||
) -> Any:
|
||||
"""Adjusts values after parsing."""
|
||||
if wire_type == WIRE_VARINT:
|
||||
if meta.proto_type in ["int32", "int64"]:
|
||||
if meta.proto_type in [TYPE_INT32, TYPE_INT64]:
|
||||
bits = int(meta.proto_type[3:])
|
||||
value = value & ((1 << bits) - 1)
|
||||
signbit = 1 << (bits - 1)
|
||||
value = int((value ^ signbit) - signbit)
|
||||
elif meta.proto_type in ["sint32", "sint64"]:
|
||||
elif meta.proto_type in [TYPE_SINT32, TYPE_SINT64]:
|
||||
# Undo zig-zag encoding
|
||||
value = (value >> 1) ^ (-(value & 1))
|
||||
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
|
||||
fmt = _pack_fmt(meta.proto_type)
|
||||
value = struct.unpack(fmt, value)[0]
|
||||
elif wire_type == WIRE_LEN_DELIM:
|
||||
if meta.proto_type in ["string"]:
|
||||
if meta.proto_type in [TYPE_STRING]:
|
||||
value = value.decode("utf-8")
|
||||
elif meta.proto_type in ["message"]:
|
||||
elif meta.proto_type in [TYPE_MESSAGE]:
|
||||
cls = self._cls_for(field)
|
||||
value = cls().parse(value)
|
||||
elif meta.proto_type in [TYPE_MAP]:
|
||||
# TODO: This is slow, use a cache to make it faster since each
|
||||
# key/value pair will recreate the class.
|
||||
Entry = dataclasses.make_dataclass(
|
||||
"Entry",
|
||||
[
|
||||
("key", Any, dataclass_field(1, meta.map_types[0], None)),
|
||||
("value", Any, dataclass_field(2, meta.map_types[1], None)),
|
||||
],
|
||||
bases=(Message,),
|
||||
)
|
||||
value = Entry().parse(value)
|
||||
|
||||
return value
|
||||
|
||||
@ -434,10 +469,12 @@ class Message(ABC):
|
||||
parsed.wire_type, meta, field, parsed.value
|
||||
)
|
||||
|
||||
if isinstance(getattr(self, field.name), list) and not isinstance(
|
||||
value, list
|
||||
):
|
||||
getattr(self, field.name).append(value)
|
||||
current = getattr(self, field.name)
|
||||
if meta.proto_type == TYPE_MAP:
|
||||
# Value represents a single key/value pair entry in the map.
|
||||
current[value.key] = value.value
|
||||
elif isinstance(current, list) and not isinstance(value, list):
|
||||
current.append(value)
|
||||
else:
|
||||
setattr(self, field.name, value)
|
||||
else:
|
||||
|
@ -4,7 +4,7 @@
|
||||
{% if description.enums %}import enum
|
||||
{% endif %}
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
|
||||
import betterproto
|
||||
|
||||
@ -36,7 +36,7 @@ class {{ message.name }}(betterproto.Message):
|
||||
{% if field.comment %}
|
||||
{{ field.comment }}
|
||||
{% endif %}
|
||||
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.zero %}, default={{ field.zero }}{% endif %})
|
||||
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.zero and field.field_type != 'map' %}, default={{ field.zero }}{% endif %}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %})
|
||||
{% endfor %}
|
||||
|
||||
|
||||
|
@ -48,7 +48,7 @@ if __name__ == "__main__":
|
||||
json_files = get_files(".json")
|
||||
|
||||
for filename in proto_files:
|
||||
print(f"Generatinng code for {os.path.basename(filename)}")
|
||||
print(f"Generating code for {os.path.basename(filename)}")
|
||||
subprocess.run(
|
||||
f"protoc --python_out=. {os.path.basename(filename)}", shell=True
|
||||
)
|
||||
|
7
betterproto/tests/map.json
Normal file
7
betterproto/tests/map.json
Normal file
@ -0,0 +1,7 @@
|
||||
{
|
||||
"counts": {
|
||||
"item1": 1,
|
||||
"item2": 2,
|
||||
"item3": 3
|
||||
}
|
||||
}
|
5
betterproto/tests/map.proto
Normal file
5
betterproto/tests/map.proto
Normal file
@ -0,0 +1,5 @@
|
||||
syntax = "proto3";
|
||||
|
||||
message Test {
|
||||
map<string, int32> counts = 1;
|
||||
}
|
@ -12,6 +12,7 @@ from google.protobuf.descriptor_pb2 import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FileDescriptorProto,
|
||||
FieldDescriptorProto,
|
||||
)
|
||||
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
@ -20,7 +21,9 @@ from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
from jinja2 import Environment, PackageLoader
|
||||
|
||||
|
||||
def py_type(descriptor: DescriptorProto) -> Tuple[str, str]:
|
||||
def py_type(
|
||||
message: DescriptorProto, descriptor: FieldDescriptorProto
|
||||
) -> Tuple[str, str]:
|
||||
if descriptor.type in [1, 2, 6, 7, 15, 16]:
|
||||
return "float", descriptor.default_value
|
||||
elif descriptor.type in [3, 4, 5, 13, 17, 18]:
|
||||
@ -115,6 +118,10 @@ def generate_code(request, response):
|
||||
|
||||
if isinstance(item, DescriptorProto):
|
||||
# print(item, file=sys.stderr)
|
||||
if item.options.map_entry:
|
||||
# Skip generated map entry messages since we just use dicts
|
||||
continue
|
||||
|
||||
data.update(
|
||||
{
|
||||
"type": "Message",
|
||||
@ -124,11 +131,33 @@ def generate_code(request, response):
|
||||
)
|
||||
|
||||
for i, f in enumerate(item.field):
|
||||
t, zero = py_type(f)
|
||||
t, zero = py_type(item, f)
|
||||
repeated = False
|
||||
packed = False
|
||||
|
||||
if f.label == 3:
|
||||
field_type = f.Type.Name(f.type).lower()[5:]
|
||||
map_types = None
|
||||
if f.type == 11:
|
||||
# This might be a map...
|
||||
message_type = f.type_name.split(".").pop()
|
||||
map_entry = f"{f.name.capitalize()}Entry"
|
||||
|
||||
if message_type == map_entry:
|
||||
for nested in item.nested_type:
|
||||
if nested.name == map_entry:
|
||||
if nested.options.map_entry:
|
||||
print("Found a map!", file=sys.stderr)
|
||||
k, _ = py_type(item, nested.field[0])
|
||||
v, _ = py_type(item, nested.field[1])
|
||||
t = f"Dict[{k}, {v}]"
|
||||
zero = "dict"
|
||||
field_type = "map"
|
||||
map_types = (
|
||||
f.Type.Name(nested.field[0].type),
|
||||
f.Type.Name(nested.field[1].type),
|
||||
)
|
||||
|
||||
if f.label == 3 and field_type != "map":
|
||||
# Repeated field
|
||||
repeated = True
|
||||
t = f"List[{t}]"
|
||||
@ -143,7 +172,8 @@ def generate_code(request, response):
|
||||
"number": f.number,
|
||||
"comment": get_comment(proto_file, path + [2, i]),
|
||||
"proto_type": int(f.type),
|
||||
"field_type": f.Type.Name(f.type).lower()[5:],
|
||||
"field_type": field_type,
|
||||
"map_types": map_types,
|
||||
"type": t,
|
||||
"zero": zero,
|
||||
"repeated": repeated,
|
||||
|
Loading…
x
Reference in New Issue
Block a user