Support pass-through of unknown fields
This commit is contained in:
parent
b5c1f1aa7c
commit
a5fac1c2ae
@ -218,7 +218,7 @@ $ pipenv run tests
|
|||||||
- [x] Repeated message fields
|
- [x] Repeated message fields
|
||||||
- [x] Maps
|
- [x] Maps
|
||||||
- [x] Maps of message fields
|
- [x] Maps of message fields
|
||||||
- [ ] Support passthrough of unknown fields
|
- [x] Support passthrough of unknown fields
|
||||||
- [x] Refs to nested types
|
- [x] Refs to nested types
|
||||||
- [x] Imports in proto files
|
- [x] Imports in proto files
|
||||||
- [x] Well-known Google types
|
- [x] Well-known Google types
|
||||||
|
@ -341,17 +341,19 @@ class ParsedField:
|
|||||||
number: int
|
number: int
|
||||||
wire_type: int
|
wire_type: int
|
||||||
value: Any
|
value: Any
|
||||||
|
raw: bytes
|
||||||
|
|
||||||
|
|
||||||
def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
||||||
i = 0
|
i = 0
|
||||||
while i < len(value):
|
while i < len(value):
|
||||||
|
start = i
|
||||||
num_wire, i = decode_varint(value, i)
|
num_wire, i = decode_varint(value, i)
|
||||||
# print(num_wire, i)
|
# print(num_wire, i)
|
||||||
number = num_wire >> 3
|
number = num_wire >> 3
|
||||||
wire_type = num_wire & 0x7
|
wire_type = num_wire & 0x7
|
||||||
|
|
||||||
decoded: Any
|
decoded: Any = None
|
||||||
if wire_type == 0:
|
if wire_type == 0:
|
||||||
decoded, i = decode_varint(value, i)
|
decoded, i = decode_varint(value, i)
|
||||||
elif wire_type == 1:
|
elif wire_type == 1:
|
||||||
@ -362,12 +364,12 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
|||||||
i += length
|
i += length
|
||||||
elif wire_type == 5:
|
elif wire_type == 5:
|
||||||
decoded, i = value[i : i + 4], i + 4
|
decoded, i = value[i : i + 4], i + 4
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Wire type {wire_type}")
|
|
||||||
|
|
||||||
# print(ParsedField(number=number, wire_type=wire_type, value=decoded))
|
# print(ParsedField(number=number, wire_type=wire_type, value=decoded))
|
||||||
|
|
||||||
yield ParsedField(number=number, wire_type=wire_type, value=decoded)
|
yield ParsedField(
|
||||||
|
number=number, wire_type=wire_type, value=decoded, raw=value[start:i]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Bound type variable to allow methods to return `self` of subclasses
|
# Bound type variable to allow methods to return `self` of subclasses
|
||||||
@ -415,6 +417,7 @@ class Message(ABC):
|
|||||||
|
|
||||||
# Now that all the defaults are set, reset it!
|
# Now that all the defaults are set, reset it!
|
||||||
self.__dict__["serialized_on_wire"] = False
|
self.__dict__["serialized_on_wire"] = False
|
||||||
|
self.__dict__["_unknown_fields"] = b""
|
||||||
|
|
||||||
def __setattr__(self, attr: str, value: Any) -> None:
|
def __setattr__(self, attr: str, value: Any) -> None:
|
||||||
if attr != "serialized_on_wire":
|
if attr != "serialized_on_wire":
|
||||||
@ -469,7 +472,7 @@ class Message(ABC):
|
|||||||
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
|
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output + self._unknown_fields
|
||||||
|
|
||||||
# For compatibility with other libraries
|
# For compatibility with other libraries
|
||||||
SerializeToString = __bytes__
|
SerializeToString = __bytes__
|
||||||
@ -571,8 +574,7 @@ class Message(ABC):
|
|||||||
else:
|
else:
|
||||||
setattr(self, field.name, value)
|
setattr(self, field.name, value)
|
||||||
else:
|
else:
|
||||||
# TODO: handle unknown fields
|
self._unknown_fields += parsed.raw
|
||||||
pass
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -48,3 +48,25 @@ def test_enum_as_int_json():
|
|||||||
# Plain-ol'-ints should serialize properly too.
|
# Plain-ol'-ints should serialize properly too.
|
||||||
foo.bar = 1
|
foo.bar = 1
|
||||||
assert foo.to_dict() == {"bar": "ONE"}
|
assert foo.to_dict() == {"bar": "ONE"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_fields():
|
||||||
|
@dataclass
|
||||||
|
class Newer(betterproto.Message):
|
||||||
|
foo: bool = betterproto.bool_field(1)
|
||||||
|
bar: int = betterproto.int32_field(2)
|
||||||
|
baz: str = betterproto.string_field(3)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Older(betterproto.Message):
|
||||||
|
foo: bool = betterproto.bool_field(1)
|
||||||
|
|
||||||
|
newer = Newer(foo=True, bar=1, baz="Hello")
|
||||||
|
serialized_newer = bytes(newer)
|
||||||
|
|
||||||
|
# Unknown fields in `Newer` should round trip with `Older`
|
||||||
|
round_trip = bytes(Older().parse(serialized_newer))
|
||||||
|
assert serialized_newer == round_trip
|
||||||
|
|
||||||
|
new_again = Newer().parse(round_trip)
|
||||||
|
assert newer == new_again
|
||||||
|
Loading…
x
Reference in New Issue
Block a user