Better JSON 64-bit int handling, add way to determine whether a message was sent on the wire, various fixes
This commit is contained in:
parent
bbceff9341
commit
811b54cabb
@ -1,28 +1,28 @@
|
|||||||
from abc import ABC
|
import dataclasses
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import struct
|
import struct
|
||||||
|
from abc import ABC
|
||||||
from typing import (
|
from typing import (
|
||||||
get_type_hints,
|
|
||||||
AsyncGenerator,
|
|
||||||
Union,
|
|
||||||
Generator,
|
|
||||||
Any,
|
Any,
|
||||||
SupportsBytes,
|
AsyncGenerator,
|
||||||
List,
|
|
||||||
Tuple,
|
|
||||||
Callable,
|
Callable,
|
||||||
Type,
|
Dict,
|
||||||
|
Generator,
|
||||||
Iterable,
|
Iterable,
|
||||||
TypeVar,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
SupportsBytes,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
get_type_hints,
|
||||||
)
|
)
|
||||||
import dataclasses
|
|
||||||
|
|
||||||
import grpclib.client
|
import grpclib.client
|
||||||
import grpclib.const
|
import grpclib.const
|
||||||
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
# Proto 3 data types
|
# Proto 3 data types
|
||||||
TYPE_ENUM = "enum"
|
TYPE_ENUM = "enum"
|
||||||
TYPE_BOOL = "bool"
|
TYPE_BOOL = "bool"
|
||||||
@ -54,6 +54,9 @@ FIXED_TYPES = [
|
|||||||
TYPE_SFIXED64,
|
TYPE_SFIXED64,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Fields that are numerical 64-bit types
|
||||||
|
INT_64_TYPES = [TYPE_INT64, TYPE_UINT64, TYPE_SINT64, TYPE_FIXED64, TYPE_SFIXED64]
|
||||||
|
|
||||||
# Fields that are efficiently packed when
|
# Fields that are efficiently packed when
|
||||||
PACKED_TYPES = [
|
PACKED_TYPES = [
|
||||||
TYPE_ENUM,
|
TYPE_ENUM,
|
||||||
@ -275,7 +278,9 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
|
def _serialize_single(
|
||||||
|
field_number: int, proto_type: str, value: Any, *, serialize_empty: bool = False
|
||||||
|
) -> bytes:
|
||||||
"""Serializes a single field and value."""
|
"""Serializes a single field and value."""
|
||||||
value = _preprocess_single(proto_type, value)
|
value = _preprocess_single(proto_type, value)
|
||||||
|
|
||||||
@ -290,7 +295,7 @@ def _serialize_single(field_number: int, proto_type: str, value: Any) -> bytes:
|
|||||||
key = encode_varint((field_number << 3) | 1)
|
key = encode_varint((field_number << 3) | 1)
|
||||||
output += key + value
|
output += key + value
|
||||||
elif proto_type in WIRE_LEN_DELIM_TYPES:
|
elif proto_type in WIRE_LEN_DELIM_TYPES:
|
||||||
if len(value):
|
if len(value) or serialize_empty:
|
||||||
key = encode_varint((field_number << 3) | 2)
|
key = encode_varint((field_number << 3) | 2)
|
||||||
output += key + encode_varint(len(value)) + value
|
output += key + encode_varint(len(value)) + value
|
||||||
else:
|
else:
|
||||||
@ -362,6 +367,11 @@ class Message(ABC):
|
|||||||
to go between Python, binary and JSON protobuf message representations.
|
to go between Python, binary and JSON protobuf message representations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# True if this message was or should be serialized on the wire. This can
|
||||||
|
# be used to detect presence (e.g. optional wrapper message) and is used
|
||||||
|
# internally during parsing/serialization.
|
||||||
|
serialized_on_wire: bool
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
# Set a default value for each field in the class after `__init__` has
|
# Set a default value for each field in the class after `__init__` has
|
||||||
# already been run.
|
# already been run.
|
||||||
@ -389,6 +399,15 @@ class Message(ABC):
|
|||||||
|
|
||||||
setattr(self, field.name, value)
|
setattr(self, field.name, value)
|
||||||
|
|
||||||
|
# Now that all the defaults are set, reset it!
|
||||||
|
self.__dict__["serialized_on_wire"] = False
|
||||||
|
|
||||||
|
def __setattr__(self, attr: str, value: Any) -> None:
|
||||||
|
if attr != "serialized_on_wire":
|
||||||
|
# Track when a field has been set.
|
||||||
|
self.__dict__["serialized_on_wire"] = True
|
||||||
|
super().__setattr__(attr, value)
|
||||||
|
|
||||||
def __bytes__(self) -> bytes:
|
def __bytes__(self) -> bytes:
|
||||||
"""
|
"""
|
||||||
Get the binary encoded Protobuf representation of this instance.
|
Get the binary encoded Protobuf representation of this instance.
|
||||||
@ -429,7 +448,12 @@ class Message(ABC):
|
|||||||
# Default (zero) values are not serialized
|
# Default (zero) values are not serialized
|
||||||
continue
|
continue
|
||||||
|
|
||||||
output += _serialize_single(meta.number, meta.proto_type, value)
|
serialize_empty = False
|
||||||
|
if isinstance(value, Message) and value.serialized_on_wire:
|
||||||
|
serialize_empty = True
|
||||||
|
output += _serialize_single(
|
||||||
|
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
|
||||||
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -462,12 +486,13 @@ class Message(ABC):
|
|||||||
fmt = _pack_fmt(meta.proto_type)
|
fmt = _pack_fmt(meta.proto_type)
|
||||||
value = struct.unpack(fmt, value)[0]
|
value = struct.unpack(fmt, value)[0]
|
||||||
elif wire_type == WIRE_LEN_DELIM:
|
elif wire_type == WIRE_LEN_DELIM:
|
||||||
if meta.proto_type in [TYPE_STRING]:
|
if meta.proto_type == TYPE_STRING:
|
||||||
value = value.decode("utf-8")
|
value = value.decode("utf-8")
|
||||||
elif meta.proto_type in [TYPE_MESSAGE]:
|
elif meta.proto_type == TYPE_MESSAGE:
|
||||||
cls = self._cls_for(field)
|
cls = self._cls_for(field)
|
||||||
value = cls().parse(value)
|
value = cls().parse(value)
|
||||||
elif meta.proto_type in [TYPE_MAP]:
|
value.serialized_on_wire = True
|
||||||
|
elif meta.proto_type == TYPE_MAP:
|
||||||
# TODO: This is slow, use a cache to make it faster since each
|
# TODO: This is slow, use a cache to make it faster since each
|
||||||
# key/value pair will recreate the class.
|
# key/value pair will recreate the class.
|
||||||
assert meta.map_types
|
assert meta.map_types
|
||||||
@ -535,8 +560,6 @@ class Message(ABC):
|
|||||||
# TODO: handle unknown fields
|
# TODO: handle unknown fields
|
||||||
pass
|
pass
|
||||||
|
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
# For compatibility with other libraries.
|
# For compatibility with other libraries.
|
||||||
@ -549,7 +572,7 @@ class Message(ABC):
|
|||||||
Returns a dict representation of this message instance which can be
|
Returns a dict representation of this message instance which can be
|
||||||
used to serialize to e.g. JSON.
|
used to serialize to e.g. JSON.
|
||||||
"""
|
"""
|
||||||
output = {}
|
output: Dict[str, Any] = {}
|
||||||
for field in dataclasses.fields(self):
|
for field in dataclasses.fields(self):
|
||||||
meta = FieldMetadata.get(field)
|
meta = FieldMetadata.get(field)
|
||||||
v = getattr(self, field.name)
|
v = getattr(self, field.name)
|
||||||
@ -557,13 +580,9 @@ class Message(ABC):
|
|||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
# Convert each item.
|
# Convert each item.
|
||||||
v = [i.to_dict() for i in v]
|
v = [i.to_dict() for i in v]
|
||||||
# Filter out empty items which we won't serialize.
|
|
||||||
v = [i for i in v if i]
|
|
||||||
else:
|
|
||||||
v = v.to_dict()
|
|
||||||
|
|
||||||
if v:
|
|
||||||
output[field.name] = v
|
output[field.name] = v
|
||||||
|
elif v.serialized_on_wire:
|
||||||
|
output[field.name] = v.to_dict()
|
||||||
elif meta.proto_type == "map":
|
elif meta.proto_type == "map":
|
||||||
for k in v:
|
for k in v:
|
||||||
if hasattr(v[k], "to_dict"):
|
if hasattr(v[k], "to_dict"):
|
||||||
@ -572,7 +591,13 @@ class Message(ABC):
|
|||||||
if v:
|
if v:
|
||||||
output[field.name] = v
|
output[field.name] = v
|
||||||
elif v != get_default(meta.proto_type):
|
elif v != get_default(meta.proto_type):
|
||||||
output[field.name] = v
|
if meta.proto_type in INT_64_TYPES:
|
||||||
|
if isinstance(v, list):
|
||||||
|
output[field.name] = [str(n) for n in v]
|
||||||
|
else:
|
||||||
|
output[field.name] = str(v)
|
||||||
|
else:
|
||||||
|
output[field.name] = v
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def from_dict(self: T, value: dict) -> T:
|
def from_dict(self: T, value: dict) -> T:
|
||||||
@ -580,6 +605,7 @@ class Message(ABC):
|
|||||||
Parse the key/value pairs in `value` into this message instance. This
|
Parse the key/value pairs in `value` into this message instance. This
|
||||||
returns the instance itself and is therefore assignable and chainable.
|
returns the instance itself and is therefore assignable and chainable.
|
||||||
"""
|
"""
|
||||||
|
self.serialized_on_wire = True
|
||||||
for field in dataclasses.fields(self):
|
for field in dataclasses.fields(self):
|
||||||
meta = FieldMetadata.get(field)
|
meta = FieldMetadata.get(field)
|
||||||
if field.name in value and value[field.name] is not None:
|
if field.name in value and value[field.name] is not None:
|
||||||
@ -598,7 +624,13 @@ class Message(ABC):
|
|||||||
for k in value[field.name]:
|
for k in value[field.name]:
|
||||||
v[k] = cls().from_dict(value[field.name][k])
|
v[k] = cls().from_dict(value[field.name][k])
|
||||||
else:
|
else:
|
||||||
setattr(self, field.name, value[field.name])
|
v = value[field.name]
|
||||||
|
if meta.proto_type in INT_64_TYPES:
|
||||||
|
if isinstance(value[field.name], list):
|
||||||
|
v = [int(n) for n in value[field.name]]
|
||||||
|
else:
|
||||||
|
v = int(value[field.name])
|
||||||
|
setattr(self, field.name, v)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_json(self) -> str:
|
def to_json(self) -> str:
|
||||||
@ -613,9 +645,6 @@ class Message(ABC):
|
|||||||
return self.from_dict(json.loads(value))
|
return self.from_dict(json.loads(value))
|
||||||
|
|
||||||
|
|
||||||
ResponseType = TypeVar("ResponseType", bound="Message")
|
|
||||||
|
|
||||||
|
|
||||||
class ServiceStub(ABC):
|
class ServiceStub(ABC):
|
||||||
"""
|
"""
|
||||||
Base class for async gRPC service stubs.
|
Base class for async gRPC service stubs.
|
||||||
|
@ -1,19 +1,21 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
import importlib
|
||||||
|
import json
|
||||||
import os # isort: skip
|
import os # isort: skip
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from typing import Generator, Tuple
|
||||||
|
|
||||||
|
from google.protobuf import symbol_database
|
||||||
|
from google.protobuf.descriptor_pool import DescriptorPool
|
||||||
|
from google.protobuf.json_format import MessageToJson, Parse
|
||||||
|
|
||||||
# Force pure-python implementation instead of C++, otherwise imports
|
# Force pure-python implementation instead of C++, otherwise imports
|
||||||
# break things because we can't properly reset the symbol database.
|
# break things because we can't properly reset the symbol database.
|
||||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||||
|
|
||||||
|
|
||||||
import subprocess
|
|
||||||
import importlib
|
|
||||||
import sys
|
|
||||||
from typing import Generator, Tuple
|
|
||||||
|
|
||||||
from google.protobuf.json_format import Parse
|
|
||||||
from google.protobuf import symbol_database
|
|
||||||
from google.protobuf.descriptor_pool import DescriptorPool
|
|
||||||
|
|
||||||
root = os.path.dirname(os.path.realpath(__file__))
|
root = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
@ -68,5 +70,10 @@ if __name__ == "__main__":
|
|||||||
print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}")
|
print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}")
|
||||||
|
|
||||||
imported = importlib.import_module(f"{parts[0]}_pb2")
|
imported = importlib.import_module(f"{parts[0]}_pb2")
|
||||||
serialized = Parse(open(filename).read(), imported.Test()).SerializeToString()
|
parsed = Parse(open(filename).read(), imported.Test())
|
||||||
|
serialized = parsed.SerializeToString()
|
||||||
|
serialized_json = MessageToJson(
|
||||||
|
parsed, preserving_proto_field_name=True, use_integers_for_enums=True
|
||||||
|
)
|
||||||
|
assert json.loads(serialized_json) == json.load(open(filename))
|
||||||
open(out, "wb").write(serialized)
|
open(out, "wb").write(serialized)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
{
|
{
|
||||||
"nested": {
|
"nested": {
|
||||||
"count": 150
|
"count": 150
|
||||||
}
|
},
|
||||||
|
"sibling": {}
|
||||||
}
|
}
|
||||||
|
@ -10,6 +10,7 @@ message Test {
|
|||||||
|
|
||||||
Nested nested = 1;
|
Nested nested = 1;
|
||||||
Sibling sibling = 2;
|
Sibling sibling = 2;
|
||||||
|
Sibling sibling2 = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Sibling {
|
message Sibling {
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"counts": [1, 2, -1, -2],
|
"counts": [1, 2, -1, -2],
|
||||||
"signed": [1, 2, -1, -2],
|
"signed": ["1", "2", "-1", "-2"],
|
||||||
"fixed": [1.0, 2.7, 3.4]
|
"fixed": [1.0, 2.7, 3.4]
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
{
|
{
|
||||||
"signed_32": -150,
|
"signed_32": -150,
|
||||||
"signed_64": -150
|
"signed_64": "-150"
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
{
|
{
|
||||||
"signed_32": 150,
|
"signed_32": 150,
|
||||||
"signed_64": 150
|
"signed_64": "150"
|
||||||
}
|
}
|
||||||
|
32
betterproto/tests/test_features.py
Normal file
32
betterproto/tests/test_features.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import betterproto
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_field():
|
||||||
|
@dataclass
|
||||||
|
class Bar(betterproto.Message):
|
||||||
|
baz: int = betterproto.int32_field(1)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Foo(betterproto.Message):
|
||||||
|
bar: Bar = betterproto.message_field(1)
|
||||||
|
|
||||||
|
# Unset by default
|
||||||
|
foo = Foo()
|
||||||
|
assert foo.bar.serialized_on_wire == False
|
||||||
|
|
||||||
|
# Serialized after setting something
|
||||||
|
foo.bar.baz = 1
|
||||||
|
assert foo.bar.serialized_on_wire == True
|
||||||
|
|
||||||
|
# Still has it after setting the default value
|
||||||
|
foo.bar.baz = 0
|
||||||
|
assert foo.bar.serialized_on_wire == True
|
||||||
|
|
||||||
|
# Manual override
|
||||||
|
foo.bar.serialized_on_wire = False
|
||||||
|
assert foo.bar.serialized_on_wire == False
|
||||||
|
|
||||||
|
# Can manually set it but defaults to false
|
||||||
|
foo.bar = Bar()
|
||||||
|
assert foo.bar.serialized_on_wire == False
|
@ -1,8 +1,9 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import pytest
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from .generate import get_files, get_base
|
import pytest
|
||||||
|
|
||||||
|
from .generate import get_base, get_files
|
||||||
|
|
||||||
inputs = get_files(".bin")
|
inputs = get_files(".bin")
|
||||||
|
|
||||||
|
@ -1,27 +1,24 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
import os.path
|
import os.path
|
||||||
import re
|
import re
|
||||||
from typing import Tuple, Any, List
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from typing import Any, List, Tuple
|
||||||
|
|
||||||
|
from jinja2 import Environment, PackageLoader
|
||||||
|
|
||||||
|
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||||
from google.protobuf.descriptor_pb2 import (
|
from google.protobuf.descriptor_pb2 import (
|
||||||
DescriptorProto,
|
DescriptorProto,
|
||||||
EnumDescriptorProto,
|
EnumDescriptorProto,
|
||||||
FileDescriptorProto,
|
|
||||||
FieldDescriptorProto,
|
FieldDescriptorProto,
|
||||||
|
FileDescriptorProto,
|
||||||
ServiceDescriptorProto,
|
ServiceDescriptorProto,
|
||||||
)
|
)
|
||||||
|
|
||||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
|
||||||
|
|
||||||
|
|
||||||
from jinja2 import Environment, PackageLoader
|
|
||||||
|
|
||||||
|
|
||||||
def snake_case(value: str) -> str:
|
def snake_case(value: str) -> str:
|
||||||
return (
|
return (
|
||||||
|
9
pyproject.toml
Normal file
9
pyproject.toml
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
[tool.black]
|
||||||
|
target-version = ['py37']
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
multi_line_output = 3
|
||||||
|
include_trailing_comma = true
|
||||||
|
force_grid_wrap = 0
|
||||||
|
use_parentheses = true
|
||||||
|
line_length = 88
|
Loading…
x
Reference in New Issue
Block a user