Better JSON 64-bit int handling, add way to determine whether a message was sent on the wire, various fixes

This commit is contained in:
Daniel G. Taylor 2019-10-17 23:36:18 -07:00
parent bbceff9341
commit 811b54cabb
No known key found for this signature in database
GPG Key ID: 7BD6DC99C9A87E22
11 changed files with 134 additions and 57 deletions

View File

@ -1,28 +1,28 @@
from abc import ABC import dataclasses
import inspect
import json import json
import struct import struct
from abc import ABC
from typing import ( from typing import (
get_type_hints,
AsyncGenerator,
Union,
Generator,
Any, Any,
SupportsBytes, AsyncGenerator,
List,
Tuple,
Callable, Callable,
Type, Dict,
Generator,
Iterable, Iterable,
TypeVar, List,
Optional, Optional,
SupportsBytes,
Tuple,
Type,
TypeVar,
Union,
get_type_hints,
) )
import dataclasses
import grpclib.client import grpclib.client
import grpclib.const import grpclib.const
import inspect
# Proto 3 data types # Proto 3 data types
TYPE_ENUM = "enum" TYPE_ENUM = "enum"
TYPE_BOOL = "bool" TYPE_BOOL = "bool"
@ -54,6 +54,9 @@ FIXED_TYPES = [
TYPE_SFIXED64, TYPE_SFIXED64,
] ]
# Fields that are numerical 64-bit types
INT_64_TYPES = [TYPE_INT64, TYPE_UINT64, TYPE_SINT64, TYPE_FIXED64, TYPE_SFIXED64]
# Fields that are efficiently packed when # Fields that are efficiently packed when
PACKED_TYPES = [ PACKED_TYPES = [
TYPE_ENUM, TYPE_ENUM,
@ -275,7 +278,9 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
return value return value
def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes: def _serialize_single(
field_number: int, proto_type: str, value: Any, *, serialize_empty: bool = False
) -> bytes:
"""Serializes a single field and value.""" """Serializes a single field and value."""
value = _preprocess_single(proto_type, value) value = _preprocess_single(proto_type, value)
@ -290,7 +295,7 @@ def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
key = encode_varint((field_number << 3) | 1) key = encode_varint((field_number << 3) | 1)
output += key + value output += key + value
elif proto_type in WIRE_LEN_DELIM_TYPES: elif proto_type in WIRE_LEN_DELIM_TYPES:
if len(value): if len(value) or serialize_empty:
key = encode_varint((field_number << 3) | 2) key = encode_varint((field_number << 3) | 2)
output += key + encode_varint(len(value)) + value output += key + encode_varint(len(value)) + value
else: else:
@ -362,6 +367,11 @@ class Message(ABC):
to go between Python, binary and JSON protobuf message representations. to go between Python, binary and JSON protobuf message representations.
""" """
# True if this message was or should be serialized on the wire. This can
# be used to detect presence (e.g. optional wrapper message) and is used
# internally during parsing/serialization.
serialized_on_wire: bool
def __post_init__(self) -> None: def __post_init__(self) -> None:
# Set a default value for each field in the class after `__init__` has # Set a default value for each field in the class after `__init__` has
# already been run. # already been run.
@ -389,6 +399,15 @@ class Message(ABC):
setattr(self, field.name, value) setattr(self, field.name, value)
# Now that all the defaults are set, reset it!
self.__dict__["serialized_on_wire"] = False
def __setattr__(self, attr: str, value: Any) -> None:
if attr != "serialized_on_wire":
# Track when a field has been set.
self.__dict__["serialized_on_wire"] = True
super().__setattr__(attr, value)
def __bytes__(self) -> bytes: def __bytes__(self) -> bytes:
""" """
Get the binary encoded Protobuf representation of this instance. Get the binary encoded Protobuf representation of this instance.
@ -429,7 +448,12 @@ class Message(ABC):
# Default (zero) values are not serialized # Default (zero) values are not serialized
continue continue
output += _serialize_single(meta.number, meta.proto_type, value) serialize_empty = False
if isinstance(value, Message) and value.serialized_on_wire:
serialize_empty = True
output += _serialize_single(
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
)
return output return output
@ -462,12 +486,13 @@ class Message(ABC):
fmt = _pack_fmt(meta.proto_type) fmt = _pack_fmt(meta.proto_type)
value = struct.unpack(fmt, value)[0] value = struct.unpack(fmt, value)[0]
elif wire_type == WIRE_LEN_DELIM: elif wire_type == WIRE_LEN_DELIM:
if meta.proto_type in [TYPE_STRING]: if meta.proto_type == TYPE_STRING:
value = value.decode("utf-8") value = value.decode("utf-8")
elif meta.proto_type in [TYPE_MESSAGE]: elif meta.proto_type == TYPE_MESSAGE:
cls = self._cls_for(field) cls = self._cls_for(field)
value = cls().parse(value) value = cls().parse(value)
elif meta.proto_type in [TYPE_MAP]: value.serialized_on_wire = True
elif meta.proto_type == TYPE_MAP:
# TODO: This is slow, use a cache to make it faster since each # TODO: This is slow, use a cache to make it faster since each
# key/value pair will recreate the class. # key/value pair will recreate the class.
assert meta.map_types assert meta.map_types
@ -535,8 +560,6 @@ class Message(ABC):
# TODO: handle unknown fields # TODO: handle unknown fields
pass pass
from typing import cast
return self return self
# For compatibility with other libraries. # For compatibility with other libraries.
@ -549,7 +572,7 @@ class Message(ABC):
Returns a dict representation of this message instance which can be Returns a dict representation of this message instance which can be
used to serialize to e.g. JSON. used to serialize to e.g. JSON.
""" """
output = {} output: Dict[str, Any] = {}
for field in dataclasses.fields(self): for field in dataclasses.fields(self):
meta = FieldMetadata.get(field) meta = FieldMetadata.get(field)
v = getattr(self, field.name) v = getattr(self, field.name)
@ -557,13 +580,9 @@ class Message(ABC):
if isinstance(v, list): if isinstance(v, list):
# Convert each item. # Convert each item.
v = [i.to_dict() for i in v] v = [i.to_dict() for i in v]
# Filter out empty items which we won't serialize.
v = [i for i in v if i]
else:
v = v.to_dict()
if v:
output[field.name] = v output[field.name] = v
elif v.serialized_on_wire:
output[field.name] = v.to_dict()
elif meta.proto_type == "map": elif meta.proto_type == "map":
for k in v: for k in v:
if hasattr(v[k], "to_dict"): if hasattr(v[k], "to_dict"):
@ -572,6 +591,12 @@ class Message(ABC):
if v: if v:
output[field.name] = v output[field.name] = v
elif v != get_default(meta.proto_type): elif v != get_default(meta.proto_type):
if meta.proto_type in INT_64_TYPES:
if isinstance(v, list):
output[field.name] = [str(n) for n in v]
else:
output[field.name] = str(v)
else:
output[field.name] = v output[field.name] = v
return output return output
@ -580,6 +605,7 @@ class Message(ABC):
Parse the key/value pairs in `value` into this message instance. This Parse the key/value pairs in `value` into this message instance. This
returns the instance itself and is therefore assignable and chainable. returns the instance itself and is therefore assignable and chainable.
""" """
self.serialized_on_wire = True
for field in dataclasses.fields(self): for field in dataclasses.fields(self):
meta = FieldMetadata.get(field) meta = FieldMetadata.get(field)
if field.name in value and value[field.name] is not None: if field.name in value and value[field.name] is not None:
@ -598,7 +624,13 @@ class Message(ABC):
for k in value[field.name]: for k in value[field.name]:
v[k] = cls().from_dict(value[field.name][k]) v[k] = cls().from_dict(value[field.name][k])
else: else:
setattr(self, field.name, value[field.name]) v = value[field.name]
if meta.proto_type in INT_64_TYPES:
if isinstance(value[field.name], list):
v = [int(n) for n in value[field.name]]
else:
v = int(value[field.name])
setattr(self, field.name, v)
return self return self
def to_json(self) -> str: def to_json(self) -> str:
@ -613,9 +645,6 @@ class Message(ABC):
return self.from_dict(json.loads(value)) return self.from_dict(json.loads(value))
ResponseType = TypeVar("ResponseType", bound="Message")
class ServiceStub(ABC): class ServiceStub(ABC):
""" """
Base class for async gRPC service stubs. Base class for async gRPC service stubs.

View File

@ -1,19 +1,21 @@
#!/usr/bin/env python #!/usr/bin/env python
import importlib
import json
import os # isort: skip import os # isort: skip
import subprocess
import sys
from typing import Generator, Tuple
from google.protobuf import symbol_database
from google.protobuf.descriptor_pool import DescriptorPool
from google.protobuf.json_format import MessageToJson, Parse
# Force pure-python implementation instead of C++, otherwise imports # Force pure-python implementation instead of C++, otherwise imports
# break things because we can't properly reset the symbol database. # break things because we can't properly reset the symbol database.
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
import subprocess
import importlib
import sys
from typing import Generator, Tuple
from google.protobuf.json_format import Parse
from google.protobuf import symbol_database
from google.protobuf.descriptor_pool import DescriptorPool
root = os.path.dirname(os.path.realpath(__file__)) root = os.path.dirname(os.path.realpath(__file__))
@ -68,5 +70,10 @@ if __name__ == "__main__":
print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}") print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}")
imported = importlib.import_module(f"{parts[0]}_pb2") imported = importlib.import_module(f"{parts[0]}_pb2")
serialized = Parse(open(filename).read(), imported.Test()).SerializeToString() parsed = Parse(open(filename).read(), imported.Test())
serialized = parsed.SerializeToString()
serialized_json = MessageToJson(
parsed, preserving_proto_field_name=True, use_integers_for_enums=True
)
assert json.loads(serialized_json) == json.load(open(filename))
open(out, "wb").write(serialized) open(out, "wb").write(serialized)

View File

@ -1,5 +1,6 @@
{ {
"nested": { "nested": {
"count": 150 "count": 150
} },
"sibling": {}
} }

View File

@ -10,6 +10,7 @@ message Test {
Nested nested = 1; Nested nested = 1;
Sibling sibling = 2; Sibling sibling = 2;
Sibling sibling2 = 3;
} }
message Sibling { message Sibling {

View File

@ -1,5 +1,5 @@
{ {
"counts": [1, 2, -1, -2], "counts": [1, 2, -1, -2],
"signed": [1, 2, -1, -2], "signed": ["1", "2", "-1", "-2"],
"fixed": [1.0, 2.7, 3.4] "fixed": [1.0, 2.7, 3.4]
} }

View File

@ -1,4 +1,4 @@
{ {
"signed_32": -150, "signed_32": -150,
"signed_64": -150 "signed_64": "-150"
} }

View File

@ -1,4 +1,4 @@
{ {
"signed_32": 150, "signed_32": 150,
"signed_64": 150 "signed_64": "150"
} }

View File

@ -0,0 +1,32 @@
import betterproto
from dataclasses import dataclass
def test_has_field():
@dataclass
class Bar(betterproto.Message):
baz: int = betterproto.int32_field(1)
@dataclass
class Foo(betterproto.Message):
bar: Bar = betterproto.message_field(1)
# Unset by default
foo = Foo()
assert foo.bar.serialized_on_wire == False
# Serialized after setting something
foo.bar.baz = 1
assert foo.bar.serialized_on_wire == True
# Still has it after setting the default value
foo.bar.baz = 0
assert foo.bar.serialized_on_wire == True
# Manual override
foo.bar.serialized_on_wire = False
assert foo.bar.serialized_on_wire == False
# Can manually set it but defaults to false
foo.bar = Bar()
assert foo.bar.serialized_on_wire == False

View File

@ -1,8 +1,9 @@
import importlib import importlib
import pytest
import json import json
from .generate import get_files, get_base import pytest
from .generate import get_base, get_files
inputs = get_files(".bin") inputs = get_files(".bin")

View File

@ -1,27 +1,24 @@
#!/usr/bin/env python #!/usr/bin/env python
import sys
import itertools import itertools
import json import json
import os.path import os.path
import re import re
from typing import Tuple, Any, List import sys
import textwrap import textwrap
from typing import Any, List, Tuple
from jinja2 import Environment, PackageLoader
from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.descriptor_pb2 import ( from google.protobuf.descriptor_pb2 import (
DescriptorProto, DescriptorProto,
EnumDescriptorProto, EnumDescriptorProto,
FileDescriptorProto,
FieldDescriptorProto, FieldDescriptorProto,
FileDescriptorProto,
ServiceDescriptorProto, ServiceDescriptorProto,
) )
from google.protobuf.compiler import plugin_pb2 as plugin
from jinja2 import Environment, PackageLoader
def snake_case(value: str) -> str: def snake_case(value: str) -> str:
return ( return (

9
pyproject.toml Normal file
View File

@ -0,0 +1,9 @@
[tool.black]
target-version = ['py37']
[tool.isort]
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
line_length = 88