Implement imports, simplified default value handling
This commit is contained in:
@@ -92,6 +92,18 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
|
||||
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
||||
|
||||
|
||||
def get_default(proto_type: int) -> Any:
|
||||
"""Get the default (zero value) for a given type."""
|
||||
return {
|
||||
TYPE_BOOL: False,
|
||||
TYPE_FLOAT: 0.0,
|
||||
TYPE_DOUBLE: 0.0,
|
||||
TYPE_STRING: "",
|
||||
TYPE_BYTES: b"",
|
||||
TYPE_MAP: {},
|
||||
}.get(proto_type, 0)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class FieldMetadata:
|
||||
"""Stores internal metadata used for parsing & serialization."""
|
||||
@@ -114,7 +126,7 @@ class FieldMetadata:
|
||||
def dataclass_field(
|
||||
number: int,
|
||||
proto_type: str,
|
||||
default: Any,
|
||||
default: Any = None,
|
||||
map_types: Optional[Tuple[str, str]] = None,
|
||||
**kwargs: dict,
|
||||
) -> dataclasses.Field:
|
||||
@@ -141,6 +153,10 @@ def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
|
||||
return dataclass_field(number, TYPE_ENUM, default=default)
|
||||
|
||||
|
||||
def bool_field(number: int, default: Union[bool, Type[Iterable]] = 0) -> Any:
|
||||
return dataclass_field(number, TYPE_BOOL, default=default)
|
||||
|
||||
|
||||
def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
|
||||
return dataclass_field(number, TYPE_INT32, default=default)
|
||||
|
||||
@@ -193,8 +209,8 @@ def string_field(number: int, default: str = "") -> Any:
|
||||
return dataclass_field(number, TYPE_STRING, default=default)
|
||||
|
||||
|
||||
def message_field(number: int, default: Type["Message"]) -> Any:
|
||||
return dataclass_field(number, TYPE_MESSAGE, default=default)
|
||||
def message_field(number: int) -> Any:
|
||||
return dataclass_field(number, TYPE_MESSAGE)
|
||||
|
||||
|
||||
def map_field(number: int, key_type: str, value_type: str) -> Any:
|
||||
@@ -345,6 +361,29 @@ class Message(ABC):
|
||||
to go between Python, binary and JSON protobuf message representations.
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Set a default value for each field in the class after `__init__` has
|
||||
# already been run.
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
|
||||
t = self._cls_for(field, index=-1)
|
||||
|
||||
value = 0
|
||||
if meta.proto_type == TYPE_MAP:
|
||||
# Maps cannot be repeated, so we check these first.
|
||||
value = {}
|
||||
elif hasattr(t, "__args__") and len(t.__args__) == 1:
|
||||
# Anything else with type args is a list.
|
||||
value = []
|
||||
elif meta.proto_type == TYPE_MESSAGE:
|
||||
# Message means creating an instance of the right type.
|
||||
value = t()
|
||||
else:
|
||||
value = get_default(meta.proto_type)
|
||||
|
||||
setattr(self, field.name, value)
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
"""
|
||||
Get the binary encoded Protobuf representation of this instance.
|
||||
@@ -356,6 +395,7 @@ class Message(ABC):
|
||||
|
||||
if isinstance(value, list):
|
||||
if not len(value):
|
||||
# Empty values are not serialized
|
||||
continue
|
||||
|
||||
if meta.proto_type in PACKED_TYPES:
|
||||
@@ -371,6 +411,7 @@ class Message(ABC):
|
||||
output += _serialize_single(meta.number, meta.proto_type, item)
|
||||
elif isinstance(value, dict):
|
||||
if not len(value):
|
||||
# Empty values are not serialized
|
||||
continue
|
||||
|
||||
for k, v in value.items():
|
||||
@@ -378,7 +419,8 @@ class Message(ABC):
|
||||
sv = _serialize_single(2, meta.map_types[1], v)
|
||||
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
|
||||
else:
|
||||
if value == field.default:
|
||||
if value == get_default(meta.proto_type):
|
||||
# Default (zero) values are not serialized
|
||||
continue
|
||||
|
||||
output += _serialize_single(meta.number, meta.proto_type, value)
|
||||
@@ -390,7 +432,7 @@ class Message(ABC):
|
||||
module = inspect.getmodule(self)
|
||||
type_hints = get_type_hints(self, vars(module))
|
||||
cls = type_hints[field.name]
|
||||
if hasattr(cls, "__args__"):
|
||||
if hasattr(cls, "__args__") and index >= 0:
|
||||
cls = type_hints[field.name].__args__[index]
|
||||
return cls
|
||||
|
||||
@@ -522,7 +564,7 @@ class Message(ABC):
|
||||
"""
|
||||
for field in dataclasses.fields(self):
|
||||
meta = FieldMetadata.get(field)
|
||||
if field.name in value:
|
||||
if field.name in value and value[field.name] is not None:
|
||||
if meta.proto_type == "message":
|
||||
v = getattr(self, field.name)
|
||||
# print(v, value[field.name])
|
||||
|
||||
@@ -7,6 +7,10 @@ from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
import betterproto
|
||||
{% for i in description.imports %}
|
||||
|
||||
{{ i }}
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% if description.enums %}{% for enum in description.enums %}
|
||||
@@ -21,9 +25,9 @@ class {{ enum.name }}(enum.IntEnum):
|
||||
{% endif %}
|
||||
{{ entry.name }} = {{ entry.value }}
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% for message in description.messages %}
|
||||
@dataclass
|
||||
@@ -36,8 +40,11 @@ class {{ message.name }}(betterproto.Message):
|
||||
{% if field.comment %}
|
||||
{{ field.comment }}
|
||||
{% endif %}
|
||||
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.zero and field.field_type != 'map' %}, default={{ field.zero }}{% endif %}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %})
|
||||
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %})
|
||||
{% endfor %}
|
||||
{% if not message.properties %}
|
||||
pass
|
||||
{% endif %}
|
||||
|
||||
|
||||
{% endfor %}
|
||||
|
||||
5
betterproto/tests/ref.json
Normal file
5
betterproto/tests/ref.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"greeting": {
|
||||
"greeting": "hello"
|
||||
}
|
||||
}
|
||||
9
betterproto/tests/ref.proto
Normal file
9
betterproto/tests/ref.proto
Normal file
@@ -0,0 +1,9 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package ref;
|
||||
|
||||
import "repeatedmessage.proto";
|
||||
|
||||
message Test {
|
||||
repeatedmessage.Sub greeting = 1;
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package repeatedmessage;
|
||||
|
||||
message Test {
|
||||
repeated Sub greetings = 1;
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ import importlib
|
||||
import pytest
|
||||
import json
|
||||
|
||||
from generate import get_files, get_base
|
||||
from .generate import get_files, get_base
|
||||
|
||||
inputs = get_files(".bin")
|
||||
|
||||
@@ -10,7 +10,7 @@ inputs = get_files(".bin")
|
||||
@pytest.mark.parametrize("filename", inputs)
|
||||
def test_sample(filename: str) -> None:
|
||||
module = get_base(filename).split("-")[0]
|
||||
imported = importlib.import_module(module)
|
||||
imported = importlib.import_module(f"betterproto.tests.{module}")
|
||||
data_binary = open(filename, "rb").read()
|
||||
data_dict = json.loads(open(filename.replace(".bin", ".json")).read())
|
||||
t1 = imported.Test().parse(data_binary)
|
||||
|
||||
Reference in New Issue
Block a user