Support for repeated message fields

This commit is contained in:
Daniel G. Taylor 2019-10-09 20:46:16 -07:00
parent 1a488faf7a
commit ad7162a3ec
No known key found for this signature in database
GPG Key ID: 7BD6DC99C9A87E22
4 changed files with 77 additions and 34 deletions

View File

@ -5,7 +5,7 @@
- [x] Zig-zag signed fields (sint32, sint64)
- [x] Don't encode zero values for nested types
- [x] Enums
- [ ] Repeated message fields
- [x] Repeated message fields
- [ ] Maps
- [ ] Support passthrough of unknown fields
- [ ] Refs to nested types

View File

@ -2,6 +2,7 @@ from abc import ABC
import json
import struct
from typing import (
get_type_hints,
Union,
Generator,
Any,
@ -15,6 +16,8 @@ from typing import (
)
import dataclasses
import inspect
# Proto 3 data types
TYPE_ENUM = "enum"
TYPE_BOOL = "bool"
@ -283,35 +286,6 @@ def decode_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, i
raise ValueError("Too many bytes when decoding varint.")
def _postprocess_single(
wire_type: int, meta: FieldMetadata, field: Any, value: Any
) -> Any:
"""Adjusts values after parsing."""
if wire_type == WIRE_VARINT:
if meta.proto_type in ["int32", "int64"]:
bits = int(meta.proto_type[3:])
value = value & ((1 << bits) - 1)
signbit = 1 << (bits - 1)
value = int((value ^ signbit) - signbit)
elif meta.proto_type in ["sint32", "sint64"]:
# Undo zig-zag encoding
value = (value >> 1) ^ (-(value & 1))
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
fmt = _pack_fmt(meta.proto_type)
value = struct.unpack(fmt, value)[0]
elif wire_type == WIRE_LEN_DELIM:
if meta.proto_type in ["string"]:
value = value.decode("utf-8")
elif meta.proto_type in ["message"]:
orig = value
value = field.default_factory()
if isinstance(value, Message):
# If it's a message (instead of e.g. list) then keep going!
value.parse(orig)
return value
@dataclasses.dataclass(frozen=True)
class ParsedField:
number: int
@ -388,6 +362,41 @@ class Message(ABC):
return output
def _cls_for(self, field: dataclasses.Field) -> Type:
"""Get the message class for a field from the type hints."""
module = inspect.getmodule(self)
type_hints = get_type_hints(self, vars(module))
cls = type_hints[field.name]
if hasattr(cls, "__args__"):
print(type_hints[field.name].__args__[0])
cls = type_hints[field.name].__args__[0]
return cls
def _postprocess_single(
self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any
) -> Any:
"""Adjusts values after parsing."""
if wire_type == WIRE_VARINT:
if meta.proto_type in ["int32", "int64"]:
bits = int(meta.proto_type[3:])
value = value & ((1 << bits) - 1)
signbit = 1 << (bits - 1)
value = int((value ^ signbit) - signbit)
elif meta.proto_type in ["sint32", "sint64"]:
# Undo zig-zag encoding
value = (value >> 1) ^ (-(value & 1))
elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]:
fmt = _pack_fmt(meta.proto_type)
value = struct.unpack(fmt, value)[0]
elif wire_type == WIRE_LEN_DELIM:
if meta.proto_type in ["string"]:
value = value.decode("utf-8")
elif meta.proto_type in ["message"]:
cls = self._cls_for(field)
value = cls().parse(value)
return value
def parse(self, data: bytes) -> T:
"""
Parse the binary encoded Protobuf into this message instance. This
@ -416,10 +425,12 @@ class Message(ABC):
else:
decoded, pos = decode_varint(parsed.value, pos)
wire_type = WIRE_VARINT
decoded = _postprocess_single(wire_type, meta, field, decoded)
decoded = self._postprocess_single(
wire_type, meta, field, decoded
)
value.append(decoded)
else:
value = _postprocess_single(
value = self._postprocess_single(
parsed.wire_type, meta, field, parsed.value
)
@ -445,7 +456,13 @@ class Message(ABC):
meta = FieldMetadata.get(field)
v = getattr(self, field.name)
if meta.proto_type == "message":
v = v.to_dict()
if isinstance(v, list):
# Convert each item.
v = [i.to_dict() for i in v]
# Filter out empty items which we won't serialize.
v = [i for i in v if i]
else:
v = v.to_dict()
if v:
output[field.name] = v
elif v != field.default:
@ -461,7 +478,14 @@ class Message(ABC):
meta = FieldMetadata.get(field)
if field.name in value:
if meta.proto_type == "message":
getattr(self, field.name).from_dict(value[field.name])
v = getattr(self, field.name)
print(v, value[field.name])
if isinstance(v, list):
cls = self._cls_for(field)
for i in range(len(value[field.name])):
v.append(cls().from_dict(value[field.name][i]))
else:
v.from_dict(value[field.name])
else:
setattr(self, field.name, value[field.name])
return self

View File

@ -0,0 +1,10 @@
{
"greetings": [
{
"greeting": "hello"
},
{
"greeting": "hi"
}
]
}

View File

@ -0,0 +1,9 @@
syntax = "proto3";
message Test {
repeated Sub greetings = 1;
}
message Sub {
string greeting = 1;
}