More features, refactoring

This commit is contained in:
Daniel G. Taylor 2019-10-08 00:23:11 -07:00
parent 6ed3b09f44
commit c932fbc72c
14 changed files with 276 additions and 112 deletions

1
.gitignore vendored
View File

@ -7,3 +7,4 @@ betterproto/tests/*_pb2.py
betterproto/tests/*.py betterproto/tests/*.py
!betterproto/tests/generate.py !betterproto/tests/generate.py
!betterproto/tests/test_*.py !betterproto/tests/test_*.py
**/__pycache__

View File

@ -1,10 +1,15 @@
# TODO # TODO
- [ ] Fixed length fields - [x] Fixed length fields
- [x] Packed fixed-length
- [x] Zig-zag signed fields (sint32, sint64) - [x] Zig-zag signed fields (sint32, sint64)
- [x] Don't encode zero values for nested types~ - [x] Don't encode zero values for nested types
- [ ] Enums - [x] Enums
- [ ] Maps - [ ] Maps
- [ ] Support passthrough of unknown fields - [ ] Support passthrough of unknown fields
- [ ] JSON that isn't naive. - [ ] Refs to nested types
- [ ] Imports in proto files
- [ ] Well-known Google types
- [ ] JSON that isn't completely naive.
- [ ] Async service stubs
- [ ] Cleanup! - [ ] Cleanup!

View File

@ -17,16 +17,49 @@ import dataclasses
from . import parse, serialize from . import parse, serialize
TYPE_ENUM = "enum"
TYPE_BOOL = "bool"
TYPE_INT32 = "int32"
TYPE_INT64 = "int64"
TYPE_UINT32 = "uint32"
TYPE_UINT64 = "uint64"
TYPE_SINT32 = "sint32"
TYPE_SINT64 = "sint64"
TYPE_FLOAT = "float"
TYPE_DOUBLE = "double"
TYPE_FIXED32 = "fixed32"
TYPE_SFIXED32 = "sfixed32"
TYPE_FIXED64 = "fixed64"
TYPE_SFIXED64 = "sfixed64"
TYPE_STRING = "string"
TYPE_BYTES = "bytes"
TYPE_MESSAGE = "message"
FIXED_TYPES = [
TYPE_FLOAT,
TYPE_DOUBLE,
TYPE_FIXED32,
TYPE_SFIXED32,
TYPE_FIXED64,
TYPE_SFIXED64,
]
PACKED_TYPES = [ PACKED_TYPES = [
"bool", TYPE_ENUM,
"int32", TYPE_BOOL,
"int64", TYPE_INT32,
"uint32", TYPE_INT64,
"uint64", TYPE_UINT32,
"sint32", TYPE_UINT64,
"sint64", TYPE_SINT32,
"float", TYPE_SINT64,
"double", TYPE_FLOAT,
TYPE_DOUBLE,
TYPE_FIXED32,
TYPE_SFIXED32,
TYPE_FIXED64,
TYPE_SFIXED64,
] ]
# Wire types # Wire types
@ -36,6 +69,21 @@ WIRE_FIXED_64 = 1
WIRE_LEN_DELIM = 2 WIRE_LEN_DELIM = 2
WIRE_FIXED_32 = 5 WIRE_FIXED_32 = 5
WIRE_VARINT_TYPES = [
TYPE_ENUM,
TYPE_BOOL,
TYPE_INT32,
TYPE_INT64,
TYPE_UINT32,
TYPE_UINT64,
TYPE_SINT32,
TYPE_SINT64,
]
WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE]
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class _Meta: class _Meta:
@ -59,74 +107,135 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
) )
def int32_field( # Note: the fields below return `Any` to prevent type errors in the generated
number: int, default: Union[int, Type[Iterable]] = 0 # data classes since the types won't match with `Field` and they get swapped
) -> dataclasses.Field: # out at runtime. The generated dataclass variables are still typed correctly.
def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
return field(number, "enum", default=default)
def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
return field(number, "int32", default=default) return field(number, "int32", default=default)
def int64_field(number: int, default: int = 0) -> dataclasses.Field: def int64_field(number: int, default: int = 0) -> Any:
return field(number, "int64", default=default) return field(number, "int64", default=default)
def uint32_field(number: int, default: int = 0) -> dataclasses.Field: def uint32_field(number: int, default: int = 0) -> Any:
return field(number, "uint32", default=default) return field(number, "uint32", default=default)
def uint64_field(number: int, default: int = 0) -> dataclasses.Field: def uint64_field(number: int, default: int = 0) -> Any:
return field(number, "uint64", default=default) return field(number, "uint64", default=default)
def sint32_field(number: int, default: int = 0) -> dataclasses.Field: def sint32_field(number: int, default: int = 0) -> Any:
return field(number, "sint32", default=default) return field(number, "sint32", default=default)
def sint64_field(number: int, default: int = 0) -> dataclasses.Field: def sint64_field(number: int, default: int = 0) -> Any:
return field(number, "sint64", default=default) return field(number, "sint64", default=default)
def float_field(number: int, default: float = 0.0) -> dataclasses.Field: def float_field(number: int, default: float = 0.0) -> Any:
return field(number, "float", default=default) return field(number, "float", default=default)
def double_field(number: int, default: float = 0.0) -> dataclasses.Field: def double_field(number: int, default: float = 0.0) -> Any:
return field(number, "double", default=default) return field(number, "double", default=default)
def string_field(number: int, default: str = "") -> dataclasses.Field: def fixed32_field(number: int, default: float = 0.0) -> Any:
return field(number, "fixed32", default=default)
def fixed64_field(number: int, default: float = 0.0) -> Any:
return field(number, "fixed64", default=default)
def sfixed32_field(number: int, default: float = 0.0) -> Any:
return field(number, "sfixed32", default=default)
def sfixed64_field(number: int, default: float = 0.0) -> Any:
return field(number, "sfixed64", default=default)
def string_field(number: int, default: str = "") -> Any:
return field(number, "string", default=default) return field(number, "string", default=default)
def message_field(number: int, default: Type["ProtoMessage"]) -> dataclasses.Field: def message_field(number: int, default: Type["Message"]) -> Any:
return field(number, "message", default=default) return field(number, "message", default=default)
def _serialize_single(meta: _Meta, value: Any) -> bytes: def _pack_fmt(proto_type: str) -> str:
return {
"double": "<d",
"float": "<f",
"fixed32": "<I",
"fixed64": "<Q",
"sfixed32": "<i",
"sfixed64": "<q",
}[proto_type]
def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
value = _preprocess_single(proto_type, value)
output = b"" output = b""
if meta.proto_type in ["int32", "int64", "uint32", "uint64"]: if proto_type in WIRE_VARINT_TYPES:
if value < 0: key = serialize._varint(field_number << 3)
# Handle negative numbers. output += key + value
value += 1 << 64 elif proto_type in WIRE_FIXED_32_TYPES:
output = serialize.varint(meta.number, value) key = serialize._varint((field_number << 3) | 5)
elif meta.proto_type in ["sint32", "sint64"]: output += key + value
if value >= 0: elif proto_type in WIRE_FIXED_64_TYPES:
value = value << 1 key = serialize._varint((field_number << 3) | 1)
output += key + value
elif proto_type in WIRE_LEN_DELIM_TYPES:
if len(value):
key = serialize._varint((field_number << 3) | 2)
output += key + serialize._varint(len(value)) + value
else: else:
value = (value << 1) ^ (~0) raise NotImplementedError(proto_type)
output = serialize.varint(meta.number, value)
elif meta.proto_type == "string":
output = serialize.len_delim(meta.number, value.encode("utf-8"))
elif meta.proto_type == "message":
b = bytes(value)
if len(b):
output = serialize.len_delim(meta.number, b)
else:
raise NotImplementedError()
return output return output
def _parse_single(wire_type: int, meta: _Meta, field: Any, value: Any) -> Any: def _preprocess_single(proto_type: str, value: Any) -> bytes:
"""Adjusts values before serialization."""
if proto_type in [
TYPE_ENUM,
TYPE_BOOL,
TYPE_INT32,
TYPE_INT64,
TYPE_UINT32,
TYPE_UINT64,
]:
return serialize._varint(value)
elif proto_type in [TYPE_SINT32, TYPE_SINT64]:
# Handle zig-zag encoding.
if value >= 0:
value = value << 1
else:
value = (value << 1) ^ (~0)
return serialize._varint(value)
elif proto_type in FIXED_TYPES:
return struct.pack(_pack_fmt(proto_type), value)
elif proto_type == TYPE_STRING:
return value.encode("utf-8")
elif proto_type == TYPE_MESSAGE:
return bytes(value)
return value
def _postprocess_single(wire_type: int, meta: _Meta, field: Any, value: Any) -> Any:
"""Adjusts values after parsing."""
if wire_type == WIRE_VARINT: if wire_type == WIRE_VARINT:
if meta.proto_type in ["int32", "int64"]: if meta.proto_type in ["int32", "int64"]:
bits = int(meta.proto_type[3:]) bits = int(meta.proto_type[3:])
@ -136,6 +245,9 @@ def _parse_single(wire_type: int, meta: _Meta, field: Any, value: Any) -> Any:
elif meta.proto_type in ["sint32", "sint64"]: elif meta.proto_type in ["sint32", "sint64"]:
# Undo zig-zag encoding # Undo zig-zag encoding
value = (value >> 1) ^ (-(value & 1)) value = (value >> 1) ^ (-(value & 1))
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
fmt = _pack_fmt(meta.proto_type)
value = struct.unpack(fmt, value)[0]
elif wire_type == WIRE_LEN_DELIM: elif wire_type == WIRE_LEN_DELIM:
if meta.proto_type in ["string"]: if meta.proto_type in ["string"]:
value = value.decode("utf-8") value = value.decode("utf-8")
@ -170,15 +282,21 @@ class Message(ABC):
continue continue
if meta.proto_type in PACKED_TYPES: if meta.proto_type in PACKED_TYPES:
output += serialize.packed(meta.number, value) # Packed lists look like a length-delimited field. First,
# preprocess/encode each value into a buffer and then
# treat it like a field of raw bytes.
buf = b""
for item in value:
buf += _preprocess_single(meta.proto_type, item)
output += _serialize_single(meta.number, TYPE_BYTES, buf)
else: else:
for item in value: for item in value:
output += _serialize_single(meta, item) output += _serialize_single(meta.number, meta.proto_type, item)
else: else:
if value == field.default: if value == field.default:
continue continue
output += _serialize_single(meta, value) output += _serialize_single(meta.number, meta.proto_type, value)
return output return output
@ -201,11 +319,21 @@ class Message(ABC):
pos = 0 pos = 0
value = [] value = []
while pos < len(parsed.value): while pos < len(parsed.value):
decoded, pos = parse._decode_varint(parsed.value, pos) if meta.proto_type in ["float", "fixed32", "sfixed32"]:
decoded = _parse_single(WIRE_VARINT, meta, field, decoded) decoded, pos = parsed.value[pos : pos + 4], pos + 4
wire_type = WIRE_FIXED_32
elif meta.proto_type in ["double", "fixed64", "sfixed64"]:
decoded, pos = parsed.value[pos : pos + 8], pos + 8
wire_type = WIRE_FIXED_64
else:
decoded, pos = parse.parse_varint(parsed.value, pos)
wire_type = WIRE_VARINT
decoded = _postprocess_single(wire_type, meta, field, decoded)
value.append(decoded) value.append(decoded)
else: else:
value = _parse_single(parsed.wire_type, meta, field, parsed.value) value = _postprocess_single(
parsed.wire_type, meta, field, parsed.value
)
if isinstance(getattr(self, field.name), list) and not isinstance( if isinstance(getattr(self, field.name), list) and not isinstance(
value, list value, list

View File

@ -3,9 +3,8 @@ from typing import Union, Generator, Any, SupportsBytes, List, Tuple
from dataclasses import dataclass from dataclasses import dataclass
def _decode_varint( def parse_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, int]:
buffer: bytes, pos: int, signed: bool = False, result_type: type = int """Parse a single varint value from a byte buffer."""
) -> Tuple[int, int]:
result = 0 result = 0
shift = 0 shift = 0
while 1: while 1:
@ -13,52 +12,40 @@ def _decode_varint(
result |= (b & 0x7F) << shift result |= (b & 0x7F) << shift
pos += 1 pos += 1
if not (b & 0x80): if not (b & 0x80):
result = result_type(result)
return (result, pos) return (result, pos)
shift += 7 shift += 7
if shift >= 64: if shift >= 64:
raise ValueError("Too many bytes when decoding varint.") raise ValueError("Too many bytes when decoding varint.")
def packed(value: bytes, signed: bool = False, result_type: type = int) -> list:
parsed = []
pos = 0
while pos < len(value):
decoded, pos = _decode_varint(
value, pos, signed=signed, result_type=result_type
)
parsed.append(decoded)
return parsed
@dataclass(frozen=True) @dataclass(frozen=True)
class Field: class ParsedField:
number: int number: int
wire_type: int wire_type: int
value: Any value: Any
def fields(value: bytes) -> Generator[Field, None, None]: def fields(value: bytes) -> Generator[ParsedField, None, None]:
i = 0 i = 0
while i < len(value): while i < len(value):
num_wire, i = _decode_varint(value, i) num_wire, i = parse_varint(value, i)
print(num_wire, i) # print(num_wire, i)
number = num_wire >> 3 number = num_wire >> 3
wire_type = num_wire & 0x7 wire_type = num_wire & 0x7
if wire_type == 0: if wire_type == 0:
decoded, i = _decode_varint(value, i) decoded, i = parse_varint(value, i)
elif wire_type == 1: elif wire_type == 1:
decoded, i = None, i + 4 decoded, i = value[i : i + 8], i + 8
elif wire_type == 2: elif wire_type == 2:
length, i = _decode_varint(value, i) length, i = parse_varint(value, i)
decoded = value[i : i + length] decoded = value[i : i + length]
i += length i += length
elif wire_type == 5: elif wire_type == 5:
decoded, i = None, i + 2 decoded, i = value[i : i + 4], i + 4
else: else:
raise NotImplementedError(f"Wire type {wire_type}") raise NotImplementedError(f"Wire type {wire_type}")
# print(Field(number=number, wire_type=wire_type, value=decoded)) # print(ParsedField(number=number, wire_type=wire_type, value=decoded))
yield Field(number=number, wire_type=wire_type, value=decoded) yield ParsedField(number=number, wire_type=wire_type, value=decoded)

View File

@ -7,6 +7,9 @@ def _varint(value: int) -> bytes:
# From https://github.com/protocolbuffers/protobuf/blob/master/python/google/protobuf/internal/encoder.py#L372 # From https://github.com/protocolbuffers/protobuf/blob/master/python/google/protobuf/internal/encoder.py#L372
b: List[int] = [] b: List[int] = []
if value < 0:
value += 1 << 64
bits = value & 0x7F bits = value & 0x7F
value >>= 7 value >>= 7
while value: while value:
@ -15,30 +18,3 @@ def _varint(value: int) -> bytes:
value >>= 7 value >>= 7
print(value) print(value)
return bytes(b + [bits]) return bytes(b + [bits])
def varint(field_number: int, value: Union[int, float]) -> bytes:
key = _varint(field_number << 3)
return key + _varint(value)
def len_delim(field_number: int, value: Union[str, bytes]) -> bytes:
key = _varint((field_number << 3) | 2)
if isinstance(value, str):
value = value.encode("utf-8")
return key + _varint(len(value)) + value
def packed(field_number: int, value: list) -> bytes:
key = _varint((field_number << 3) | 2)
packed = b""
for item in value:
if item < 0:
# Handle negative numbers.
item += 1 << 64
packed += _varint(item)
return key + _varint(len(packed)) + packed

View File

@ -1,11 +1,30 @@
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# source: {{ description.filename }} # source: {{ description.filename }}
# plugin: python-betterproto
{% if description.enums %}import enum
{% endif %}
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List
import betterproto import betterproto
{% if description.enums %}{% for enum in description.enums %}
class {{ enum.name }}(enum.IntEnum):
{% if enum.comment %}
{{ enum.comment }}
{% endif %}
{% for entry in enum.entries %}
{% if entry.comment %}
{{ entry.comment }}
{% endif %}
{{ entry.name }} = {{ entry.value }}
{% endfor %}
{% endfor %}
{% endif %}
{% for message in description.messages %} {% for message in description.messages %}
@dataclass @dataclass
class {{ message.name }}(betterproto.Message): class {{ message.name }}(betterproto.Message):

View File

@ -0,0 +1,3 @@
{
"count": -123.45
}

View File

@ -0,0 +1,3 @@
{
"count": 123.45
}

View File

@ -0,0 +1,5 @@
syntax = "proto3";
message Test {
double count = 1;
}

View File

@ -0,0 +1,3 @@
{
"greeting": 1
}

View File

@ -0,0 +1,14 @@
syntax = "proto3";
// Enum for the different greeting types
enum Greeting {
HI = 0;
HEY = 1;
// Formal greeting
HELLO = 2;
}
message Test {
// Greeting enum example
Greeting greeting = 1;
}

View File

@ -1,3 +1,5 @@
{ {
"counts": [1, 2, -1, -2] "counts": [1, 2, -1, -2],
"signed": [1, 2, -1, -2],
"fixed": [1.0, 2.7, 3.4]
} }

View File

@ -2,4 +2,6 @@ syntax = "proto3";
message Test { message Test {
repeated int32 counts = 1; repeated int32 counts = 1;
repeated sint64 signed = 2;
repeated double fixed = 3;
} }

View File

@ -43,8 +43,11 @@ def py_type(descriptor: DescriptorProto) -> Tuple[str, str]:
if descriptor.default_value: if descriptor.default_value:
default = f'b"{descriptor.default_value}"' default = f'b"{descriptor.default_value}"'
return "bytes", default return "bytes", default
elif descriptor.type == 14:
# print(descriptor.type_name, file=sys.stderr)
return descriptor.type_name.split(".").pop(), 0
else: else:
raise NotImplementedError() raise NotImplementedError(f"Unknown type {descriptor.type}")
def traverse(proto_file): def traverse(proto_file):
@ -101,6 +104,7 @@ def generate_code(request, response):
"package": proto_file.package, "package": proto_file.package,
"filename": proto_file.name, "filename": proto_file.name,
"messages": [], "messages": [],
"enums": [],
} }
# Parse request # Parse request
@ -148,15 +152,27 @@ def generate_code(request, response):
) )
# print(f, file=sys.stderr) # print(f, file=sys.stderr)
# elif isinstance(item, EnumDescriptorProto):
# data.update({
# 'type': 'Enum',
# 'values': [{'name': v.name, 'value': v.number}
# for v in item.value]
# })
output["messages"].append(data) output["messages"].append(data)
elif isinstance(item, EnumDescriptorProto):
# print(item.name, path, file=sys.stderr)
data.update(
{
"type": "Enum",
"comment": get_comment(proto_file, path),
"entries": [
{
"name": v.name,
"value": v.number,
"comment": get_comment(proto_file, path + [2, i]),
}
for i, v in enumerate(item.value)
],
}
)
output["enums"].append(data)
# Fill response # Fill response
f = response.file.add() f = response.file.add()
f.name = os.path.splitext(proto_file.name)[0] + ".py" f.name = os.path.splitext(proto_file.name)[0] + ".py"