diff --git a/README.md b/README.md index b1b5fa1..7fcde09 100644 --- a/README.md +++ b/README.md @@ -218,7 +218,7 @@ $ pipenv run tests - [x] Repeated message fields - [x] Maps - [x] Maps of message fields -- [ ] Support passthrough of unknown fields +- [x] Support passthrough of unknown fields - [x] Refs to nested types - [x] Imports in proto files - [x] Well-known Google types diff --git a/betterproto/__init__.py b/betterproto/__init__.py index e53f869..b0fb445 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -341,17 +341,19 @@ class ParsedField: number: int wire_type: int value: Any + raw: bytes def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: i = 0 while i < len(value): + start = i num_wire, i = decode_varint(value, i) # print(num_wire, i) number = num_wire >> 3 wire_type = num_wire & 0x7 - decoded: Any + decoded: Any = None if wire_type == 0: decoded, i = decode_varint(value, i) elif wire_type == 1: @@ -362,12 +364,12 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: i += length elif wire_type == 5: decoded, i = value[i : i + 4], i + 4 - else: - raise NotImplementedError(f"Wire type {wire_type}") # print(ParsedField(number=number, wire_type=wire_type, value=decoded)) - yield ParsedField(number=number, wire_type=wire_type, value=decoded) + yield ParsedField( + number=number, wire_type=wire_type, value=decoded, raw=value[start:i] + ) # Bound type variable to allow methods to return `self` of subclasses @@ -415,6 +417,7 @@ class Message(ABC): # Now that all the defaults are set, reset it! self.__dict__["serialized_on_wire"] = False + self.__dict__["_unknown_fields"] = b"" def __setattr__(self, attr: str, value: Any) -> None: if attr != "serialized_on_wire": @@ -469,7 +472,7 @@ class Message(ABC): meta.number, meta.proto_type, value, serialize_empty=serialize_empty ) - return output + return output + self._unknown_fields # For compatibility with other libraries SerializeToString = __bytes__ @@ -571,8 +574,7 @@ class Message(ABC): else: setattr(self, field.name, value) else: - # TODO: handle unknown fields - pass + self._unknown_fields += parsed.raw return self diff --git a/betterproto/tests/test_features.py b/betterproto/tests/test_features.py index 4a6c4d4..8afaa74 100644 --- a/betterproto/tests/test_features.py +++ b/betterproto/tests/test_features.py @@ -48,3 +48,25 @@ def test_enum_as_int_json(): # Plain-ol'-ints should serialize properly too. foo.bar = 1 assert foo.to_dict() == {"bar": "ONE"} + + +def test_unknown_fields(): + @dataclass + class Newer(betterproto.Message): + foo: bool = betterproto.bool_field(1) + bar: int = betterproto.int32_field(2) + baz: str = betterproto.string_field(3) + + @dataclass + class Older(betterproto.Message): + foo: bool = betterproto.bool_field(1) + + newer = Newer(foo=True, bar=1, baz="Hello") + serialized_newer = bytes(newer) + + # Unknown fields in `Newer` should round trip with `Older` + round_trip = bytes(Older().parse(serialized_newer)) + assert serialized_newer == round_trip + + new_again = Newer().parse(round_trip) + assert newer == new_again