Initial commit
This commit is contained in:
262
betterproto/__init__.py
Normal file
262
betterproto/__init__.py
Normal file
@@ -0,0 +1,262 @@
|
||||
from abc import ABC
|
||||
import json
|
||||
import struct
|
||||
from typing import (
|
||||
Union,
|
||||
Generator,
|
||||
Any,
|
||||
SupportsBytes,
|
||||
List,
|
||||
Tuple,
|
||||
Callable,
|
||||
Type,
|
||||
Iterable,
|
||||
TypeVar,
|
||||
)
|
||||
import dataclasses
|
||||
|
||||
from . import parse, serialize
|
||||
|
||||
PACKED_TYPES = [
|
||||
"bool",
|
||||
"int32",
|
||||
"int64",
|
||||
"uint32",
|
||||
"uint64",
|
||||
"sint32",
|
||||
"sint64",
|
||||
"float",
|
||||
"double",
|
||||
]
|
||||
|
||||
# Wire types
|
||||
# https://developers.google.com/protocol-buffers/docs/encoding#structure
|
||||
WIRE_VARINT = 0
|
||||
WIRE_FIXED_64 = 1
|
||||
WIRE_LEN_DELIM = 2
|
||||
WIRE_FIXED_32 = 5
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _Meta:
|
||||
number: int
|
||||
proto_type: str
|
||||
default: Any
|
||||
|
||||
|
||||
def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
|
||||
kwargs = {}
|
||||
|
||||
if callable(default):
|
||||
kwargs["default_factory"] = default
|
||||
elif isinstance(default, dict) or isinstance(default, list):
|
||||
kwargs["default_factory"] = lambda: default
|
||||
else:
|
||||
kwargs["default"] = default
|
||||
|
||||
return dataclasses.field(
|
||||
**kwargs, metadata={"betterproto": _Meta(number, proto_type, default)}
|
||||
)
|
||||
|
||||
|
||||
def int32_field(
|
||||
number: int, default: Union[int, Type[Iterable]] = 0
|
||||
) -> dataclasses.Field:
|
||||
return field(number, "int32", default=default)
|
||||
|
||||
|
||||
def int64_field(number: int, default: int = 0) -> dataclasses.Field:
|
||||
return field(number, "int64", default=default)
|
||||
|
||||
|
||||
def uint32_field(number: int, default: int = 0) -> dataclasses.Field:
|
||||
return field(number, "uint32", default=default)
|
||||
|
||||
|
||||
def uint64_field(number: int, default: int = 0) -> dataclasses.Field:
|
||||
return field(number, "uint64", default=default)
|
||||
|
||||
|
||||
def sint32_field(number: int, default: int = 0) -> dataclasses.Field:
|
||||
return field(number, "sint32", default=default)
|
||||
|
||||
|
||||
def sint64_field(number: int, default: int = 0) -> dataclasses.Field:
|
||||
return field(number, "sint64", default=default)
|
||||
|
||||
|
||||
def float_field(number: int, default: float = 0.0) -> dataclasses.Field:
|
||||
return field(number, "float", default=default)
|
||||
|
||||
|
||||
def double_field(number: int, default: float = 0.0) -> dataclasses.Field:
|
||||
return field(number, "double", default=default)
|
||||
|
||||
|
||||
def string_field(number: int, default: str = "") -> dataclasses.Field:
|
||||
return field(number, "string", default=default)
|
||||
|
||||
|
||||
def message_field(number: int, default: Type["ProtoMessage"]) -> dataclasses.Field:
|
||||
return field(number, "message", default=default)
|
||||
|
||||
|
||||
def _serialize_single(meta: _Meta, value: Any) -> bytes:
|
||||
output = b""
|
||||
if meta.proto_type in ["int32", "int64", "uint32", "uint64"]:
|
||||
if value < 0:
|
||||
# Handle negative numbers.
|
||||
value += 1 << 64
|
||||
output = serialize.varint(meta.number, value)
|
||||
elif meta.proto_type in ["sint32", "sint64"]:
|
||||
if value >= 0:
|
||||
value = value << 1
|
||||
else:
|
||||
value = (value << 1) ^ (~0)
|
||||
output = serialize.varint(meta.number, value)
|
||||
elif meta.proto_type == "string":
|
||||
output = serialize.len_delim(meta.number, value.encode("utf-8"))
|
||||
elif meta.proto_type == "message":
|
||||
b = bytes(value)
|
||||
if len(b):
|
||||
output = serialize.len_delim(meta.number, b)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _parse_single(wire_type: int, meta: _Meta, field: Any, value: Any) -> Any:
|
||||
if wire_type == WIRE_VARINT:
|
||||
if meta.proto_type in ["int32", "int64"]:
|
||||
bits = int(meta.proto_type[3:])
|
||||
value = value & ((1 << bits) - 1)
|
||||
signbit = 1 << (bits - 1)
|
||||
value = int((value ^ signbit) - signbit)
|
||||
elif meta.proto_type in ["sint32", "sint64"]:
|
||||
# Undo zig-zag encoding
|
||||
value = (value >> 1) ^ (-(value & 1))
|
||||
elif wire_type == WIRE_LEN_DELIM:
|
||||
if meta.proto_type in ["string"]:
|
||||
value = value.decode("utf-8")
|
||||
elif meta.proto_type in ["message"]:
|
||||
value = field.default_factory().parse(value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
# Bound type variable to allow methods to return `self` of subclasses
|
||||
T = TypeVar("T", bound="Message")
|
||||
|
||||
|
||||
class Message(ABC):
|
||||
"""
|
||||
A protobuf message base class. Generated code will inherit from this and
|
||||
register the message fields which get used by the serializers and parsers
|
||||
to go between Python, binary and JSON protobuf message representations.
|
||||
"""
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
"""
|
||||
Get the binary encoded Protobuf representation of this instance.
|
||||
"""
|
||||
output = b""
|
||||
for field in dataclasses.fields(self):
|
||||
meta: _Meta = field.metadata.get("betterproto")
|
||||
value = getattr(self, field.name)
|
||||
|
||||
if isinstance(value, list):
|
||||
if not len(value):
|
||||
continue
|
||||
|
||||
if meta.proto_type in PACKED_TYPES:
|
||||
output += serialize.packed(meta.number, value)
|
||||
else:
|
||||
for item in value:
|
||||
output += _serialize_single(meta, item)
|
||||
else:
|
||||
if value == field.default:
|
||||
continue
|
||||
|
||||
output += _serialize_single(meta, value)
|
||||
|
||||
return output
|
||||
|
||||
def parse(self, data: bytes) -> T:
|
||||
"""
|
||||
Parse the binary encoded Protobuf into this message instance. This
|
||||
returns the instance itself and is therefore assignable and chainable.
|
||||
"""
|
||||
fields = {f.metadata["betterproto"].number: f for f in dataclasses.fields(self)}
|
||||
for parsed in parse.fields(data):
|
||||
if parsed.number in fields:
|
||||
field = fields[parsed.number]
|
||||
meta: _Meta = field.metadata.get("betterproto")
|
||||
|
||||
if (
|
||||
parsed.wire_type == WIRE_LEN_DELIM
|
||||
and meta.proto_type in PACKED_TYPES
|
||||
):
|
||||
# This is a packed repeated field.
|
||||
pos = 0
|
||||
value = []
|
||||
while pos < len(parsed.value):
|
||||
decoded, pos = parse._decode_varint(parsed.value, pos)
|
||||
decoded = _parse_single(WIRE_VARINT, meta, field, decoded)
|
||||
value.append(decoded)
|
||||
else:
|
||||
value = _parse_single(parsed.wire_type, meta, field, parsed.value)
|
||||
|
||||
if isinstance(getattr(self, field.name), list) and not isinstance(
|
||||
value, list
|
||||
):
|
||||
getattr(self, field.name).append(value)
|
||||
else:
|
||||
setattr(self, field.name, value)
|
||||
else:
|
||||
# TODO: handle unknown fields
|
||||
pass
|
||||
|
||||
return self
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
Returns a dict representation of this message instance which can be
|
||||
used to serialize to e.g. JSON.
|
||||
"""
|
||||
output = {}
|
||||
for field in dataclasses.fields(self):
|
||||
meta: Meta_ = field.metadata.get("betterproto")
|
||||
v = getattr(self, field.name)
|
||||
if meta.proto_type == "message":
|
||||
v = v.to_dict()
|
||||
if v:
|
||||
output[field.name] = v
|
||||
elif v != field.default:
|
||||
output[field.name] = getattr(self, field.name)
|
||||
return output
|
||||
|
||||
def from_dict(self, value: dict) -> T:
|
||||
"""
|
||||
Parse the key/value pairs in `value` into this message instance. This
|
||||
returns the instance itself and is therefore assignable and chainable.
|
||||
"""
|
||||
for field in dataclasses.fields(self):
|
||||
meta: Meta_ = field.metadata.get("betterproto")
|
||||
if field.name in value:
|
||||
if meta.proto_type == "message":
|
||||
getattr(self, field.name).from_dict(value[field.name])
|
||||
else:
|
||||
setattr(self, field.name, value[field.name])
|
||||
return self
|
||||
|
||||
def to_json(self) -> bytes:
|
||||
"""Returns the encoded JSON representation of this message instance."""
|
||||
return json.dumps(self.to_dict())
|
||||
|
||||
def from_json(self, value: bytes) -> T:
|
||||
"""
|
||||
Parse the key/value pairs in `value` into this message instance. This
|
||||
returns the instance itself and is therefore assignable and chainable.
|
||||
"""
|
||||
return self.from_dict(json.loads(value))
|
||||
BIN
betterproto/__pycache__/__init__.cpython-37.pyc
Normal file
BIN
betterproto/__pycache__/__init__.cpython-37.pyc
Normal file
Binary file not shown.
64
betterproto/parse.py
Normal file
64
betterproto/parse.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import struct
|
||||
from typing import Union, Generator, Any, SupportsBytes, List, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def _decode_varint(
|
||||
buffer: bytes, pos: int, signed: bool = False, result_type: type = int
|
||||
) -> Tuple[int, int]:
|
||||
result = 0
|
||||
shift = 0
|
||||
while 1:
|
||||
b = buffer[pos]
|
||||
result |= (b & 0x7F) << shift
|
||||
pos += 1
|
||||
if not (b & 0x80):
|
||||
result = result_type(result)
|
||||
return (result, pos)
|
||||
shift += 7
|
||||
if shift >= 64:
|
||||
raise ValueError("Too many bytes when decoding varint.")
|
||||
|
||||
|
||||
def packed(value: bytes, signed: bool = False, result_type: type = int) -> list:
|
||||
parsed = []
|
||||
pos = 0
|
||||
while pos < len(value):
|
||||
decoded, pos = _decode_varint(
|
||||
value, pos, signed=signed, result_type=result_type
|
||||
)
|
||||
parsed.append(decoded)
|
||||
return parsed
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Field:
|
||||
number: int
|
||||
wire_type: int
|
||||
value: Any
|
||||
|
||||
|
||||
def fields(value: bytes) -> Generator[Field, None, None]:
|
||||
i = 0
|
||||
while i < len(value):
|
||||
num_wire, i = _decode_varint(value, i)
|
||||
print(num_wire, i)
|
||||
number = num_wire >> 3
|
||||
wire_type = num_wire & 0x7
|
||||
|
||||
if wire_type == 0:
|
||||
decoded, i = _decode_varint(value, i)
|
||||
elif wire_type == 1:
|
||||
decoded, i = None, i + 4
|
||||
elif wire_type == 2:
|
||||
length, i = _decode_varint(value, i)
|
||||
decoded = value[i : i + length]
|
||||
i += length
|
||||
elif wire_type == 5:
|
||||
decoded, i = None, i + 2
|
||||
else:
|
||||
raise NotImplementedError(f"Wire type {wire_type}")
|
||||
|
||||
# print(Field(number=number, wire_type=wire_type, value=decoded))
|
||||
|
||||
yield Field(number=number, wire_type=wire_type, value=decoded)
|
||||
0
betterproto/py.typed
Normal file
0
betterproto/py.typed
Normal file
44
betterproto/serialize.py
Normal file
44
betterproto/serialize.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import struct
|
||||
from typing import Union, Generator, Any, SupportsBytes, List, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def _varint(value: int) -> bytes:
|
||||
# From https://github.com/protocolbuffers/protobuf/blob/master/python/google/protobuf/internal/encoder.py#L372
|
||||
b: List[int] = []
|
||||
|
||||
bits = value & 0x7F
|
||||
value >>= 7
|
||||
while value:
|
||||
b.append(0x80 | bits)
|
||||
bits = value & 0x7F
|
||||
value >>= 7
|
||||
print(value)
|
||||
return bytes(b + [bits])
|
||||
|
||||
|
||||
def varint(field_number: int, value: Union[int, float]) -> bytes:
|
||||
key = _varint(field_number << 3)
|
||||
return key + _varint(value)
|
||||
|
||||
|
||||
def len_delim(field_number: int, value: Union[str, bytes]) -> bytes:
|
||||
key = _varint((field_number << 3) | 2)
|
||||
|
||||
if isinstance(value, str):
|
||||
value = value.encode("utf-8")
|
||||
|
||||
return key + _varint(len(value)) + value
|
||||
|
||||
|
||||
def packed(field_number: int, value: list) -> bytes:
|
||||
key = _varint((field_number << 3) | 2)
|
||||
|
||||
packed = b""
|
||||
for item in value:
|
||||
if item < 0:
|
||||
# Handle negative numbers.
|
||||
item += 1 << 64
|
||||
packed += _varint(item)
|
||||
|
||||
return key + _varint(len(packed)) + packed
|
||||
24
betterproto/templates/main.py
Normal file
24
betterproto/templates/main.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: {{ description.filename }}
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import betterproto
|
||||
|
||||
|
||||
{% for message in description.messages %}
|
||||
@dataclass
|
||||
class {{ message.name }}(betterproto.Message):
|
||||
{% if message.comment %}
|
||||
{{ message.comment }}
|
||||
|
||||
{% endif %}
|
||||
{% for field in message.properties %}
|
||||
{% if field.comment %}
|
||||
{{ field.comment }}
|
||||
{% endif %}
|
||||
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.zero %}, default={{ field.zero }}{% endif %})
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endfor %}
|
||||
45
betterproto/tests/generate.py
Normal file
45
betterproto/tests/generate.py
Normal file
@@ -0,0 +1,45 @@
|
||||
#!/usr/bin/env python
|
||||
import os # isort: skip
|
||||
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
|
||||
|
||||
import subprocess
|
||||
import importlib
|
||||
from typing import Generator, Tuple
|
||||
|
||||
from google.protobuf.json_format import Parse
|
||||
from google.protobuf import symbol_database
|
||||
from google.protobuf.descriptor_pool import DescriptorPool
|
||||
|
||||
root = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
|
||||
def get_files(end: str) -> Generator[Tuple[str, str], None, None]:
|
||||
for r, dirs, files in os.walk(root):
|
||||
for filename in [f for f in files if f.endswith(end)]:
|
||||
parts = os.path.splitext(filename)[0].split("-")
|
||||
yield [parts[0], os.path.join(r, filename)]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.chdir(root)
|
||||
|
||||
for base, filename in get_files(".proto"):
|
||||
subprocess.run(
|
||||
f"protoc --python_out=. {os.path.basename(filename)}", shell=True
|
||||
)
|
||||
subprocess.run(
|
||||
f"protoc --plugin=protoc-gen-custom=../../protoc-gen-betterpy.py --custom_out=. {os.path.basename(filename)}",
|
||||
shell=True,
|
||||
)
|
||||
|
||||
for base, filename in get_files(".json"):
|
||||
# Reset the internal symbol database so we can import the `Test` message
|
||||
# multiple times. Ugh.
|
||||
sym = symbol_database.Default()
|
||||
sym.pool = DescriptorPool()
|
||||
imported = importlib.import_module(f"{base}_pb2")
|
||||
out = filename.replace(".json", ".bin")
|
||||
serialized = Parse(open(filename).read(), imported.Test()).SerializeToString()
|
||||
open(out, "wb").write(serialized)
|
||||
3
betterproto/tests/int32-negative.json
Normal file
3
betterproto/tests/int32-negative.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"count": -150
|
||||
}
|
||||
3
betterproto/tests/int32.json
Normal file
3
betterproto/tests/int32.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"count": 150
|
||||
}
|
||||
7
betterproto/tests/int32.proto
Normal file
7
betterproto/tests/int32.proto
Normal file
@@ -0,0 +1,7 @@
|
||||
syntax = "proto3";
|
||||
|
||||
// Some documentation about the Test message.
|
||||
message Test {
|
||||
// Some documentation about the count.
|
||||
int32 count = 1;
|
||||
}
|
||||
5
betterproto/tests/nested.json
Normal file
5
betterproto/tests/nested.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"nested": {
|
||||
"count": 150
|
||||
}
|
||||
}
|
||||
17
betterproto/tests/nested.proto
Normal file
17
betterproto/tests/nested.proto
Normal file
@@ -0,0 +1,17 @@
|
||||
syntax = "proto3";
|
||||
|
||||
// A test message with a nested message inside of it.
|
||||
message Test {
|
||||
// This is the nested type.
|
||||
message Nested {
|
||||
// Stores a simple counter.
|
||||
int32 count = 1;
|
||||
}
|
||||
|
||||
Nested nested = 1;
|
||||
Sibling sibling = 2;
|
||||
}
|
||||
|
||||
message Sibling {
|
||||
int32 foo = 1;
|
||||
}
|
||||
3
betterproto/tests/repeated.json
Normal file
3
betterproto/tests/repeated.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"names": ["one", "two", "three"]
|
||||
}
|
||||
5
betterproto/tests/repeated.proto
Normal file
5
betterproto/tests/repeated.proto
Normal file
@@ -0,0 +1,5 @@
|
||||
syntax = "proto3";
|
||||
|
||||
message Test {
|
||||
repeated string names = 1;
|
||||
}
|
||||
3
betterproto/tests/repeatedpacked.json
Normal file
3
betterproto/tests/repeatedpacked.json
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
"counts": [1, 2, -1, -2]
|
||||
}
|
||||
5
betterproto/tests/repeatedpacked.proto
Normal file
5
betterproto/tests/repeatedpacked.proto
Normal file
@@ -0,0 +1,5 @@
|
||||
syntax = "proto3";
|
||||
|
||||
message Test {
|
||||
repeated int32 counts = 1;
|
||||
}
|
||||
4
betterproto/tests/signed-negative.json
Normal file
4
betterproto/tests/signed-negative.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"signed_32": -150,
|
||||
"signed_64": -150
|
||||
}
|
||||
4
betterproto/tests/signed.json
Normal file
4
betterproto/tests/signed.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"signed_32": 150,
|
||||
"signed_64": 150
|
||||
}
|
||||
6
betterproto/tests/signed.proto
Normal file
6
betterproto/tests/signed.proto
Normal file
@@ -0,0 +1,6 @@
|
||||
syntax = "proto3";
|
||||
|
||||
message Test {
|
||||
sint32 signed_32 = 1;
|
||||
sint64 signed_64 = 2;
|
||||
}
|
||||
23
betterproto/tests/test_inputs.py
Normal file
23
betterproto/tests/test_inputs.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import importlib
|
||||
import pytest
|
||||
import json
|
||||
|
||||
from generate import get_files
|
||||
|
||||
inputs = get_files(".bin")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("name,filename", inputs)
|
||||
def test_sample(name: str, filename: str) -> None:
|
||||
imported = importlib.import_module(name)
|
||||
data_binary = open(filename, "rb").read()
|
||||
data_dict = json.loads(open(filename.replace(".bin", ".json")).read())
|
||||
t1 = imported.Test().parse(data_binary)
|
||||
t2 = imported.Test().from_dict(data_dict)
|
||||
print(t1)
|
||||
print(t2)
|
||||
assert t1 == t2
|
||||
assert bytes(t1) == data_binary
|
||||
assert bytes(t2) == data_binary
|
||||
assert t1.to_dict() == data_dict
|
||||
assert t2.to_dict() == data_dict
|
||||
Reference in New Issue
Block a user