diff --git a/Pipfile b/Pipfile index c66afd5..18f83ea 100644 --- a/Pipfile +++ b/Pipfile @@ -8,6 +8,7 @@ flake8 = "*" mypy = "*" isort = "*" pytest = "*" +rope = "*" [packages] protobuf = "*" diff --git a/Pipfile.lock b/Pipfile.lock index 415b48c..03ab532 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "817b0f61c21a4841d0cfcc977becb16b4d55090f3d78c1ebcd6974c298a06348" + "sha256": "6c1797fb4eb73be97ca566206527c9d648b90f38c5bf2caf4b69537cd325ced9" }, "pipfile-spec": 6, "requires": { @@ -18,11 +18,11 @@ "default": { "jinja2": { "hashes": [ - "sha256:065c4f02ebe7f7cf559e49ee5a95fb800a9e4528727aec6f24402a5374c65013", - "sha256:14dd6caf1527abb21f08f86c784eac40853ba93edb79552aa1e4b8aef1b61c7b" + "sha256:74320bb91f31270f9551d46522e33af46a80c3d619f4a4bf42b3164d30b5911f", + "sha256:9fe95f19286cfefaa917656583d020be14e7859c6b0252588391e47db34527de" ], "index": "pypi", - "version": "==2.10.1" + "version": "==2.10.3" }, "markupsafe": { "hashes": [ @@ -214,11 +214,20 @@ }, "pytest": { "hashes": [ - "sha256:13c1c9b22127a77fc684eee24791efafcef343335d855e3573791c68588fe1a5", - "sha256:d8ba7be9466f55ef96ba203fc0f90d0cf212f2f927e69186e1353e30bc7f62e5" + "sha256:7e4800063ccfc306a53c461442526c5571e1462f61583506ce97e4da6a1d88c8", + "sha256:ca563435f4941d0cb34767301c27bc65c510cb82e90b9ecf9cb52dc2c63caaa0" ], "index": "pypi", - "version": "==5.2.0" + "version": "==5.2.1" + }, + "rope": { + "hashes": [ + "sha256:6b728fdc3e98a83446c27a91fc5d56808a004f8beab7a31ab1d7224cecc7d969", + "sha256:c5c5a6a87f7b1a2095fb311135e2a3d1f194f5ecb96900fdd0a9100881f48aaf", + "sha256:f0dcf719b63200d492b85535ebe5ea9b29e0d0b8aebeb87fe03fc1a65924fdaf" + ], + "index": "pypi", + "version": "==0.14.0" }, "six": { "hashes": [ diff --git a/betterproto/__init__.py b/betterproto/__init__.py index adea090..d256c3e 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -15,8 +15,7 @@ from typing import ( ) import dataclasses -from . import parse, serialize - +# Proto 3 data types TYPE_ENUM = "enum" TYPE_BOOL = "bool" TYPE_INT32 = "int32" @@ -36,6 +35,7 @@ TYPE_BYTES = "bytes" TYPE_MESSAGE = "message" +# Fields that use a fixed amount of space (4 or 8 bytes) FIXED_TYPES = [ TYPE_FLOAT, TYPE_DOUBLE, @@ -45,6 +45,7 @@ FIXED_TYPES = [ TYPE_SFIXED64, ] +# Fields that are efficiently packed when PACKED_TYPES = [ TYPE_ENUM, TYPE_BOOL, @@ -69,6 +70,7 @@ WIRE_FIXED_64 = 1 WIRE_LEN_DELIM = 2 WIRE_FIXED_32 = 5 +# Mappings of which Proto 3 types correspond to which wire types. WIRE_VARINT_TYPES = [ TYPE_ENUM, TYPE_BOOL, @@ -86,13 +88,24 @@ WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE] @dataclasses.dataclass(frozen=True) -class _Meta: +class FieldMetadata: + """Stores internal metadata used for parsing & serialization.""" + + # Protobuf field number number: int + # Protobuf type name proto_type: str + # Default value if given default: Any + @staticmethod + def get(field: dataclasses.Field) -> "FieldMetadata": + """Returns the field metadata for a dataclass field.""" + return field.metadata["betterproto"] + def field(number: int, proto_type: str, default: Any) -> dataclasses.Field: + """Creates a dataclass field with attached protobuf metadata.""" kwargs = {} if callable(default): @@ -103,7 +116,7 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field: kwargs["default"] = default return dataclasses.field( - **kwargs, metadata={"betterproto": _Meta(number, proto_type, default)} + **kwargs, metadata={"betterproto": FieldMetadata(number, proto_type, default)} ) @@ -113,97 +126,91 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field: def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any: - return field(number, "enum", default=default) + return field(number, TYPE_ENUM, default=default) def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any: - return field(number, "int32", default=default) + return field(number, TYPE_INT32, default=default) def int64_field(number: int, default: int = 0) -> Any: - return field(number, "int64", default=default) + return field(number, TYPE_INT64, default=default) def uint32_field(number: int, default: int = 0) -> Any: - return field(number, "uint32", default=default) + return field(number, TYPE_UINT32, default=default) def uint64_field(number: int, default: int = 0) -> Any: - return field(number, "uint64", default=default) + return field(number, TYPE_UINT64, default=default) def sint32_field(number: int, default: int = 0) -> Any: - return field(number, "sint32", default=default) + return field(number, TYPE_SINT32, default=default) def sint64_field(number: int, default: int = 0) -> Any: - return field(number, "sint64", default=default) + return field(number, TYPE_SINT64, default=default) def float_field(number: int, default: float = 0.0) -> Any: - return field(number, "float", default=default) + return field(number, TYPE_FLOAT, default=default) def double_field(number: int, default: float = 0.0) -> Any: - return field(number, "double", default=default) + return field(number, TYPE_DOUBLE, default=default) def fixed32_field(number: int, default: float = 0.0) -> Any: - return field(number, "fixed32", default=default) + return field(number, TYPE_FIXED32, default=default) def fixed64_field(number: int, default: float = 0.0) -> Any: - return field(number, "fixed64", default=default) + return field(number, TYPE_FIXED64, default=default) def sfixed32_field(number: int, default: float = 0.0) -> Any: - return field(number, "sfixed32", default=default) + return field(number, TYPE_SFIXED32, default=default) def sfixed64_field(number: int, default: float = 0.0) -> Any: - return field(number, "sfixed64", default=default) + return field(number, TYPE_SFIXED64, default=default) def string_field(number: int, default: str = "") -> Any: - return field(number, "string", default=default) + return field(number, TYPE_STRING, default=default) def message_field(number: int, default: Type["Message"]) -> Any: - return field(number, "message", default=default) + return field(number, TYPE_MESSAGE, default=default) def _pack_fmt(proto_type: str) -> str: + """Returns a little-endian format string for reading/writing binary.""" return { - "double": " bytes: - value = _preprocess_single(proto_type, value) +def encode_varint(value: int) -> bytes: + """Encodes a single varint value for serialization.""" + b: List[int] = [] - output = b"" - if proto_type in WIRE_VARINT_TYPES: - key = serialize._varint(field_number << 3) - output += key + value - elif proto_type in WIRE_FIXED_32_TYPES: - key = serialize._varint((field_number << 3) | 5) - output += key + value - elif proto_type in WIRE_FIXED_64_TYPES: - key = serialize._varint((field_number << 3) | 1) - output += key + value - elif proto_type in WIRE_LEN_DELIM_TYPES: - if len(value): - key = serialize._varint((field_number << 3) | 2) - output += key + serialize._varint(len(value)) + value - else: - raise NotImplementedError(proto_type) + if value < 0: + value += 1 << 64 - return output + bits = value & 0x7F + value >>= 7 + while value: + b.append(0x80 | bits) + bits = value & 0x7F + value >>= 7 + return bytes(b + [bits]) def _preprocess_single(proto_type: str, value: Any) -> bytes: @@ -216,14 +223,14 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes: TYPE_UINT32, TYPE_UINT64, ]: - return serialize._varint(value) + return encode_varint(value) elif proto_type in [TYPE_SINT32, TYPE_SINT64]: # Handle zig-zag encoding. if value >= 0: value = value << 1 else: value = (value << 1) ^ (~0) - return serialize._varint(value) + return encode_varint(value) elif proto_type in FIXED_TYPES: return struct.pack(_pack_fmt(proto_type), value) elif proto_type == TYPE_STRING: @@ -234,7 +241,51 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes: return value -def _postprocess_single(wire_type: int, meta: _Meta, field: Any, value: Any) -> Any: +def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes: + """Serializes a single field and value.""" + value = _preprocess_single(proto_type, value) + + output = b"" + if proto_type in WIRE_VARINT_TYPES: + key = encode_varint(field_number << 3) + output += key + value + elif proto_type in WIRE_FIXED_32_TYPES: + key = encode_varint((field_number << 3) | 5) + output += key + value + elif proto_type in WIRE_FIXED_64_TYPES: + key = encode_varint((field_number << 3) | 1) + output += key + value + elif proto_type in WIRE_LEN_DELIM_TYPES: + if len(value): + key = encode_varint((field_number << 3) | 2) + output += key + encode_varint(len(value)) + value + else: + raise NotImplementedError(proto_type) + + return output + + +def decode_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, int]: + """ + Decode a single varint value from a byte buffer. Returns the value and the + new position in the buffer. + """ + result = 0 + shift = 0 + while 1: + b = buffer[pos] + result |= (b & 0x7F) << shift + pos += 1 + if not (b & 0x80): + return (result, pos) + shift += 7 + if shift >= 64: + raise ValueError("Too many bytes when decoding varint.") + + +def _postprocess_single( + wire_type: int, meta: FieldMetadata, field: Any, value: Any +) -> Any: """Adjusts values after parsing.""" if wire_type == WIRE_VARINT: if meta.proto_type in ["int32", "int64"]: @@ -257,6 +308,39 @@ def _postprocess_single(wire_type: int, meta: _Meta, field: Any, value: Any) -> return value +@dataclasses.dataclass(frozen=True) +class ParsedField: + number: int + wire_type: int + value: Any + + +def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: + i = 0 + while i < len(value): + num_wire, i = decode_varint(value, i) + # print(num_wire, i) + number = num_wire >> 3 + wire_type = num_wire & 0x7 + + if wire_type == 0: + decoded, i = decode_varint(value, i) + elif wire_type == 1: + decoded, i = value[i : i + 8], i + 8 + elif wire_type == 2: + length, i = decode_varint(value, i) + decoded = value[i : i + length] + i += length + elif wire_type == 5: + decoded, i = value[i : i + 4], i + 4 + else: + raise NotImplementedError(f"Wire type {wire_type}") + + # print(ParsedField(number=number, wire_type=wire_type, value=decoded)) + + yield ParsedField(number=number, wire_type=wire_type, value=decoded) + + # Bound type variable to allow methods to return `self` of subclasses T = TypeVar("T", bound="Message") @@ -274,7 +358,7 @@ class Message(ABC): """ output = b"" for field in dataclasses.fields(self): - meta: _Meta = field.metadata.get("betterproto") + meta = FieldMetadata.get(field) value = getattr(self, field.name) if isinstance(value, list): @@ -306,10 +390,10 @@ class Message(ABC): returns the instance itself and is therefore assignable and chainable. """ fields = {f.metadata["betterproto"].number: f for f in dataclasses.fields(self)} - for parsed in parse.fields(data): + for parsed in parse_fields(data): if parsed.number in fields: field = fields[parsed.number] - meta: _Meta = field.metadata.get("betterproto") + meta = FieldMetadata.get(field) if ( parsed.wire_type == WIRE_LEN_DELIM @@ -326,7 +410,7 @@ class Message(ABC): decoded, pos = parsed.value[pos : pos + 8], pos + 8 wire_type = WIRE_FIXED_64 else: - decoded, pos = parse.parse_varint(parsed.value, pos) + decoded, pos = decode_varint(parsed.value, pos) wire_type = WIRE_VARINT decoded = _postprocess_single(wire_type, meta, field, decoded) value.append(decoded) @@ -354,7 +438,7 @@ class Message(ABC): """ output = {} for field in dataclasses.fields(self): - meta: Meta_ = field.metadata.get("betterproto") + meta = FieldMetadata.get(field) v = getattr(self, field.name) if meta.proto_type == "message": v = v.to_dict() @@ -370,7 +454,7 @@ class Message(ABC): returns the instance itself and is therefore assignable and chainable. """ for field in dataclasses.fields(self): - meta: Meta_ = field.metadata.get("betterproto") + meta = FieldMetadata.get(field) if field.name in value: if meta.proto_type == "message": getattr(self, field.name).from_dict(value[field.name]) diff --git a/betterproto/parse.py b/betterproto/parse.py deleted file mode 100644 index a2c9592..0000000 --- a/betterproto/parse.py +++ /dev/null @@ -1,51 +0,0 @@ -import struct -from typing import Union, Generator, Any, SupportsBytes, List, Tuple -from dataclasses import dataclass - - -def parse_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, int]: - """Parse a single varint value from a byte buffer.""" - result = 0 - shift = 0 - while 1: - b = buffer[pos] - result |= (b & 0x7F) << shift - pos += 1 - if not (b & 0x80): - return (result, pos) - shift += 7 - if shift >= 64: - raise ValueError("Too many bytes when decoding varint.") - - -@dataclass(frozen=True) -class ParsedField: - number: int - wire_type: int - value: Any - - -def fields(value: bytes) -> Generator[ParsedField, None, None]: - i = 0 - while i < len(value): - num_wire, i = parse_varint(value, i) - # print(num_wire, i) - number = num_wire >> 3 - wire_type = num_wire & 0x7 - - if wire_type == 0: - decoded, i = parse_varint(value, i) - elif wire_type == 1: - decoded, i = value[i : i + 8], i + 8 - elif wire_type == 2: - length, i = parse_varint(value, i) - decoded = value[i : i + length] - i += length - elif wire_type == 5: - decoded, i = value[i : i + 4], i + 4 - else: - raise NotImplementedError(f"Wire type {wire_type}") - - # print(ParsedField(number=number, wire_type=wire_type, value=decoded)) - - yield ParsedField(number=number, wire_type=wire_type, value=decoded) diff --git a/betterproto/serialize.py b/betterproto/serialize.py deleted file mode 100644 index bade2c6..0000000 --- a/betterproto/serialize.py +++ /dev/null @@ -1,20 +0,0 @@ -import struct -from typing import Union, Generator, Any, SupportsBytes, List, Tuple -from dataclasses import dataclass - - -def _varint(value: int) -> bytes: - # From https://github.com/protocolbuffers/protobuf/blob/master/python/google/protobuf/internal/encoder.py#L372 - b: List[int] = [] - - if value < 0: - value += 1 << 64 - - bits = value & 0x7F - value >>= 7 - while value: - b.append(0x80 | bits) - bits = value & 0x7F - value >>= 7 - print(value) - return bytes(b + [bits]) diff --git a/betterproto/tests/generate.py b/betterproto/tests/generate.py index 51a9ca7..c1cfbde 100644 --- a/betterproto/tests/generate.py +++ b/betterproto/tests/generate.py @@ -1,6 +1,8 @@ #!/usr/bin/env python import os # isort: skip +# Force pure-python implementation instead of C++, otherwise imports +# break things because we can't properly reset the symbol database. os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"