More features, refactoring

This commit is contained in:
Daniel G. Taylor 2019-10-08 00:23:11 -07:00
parent 6ed3b09f44
commit c932fbc72c
14 changed files with 276 additions and 112 deletions

1
.gitignore vendored
View File

@ -7,3 +7,4 @@ betterproto/tests/*_pb2.py
betterproto/tests/*.py
!betterproto/tests/generate.py
!betterproto/tests/test_*.py
**/__pycache__

View File

@ -1,10 +1,15 @@
# TODO
- [ ] Fixed length fields
- [x] Fixed length fields
- [x] Packed fixed-length
- [x] Zig-zag signed fields (sint32, sint64)
- [x] Don't encode zero values for nested types~
- [ ] Enums
- [x] Don't encode zero values for nested types
- [x] Enums
- [ ] Maps
- [ ] Support passthrough of unknown fields
- [ ] JSON that isn't naive.
- [ ] Refs to nested types
- [ ] Imports in proto files
- [ ] Well-known Google types
- [ ] JSON that isn't completely naive.
- [ ] Async service stubs
- [ ] Cleanup!

View File

@ -17,16 +17,49 @@ import dataclasses
from . import parse, serialize
TYPE_ENUM = "enum"
TYPE_BOOL = "bool"
TYPE_INT32 = "int32"
TYPE_INT64 = "int64"
TYPE_UINT32 = "uint32"
TYPE_UINT64 = "uint64"
TYPE_SINT32 = "sint32"
TYPE_SINT64 = "sint64"
TYPE_FLOAT = "float"
TYPE_DOUBLE = "double"
TYPE_FIXED32 = "fixed32"
TYPE_SFIXED32 = "sfixed32"
TYPE_FIXED64 = "fixed64"
TYPE_SFIXED64 = "sfixed64"
TYPE_STRING = "string"
TYPE_BYTES = "bytes"
TYPE_MESSAGE = "message"
FIXED_TYPES = [
TYPE_FLOAT,
TYPE_DOUBLE,
TYPE_FIXED32,
TYPE_SFIXED32,
TYPE_FIXED64,
TYPE_SFIXED64,
]
PACKED_TYPES = [
"bool",
"int32",
"int64",
"uint32",
"uint64",
"sint32",
"sint64",
"float",
"double",
TYPE_ENUM,
TYPE_BOOL,
TYPE_INT32,
TYPE_INT64,
TYPE_UINT32,
TYPE_UINT64,
TYPE_SINT32,
TYPE_SINT64,
TYPE_FLOAT,
TYPE_DOUBLE,
TYPE_FIXED32,
TYPE_SFIXED32,
TYPE_FIXED64,
TYPE_SFIXED64,
]
# Wire types
@ -36,6 +69,21 @@ WIRE_FIXED_64 = 1
WIRE_LEN_DELIM = 2
WIRE_FIXED_32 = 5
WIRE_VARINT_TYPES = [
TYPE_ENUM,
TYPE_BOOL,
TYPE_INT32,
TYPE_INT64,
TYPE_UINT32,
TYPE_UINT64,
TYPE_SINT32,
TYPE_SINT64,
]
WIRE_FIXED_32_TYPES = [TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32]
WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE]
@dataclasses.dataclass(frozen=True)
class _Meta:
@ -59,74 +107,135 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
)
def int32_field(
number: int, default: Union[int, Type[Iterable]] = 0
) -> dataclasses.Field:
# Note: the fields below return `Any` to prevent type errors in the generated
# data classes since the types won't match with `Field` and they get swapped
# out at runtime. The generated dataclass variables are still typed correctly.
def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
return field(number, "enum", default=default)
def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
return field(number, "int32", default=default)
def int64_field(number: int, default: int = 0) -> dataclasses.Field:
def int64_field(number: int, default: int = 0) -> Any:
return field(number, "int64", default=default)
def uint32_field(number: int, default: int = 0) -> dataclasses.Field:
def uint32_field(number: int, default: int = 0) -> Any:
return field(number, "uint32", default=default)
def uint64_field(number: int, default: int = 0) -> dataclasses.Field:
def uint64_field(number: int, default: int = 0) -> Any:
return field(number, "uint64", default=default)
def sint32_field(number: int, default: int = 0) -> dataclasses.Field:
def sint32_field(number: int, default: int = 0) -> Any:
return field(number, "sint32", default=default)
def sint64_field(number: int, default: int = 0) -> dataclasses.Field:
def sint64_field(number: int, default: int = 0) -> Any:
return field(number, "sint64", default=default)
def float_field(number: int, default: float = 0.0) -> dataclasses.Field:
def float_field(number: int, default: float = 0.0) -> Any:
return field(number, "float", default=default)
def double_field(number: int, default: float = 0.0) -> dataclasses.Field:
def double_field(number: int, default: float = 0.0) -> Any:
return field(number, "double", default=default)
def string_field(number: int, default: str = "") -> dataclasses.Field:
def fixed32_field(number: int, default: float = 0.0) -> Any:
return field(number, "fixed32", default=default)
def fixed64_field(number: int, default: float = 0.0) -> Any:
return field(number, "fixed64", default=default)
def sfixed32_field(number: int, default: float = 0.0) -> Any:
return field(number, "sfixed32", default=default)
def sfixed64_field(number: int, default: float = 0.0) -> Any:
return field(number, "sfixed64", default=default)
def string_field(number: int, default: str = "") -> Any:
return field(number, "string", default=default)
def message_field(number: int, default: Type["ProtoMessage"]) -> dataclasses.Field:
def message_field(number: int, default: Type["Message"]) -> Any:
return field(number, "message", default=default)
def _serialize_single(meta: _Meta, value: Any) -> bytes:
def _pack_fmt(proto_type: str) -> str:
return {
"double": "<d",
"float": "<f",
"fixed32": "<I",
"fixed64": "<Q",
"sfixed32": "<i",
"sfixed64": "<q",
}[proto_type]
def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
value = _preprocess_single(proto_type, value)
output = b""
if meta.proto_type in ["int32", "int64", "uint32", "uint64"]:
if value < 0:
# Handle negative numbers.
value += 1 << 64
output = serialize.varint(meta.number, value)
elif meta.proto_type in ["sint32", "sint64"]:
if value >= 0:
value = value << 1
else:
value = (value << 1) ^ (~0)
output = serialize.varint(meta.number, value)
elif meta.proto_type == "string":
output = serialize.len_delim(meta.number, value.encode("utf-8"))
elif meta.proto_type == "message":
b = bytes(value)
if len(b):
output = serialize.len_delim(meta.number, b)
if proto_type in WIRE_VARINT_TYPES:
key = serialize._varint(field_number << 3)
output += key + value
elif proto_type in WIRE_FIXED_32_TYPES:
key = serialize._varint((field_number << 3) | 5)
output += key + value
elif proto_type in WIRE_FIXED_64_TYPES:
key = serialize._varint((field_number << 3) | 1)
output += key + value
elif proto_type in WIRE_LEN_DELIM_TYPES:
if len(value):
key = serialize._varint((field_number << 3) | 2)
output += key + serialize._varint(len(value)) + value
else:
raise NotImplementedError()
raise NotImplementedError(proto_type)
return output
def _parse_single(wire_type: int, meta: _Meta, field: Any, value: Any) -> Any:
def _preprocess_single(proto_type: str, value: Any) -> bytes:
"""Adjusts values before serialization."""
if proto_type in [
TYPE_ENUM,
TYPE_BOOL,
TYPE_INT32,
TYPE_INT64,
TYPE_UINT32,
TYPE_UINT64,
]:
return serialize._varint(value)
elif proto_type in [TYPE_SINT32, TYPE_SINT64]:
# Handle zig-zag encoding.
if value >= 0:
value = value << 1
else:
value = (value << 1) ^ (~0)
return serialize._varint(value)
elif proto_type in FIXED_TYPES:
return struct.pack(_pack_fmt(proto_type), value)
elif proto_type == TYPE_STRING:
return value.encode("utf-8")
elif proto_type == TYPE_MESSAGE:
return bytes(value)
return value
def _postprocess_single(wire_type: int, meta: _Meta, field: Any, value: Any) -> Any:
"""Adjusts values after parsing."""
if wire_type == WIRE_VARINT:
if meta.proto_type in ["int32", "int64"]:
bits = int(meta.proto_type[3:])
@ -136,6 +245,9 @@ def _parse_single(wire_type: int, meta: _Meta, field: Any, value: Any) -> Any:
elif meta.proto_type in ["sint32", "sint64"]:
# Undo zig-zag encoding
value = (value >> 1) ^ (-(value & 1))
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
fmt = _pack_fmt(meta.proto_type)
value = struct.unpack(fmt, value)[0]
elif wire_type == WIRE_LEN_DELIM:
if meta.proto_type in ["string"]:
value = value.decode("utf-8")
@ -170,15 +282,21 @@ class Message(ABC):
continue
if meta.proto_type in PACKED_TYPES:
output += serialize.packed(meta.number, value)
# Packed lists look like a length-delimited field. First,
# preprocess/encode each value into a buffer and then
# treat it like a field of raw bytes.
buf = b""
for item in value:
buf += _preprocess_single(meta.proto_type, item)
output += _serialize_single(meta.number, TYPE_BYTES, buf)
else:
for item in value:
output += _serialize_single(meta, item)
output += _serialize_single(meta.number, meta.proto_type, item)
else:
if value == field.default:
continue
output += _serialize_single(meta, value)
output += _serialize_single(meta.number, meta.proto_type, value)
return output
@ -201,11 +319,21 @@ class Message(ABC):
pos = 0
value = []
while pos < len(parsed.value):
decoded, pos = parse._decode_varint(parsed.value, pos)
decoded = _parse_single(WIRE_VARINT, meta, field, decoded)
if meta.proto_type in ["float", "fixed32", "sfixed32"]:
decoded, pos = parsed.value[pos : pos + 4], pos + 4
wire_type = WIRE_FIXED_32
elif meta.proto_type in ["double", "fixed64", "sfixed64"]:
decoded, pos = parsed.value[pos : pos + 8], pos + 8
wire_type = WIRE_FIXED_64
else:
decoded, pos = parse.parse_varint(parsed.value, pos)
wire_type = WIRE_VARINT
decoded = _postprocess_single(wire_type, meta, field, decoded)
value.append(decoded)
else:
value = _parse_single(parsed.wire_type, meta, field, parsed.value)
value = _postprocess_single(
parsed.wire_type, meta, field, parsed.value
)
if isinstance(getattr(self, field.name), list) and not isinstance(
value, list

View File

@ -3,9 +3,8 @@ from typing import Union, Generator, Any, SupportsBytes, List, Tuple
from dataclasses import dataclass
def _decode_varint(
buffer: bytes, pos: int, signed: bool = False, result_type: type = int
) -> Tuple[int, int]:
def parse_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, int]:
"""Parse a single varint value from a byte buffer."""
result = 0
shift = 0
while 1:
@ -13,52 +12,40 @@ def _decode_varint(
result |= (b & 0x7F) << shift
pos += 1
if not (b & 0x80):
result = result_type(result)
return (result, pos)
shift += 7
if shift >= 64:
raise ValueError("Too many bytes when decoding varint.")
def packed(value: bytes, signed: bool = False, result_type: type = int) -> list:
parsed = []
pos = 0
while pos < len(value):
decoded, pos = _decode_varint(
value, pos, signed=signed, result_type=result_type
)
parsed.append(decoded)
return parsed
@dataclass(frozen=True)
class Field:
class ParsedField:
number: int
wire_type: int
value: Any
def fields(value: bytes) -> Generator[Field, None, None]:
def fields(value: bytes) -> Generator[ParsedField, None, None]:
i = 0
while i < len(value):
num_wire, i = _decode_varint(value, i)
print(num_wire, i)
num_wire, i = parse_varint(value, i)
# print(num_wire, i)
number = num_wire >> 3
wire_type = num_wire & 0x7
if wire_type == 0:
decoded, i = _decode_varint(value, i)
decoded, i = parse_varint(value, i)
elif wire_type == 1:
decoded, i = None, i + 4
decoded, i = value[i : i + 8], i + 8
elif wire_type == 2:
length, i = _decode_varint(value, i)
length, i = parse_varint(value, i)
decoded = value[i : i + length]
i += length
elif wire_type == 5:
decoded, i = None, i + 2
decoded, i = value[i : i + 4], i + 4
else:
raise NotImplementedError(f"Wire type {wire_type}")
# print(Field(number=number, wire_type=wire_type, value=decoded))
# print(ParsedField(number=number, wire_type=wire_type, value=decoded))
yield Field(number=number, wire_type=wire_type, value=decoded)
yield ParsedField(number=number, wire_type=wire_type, value=decoded)

View File

@ -7,6 +7,9 @@ def _varint(value: int) -> bytes:
# From https://github.com/protocolbuffers/protobuf/blob/master/python/google/protobuf/internal/encoder.py#L372
b: List[int] = []
if value < 0:
value += 1 << 64
bits = value & 0x7F
value >>= 7
while value:
@ -15,30 +18,3 @@ def _varint(value: int) -> bytes:
value >>= 7
print(value)
return bytes(b + [bits])
def varint(field_number: int, value: Union[int, float]) -> bytes:
key = _varint(field_number << 3)
return key + _varint(value)
def len_delim(field_number: int, value: Union[str, bytes]) -> bytes:
key = _varint((field_number << 3) | 2)
if isinstance(value, str):
value = value.encode("utf-8")
return key + _varint(len(value)) + value
def packed(field_number: int, value: list) -> bytes:
key = _varint((field_number << 3) | 2)
packed = b""
for item in value:
if item < 0:
# Handle negative numbers.
item += 1 << 64
packed += _varint(item)
return key + _varint(len(packed)) + packed

View File

@ -1,11 +1,30 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: {{ description.filename }}
# plugin: python-betterproto
{% if description.enums %}import enum
{% endif %}
from dataclasses import dataclass
from typing import List
import betterproto
{% if description.enums %}{% for enum in description.enums %}
class {{ enum.name }}(enum.IntEnum):
{% if enum.comment %}
{{ enum.comment }}
{% endif %}
{% for entry in enum.entries %}
{% if entry.comment %}
{{ entry.comment }}
{% endif %}
{{ entry.name }} = {{ entry.value }}
{% endfor %}
{% endfor %}
{% endif %}
{% for message in description.messages %}
@dataclass
class {{ message.name }}(betterproto.Message):

View File

@ -0,0 +1,3 @@
{
"count": -123.45
}

View File

@ -0,0 +1,3 @@
{
"count": 123.45
}

View File

@ -0,0 +1,5 @@
syntax = "proto3";
message Test {
double count = 1;
}

View File

@ -0,0 +1,3 @@
{
"greeting": 1
}

View File

@ -0,0 +1,14 @@
syntax = "proto3";
// Enum for the different greeting types
enum Greeting {
HI = 0;
HEY = 1;
// Formal greeting
HELLO = 2;
}
message Test {
// Greeting enum example
Greeting greeting = 1;
}

View File

@ -1,3 +1,5 @@
{
"counts": [1, 2, -1, -2]
"counts": [1, 2, -1, -2],
"signed": [1, 2, -1, -2],
"fixed": [1.0, 2.7, 3.4]
}

View File

@ -2,4 +2,6 @@ syntax = "proto3";
message Test {
repeated int32 counts = 1;
repeated sint64 signed = 2;
repeated double fixed = 3;
}

View File

@ -43,8 +43,11 @@ def py_type(descriptor: DescriptorProto) -> Tuple[str, str]:
if descriptor.default_value:
default = f'b"{descriptor.default_value}"'
return "bytes", default
elif descriptor.type == 14:
# print(descriptor.type_name, file=sys.stderr)
return descriptor.type_name.split(".").pop(), 0
else:
raise NotImplementedError()
raise NotImplementedError(f"Unknown type {descriptor.type}")
def traverse(proto_file):
@ -101,6 +104,7 @@ def generate_code(request, response):
"package": proto_file.package,
"filename": proto_file.name,
"messages": [],
"enums": [],
}
# Parse request
@ -148,14 +152,26 @@ def generate_code(request, response):
)
# print(f, file=sys.stderr)
# elif isinstance(item, EnumDescriptorProto):
# data.update({
# 'type': 'Enum',
# 'values': [{'name': v.name, 'value': v.number}
# for v in item.value]
# })
output["messages"].append(data)
output["messages"].append(data)
elif isinstance(item, EnumDescriptorProto):
# print(item.name, path, file=sys.stderr)
data.update(
{
"type": "Enum",
"comment": get_comment(proto_file, path),
"entries": [
{
"name": v.name,
"value": v.number,
"comment": get_comment(proto_file, path + [2, i]),
}
for i, v in enumerate(item.value)
],
}
)
output["enums"].append(data)
# Fill response
f = response.file.add()