Better JSON 64-bit int handling, add way to determine whether a message was sent on the wire, various fixes

This commit is contained in:
Daniel G. Taylor
2019-10-17 23:36:18 -07:00
parent bbceff9341
commit 811b54cabb
11 changed files with 134 additions and 57 deletions

View File

@@ -1,28 +1,28 @@
from abc import ABC
import dataclasses
import inspect
import json
import struct
from abc import ABC
from typing import (
get_type_hints,
AsyncGenerator,
Union,
Generator,
Any,
SupportsBytes,
List,
Tuple,
AsyncGenerator,
Callable,
Type,
Dict,
Generator,
Iterable,
TypeVar,
List,
Optional,
SupportsBytes,
Tuple,
Type,
TypeVar,
Union,
get_type_hints,
)
import dataclasses
import grpclib.client
import grpclib.const
import inspect
# Proto 3 data types
TYPE_ENUM = "enum"
TYPE_BOOL = "bool"
@@ -54,6 +54,9 @@ FIXED_TYPES = [
TYPE_SFIXED64,
]
# Fields that are numerical 64-bit types
INT_64_TYPES = [TYPE_INT64, TYPE_UINT64, TYPE_SINT64, TYPE_FIXED64, TYPE_SFIXED64]
# Fields that are efficiently packed when
PACKED_TYPES = [
TYPE_ENUM,
@@ -275,7 +278,9 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
return value
def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
def _serialize_single(
field_number: int, proto_type: str, value: Any, *, serialize_empty: bool = False
) -> bytes:
"""Serializes a single field and value."""
value = _preprocess_single(proto_type, value)
@@ -290,7 +295,7 @@ def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
key = encode_varint((field_number << 3) | 1)
output += key + value
elif proto_type in WIRE_LEN_DELIM_TYPES:
if len(value):
if len(value) or serialize_empty:
key = encode_varint((field_number << 3) | 2)
output += key + encode_varint(len(value)) + value
else:
@@ -362,6 +367,11 @@ class Message(ABC):
to go between Python, binary and JSON protobuf message representations.
"""
# True if this message was or should be serialized on the wire. This can
# be used to detect presence (e.g. optional wrapper message) and is used
# internally during parsing/serialization.
serialized_on_wire: bool
def __post_init__(self) -> None:
# Set a default value for each field in the class after `__init__` has
# already been run.
@@ -389,6 +399,15 @@ class Message(ABC):
setattr(self, field.name, value)
# Now that all the defaults are set, reset it!
self.__dict__["serialized_on_wire"] = False
def __setattr__(self, attr: str, value: Any) -> None:
if attr != "serialized_on_wire":
# Track when a field has been set.
self.__dict__["serialized_on_wire"] = True
super().__setattr__(attr, value)
def __bytes__(self) -> bytes:
"""
Get the binary encoded Protobuf representation of this instance.
@@ -429,7 +448,12 @@ class Message(ABC):
# Default (zero) values are not serialized
continue
output += _serialize_single(meta.number, meta.proto_type, value)
serialize_empty = False
if isinstance(value, Message) and value.serialized_on_wire:
serialize_empty = True
output += _serialize_single(
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
)
return output
@@ -462,12 +486,13 @@ class Message(ABC):
fmt = _pack_fmt(meta.proto_type)
value = struct.unpack(fmt, value)[0]
elif wire_type == WIRE_LEN_DELIM:
if meta.proto_type in [TYPE_STRING]:
if meta.proto_type == TYPE_STRING:
value = value.decode("utf-8")
elif meta.proto_type in [TYPE_MESSAGE]:
elif meta.proto_type == TYPE_MESSAGE:
cls = self._cls_for(field)
value = cls().parse(value)
elif meta.proto_type in [TYPE_MAP]:
value.serialized_on_wire = True
elif meta.proto_type == TYPE_MAP:
# TODO: This is slow, use a cache to make it faster since each
# key/value pair will recreate the class.
assert meta.map_types
@@ -535,8 +560,6 @@ class Message(ABC):
# TODO: handle unknown fields
pass
from typing import cast
return self
# For compatibility with other libraries.
@@ -549,7 +572,7 @@ class Message(ABC):
Returns a dict representation of this message instance which can be
used to serialize to e.g. JSON.
"""
output = {}
output: Dict[str, Any] = {}
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
v = getattr(self, field.name)
@@ -557,13 +580,9 @@ class Message(ABC):
if isinstance(v, list):
# Convert each item.
v = [i.to_dict() for i in v]
# Filter out empty items which we won't serialize.
v = [i for i in v if i]
else:
v = v.to_dict()
if v:
output[field.name] = v
elif v.serialized_on_wire:
output[field.name] = v.to_dict()
elif meta.proto_type == "map":
for k in v:
if hasattr(v[k], "to_dict"):
@@ -572,7 +591,13 @@ class Message(ABC):
if v:
output[field.name] = v
elif v != get_default(meta.proto_type):
output[field.name] = v
if meta.proto_type in INT_64_TYPES:
if isinstance(v, list):
output[field.name] = [str(n) for n in v]
else:
output[field.name] = str(v)
else:
output[field.name] = v
return output
def from_dict(self: T, value: dict) -> T:
@@ -580,6 +605,7 @@ class Message(ABC):
Parse the key/value pairs in `value` into this message instance. This
returns the instance itself and is therefore assignable and chainable.
"""
self.serialized_on_wire = True
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
if field.name in value and value[field.name] is not None:
@@ -598,7 +624,13 @@ class Message(ABC):
for k in value[field.name]:
v[k] = cls().from_dict(value[field.name][k])
else:
setattr(self, field.name, value[field.name])
v = value[field.name]
if meta.proto_type in INT_64_TYPES:
if isinstance(value[field.name], list):
v = [int(n) for n in value[field.name]]
else:
v = int(value[field.name])
setattr(self, field.name, v)
return self
def to_json(self) -> str:
@@ -613,9 +645,6 @@ class Message(ABC):
return self.from_dict(json.loads(value))
ResponseType = TypeVar("ResponseType", bound="Message")
class ServiceStub(ABC):
"""
Base class for async gRPC service stubs.