Implement imports, simplified default value handling

This commit is contained in:
Daniel G. Taylor
2019-10-12 09:48:03 -07:00
parent 55be5eed69
commit dcb7102d92
9 changed files with 232 additions and 123 deletions

View File

@@ -92,6 +92,18 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
def get_default(proto_type: int) -> Any:
"""Get the default (zero value) for a given type."""
return {
TYPE_BOOL: False,
TYPE_FLOAT: 0.0,
TYPE_DOUBLE: 0.0,
TYPE_STRING: "",
TYPE_BYTES: b"",
TYPE_MAP: {},
}.get(proto_type, 0)
@dataclasses.dataclass(frozen=True)
class FieldMetadata:
"""Stores internal metadata used for parsing & serialization."""
@@ -114,7 +126,7 @@ class FieldMetadata:
def dataclass_field(
number: int,
proto_type: str,
default: Any,
default: Any = None,
map_types: Optional[Tuple[str, str]] = None,
**kwargs: dict,
) -> dataclasses.Field:
@@ -141,6 +153,10 @@ def enum_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
return dataclass_field(number, TYPE_ENUM, default=default)
def bool_field(number: int, default: Union[bool, Type[Iterable]] = 0) -> Any:
return dataclass_field(number, TYPE_BOOL, default=default)
def int32_field(number: int, default: Union[int, Type[Iterable]] = 0) -> Any:
return dataclass_field(number, TYPE_INT32, default=default)
@@ -193,8 +209,8 @@ def string_field(number: int, default: str = "") -> Any:
return dataclass_field(number, TYPE_STRING, default=default)
def message_field(number: int, default: Type["Message"]) -> Any:
return dataclass_field(number, TYPE_MESSAGE, default=default)
def message_field(number: int) -> Any:
return dataclass_field(number, TYPE_MESSAGE)
def map_field(number: int, key_type: str, value_type: str) -> Any:
@@ -345,6 +361,29 @@ class Message(ABC):
to go between Python, binary and JSON protobuf message representations.
"""
def __post_init__(self) -> None:
# Set a default value for each field in the class after `__init__` has
# already been run.
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
t = self._cls_for(field, index=-1)
value = 0
if meta.proto_type == TYPE_MAP:
# Maps cannot be repeated, so we check these first.
value = {}
elif hasattr(t, "__args__") and len(t.__args__) == 1:
# Anything else with type args is a list.
value = []
elif meta.proto_type == TYPE_MESSAGE:
# Message means creating an instance of the right type.
value = t()
else:
value = get_default(meta.proto_type)
setattr(self, field.name, value)
def __bytes__(self) -> bytes:
"""
Get the binary encoded Protobuf representation of this instance.
@@ -356,6 +395,7 @@ class Message(ABC):
if isinstance(value, list):
if not len(value):
# Empty values are not serialized
continue
if meta.proto_type in PACKED_TYPES:
@@ -371,6 +411,7 @@ class Message(ABC):
output += _serialize_single(meta.number, meta.proto_type, item)
elif isinstance(value, dict):
if not len(value):
# Empty values are not serialized
continue
for k, v in value.items():
@@ -378,7 +419,8 @@ class Message(ABC):
sv = _serialize_single(2, meta.map_types[1], v)
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
else:
if value == field.default:
if value == get_default(meta.proto_type):
# Default (zero) values are not serialized
continue
output += _serialize_single(meta.number, meta.proto_type, value)
@@ -390,7 +432,7 @@ class Message(ABC):
module = inspect.getmodule(self)
type_hints = get_type_hints(self, vars(module))
cls = type_hints[field.name]
if hasattr(cls, "__args__"):
if hasattr(cls, "__args__") and index >= 0:
cls = type_hints[field.name].__args__[index]
return cls
@@ -522,7 +564,7 @@ class Message(ABC):
"""
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
if field.name in value:
if field.name in value and value[field.name] is not None:
if meta.proto_type == "message":
v = getattr(self, field.name)
# print(v, value[field.name])

View File

@@ -7,6 +7,10 @@ from dataclasses import dataclass
from typing import Dict, List
import betterproto
{% for i in description.imports %}
{{ i }}
{% endfor %}
{% if description.enums %}{% for enum in description.enums %}
@@ -21,9 +25,9 @@ class {{ enum.name }}(enum.IntEnum):
{% endif %}
{{ entry.name }} = {{ entry.value }}
{% endfor %}
{% endfor %}
{% endif %}
{% for message in description.messages %}
@dataclass
@@ -36,8 +40,11 @@ class {{ message.name }}(betterproto.Message):
{% if field.comment %}
{{ field.comment }}
{% endif %}
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.zero and field.field_type != 'map' %}, default={{ field.zero }}{% endif %}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %})
{{ field.name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %})
{% endfor %}
{% if not message.properties %}
pass
{% endif %}
{% endfor %}

View File

@@ -0,0 +1,5 @@
{
"greeting": {
"greeting": "hello"
}
}

View File

@@ -0,0 +1,9 @@
syntax = "proto3";
package ref;
import "repeatedmessage.proto";
message Test {
repeatedmessage.Sub greeting = 1;
}

View File

@@ -1,5 +1,7 @@
syntax = "proto3";
package repeatedmessage;
message Test {
repeated Sub greetings = 1;
}

View File

@@ -2,7 +2,7 @@ import importlib
import pytest
import json
from generate import get_files, get_base
from .generate import get_files, get_base
inputs = get_files(".bin")
@@ -10,7 +10,7 @@ inputs = get_files(".bin")
@pytest.mark.parametrize("filename", inputs)
def test_sample(filename: str) -> None:
module = get_base(filename).split("-")[0]
imported = importlib.import_module(module)
imported = importlib.import_module(f"betterproto.tests.{module}")
data_binary = open(filename, "rb").read()
data_dict = json.loads(open(filename.replace(".bin", ".json")).read())
t1 = imported.Test().parse(data_binary)