From e0d1611797044ad7e2385a548a476f3b799c99fb Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Thu, 10 Oct 2019 22:20:27 -0700 Subject: [PATCH] Add basic support for maps --- README.md | 3 +- betterproto/__init__.py | 93 ++++++++++++++++++++++++----------- betterproto/templates/main.py | 4 +- betterproto/tests/generate.py | 2 +- betterproto/tests/map.json | 7 +++ betterproto/tests/map.proto | 5 ++ protoc-gen-betterpy.py | 38 ++++++++++++-- 7 files changed, 116 insertions(+), 36 deletions(-) create mode 100644 betterproto/tests/map.json create mode 100644 betterproto/tests/map.proto diff --git a/README.md b/README.md index 3cff8e8..fb06616 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,8 @@ - [x] Don't encode zero values for nested types - [x] Enums - [x] Repeated message fields -- [ ] Maps +- [x] Maps + - [ ] Maps of message fields - [ ] Support passthrough of unknown fields - [ ] Refs to nested types - [ ] Imports in proto files diff --git a/betterproto/__init__.py b/betterproto/__init__.py index cf4b6e0..7d9b217 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -13,6 +13,7 @@ from typing import ( Type, Iterable, TypeVar, + Optional, ) import dataclasses @@ -36,6 +37,7 @@ TYPE_SFIXED64 = "sfixed64" TYPE_STRING = "string" TYPE_BYTES = "bytes" TYPE_MESSAGE = "message" +TYPE_MAP = "map" # Fields that use a fixed amount of space (4 or 8 bytes) @@ -87,7 +89,7 @@ WIRE_VARINT_TYPES = [ WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32] WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64] -WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE] +WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] @dataclasses.dataclass(frozen=True) @@ -98,6 +100,8 @@ class FieldMetadata: number: int # Protobuf type name proto_type: str + # Map information if the proto_type is a map + map_types: Optional[Tuple[str, str]] # Default value if given default: Any @@ -107,10 +111,14 @@ class FieldMetadata: return field.metadata["betterproto"] -def field(number: int, proto_type: str, default: Any) -> dataclasses.Field: +def dataclass_field( + number: int, + proto_type: str, + default: Any, + map_types: Optional[Tuple[str, str]] = None, + **kwargs: dict, +) -> dataclasses.Field: """Creates a dataclass field with attached protobuf metadata.""" - kwargs = {} - if callable(default): kwargs["default_factory"] = default elif isinstance(default, dict) or isinstance(default, list): @@ -119,7 +127,8 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field: kwargs["default"] = default return dataclasses.field( - **kwargs, metadata={"betterproto": FieldMetadata(number, proto_type, default)} + **kwargs, + metadata={"betterproto": FieldMetadata(number, proto_type, map_types, default)}, ) @@ -129,63 +138,69 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field: def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any: - return field(number, TYPE_ENUM, default=default) + return dataclass_field(number, TYPE_ENUM, default=default) def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any: - return field(number, TYPE_INT32, default=default) + return dataclass_field(number, TYPE_INT32, default=default) def int64_field(number: int, default: int = 0) -> Any: - return field(number, TYPE_INT64, default=default) + return dataclass_field(number, TYPE_INT64, default=default) def uint32_field(number: int, default: int = 0) -> Any: - return field(number, TYPE_UINT32, default=default) + return dataclass_field(number, TYPE_UINT32, default=default) def uint64_field(number: int, default: int = 0) -> Any: - return field(number, TYPE_UINT64, default=default) + return dataclass_field(number, TYPE_UINT64, default=default) def sint32_field(number: int, default: int = 0) -> Any: - return field(number, TYPE_SINT32, default=default) + return dataclass_field(number, TYPE_SINT32, default=default) def sint64_field(number: int, default: int = 0) -> Any: - return field(number, TYPE_SINT64, default=default) + return dataclass_field(number, TYPE_SINT64, default=default) def float_field(number: int, default: float = 0.0) -> Any: - return field(number, TYPE_FLOAT, default=default) + return dataclass_field(number, TYPE_FLOAT, default=default) def double_field(number: int, default: float = 0.0) -> Any: - return field(number, TYPE_DOUBLE, default=default) + return dataclass_field(number, TYPE_DOUBLE, default=default) def fixed32_field(number: int, default: float = 0.0) -> Any: - return field(number, TYPE_FIXED32, default=default) + return dataclass_field(number, TYPE_FIXED32, default=default) def fixed64_field(number: int, default: float = 0.0) -> Any: - return field(number, TYPE_FIXED64, default=default) + return dataclass_field(number, TYPE_FIXED64, default=default) def sfixed32_field(number: int, default: float = 0.0) -> Any: - return field(number, TYPE_SFIXED32, default=default) + return dataclass_field(number, TYPE_SFIXED32, default=default) def sfixed64_field(number: int, default: float = 0.0) -> Any: - return field(number, TYPE_SFIXED64, default=default) + return dataclass_field(number, TYPE_SFIXED64, default=default) def string_field(number: int, default: str = "") -> Any: - return field(number, TYPE_STRING, default=default) + return dataclass_field(number, TYPE_STRING, default=default) def message_field(number: int, default: Type["Message"]) -> Any: - return field(number, TYPE_MESSAGE, default=default) + return dataclass_field(number, TYPE_MESSAGE, default=default) + + +def map_field(number: int, key_type: str, value_type: str) -> Any: + return dataclass_field( + number, TYPE_MAP, default=dict, map_types=(key_type, value_type) + ) def _pack_fmt(proto_type: str) -> str: @@ -354,6 +369,14 @@ class Message(ABC): else: for item in value: output += _serialize_single(meta.number, meta.proto_type, item) + elif isinstance(value, dict): + if not len(value): + continue + + for k, v in value.items(): + sk = _serialize_single(1, meta.map_types[0], k) + sv = _serialize_single(2, meta.map_types[1], v) + output += _serialize_single(meta.number, meta.proto_type, sk + sv) else: if value == field.default: continue @@ -377,23 +400,35 @@ class Message(ABC): ) -> Any: """Adjusts values after parsing.""" if wire_type == WIRE_VARINT: - if meta.proto_type in ["int32", "int64"]: + if meta.proto_type in [TYPE_INT32, TYPE_INT64]: bits = int(meta.proto_type[3:]) value = value & ((1 << bits) - 1) signbit = 1 << (bits - 1) value = int((value ^ signbit) - signbit) - elif meta.proto_type in ["sint32", "sint64"]: + elif meta.proto_type in [TYPE_SINT32, TYPE_SINT64]: # Undo zig-zag encoding value = (value >> 1) ^ (-(value & 1)) elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]: fmt = _pack_fmt(meta.proto_type) value = struct.unpack(fmt, value)[0] elif wire_type == WIRE_LEN_DELIM: - if meta.proto_type in ["string"]: + if meta.proto_type in [TYPE_STRING]: value = value.decode("utf-8") - elif meta.proto_type in ["message"]: + elif meta.proto_type in [TYPE_MESSAGE]: cls = self._cls_for(field) value = cls().parse(value) + elif meta.proto_type in [TYPE_MAP]: + # TODO: This is slow, use a cache to make it faster since each + # key/value pair will recreate the class. + Entry = dataclasses.make_dataclass( + "Entry", + [ + ("key", Any, dataclass_field(1, meta.map_types[0], None)), + ("value", Any, dataclass_field(2, meta.map_types[1], None)), + ], + bases=(Message,), + ) + value = Entry().parse(value) return value @@ -434,10 +469,12 @@ class Message(ABC): parsed.wire_type, meta, field, parsed.value ) - if isinstance(getattr(self, field.name), list) and not isinstance( - value, list - ): - getattr(self, field.name).append(value) + current = getattr(self, field.name) + if meta.proto_type == TYPE_MAP: + # Value represents a single key/value pair entry in the map. + current[value.key] = value.value + elif isinstance(current, list) and not isinstance(value, list): + current.append(value) else: setattr(self, field.name, value) else: diff --git a/betterproto/templates/main.py b/betterproto/templates/main.py index 1313b05..332a144 100644 --- a/betterproto/templates/main.py +++ b/betterproto/templates/main.py @@ -4,7 +4,7 @@ {% if description.enums %}import enum {% endif %} from dataclasses import dataclass -from typing import List +from typing import Dict, List import betterproto @@ -36,7 +36,7 @@ class {{ message.name }}(betterproto.Message): {% if field.comment %} {{ field.comment }} {% endif %} - {{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.zero %}, default={{ field.zero }}{% endif %}) + {{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.zero and field.field_type != 'map' %}, default={{ field.zero }}{% endif %}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}) {% endfor %} diff --git a/betterproto/tests/generate.py b/betterproto/tests/generate.py index 5fd037e..ae3b095 100644 --- a/betterproto/tests/generate.py +++ b/betterproto/tests/generate.py @@ -48,7 +48,7 @@ if __name__ == "__main__": json_files = get_files(".json") for filename in proto_files: - print(f"Generatinng code for {os.path.basename(filename)}") + print(f"Generating code for {os.path.basename(filename)}") subprocess.run( f"protoc --python_out=. {os.path.basename(filename)}", shell=True ) diff --git a/betterproto/tests/map.json b/betterproto/tests/map.json new file mode 100644 index 0000000..6a1e853 --- /dev/null +++ b/betterproto/tests/map.json @@ -0,0 +1,7 @@ +{ + "counts": { + "item1": 1, + "item2": 2, + "item3": 3 + } +} diff --git a/betterproto/tests/map.proto b/betterproto/tests/map.proto new file mode 100644 index 0000000..669e287 --- /dev/null +++ b/betterproto/tests/map.proto @@ -0,0 +1,5 @@ +syntax = "proto3"; + +message Test { + map counts = 1; +} diff --git a/protoc-gen-betterpy.py b/protoc-gen-betterpy.py index a5d024c..f779199 100755 --- a/protoc-gen-betterpy.py +++ b/protoc-gen-betterpy.py @@ -12,6 +12,7 @@ from google.protobuf.descriptor_pb2 import ( DescriptorProto, EnumDescriptorProto, FileDescriptorProto, + FieldDescriptorProto, ) from google.protobuf.compiler import plugin_pb2 as plugin @@ -20,7 +21,9 @@ from google.protobuf.compiler import plugin_pb2 as plugin from jinja2 import Environment, PackageLoader -def py_type(descriptor: DescriptorProto) -> Tuple[str, str]: +def py_type( + message: DescriptorProto, descriptor: FieldDescriptorProto +) -> Tuple[str, str]: if descriptor.type in [1, 2, 6, 7, 15, 16]: return "float", descriptor.default_value elif descriptor.type in [3, 4, 5, 13, 17, 18]: @@ -115,6 +118,10 @@ def generate_code(request, response): if isinstance(item, DescriptorProto): # print(item, file=sys.stderr) + if item.options.map_entry: + # Skip generated map entry messages since we just use dicts + continue + data.update( { "type": "Message", @@ -124,11 +131,33 @@ def generate_code(request, response): ) for i, f in enumerate(item.field): - t, zero = py_type(f) + t, zero = py_type(item, f) repeated = False packed = False - if f.label == 3: + field_type = f.Type.Name(f.type).lower()[5:] + map_types = None + if f.type == 11: + # This might be a map... + message_type = f.type_name.split(".").pop() + map_entry = f"{f.name.capitalize()}Entry" + + if message_type == map_entry: + for nested in item.nested_type: + if nested.name == map_entry: + if nested.options.map_entry: + print("Found a map!", file=sys.stderr) + k, _ = py_type(item, nested.field[0]) + v, _ = py_type(item, nested.field[1]) + t = f"Dict[{k}, {v}]" + zero = "dict" + field_type = "map" + map_types = ( + f.Type.Name(nested.field[0].type), + f.Type.Name(nested.field[1].type), + ) + + if f.label == 3 and field_type != "map": # Repeated field repeated = True t = f"List[{t}]" @@ -143,7 +172,8 @@ def generate_code(request, response): "number": f.number, "comment": get_comment(proto_file, path + [2, i]), "proto_type": int(f.type), - "field_type": f.Type.Name(f.type).lower()[5:], + "field_type": field_type, + "map_types": map_types, "type": t, "zero": zero, "repeated": repeated,