Merge pull request #46 from jameslan/perf/class-cache
Improve performance of serialize/deserialize by caching type information of fields in class
This commit is contained in:
commit
4a2baf3f0a
@ -120,7 +120,11 @@ WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
|||||||
|
|
||||||
|
|
||||||
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
|
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
|
||||||
DATETIME_ZERO = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
def datetime_default_gen():
|
||||||
|
return datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
DATETIME_ZERO = datetime_default_gen()
|
||||||
|
|
||||||
|
|
||||||
class Casing(enum.Enum):
|
class Casing(enum.Enum):
|
||||||
@ -428,6 +432,63 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
|||||||
T = TypeVar("T", bound="Message")
|
T = TypeVar("T", bound="Message")
|
||||||
|
|
||||||
|
|
||||||
|
class ProtoClassMetadata:
|
||||||
|
cls: Type["Message"]
|
||||||
|
|
||||||
|
def __init__(self, cls: Type["Message"]):
|
||||||
|
self.cls = cls
|
||||||
|
by_field = {}
|
||||||
|
by_group = {}
|
||||||
|
|
||||||
|
for field in dataclasses.fields(cls):
|
||||||
|
meta = FieldMetadata.get(field)
|
||||||
|
|
||||||
|
if meta.group:
|
||||||
|
# This is part of a one-of group.
|
||||||
|
by_field[field.name] = meta.group
|
||||||
|
|
||||||
|
by_group.setdefault(meta.group, set()).add(field)
|
||||||
|
|
||||||
|
self.oneof_group_by_field = by_field
|
||||||
|
self.oneof_field_by_group = by_group
|
||||||
|
|
||||||
|
self.init_default_gen()
|
||||||
|
self.init_cls_by_field()
|
||||||
|
|
||||||
|
def init_default_gen(self):
|
||||||
|
default_gen = {}
|
||||||
|
|
||||||
|
for field in dataclasses.fields(self.cls):
|
||||||
|
meta = FieldMetadata.get(field)
|
||||||
|
default_gen[field.name] = self.cls._get_field_default_gen(field, meta)
|
||||||
|
|
||||||
|
self.default_gen = default_gen
|
||||||
|
|
||||||
|
def init_cls_by_field(self):
|
||||||
|
field_cls = {}
|
||||||
|
|
||||||
|
for field in dataclasses.fields(self.cls):
|
||||||
|
meta = FieldMetadata.get(field)
|
||||||
|
if meta.proto_type == TYPE_MAP:
|
||||||
|
assert meta.map_types
|
||||||
|
kt = self.cls._cls_for(field, index=0)
|
||||||
|
vt = self.cls._cls_for(field, index=1)
|
||||||
|
Entry = dataclasses.make_dataclass(
|
||||||
|
"Entry",
|
||||||
|
[
|
||||||
|
("key", kt, dataclass_field(1, meta.map_types[0])),
|
||||||
|
("value", vt, dataclass_field(2, meta.map_types[1])),
|
||||||
|
],
|
||||||
|
bases=(Message,),
|
||||||
|
)
|
||||||
|
field_cls[field.name] = Entry
|
||||||
|
field_cls[field.name + ".value"] = vt
|
||||||
|
else:
|
||||||
|
field_cls[field.name] = self.cls._cls_for(field)
|
||||||
|
|
||||||
|
self.cls_by_field = field_cls
|
||||||
|
|
||||||
|
|
||||||
class Message(ABC):
|
class Message(ABC):
|
||||||
"""
|
"""
|
||||||
A protobuf message base class. Generated code will inherit from this and
|
A protobuf message base class. Generated code will inherit from this and
|
||||||
@ -445,17 +506,12 @@ class Message(ABC):
|
|||||||
|
|
||||||
# Set a default value for each field in the class after `__init__` has
|
# Set a default value for each field in the class after `__init__` has
|
||||||
# already been run.
|
# already been run.
|
||||||
group_map: Dict[str, dict] = {"fields": {}, "groups": {}}
|
group_map: Dict[str, dataclasses.Field] = {}
|
||||||
for field in dataclasses.fields(self):
|
for field in dataclasses.fields(self):
|
||||||
meta = FieldMetadata.get(field)
|
meta = FieldMetadata.get(field)
|
||||||
|
|
||||||
if meta.group:
|
if meta.group:
|
||||||
# This is part of a one-of group.
|
group_map.setdefault(meta.group)
|
||||||
group_map["fields"][field.name] = meta.group
|
|
||||||
|
|
||||||
if meta.group not in group_map["groups"]:
|
|
||||||
group_map["groups"][meta.group] = {"current": None, "fields": set()}
|
|
||||||
group_map["groups"][meta.group]["fields"].add(field)
|
|
||||||
|
|
||||||
if getattr(self, field.name) != PLACEHOLDER:
|
if getattr(self, field.name) != PLACEHOLDER:
|
||||||
# Skip anything not set to the sentinel value
|
# Skip anything not set to the sentinel value
|
||||||
@ -463,7 +519,7 @@ class Message(ABC):
|
|||||||
|
|
||||||
if meta.group:
|
if meta.group:
|
||||||
# This was set, so make it the selected value of the one-of.
|
# This was set, so make it the selected value of the one-of.
|
||||||
group_map["groups"][meta.group]["current"] = field
|
group_map[meta.group] = field
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -479,19 +535,33 @@ class Message(ABC):
|
|||||||
# Track when a field has been set.
|
# Track when a field has been set.
|
||||||
self.__dict__["_serialized_on_wire"] = True
|
self.__dict__["_serialized_on_wire"] = True
|
||||||
|
|
||||||
if attr in getattr(self, "_group_map", {}).get("fields", {}):
|
if hasattr(self, "_group_map"): # __post_init__ had already run
|
||||||
group = self._group_map["fields"][attr]
|
if attr in self._betterproto.oneof_group_by_field:
|
||||||
for field in self._group_map["groups"][group]["fields"]:
|
group = self._betterproto.oneof_group_by_field[attr]
|
||||||
if field.name == attr:
|
for field in self._betterproto.oneof_field_by_group[group]:
|
||||||
self._group_map["groups"][group]["current"] = field
|
if field.name == attr:
|
||||||
else:
|
self._group_map[group] = field
|
||||||
super().__setattr__(
|
else:
|
||||||
field.name,
|
super().__setattr__(
|
||||||
self._get_field_default(field, FieldMetadata.get(field)),
|
field.name,
|
||||||
)
|
self._get_field_default(field, FieldMetadata.get(field)),
|
||||||
|
)
|
||||||
|
|
||||||
super().__setattr__(attr, value)
|
super().__setattr__(attr, value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _betterproto(self):
|
||||||
|
"""
|
||||||
|
Lazy initialize metadata for each protobuf class.
|
||||||
|
It may be initialized multiple times in a multi-threaded environment,
|
||||||
|
but that won't affect the correctness.
|
||||||
|
"""
|
||||||
|
meta = getattr(self.__class__, "_betterproto_meta", None)
|
||||||
|
if not meta:
|
||||||
|
meta = ProtoClassMetadata(self.__class__)
|
||||||
|
self.__class__._betterproto_meta = meta
|
||||||
|
return meta
|
||||||
|
|
||||||
def __bytes__(self) -> bytes:
|
def __bytes__(self) -> bytes:
|
||||||
"""
|
"""
|
||||||
Get the binary encoded Protobuf representation of this instance.
|
Get the binary encoded Protobuf representation of this instance.
|
||||||
@ -510,7 +580,7 @@ class Message(ABC):
|
|||||||
# currently set in a `oneof` group, so it must be serialized even
|
# currently set in a `oneof` group, so it must be serialized even
|
||||||
# if the value is the default zero value.
|
# if the value is the default zero value.
|
||||||
selected_in_group = False
|
selected_in_group = False
|
||||||
if meta.group and self._group_map["groups"][meta.group]["current"] == field:
|
if meta.group and self._group_map[meta.group] == field:
|
||||||
selected_in_group = True
|
selected_in_group = True
|
||||||
|
|
||||||
serialize_empty = False
|
serialize_empty = False
|
||||||
@ -562,47 +632,50 @@ class Message(ABC):
|
|||||||
# For compatibility with other libraries
|
# For compatibility with other libraries
|
||||||
SerializeToString = __bytes__
|
SerializeToString = __bytes__
|
||||||
|
|
||||||
def _type_hint(self, field_name: str) -> Type:
|
@classmethod
|
||||||
module = inspect.getmodule(self.__class__)
|
def _type_hint(cls, field_name: str) -> Type:
|
||||||
type_hints = get_type_hints(self.__class__, vars(module))
|
module = inspect.getmodule(cls)
|
||||||
|
type_hints = get_type_hints(cls, vars(module))
|
||||||
return type_hints[field_name]
|
return type_hints[field_name]
|
||||||
|
|
||||||
def _cls_for(self, field: dataclasses.Field, index: int = 0) -> Type:
|
@classmethod
|
||||||
|
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
|
||||||
"""Get the message class for a field from the type hints."""
|
"""Get the message class for a field from the type hints."""
|
||||||
cls = self._type_hint(field.name)
|
field_cls = cls._type_hint(field.name)
|
||||||
if hasattr(cls, "__args__") and index >= 0:
|
if hasattr(field_cls, "__args__") and index >= 0:
|
||||||
cls = cls.__args__[index]
|
field_cls = field_cls.__args__[index]
|
||||||
return cls
|
return field_cls
|
||||||
|
|
||||||
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
|
def _get_field_default(self, field: dataclasses.Field, meta: FieldMetadata) -> Any:
|
||||||
t = self._type_hint(field.name)
|
return self._betterproto.default_gen[field.name]()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_field_default_gen(cls, field: dataclasses.Field, meta: FieldMetadata) -> Any:
|
||||||
|
t = cls._type_hint(field.name)
|
||||||
|
|
||||||
value: Any = 0
|
|
||||||
if hasattr(t, "__origin__"):
|
if hasattr(t, "__origin__"):
|
||||||
if t.__origin__ in (dict, Dict):
|
if t.__origin__ in (dict, Dict):
|
||||||
# This is some kind of map (dict in Python).
|
# This is some kind of map (dict in Python).
|
||||||
value = {}
|
return dict
|
||||||
elif t.__origin__ in (list, List):
|
elif t.__origin__ in (list, List):
|
||||||
# This is some kind of list (repeated) field.
|
# This is some kind of list (repeated) field.
|
||||||
value = []
|
return list
|
||||||
elif t.__origin__ == Union and t.__args__[1] == type(None):
|
elif t.__origin__ == Union and t.__args__[1] == type(None):
|
||||||
# This is an optional (wrapped) field. For setting the default we
|
# This is an optional (wrapped) field. For setting the default we
|
||||||
# really don't care what kind of field it is.
|
# really don't care what kind of field it is.
|
||||||
value = None
|
return type(None)
|
||||||
else:
|
else:
|
||||||
value = t()
|
return t
|
||||||
elif issubclass(t, Enum):
|
elif issubclass(t, Enum):
|
||||||
# Enums always default to zero.
|
# Enums always default to zero.
|
||||||
value = 0
|
return int
|
||||||
elif t == datetime:
|
elif t == datetime:
|
||||||
# Offsets are relative to 1970-01-01T00:00:00Z
|
# Offsets are relative to 1970-01-01T00:00:00Z
|
||||||
value = DATETIME_ZERO
|
return datetime_default_gen
|
||||||
else:
|
else:
|
||||||
# This is either a primitive scalar or another message type. Calling
|
# This is either a primitive scalar or another message type. Calling
|
||||||
# it should result in its zero value.
|
# it should result in its zero value.
|
||||||
value = t()
|
return t
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
def _postprocess_single(
|
def _postprocess_single(
|
||||||
self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any
|
self, wire_type: int, meta: FieldMetadata, field: dataclasses.Field, value: Any
|
||||||
@ -627,7 +700,7 @@ class Message(ABC):
|
|||||||
if meta.proto_type == TYPE_STRING:
|
if meta.proto_type == TYPE_STRING:
|
||||||
value = value.decode("utf-8")
|
value = value.decode("utf-8")
|
||||||
elif meta.proto_type == TYPE_MESSAGE:
|
elif meta.proto_type == TYPE_MESSAGE:
|
||||||
cls = self._cls_for(field)
|
cls = self._betterproto.cls_by_field[field.name]
|
||||||
|
|
||||||
if cls == datetime:
|
if cls == datetime:
|
||||||
value = _Timestamp().parse(value).to_datetime()
|
value = _Timestamp().parse(value).to_datetime()
|
||||||
@ -641,20 +714,7 @@ class Message(ABC):
|
|||||||
value = cls().parse(value)
|
value = cls().parse(value)
|
||||||
value._serialized_on_wire = True
|
value._serialized_on_wire = True
|
||||||
elif meta.proto_type == TYPE_MAP:
|
elif meta.proto_type == TYPE_MAP:
|
||||||
# TODO: This is slow, use a cache to make it faster since each
|
value = self._betterproto.cls_by_field[field.name]().parse(value)
|
||||||
# key/value pair will recreate the class.
|
|
||||||
assert meta.map_types
|
|
||||||
kt = self._cls_for(field, index=0)
|
|
||||||
vt = self._cls_for(field, index=1)
|
|
||||||
Entry = dataclasses.make_dataclass(
|
|
||||||
"Entry",
|
|
||||||
[
|
|
||||||
("key", kt, dataclass_field(1, meta.map_types[0])),
|
|
||||||
("value", vt, dataclass_field(2, meta.map_types[1])),
|
|
||||||
],
|
|
||||||
bases=(Message,),
|
|
||||||
)
|
|
||||||
value = Entry().parse(value)
|
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@ -769,7 +829,7 @@ class Message(ABC):
|
|||||||
else:
|
else:
|
||||||
output[cased_name] = b64encode(v).decode("utf8")
|
output[cased_name] = b64encode(v).decode("utf8")
|
||||||
elif meta.proto_type == TYPE_ENUM:
|
elif meta.proto_type == TYPE_ENUM:
|
||||||
enum_values = list(self._cls_for(field)) # type: ignore
|
enum_values = list(self._betterproto.cls_by_field[field.name]) # type: ignore
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
output[cased_name] = [enum_values[e].name for e in v]
|
output[cased_name] = [enum_values[e].name for e in v]
|
||||||
else:
|
else:
|
||||||
@ -795,7 +855,7 @@ class Message(ABC):
|
|||||||
if meta.proto_type == "message":
|
if meta.proto_type == "message":
|
||||||
v = getattr(self, field.name)
|
v = getattr(self, field.name)
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
cls = self._cls_for(field)
|
cls = self._betterproto.cls_by_field[field.name]
|
||||||
for i in range(len(value[key])):
|
for i in range(len(value[key])):
|
||||||
v.append(cls().from_dict(value[key][i]))
|
v.append(cls().from_dict(value[key][i]))
|
||||||
elif isinstance(v, datetime):
|
elif isinstance(v, datetime):
|
||||||
@ -812,7 +872,7 @@ class Message(ABC):
|
|||||||
v.from_dict(value[key])
|
v.from_dict(value[key])
|
||||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||||
v = getattr(self, field.name)
|
v = getattr(self, field.name)
|
||||||
cls = self._cls_for(field, index=1)
|
cls = self._betterproto.cls_by_field[field.name + ".value"]
|
||||||
for k in value[key]:
|
for k in value[key]:
|
||||||
v[k] = cls().from_dict(value[key][k])
|
v[k] = cls().from_dict(value[key][k])
|
||||||
else:
|
else:
|
||||||
@ -828,7 +888,7 @@ class Message(ABC):
|
|||||||
else:
|
else:
|
||||||
v = b64decode(value[key])
|
v = b64decode(value[key])
|
||||||
elif meta.proto_type == TYPE_ENUM:
|
elif meta.proto_type == TYPE_ENUM:
|
||||||
enum_cls = self._cls_for(field)
|
enum_cls = self._betterproto.cls_by_field[field.name]
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = [enum_cls.from_string(e) for e in v]
|
v = [enum_cls.from_string(e) for e in v]
|
||||||
elif isinstance(v, str):
|
elif isinstance(v, str):
|
||||||
@ -861,7 +921,7 @@ def serialized_on_wire(message: Message) -> bool:
|
|||||||
|
|
||||||
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
||||||
"""Return the name and value of a message's one-of field group."""
|
"""Return the name and value of a message's one-of field group."""
|
||||||
field = message._group_map["groups"].get(group_name, {}).get("current")
|
field = message._group_map.get(group_name)
|
||||||
if not field:
|
if not field:
|
||||||
return ("", None)
|
return ("", None)
|
||||||
return (field.name, getattr(message, field.name))
|
return (field.name, getattr(message, field.name))
|
||||||
|
@ -5,6 +5,7 @@ import sys
|
|||||||
import pytest
|
import pytest
|
||||||
import betterproto
|
import betterproto
|
||||||
from betterproto.tests.util import get_directories, inputs_path
|
from betterproto.tests.util import get_directories, inputs_path
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
# Force pure-python implementation instead of C++, otherwise imports
|
# Force pure-python implementation instead of C++, otherwise imports
|
||||||
# break things because we can't properly reset the symbol database.
|
# break things because we can't properly reset the symbol database.
|
||||||
@ -22,47 +23,13 @@ plugin_output_package = "betterproto.tests.output_betterproto"
|
|||||||
reference_output_package = "betterproto.tests.output_reference"
|
reference_output_package = "betterproto.tests.output_reference"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("test_case_name", test_case_names)
|
TestData = namedtuple("TestData", "plugin_module, reference_module, json_data")
|
||||||
def test_message_can_be_imported(test_case_name: str) -> None:
|
|
||||||
importlib.import_module(
|
|
||||||
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("test_case_name", test_case_names)
|
@pytest.fixture(scope="module", params=test_case_names)
|
||||||
def test_message_can_instantiated(test_case_name: str) -> None:
|
def test_data(request):
|
||||||
plugin_module = importlib.import_module(
|
test_case_name = request.param
|
||||||
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
|
|
||||||
)
|
|
||||||
plugin_module.Test()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("test_case_name", test_case_names)
|
|
||||||
def test_message_equality(test_case_name: str) -> None:
|
|
||||||
plugin_module = importlib.import_module(
|
|
||||||
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
|
|
||||||
)
|
|
||||||
message1 = plugin_module.Test()
|
|
||||||
message2 = plugin_module.Test()
|
|
||||||
assert message1 == message2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("test_case_name", test_case_names)
|
|
||||||
def test_message_json(test_case_name: str) -> None:
|
|
||||||
plugin_module = importlib.import_module(
|
|
||||||
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
|
|
||||||
)
|
|
||||||
message: betterproto.Message = plugin_module.Test()
|
|
||||||
reference_json_data = get_test_case_json_data(test_case_name)
|
|
||||||
|
|
||||||
message.from_json(reference_json_data)
|
|
||||||
message_json = message.to_json(0)
|
|
||||||
|
|
||||||
assert json.loads(reference_json_data) == json.loads(message_json)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("test_case_name", test_case_names)
|
|
||||||
def test_binary_compatibility(test_case_name: str) -> None:
|
|
||||||
# Reset the internal symbol database so we can import the `Test` message
|
# Reset the internal symbol database so we can import the `Test` message
|
||||||
# multiple times. Ugh.
|
# multiple times. Ugh.
|
||||||
sym = symbol_database.Default()
|
sym = symbol_database.Default()
|
||||||
@ -74,33 +41,66 @@ def test_binary_compatibility(test_case_name: str) -> None:
|
|||||||
|
|
||||||
sys.path.append(reference_module_root)
|
sys.path.append(reference_module_root)
|
||||||
|
|
||||||
# import reference message
|
yield TestData(
|
||||||
reference_module = importlib.import_module(
|
plugin_module=importlib.import_module(
|
||||||
f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2"
|
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
|
||||||
|
),
|
||||||
|
reference_module=importlib.import_module(
|
||||||
|
f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2"
|
||||||
|
),
|
||||||
|
json_data=get_test_case_json_data(test_case_name),
|
||||||
)
|
)
|
||||||
plugin_module = importlib.import_module(
|
|
||||||
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
test_data = get_test_case_json_data(test_case_name)
|
|
||||||
|
|
||||||
reference_instance = Parse(test_data, reference_module.Test())
|
|
||||||
reference_binary_output = reference_instance.SerializeToString()
|
|
||||||
|
|
||||||
plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json(
|
|
||||||
test_data
|
|
||||||
)
|
|
||||||
plugin_instance_from_binary = plugin_module.Test.FromString(reference_binary_output)
|
|
||||||
|
|
||||||
# # Generally this can't be relied on, but here we are aiming to match the
|
|
||||||
# # existing Python implementation and aren't doing anything tricky.
|
|
||||||
# # https://developers.google.com/protocol-buffers/docs/encoding#implications
|
|
||||||
assert plugin_instance_from_json == plugin_instance_from_binary
|
|
||||||
assert plugin_instance_from_json.to_dict() == plugin_instance_from_binary.to_dict()
|
|
||||||
|
|
||||||
sys.path.remove(reference_module_root)
|
sys.path.remove(reference_module_root)
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_can_instantiated(test_data: TestData) -> None:
|
||||||
|
plugin_module, *_ = test_data
|
||||||
|
plugin_module.Test()
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_equality(test_data: TestData) -> None:
|
||||||
|
plugin_module, *_ = test_data
|
||||||
|
message1 = plugin_module.Test()
|
||||||
|
message2 = plugin_module.Test()
|
||||||
|
assert message1 == message2
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_json(repeat, test_data: TestData) -> None:
|
||||||
|
plugin_module, _, json_data = test_data
|
||||||
|
|
||||||
|
for _ in range(repeat):
|
||||||
|
message: betterproto.Message = plugin_module.Test()
|
||||||
|
|
||||||
|
message.from_json(json_data)
|
||||||
|
message_json = message.to_json(0)
|
||||||
|
|
||||||
|
assert json.loads(json_data) == json.loads(message_json)
|
||||||
|
|
||||||
|
|
||||||
|
def test_binary_compatibility(repeat, test_data: TestData) -> None:
|
||||||
|
plugin_module, reference_module, json_data = test_data
|
||||||
|
|
||||||
|
reference_instance = Parse(json_data, reference_module.Test())
|
||||||
|
reference_binary_output = reference_instance.SerializeToString()
|
||||||
|
|
||||||
|
for _ in range(repeat):
|
||||||
|
plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json(
|
||||||
|
json_data
|
||||||
|
)
|
||||||
|
plugin_instance_from_binary = plugin_module.Test.FromString(
|
||||||
|
reference_binary_output
|
||||||
|
)
|
||||||
|
|
||||||
|
# # Generally this can't be relied on, but here we are aiming to match the
|
||||||
|
# # existing Python implementation and aren't doing anything tricky.
|
||||||
|
# # https://developers.google.com/protocol-buffers/docs/encoding#implications
|
||||||
|
assert plugin_instance_from_json == plugin_instance_from_binary
|
||||||
|
assert (
|
||||||
|
plugin_instance_from_json.to_dict() == plugin_instance_from_binary.to_dict()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
helper methods
|
helper methods
|
||||||
"""
|
"""
|
||||||
|
10
conftest.py
Normal file
10
conftest.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption("--repeat", type=int, default=1, help="repeat the operation multiple times")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def repeat(request):
|
||||||
|
return request.config.getoption("repeat")
|
Loading…
x
Reference in New Issue
Block a user