Support for repeated message fields
This commit is contained in:
parent
1a488faf7a
commit
ad7162a3ec
@ -5,7 +5,7 @@
|
||||
- [x] Zig-zag signed fields (sint32, sint64)
|
||||
- [x] Don't encode zero values for nested types
|
||||
- [x] Enums
|
||||
- [ ] Repeated message fields
|
||||
- [x] Repeated message fields
|
||||
- [ ] Maps
|
||||
- [ ] Support passthrough of unknown fields
|
||||
- [ ] Refs to nested types
|
||||
|
@ -2,6 +2,7 @@ from abc import ABC
|
||||
import json
|
||||
import struct
|
||||
from typing import (
|
||||
get_type_hints,
|
||||
Union,
|
||||
Generator,
|
||||
Any,
|
||||
@ -15,6 +16,8 @@ from typing import (
|
||||
)
|
||||
import dataclasses
|
||||
|
||||
import inspect
|
||||
|
||||
# Proto 3 data types
|
||||
TYPE_ENUM = "enum"
|
||||
TYPE_BOOL = "bool"
|
||||
@ -283,35 +286,6 @@ def decode_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, i
|
||||
raise ValueError("Too many bytes when decoding varint.")
|
||||
|
||||
|
||||
def _postprocess_single(
|
||||
wire_type: int, meta: FieldMetadata, field: Any, value: Any
|
||||
) -> Any:
|
||||
"""Adjusts values after parsing."""
|
||||
if wire_type == WIRE_VARINT:
|
||||
if meta.proto_type in ["int32", "int64"]:
|
||||
bits = int(meta.proto_type[3:])
|
||||
value = value & ((1 << bits) - 1)
|
||||
signbit = 1 << (bits - 1)
|
||||
value = int((value ^ signbit) - signbit)
|
||||
elif meta.proto_type in ["sint32", "sint64"]:
|
||||
# Undo zig-zag encoding
|
||||
value = (value >> 1) ^ (-(value & 1))
|
||||
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
|
||||
fmt = _pack_fmt(meta.proto_type)
|
||||
value = struct.unpack(fmt, value)[0]
|
||||
elif wire_type == WIRE_LEN_DELIM:
|
||||
if meta.proto_type in ["string"]:
|
||||
value = value.decode("utf-8")
|
||||
elif meta.proto_type in ["message"]:
|
||||
orig = value
|
||||
value = field.default_factory()
|
||||
if isinstance(value, Message):
|
||||
# If it's a message (instead of e.g. list) then keep going!
|
||||
value.parse(orig)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ParsedField:
|
||||
number: int
|
||||
@ -388,6 +362,41 @@ class Message(ABC):
|
||||
|
||||
return output
|
||||
|
||||
def _cls_for(self, field: dataclasses.Field) -> Type:
|
||||
"""Get the message class for a field from the type hints."""
|
||||
module = inspect.getmodule(self)
|
||||
type_hints = get_type_hints(self, vars(module))
|
||||
cls = type_hints[field.name]
|
||||
if hasattr(cls, "__args__"):
|
||||
print(type_hints[field.name].__args__[0])
|
||||
cls = type_hints[field.name].__args__[0]
|
||||
return cls
|
||||
|
||||
def _postprocess_single(
|
||||
self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any
|
||||
) -> Any:
|
||||
"""Adjusts values after parsing."""
|
||||
if wire_type == WIRE_VARINT:
|
||||
if meta.proto_type in ["int32", "int64"]:
|
||||
bits = int(meta.proto_type[3:])
|
||||
value = value & ((1 << bits) - 1)
|
||||
signbit = 1 << (bits - 1)
|
||||
value = int((value ^ signbit) - signbit)
|
||||
elif meta.proto_type in ["sint32", "sint64"]:
|
||||
# Undo zig-zag encoding
|
||||
value = (value >> 1) ^ (-(value & 1))
|
||||
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
|
||||
fmt = _pack_fmt(meta.proto_type)
|
||||
value = struct.unpack(fmt, value)[0]
|
||||
elif wire_type == WIRE_LEN_DELIM:
|
||||
if meta.proto_type in ["string"]:
|
||||
value = value.decode("utf-8")
|
||||
elif meta.proto_type in ["message"]:
|
||||
cls = self._cls_for(field)
|
||||
value = cls().parse(value)
|
||||
|
||||
return value
|
||||
|
||||
def parse(self, data: bytes) -> T:
|
||||
"""
|
||||
Parse the binary encoded Protobuf into this message instance. This
|
||||
@ -416,10 +425,12 @@ class Message(ABC):
|
||||
else:
|
||||
decoded, pos = decode_varint(parsed.value, pos)
|
||||
wire_type = WIRE_VARINT
|
||||
decoded = _postprocess_single(wire_type, meta, field, decoded)
|
||||
decoded = self._postprocess_single(
|
||||
wire_type, meta, field, decoded
|
||||
)
|
||||
value.append(decoded)
|
||||
else:
|
||||
value = _postprocess_single(
|
||||
value = self._postprocess_single(
|
||||
parsed.wire_type, meta, field, parsed.value
|
||||
)
|
||||
|
||||
@ -445,7 +456,13 @@ class Message(ABC):
|
||||
meta = FieldMetadata.get(field)
|
||||
v = getattr(self, field.name)
|
||||
if meta.proto_type == "message":
|
||||
v = v.to_dict()
|
||||
if isinstance(v, list):
|
||||
# Convert each item.
|
||||
v = [i.to_dict() for i in v]
|
||||
# Filter out empty items which we won't serialize.
|
||||
v = [i for i in v if i]
|
||||
else:
|
||||
v = v.to_dict()
|
||||
if v:
|
||||
output[field.name] = v
|
||||
elif v != field.default:
|
||||
@ -461,7 +478,14 @@ class Message(ABC):
|
||||
meta = FieldMetadata.get(field)
|
||||
if field.name in value:
|
||||
if meta.proto_type == "message":
|
||||
getattr(self, field.name).from_dict(value[field.name])
|
||||
v = getattr(self, field.name)
|
||||
print(v, value[field.name])
|
||||
if isinstance(v, list):
|
||||
cls = self._cls_for(field)
|
||||
for i in range(len(value[field.name])):
|
||||
v.append(cls().from_dict(value[field.name][i]))
|
||||
else:
|
||||
v.from_dict(value[field.name])
|
||||
else:
|
||||
setattr(self, field.name, value[field.name])
|
||||
return self
|
||||
|
10
betterproto/tests/repeatedmessage.json
Normal file
10
betterproto/tests/repeatedmessage.json
Normal file
@ -0,0 +1,10 @@
|
||||
{
|
||||
"greetings": [
|
||||
{
|
||||
"greeting": "hello"
|
||||
},
|
||||
{
|
||||
"greeting": "hi"
|
||||
}
|
||||
]
|
||||
}
|
9
betterproto/tests/repeatedmessage.proto
Normal file
9
betterproto/tests/repeatedmessage.proto
Normal file
@ -0,0 +1,9 @@
|
||||
syntax = "proto3";
|
||||
|
||||
message Test {
|
||||
repeated Sub greetings = 1;
|
||||
}
|
||||
|
||||
message Sub {
|
||||
string greeting = 1;
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user