Better JSON 64-bit int handling, add way to determine whether a message was sent on the wire, various fixes

This commit is contained in:
Daniel G. Taylor 2019-10-17 23:36:18 -07:00
parent bbceff9341
commit 811b54cabb
No known key found for this signature in database
GPG Key ID: 7BD6DC99C9A87E22
11 changed files with 134 additions and 57 deletions

View File

@ -1,28 +1,28 @@
from abc import ABC
import dataclasses
import inspect
import json
import struct
from abc import ABC
from typing import (
get_type_hints,
AsyncGenerator,
Union,
Generator,
Any,
SupportsBytes,
List,
Tuple,
AsyncGenerator,
Callable,
Type,
Dict,
Generator,
Iterable,
TypeVar,
List,
Optional,
SupportsBytes,
Tuple,
Type,
TypeVar,
Union,
get_type_hints,
)
import dataclasses
import grpclib.client
import grpclib.const
import inspect
# Proto 3 data types
TYPE_ENUM = "enum"
TYPE_BOOL = "bool"
@ -54,6 +54,9 @@ FIXED_TYPES = [
TYPE_SFIXED64,
]
# Fields that are numerical 64-bit types
INT_64_TYPES = [TYPE_INT64, TYPE_UINT64, TYPE_SINT64, TYPE_FIXED64, TYPE_SFIXED64]
# Fields that are efficiently packed when
PACKED_TYPES = [
TYPE_ENUM,
@ -275,7 +278,9 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
return value
def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
def _serialize_single(
field_number: int, proto_type: str, value: Any, *, serialize_empty: bool = False
) -> bytes:
"""Serializes a single field and value."""
value = _preprocess_single(proto_type, value)
@ -290,7 +295,7 @@ def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
key = encode_varint((field_number << 3) | 1)
output += key + value
elif proto_type in WIRE_LEN_DELIM_TYPES:
if len(value):
if len(value) or serialize_empty:
key = encode_varint((field_number << 3) | 2)
output += key + encode_varint(len(value)) + value
else:
@ -362,6 +367,11 @@ class Message(ABC):
to go between Python, binary and JSON protobuf message representations.
"""
# True if this message was or should be serialized on the wire. This can
# be used to detect presence (e.g. optional wrapper message) and is used
# internally during parsing/serialization.
serialized_on_wire: bool
def __post_init__(self) -> None:
# Set a default value for each field in the class after `__init__` has
# already been run.
@ -389,6 +399,15 @@ class Message(ABC):
setattr(self, field.name, value)
# Now that all the defaults are set, reset it!
self.__dict__["serialized_on_wire"] = False
def __setattr__(self, attr: str, value: Any) -> None:
if attr != "serialized_on_wire":
# Track when a field has been set.
self.__dict__["serialized_on_wire"] = True
super().__setattr__(attr, value)
def __bytes__(self) -> bytes:
"""
Get the binary encoded Protobuf representation of this instance.
@ -429,7 +448,12 @@ class Message(ABC):
# Default (zero) values are not serialized
continue
output += _serialize_single(meta.number, meta.proto_type, value)
serialize_empty = False
if isinstance(value, Message) and value.serialized_on_wire:
serialize_empty = True
output += _serialize_single(
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
)
return output
@ -462,12 +486,13 @@ class Message(ABC):
fmt = _pack_fmt(meta.proto_type)
value = struct.unpack(fmt, value)[0]
elif wire_type == WIRE_LEN_DELIM:
if meta.proto_type in [TYPE_STRING]:
if meta.proto_type == TYPE_STRING:
value = value.decode("utf-8")
elif meta.proto_type in [TYPE_MESSAGE]:
elif meta.proto_type == TYPE_MESSAGE:
cls = self._cls_for(field)
value = cls().parse(value)
elif meta.proto_type in [TYPE_MAP]:
value.serialized_on_wire = True
elif meta.proto_type == TYPE_MAP:
# TODO: This is slow, use a cache to make it faster since each
# key/value pair will recreate the class.
assert meta.map_types
@ -535,8 +560,6 @@ class Message(ABC):
# TODO: handle unknown fields
pass
from typing import cast
return self
# For compatibility with other libraries.
@ -549,7 +572,7 @@ class Message(ABC):
Returns a dict representation of this message instance which can be
used to serialize to e.g. JSON.
"""
output = {}
output: Dict[str, Any] = {}
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
v = getattr(self, field.name)
@ -557,13 +580,9 @@ class Message(ABC):
if isinstance(v, list):
# Convert each item.
v = [i.to_dict() for i in v]
# Filter out empty items which we won't serialize.
v = [i for i in v if i]
else:
v = v.to_dict()
if v:
output[field.name] = v
elif v.serialized_on_wire:
output[field.name] = v.to_dict()
elif meta.proto_type == "map":
for k in v:
if hasattr(v[k], "to_dict"):
@ -572,7 +591,13 @@ class Message(ABC):
if v:
output[field.name] = v
elif v != get_default(meta.proto_type):
output[field.name] = v
if meta.proto_type in INT_64_TYPES:
if isinstance(v, list):
output[field.name] = [str(n) for n in v]
else:
output[field.name] = str(v)
else:
output[field.name] = v
return output
def from_dict(self: T, value: dict) -> T:
@ -580,6 +605,7 @@ class Message(ABC):
Parse the key/value pairs in `value` into this message instance. This
returns the instance itself and is therefore assignable and chainable.
"""
self.serialized_on_wire = True
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
if field.name in value and value[field.name] is not None:
@ -598,7 +624,13 @@ class Message(ABC):
for k in value[field.name]:
v[k] = cls().from_dict(value[field.name][k])
else:
setattr(self, field.name, value[field.name])
v = value[field.name]
if meta.proto_type in INT_64_TYPES:
if isinstance(value[field.name], list):
v = [int(n) for n in value[field.name]]
else:
v = int(value[field.name])
setattr(self, field.name, v)
return self
def to_json(self) -> str:
@ -613,9 +645,6 @@ class Message(ABC):
return self.from_dict(json.loads(value))
ResponseType = TypeVar("ResponseType", bound="Message")
class ServiceStub(ABC):
"""
Base class for async gRPC service stubs.

View File

@ -1,19 +1,21 @@
#!/usr/bin/env python
import importlib
import json
import os # isort: skip
import subprocess
import sys
from typing import Generator, Tuple
from google.protobuf import symbol_database
from google.protobuf.descriptor_pool import DescriptorPool
from google.protobuf.json_format import MessageToJson, Parse
# Force pure-python implementation instead of C++, otherwise imports
# break things because we can't properly reset the symbol database.
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
import subprocess
import importlib
import sys
from typing import Generator, Tuple
from google.protobuf.json_format import Parse
from google.protobuf import symbol_database
from google.protobuf.descriptor_pool import DescriptorPool
root = os.path.dirname(os.path.realpath(__file__))
@ -68,5 +70,10 @@ if __name__ == "__main__":
print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}")
imported = importlib.import_module(f"{parts[0]}_pb2")
serialized = Parse(open(filename).read(), imported.Test()).SerializeToString()
parsed = Parse(open(filename).read(), imported.Test())
serialized = parsed.SerializeToString()
serialized_json = MessageToJson(
parsed, preserving_proto_field_name=True, use_integers_for_enums=True
)
assert json.loads(serialized_json) == json.load(open(filename))
open(out, "wb").write(serialized)

View File

@ -1,5 +1,6 @@
{
"nested": {
"count": 150
}
},
"sibling": {}
}

View File

@ -10,6 +10,7 @@ message Test {
Nested nested = 1;
Sibling sibling = 2;
Sibling sibling2 = 3;
}
message Sibling {

View File

@ -1,5 +1,5 @@
{
"counts": [1, 2, -1, -2],
"signed": [1, 2, -1, -2],
"signed": ["1", "2", "-1", "-2"],
"fixed": [1.0, 2.7, 3.4]
}

View File

@ -1,4 +1,4 @@
{
"signed_32": -150,
"signed_64": -150
"signed_64": "-150"
}

View File

@ -1,4 +1,4 @@
{
"signed_32": 150,
"signed_64": 150
"signed_64": "150"
}

View File

@ -0,0 +1,32 @@
import betterproto
from dataclasses import dataclass
def test_has_field():
@dataclass
class Bar(betterproto.Message):
baz: int = betterproto.int32_field(1)
@dataclass
class Foo(betterproto.Message):
bar: Bar = betterproto.message_field(1)
# Unset by default
foo = Foo()
assert foo.bar.serialized_on_wire == False
# Serialized after setting something
foo.bar.baz = 1
assert foo.bar.serialized_on_wire == True
# Still has it after setting the default value
foo.bar.baz = 0
assert foo.bar.serialized_on_wire == True
# Manual override
foo.bar.serialized_on_wire = False
assert foo.bar.serialized_on_wire == False
# Can manually set it but defaults to false
foo.bar = Bar()
assert foo.bar.serialized_on_wire == False

View File

@ -1,8 +1,9 @@
import importlib
import pytest
import json
from .generate import get_files, get_base
import pytest
from .generate import get_base, get_files
inputs = get_files(".bin")

View File

@ -1,27 +1,24 @@
#!/usr/bin/env python
import sys
import itertools
import json
import os.path
import re
from typing import Tuple, Any, List
import sys
import textwrap
from typing import Any, List, Tuple
from jinja2 import Environment, PackageLoader
from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.descriptor_pb2 import (
DescriptorProto,
EnumDescriptorProto,
FileDescriptorProto,
FieldDescriptorProto,
FileDescriptorProto,
ServiceDescriptorProto,
)
from google.protobuf.compiler import plugin_pb2 as plugin
from jinja2 import Environment, PackageLoader
def snake_case(value: str) -> str:
return (

9
pyproject.toml Normal file
View File

@ -0,0 +1,9 @@
[tool.black]
target-version = ['py37']
[tool.isort]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
line_length = 88