Support for repeated message fields
This commit is contained in:
parent
1a488faf7a
commit
ad7162a3ec
@ -5,7 +5,7 @@
|
|||||||
- [x] Zig-zag signed fields (sint32, sint64)
|
- [x] Zig-zag signed fields (sint32, sint64)
|
||||||
- [x] Don't encode zero values for nested types
|
- [x] Don't encode zero values for nested types
|
||||||
- [x] Enums
|
- [x] Enums
|
||||||
- [ ] Repeated message fields
|
- [x] Repeated message fields
|
||||||
- [ ] Maps
|
- [ ] Maps
|
||||||
- [ ] Support passthrough of unknown fields
|
- [ ] Support passthrough of unknown fields
|
||||||
- [ ] Refs to nested types
|
- [ ] Refs to nested types
|
||||||
|
@ -2,6 +2,7 @@ from abc import ABC
|
|||||||
import json
|
import json
|
||||||
import struct
|
import struct
|
||||||
from typing import (
|
from typing import (
|
||||||
|
get_type_hints,
|
||||||
Union,
|
Union,
|
||||||
Generator,
|
Generator,
|
||||||
Any,
|
Any,
|
||||||
@ -15,6 +16,8 @@ from typing import (
|
|||||||
)
|
)
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
|
||||||
# Proto 3 data types
|
# Proto 3 data types
|
||||||
TYPE_ENUM = "enum"
|
TYPE_ENUM = "enum"
|
||||||
TYPE_BOOL = "bool"
|
TYPE_BOOL = "bool"
|
||||||
@ -283,35 +286,6 @@ def decode_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, i
|
|||||||
raise ValueError("Too many bytes when decoding varint.")
|
raise ValueError("Too many bytes when decoding varint.")
|
||||||
|
|
||||||
|
|
||||||
def _postprocess_single(
|
|
||||||
wire_type: int, meta: FieldMetadata, field: Any, value: Any
|
|
||||||
) -> Any:
|
|
||||||
"""Adjusts values after parsing."""
|
|
||||||
if wire_type == WIRE_VARINT:
|
|
||||||
if meta.proto_type in ["int32", "int64"]:
|
|
||||||
bits = int(meta.proto_type[3:])
|
|
||||||
value = value & ((1 << bits) - 1)
|
|
||||||
signbit = 1 << (bits - 1)
|
|
||||||
value = int((value ^ signbit) - signbit)
|
|
||||||
elif meta.proto_type in ["sint32", "sint64"]:
|
|
||||||
# Undo zig-zag encoding
|
|
||||||
value = (value >> 1) ^ (-(value & 1))
|
|
||||||
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
|
|
||||||
fmt = _pack_fmt(meta.proto_type)
|
|
||||||
value = struct.unpack(fmt, value)[0]
|
|
||||||
elif wire_type == WIRE_LEN_DELIM:
|
|
||||||
if meta.proto_type in ["string"]:
|
|
||||||
value = value.decode("utf-8")
|
|
||||||
elif meta.proto_type in ["message"]:
|
|
||||||
orig = value
|
|
||||||
value = field.default_factory()
|
|
||||||
if isinstance(value, Message):
|
|
||||||
# If it's a message (instead of e.g. list) then keep going!
|
|
||||||
value.parse(orig)
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class ParsedField:
|
class ParsedField:
|
||||||
number: int
|
number: int
|
||||||
@ -388,6 +362,41 @@ class Message(ABC):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def _cls_for(self, field: dataclasses.Field) -> Type:
|
||||||
|
"""Get the message class for a field from the type hints."""
|
||||||
|
module = inspect.getmodule(self)
|
||||||
|
type_hints = get_type_hints(self, vars(module))
|
||||||
|
cls = type_hints[field.name]
|
||||||
|
if hasattr(cls, "__args__"):
|
||||||
|
print(type_hints[field.name].__args__[0])
|
||||||
|
cls = type_hints[field.name].__args__[0]
|
||||||
|
return cls
|
||||||
|
|
||||||
|
def _postprocess_single(
|
||||||
|
self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any
|
||||||
|
) -> Any:
|
||||||
|
"""Adjusts values after parsing."""
|
||||||
|
if wire_type == WIRE_VARINT:
|
||||||
|
if meta.proto_type in ["int32", "int64"]:
|
||||||
|
bits = int(meta.proto_type[3:])
|
||||||
|
value = value & ((1 << bits) - 1)
|
||||||
|
signbit = 1 << (bits - 1)
|
||||||
|
value = int((value ^ signbit) - signbit)
|
||||||
|
elif meta.proto_type in ["sint32", "sint64"]:
|
||||||
|
# Undo zig-zag encoding
|
||||||
|
value = (value >> 1) ^ (-(value & 1))
|
||||||
|
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
|
||||||
|
fmt = _pack_fmt(meta.proto_type)
|
||||||
|
value = struct.unpack(fmt, value)[0]
|
||||||
|
elif wire_type == WIRE_LEN_DELIM:
|
||||||
|
if meta.proto_type in ["string"]:
|
||||||
|
value = value.decode("utf-8")
|
||||||
|
elif meta.proto_type in ["message"]:
|
||||||
|
cls = self._cls_for(field)
|
||||||
|
value = cls().parse(value)
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
def parse(self, data: bytes) -> T:
|
def parse(self, data: bytes) -> T:
|
||||||
"""
|
"""
|
||||||
Parse the binary encoded Protobuf into this message instance. This
|
Parse the binary encoded Protobuf into this message instance. This
|
||||||
@ -416,10 +425,12 @@ class Message(ABC):
|
|||||||
else:
|
else:
|
||||||
decoded, pos = decode_varint(parsed.value, pos)
|
decoded, pos = decode_varint(parsed.value, pos)
|
||||||
wire_type = WIRE_VARINT
|
wire_type = WIRE_VARINT
|
||||||
decoded = _postprocess_single(wire_type, meta, field, decoded)
|
decoded = self._postprocess_single(
|
||||||
|
wire_type, meta, field, decoded
|
||||||
|
)
|
||||||
value.append(decoded)
|
value.append(decoded)
|
||||||
else:
|
else:
|
||||||
value = _postprocess_single(
|
value = self._postprocess_single(
|
||||||
parsed.wire_type, meta, field, parsed.value
|
parsed.wire_type, meta, field, parsed.value
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -445,7 +456,13 @@ class Message(ABC):
|
|||||||
meta = FieldMetadata.get(field)
|
meta = FieldMetadata.get(field)
|
||||||
v = getattr(self, field.name)
|
v = getattr(self, field.name)
|
||||||
if meta.proto_type == "message":
|
if meta.proto_type == "message":
|
||||||
v = v.to_dict()
|
if isinstance(v, list):
|
||||||
|
# Convert each item.
|
||||||
|
v = [i.to_dict() for i in v]
|
||||||
|
# Filter out empty items which we won't serialize.
|
||||||
|
v = [i for i in v if i]
|
||||||
|
else:
|
||||||
|
v = v.to_dict()
|
||||||
if v:
|
if v:
|
||||||
output[field.name] = v
|
output[field.name] = v
|
||||||
elif v != field.default:
|
elif v != field.default:
|
||||||
@ -461,7 +478,14 @@ class Message(ABC):
|
|||||||
meta = FieldMetadata.get(field)
|
meta = FieldMetadata.get(field)
|
||||||
if field.name in value:
|
if field.name in value:
|
||||||
if meta.proto_type == "message":
|
if meta.proto_type == "message":
|
||||||
getattr(self, field.name).from_dict(value[field.name])
|
v = getattr(self, field.name)
|
||||||
|
print(v, value[field.name])
|
||||||
|
if isinstance(v, list):
|
||||||
|
cls = self._cls_for(field)
|
||||||
|
for i in range(len(value[field.name])):
|
||||||
|
v.append(cls().from_dict(value[field.name][i]))
|
||||||
|
else:
|
||||||
|
v.from_dict(value[field.name])
|
||||||
else:
|
else:
|
||||||
setattr(self, field.name, value[field.name])
|
setattr(self, field.name, value[field.name])
|
||||||
return self
|
return self
|
||||||
|
10
betterproto/tests/repeatedmessage.json
Normal file
10
betterproto/tests/repeatedmessage.json
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"greetings": [
|
||||||
|
{
|
||||||
|
"greeting": "hello"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"greeting": "hi"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
9
betterproto/tests/repeatedmessage.proto
Normal file
9
betterproto/tests/repeatedmessage.proto
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
message Test {
|
||||||
|
repeated Sub greetings = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message Sub {
|
||||||
|
string greeting = 1;
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user