From c932fbc72cefc09f1b67f7c00f07bda6584ad057 Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Tue, 8 Oct 2019 00:23:11 -0700 Subject: [PATCH] More features, refactoring --- .gitignore | 1 + README.md | 13 +- betterproto/__init__.py | 222 +++++++++++++++++++------ betterproto/parse.py | 37 ++--- betterproto/serialize.py | 30 +--- betterproto/templates/main.py | 19 +++ betterproto/tests/double-negative.json | 3 + betterproto/tests/double.json | 3 + betterproto/tests/double.proto | 5 + betterproto/tests/enums.json | 3 + betterproto/tests/enums.proto | 14 ++ betterproto/tests/repeatedpacked.json | 4 +- betterproto/tests/repeatedpacked.proto | 2 + protoc-gen-betterpy.py | 32 +++- 14 files changed, 276 insertions(+), 112 deletions(-) create mode 100644 betterproto/tests/double-negative.json create mode 100644 betterproto/tests/double.json create mode 100644 betterproto/tests/double.proto create mode 100644 betterproto/tests/enums.json create mode 100644 betterproto/tests/enums.proto diff --git a/.gitignore b/.gitignore index 68a329f..5d5b945 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ betterproto/tests/*_pb2.py betterproto/tests/*.py !betterproto/tests/generate.py !betterproto/tests/test_*.py +**/__pycache__ diff --git a/README.md b/README.md index 2069637..565c91a 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,15 @@ # TODO -- [ ] Fixed length fields +- [x] Fixed length fields + - [x] Packed fixed-length - [x] Zig-zag signed fields (sint32, sint64) -- [x] Don't encode zero values for nested types~ -- [ ] Enums +- [x] Don't encode zero values for nested types +- [x] Enums - [ ] Maps - [ ] Support passthrough of unknown fields -- [ ] JSON that isn't naive. +- [ ] Refs to nested types +- [ ] Imports in proto files +- [ ] Well-known Google types +- [ ] JSON that isn't completely naive. +- [ ] Async service stubs - [ ] Cleanup! diff --git a/betterproto/__init__.py b/betterproto/__init__.py index fc13b27..adea090 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -17,16 +17,49 @@ import dataclasses from . import parse, serialize +TYPE_ENUM = "enum" +TYPE_BOOL = "bool" +TYPE_INT32 = "int32" +TYPE_INT64 = "int64" +TYPE_UINT32 = "uint32" +TYPE_UINT64 = "uint64" +TYPE_SINT32 = "sint32" +TYPE_SINT64 = "sint64" +TYPE_FLOAT = "float" +TYPE_DOUBLE = "double" +TYPE_FIXED32 = "fixed32" +TYPE_SFIXED32 = "sfixed32" +TYPE_FIXED64 = "fixed64" +TYPE_SFIXED64 = "sfixed64" +TYPE_STRING = "string" +TYPE_BYTES = "bytes" +TYPE_MESSAGE = "message" + + +FIXED_TYPES = [ + TYPE_FLOAT, + TYPE_DOUBLE, + TYPE_FIXED32, + TYPE_SFIXED32, + TYPE_FIXED64, + TYPE_SFIXED64, +] + PACKED_TYPES = [ - "bool", - "int32", - "int64", - "uint32", - "uint64", - "sint32", - "sint64", - "float", - "double", + TYPE_ENUM, + TYPE_BOOL, + TYPE_INT32, + TYPE_INT64, + TYPE_UINT32, + TYPE_UINT64, + TYPE_SINT32, + TYPE_SINT64, + TYPE_FLOAT, + TYPE_DOUBLE, + TYPE_FIXED32, + TYPE_SFIXED32, + TYPE_FIXED64, + TYPE_SFIXED64, ] # Wire types @@ -36,6 +69,21 @@ WIRE_FIXED_64 = 1 WIRE_LEN_DELIM = 2 WIRE_FIXED_32 = 5 +WIRE_VARINT_TYPES = [ + TYPE_ENUM, + TYPE_BOOL, + TYPE_INT32, + TYPE_INT64, + TYPE_UINT32, + TYPE_UINT64, + TYPE_SINT32, + TYPE_SINT64, +] + +WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32] +WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64] +WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE] + @dataclasses.dataclass(frozen=True) class _Meta: @@ -59,74 +107,135 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field: ) -def int32_field( - number: int, default: Union[int, Type[Iterable]] = 0 -) -> dataclasses.Field: +# Note: the fields below return `Any` to prevent type errors in the generated +# data classes since the types won't match with `Field` and they get swapped +# out at runtime. The generated dataclass variables are still typed correctly. + + +def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any: + return field(number, "enum", default=default) + + +def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any: return field(number, "int32", default=default) -def int64_field(number: int, default: int = 0) -> dataclasses.Field: +def int64_field(number: int, default: int = 0) -> Any: return field(number, "int64", default=default) -def uint32_field(number: int, default: int = 0) -> dataclasses.Field: +def uint32_field(number: int, default: int = 0) -> Any: return field(number, "uint32", default=default) -def uint64_field(number: int, default: int = 0) -> dataclasses.Field: +def uint64_field(number: int, default: int = 0) -> Any: return field(number, "uint64", default=default) -def sint32_field(number: int, default: int = 0) -> dataclasses.Field: +def sint32_field(number: int, default: int = 0) -> Any: return field(number, "sint32", default=default) -def sint64_field(number: int, default: int = 0) -> dataclasses.Field: +def sint64_field(number: int, default: int = 0) -> Any: return field(number, "sint64", default=default) -def float_field(number: int, default: float = 0.0) -> dataclasses.Field: +def float_field(number: int, default: float = 0.0) -> Any: return field(number, "float", default=default) -def double_field(number: int, default: float = 0.0) -> dataclasses.Field: +def double_field(number: int, default: float = 0.0) -> Any: return field(number, "double", default=default) -def string_field(number: int, default: str = "") -> dataclasses.Field: +def fixed32_field(number: int, default: float = 0.0) -> Any: + return field(number, "fixed32", default=default) + + +def fixed64_field(number: int, default: float = 0.0) -> Any: + return field(number, "fixed64", default=default) + + +def sfixed32_field(number: int, default: float = 0.0) -> Any: + return field(number, "sfixed32", default=default) + + +def sfixed64_field(number: int, default: float = 0.0) -> Any: + return field(number, "sfixed64", default=default) + + +def string_field(number: int, default: str = "") -> Any: return field(number, "string", default=default) -def message_field(number: int, default: Type["ProtoMessage"]) -> dataclasses.Field: +def message_field(number: int, default: Type["Message"]) -> Any: return field(number, "message", default=default) -def _serialize_single(meta: _Meta, value: Any) -> bytes: +def _pack_fmt(proto_type: str) -> str: + return { + "double": " bytes: + value = _preprocess_single(proto_type, value) + output = b"" - if meta.proto_type in ["int32", "int64", "uint32", "uint64"]: - if value < 0: - # Handle negative numbers. - value += 1 << 64 - output = serialize.varint(meta.number, value) - elif meta.proto_type in ["sint32", "sint64"]: - if value >= 0: - value = value << 1 - else: - value = (value << 1) ^ (~0) - output = serialize.varint(meta.number, value) - elif meta.proto_type == "string": - output = serialize.len_delim(meta.number, value.encode("utf-8")) - elif meta.proto_type == "message": - b = bytes(value) - if len(b): - output = serialize.len_delim(meta.number, b) + if proto_type in WIRE_VARINT_TYPES: + key = serialize._varint(field_number << 3) + output += key + value + elif proto_type in WIRE_FIXED_32_TYPES: + key = serialize._varint((field_number << 3) | 5) + output += key + value + elif proto_type in WIRE_FIXED_64_TYPES: + key = serialize._varint((field_number << 3) | 1) + output += key + value + elif proto_type in WIRE_LEN_DELIM_TYPES: + if len(value): + key = serialize._varint((field_number << 3) | 2) + output += key + serialize._varint(len(value)) + value else: - raise NotImplementedError() + raise NotImplementedError(proto_type) return output -def _parse_single(wire_type: int, meta: _Meta, field: Any, value: Any) -> Any: +def _preprocess_single(proto_type: str, value: Any) -> bytes: + """Adjusts values before serialization.""" + if proto_type in [ + TYPE_ENUM, + TYPE_BOOL, + TYPE_INT32, + TYPE_INT64, + TYPE_UINT32, + TYPE_UINT64, + ]: + return serialize._varint(value) + elif proto_type in [TYPE_SINT32, TYPE_SINT64]: + # Handle zig-zag encoding. + if value >= 0: + value = value << 1 + else: + value = (value << 1) ^ (~0) + return serialize._varint(value) + elif proto_type in FIXED_TYPES: + return struct.pack(_pack_fmt(proto_type), value) + elif proto_type == TYPE_STRING: + return value.encode("utf-8") + elif proto_type == TYPE_MESSAGE: + return bytes(value) + + return value + + +def _postprocess_single(wire_type: int, meta: _Meta, field: Any, value: Any) -> Any: + """Adjusts values after parsing.""" if wire_type == WIRE_VARINT: if meta.proto_type in ["int32", "int64"]: bits = int(meta.proto_type[3:]) @@ -136,6 +245,9 @@ def _parse_single(wire_type: int, meta: _Meta, field: Any, value: Any) -> Any: elif meta.proto_type in ["sint32", "sint64"]: # Undo zig-zag encoding value = (value >> 1) ^ (-(value & 1)) + elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]: + fmt = _pack_fmt(meta.proto_type) + value = struct.unpack(fmt, value)[0] elif wire_type == WIRE_LEN_DELIM: if meta.proto_type in ["string"]: value = value.decode("utf-8") @@ -170,15 +282,21 @@ class Message(ABC): continue if meta.proto_type in PACKED_TYPES: - output += serialize.packed(meta.number, value) + # Packed lists look like a length-delimited field. First, + # preprocess/encode each value into a buffer and then + # treat it like a field of raw bytes. + buf = b"" + for item in value: + buf += _preprocess_single(meta.proto_type, item) + output += _serialize_single(meta.number, TYPE_BYTES, buf) else: for item in value: - output += _serialize_single(meta, item) + output += _serialize_single(meta.number, meta.proto_type, item) else: if value == field.default: continue - output += _serialize_single(meta, value) + output += _serialize_single(meta.number, meta.proto_type, value) return output @@ -201,11 +319,21 @@ class Message(ABC): pos = 0 value = [] while pos < len(parsed.value): - decoded, pos = parse._decode_varint(parsed.value, pos) - decoded = _parse_single(WIRE_VARINT, meta, field, decoded) + if meta.proto_type in ["float", "fixed32", "sfixed32"]: + decoded, pos = parsed.value[pos : pos + 4], pos + 4 + wire_type = WIRE_FIXED_32 + elif meta.proto_type in ["double", "fixed64", "sfixed64"]: + decoded, pos = parsed.value[pos : pos + 8], pos + 8 + wire_type = WIRE_FIXED_64 + else: + decoded, pos = parse.parse_varint(parsed.value, pos) + wire_type = WIRE_VARINT + decoded = _postprocess_single(wire_type, meta, field, decoded) value.append(decoded) else: - value = _parse_single(parsed.wire_type, meta, field, parsed.value) + value = _postprocess_single( + parsed.wire_type, meta, field, parsed.value + ) if isinstance(getattr(self, field.name), list) and not isinstance( value, list diff --git a/betterproto/parse.py b/betterproto/parse.py index 69ed554..a2c9592 100644 --- a/betterproto/parse.py +++ b/betterproto/parse.py @@ -3,9 +3,8 @@ from typing import Union, Generator, Any, SupportsBytes, List, Tuple from dataclasses import dataclass -def _decode_varint( - buffer: bytes, pos: int, signed: bool = False, result_type: type = int -) -> Tuple[int, int]: +def parse_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, int]: + """Parse a single varint value from a byte buffer.""" result = 0 shift = 0 while 1: @@ -13,52 +12,40 @@ def _decode_varint( result |= (b & 0x7F) << shift pos += 1 if not (b & 0x80): - result = result_type(result) return (result, pos) shift += 7 if shift >= 64: raise ValueError("Too many bytes when decoding varint.") -def packed(value: bytes, signed: bool = False, result_type: type = int) -> list: - parsed = [] - pos = 0 - while pos < len(value): - decoded, pos = _decode_varint( - value, pos, signed=signed, result_type=result_type - ) - parsed.append(decoded) - return parsed - - @dataclass(frozen=True) -class Field: +class ParsedField: number: int wire_type: int value: Any -def fields(value: bytes) -> Generator[Field, None, None]: +def fields(value: bytes) -> Generator[ParsedField, None, None]: i = 0 while i < len(value): - num_wire, i = _decode_varint(value, i) - print(num_wire, i) + num_wire, i = parse_varint(value, i) + # print(num_wire, i) number = num_wire >> 3 wire_type = num_wire & 0x7 if wire_type == 0: - decoded, i = _decode_varint(value, i) + decoded, i = parse_varint(value, i) elif wire_type == 1: - decoded, i = None, i + 4 + decoded, i = value[i : i + 8], i + 8 elif wire_type == 2: - length, i = _decode_varint(value, i) + length, i = parse_varint(value, i) decoded = value[i : i + length] i += length elif wire_type == 5: - decoded, i = None, i + 2 + decoded, i = value[i : i + 4], i + 4 else: raise NotImplementedError(f"Wire type {wire_type}") - # print(Field(number=number, wire_type=wire_type, value=decoded)) + # print(ParsedField(number=number, wire_type=wire_type, value=decoded)) - yield Field(number=number, wire_type=wire_type, value=decoded) + yield ParsedField(number=number, wire_type=wire_type, value=decoded) diff --git a/betterproto/serialize.py b/betterproto/serialize.py index 3d9ac50..bade2c6 100644 --- a/betterproto/serialize.py +++ b/betterproto/serialize.py @@ -7,6 +7,9 @@ def _varint(value: int) -> bytes: # From https://github.com/protocolbuffers/protobuf/blob/master/python/google/protobuf/internal/encoder.py#L372 b: List[int] = [] + if value < 0: + value += 1 << 64 + bits = value & 0x7F value >>= 7 while value: @@ -15,30 +18,3 @@ def _varint(value: int) -> bytes: value >>= 7 print(value) return bytes(b + [bits]) - - -def varint(field_number: int, value: Union[int, float]) -> bytes: - key = _varint(field_number << 3) - return key + _varint(value) - - -def len_delim(field_number: int, value: Union[str, bytes]) -> bytes: - key = _varint((field_number << 3) | 2) - - if isinstance(value, str): - value = value.encode("utf-8") - - return key + _varint(len(value)) + value - - -def packed(field_number: int, value: list) -> bytes: - key = _varint((field_number << 3) | 2) - - packed = b"" - for item in value: - if item < 0: - # Handle negative numbers. - item += 1 << 64 - packed += _varint(item) - - return key + _varint(len(packed)) + packed diff --git a/betterproto/templates/main.py b/betterproto/templates/main.py index d831cf3..1313b05 100644 --- a/betterproto/templates/main.py +++ b/betterproto/templates/main.py @@ -1,11 +1,30 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! # source: {{ description.filename }} +# plugin: python-betterproto +{% if description.enums %}import enum +{% endif %} from dataclasses import dataclass from typing import List import betterproto +{% if description.enums %}{% for enum in description.enums %} +class {{ enum.name }}(enum.IntEnum): + {% if enum.comment %} +{{ enum.comment }} + + {% endif %} + {% for entry in enum.entries %} + {% if entry.comment %} +{{ entry.comment }} + {% endif %} + {{ entry.name }} = {{ entry.value }} + {% endfor %} +{% endfor %} + + +{% endif %} {% for message in description.messages %} @dataclass class {{ message.name }}(betterproto.Message): diff --git a/betterproto/tests/double-negative.json b/betterproto/tests/double-negative.json new file mode 100644 index 0000000..e0776c7 --- /dev/null +++ b/betterproto/tests/double-negative.json @@ -0,0 +1,3 @@ +{ + "count": -123.45 +} diff --git a/betterproto/tests/double.json b/betterproto/tests/double.json new file mode 100644 index 0000000..321412e --- /dev/null +++ b/betterproto/tests/double.json @@ -0,0 +1,3 @@ +{ + "count": 123.45 +} diff --git a/betterproto/tests/double.proto b/betterproto/tests/double.proto new file mode 100644 index 0000000..88525d9 --- /dev/null +++ b/betterproto/tests/double.proto @@ -0,0 +1,5 @@ +syntax = "proto3"; + +message Test { + double count = 1; +} diff --git a/betterproto/tests/enums.json b/betterproto/tests/enums.json new file mode 100644 index 0000000..182f73c --- /dev/null +++ b/betterproto/tests/enums.json @@ -0,0 +1,3 @@ +{ + "greeting": 1 +} diff --git a/betterproto/tests/enums.proto b/betterproto/tests/enums.proto new file mode 100644 index 0000000..421f78a --- /dev/null +++ b/betterproto/tests/enums.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +// Enum for the different greeting types +enum Greeting { + HI = 0; + HEY = 1; + // Formal greeting + HELLO = 2; +} + +message Test { + // Greeting enum example + Greeting greeting = 1; +} diff --git a/betterproto/tests/repeatedpacked.json b/betterproto/tests/repeatedpacked.json index 7d9ae00..2a19e3d 100644 --- a/betterproto/tests/repeatedpacked.json +++ b/betterproto/tests/repeatedpacked.json @@ -1,3 +1,5 @@ { - "counts": [1, 2, -1, -2] + "counts": [1, 2, -1, -2], + "signed": [1, 2, -1, -2], + "fixed": [1.0, 2.7, 3.4] } diff --git a/betterproto/tests/repeatedpacked.proto b/betterproto/tests/repeatedpacked.proto index 0662cdb..ea86dde 100644 --- a/betterproto/tests/repeatedpacked.proto +++ b/betterproto/tests/repeatedpacked.proto @@ -2,4 +2,6 @@ syntax = "proto3"; message Test { repeated int32 counts = 1; + repeated sint64 signed = 2; + repeated double fixed = 3; } diff --git a/protoc-gen-betterpy.py b/protoc-gen-betterpy.py index d2f96b5..a5d024c 100755 --- a/protoc-gen-betterpy.py +++ b/protoc-gen-betterpy.py @@ -43,8 +43,11 @@ def py_type(descriptor: DescriptorProto) -> Tuple[str, str]: if descriptor.default_value: default = f'b"{descriptor.default_value}"' return "bytes", default + elif descriptor.type == 14: + # print(descriptor.type_name, file=sys.stderr) + return descriptor.type_name.split(".").pop(), 0 else: - raise NotImplementedError() + raise NotImplementedError(f"Unknown type {descriptor.type}") def traverse(proto_file): @@ -101,6 +104,7 @@ def generate_code(request, response): "package": proto_file.package, "filename": proto_file.name, "messages": [], + "enums": [], } # Parse request @@ -148,14 +152,26 @@ def generate_code(request, response): ) # print(f, file=sys.stderr) - # elif isinstance(item, EnumDescriptorProto): - # data.update({ - # 'type': 'Enum', - # 'values': [{'name': v.name, 'value': v.number} - # for v in item.value] - # }) + output["messages"].append(data) - output["messages"].append(data) + elif isinstance(item, EnumDescriptorProto): + # print(item.name, path, file=sys.stderr) + data.update( + { + "type": "Enum", + "comment": get_comment(proto_file, path), + "entries": [ + { + "name": v.name, + "value": v.number, + "comment": get_comment(proto_file, path + [2, i]), + } + for i, v in enumerate(item.value) + ], + } + ) + + output["enums"].append(data) # Fill response f = response.file.add()