Better JSON 64-bit int handling, add way to determine whether a message was sent on the wire, various fixes
This commit is contained in:
@@ -1,28 +1,28 @@
|
||||
from abc import ABC
|
||||
import dataclasses
|
||||
import inspect
|
||||
import json
|
||||
import struct
|
||||
from abc import ABC
|
||||
from typing import (
|
||||
get_type_hints,
|
||||
AsyncGenerator,
|
||||
Union,
|
||||
Generator,
|
||||
Any,
|
||||
SupportsBytes,
|
||||
List,
|
||||
Tuple,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Type,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
TypeVar,
|
||||
List,
|
||||
Optional,
|
||||
SupportsBytes,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_type_hints,
|
||||
)
|
||||
import dataclasses
|
||||
|
||||
import grpclib.client
|
||||
import grpclib.const
|
||||
|
||||
import inspect
|
||||
|
||||
# Proto 3 data types
|
||||
TYPE_ENUM = "enum"
|
||||
TYPE_BOOL = "bool"
|
||||
@@ -54,6 +54,9 @@ FIXED_TYPES = [
|
||||
TYPE_SFIXED64,
|
||||
]
|
||||
|
||||
# Fields that are numerical 64-bit types
|
||||
INT_64_TYPES = [TYPE_INT64, TYPE_UINT64, TYPE_SINT64, TYPE_FIXED64, TYPE_SFIXED64]
|
||||
|
||||
# Fields that are efficiently packed when
|
||||
PACKED_TYPES = [
|
||||
TYPE_ENUM,
|
||||
@@ -275,7 +278,9 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
|
||||
return value
|
||||
|
||||
|
||||
def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
|
||||
def _serialize_single(
|
||||
field_number: int, proto_type: str, value: Any, *, serialize_empty: bool = False
|
||||
) -> bytes:
|
||||
"""Serializes a single field and value."""
|
||||
value = _preprocess_single(proto_type, value)
|
||||
|
||||
@@ -290,7 +295,7 @@ def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
|
||||
key = encode_varint((field_number << 3) | 1)
|
||||
output += key + value
|
||||
elif proto_type in WIRE_LEN_DELIM_TYPES:
|
||||
if len(value):
|
||||
if len(value) or serialize_empty:
|
||||
key = encode_varint((field_number << 3) | 2)
|
||||
output += key + encode_varint(len(value)) + value
|
||||
else:
|
||||
@@ -362,6 +367,11 @@ class Message(ABC):
|
||||
to go between Python, binary and JSON protobuf message representations.
|
||||
"""
|
||||
|
||||
# True if this message was or should be serialized on the wire. This can
|
||||
# be used to detect presence (e.g. optional wrapper message) and is used
|
||||
# internally during parsing/serialization.
|
||||
serialized_on_wire: bool
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Set a default value for each field in the class after `__init__` has
|
||||
# already been run.
|
||||
@@ -389,6 +399,15 @@ class Message(ABC):
|
||||
|
||||
setattr(self, field.name, value)
|
||||
|
||||
# Now that all the defaults are set, reset it!
|
||||
self.__dict__["serialized_on_wire"] = False
|
||||
|
||||
def __setattr__(self, attr: str, value: Any) -> None:
|
||||
if attr != "serialized_on_wire":
|
||||
# Track when a field has been set.
|
||||
self.__dict__["serialized_on_wire"] = True
|
||||
super().__setattr__(attr, value)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
"""
|
||||
Get the binary encoded Protobuf representation of this instance.
|
||||
@@ -429,7 +448,12 @@ class Message(ABC):
|
||||
# Default (zero) values are not serialized
|
||||
continue
|
||||
|
||||
output += _serialize_single(meta.number, meta.proto_type, value)
|
||||
serialize_empty = False
|
||||
if isinstance(value, Message) and value.serialized_on_wire:
|
||||
serialize_empty = True
|
||||
output += _serialize_single(
|
||||
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@@ -462,12 +486,13 @@ class Message(ABC):
|
||||
fmt = _pack_fmt(meta.proto_type)
|
||||
value = struct.unpack(fmt, value)[0]
|
||||
elif wire_type == WIRE_LEN_DELIM:
|
||||
if meta.proto_type in [TYPE_STRING]:
|
||||
if meta.proto_type == TYPE_STRING:
|
||||
value = value.decode("utf-8")
|
||||
elif meta.proto_type in [TYPE_MESSAGE]:
|
||||
elif meta.proto_type == TYPE_MESSAGE:
|
||||
cls = self._cls_for(field)
|
||||
value = cls().parse(value)
|
||||
elif meta.proto_type in [TYPE_MAP]:
|
||||
value.serialized_on_wire = True
|
||||
elif meta.proto_type == TYPE_MAP:
|
||||
# TODO: This is slow, use a cache to make it faster since each
|
||||
# key/value pair will recreate the class.
|
||||
assert meta.map_types
|
||||
@@ -535,8 +560,6 @@ class Message(ABC):
|
||||
# TODO: handle unknown fields
|
||||
pass
|
||||
|
||||
from typing import cast
|
||||
|
||||
return self
|
||||
|
||||
# For compatibility with other libraries.
|
||||
@@ -549,7 +572,7 @@ class Message(ABC):
|
||||
Returns a dict representation of this message instance which can be
|
||||
used to serialize to e.g. JSON.
|
||||
"""
|
||||
output = {}
|
||||
output: Dict[str, Any] = {}
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
v = getattr(self, field.name)
|
||||
@@ -557,13 +580,9 @@ class Message(ABC):
|
||||
if isinstance(v, list):
|
||||
# Convert each item.
|
||||
v = [i.to_dict() for i in v]
|
||||
# Filter out empty items which we won't serialize.
|
||||
v = [i for i in v if i]
|
||||
else:
|
||||
v = v.to_dict()
|
||||
|
||||
if v:
|
||||
output[field.name] = v
|
||||
elif v.serialized_on_wire:
|
||||
output[field.name] = v.to_dict()
|
||||
elif meta.proto_type == "map":
|
||||
for k in v:
|
||||
if hasattr(v[k], "to_dict"):
|
||||
@@ -572,7 +591,13 @@ class Message(ABC):
|
||||
if v:
|
||||
output[field.name] = v
|
||||
elif v != get_default(meta.proto_type):
|
||||
output[field.name] = v
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(v, list):
|
||||
output[field.name] = [str(n) for n in v]
|
||||
else:
|
||||
output[field.name] = str(v)
|
||||
else:
|
||||
output[field.name] = v
|
||||
return output
|
||||
|
||||
def from_dict(self: T, value: dict) -> T:
|
||||
@@ -580,6 +605,7 @@ class Message(ABC):
|
||||
Parse the key/value pairs in `value` into this message instance. This
|
||||
returns the instance itself and is therefore assignable and chainable.
|
||||
"""
|
||||
self.serialized_on_wire = True
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
if field.name in value and value[field.name] is not None:
|
||||
@@ -598,7 +624,13 @@ class Message(ABC):
|
||||
for k in value[field.name]:
|
||||
v[k] = cls().from_dict(value[field.name][k])
|
||||
else:
|
||||
setattr(self, field.name, value[field.name])
|
||||
v = value[field.name]
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(value[field.name], list):
|
||||
v = [int(n) for n in value[field.name]]
|
||||
else:
|
||||
v = int(value[field.name])
|
||||
setattr(self, field.name, v)
|
||||
return self
|
||||
|
||||
def to_json(self) -> str:
|
||||
@@ -613,9 +645,6 @@ class Message(ABC):
|
||||
return self.from_dict(json.loads(value))
|
||||
|
||||
|
||||
ResponseType = TypeVar("ResponseType", bound="Message")
|
||||
|
||||
|
||||
class ServiceStub(ABC):
|
||||
"""
|
||||
Base class for async gRPC service stubs.
|
||||
|
||||
Reference in New Issue
Block a user