Better JSON 64-bit int handling, add way to determine whether a message was sent on the wire, various fixes
This commit is contained in:
parent
bbceff9341
commit
811b54cabb
@ -1,28 +1,28 @@
|
||||
from abc import ABC
|
||||
import dataclasses
|
||||
import inspect
|
||||
import json
|
||||
import struct
|
||||
from abc import ABC
|
||||
from typing import (
|
||||
get_type_hints,
|
||||
AsyncGenerator,
|
||||
Union,
|
||||
Generator,
|
||||
Any,
|
||||
SupportsBytes,
|
||||
List,
|
||||
Tuple,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Type,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
TypeVar,
|
||||
List,
|
||||
Optional,
|
||||
SupportsBytes,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_type_hints,
|
||||
)
|
||||
import dataclasses
|
||||
|
||||
import grpclib.client
|
||||
import grpclib.const
|
||||
|
||||
import inspect
|
||||
|
||||
# Proto 3 data types
|
||||
TYPE_ENUM = "enum"
|
||||
TYPE_BOOL = "bool"
|
||||
@ -54,6 +54,9 @@ FIXED_TYPES = [
|
||||
TYPE_SFIXED64,
|
||||
]
|
||||
|
||||
# Fields that are numerical 64-bit types
|
||||
INT_64_TYPES = [TYPE_INT64, TYPE_UINT64, TYPE_SINT64, TYPE_FIXED64, TYPE_SFIXED64]
|
||||
|
||||
# Fields that are efficiently packed when
|
||||
PACKED_TYPES = [
|
||||
TYPE_ENUM,
|
||||
@ -275,7 +278,9 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
|
||||
return value
|
||||
|
||||
|
||||
def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
|
||||
def _serialize_single(
|
||||
field_number: int, proto_type: str, value: Any, *, serialize_empty: bool = False
|
||||
) -> bytes:
|
||||
"""Serializes a single field and value."""
|
||||
value = _preprocess_single(proto_type, value)
|
||||
|
||||
@ -290,7 +295,7 @@ def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
|
||||
key = encode_varint((field_number << 3) | 1)
|
||||
output += key + value
|
||||
elif proto_type in WIRE_LEN_DELIM_TYPES:
|
||||
if len(value):
|
||||
if len(value) or serialize_empty:
|
||||
key = encode_varint((field_number << 3) | 2)
|
||||
output += key + encode_varint(len(value)) + value
|
||||
else:
|
||||
@ -362,6 +367,11 @@ class Message(ABC):
|
||||
to go between Python, binary and JSON protobuf message representations.
|
||||
"""
|
||||
|
||||
# True if this message was or should be serialized on the wire. This can
|
||||
# be used to detect presence (e.g. optional wrapper message) and is used
|
||||
# internally during parsing/serialization.
|
||||
serialized_on_wire: bool
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Set a default value for each field in the class after `__init__` has
|
||||
# already been run.
|
||||
@ -389,6 +399,15 @@ class Message(ABC):
|
||||
|
||||
setattr(self, field.name, value)
|
||||
|
||||
# Now that all the defaults are set, reset it!
|
||||
self.__dict__["serialized_on_wire"] = False
|
||||
|
||||
def __setattr__(self, attr: str, value: Any) -> None:
|
||||
if attr != "serialized_on_wire":
|
||||
# Track when a field has been set.
|
||||
self.__dict__["serialized_on_wire"] = True
|
||||
super().__setattr__(attr, value)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
"""
|
||||
Get the binary encoded Protobuf representation of this instance.
|
||||
@ -429,7 +448,12 @@ class Message(ABC):
|
||||
# Default (zero) values are not serialized
|
||||
continue
|
||||
|
||||
output += _serialize_single(meta.number, meta.proto_type, value)
|
||||
serialize_empty = False
|
||||
if isinstance(value, Message) and value.serialized_on_wire:
|
||||
serialize_empty = True
|
||||
output += _serialize_single(
|
||||
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@ -462,12 +486,13 @@ class Message(ABC):
|
||||
fmt = _pack_fmt(meta.proto_type)
|
||||
value = struct.unpack(fmt, value)[0]
|
||||
elif wire_type == WIRE_LEN_DELIM:
|
||||
if meta.proto_type in [TYPE_STRING]:
|
||||
if meta.proto_type == TYPE_STRING:
|
||||
value = value.decode("utf-8")
|
||||
elif meta.proto_type in [TYPE_MESSAGE]:
|
||||
elif meta.proto_type == TYPE_MESSAGE:
|
||||
cls = self._cls_for(field)
|
||||
value = cls().parse(value)
|
||||
elif meta.proto_type in [TYPE_MAP]:
|
||||
value.serialized_on_wire = True
|
||||
elif meta.proto_type == TYPE_MAP:
|
||||
# TODO: This is slow, use a cache to make it faster since each
|
||||
# key/value pair will recreate the class.
|
||||
assert meta.map_types
|
||||
@ -535,8 +560,6 @@ class Message(ABC):
|
||||
# TODO: handle unknown fields
|
||||
pass
|
||||
|
||||
from typing import cast
|
||||
|
||||
return self
|
||||
|
||||
# For compatibility with other libraries.
|
||||
@ -549,7 +572,7 @@ class Message(ABC):
|
||||
Returns a dict representation of this message instance which can be
|
||||
used to serialize to e.g. JSON.
|
||||
"""
|
||||
output = {}
|
||||
output: Dict[str, Any] = {}
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
v = getattr(self, field.name)
|
||||
@ -557,13 +580,9 @@ class Message(ABC):
|
||||
if isinstance(v, list):
|
||||
# Convert each item.
|
||||
v = [i.to_dict() for i in v]
|
||||
# Filter out empty items which we won't serialize.
|
||||
v = [i for i in v if i]
|
||||
else:
|
||||
v = v.to_dict()
|
||||
|
||||
if v:
|
||||
output[field.name] = v
|
||||
elif v.serialized_on_wire:
|
||||
output[field.name] = v.to_dict()
|
||||
elif meta.proto_type == "map":
|
||||
for k in v:
|
||||
if hasattr(v[k], "to_dict"):
|
||||
@ -572,7 +591,13 @@ class Message(ABC):
|
||||
if v:
|
||||
output[field.name] = v
|
||||
elif v != get_default(meta.proto_type):
|
||||
output[field.name] = v
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(v, list):
|
||||
output[field.name] = [str(n) for n in v]
|
||||
else:
|
||||
output[field.name] = str(v)
|
||||
else:
|
||||
output[field.name] = v
|
||||
return output
|
||||
|
||||
def from_dict(self: T, value: dict) -> T:
|
||||
@ -580,6 +605,7 @@ class Message(ABC):
|
||||
Parse the key/value pairs in `value` into this message instance. This
|
||||
returns the instance itself and is therefore assignable and chainable.
|
||||
"""
|
||||
self.serialized_on_wire = True
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
if field.name in value and value[field.name] is not None:
|
||||
@ -598,7 +624,13 @@ class Message(ABC):
|
||||
for k in value[field.name]:
|
||||
v[k] = cls().from_dict(value[field.name][k])
|
||||
else:
|
||||
setattr(self, field.name, value[field.name])
|
||||
v = value[field.name]
|
||||
if meta.proto_type in INT_64_TYPES:
|
||||
if isinstance(value[field.name], list):
|
||||
v = [int(n) for n in value[field.name]]
|
||||
else:
|
||||
v = int(value[field.name])
|
||||
setattr(self, field.name, v)
|
||||
return self
|
||||
|
||||
def to_json(self) -> str:
|
||||
@ -613,9 +645,6 @@ class Message(ABC):
|
||||
return self.from_dict(json.loads(value))
|
||||
|
||||
|
||||
ResponseType = TypeVar("ResponseType", bound="Message")
|
||||
|
||||
|
||||
class ServiceStub(ABC):
|
||||
"""
|
||||
Base class for async gRPC service stubs.
|
||||
|
@ -1,19 +1,21 @@
|
||||
#!/usr/bin/env python
|
||||
import importlib
|
||||
import json
|
||||
import os # isort: skip
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Generator, Tuple
|
||||
|
||||
from google.protobuf import symbol_database
|
||||
from google.protobuf.descriptor_pool import DescriptorPool
|
||||
from google.protobuf.json_format import MessageToJson, Parse
|
||||
|
||||
# Force pure-python implementation instead of C++, otherwise imports
|
||||
# break things because we can't properly reset the symbol database.
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
|
||||
|
||||
import subprocess
|
||||
import importlib
|
||||
import sys
|
||||
from typing import Generator, Tuple
|
||||
|
||||
from google.protobuf.json_format import Parse
|
||||
from google.protobuf import symbol_database
|
||||
from google.protobuf.descriptor_pool import DescriptorPool
|
||||
|
||||
root = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
@ -68,5 +70,10 @@ if __name__ == "__main__":
|
||||
print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}")
|
||||
|
||||
imported = importlib.import_module(f"{parts[0]}_pb2")
|
||||
serialized = Parse(open(filename).read(), imported.Test()).SerializeToString()
|
||||
parsed = Parse(open(filename).read(), imported.Test())
|
||||
serialized = parsed.SerializeToString()
|
||||
serialized_json = MessageToJson(
|
||||
parsed, preserving_proto_field_name=True, use_integers_for_enums=True
|
||||
)
|
||||
assert json.loads(serialized_json) == json.load(open(filename))
|
||||
open(out, "wb").write(serialized)
|
||||
|
@ -1,5 +1,6 @@
|
||||
{
|
||||
"nested": {
|
||||
"count": 150
|
||||
}
|
||||
},
|
||||
"sibling": {}
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ message Test {
|
||||
|
||||
Nested nested = 1;
|
||||
Sibling sibling = 2;
|
||||
Sibling sibling2 = 3;
|
||||
}
|
||||
|
||||
message Sibling {
|
||||
|
@ -1,5 +1,5 @@
|
||||
{
|
||||
"counts": [1, 2, -1, -2],
|
||||
"signed": [1, 2, -1, -2],
|
||||
"signed": ["1", "2", "-1", "-2"],
|
||||
"fixed": [1.0, 2.7, 3.4]
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
{
|
||||
"signed_32": -150,
|
||||
"signed_64": -150
|
||||
"signed_64": "-150"
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
{
|
||||
"signed_32": 150,
|
||||
"signed_64": 150
|
||||
"signed_64": "150"
|
||||
}
|
||||
|
32
betterproto/tests/test_features.py
Normal file
32
betterproto/tests/test_features.py
Normal file
@ -0,0 +1,32 @@
|
||||
import betterproto
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def test_has_field():
|
||||
@dataclass
|
||||
class Bar(betterproto.Message):
|
||||
baz: int = betterproto.int32_field(1)
|
||||
|
||||
@dataclass
|
||||
class Foo(betterproto.Message):
|
||||
bar: Bar = betterproto.message_field(1)
|
||||
|
||||
# Unset by default
|
||||
foo = Foo()
|
||||
assert foo.bar.serialized_on_wire == False
|
||||
|
||||
# Serialized after setting something
|
||||
foo.bar.baz = 1
|
||||
assert foo.bar.serialized_on_wire == True
|
||||
|
||||
# Still has it after setting the default value
|
||||
foo.bar.baz = 0
|
||||
assert foo.bar.serialized_on_wire == True
|
||||
|
||||
# Manual override
|
||||
foo.bar.serialized_on_wire = False
|
||||
assert foo.bar.serialized_on_wire == False
|
||||
|
||||
# Can manually set it but defaults to false
|
||||
foo.bar = Bar()
|
||||
assert foo.bar.serialized_on_wire == False
|
@ -1,8 +1,9 @@
|
||||
import importlib
|
||||
import pytest
|
||||
import json
|
||||
|
||||
from .generate import get_files, get_base
|
||||
import pytest
|
||||
|
||||
from .generate import get_base, get_files
|
||||
|
||||
inputs = get_files(".bin")
|
||||
|
||||
|
@ -1,27 +1,24 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import sys
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import os.path
|
||||
import re
|
||||
from typing import Tuple, Any, List
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from jinja2 import Environment, PackageLoader
|
||||
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
from google.protobuf.descriptor_pb2 import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FileDescriptorProto,
|
||||
FieldDescriptorProto,
|
||||
FileDescriptorProto,
|
||||
ServiceDescriptorProto,
|
||||
)
|
||||
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
|
||||
|
||||
from jinja2 import Environment, PackageLoader
|
||||
|
||||
|
||||
def snake_case(value: str) -> str:
|
||||
return (
|
||||
|
9
pyproject.toml
Normal file
9
pyproject.toml
Normal file
@ -0,0 +1,9 @@
|
||||
[tool.black]
|
||||
target-version = ['py37']
|
||||
|
||||
[tool.isort]
|
||||
multi_line_output = 3
|
||||
include_trailing_comma = true
|
||||
force_grid_wrap = 0
|
||||
use_parentheses = true
|
||||
line_length = 88
|
Loading…
x
Reference in New Issue
Block a user