Support pass-through of unknown fields
This commit is contained in:
parent
b5c1f1aa7c
commit
a5fac1c2ae
@ -218,7 +218,7 @@ $ pipenv run tests
|
||||
- [x] Repeated message fields
|
||||
- [x] Maps
|
||||
- [x] Maps of message fields
|
||||
- [ ] Support passthrough of unknown fields
|
||||
- [x] Support passthrough of unknown fields
|
||||
- [x] Refs to nested types
|
||||
- [x] Imports in proto files
|
||||
- [x] Well-known Google types
|
||||
|
@ -341,17 +341,19 @@ class ParsedField:
|
||||
number: int
|
||||
wire_type: int
|
||||
value: Any
|
||||
raw: bytes
|
||||
|
||||
|
||||
def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
||||
i = 0
|
||||
while i < len(value):
|
||||
start = i
|
||||
num_wire, i = decode_varint(value, i)
|
||||
# print(num_wire, i)
|
||||
number = num_wire >> 3
|
||||
wire_type = num_wire & 0x7
|
||||
|
||||
decoded: Any
|
||||
decoded: Any = None
|
||||
if wire_type == 0:
|
||||
decoded, i = decode_varint(value, i)
|
||||
elif wire_type == 1:
|
||||
@ -362,12 +364,12 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
||||
i += length
|
||||
elif wire_type == 5:
|
||||
decoded, i = value[i : i + 4], i + 4
|
||||
else:
|
||||
raise NotImplementedError(f"Wire type {wire_type}")
|
||||
|
||||
# print(ParsedField(number=number, wire_type=wire_type, value=decoded))
|
||||
|
||||
yield ParsedField(number=number, wire_type=wire_type, value=decoded)
|
||||
yield ParsedField(
|
||||
number=number, wire_type=wire_type, value=decoded, raw=value[start:i]
|
||||
)
|
||||
|
||||
|
||||
# Bound type variable to allow methods to return `self` of subclasses
|
||||
@ -415,6 +417,7 @@ class Message(ABC):
|
||||
|
||||
# Now that all the defaults are set, reset it!
|
||||
self.__dict__["serialized_on_wire"] = False
|
||||
self.__dict__["_unknown_fields"] = b""
|
||||
|
||||
def __setattr__(self, attr: str, value: Any) -> None:
|
||||
if attr != "serialized_on_wire":
|
||||
@ -469,7 +472,7 @@ class Message(ABC):
|
||||
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
|
||||
)
|
||||
|
||||
return output
|
||||
return output + self._unknown_fields
|
||||
|
||||
# For compatibility with other libraries
|
||||
SerializeToString = __bytes__
|
||||
@ -571,8 +574,7 @@ class Message(ABC):
|
||||
else:
|
||||
setattr(self, field.name, value)
|
||||
else:
|
||||
# TODO: handle unknown fields
|
||||
pass
|
||||
self._unknown_fields += parsed.raw
|
||||
|
||||
return self
|
||||
|
||||
|
@ -48,3 +48,25 @@ def test_enum_as_int_json():
|
||||
# Plain-ol'-ints should serialize properly too.
|
||||
foo.bar = 1
|
||||
assert foo.to_dict() == {"bar": "ONE"}
|
||||
|
||||
|
||||
def test_unknown_fields():
|
||||
@dataclass
|
||||
class Newer(betterproto.Message):
|
||||
foo: bool = betterproto.bool_field(1)
|
||||
bar: int = betterproto.int32_field(2)
|
||||
baz: str = betterproto.string_field(3)
|
||||
|
||||
@dataclass
|
||||
class Older(betterproto.Message):
|
||||
foo: bool = betterproto.bool_field(1)
|
||||
|
||||
newer = Newer(foo=True, bar=1, baz="Hello")
|
||||
serialized_newer = bytes(newer)
|
||||
|
||||
# Unknown fields in `Newer` should round trip with `Older`
|
||||
round_trip = bytes(Older().parse(serialized_newer))
|
||||
assert serialized_newer == round_trip
|
||||
|
||||
new_again = Newer().parse(round_trip)
|
||||
assert newer == new_again
|
||||
|
Loading…
x
Reference in New Issue
Block a user