From ad7162a3ec11c193bbc4d3acf8b94138d2c22cc4 Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Wed, 9 Oct 2019 20:46:16 -0700 Subject: [PATCH] Support for repeated message fields --- README.md | 2 +- betterproto/__init__.py | 90 ++++++++++++++++--------- betterproto/tests/repeatedmessage.json | 10 +++ betterproto/tests/repeatedmessage.proto | 9 +++ 4 files changed, 77 insertions(+), 34 deletions(-) create mode 100644 betterproto/tests/repeatedmessage.json create mode 100644 betterproto/tests/repeatedmessage.proto diff --git a/README.md b/README.md index 897b31d..3cff8e8 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ - [x] Zig-zag signed fields (sint32, sint64) - [x] Don't encode zero values for nested types - [x] Enums -- [ ] Repeated message fields +- [x] Repeated message fields - [ ] Maps - [ ] Support passthrough of unknown fields - [ ] Refs to nested types diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 966b9a6..cf4b6e0 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -2,6 +2,7 @@ from abc import ABC import json import struct from typing import ( + get_type_hints, Union, Generator, Any, @@ -15,6 +16,8 @@ from typing import ( ) import dataclasses +import inspect + # Proto 3 data types TYPE_ENUM = "enum" TYPE_BOOL = "bool" @@ -283,35 +286,6 @@ def decode_varint(buffer: bytes, pos: int, signed: bool = False) -> Tuple[int, i raise ValueError("Too many bytes when decoding varint.") -def _postprocess_single( - wire_type: int, meta: FieldMetadata, field: Any, value: Any -) -> Any: - """Adjusts values after parsing.""" - if wire_type == WIRE_VARINT: - if meta.proto_type in ["int32", "int64"]: - bits = int(meta.proto_type[3:]) - value = value & ((1 << bits) - 1) - signbit = 1 << (bits - 1) - value = int((value ^ signbit) - signbit) - elif meta.proto_type in ["sint32", "sint64"]: - # Undo zig-zag encoding - value = (value >> 1) ^ (-(value & 1)) - elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]: - fmt = _pack_fmt(meta.proto_type) - value = struct.unpack(fmt, value)[0] - elif wire_type == WIRE_LEN_DELIM: - if meta.proto_type in ["string"]: - value = value.decode("utf-8") - elif meta.proto_type in ["message"]: - orig = value - value = field.default_factory() - if isinstance(value, Message): - # If it's a message (instead of e.g. list) then keep going! - value.parse(orig) - - return value - - @dataclasses.dataclass(frozen=True) class ParsedField: number: int @@ -388,6 +362,41 @@ class Message(ABC): return output + def _cls_for(self, field: dataclasses.Field) -> Type: + """Get the message class for a field from the type hints.""" + module = inspect.getmodule(self) + type_hints = get_type_hints(self, vars(module)) + cls = type_hints[field.name] + if hasattr(cls, "__args__"): + print(type_hints[field.name].__args__[0]) + cls = type_hints[field.name].__args__[0] + return cls + + def _postprocess_single( + self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any + ) -> Any: + """Adjusts values after parsing.""" + if wire_type == WIRE_VARINT: + if meta.proto_type in ["int32", "int64"]: + bits = int(meta.proto_type[3:]) + value = value & ((1 << bits) - 1) + signbit = 1 << (bits - 1) + value = int((value ^ signbit) - signbit) + elif meta.proto_type in ["sint32", "sint64"]: + # Undo zig-zag encoding + value = (value >> 1) ^ (-(value & 1)) + elif wire_type in [WIRE_FIXED_32, WIRE_FIXED_64]: + fmt = _pack_fmt(meta.proto_type) + value = struct.unpack(fmt, value)[0] + elif wire_type == WIRE_LEN_DELIM: + if meta.proto_type in ["string"]: + value = value.decode("utf-8") + elif meta.proto_type in ["message"]: + cls = self._cls_for(field) + value = cls().parse(value) + + return value + def parse(self, data: bytes) -> T: """ Parse the binary encoded Protobuf into this message instance. This @@ -416,10 +425,12 @@ class Message(ABC): else: decoded, pos = decode_varint(parsed.value, pos) wire_type = WIRE_VARINT - decoded = _postprocess_single(wire_type, meta, field, decoded) + decoded = self._postprocess_single( + wire_type, meta, field, decoded + ) value.append(decoded) else: - value = _postprocess_single( + value = self._postprocess_single( parsed.wire_type, meta, field, parsed.value ) @@ -445,7 +456,13 @@ class Message(ABC): meta = FieldMetadata.get(field) v = getattr(self, field.name) if meta.proto_type == "message": - v = v.to_dict() + if isinstance(v, list): + # Convert each item. + v = [i.to_dict() for i in v] + # Filter out empty items which we won't serialize. + v = [i for i in v if i] + else: + v = v.to_dict() if v: output[field.name] = v elif v != field.default: @@ -461,7 +478,14 @@ class Message(ABC): meta = FieldMetadata.get(field) if field.name in value: if meta.proto_type == "message": - getattr(self, field.name).from_dict(value[field.name]) + v = getattr(self, field.name) + print(v, value[field.name]) + if isinstance(v, list): + cls = self._cls_for(field) + for i in range(len(value[field.name])): + v.append(cls().from_dict(value[field.name][i])) + else: + v.from_dict(value[field.name]) else: setattr(self, field.name, value[field.name]) return self diff --git a/betterproto/tests/repeatedmessage.json b/betterproto/tests/repeatedmessage.json new file mode 100644 index 0000000..90ec596 --- /dev/null +++ b/betterproto/tests/repeatedmessage.json @@ -0,0 +1,10 @@ +{ + "greetings": [ + { + "greeting": "hello" + }, + { + "greeting": "hi" + } + ] +} diff --git a/betterproto/tests/repeatedmessage.proto b/betterproto/tests/repeatedmessage.proto new file mode 100644 index 0000000..ea4c01f --- /dev/null +++ b/betterproto/tests/repeatedmessage.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +message Test { + repeated Sub greetings = 1; +} + +message Sub { + string greeting = 1; +} \ No newline at end of file