Initial commit

This commit is contained in:
Daniel G. Taylor 2019-10-05 08:36:23 -07:00
commit 6ed3b09f44
26 changed files with 1026 additions and 0 deletions

1
.env.default Normal file
View File

@ -0,0 +1 @@
PYTHONPATH=.

9
.gitignore vendored Normal file
View File

@ -0,0 +1,9 @@
.env
.vscode/settings.json
.mypy_cache
.pytest_cache
betterproto/tests/*.bin
betterproto/tests/*_pb2.py
betterproto/tests/*.py
!betterproto/tests/generate.py
!betterproto/tests/test_*.py

21
Pipfile Normal file
View File

@ -0,0 +1,21 @@
[[source]]
name = "pypi"
url = "https://pypi.org/simple"
verify_ssl = true
[dev-packages]
flake8 = "*"
mypy = "*"
isort = "*"
pytest = "*"
[packages]
protobuf = "*"
jinja2 = "*"
[requires]
python_version = "3.7"
[scripts]
generate = "python betterproto/tests/generate.py"
test = "pytest ./betterproto/tests"

273
Pipfile.lock generated Normal file
View File

@ -0,0 +1,273 @@
{
"_meta": {
"hash": {
"sha256": "817b0f61c21a4841d0cfcc977becb16b4d55090f3d78c1ebcd6974c298a06348"
},
"pipfile-spec": 6,
"requires": {
"python_version": "3.7"
},
"sources": [
{
"name": "pypi",
"url": "https://pypi.org/simple",
"verify_ssl": true
}
]
},
"default": {
"jinja2": {
"hashes": [
"sha256:065c4f02ebe7f7cf559e49ee5a95fb800a9e4528727aec6f24402a5374c65013",
"sha256:14dd6caf1527abb21f08f86c784eac40853ba93edb79552aa1e4b8aef1b61c7b"
],
"index": "pypi",
"version": "==2.10.1"
},
"markupsafe": {
"hashes": [
"sha256:00bc623926325b26bb9605ae9eae8a215691f33cae5df11ca5424f06f2d1f473",
"sha256:09027a7803a62ca78792ad89403b1b7a73a01c8cb65909cd876f7fcebd79b161",
"sha256:09c4b7f37d6c648cb13f9230d847adf22f8171b1ccc4d5682398e77f40309235",
"sha256:1027c282dad077d0bae18be6794e6b6b8c91d58ed8a8d89a89d59693b9131db5",
"sha256:24982cc2533820871eba85ba648cd53d8623687ff11cbb805be4ff7b4c971aff",
"sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b",
"sha256:43a55c2930bbc139570ac2452adf3d70cdbb3cfe5912c71cdce1c2c6bbd9c5d1",
"sha256:46c99d2de99945ec5cb54f23c8cd5689f6d7177305ebff350a58ce5f8de1669e",
"sha256:500d4957e52ddc3351cabf489e79c91c17f6e0899158447047588650b5e69183",
"sha256:535f6fc4d397c1563d08b88e485c3496cf5784e927af890fb3c3aac7f933ec66",
"sha256:62fe6c95e3ec8a7fad637b7f3d372c15ec1caa01ab47926cfdf7a75b40e0eac1",
"sha256:6dd73240d2af64df90aa7c4e7481e23825ea70af4b4922f8ede5b9e35f78a3b1",
"sha256:717ba8fe3ae9cc0006d7c451f0bb265ee07739daf76355d06366154ee68d221e",
"sha256:79855e1c5b8da654cf486b830bd42c06e8780cea587384cf6545b7d9ac013a0b",
"sha256:7c1699dfe0cf8ff607dbdcc1e9b9af1755371f92a68f706051cc8c37d447c905",
"sha256:88e5fcfb52ee7b911e8bb6d6aa2fd21fbecc674eadd44118a9cc3863f938e735",
"sha256:8defac2f2ccd6805ebf65f5eeb132adcf2ab57aa11fdf4c0dd5169a004710e7d",
"sha256:98c7086708b163d425c67c7a91bad6e466bb99d797aa64f965e9d25c12111a5e",
"sha256:9add70b36c5666a2ed02b43b335fe19002ee5235efd4b8a89bfcf9005bebac0d",
"sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c",
"sha256:ade5e387d2ad0d7ebf59146cc00c8044acbd863725f887353a10df825fc8ae21",
"sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2",
"sha256:b1282f8c00509d99fef04d8ba936b156d419be841854fe901d8ae224c59f0be5",
"sha256:b2051432115498d3562c084a49bba65d97cf251f5a331c64a12ee7e04dacc51b",
"sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6",
"sha256:c8716a48d94b06bb3b2524c2b77e055fb313aeb4ea620c8dd03a105574ba704f",
"sha256:cd5df75523866410809ca100dc9681e301e3c27567cf498077e8551b6d20e42f",
"sha256:e249096428b3ae81b08327a63a485ad0878de3fb939049038579ac0ef61e17e7"
],
"version": "==1.1.1"
},
"protobuf": {
"hashes": [
"sha256:125713564d8cfed7610e52444c9769b8dcb0b55e25cc7841f2290ee7bc86636f",
"sha256:1accdb7a47e51503be64d9a57543964ba674edac103215576399d2d0e34eac77",
"sha256:27003d12d4f68e3cbea9eb67427cab3bfddd47ff90670cb367fcd7a3a89b9657",
"sha256:3264f3c431a631b0b31e9db2ae8c927b79fc1a7b1b06b31e8e5bcf2af91fe896",
"sha256:3c5ab0f5c71ca5af27143e60613729e3488bb45f6d3f143dc918a20af8bab0bf",
"sha256:45dcf8758873e3f69feab075e5f3177270739f146255225474ee0b90429adef6",
"sha256:56a77d61a91186cc5676d8e11b36a5feb513873e4ae88d2ee5cf530d52bbcd3b",
"sha256:5984e4947bbcef5bd849d6244aec507d31786f2dd3344139adc1489fb403b300",
"sha256:6b0441da73796dd00821763bb4119674eaf252776beb50ae3883bed179a60b2a",
"sha256:6f6677c5ade94d4fe75a912926d6796d5c71a2a90c2aeefe0d6f211d75c74789",
"sha256:84a825a9418d7196e2acc48f8746cf1ee75877ed2f30433ab92a133f3eaf8fbe",
"sha256:b842c34fe043ccf78b4a6cf1019d7b80113707d68c88842d061fa2b8fb6ddedc",
"sha256:ca33d2f09dae149a1dcf942d2d825ebb06343b77b437198c9e2ef115cf5d5bc1",
"sha256:db83b5c12c0cd30150bb568e6feb2435c49ce4e68fe2d7b903113f0e221e58fe",
"sha256:f50f3b1c5c1c1334ca7ce9cad5992f098f460ffd6388a3cabad10b66c2006b09",
"sha256:f99f127909731cafb841c52f9216e447d3e4afb99b17bebfad327a75aee206de"
],
"index": "pypi",
"version": "==3.10.0"
},
"six": {
"hashes": [
"sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c",
"sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73"
],
"version": "==1.12.0"
}
},
"develop": {
"atomicwrites": {
"hashes": [
"sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4",
"sha256:75a9445bac02d8d058d5e1fe689654ba5a6556a1dfd8ce6ec55a0ed79866cfa6"
],
"version": "==1.3.0"
},
"attrs": {
"hashes": [
"sha256:ec20e7a4825331c1b5ebf261d111e16fa9612c1f7a5e1f884f12bd53a664dfd2",
"sha256:f913492e1663d3c36f502e5e9ba6cd13cf19d7fab50aa13239e420fef95e1396"
],
"version": "==19.2.0"
},
"entrypoints": {
"hashes": [
"sha256:589f874b313739ad35be6e0cd7efde2a4e9b6fea91edcc34e58ecbb8dbe56d19",
"sha256:c70dd71abe5a8c85e55e12c19bd91ccfeec11a6e99044204511f9ed547d48451"
],
"version": "==0.3"
},
"flake8": {
"hashes": [
"sha256:19241c1cbc971b9962473e4438a2ca19749a7dd002dd1a946eaba171b4114548",
"sha256:8e9dfa3cecb2400b3738a42c54c3043e821682b9c840b0448c0503f781130696"
],
"index": "pypi",
"version": "==3.7.8"
},
"importlib-metadata": {
"hashes": [
"sha256:aa18d7378b00b40847790e7c27e11673d7fed219354109d0e7b9e5b25dc3ad26",
"sha256:d5f18a79777f3aa179c145737780282e27b508fc8fd688cb17c7a813e8bd39af"
],
"markers": "python_version < '3.8'",
"version": "==0.23"
},
"isort": {
"hashes": [
"sha256:54da7e92468955c4fceacd0c86bd0ec997b0e1ee80d97f67c35a78b719dccab1",
"sha256:6e811fcb295968434526407adb8796944f1988c5b65e8139058f2014cbe100fd"
],
"index": "pypi",
"version": "==4.3.21"
},
"mccabe": {
"hashes": [
"sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42",
"sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"
],
"version": "==0.6.1"
},
"more-itertools": {
"hashes": [
"sha256:409cd48d4db7052af495b09dec721011634af3753ae1ef92d2b32f73a745f832",
"sha256:92b8c4b06dac4f0611c0729b2f2ede52b2e1bac1ab48f089c7ddc12e26bb60c4"
],
"version": "==7.2.0"
},
"mypy": {
"hashes": [
"sha256:1d98fd818ad3128a5408148c9e4a5edce6ed6b58cc314283e631dd5d9216527b",
"sha256:22ee018e8fc212fe601aba65d3699689dd29a26410ef0d2cc1943de7bec7e3ac",
"sha256:3a24f80776edc706ec8d05329e854d5b9e464cd332e25cde10c8da2da0a0db6c",
"sha256:42a78944e80770f21609f504ca6c8173f7768043205b5ac51c9144e057dcf879",
"sha256:4b2b20106973548975f0c0b1112eceb4d77ed0cafe0a231a1318f3b3a22fc795",
"sha256:591a9625b4d285f3ba69f541c84c0ad9e7bffa7794da3fa0585ef13cf95cb021",
"sha256:5b4b70da3d8bae73b908a90bb2c387b977e59d484d22c604a2131f6f4397c1a3",
"sha256:84edda1ffeda0941b2ab38ecf49302326df79947fa33d98cdcfbf8ca9cf0bb23",
"sha256:b2b83d29babd61b876ae375786960a5374bba0e4aba3c293328ca6ca5dc448dd",
"sha256:cc4502f84c37223a1a5ab700649b5ab1b5e4d2bf2d426907161f20672a21930b",
"sha256:e29e24dd6e7f39f200a5bb55dcaa645d38a397dd5a6674f6042ef02df5795046"
],
"index": "pypi",
"version": "==0.730"
},
"mypy-extensions": {
"hashes": [
"sha256:a161e3b917053de87dbe469987e173e49fb454eca10ef28b48b384538cc11458"
],
"version": "==0.4.2"
},
"packaging": {
"hashes": [
"sha256:28b924174df7a2fa32c1953825ff29c61e2f5e082343165438812f00d3a7fc47",
"sha256:d9551545c6d761f3def1677baf08ab2a3ca17c56879e70fecba2fc4dde4ed108"
],
"version": "==19.2"
},
"pluggy": {
"hashes": [
"sha256:0db4b7601aae1d35b4a033282da476845aa19185c1e6964b25cf324b5e4ec3e6",
"sha256:fa5fa1622fa6dd5c030e9cad086fa19ef6a0cf6d7a2d12318e10cb49d6d68f34"
],
"version": "==0.13.0"
},
"py": {
"hashes": [
"sha256:64f65755aee5b381cea27766a3a147c3f15b9b6b9ac88676de66ba2ae36793fa",
"sha256:dc639b046a6e2cff5bbe40194ad65936d6ba360b52b3c3fe1d08a82dd50b5e53"
],
"version": "==1.8.0"
},
"pycodestyle": {
"hashes": [
"sha256:95a2219d12372f05704562a14ec30bc76b05a5b297b21a5dfe3f6fac3491ae56",
"sha256:e40a936c9a450ad81df37f549d676d127b1b66000a6c500caa2b085bc0ca976c"
],
"version": "==2.5.0"
},
"pyflakes": {
"hashes": [
"sha256:17dbeb2e3f4d772725c777fabc446d5634d1038f234e77343108ce445ea69ce0",
"sha256:d976835886f8c5b31d47970ed689944a0262b5f3afa00a5a7b4dc81e5449f8a2"
],
"version": "==2.1.1"
},
"pyparsing": {
"hashes": [
"sha256:6f98a7b9397e206d78cc01df10131398f1c8b8510a2f4d97d9abd82e1aacdd80",
"sha256:d9338df12903bbf5d65a0e4e87c2161968b10d2e489652bb47001d82a9b028b4"
],
"version": "==2.4.2"
},
"pytest": {
"hashes": [
"sha256:13c1c9b22127a77fc684eee24791efafcef343335d855e3573791c68588fe1a5",
"sha256:d8ba7be9466f55ef96ba203fc0f90d0cf212f2f927e69186e1353e30bc7f62e5"
],
"index": "pypi",
"version": "==5.2.0"
},
"six": {
"hashes": [
"sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c",
"sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73"
],
"version": "==1.12.0"
},
"typed-ast": {
"hashes": [
"sha256:18511a0b3e7922276346bcb47e2ef9f38fb90fd31cb9223eed42c85d1312344e",
"sha256:262c247a82d005e43b5b7f69aff746370538e176131c32dda9cb0f324d27141e",
"sha256:2b907eb046d049bcd9892e3076c7a6456c93a25bebfe554e931620c90e6a25b0",
"sha256:354c16e5babd09f5cb0ee000d54cfa38401d8b8891eefa878ac772f827181a3c",
"sha256:4e0b70c6fc4d010f8107726af5fd37921b666f5b31d9331f0bd24ad9a088e631",
"sha256:630968c5cdee51a11c05a30453f8cd65e0cc1d2ad0d9192819df9978984529f4",
"sha256:66480f95b8167c9c5c5c87f32cf437d585937970f3fc24386f313a4c97b44e34",
"sha256:71211d26ffd12d63a83e079ff258ac9d56a1376a25bc80b1cdcdf601b855b90b",
"sha256:95bd11af7eafc16e829af2d3df510cecfd4387f6453355188342c3e79a2ec87a",
"sha256:bc6c7d3fa1325a0c6613512a093bc2a2a15aeec350451cbdf9e1d4bffe3e3233",
"sha256:cc34a6f5b426748a507dd5d1de4c1978f2eb5626d51326e43280941206c209e1",
"sha256:d755f03c1e4a51e9b24d899561fec4ccaf51f210d52abdf8c07ee2849b212a36",
"sha256:d7c45933b1bdfaf9f36c579671fec15d25b06c8398f113dab64c18ed1adda01d",
"sha256:d896919306dd0aa22d0132f62a1b78d11aaf4c9fc5b3410d3c666b818191630a",
"sha256:ffde2fbfad571af120fcbfbbc61c72469e72f550d676c3342492a9dfdefb8f12"
],
"version": "==1.4.0"
},
"typing-extensions": {
"hashes": [
"sha256:2ed632b30bb54fc3941c382decfd0ee4148f5c591651c9272473fea2c6397d95",
"sha256:b1edbbf0652660e32ae780ac9433f4231e7339c7f9a8057d0f042fcbcea49b87",
"sha256:d8179012ec2c620d3791ca6fe2bf7979d979acdbef1fca0bc56b37411db682ed"
],
"version": "==3.7.4"
},
"wcwidth": {
"hashes": [
"sha256:3df37372226d6e63e1b1e1eda15c594bca98a22d33a23832a90998faa96bc65e",
"sha256:f4ebe71925af7b40a864553f761ed559b43544f8f71746c2d756c7fe788ade7c"
],
"version": "==0.1.7"
},
"zipp": {
"hashes": [
"sha256:3718b1cbcd963c7d4c5511a8240812904164b7f381b647143a89d3b98f9bcd8e",
"sha256:f06903e9f1f43b12d371004b4ac7b06ab39a44adc747266928ae6debfa7b3335"
],
"version": "==0.6.0"
}
}
}

10
README.md Normal file
View File

@ -0,0 +1,10 @@
# TODO
- [ ] Fixed length fields
- [x] Zig-zag signed fields (sint32, sint64)
- [x] Don't encode zero values for nested types~
- [ ] Enums
- [ ] Maps
- [ ] Support passthrough of unknown fields
- [ ] JSON that isn't naive.
- [ ] Cleanup!

262
betterproto/__init__.py Normal file
View File

@ -0,0 +1,262 @@
from abc import ABC
import json
import struct
from typing import (
Union,
Generator,
Any,
SupportsBytes,
List,
Tuple,
Callable,
Type,
Iterable,
TypeVar,
)
import dataclasses
from . import parse, serialize
PACKED_TYPES = [
"bool",
"int32",
"int64",
"uint32",
"uint64",
"sint32",
"sint64",
"float",
"double",
]
# Wire types
# https://developers.google.com/protocol-buffers/docs/encoding#structure
WIRE_VARINT = 0
WIRE_FIXED_64 = 1
WIRE_LEN_DELIM = 2
WIRE_FIXED_32 = 5
@dataclasses.dataclass(frozen=True)
class _Meta:
number: int
proto_type: str
default: Any
def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
kwargs = {}
if callable(default):
kwargs["default_factory"] = default
elif isinstance(default, dict) or isinstance(default, list):
kwargs["default_factory"] = lambda: default
else:
kwargs["default"] = default
return dataclasses.field(
**kwargs, metadata={"betterproto": _Meta(number, proto_type, default)}
)
def int32_field(
number: int, default: Union[int, Type[Iterable]] = 0
) -> dataclasses.Field:
return field(number, "int32", default=default)
def int64_field(number: int, default: int = 0) -> dataclasses.Field:
return field(number, "int64", default=default)
def uint32_field(number: int, default: int = 0) -> dataclasses.Field:
return field(number, "uint32", default=default)
def uint64_field(number: int, default: int = 0) -> dataclasses.Field:
return field(number, "uint64", default=default)
def sint32_field(number: int, default: int = 0) -> dataclasses.Field:
return field(number, "sint32", default=default)
def sint64_field(number: int, default: int = 0) -> dataclasses.Field:
return field(number, "sint64", default=default)
def float_field(number: int, default: float = 0.0) -> dataclasses.Field:
return field(number, "float", default=default)
def double_field(number: int, default: float = 0.0) -> dataclasses.Field:
return field(number, "double", default=default)
def string_field(number: int, default: str = "") -> dataclasses.Field:
return field(number, "string", default=default)
def message_field(number: int, default: Type["ProtoMessage"]) -> dataclasses.Field:
return field(number, "message", default=default)
def _serialize_single(meta: _Meta, value: Any) -> bytes:
output = b""
if meta.proto_type in ["int32", "int64", "uint32", "uint64"]:
if value < 0:
# Handle negative numbers.
value += 1 << 64
output = serialize.varint(meta.number, value)
elif meta.proto_type in ["sint32", "sint64"]:
if value >= 0:
value = value << 1
else:
value = (value << 1) ^ (~0)
output = serialize.varint(meta.number, value)
elif meta.proto_type == "string":
output = serialize.len_delim(meta.number, value.encode("utf-8"))
elif meta.proto_type == "message":
b = bytes(value)
if len(b):
output = serialize.len_delim(meta.number, b)
else:
raise NotImplementedError()
return output
def _parse_single(wire_type: int, meta: _Meta, field: Any, value: Any) -> Any:
if wire_type == WIRE_VARINT:
if meta.proto_type in ["int32", "int64"]:
bits = int(meta.proto_type[3:])
value = value & ((1 << bits) - 1)
signbit = 1 << (bits - 1)
value = int((value ^ signbit) - signbit)
elif meta.proto_type in ["sint32", "sint64"]:
# Undo zig-zag encoding
value = (value >> 1) ^ (-(value & 1))
elif wire_type == WIRE_LEN_DELIM:
if meta.proto_type in ["string"]:
value = value.decode("utf-8")
elif meta.proto_type in ["message"]:
value = field.default_factory().parse(value)
return value
# Bound type variable to allow methods to return `self` of subclasses
T = TypeVar("T", bound="Message")
class Message(ABC):
"""
A protobuf message base class. Generated code will inherit from this and
register the message fields which get used by the serializers and parsers
to go between Python, binary and JSON protobuf message representations.
"""
def __bytes__(self) -> bytes:
"""
Get the binary encoded Protobuf representation of this instance.
"""
output = b""
for field in dataclasses.fields(self):
meta: _Meta = field.metadata.get("betterproto")
value = getattr(self, field.name)
if isinstance(value, list):
if not len(value):
continue
if meta.proto_type in PACKED_TYPES:
output += serialize.packed(meta.number, value)
else:
for item in value:
output += _serialize_single(meta, item)
else:
if value == field.default:
continue
output += _serialize_single(meta, value)
return output
def parse(self, data: bytes) -> T:
"""
Parse the binary encoded Protobuf into this message instance. This
returns the instance itself and is therefore assignable and chainable.
"""
fields = {f.metadata["betterproto"].number: f for f in dataclasses.fields(self)}
for parsed in parse.fields(data):
if parsed.number in fields:
field = fields[parsed.number]
meta: _Meta = field.metadata.get("betterproto")
if (
parsed.wire_type == WIRE_LEN_DELIM
and meta.proto_type in PACKED_TYPES
):
# This is a packed repeated field.
pos = 0
value = []
while pos < len(parsed.value):
decoded, pos = parse._decode_varint(parsed.value, pos)
decoded = _parse_single(WIRE_VARINT, meta, field, decoded)
value.append(decoded)
else:
value = _parse_single(parsed.wire_type, meta, field, parsed.value)
if isinstance(getattr(self, field.name), list) and not isinstance(
value, list
):
getattr(self, field.name).append(value)
else:
setattr(self, field.name, value)
else:
# TODO: handle unknown fields
pass
return self
def to_dict(self) -> dict:
"""
Returns a dict representation of this message instance which can be
used to serialize to e.g. JSON.
"""
output = {}
for field in dataclasses.fields(self):
meta: Meta_ = field.metadata.get("betterproto")
v = getattr(self, field.name)
if meta.proto_type == "message":
v = v.to_dict()
if v:
output[field.name] = v
elif v != field.default:
output[field.name] = getattr(self, field.name)
return output
def from_dict(self, value: dict) -> T:
"""
Parse the key/value pairs in `value` into this message instance. This
returns the instance itself and is therefore assignable and chainable.
"""
for field in dataclasses.fields(self):
meta: Meta_ = field.metadata.get("betterproto")
if field.name in value:
if meta.proto_type == "message":
getattr(self, field.name).from_dict(value[field.name])
else:
setattr(self, field.name, value[field.name])
return self
def to_json(self) -> bytes:
"""Returns the encoded JSON representation of this message instance."""
return json.dumps(self.to_dict())
def from_json(self, value: bytes) -> T:
"""
Parse the key/value pairs in `value` into this message instance. This
returns the instance itself and is therefore assignable and chainable.
"""
return self.from_dict(json.loads(value))

Binary file not shown.

64
betterproto/parse.py Normal file
View File

@ -0,0 +1,64 @@
import struct
from typing import Union, Generator, Any, SupportsBytes, List, Tuple
from dataclasses import dataclass
def _decode_varint(
buffer: bytes, pos: int, signed: bool = False, result_type: type = int
) -> Tuple[int, int]:
result = 0
shift = 0
while 1:
b = buffer[pos]
result |= (b & 0x7F) << shift
pos += 1
if not (b & 0x80):
result = result_type(result)
return (result, pos)
shift += 7
if shift >= 64:
raise ValueError("Too many bytes when decoding varint.")
def packed(value: bytes, signed: bool = False, result_type: type = int) -> list:
parsed = []
pos = 0
while pos < len(value):
decoded, pos = _decode_varint(
value, pos, signed=signed, result_type=result_type
)
parsed.append(decoded)
return parsed
@dataclass(frozen=True)
class Field:
number: int
wire_type: int
value: Any
def fields(value: bytes) -> Generator[Field, None, None]:
i = 0
while i < len(value):
num_wire, i = _decode_varint(value, i)
print(num_wire, i)
number = num_wire >> 3
wire_type = num_wire & 0x7
if wire_type == 0:
decoded, i = _decode_varint(value, i)
elif wire_type == 1:
decoded, i = None, i + 4
elif wire_type == 2:
length, i = _decode_varint(value, i)
decoded = value[i : i + length]
i += length
elif wire_type == 5:
decoded, i = None, i + 2
else:
raise NotImplementedError(f"Wire type {wire_type}")
# print(Field(number=number, wire_type=wire_type, value=decoded))
yield Field(number=number, wire_type=wire_type, value=decoded)

0
betterproto/py.typed Normal file
View File

44
betterproto/serialize.py Normal file
View File

@ -0,0 +1,44 @@
import struct
from typing import Union, Generator, Any, SupportsBytes, List, Tuple
from dataclasses import dataclass
def _varint(value: int) -> bytes:
# From https://github.com/protocolbuffers/protobuf/blob/master/python/google/protobuf/internal/encoder.py#L372
b: List[int] = []
bits = value & 0x7F
value >>= 7
while value:
b.append(0x80 | bits)
bits = value & 0x7F
value >>= 7
print(value)
return bytes(b + [bits])
def varint(field_number: int, value: Union[int, float]) -> bytes:
key = _varint(field_number << 3)
return key + _varint(value)
def len_delim(field_number: int, value: Union[str, bytes]) -> bytes:
key = _varint((field_number << 3) | 2)
if isinstance(value, str):
value = value.encode("utf-8")
return key + _varint(len(value)) + value
def packed(field_number: int, value: list) -> bytes:
key = _varint((field_number << 3) | 2)
packed = b""
for item in value:
if item < 0:
# Handle negative numbers.
item += 1 << 64
packed += _varint(item)
return key + _varint(len(packed)) + packed

View File

@ -0,0 +1,24 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: {{ description.filename }}
from dataclasses import dataclass
from typing import List
import betterproto
{% for message in description.messages %}
@dataclass
class {{ message.name }}(betterproto.Message):
{% if message.comment %}
{{ message.comment }}
{% endif %}
{% for field in message.properties %}
{% if field.comment %}
{{ field.comment }}
{% endif %}
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.zero %}, default={{ field.zero }}{% endif %})
{% endfor %}
{% endfor %}

View File

@ -0,0 +1,45 @@
#!/usr/bin/env python
import os # isort: skip
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
import subprocess
import importlib
from typing import Generator, Tuple
from google.protobuf.json_format import Parse
from google.protobuf import symbol_database
from google.protobuf.descriptor_pool import DescriptorPool
root = os.path.dirname(os.path.realpath(__file__))
def get_files(end: str) -> Generator[Tuple[str, str], None, None]:
for r, dirs, files in os.walk(root):
for filename in [f for f in files if f.endswith(end)]:
parts = os.path.splitext(filename)[0].split("-")
yield [parts[0], os.path.join(r, filename)]
if __name__ == "__main__":
os.chdir(root)
for base, filename in get_files(".proto"):
subprocess.run(
f"protoc --python_out=. {os.path.basename(filename)}", shell=True
)
subprocess.run(
f"protoc --plugin=protoc-gen-custom=../../protoc-gen-betterpy.py --custom_out=. {os.path.basename(filename)}",
shell=True,
)
for base, filename in get_files(".json"):
# Reset the internal symbol database so we can import the `Test` message
# multiple times. Ugh.
sym = symbol_database.Default()
sym.pool = DescriptorPool()
imported = importlib.import_module(f"{base}_pb2")
out = filename.replace(".json", ".bin")
serialized = Parse(open(filename).read(), imported.Test()).SerializeToString()
open(out, "wb").write(serialized)

View File

@ -0,0 +1,3 @@
{
"count": -150
}

View File

@ -0,0 +1,3 @@
{
"count": 150
}

View File

@ -0,0 +1,7 @@
syntax = "proto3";
// Some documentation about the Test message.
message Test {
// Some documentation about the count.
int32 count = 1;
}

View File

@ -0,0 +1,5 @@
{
"nested": {
"count": 150
}
}

View File

@ -0,0 +1,17 @@
syntax = "proto3";
// A test message with a nested message inside of it.
message Test {
// This is the nested type.
message Nested {
// Stores a simple counter.
int32 count = 1;
}
Nested nested = 1;
Sibling sibling = 2;
}
message Sibling {
int32 foo = 1;
}

View File

@ -0,0 +1,3 @@
{
"names": ["one", "two", "three"]
}

View File

@ -0,0 +1,5 @@
syntax = "proto3";
message Test {
repeated string names = 1;
}

View File

@ -0,0 +1,3 @@
{
"counts": [1, 2, -1, -2]
}

View File

@ -0,0 +1,5 @@
syntax = "proto3";
message Test {
repeated int32 counts = 1;
}

View File

@ -0,0 +1,4 @@
{
"signed_32": -150,
"signed_64": -150
}

View File

@ -0,0 +1,4 @@
{
"signed_32": 150,
"signed_64": 150
}

View File

@ -0,0 +1,6 @@
syntax = "proto3";
message Test {
sint32 signed_32 = 1;
sint64 signed_64 = 2;
}

View File

@ -0,0 +1,23 @@
import importlib
import pytest
import json
from generate import get_files
inputs = get_files(".bin")
@pytest.mark.parametrize("name,filename", inputs)
def test_sample(name: str, filename: str) -> None:
imported = importlib.import_module(name)
data_binary = open(filename, "rb").read()
data_dict = json.loads(open(filename.replace(".bin", ".json")).read())
t1 = imported.Test().parse(data_binary)
t2 = imported.Test().from_dict(data_dict)
print(t1)
print(t2)
assert t1 == t2
assert bytes(t1) == data_binary
assert bytes(t2) == data_binary
assert t1.to_dict() == data_dict
assert t2.to_dict() == data_dict

185
protoc-gen-betterpy.py Executable file
View File

@ -0,0 +1,185 @@
#!/usr/bin/env python
import sys
import itertools
import json
import os.path
from typing import Tuple, Any, List
import textwrap
from google.protobuf.descriptor_pb2 import (
DescriptorProto,
EnumDescriptorProto,
FileDescriptorProto,
)
from google.protobuf.compiler import plugin_pb2 as plugin
from jinja2 import Environment, PackageLoader
def py_type(descriptor: DescriptorProto) -> Tuple[str, str]:
if descriptor.type in [1, 2, 6, 7, 15, 16]:
return "float", descriptor.default_value
elif descriptor.type in [3, 4, 5, 13, 17, 18]:
return "int", descriptor.default_value
elif descriptor.type == 8:
return "bool", descriptor.default_value.capitalize()
elif descriptor.type == 9:
default = ""
if descriptor.default_value:
default = f'"{descriptor.default_value}"'
return "str", default
elif descriptor.type == 11:
# Type referencing another defined Message
# print(descriptor.type_name, file=sys.stderr)
# message_type = descriptor.type_name.replace(".", "")
message_type = descriptor.type_name.split(".").pop()
return f'"{message_type}"', f"lambda: {message_type}()"
elif descriptor.type == 12:
default = ""
if descriptor.default_value:
default = f'b"{descriptor.default_value}"'
return "bytes", default
else:
raise NotImplementedError()
def traverse(proto_file):
def _traverse(path, items):
for i, item in enumerate(items):
yield item, path + [i]
if isinstance(item, DescriptorProto):
for enum in item.enum_type:
yield enum, path + [i, 4]
if item.nested_type:
for n, p in _traverse(path + [i, 3], item.nested_type):
yield n, p
return itertools.chain(
_traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type)
)
def get_comment(proto_file, path: List[int]) -> str:
for sci in proto_file.source_code_info.location:
# print(list(sci.path), path, file=sys.stderr)
if list(sci.path) == path and sci.leading_comments:
lines = textwrap.wrap(
sci.leading_comments.strip().replace("\n", ""), width=75
)
if path[-2] == 2:
# This is a field
return " # " + " # ".join(lines)
else:
# This is a class
if len(lines) == 1 and len(lines[0]) < 70:
return f' """{lines[0]}"""'
else:
return f' """\n{" ".join(lines)}\n """'
return ""
def generate_code(request, response):
env = Environment(
trim_blocks=True,
lstrip_blocks=True,
loader=PackageLoader("betterproto", "templates"),
)
template = env.get_template("main.py")
for proto_file in request.proto_file:
# print(proto_file.message_type, file=sys.stderr)
# print(proto_file.source_code_info, file=sys.stderr)
output = {
"package": proto_file.package,
"filename": proto_file.name,
"messages": [],
}
# Parse request
for item, path in traverse(proto_file):
# print(item, file=sys.stderr)
# print(path, file=sys.stderr)
data = {"name": item.name}
if isinstance(item, DescriptorProto):
# print(item, file=sys.stderr)
data.update(
{
"type": "Message",
"comment": get_comment(proto_file, path),
"properties": [],
}
)
for i, f in enumerate(item.field):
t, zero = py_type(f)
repeated = False
packed = False
if f.label == 3:
# Repeated field
repeated = True
t = f"List[{t}]"
zero = "list"
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
packed = True
data["properties"].append(
{
"name": f.name,
"number": f.number,
"comment": get_comment(proto_file, path + [2, i]),
"proto_type": int(f.type),
"field_type": f.Type.Name(f.type).lower()[5:],
"type": t,
"zero": zero,
"repeated": repeated,
"packed": packed,
}
)
# print(f, file=sys.stderr)
# elif isinstance(item, EnumDescriptorProto):
# data.update({
# 'type': 'Enum',
# 'values': [{'name': v.name, 'value': v.number}
# for v in item.value]
# })
output["messages"].append(data)
# Fill response
f = response.file.add()
f.name = os.path.splitext(proto_file.name)[0] + ".py"
# f.content = json.dumps(output, indent=2)
f.content = template.render(description=output).rstrip("\n") + "\n"
if __name__ == "__main__":
# Read request message from stdin
data = sys.stdin.buffer.read()
# Parse request
request = plugin.CodeGeneratorRequest()
request.ParseFromString(data)
# Create response
response = plugin.CodeGeneratorResponse()
# Generate code
generate_code(request, response)
# Serialise response message
output = response.SerializeToString()
# Write to stdout
sys.stdout.buffer.write(output)