Merge pull request #1 from danielgtaylor/maps
Add basic support for maps
This commit is contained in:
commit
32bc8d50fb
@ -6,7 +6,8 @@
|
|||||||
- [x] Don't encode zero values for nested types
|
- [x] Don't encode zero values for nested types
|
||||||
- [x] Enums
|
- [x] Enums
|
||||||
- [x] Repeated message fields
|
- [x] Repeated message fields
|
||||||
- [ ] Maps
|
- [x] Maps
|
||||||
|
- [ ] Maps of message fields
|
||||||
- [ ] Support passthrough of unknown fields
|
- [ ] Support passthrough of unknown fields
|
||||||
- [ ] Refs to nested types
|
- [ ] Refs to nested types
|
||||||
- [ ] Imports in proto files
|
- [ ] Imports in proto files
|
||||||
|
@ -13,6 +13,7 @@ from typing import (
|
|||||||
Type,
|
Type,
|
||||||
Iterable,
|
Iterable,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
|
Optional,
|
||||||
)
|
)
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
|
||||||
@ -36,6 +37,7 @@ TYPE_SFIXED64 = "sfixed64"
|
|||||||
TYPE_STRING = "string"
|
TYPE_STRING = "string"
|
||||||
TYPE_BYTES = "bytes"
|
TYPE_BYTES = "bytes"
|
||||||
TYPE_MESSAGE = "message"
|
TYPE_MESSAGE = "message"
|
||||||
|
TYPE_MAP = "map"
|
||||||
|
|
||||||
|
|
||||||
# Fields that use a fixed amount of space (4 or 8 bytes)
|
# Fields that use a fixed amount of space (4 or 8 bytes)
|
||||||
@ -87,7 +89,7 @@ WIRE_VARINT_TYPES = [
|
|||||||
|
|
||||||
WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]
|
WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]
|
||||||
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
|
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
|
||||||
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE]
|
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
@ -98,6 +100,8 @@ class FieldMetadata:
|
|||||||
number: int
|
number: int
|
||||||
# Protobuf type name
|
# Protobuf type name
|
||||||
proto_type: str
|
proto_type: str
|
||||||
|
# Map information if the proto_type is a map
|
||||||
|
map_types: Optional[Tuple[str, str]]
|
||||||
# Default value if given
|
# Default value if given
|
||||||
default: Any
|
default: Any
|
||||||
|
|
||||||
@ -107,10 +111,14 @@ class FieldMetadata:
|
|||||||
return field.metadata["betterproto"]
|
return field.metadata["betterproto"]
|
||||||
|
|
||||||
|
|
||||||
def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
|
def dataclass_field(
|
||||||
|
number: int,
|
||||||
|
proto_type: str,
|
||||||
|
default: Any,
|
||||||
|
map_types: Optional[Tuple[str, str]] = None,
|
||||||
|
**kwargs: dict,
|
||||||
|
) -> dataclasses.Field:
|
||||||
"""Creates a dataclass field with attached protobuf metadata."""
|
"""Creates a dataclass field with attached protobuf metadata."""
|
||||||
kwargs = {}
|
|
||||||
|
|
||||||
if callable(default):
|
if callable(default):
|
||||||
kwargs["default_factory"] = default
|
kwargs["default_factory"] = default
|
||||||
elif isinstance(default, dict) or isinstance(default, list):
|
elif isinstance(default, dict) or isinstance(default, list):
|
||||||
@ -119,7 +127,8 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
|
|||||||
kwargs["default"] = default
|
kwargs["default"] = default
|
||||||
|
|
||||||
return dataclasses.field(
|
return dataclasses.field(
|
||||||
**kwargs, metadata={"betterproto": FieldMetadata(number, proto_type, default)}
|
**kwargs,
|
||||||
|
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, default)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -129,63 +138,69 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
|
|||||||
|
|
||||||
|
|
||||||
def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
|
def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
|
||||||
return field(number, TYPE_ENUM, default=default)
|
return dataclass_field(number, TYPE_ENUM, default=default)
|
||||||
|
|
||||||
|
|
||||||
def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
|
def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
|
||||||
return field(number, TYPE_INT32, default=default)
|
return dataclass_field(number, TYPE_INT32, default=default)
|
||||||
|
|
||||||
|
|
||||||
def int64_field(number: int, default: int = 0) -> Any:
|
def int64_field(number: int, default: int = 0) -> Any:
|
||||||
return field(number, TYPE_INT64, default=default)
|
return dataclass_field(number, TYPE_INT64, default=default)
|
||||||
|
|
||||||
|
|
||||||
def uint32_field(number: int, default: int = 0) -> Any:
|
def uint32_field(number: int, default: int = 0) -> Any:
|
||||||
return field(number, TYPE_UINT32, default=default)
|
return dataclass_field(number, TYPE_UINT32, default=default)
|
||||||
|
|
||||||
|
|
||||||
def uint64_field(number: int, default: int = 0) -> Any:
|
def uint64_field(number: int, default: int = 0) -> Any:
|
||||||
return field(number, TYPE_UINT64, default=default)
|
return dataclass_field(number, TYPE_UINT64, default=default)
|
||||||
|
|
||||||
|
|
||||||
def sint32_field(number: int, default: int = 0) -> Any:
|
def sint32_field(number: int, default: int = 0) -> Any:
|
||||||
return field(number, TYPE_SINT32, default=default)
|
return dataclass_field(number, TYPE_SINT32, default=default)
|
||||||
|
|
||||||
|
|
||||||
def sint64_field(number: int, default: int = 0) -> Any:
|
def sint64_field(number: int, default: int = 0) -> Any:
|
||||||
return field(number, TYPE_SINT64, default=default)
|
return dataclass_field(number, TYPE_SINT64, default=default)
|
||||||
|
|
||||||
|
|
||||||
def float_field(number: int, default: float = 0.0) -> Any:
|
def float_field(number: int, default: float = 0.0) -> Any:
|
||||||
return field(number, TYPE_FLOAT, default=default)
|
return dataclass_field(number, TYPE_FLOAT, default=default)
|
||||||
|
|
||||||
|
|
||||||
def double_field(number: int, default: float = 0.0) -> Any:
|
def double_field(number: int, default: float = 0.0) -> Any:
|
||||||
return field(number, TYPE_DOUBLE, default=default)
|
return dataclass_field(number, TYPE_DOUBLE, default=default)
|
||||||
|
|
||||||
|
|
||||||
def fixed32_field(number: int, default: float = 0.0) -> Any:
|
def fixed32_field(number: int, default: float = 0.0) -> Any:
|
||||||
return field(number, TYPE_FIXED32, default=default)
|
return dataclass_field(number, TYPE_FIXED32, default=default)
|
||||||
|
|
||||||
|
|
||||||
def fixed64_field(number: int, default: float = 0.0) -> Any:
|
def fixed64_field(number: int, default: float = 0.0) -> Any:
|
||||||
return field(number, TYPE_FIXED64, default=default)
|
return dataclass_field(number, TYPE_FIXED64, default=default)
|
||||||
|
|
||||||
|
|
||||||
def sfixed32_field(number: int, default: float = 0.0) -> Any:
|
def sfixed32_field(number: int, default: float = 0.0) -> Any:
|
||||||
return field(number, TYPE_SFIXED32, default=default)
|
return dataclass_field(number, TYPE_SFIXED32, default=default)
|
||||||
|
|
||||||
|
|
||||||
def sfixed64_field(number: int, default: float = 0.0) -> Any:
|
def sfixed64_field(number: int, default: float = 0.0) -> Any:
|
||||||
return field(number, TYPE_SFIXED64, default=default)
|
return dataclass_field(number, TYPE_SFIXED64, default=default)
|
||||||
|
|
||||||
|
|
||||||
def string_field(number: int, default: str = "") -> Any:
|
def string_field(number: int, default: str = "") -> Any:
|
||||||
return field(number, TYPE_STRING, default=default)
|
return dataclass_field(number, TYPE_STRING, default=default)
|
||||||
|
|
||||||
|
|
||||||
def message_field(number: int, default: Type["Message"]) -> Any:
|
def message_field(number: int, default: Type["Message"]) -> Any:
|
||||||
return field(number, TYPE_MESSAGE, default=default)
|
return dataclass_field(number, TYPE_MESSAGE, default=default)
|
||||||
|
|
||||||
|
|
||||||
|
def map_field(number: int, key_type: str, value_type: str) -> Any:
|
||||||
|
return dataclass_field(
|
||||||
|
number, TYPE_MAP, default=dict, map_types=(key_type, value_type)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _pack_fmt(proto_type: str) -> str:
|
def _pack_fmt(proto_type: str) -> str:
|
||||||
@ -354,6 +369,14 @@ class Message(ABC):
|
|||||||
else:
|
else:
|
||||||
for item in value:
|
for item in value:
|
||||||
output += _serialize_single(meta.number, meta.proto_type, item)
|
output += _serialize_single(meta.number, meta.proto_type, item)
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
if not len(value):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for k, v in value.items():
|
||||||
|
sk = _serialize_single(1, meta.map_types[0], k)
|
||||||
|
sv = _serialize_single(2, meta.map_types[1], v)
|
||||||
|
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
|
||||||
else:
|
else:
|
||||||
if value == field.default:
|
if value == field.default:
|
||||||
continue
|
continue
|
||||||
@ -377,23 +400,35 @@ class Message(ABC):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
"""Adjusts values after parsing."""
|
"""Adjusts values after parsing."""
|
||||||
if wire_type == WIRE_VARINT:
|
if wire_type == WIRE_VARINT:
|
||||||
if meta.proto_type in ["int32", "int64"]:
|
if meta.proto_type in [TYPE_INT32, TYPE_INT64]:
|
||||||
bits = int(meta.proto_type[3:])
|
bits = int(meta.proto_type[3:])
|
||||||
value = value & ((1 << bits) - 1)
|
value = value & ((1 << bits) - 1)
|
||||||
signbit = 1 << (bits - 1)
|
signbit = 1 << (bits - 1)
|
||||||
value = int((value ^ signbit) - signbit)
|
value = int((value ^ signbit) - signbit)
|
||||||
elif meta.proto_type in ["sint32", "sint64"]:
|
elif meta.proto_type in [TYPE_SINT32, TYPE_SINT64]:
|
||||||
# Undo zig-zag encoding
|
# Undo zig-zag encoding
|
||||||
value = (value >> 1) ^ (-(value & 1))
|
value = (value >> 1) ^ (-(value & 1))
|
||||||
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
|
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
|
||||||
fmt = _pack_fmt(meta.proto_type)
|
fmt = _pack_fmt(meta.proto_type)
|
||||||
value = struct.unpack(fmt, value)[0]
|
value = struct.unpack(fmt, value)[0]
|
||||||
elif wire_type == WIRE_LEN_DELIM:
|
elif wire_type == WIRE_LEN_DELIM:
|
||||||
if meta.proto_type in ["string"]:
|
if meta.proto_type in [TYPE_STRING]:
|
||||||
value = value.decode("utf-8")
|
value = value.decode("utf-8")
|
||||||
elif meta.proto_type in ["message"]:
|
elif meta.proto_type in [TYPE_MESSAGE]:
|
||||||
cls = self._cls_for(field)
|
cls = self._cls_for(field)
|
||||||
value = cls().parse(value)
|
value = cls().parse(value)
|
||||||
|
elif meta.proto_type in [TYPE_MAP]:
|
||||||
|
# TODO: This is slow, use a cache to make it faster since each
|
||||||
|
# key/value pair will recreate the class.
|
||||||
|
Entry = dataclasses.make_dataclass(
|
||||||
|
"Entry",
|
||||||
|
[
|
||||||
|
("key", Any, dataclass_field(1, meta.map_types[0], None)),
|
||||||
|
("value", Any, dataclass_field(2, meta.map_types[1], None)),
|
||||||
|
],
|
||||||
|
bases=(Message,),
|
||||||
|
)
|
||||||
|
value = Entry().parse(value)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@ -434,10 +469,12 @@ class Message(ABC):
|
|||||||
parsed.wire_type, meta, field, parsed.value
|
parsed.wire_type, meta, field, parsed.value
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(getattr(self, field.name), list) and not isinstance(
|
current = getattr(self, field.name)
|
||||||
value, list
|
if meta.proto_type == TYPE_MAP:
|
||||||
):
|
# Value represents a single key/value pair entry in the map.
|
||||||
getattr(self, field.name).append(value)
|
current[value.key] = value.value
|
||||||
|
elif isinstance(current, list) and not isinstance(value, list):
|
||||||
|
current.append(value)
|
||||||
else:
|
else:
|
||||||
setattr(self, field.name, value)
|
setattr(self, field.name, value)
|
||||||
else:
|
else:
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
{% if description.enums %}import enum
|
{% if description.enums %}import enum
|
||||||
{% endif %}
|
{% endif %}
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List
|
from typing import Dict, List
|
||||||
|
|
||||||
import betterproto
|
import betterproto
|
||||||
|
|
||||||
@ -36,7 +36,7 @@ class {{ message.name }}(betterproto.Message):
|
|||||||
{% if field.comment %}
|
{% if field.comment %}
|
||||||
{{ field.comment }}
|
{{ field.comment }}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.zero %}, default={{ field.zero }}{% endif %})
|
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.zero and field.field_type != 'map' %}, default={{ field.zero }}{% endif %}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %})
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ if __name__ == "__main__":
|
|||||||
json_files = get_files(".json")
|
json_files = get_files(".json")
|
||||||
|
|
||||||
for filename in proto_files:
|
for filename in proto_files:
|
||||||
print(f"Generatinng code for {os.path.basename(filename)}")
|
print(f"Generating code for {os.path.basename(filename)}")
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
f"protoc --python_out=. {os.path.basename(filename)}", shell=True
|
f"protoc --python_out=. {os.path.basename(filename)}", shell=True
|
||||||
)
|
)
|
||||||
|
7
betterproto/tests/map.json
Normal file
7
betterproto/tests/map.json
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"counts": {
|
||||||
|
"item1": 1,
|
||||||
|
"item2": 2,
|
||||||
|
"item3": 3
|
||||||
|
}
|
||||||
|
}
|
5
betterproto/tests/map.proto
Normal file
5
betterproto/tests/map.proto
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
map<string, int32> counts = 1;
|
||||||
|
}
|
@ -12,6 +12,7 @@ from google.protobuf.descriptor_pb2 import (
|
|||||||
DescriptorProto,
|
DescriptorProto,
|
||||||
EnumDescriptorProto,
|
EnumDescriptorProto,
|
||||||
FileDescriptorProto,
|
FileDescriptorProto,
|
||||||
|
FieldDescriptorProto,
|
||||||
)
|
)
|
||||||
|
|
||||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||||
@ -20,7 +21,9 @@ from google.protobuf.compiler import plugin_pb2 as plugin
|
|||||||
from jinja2 import Environment, PackageLoader
|
from jinja2 import Environment, PackageLoader
|
||||||
|
|
||||||
|
|
||||||
def py_type(descriptor: DescriptorProto) -> Tuple[str, str]:
|
def py_type(
|
||||||
|
message: DescriptorProto, descriptor: FieldDescriptorProto
|
||||||
|
) -> Tuple[str, str]:
|
||||||
if descriptor.type in [1, 2, 6, 7, 15, 16]:
|
if descriptor.type in [1, 2, 6, 7, 15, 16]:
|
||||||
return "float", descriptor.default_value
|
return "float", descriptor.default_value
|
||||||
elif descriptor.type in [3, 4, 5, 13, 17, 18]:
|
elif descriptor.type in [3, 4, 5, 13, 17, 18]:
|
||||||
@ -115,6 +118,10 @@ def generate_code(request, response):
|
|||||||
|
|
||||||
if isinstance(item, DescriptorProto):
|
if isinstance(item, DescriptorProto):
|
||||||
# print(item, file=sys.stderr)
|
# print(item, file=sys.stderr)
|
||||||
|
if item.options.map_entry:
|
||||||
|
# Skip generated map entry messages since we just use dicts
|
||||||
|
continue
|
||||||
|
|
||||||
data.update(
|
data.update(
|
||||||
{
|
{
|
||||||
"type": "Message",
|
"type": "Message",
|
||||||
@ -124,11 +131,33 @@ def generate_code(request, response):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for i, f in enumerate(item.field):
|
for i, f in enumerate(item.field):
|
||||||
t, zero = py_type(f)
|
t, zero = py_type(item, f)
|
||||||
repeated = False
|
repeated = False
|
||||||
packed = False
|
packed = False
|
||||||
|
|
||||||
if f.label == 3:
|
field_type = f.Type.Name(f.type).lower()[5:]
|
||||||
|
map_types = None
|
||||||
|
if f.type == 11:
|
||||||
|
# This might be a map...
|
||||||
|
message_type = f.type_name.split(".").pop()
|
||||||
|
map_entry = f"{f.name.capitalize()}Entry"
|
||||||
|
|
||||||
|
if message_type == map_entry:
|
||||||
|
for nested in item.nested_type:
|
||||||
|
if nested.name == map_entry:
|
||||||
|
if nested.options.map_entry:
|
||||||
|
print("Found a map!", file=sys.stderr)
|
||||||
|
k, _ = py_type(item, nested.field[0])
|
||||||
|
v, _ = py_type(item, nested.field[1])
|
||||||
|
t = f"Dict[{k}, {v}]"
|
||||||
|
zero = "dict"
|
||||||
|
field_type = "map"
|
||||||
|
map_types = (
|
||||||
|
f.Type.Name(nested.field[0].type),
|
||||||
|
f.Type.Name(nested.field[1].type),
|
||||||
|
)
|
||||||
|
|
||||||
|
if f.label == 3 and field_type != "map":
|
||||||
# Repeated field
|
# Repeated field
|
||||||
repeated = True
|
repeated = True
|
||||||
t = f"List[{t}]"
|
t = f"List[{t}]"
|
||||||
@ -143,7 +172,8 @@ def generate_code(request, response):
|
|||||||
"number": f.number,
|
"number": f.number,
|
||||||
"comment": get_comment(proto_file, path + [2, i]),
|
"comment": get_comment(proto_file, path + [2, i]),
|
||||||
"proto_type": int(f.type),
|
"proto_type": int(f.type),
|
||||||
"field_type": f.Type.Name(f.type).lower()[5:],
|
"field_type": field_type,
|
||||||
|
"map_types": map_types,
|
||||||
"type": t,
|
"type": t,
|
||||||
"zero": zero,
|
"zero": zero,
|
||||||
"repeated": repeated,
|
"repeated": repeated,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user