Refactoring
This commit is contained in:
parent
c932fbc72c
commit
1f46e10ba7
1
Pipfile
1
Pipfile
@ -8,6 +8,7 @@ flake8 = "*"
|
|||||||
mypy = "*"
|
mypy = "*"
|
||||||
isort = "*"
|
isort = "*"
|
||||||
pytest = "*"
|
pytest = "*"
|
||||||
|
rope = "*"
|
||||||
|
|
||||||
[packages]
|
[packages]
|
||||||
protobuf = "*"
|
protobuf = "*"
|
||||||
|
23
Pipfile.lock
generated
23
Pipfile.lock
generated
@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"_meta": {
|
"_meta": {
|
||||||
"hash": {
|
"hash": {
|
||||||
"sha256": "817b0f61c21a4841d0cfcc977becb16b4d55090f3d78c1ebcd6974c298a06348"
|
"sha256": "6c1797fb4eb73be97ca566206527c9d648b90f38c5bf2caf4b69537cd325ced9"
|
||||||
},
|
},
|
||||||
"pipfile-spec": 6,
|
"pipfile-spec": 6,
|
||||||
"requires": {
|
"requires": {
|
||||||
@ -18,11 +18,11 @@
|
|||||||
"default": {
|
"default": {
|
||||||
"jinja2": {
|
"jinja2": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:065c4f02ebe7f7cf559e49ee5a95fb800a9e4528727aec6f24402a5374c65013",
|
"sha256:74320bb91f31270f9551d46522e33af46a80c3d619f4a4bf42b3164d30b5911f",
|
||||||
"sha256:14dd6caf1527abb21f08f86c784eac40853ba93edb79552aa1e4b8aef1b61c7b"
|
"sha256:9fe95f19286cfefaa917656583d020be14e7859c6b0252588391e47db34527de"
|
||||||
],
|
],
|
||||||
"index": "pypi",
|
"index": "pypi",
|
||||||
"version": "==2.10.1"
|
"version": "==2.10.3"
|
||||||
},
|
},
|
||||||
"markupsafe": {
|
"markupsafe": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
@ -214,11 +214,20 @@
|
|||||||
},
|
},
|
||||||
"pytest": {
|
"pytest": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
"sha256:13c1c9b22127a77fc684eee24791efafcef343335d855e3573791c68588fe1a5",
|
"sha256:7e4800063ccfc306a53c461442526c5571e1462f61583506ce97e4da6a1d88c8",
|
||||||
"sha256:d8ba7be9466f55ef96ba203fc0f90d0cf212f2f927e69186e1353e30bc7f62e5"
|
"sha256:ca563435f4941d0cb34767301c27bc65c510cb82e90b9ecf9cb52dc2c63caaa0"
|
||||||
],
|
],
|
||||||
"index": "pypi",
|
"index": "pypi",
|
||||||
"version": "==5.2.0"
|
"version": "==5.2.1"
|
||||||
|
},
|
||||||
|
"rope": {
|
||||||
|
"hashes": [
|
||||||
|
"sha256:6b728fdc3e98a83446c27a91fc5d56808a004f8beab7a31ab1d7224cecc7d969",
|
||||||
|
"sha256:c5c5a6a87f7b1a2095fb311135e2a3d1f194f5ecb96900fdd0a9100881f48aaf",
|
||||||
|
"sha256:f0dcf719b63200d492b85535ebe5ea9b29e0d0b8aebeb87fe03fc1a65924fdaf"
|
||||||
|
],
|
||||||
|
"index": "pypi",
|
||||||
|
"version": "==0.14.0"
|
||||||
},
|
},
|
||||||
"six": {
|
"six": {
|
||||||
"hashes": [
|
"hashes": [
|
||||||
|
@ -15,8 +15,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
|
||||||
from . import parse, serialize
|
# Proto 3 data types
|
||||||
|
|
||||||
TYPE_ENUM = "enum"
|
TYPE_ENUM = "enum"
|
||||||
TYPE_BOOL = "bool"
|
TYPE_BOOL = "bool"
|
||||||
TYPE_INT32 = "int32"
|
TYPE_INT32 = "int32"
|
||||||
@ -36,6 +35,7 @@ TYPE_BYTES = "bytes"
|
|||||||
TYPE_MESSAGE = "message"
|
TYPE_MESSAGE = "message"
|
||||||
|
|
||||||
|
|
||||||
|
# Fields that use a fixed amount of space (4 or 8 bytes)
|
||||||
FIXED_TYPES = [
|
FIXED_TYPES = [
|
||||||
TYPE_FLOAT,
|
TYPE_FLOAT,
|
||||||
TYPE_DOUBLE,
|
TYPE_DOUBLE,
|
||||||
@ -45,6 +45,7 @@ FIXED_TYPES = [
|
|||||||
TYPE_SFIXED64,
|
TYPE_SFIXED64,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Fields that are efficiently packed when
|
||||||
PACKED_TYPES = [
|
PACKED_TYPES = [
|
||||||
TYPE_ENUM,
|
TYPE_ENUM,
|
||||||
TYPE_BOOL,
|
TYPE_BOOL,
|
||||||
@ -69,6 +70,7 @@ WIRE_FIXED_64 = 1
|
|||||||
WIRE_LEN_DELIM = 2
|
WIRE_LEN_DELIM = 2
|
||||||
WIRE_FIXED_32 = 5
|
WIRE_FIXED_32 = 5
|
||||||
|
|
||||||
|
# Mappings of which Proto 3 types correspond to which wire types.
|
||||||
WIRE_VARINT_TYPES = [
|
WIRE_VARINT_TYPES = [
|
||||||
TYPE_ENUM,
|
TYPE_ENUM,
|
||||||
TYPE_BOOL,
|
TYPE_BOOL,
|
||||||
@ -86,13 +88,24 @@ WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE]
|
|||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class _Meta:
|
class FieldMetadata:
|
||||||
|
"""Stores internal metadata used for parsing & serialization."""
|
||||||
|
|
||||||
|
# Protobuf field number
|
||||||
number: int
|
number: int
|
||||||
|
# Protobuf type name
|
||||||
proto_type: str
|
proto_type: str
|
||||||
|
# Default value if given
|
||||||
default: Any
|
default: Any
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get(field: dataclasses.Field) -> "FieldMetadata":
|
||||||
|
"""Returns the field metadata for a dataclass field."""
|
||||||
|
return field.metadata["betterproto"]
|
||||||
|
|
||||||
|
|
||||||
def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
|
def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
|
||||||
|
"""Creates a dataclass field with attached protobuf metadata."""
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
if callable(default):
|
if callable(default):
|
||||||
@ -103,7 +116,7 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
|
|||||||
kwargs["default"] = default
|
kwargs["default"] = default
|
||||||
|
|
||||||
return dataclasses.field(
|
return dataclasses.field(
|
||||||
**kwargs, metadata={"betterproto": _Meta(number, proto_type, default)}
|
**kwargs, metadata={"betterproto": FieldMetadata(number, proto_type, default)}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -113,97 +126,91 @@ def field(number: int, proto_type: str, default: Any) -> dataclasses.Field:
|
|||||||
|
|
||||||
|
|
||||||
def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
|
def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
|
||||||
return field(number, "enum", default=default)
|
return field(number, TYPE_ENUM, default=default)
|
||||||
|
|
||||||
|
|
||||||
def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
|
def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
|
||||||
return field(number, "int32", default=default)
|
return field(number, TYPE_INT32, default=default)
|
||||||
|
|
||||||
|
|
||||||
def int64_field(number: int, default: int = 0) -> Any:
|
def int64_field(number: int, default: int = 0) -> Any:
|
||||||
return field(number, "int64", default=default)
|
return field(number, TYPE_INT64, default=default)
|
||||||
|
|
||||||
|
|
||||||
def uint32_field(number: int, default: int = 0) -> Any:
|
def uint32_field(number: int, default: int = 0) -> Any:
|
||||||
return field(number, "uint32", default=default)
|
return field(number, TYPE_UINT32, default=default)
|
||||||
|
|
||||||
|
|
||||||
def uint64_field(number: int, default: int = 0) -> Any:
|
def uint64_field(number: int, default: int = 0) -> Any:
|
||||||
return field(number, "uint64", default=default)
|
return field(number, TYPE_UINT64, default=default)
|
||||||
|
|
||||||
|
|
||||||
def sint32_field(number: int, default: int = 0) -> Any:
|
def sint32_field(number: int, default: int = 0) -> Any:
|
||||||
return field(number, "sint32", default=default)
|
return field(number, TYPE_SINT32, default=default)
|
||||||
|
|
||||||
|
|
||||||
def sint64_field(number: int, default: int = 0) -> Any:
|
def sint64_field(number: int, default: int = 0) -> Any:
|
||||||
return field(number, "sint64", default=default)
|
return field(number, TYPE_SINT64, default=default)
|
||||||
|
|
||||||
|
|
||||||
def float_field(number: int, default: float = 0.0) -> Any:
|
def float_field(number: int, default: float = 0.0) -> Any:
|
||||||
return field(number, "float", default=default)
|
return field(number, TYPE_FLOAT, default=default)
|
||||||
|
|
||||||
|
|
||||||
def double_field(number: int, default: float = 0.0) -> Any:
|
def double_field(number: int, default: float = 0.0) -> Any:
|
||||||
return field(number, "double", default=default)
|
return field(number, TYPE_DOUBLE, default=default)
|
||||||
|
|
||||||
|
|
||||||
def fixed32_field(number: int, default: float = 0.0) -> Any:
|
def fixed32_field(number: int, default: float = 0.0) -> Any:
|
||||||
return field(number, "fixed32", default=default)
|
return field(number, TYPE_FIXED32, default=default)
|
||||||
|
|
||||||
|
|
||||||
def fixed64_field(number: int, default: float = 0.0) -> Any:
|
def fixed64_field(number: int, default: float = 0.0) -> Any:
|
||||||
return field(number, "fixed64", default=default)
|
return field(number, TYPE_FIXED64, default=default)
|
||||||
|
|
||||||
|
|
||||||
def sfixed32_field(number: int, default: float = 0.0) -> Any:
|
def sfixed32_field(number: int, default: float = 0.0) -> Any:
|
||||||
return field(number, "sfixed32", default=default)
|
return field(number, TYPE_SFIXED32, default=default)
|
||||||
|
|
||||||
|
|
||||||
def sfixed64_field(number: int, default: float = 0.0) -> Any:
|
def sfixed64_field(number: int, default: float = 0.0) -> Any:
|
||||||
return field(number, "sfixed64", default=default)
|
return field(number, TYPE_SFIXED64, default=default)
|
||||||
|
|
||||||
|
|
||||||
def string_field(number: int, default: str = "") -> Any:
|
def string_field(number: int, default: str = "") -> Any:
|
||||||
return field(number, "string", default=default)
|
return field(number, TYPE_STRING, default=default)
|
||||||
|
|
||||||
|
|
||||||
def message_field(number: int, default: Type["Message"]) -> Any:
|
def message_field(number: int, default: Type["Message"]) -> Any:
|
||||||
return field(number, "message", default=default)
|
return field(number, TYPE_MESSAGE, default=default)
|
||||||
|
|
||||||
|
|
||||||
def _pack_fmt(proto_type: str) -> str:
|
def _pack_fmt(proto_type: str) -> str:
|
||||||
|
"""Returns a little-endian format string for reading/writing binary."""
|
||||||
return {
|
return {
|
||||||
"double": "<d",
|
TYPE_DOUBLE: "<d",
|
||||||
"float": "<f",
|
TYPE_FLOAT: "<f",
|
||||||
"fixed32": "<I",
|
TYPE_FIXED32: "<I",
|
||||||
"fixed64": "<Q",
|
TYPE_FIXED64: "<Q",
|
||||||
"sfixed32": "<i",
|
TYPE_SFIXED32: "<i",
|
||||||
"sfixed64": "<q",
|
TYPE_SFIXED64: "<q",
|
||||||
}[proto_type]
|
}[proto_type]
|
||||||
|
|
||||||
|
|
||||||
def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
|
def encode_varint(value: int) -> bytes:
|
||||||
value = _preprocess_single(proto_type, value)
|
"""Encodes a single varint value for serialization."""
|
||||||
|
b: List[int] = []
|
||||||
|
|
||||||
output = b""
|
if value < 0:
|
||||||
if proto_type in WIRE_VARINT_TYPES:
|
value += 1 << 64
|
||||||
key = serialize._varint(field_number << 3)
|
|
||||||
output += key + value
|
|
||||||
elif proto_type in WIRE_FIXED_32_TYPES:
|
|
||||||
key = serialize._varint((field_number << 3) | 5)
|
|
||||||
output += key + value
|
|
||||||
elif proto_type in WIRE_FIXED_64_TYPES:
|
|
||||||
key = serialize._varint((field_number << 3) | 1)
|
|
||||||
output += key + value
|
|
||||||
elif proto_type in WIRE_LEN_DELIM_TYPES:
|
|
||||||
if len(value):
|
|
||||||
key = serialize._varint((field_number << 3) | 2)
|
|
||||||
output += key + serialize._varint(len(value)) + value
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(proto_type)
|
|
||||||
|
|
||||||
return output
|
bits = value & 0x7F
|
||||||
|
value >>= 7
|
||||||
|
while value:
|
||||||
|
b.append(0x80 | bits)
|
||||||
|
bits = value & 0x7F
|
||||||
|
value >>= 7
|
||||||
|
return bytes(b + [bits])
|
||||||
|
|
||||||
|
|
||||||
def _preprocess_single(proto_type: str, value: Any) -> bytes:
|
def _preprocess_single(proto_type: str, value: Any) -> bytes:
|
||||||
@ -216,14 +223,14 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
|
|||||||
TYPE_UINT32,
|
TYPE_UINT32,
|
||||||
TYPE_UINT64,
|
TYPE_UINT64,
|
||||||
]:
|
]:
|
||||||
return serialize._varint(value)
|
return encode_varint(value)
|
||||||
elif proto_type in [TYPE_SINT32, TYPE_SINT64]:
|
elif proto_type in [TYPE_SINT32, TYPE_SINT64]:
|
||||||
# Handle zig-zag encoding.
|
# Handle zig-zag encoding.
|
||||||
if value >= 0:
|
if value >= 0:
|
||||||
value = value << 1
|
value = value << 1
|
||||||
else:
|
else:
|
||||||
value = (value << 1) ^ (~0)
|
value = (value << 1) ^ (~0)
|
||||||
return serialize._varint(value)
|
return encode_varint(value)
|
||||||
elif proto_type in FIXED_TYPES:
|
elif proto_type in FIXED_TYPES:
|
||||||
return struct.pack(_pack_fmt(proto_type), value)
|
return struct.pack(_pack_fmt(proto_type), value)
|
||||||
elif proto_type == TYPE_STRING:
|
elif proto_type == TYPE_STRING:
|
||||||
@ -234,7 +241,51 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def _postprocess_single(wire_type: int, meta: _Meta, field: Any, value: Any) -> Any:
|
def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
|
||||||
|
"""Serializes a single field and value."""
|
||||||
|
value = _preprocess_single(proto_type, value)
|
||||||
|
|
||||||
|
output = b""
|
||||||
|
if proto_type in WIRE_VARINT_TYPES:
|
||||||
|
key = encode_varint(field_number << 3)
|
||||||
|
output += key + value
|
||||||
|
elif proto_type in WIRE_FIXED_32_TYPES:
|
||||||
|
key = encode_varint((field_number << 3) | 5)
|
||||||
|
output += key + value
|
||||||
|
elif proto_type in WIRE_FIXED_64_TYPES:
|
||||||
|
key = encode_varint((field_number << 3) | 1)
|
||||||
|
output += key + value
|
||||||
|
elif proto_type in WIRE_LEN_DELIM_TYPES:
|
||||||
|
if len(value):
|
||||||
|
key = encode_varint((field_number << 3) | 2)
|
||||||
|
output += key + encode_varint(len(value)) + value
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(proto_type)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def decode_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Decode a single varint value from a byte buffer. Returns the value and the
|
||||||
|
new position in the buffer.
|
||||||
|
"""
|
||||||
|
result = 0
|
||||||
|
shift = 0
|
||||||
|
while 1:
|
||||||
|
b = buffer[pos]
|
||||||
|
result |= (b & 0x7F) << shift
|
||||||
|
pos += 1
|
||||||
|
if not (b & 0x80):
|
||||||
|
return (result, pos)
|
||||||
|
shift += 7
|
||||||
|
if shift >= 64:
|
||||||
|
raise ValueError("Too many bytes when decoding varint.")
|
||||||
|
|
||||||
|
|
||||||
|
def _postprocess_single(
|
||||||
|
wire_type: int, meta: FieldMetadata, field: Any, value: Any
|
||||||
|
) -> Any:
|
||||||
"""Adjusts values after parsing."""
|
"""Adjusts values after parsing."""
|
||||||
if wire_type == WIRE_VARINT:
|
if wire_type == WIRE_VARINT:
|
||||||
if meta.proto_type in ["int32", "int64"]:
|
if meta.proto_type in ["int32", "int64"]:
|
||||||
@ -257,6 +308,39 @@ def _postprocess_single(wire_type: int, meta: _Meta, field: Any, value: Any) ->
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class ParsedField:
|
||||||
|
number: int
|
||||||
|
wire_type: int
|
||||||
|
value: Any
|
||||||
|
|
||||||
|
|
||||||
|
def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
||||||
|
i = 0
|
||||||
|
while i < len(value):
|
||||||
|
num_wire, i = decode_varint(value, i)
|
||||||
|
# print(num_wire, i)
|
||||||
|
number = num_wire >> 3
|
||||||
|
wire_type = num_wire & 0x7
|
||||||
|
|
||||||
|
if wire_type == 0:
|
||||||
|
decoded, i = decode_varint(value, i)
|
||||||
|
elif wire_type == 1:
|
||||||
|
decoded, i = value[i : i + 8], i + 8
|
||||||
|
elif wire_type == 2:
|
||||||
|
length, i = decode_varint(value, i)
|
||||||
|
decoded = value[i : i + length]
|
||||||
|
i += length
|
||||||
|
elif wire_type == 5:
|
||||||
|
decoded, i = value[i : i + 4], i + 4
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Wire type {wire_type}")
|
||||||
|
|
||||||
|
# print(ParsedField(number=number, wire_type=wire_type, value=decoded))
|
||||||
|
|
||||||
|
yield ParsedField(number=number, wire_type=wire_type, value=decoded)
|
||||||
|
|
||||||
|
|
||||||
# Bound type variable to allow methods to return `self` of subclasses
|
# Bound type variable to allow methods to return `self` of subclasses
|
||||||
T = TypeVar("T", bound="Message")
|
T = TypeVar("T", bound="Message")
|
||||||
|
|
||||||
@ -274,7 +358,7 @@ class Message(ABC):
|
|||||||
"""
|
"""
|
||||||
output = b""
|
output = b""
|
||||||
for field in dataclasses.fields(self):
|
for field in dataclasses.fields(self):
|
||||||
meta: _Meta = field.metadata.get("betterproto")
|
meta = FieldMetadata.get(field)
|
||||||
value = getattr(self, field.name)
|
value = getattr(self, field.name)
|
||||||
|
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
@ -306,10 +390,10 @@ class Message(ABC):
|
|||||||
returns the instance itself and is therefore assignable and chainable.
|
returns the instance itself and is therefore assignable and chainable.
|
||||||
"""
|
"""
|
||||||
fields = {f.metadata["betterproto"].number: f for f in dataclasses.fields(self)}
|
fields = {f.metadata["betterproto"].number: f for f in dataclasses.fields(self)}
|
||||||
for parsed in parse.fields(data):
|
for parsed in parse_fields(data):
|
||||||
if parsed.number in fields:
|
if parsed.number in fields:
|
||||||
field = fields[parsed.number]
|
field = fields[parsed.number]
|
||||||
meta: _Meta = field.metadata.get("betterproto")
|
meta = FieldMetadata.get(field)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
parsed.wire_type == WIRE_LEN_DELIM
|
parsed.wire_type == WIRE_LEN_DELIM
|
||||||
@ -326,7 +410,7 @@ class Message(ABC):
|
|||||||
decoded, pos = parsed.value[pos : pos + 8], pos + 8
|
decoded, pos = parsed.value[pos : pos + 8], pos + 8
|
||||||
wire_type = WIRE_FIXED_64
|
wire_type = WIRE_FIXED_64
|
||||||
else:
|
else:
|
||||||
decoded, pos = parse.parse_varint(parsed.value, pos)
|
decoded, pos = decode_varint(parsed.value, pos)
|
||||||
wire_type = WIRE_VARINT
|
wire_type = WIRE_VARINT
|
||||||
decoded = _postprocess_single(wire_type, meta, field, decoded)
|
decoded = _postprocess_single(wire_type, meta, field, decoded)
|
||||||
value.append(decoded)
|
value.append(decoded)
|
||||||
@ -354,7 +438,7 @@ class Message(ABC):
|
|||||||
"""
|
"""
|
||||||
output = {}
|
output = {}
|
||||||
for field in dataclasses.fields(self):
|
for field in dataclasses.fields(self):
|
||||||
meta: Meta_ = field.metadata.get("betterproto")
|
meta = FieldMetadata.get(field)
|
||||||
v = getattr(self, field.name)
|
v = getattr(self, field.name)
|
||||||
if meta.proto_type == "message":
|
if meta.proto_type == "message":
|
||||||
v = v.to_dict()
|
v = v.to_dict()
|
||||||
@ -370,7 +454,7 @@ class Message(ABC):
|
|||||||
returns the instance itself and is therefore assignable and chainable.
|
returns the instance itself and is therefore assignable and chainable.
|
||||||
"""
|
"""
|
||||||
for field in dataclasses.fields(self):
|
for field in dataclasses.fields(self):
|
||||||
meta: Meta_ = field.metadata.get("betterproto")
|
meta = FieldMetadata.get(field)
|
||||||
if field.name in value:
|
if field.name in value:
|
||||||
if meta.proto_type == "message":
|
if meta.proto_type == "message":
|
||||||
getattr(self, field.name).from_dict(value[field.name])
|
getattr(self, field.name).from_dict(value[field.name])
|
||||||
|
@ -1,51 +0,0 @@
|
|||||||
import struct
|
|
||||||
from typing import Union, Generator, Any, SupportsBytes, List, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
def parse_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, int]:
|
|
||||||
"""Parse a single varint value from a byte buffer."""
|
|
||||||
result = 0
|
|
||||||
shift = 0
|
|
||||||
while 1:
|
|
||||||
b = buffer[pos]
|
|
||||||
result |= (b & 0x7F) << shift
|
|
||||||
pos += 1
|
|
||||||
if not (b & 0x80):
|
|
||||||
return (result, pos)
|
|
||||||
shift += 7
|
|
||||||
if shift >= 64:
|
|
||||||
raise ValueError("Too many bytes when decoding varint.")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ParsedField:
|
|
||||||
number: int
|
|
||||||
wire_type: int
|
|
||||||
value: Any
|
|
||||||
|
|
||||||
|
|
||||||
def fields(value: bytes) -> Generator[ParsedField, None, None]:
|
|
||||||
i = 0
|
|
||||||
while i < len(value):
|
|
||||||
num_wire, i = parse_varint(value, i)
|
|
||||||
# print(num_wire, i)
|
|
||||||
number = num_wire >> 3
|
|
||||||
wire_type = num_wire & 0x7
|
|
||||||
|
|
||||||
if wire_type == 0:
|
|
||||||
decoded, i = parse_varint(value, i)
|
|
||||||
elif wire_type == 1:
|
|
||||||
decoded, i = value[i : i + 8], i + 8
|
|
||||||
elif wire_type == 2:
|
|
||||||
length, i = parse_varint(value, i)
|
|
||||||
decoded = value[i : i + length]
|
|
||||||
i += length
|
|
||||||
elif wire_type == 5:
|
|
||||||
decoded, i = value[i : i + 4], i + 4
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Wire type {wire_type}")
|
|
||||||
|
|
||||||
# print(ParsedField(number=number, wire_type=wire_type, value=decoded))
|
|
||||||
|
|
||||||
yield ParsedField(number=number, wire_type=wire_type, value=decoded)
|
|
@ -1,20 +0,0 @@
|
|||||||
import struct
|
|
||||||
from typing import Union, Generator, Any, SupportsBytes, List, Tuple
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
def _varint(value: int) -> bytes:
|
|
||||||
# From https://github.com/protocolbuffers/protobuf/blob/master/python/google/protobuf/internal/encoder.py#L372
|
|
||||||
b: List[int] = []
|
|
||||||
|
|
||||||
if value < 0:
|
|
||||||
value += 1 << 64
|
|
||||||
|
|
||||||
bits = value & 0x7F
|
|
||||||
value >>= 7
|
|
||||||
while value:
|
|
||||||
b.append(0x80 | bits)
|
|
||||||
bits = value & 0x7F
|
|
||||||
value >>= 7
|
|
||||||
print(value)
|
|
||||||
return bytes(b + [bits])
|
|
@ -1,6 +1,8 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
import os # isort: skip
|
import os # isort: skip
|
||||||
|
|
||||||
|
# Force pure-python implementation instead of C++, otherwise imports
|
||||||
|
# break things because we can't properly reset the symbol database.
|
||||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user