commit 6ed3b09f44fa93f718f862a7361dd35ff43ff139 Author: Daniel G. Taylor Date: Sat Oct 5 08:36:23 2019 -0700 Initial commit diff --git a/.env.default b/.env.default new file mode 100644 index 0000000..ae53c39 --- /dev/null +++ b/.env.default @@ -0,0 +1 @@ +PYTHONPATH=. diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..68a329f --- /dev/null +++ b/.gitignore @@ -0,0 +1,9 @@ +.env +.vscode/settings.json +.mypy_cache +.pytest_cache +betterproto/tests/*.bin +betterproto/tests/*_pb2.py +betterproto/tests/*.py +!betterproto/tests/generate.py +!betterproto/tests/test_*.py diff --git a/Pipfile b/Pipfile new file mode 100644 index 0000000..c66afd5 --- /dev/null +++ b/Pipfile @@ -0,0 +1,21 @@ +[[source]] +name = "pypi" +url = "https://pypi.org/simple" +verify_ssl = true + +[dev-packages] +flake8 = "*" +mypy = "*" +isort = "*" +pytest = "*" + +[packages] +protobuf = "*" +jinja2 = "*" + +[requires] +python_version = "3.7" + +[scripts] +generate = "python betterproto/tests/generate.py" +test = "pytest ./betterproto/tests" diff --git a/Pipfile.lock b/Pipfile.lock new file mode 100644 index 0000000..415b48c --- /dev/null +++ b/Pipfile.lock @@ -0,0 +1,273 @@ +{ + "_meta": { + "hash": { + "sha256": "817b0f61c21a4841d0cfcc977becb16b4d55090f3d78c1ebcd6974c298a06348" + }, + "pipfile-spec": 6, + "requires": { + "python_version": "3.7" + }, + "sources": [ + { + "name": "pypi", + "url": "https://pypi.org/simple", + "verify_ssl": true + } + ] + }, + "default": { + "jinja2": { + "hashes": [ + "sha256:065c4f02ebe7f7cf559e49ee5a95fb800a9e4528727aec6f24402a5374c65013", + "sha256:14dd6caf1527abb21f08f86c784eac40853ba93edb79552aa1e4b8aef1b61c7b" + ], + "index": "pypi", + "version": "==2.10.1" + }, + "markupsafe": { + "hashes": [ + "sha256:00bc623926325b26bb9605ae9eae8a215691f33cae5df11ca5424f06f2d1f473", + "sha256:09027a7803a62ca78792ad89403b1b7a73a01c8cb65909cd876f7fcebd79b161", + "sha256:09c4b7f37d6c648cb13f9230d847adf22f8171b1ccc4d5682398e77f40309235", + "sha256:1027c282dad077d0bae18be6794e6b6b8c91d58ed8a8d89a89d59693b9131db5", + "sha256:24982cc2533820871eba85ba648cd53d8623687ff11cbb805be4ff7b4c971aff", + "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b", + "sha256:43a55c2930bbc139570ac2452adf3d70cdbb3cfe5912c71cdce1c2c6bbd9c5d1", + "sha256:46c99d2de99945ec5cb54f23c8cd5689f6d7177305ebff350a58ce5f8de1669e", + "sha256:500d4957e52ddc3351cabf489e79c91c17f6e0899158447047588650b5e69183", + "sha256:535f6fc4d397c1563d08b88e485c3496cf5784e927af890fb3c3aac7f933ec66", + "sha256:62fe6c95e3ec8a7fad637b7f3d372c15ec1caa01ab47926cfdf7a75b40e0eac1", + "sha256:6dd73240d2af64df90aa7c4e7481e23825ea70af4b4922f8ede5b9e35f78a3b1", + "sha256:717ba8fe3ae9cc0006d7c451f0bb265ee07739daf76355d06366154ee68d221e", + "sha256:79855e1c5b8da654cf486b830bd42c06e8780cea587384cf6545b7d9ac013a0b", + "sha256:7c1699dfe0cf8ff607dbdcc1e9b9af1755371f92a68f706051cc8c37d447c905", + "sha256:88e5fcfb52ee7b911e8bb6d6aa2fd21fbecc674eadd44118a9cc3863f938e735", + "sha256:8defac2f2ccd6805ebf65f5eeb132adcf2ab57aa11fdf4c0dd5169a004710e7d", + "sha256:98c7086708b163d425c67c7a91bad6e466bb99d797aa64f965e9d25c12111a5e", + "sha256:9add70b36c5666a2ed02b43b335fe19002ee5235efd4b8a89bfcf9005bebac0d", + "sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c", + "sha256:ade5e387d2ad0d7ebf59146cc00c8044acbd863725f887353a10df825fc8ae21", + "sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2", + "sha256:b1282f8c00509d99fef04d8ba936b156d419be841854fe901d8ae224c59f0be5", + "sha256:b2051432115498d3562c084a49bba65d97cf251f5a331c64a12ee7e04dacc51b", + "sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6", + "sha256:c8716a48d94b06bb3b2524c2b77e055fb313aeb4ea620c8dd03a105574ba704f", + "sha256:cd5df75523866410809ca100dc9681e301e3c27567cf498077e8551b6d20e42f", + "sha256:e249096428b3ae81b08327a63a485ad0878de3fb939049038579ac0ef61e17e7" + ], + "version": "==1.1.1" + }, + "protobuf": { + "hashes": [ + "sha256:125713564d8cfed7610e52444c9769b8dcb0b55e25cc7841f2290ee7bc86636f", + "sha256:1accdb7a47e51503be64d9a57543964ba674edac103215576399d2d0e34eac77", + "sha256:27003d12d4f68e3cbea9eb67427cab3bfddd47ff90670cb367fcd7a3a89b9657", + "sha256:3264f3c431a631b0b31e9db2ae8c927b79fc1a7b1b06b31e8e5bcf2af91fe896", + "sha256:3c5ab0f5c71ca5af27143e60613729e3488bb45f6d3f143dc918a20af8bab0bf", + "sha256:45dcf8758873e3f69feab075e5f3177270739f146255225474ee0b90429adef6", + "sha256:56a77d61a91186cc5676d8e11b36a5feb513873e4ae88d2ee5cf530d52bbcd3b", + "sha256:5984e4947bbcef5bd849d6244aec507d31786f2dd3344139adc1489fb403b300", + "sha256:6b0441da73796dd00821763bb4119674eaf252776beb50ae3883bed179a60b2a", + "sha256:6f6677c5ade94d4fe75a912926d6796d5c71a2a90c2aeefe0d6f211d75c74789", + "sha256:84a825a9418d7196e2acc48f8746cf1ee75877ed2f30433ab92a133f3eaf8fbe", + "sha256:b842c34fe043ccf78b4a6cf1019d7b80113707d68c88842d061fa2b8fb6ddedc", + "sha256:ca33d2f09dae149a1dcf942d2d825ebb06343b77b437198c9e2ef115cf5d5bc1", + "sha256:db83b5c12c0cd30150bb568e6feb2435c49ce4e68fe2d7b903113f0e221e58fe", + "sha256:f50f3b1c5c1c1334ca7ce9cad5992f098f460ffd6388a3cabad10b66c2006b09", + "sha256:f99f127909731cafb841c52f9216e447d3e4afb99b17bebfad327a75aee206de" + ], + "index": "pypi", + "version": "==3.10.0" + }, + "six": { + "hashes": [ + "sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c", + "sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73" + ], + "version": "==1.12.0" + } + }, + "develop": { + "atomicwrites": { + "hashes": [ + "sha256:03472c30eb2c5d1ba9227e4c2ca66ab8287fbfbbda3888aa93dc2e28fc6811b4", + "sha256:75a9445bac02d8d058d5e1fe689654ba5a6556a1dfd8ce6ec55a0ed79866cfa6" + ], + "version": "==1.3.0" + }, + "attrs": { + "hashes": [ + "sha256:ec20e7a4825331c1b5ebf261d111e16fa9612c1f7a5e1f884f12bd53a664dfd2", + "sha256:f913492e1663d3c36f502e5e9ba6cd13cf19d7fab50aa13239e420fef95e1396" + ], + "version": "==19.2.0" + }, + "entrypoints": { + "hashes": [ + "sha256:589f874b313739ad35be6e0cd7efde2a4e9b6fea91edcc34e58ecbb8dbe56d19", + "sha256:c70dd71abe5a8c85e55e12c19bd91ccfeec11a6e99044204511f9ed547d48451" + ], + "version": "==0.3" + }, + "flake8": { + "hashes": [ + "sha256:19241c1cbc971b9962473e4438a2ca19749a7dd002dd1a946eaba171b4114548", + "sha256:8e9dfa3cecb2400b3738a42c54c3043e821682b9c840b0448c0503f781130696" + ], + "index": "pypi", + "version": "==3.7.8" + }, + "importlib-metadata": { + "hashes": [ + "sha256:aa18d7378b00b40847790e7c27e11673d7fed219354109d0e7b9e5b25dc3ad26", + "sha256:d5f18a79777f3aa179c145737780282e27b508fc8fd688cb17c7a813e8bd39af" + ], + "markers": "python_version < '3.8'", + "version": "==0.23" + }, + "isort": { + "hashes": [ + "sha256:54da7e92468955c4fceacd0c86bd0ec997b0e1ee80d97f67c35a78b719dccab1", + "sha256:6e811fcb295968434526407adb8796944f1988c5b65e8139058f2014cbe100fd" + ], + "index": "pypi", + "version": "==4.3.21" + }, + "mccabe": { + "hashes": [ + "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42", + "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f" + ], + "version": "==0.6.1" + }, + "more-itertools": { + "hashes": [ + "sha256:409cd48d4db7052af495b09dec721011634af3753ae1ef92d2b32f73a745f832", + "sha256:92b8c4b06dac4f0611c0729b2f2ede52b2e1bac1ab48f089c7ddc12e26bb60c4" + ], + "version": "==7.2.0" + }, + "mypy": { + "hashes": [ + "sha256:1d98fd818ad3128a5408148c9e4a5edce6ed6b58cc314283e631dd5d9216527b", + "sha256:22ee018e8fc212fe601aba65d3699689dd29a26410ef0d2cc1943de7bec7e3ac", + "sha256:3a24f80776edc706ec8d05329e854d5b9e464cd332e25cde10c8da2da0a0db6c", + "sha256:42a78944e80770f21609f504ca6c8173f7768043205b5ac51c9144e057dcf879", + "sha256:4b2b20106973548975f0c0b1112eceb4d77ed0cafe0a231a1318f3b3a22fc795", + "sha256:591a9625b4d285f3ba69f541c84c0ad9e7bffa7794da3fa0585ef13cf95cb021", + "sha256:5b4b70da3d8bae73b908a90bb2c387b977e59d484d22c604a2131f6f4397c1a3", + "sha256:84edda1ffeda0941b2ab38ecf49302326df79947fa33d98cdcfbf8ca9cf0bb23", + "sha256:b2b83d29babd61b876ae375786960a5374bba0e4aba3c293328ca6ca5dc448dd", + "sha256:cc4502f84c37223a1a5ab700649b5ab1b5e4d2bf2d426907161f20672a21930b", + "sha256:e29e24dd6e7f39f200a5bb55dcaa645d38a397dd5a6674f6042ef02df5795046" + ], + "index": "pypi", + "version": "==0.730" + }, + "mypy-extensions": { + "hashes": [ + "sha256:a161e3b917053de87dbe469987e173e49fb454eca10ef28b48b384538cc11458" + ], + "version": "==0.4.2" + }, + "packaging": { + "hashes": [ + "sha256:28b924174df7a2fa32c1953825ff29c61e2f5e082343165438812f00d3a7fc47", + "sha256:d9551545c6d761f3def1677baf08ab2a3ca17c56879e70fecba2fc4dde4ed108" + ], + "version": "==19.2" + }, + "pluggy": { + "hashes": [ + "sha256:0db4b7601aae1d35b4a033282da476845aa19185c1e6964b25cf324b5e4ec3e6", + "sha256:fa5fa1622fa6dd5c030e9cad086fa19ef6a0cf6d7a2d12318e10cb49d6d68f34" + ], + "version": "==0.13.0" + }, + "py": { + "hashes": [ + "sha256:64f65755aee5b381cea27766a3a147c3f15b9b6b9ac88676de66ba2ae36793fa", + "sha256:dc639b046a6e2cff5bbe40194ad65936d6ba360b52b3c3fe1d08a82dd50b5e53" + ], + "version": "==1.8.0" + }, + "pycodestyle": { + "hashes": [ + "sha256:95a2219d12372f05704562a14ec30bc76b05a5b297b21a5dfe3f6fac3491ae56", + "sha256:e40a936c9a450ad81df37f549d676d127b1b66000a6c500caa2b085bc0ca976c" + ], + "version": "==2.5.0" + }, + "pyflakes": { + "hashes": [ + "sha256:17dbeb2e3f4d772725c777fabc446d5634d1038f234e77343108ce445ea69ce0", + "sha256:d976835886f8c5b31d47970ed689944a0262b5f3afa00a5a7b4dc81e5449f8a2" + ], + "version": "==2.1.1" + }, + "pyparsing": { + "hashes": [ + "sha256:6f98a7b9397e206d78cc01df10131398f1c8b8510a2f4d97d9abd82e1aacdd80", + "sha256:d9338df12903bbf5d65a0e4e87c2161968b10d2e489652bb47001d82a9b028b4" + ], + "version": "==2.4.2" + }, + "pytest": { + "hashes": [ + "sha256:13c1c9b22127a77fc684eee24791efafcef343335d855e3573791c68588fe1a5", + "sha256:d8ba7be9466f55ef96ba203fc0f90d0cf212f2f927e69186e1353e30bc7f62e5" + ], + "index": "pypi", + "version": "==5.2.0" + }, + "six": { + "hashes": [ + "sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c", + "sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73" + ], + "version": "==1.12.0" + }, + "typed-ast": { + "hashes": [ + "sha256:18511a0b3e7922276346bcb47e2ef9f38fb90fd31cb9223eed42c85d1312344e", + "sha256:262c247a82d005e43b5b7f69aff746370538e176131c32dda9cb0f324d27141e", + "sha256:2b907eb046d049bcd9892e3076c7a6456c93a25bebfe554e931620c90e6a25b0", + "sha256:354c16e5babd09f5cb0ee000d54cfa38401d8b8891eefa878ac772f827181a3c", + "sha256:4e0b70c6fc4d010f8107726af5fd37921b666f5b31d9331f0bd24ad9a088e631", + "sha256:630968c5cdee51a11c05a30453f8cd65e0cc1d2ad0d9192819df9978984529f4", + "sha256:66480f95b8167c9c5c5c87f32cf437d585937970f3fc24386f313a4c97b44e34", + "sha256:71211d26ffd12d63a83e079ff258ac9d56a1376a25bc80b1cdcdf601b855b90b", + "sha256:95bd11af7eafc16e829af2d3df510cecfd4387f6453355188342c3e79a2ec87a", + "sha256:bc6c7d3fa1325a0c6613512a093bc2a2a15aeec350451cbdf9e1d4bffe3e3233", + "sha256:cc34a6f5b426748a507dd5d1de4c1978f2eb5626d51326e43280941206c209e1", + "sha256:d755f03c1e4a51e9b24d899561fec4ccaf51f210d52abdf8c07ee2849b212a36", + "sha256:d7c45933b1bdfaf9f36c579671fec15d25b06c8398f113dab64c18ed1adda01d", + "sha256:d896919306dd0aa22d0132f62a1b78d11aaf4c9fc5b3410d3c666b818191630a", + "sha256:ffde2fbfad571af120fcbfbbc61c72469e72f550d676c3342492a9dfdefb8f12" + ], + "version": "==1.4.0" + }, + "typing-extensions": { + "hashes": [ + "sha256:2ed632b30bb54fc3941c382decfd0ee4148f5c591651c9272473fea2c6397d95", + "sha256:b1edbbf0652660e32ae780ac9433f4231e7339c7f9a8057d0f042fcbcea49b87", + "sha256:d8179012ec2c620d3791ca6fe2bf7979d979acdbef1fca0bc56b37411db682ed" + ], + "version": "==3.7.4" + }, + "wcwidth": { + "hashes": [ + "sha256:3df37372226d6e63e1b1e1eda15c594bca98a22d33a23832a90998faa96bc65e", + "sha256:f4ebe71925af7b40a864553f761ed559b43544f8f71746c2d756c7fe788ade7c" + ], + "version": "==0.1.7" + }, + "zipp": { + "hashes": [ + "sha256:3718b1cbcd963c7d4c5511a8240812904164b7f381b647143a89d3b98f9bcd8e", + "sha256:f06903e9f1f43b12d371004b4ac7b06ab39a44adc747266928ae6debfa7b3335" + ], + "version": "==0.6.0" + } + } +} diff --git a/README.md b/README.md new file mode 100644 index 0000000..2069637 --- /dev/null +++ b/README.md @@ -0,0 +1,10 @@ +# TODO + +- [ ] Fixed length fields +- [x] Zig-zag signed fields (sint32, sint64) +- [x] Don't encode zero values for nested types~ +- [ ] Enums +- [ ] Maps +- [ ] Support passthrough of unknown fields +- [ ] JSON that isn't naive. +- [ ] Cleanup! diff --git a/betterproto/__init__.py b/betterproto/__init__.py new file mode 100644 index 0000000..fc13b27 --- /dev/null +++ b/betterproto/__init__.py @@ -0,0 +1,262 @@ +from abc import ABC +import json +import struct +from typing import ( + Union, + Generator, + Any, + SupportsBytes, + List, + Tuple, + Callable, + Type, + Iterable, + TypeVar, +) +import dataclasses + +from . import parse, serialize + +PACKED_TYPES = [ + "bool", + "int32", + "int64", + "uint32", + "uint64", + "sint32", + "sint64", + "float", + "double", +] + +# Wire types +# https://developers.google.com/protocol-buffers/docs/encoding#structure +WIRE_VARINT = 0 +WIRE_FIXED_64 = 1 +WIRE_LEN_DELIM = 2 +WIRE_FIXED_32 = 5 + + +@dataclasses.dataclass(frozen=True) +class _Meta: + number: int + proto_type: str + default: Any + + +def field(number: int, proto_type: str, default: Any) -> dataclasses.Field: + kwargs = {} + + if callable(default): + kwargs["default_factory"] = default + elif isinstance(default, dict) or isinstance(default, list): + kwargs["default_factory"] = lambda: default + else: + kwargs["default"] = default + + return dataclasses.field( + **kwargs, metadata={"betterproto": _Meta(number, proto_type, default)} + ) + + +def int32_field( + number: int, default: Union[int, Type[Iterable]] = 0 +) -> dataclasses.Field: + return field(number, "int32", default=default) + + +def int64_field(number: int, default: int = 0) -> dataclasses.Field: + return field(number, "int64", default=default) + + +def uint32_field(number: int, default: int = 0) -> dataclasses.Field: + return field(number, "uint32", default=default) + + +def uint64_field(number: int, default: int = 0) -> dataclasses.Field: + return field(number, "uint64", default=default) + + +def sint32_field(number: int, default: int = 0) -> dataclasses.Field: + return field(number, "sint32", default=default) + + +def sint64_field(number: int, default: int = 0) -> dataclasses.Field: + return field(number, "sint64", default=default) + + +def float_field(number: int, default: float = 0.0) -> dataclasses.Field: + return field(number, "float", default=default) + + +def double_field(number: int, default: float = 0.0) -> dataclasses.Field: + return field(number, "double", default=default) + + +def string_field(number: int, default: str = "") -> dataclasses.Field: + return field(number, "string", default=default) + + +def message_field(number: int, default: Type["ProtoMessage"]) -> dataclasses.Field: + return field(number, "message", default=default) + + +def _serialize_single(meta: _Meta, value: Any) -> bytes: + output = b"" + if meta.proto_type in ["int32", "int64", "uint32", "uint64"]: + if value < 0: + # Handle negative numbers. + value += 1 << 64 + output = serialize.varint(meta.number, value) + elif meta.proto_type in ["sint32", "sint64"]: + if value >= 0: + value = value << 1 + else: + value = (value << 1) ^ (~0) + output = serialize.varint(meta.number, value) + elif meta.proto_type == "string": + output = serialize.len_delim(meta.number, value.encode("utf-8")) + elif meta.proto_type == "message": + b = bytes(value) + if len(b): + output = serialize.len_delim(meta.number, b) + else: + raise NotImplementedError() + + return output + + +def _parse_single(wire_type: int, meta: _Meta, field: Any, value: Any) -> Any: + if wire_type == WIRE_VARINT: + if meta.proto_type in ["int32", "int64"]: + bits = int(meta.proto_type[3:]) + value = value & ((1 << bits) - 1) + signbit = 1 << (bits - 1) + value = int((value ^ signbit) - signbit) + elif meta.proto_type in ["sint32", "sint64"]: + # Undo zig-zag encoding + value = (value >> 1) ^ (-(value & 1)) + elif wire_type == WIRE_LEN_DELIM: + if meta.proto_type in ["string"]: + value = value.decode("utf-8") + elif meta.proto_type in ["message"]: + value = field.default_factory().parse(value) + + return value + + +# Bound type variable to allow methods to return `self` of subclasses +T = TypeVar("T", bound="Message") + + +class Message(ABC): + """ + A protobuf message base class. Generated code will inherit from this and + register the message fields which get used by the serializers and parsers + to go between Python, binary and JSON protobuf message representations. + """ + + def __bytes__(self) -> bytes: + """ + Get the binary encoded Protobuf representation of this instance. + """ + output = b"" + for field in dataclasses.fields(self): + meta: _Meta = field.metadata.get("betterproto") + value = getattr(self, field.name) + + if isinstance(value, list): + if not len(value): + continue + + if meta.proto_type in PACKED_TYPES: + output += serialize.packed(meta.number, value) + else: + for item in value: + output += _serialize_single(meta, item) + else: + if value == field.default: + continue + + output += _serialize_single(meta, value) + + return output + + def parse(self, data: bytes) -> T: + """ + Parse the binary encoded Protobuf into this message instance. This + returns the instance itself and is therefore assignable and chainable. + """ + fields = {f.metadata["betterproto"].number: f for f in dataclasses.fields(self)} + for parsed in parse.fields(data): + if parsed.number in fields: + field = fields[parsed.number] + meta: _Meta = field.metadata.get("betterproto") + + if ( + parsed.wire_type == WIRE_LEN_DELIM + and meta.proto_type in PACKED_TYPES + ): + # This is a packed repeated field. + pos = 0 + value = [] + while pos < len(parsed.value): + decoded, pos = parse._decode_varint(parsed.value, pos) + decoded = _parse_single(WIRE_VARINT, meta, field, decoded) + value.append(decoded) + else: + value = _parse_single(parsed.wire_type, meta, field, parsed.value) + + if isinstance(getattr(self, field.name), list) and not isinstance( + value, list + ): + getattr(self, field.name).append(value) + else: + setattr(self, field.name, value) + else: + # TODO: handle unknown fields + pass + + return self + + def to_dict(self) -> dict: + """ + Returns a dict representation of this message instance which can be + used to serialize to e.g. JSON. + """ + output = {} + for field in dataclasses.fields(self): + meta: Meta_ = field.metadata.get("betterproto") + v = getattr(self, field.name) + if meta.proto_type == "message": + v = v.to_dict() + if v: + output[field.name] = v + elif v != field.default: + output[field.name] = getattr(self, field.name) + return output + + def from_dict(self, value: dict) -> T: + """ + Parse the key/value pairs in `value` into this message instance. This + returns the instance itself and is therefore assignable and chainable. + """ + for field in dataclasses.fields(self): + meta: Meta_ = field.metadata.get("betterproto") + if field.name in value: + if meta.proto_type == "message": + getattr(self, field.name).from_dict(value[field.name]) + else: + setattr(self, field.name, value[field.name]) + return self + + def to_json(self) -> bytes: + """Returns the encoded JSON representation of this message instance.""" + return json.dumps(self.to_dict()) + + def from_json(self, value: bytes) -> T: + """ + Parse the key/value pairs in `value` into this message instance. This + returns the instance itself and is therefore assignable and chainable. + """ + return self.from_dict(json.loads(value)) diff --git a/betterproto/__pycache__/__init__.cpython-37.pyc b/betterproto/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..c8073d1 Binary files /dev/null and b/betterproto/__pycache__/__init__.cpython-37.pyc differ diff --git a/betterproto/parse.py b/betterproto/parse.py new file mode 100644 index 0000000..69ed554 --- /dev/null +++ b/betterproto/parse.py @@ -0,0 +1,64 @@ +import struct +from typing import Union, Generator, Any, SupportsBytes, List, Tuple +from dataclasses import dataclass + + +def _decode_varint( + buffer: bytes, pos: int, signed: bool = False, result_type: type = int +) -> Tuple[int, int]: + result = 0 + shift = 0 + while 1: + b = buffer[pos] + result |= (b & 0x7F) << shift + pos += 1 + if not (b & 0x80): + result = result_type(result) + return (result, pos) + shift += 7 + if shift >= 64: + raise ValueError("Too many bytes when decoding varint.") + + +def packed(value: bytes, signed: bool = False, result_type: type = int) -> list: + parsed = [] + pos = 0 + while pos < len(value): + decoded, pos = _decode_varint( + value, pos, signed=signed, result_type=result_type + ) + parsed.append(decoded) + return parsed + + +@dataclass(frozen=True) +class Field: + number: int + wire_type: int + value: Any + + +def fields(value: bytes) -> Generator[Field, None, None]: + i = 0 + while i < len(value): + num_wire, i = _decode_varint(value, i) + print(num_wire, i) + number = num_wire >> 3 + wire_type = num_wire & 0x7 + + if wire_type == 0: + decoded, i = _decode_varint(value, i) + elif wire_type == 1: + decoded, i = None, i + 4 + elif wire_type == 2: + length, i = _decode_varint(value, i) + decoded = value[i : i + length] + i += length + elif wire_type == 5: + decoded, i = None, i + 2 + else: + raise NotImplementedError(f"Wire type {wire_type}") + + # print(Field(number=number, wire_type=wire_type, value=decoded)) + + yield Field(number=number, wire_type=wire_type, value=decoded) diff --git a/betterproto/py.typed b/betterproto/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/betterproto/serialize.py b/betterproto/serialize.py new file mode 100644 index 0000000..3d9ac50 --- /dev/null +++ b/betterproto/serialize.py @@ -0,0 +1,44 @@ +import struct +from typing import Union, Generator, Any, SupportsBytes, List, Tuple +from dataclasses import dataclass + + +def _varint(value: int) -> bytes: + # From https://github.com/protocolbuffers/protobuf/blob/master/python/google/protobuf/internal/encoder.py#L372 + b: List[int] = [] + + bits = value & 0x7F + value >>= 7 + while value: + b.append(0x80 | bits) + bits = value & 0x7F + value >>= 7 + print(value) + return bytes(b + [bits]) + + +def varint(field_number: int, value: Union[int, float]) -> bytes: + key = _varint(field_number << 3) + return key + _varint(value) + + +def len_delim(field_number: int, value: Union[str, bytes]) -> bytes: + key = _varint((field_number << 3) | 2) + + if isinstance(value, str): + value = value.encode("utf-8") + + return key + _varint(len(value)) + value + + +def packed(field_number: int, value: list) -> bytes: + key = _varint((field_number << 3) | 2) + + packed = b"" + for item in value: + if item < 0: + # Handle negative numbers. + item += 1 << 64 + packed += _varint(item) + + return key + _varint(len(packed)) + packed diff --git a/betterproto/templates/main.py b/betterproto/templates/main.py new file mode 100644 index 0000000..d831cf3 --- /dev/null +++ b/betterproto/templates/main.py @@ -0,0 +1,24 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: {{ description.filename }} +from dataclasses import dataclass +from typing import List + +import betterproto + + +{% for message in description.messages %} +@dataclass +class {{ message.name }}(betterproto.Message): + {% if message.comment %} +{{ message.comment }} + + {% endif %} + {% for field in message.properties %} + {% if field.comment %} +{{ field.comment }} + {% endif %} + {{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.zero %}, default={{ field.zero }}{% endif %}) + {% endfor %} + + +{% endfor %} diff --git a/betterproto/tests/generate.py b/betterproto/tests/generate.py new file mode 100644 index 0000000..51a9ca7 --- /dev/null +++ b/betterproto/tests/generate.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +import os # isort: skip + +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + + +import subprocess +import importlib +from typing import Generator, Tuple + +from google.protobuf.json_format import Parse +from google.protobuf import symbol_database +from google.protobuf.descriptor_pool import DescriptorPool + +root = os.path.dirname(os.path.realpath(__file__)) + + +def get_files(end: str) -> Generator[Tuple[str, str], None, None]: + for r, dirs, files in os.walk(root): + for filename in [f for f in files if f.endswith(end)]: + parts = os.path.splitext(filename)[0].split("-") + yield [parts[0], os.path.join(r, filename)] + + +if __name__ == "__main__": + os.chdir(root) + + for base, filename in get_files(".proto"): + subprocess.run( + f"protoc --python_out=. {os.path.basename(filename)}", shell=True + ) + subprocess.run( + f"protoc --plugin=protoc-gen-custom=../../protoc-gen-betterpy.py --custom_out=. {os.path.basename(filename)}", + shell=True, + ) + + for base, filename in get_files(".json"): + # Reset the internal symbol database so we can import the `Test` message + # multiple times. Ugh. + sym = symbol_database.Default() + sym.pool = DescriptorPool() + imported = importlib.import_module(f"{base}_pb2") + out = filename.replace(".json", ".bin") + serialized = Parse(open(filename).read(), imported.Test()).SerializeToString() + open(out, "wb").write(serialized) diff --git a/betterproto/tests/int32-negative.json b/betterproto/tests/int32-negative.json new file mode 100644 index 0000000..0d2bb48 --- /dev/null +++ b/betterproto/tests/int32-negative.json @@ -0,0 +1,3 @@ +{ + "count": -150 +} diff --git a/betterproto/tests/int32.json b/betterproto/tests/int32.json new file mode 100644 index 0000000..9514828 --- /dev/null +++ b/betterproto/tests/int32.json @@ -0,0 +1,3 @@ +{ + "count": 150 +} diff --git a/betterproto/tests/int32.proto b/betterproto/tests/int32.proto new file mode 100644 index 0000000..6b46857 --- /dev/null +++ b/betterproto/tests/int32.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +// Some documentation about the Test message. +message Test { + // Some documentation about the count. + int32 count = 1; +} diff --git a/betterproto/tests/nested.json b/betterproto/tests/nested.json new file mode 100644 index 0000000..217a7d4 --- /dev/null +++ b/betterproto/tests/nested.json @@ -0,0 +1,5 @@ +{ + "nested": { + "count": 150 + } +} diff --git a/betterproto/tests/nested.proto b/betterproto/tests/nested.proto new file mode 100644 index 0000000..0ed4540 --- /dev/null +++ b/betterproto/tests/nested.proto @@ -0,0 +1,17 @@ +syntax = "proto3"; + +// A test message with a nested message inside of it. +message Test { + // This is the nested type. + message Nested { + // Stores a simple counter. + int32 count = 1; + } + + Nested nested = 1; + Sibling sibling = 2; +} + +message Sibling { + int32 foo = 1; +} \ No newline at end of file diff --git a/betterproto/tests/repeated.json b/betterproto/tests/repeated.json new file mode 100644 index 0000000..b8a7c4e --- /dev/null +++ b/betterproto/tests/repeated.json @@ -0,0 +1,3 @@ +{ + "names": ["one", "two", "three"] +} diff --git a/betterproto/tests/repeated.proto b/betterproto/tests/repeated.proto new file mode 100644 index 0000000..42c1132 --- /dev/null +++ b/betterproto/tests/repeated.proto @@ -0,0 +1,5 @@ +syntax = "proto3"; + +message Test { + repeated string names = 1; +} diff --git a/betterproto/tests/repeatedpacked.json b/betterproto/tests/repeatedpacked.json new file mode 100644 index 0000000..7d9ae00 --- /dev/null +++ b/betterproto/tests/repeatedpacked.json @@ -0,0 +1,3 @@ +{ + "counts": [1, 2, -1, -2] +} diff --git a/betterproto/tests/repeatedpacked.proto b/betterproto/tests/repeatedpacked.proto new file mode 100644 index 0000000..0662cdb --- /dev/null +++ b/betterproto/tests/repeatedpacked.proto @@ -0,0 +1,5 @@ +syntax = "proto3"; + +message Test { + repeated int32 counts = 1; +} diff --git a/betterproto/tests/signed-negative.json b/betterproto/tests/signed-negative.json new file mode 100644 index 0000000..85e74c8 --- /dev/null +++ b/betterproto/tests/signed-negative.json @@ -0,0 +1,4 @@ +{ + "signed_32": -150, + "signed_64": -150 +} diff --git a/betterproto/tests/signed.json b/betterproto/tests/signed.json new file mode 100644 index 0000000..3d5696a --- /dev/null +++ b/betterproto/tests/signed.json @@ -0,0 +1,4 @@ +{ + "signed_32": 150, + "signed_64": 150 +} diff --git a/betterproto/tests/signed.proto b/betterproto/tests/signed.proto new file mode 100644 index 0000000..49b2bfd --- /dev/null +++ b/betterproto/tests/signed.proto @@ -0,0 +1,6 @@ +syntax = "proto3"; + +message Test { + sint32 signed_32 = 1; + sint64 signed_64 = 2; +} diff --git a/betterproto/tests/test_inputs.py b/betterproto/tests/test_inputs.py new file mode 100644 index 0000000..1ac7f6f --- /dev/null +++ b/betterproto/tests/test_inputs.py @@ -0,0 +1,23 @@ +import importlib +import pytest +import json + +from generate import get_files + +inputs = get_files(".bin") + + +@pytest.mark.parametrize("name,filename", inputs) +def test_sample(name: str, filename: str) -> None: + imported = importlib.import_module(name) + data_binary = open(filename, "rb").read() + data_dict = json.loads(open(filename.replace(".bin", ".json")).read()) + t1 = imported.Test().parse(data_binary) + t2 = imported.Test().from_dict(data_dict) + print(t1) + print(t2) + assert t1 == t2 + assert bytes(t1) == data_binary + assert bytes(t2) == data_binary + assert t1.to_dict() == data_dict + assert t2.to_dict() == data_dict diff --git a/protoc-gen-betterpy.py b/protoc-gen-betterpy.py new file mode 100755 index 0000000..d2f96b5 --- /dev/null +++ b/protoc-gen-betterpy.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python + +import sys + +import itertools +import json +import os.path +from typing import Tuple, Any, List +import textwrap + +from google.protobuf.descriptor_pb2 import ( + DescriptorProto, + EnumDescriptorProto, + FileDescriptorProto, +) + +from google.protobuf.compiler import plugin_pb2 as plugin + + +from jinja2 import Environment, PackageLoader + + +def py_type(descriptor: DescriptorProto) -> Tuple[str, str]: + if descriptor.type in [1, 2, 6, 7, 15, 16]: + return "float", descriptor.default_value + elif descriptor.type in [3, 4, 5, 13, 17, 18]: + return "int", descriptor.default_value + elif descriptor.type == 8: + return "bool", descriptor.default_value.capitalize() + elif descriptor.type == 9: + default = "" + if descriptor.default_value: + default = f'"{descriptor.default_value}"' + return "str", default + elif descriptor.type == 11: + # Type referencing another defined Message + # print(descriptor.type_name, file=sys.stderr) + # message_type = descriptor.type_name.replace(".", "") + message_type = descriptor.type_name.split(".").pop() + return f'"{message_type}"', f"lambda: {message_type}()" + elif descriptor.type == 12: + default = "" + if descriptor.default_value: + default = f'b"{descriptor.default_value}"' + return "bytes", default + else: + raise NotImplementedError() + + +def traverse(proto_file): + def _traverse(path, items): + for i, item in enumerate(items): + yield item, path + [i] + + if isinstance(item, DescriptorProto): + for enum in item.enum_type: + yield enum, path + [i, 4] + + if item.nested_type: + for n, p in _traverse(path + [i, 3], item.nested_type): + yield n, p + + return itertools.chain( + _traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type) + ) + + +def get_comment(proto_file, path: List[int]) -> str: + for sci in proto_file.source_code_info.location: + # print(list(sci.path), path, file=sys.stderr) + if list(sci.path) == path and sci.leading_comments: + lines = textwrap.wrap( + sci.leading_comments.strip().replace("\n", ""), width=75 + ) + + if path[-2] == 2: + # This is a field + return " # " + " # ".join(lines) + else: + # This is a class + if len(lines) == 1 and len(lines[0]) < 70: + return f' """{lines[0]}"""' + else: + return f' """\n{" ".join(lines)}\n """' + + return "" + + +def generate_code(request, response): + env = Environment( + trim_blocks=True, + lstrip_blocks=True, + loader=PackageLoader("betterproto", "templates"), + ) + template = env.get_template("main.py") + + for proto_file in request.proto_file: + # print(proto_file.message_type, file=sys.stderr) + # print(proto_file.source_code_info, file=sys.stderr) + output = { + "package": proto_file.package, + "filename": proto_file.name, + "messages": [], + } + + # Parse request + for item, path in traverse(proto_file): + # print(item, file=sys.stderr) + # print(path, file=sys.stderr) + data = {"name": item.name} + + if isinstance(item, DescriptorProto): + # print(item, file=sys.stderr) + data.update( + { + "type": "Message", + "comment": get_comment(proto_file, path), + "properties": [], + } + ) + + for i, f in enumerate(item.field): + t, zero = py_type(f) + repeated = False + packed = False + + if f.label == 3: + # Repeated field + repeated = True + t = f"List[{t}]" + zero = "list" + + if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: + packed = True + + data["properties"].append( + { + "name": f.name, + "number": f.number, + "comment": get_comment(proto_file, path + [2, i]), + "proto_type": int(f.type), + "field_type": f.Type.Name(f.type).lower()[5:], + "type": t, + "zero": zero, + "repeated": repeated, + "packed": packed, + } + ) + # print(f, file=sys.stderr) + + # elif isinstance(item, EnumDescriptorProto): + # data.update({ + # 'type': 'Enum', + # 'values': [{'name': v.name, 'value': v.number} + # for v in item.value] + # }) + + output["messages"].append(data) + + # Fill response + f = response.file.add() + f.name = os.path.splitext(proto_file.name)[0] + ".py" + # f.content = json.dumps(output, indent=2) + f.content = template.render(description=output).rstrip("\n") + "\n" + + +if __name__ == "__main__": + # Read request message from stdin + data = sys.stdin.buffer.read() + + # Parse request + request = plugin.CodeGeneratorRequest() + request.ParseFromString(data) + + # Create response + response = plugin.CodeGeneratorResponse() + + # Generate code + generate_code(request, response) + + # Serialise response message + output = response.SerializeToString() + + # Write to stdout + sys.stdout.buffer.write(output)