Better JSON 64-bit int handling, add way to determine whether a message was sent on the wire, various fixes

This commit is contained in:
Daniel G. Taylor
2019-10-17 23:36:18 -07:00
parent bbceff9341
commit 811b54cabb
11 changed files with 134 additions and 57 deletions

View File

@@ -1,28 +1,28 @@
from abc import ABC
import dataclasses
import inspect
import json
import struct
from abc import ABC
from typing import (
get_type_hints,
AsyncGenerator,
Union,
Generator,
Any,
SupportsBytes,
List,
Tuple,
AsyncGenerator,
Callable,
Type,
Dict,
Generator,
Iterable,
TypeVar,
List,
Optional,
SupportsBytes,
Tuple,
Type,
TypeVar,
Union,
get_type_hints,
)
import dataclasses
import grpclib.client
import grpclib.const
import inspect
# Proto 3 data types
TYPE_ENUM = "enum"
TYPE_BOOL = "bool"
@@ -54,6 +54,9 @@ FIXED_TYPES = [
TYPE_SFIXED64,
]
# Fields that are numerical 64-bit types
INT_64_TYPES = [TYPE_INT64, TYPE_UINT64, TYPE_SINT64, TYPE_FIXED64, TYPE_SFIXED64]
# Fields that are efficiently packed when
PACKED_TYPES = [
TYPE_ENUM,
@@ -275,7 +278,9 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
return value
def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
def _serialize_single(
field_number: int, proto_type: str, value: Any, *, serialize_empty: bool = False
) -> bytes:
"""Serializes a single field and value."""
value = _preprocess_single(proto_type, value)
@@ -290,7 +295,7 @@ def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
key = encode_varint((field_number << 3) | 1)
output += key + value
elif proto_type in WIRE_LEN_DELIM_TYPES:
if len(value):
if len(value) or serialize_empty:
key = encode_varint((field_number << 3) | 2)
output += key + encode_varint(len(value)) + value
else:
@@ -362,6 +367,11 @@ class Message(ABC):
to go between Python, binary and JSON protobuf message representations.
"""
# True if this message was or should be serialized on the wire. This can
# be used to detect presence (e.g. optional wrapper message) and is used
# internally during parsing/serialization.
serialized_on_wire: bool
def __post_init__(self) -> None:
# Set a default value for each field in the class after `__init__` has
# already been run.
@@ -389,6 +399,15 @@ class Message(ABC):
setattr(self, field.name, value)
# Now that all the defaults are set, reset it!
self.__dict__["serialized_on_wire"] = False
def __setattr__(self, attr: str, value: Any) -> None:
if attr != "serialized_on_wire":
# Track when a field has been set.
self.__dict__["serialized_on_wire"] = True
super().__setattr__(attr, value)
def __bytes__(self) -> bytes:
"""
Get the binary encoded Protobuf representation of this instance.
@@ -429,7 +448,12 @@ class Message(ABC):
# Default (zero) values are not serialized
continue
output += _serialize_single(meta.number, meta.proto_type, value)
serialize_empty = False
if isinstance(value, Message) and value.serialized_on_wire:
serialize_empty = True
output += _serialize_single(
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
)
return output
@@ -462,12 +486,13 @@ class Message(ABC):
fmt = _pack_fmt(meta.proto_type)
value = struct.unpack(fmt, value)[0]
elif wire_type == WIRE_LEN_DELIM:
if meta.proto_type in [TYPE_STRING]:
if meta.proto_type == TYPE_STRING:
value = value.decode("utf-8")
elif meta.proto_type in [TYPE_MESSAGE]:
elif meta.proto_type == TYPE_MESSAGE:
cls = self._cls_for(field)
value = cls().parse(value)
elif meta.proto_type in [TYPE_MAP]:
value.serialized_on_wire = True
elif meta.proto_type == TYPE_MAP:
# TODO: This is slow, use a cache to make it faster since each
# key/value pair will recreate the class.
assert meta.map_types
@@ -535,8 +560,6 @@ class Message(ABC):
# TODO: handle unknown fields
pass
from typing import cast
return self
# For compatibility with other libraries.
@@ -549,7 +572,7 @@ class Message(ABC):
Returns a dict representation of this message instance which can be
used to serialize to e.g. JSON.
"""
output = {}
output: Dict[str, Any] = {}
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
v = getattr(self, field.name)
@@ -557,13 +580,9 @@ class Message(ABC):
if isinstance(v, list):
# Convert each item.
v = [i.to_dict() for i in v]
# Filter out empty items which we won't serialize.
v = [i for i in v if i]
else:
v = v.to_dict()
if v:
output[field.name] = v
elif v.serialized_on_wire:
output[field.name] = v.to_dict()
elif meta.proto_type == "map":
for k in v:
if hasattr(v[k], "to_dict"):
@@ -572,7 +591,13 @@ class Message(ABC):
if v:
output[field.name] = v
elif v != get_default(meta.proto_type):
output[field.name] = v
if meta.proto_type in INT_64_TYPES:
if isinstance(v, list):
output[field.name] = [str(n) for n in v]
else:
output[field.name] = str(v)
else:
output[field.name] = v
return output
def from_dict(self: T, value: dict) -> T:
@@ -580,6 +605,7 @@ class Message(ABC):
Parse the key/value pairs in `value` into this message instance. This
returns the instance itself and is therefore assignable and chainable.
"""
self.serialized_on_wire = True
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
if field.name in value and value[field.name] is not None:
@@ -598,7 +624,13 @@ class Message(ABC):
for k in value[field.name]:
v[k] = cls().from_dict(value[field.name][k])
else:
setattr(self, field.name, value[field.name])
v = value[field.name]
if meta.proto_type in INT_64_TYPES:
if isinstance(value[field.name], list):
v = [int(n) for n in value[field.name]]
else:
v = int(value[field.name])
setattr(self, field.name, v)
return self
def to_json(self) -> str:
@@ -613,9 +645,6 @@ class Message(ABC):
return self.from_dict(json.loads(value))
ResponseType = TypeVar("ResponseType", bound="Message")
class ServiceStub(ABC):
"""
Base class for async gRPC service stubs.

View File

@@ -1,19 +1,21 @@
#!/usr/bin/env python
import importlib
import json
import os # isort: skip
import subprocess
import sys
from typing import Generator, Tuple
from google.protobuf import symbol_database
from google.protobuf.descriptor_pool import DescriptorPool
from google.protobuf.json_format import MessageToJson, Parse
# Force pure-python implementation instead of C++, otherwise imports
# break things because we can't properly reset the symbol database.
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
import subprocess
import importlib
import sys
from typing import Generator, Tuple
from google.protobuf.json_format import Parse
from google.protobuf import symbol_database
from google.protobuf.descriptor_pool import DescriptorPool
root = os.path.dirname(os.path.realpath(__file__))
@@ -68,5 +70,10 @@ if __name__ == "__main__":
print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}")
imported = importlib.import_module(f"{parts[0]}_pb2")
serialized = Parse(open(filename).read(), imported.Test()).SerializeToString()
parsed = Parse(open(filename).read(), imported.Test())
serialized = parsed.SerializeToString()
serialized_json = MessageToJson(
parsed, preserving_proto_field_name=True, use_integers_for_enums=True
)
assert json.loads(serialized_json) == json.load(open(filename))
open(out, "wb").write(serialized)

View File

@@ -1,5 +1,6 @@
{
"nested": {
"count": 150
}
},
"sibling": {}
}

View File

@@ -10,8 +10,9 @@ message Test {
Nested nested = 1;
Sibling sibling = 2;
Sibling sibling2 = 3;
}
message Sibling {
int32 foo = 1;
}
}

View File

@@ -1,5 +1,5 @@
{
"counts": [1, 2, -1, -2],
"signed": [1, 2, -1, -2],
"signed": ["1", "2", "-1", "-2"],
"fixed": [1.0, 2.7, 3.4]
}

View File

@@ -1,4 +1,4 @@
{
"signed_32": -150,
"signed_64": -150
"signed_64": "-150"
}

View File

@@ -1,4 +1,4 @@
{
"signed_32": 150,
"signed_64": 150
"signed_64": "150"
}

View File

@@ -0,0 +1,32 @@
import betterproto
from dataclasses import dataclass
def test_has_field():
@dataclass
class Bar(betterproto.Message):
baz: int = betterproto.int32_field(1)
@dataclass
class Foo(betterproto.Message):
bar: Bar = betterproto.message_field(1)
# Unset by default
foo = Foo()
assert foo.bar.serialized_on_wire == False
# Serialized after setting something
foo.bar.baz = 1
assert foo.bar.serialized_on_wire == True
# Still has it after setting the default value
foo.bar.baz = 0
assert foo.bar.serialized_on_wire == True
# Manual override
foo.bar.serialized_on_wire = False
assert foo.bar.serialized_on_wire == False
# Can manually set it but defaults to false
foo.bar = Bar()
assert foo.bar.serialized_on_wire == False

View File

@@ -1,8 +1,9 @@
import importlib
import pytest
import json
from .generate import get_files, get_base
import pytest
from .generate import get_base, get_files
inputs = get_files(".bin")