Support wrapper types

This commit is contained in:
Daniel G. Taylor 2019-10-27 14:55:25 -07:00
parent c79535b614
commit 035793aec3
No known key found for this signature in database
GPG Key ID: 7BD6DC99C9A87E22
7 changed files with 233 additions and 32 deletions

View File

@ -238,6 +238,53 @@ Again this is a little different than the official Google code generator:
["foo", "foo's value"]
```
### Well-Known Google Types
Google provides several well-known message types like a timestamp, duration, and several wrappers used to provide optional zero value support. Each of these has a special JSON representation and is handled a little differently from normal messages. The Python mapping for these is as follows:
| Google Message | Python Type | Default |
| --------------------------- | ---------------------------------------- | ---------------------- |
| `google.protobuf.duration` | [`datetime.timedelta`][td] | `0` |
| `google.protobuf.timestamp` | Timezone-aware [`datetime.datetime`][dt] | `1970-01-01T00:00:00Z` |
| `google.protobuf.*Value` | `Optional[...]` | `None` |
[td]: https://docs.python.org/3/library/datetime.html#timedelta-objects
[dt]: https://docs.python.org/3/library/datetime.html#datetime.datetime
For the wrapper types, the Python type corresponds to the wrapped type, e.g. `google.protobuf.BoolValue` becomes `Optional[bool]` while `google.protobuf.Int32Value` becomes `Optional[int]`. All of the optional values default to `None`, so don't forget to check for that possible state. Given:
```protobuf
syntax = "proto3";
import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto";
message Test {
google.protobuf.BoolValue maybe = 1;
google.protobuf.Timestamp ts = 2;
google.protobuf.Duration duration = 3;
}
```
You can do stuff like:
```py
>>> t = Test().from_dict({"maybe": True, "ts": "2019-01-01T12:00:00Z", "duration": "1.200s"})
>>> t
st(maybe=True, ts=datetime.datetime(2019, 1, 1, 12, 0, tzinfo=datetime.timezone.utc), duration=datetime.timedelta(seconds=1, microseconds=200000))
>>> t.ts - t.duration
datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
>>> t.ts.isoformat()
'2019-01-01T12:00:00+00:00'
>>> t.maybe = None
>>> t.to_dict()
{'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'}
```
## Development
First, make sure you have Python 3.7+ and `pipenv` installed, along with the official [Protobuf Compiler](https://github.com/protocolbuffers/protobuf/releases) for your platform. Then:
@ -295,7 +342,7 @@ $ pipenv run tests
- [x] Bytes as base64
- [ ] Any support
- [x] Enum strings
- [ ] Well known types support (timestamp, duration, wrappers)
- [x] Well known types support (timestamp, duration, wrappers)
- [x] Support different casing (orig vs. camel vs. others?)
- [ ] Async service stubs
- [x] Unary-unary

View File

@ -105,6 +105,10 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
DATETIME_ZERO = datetime(1970, 1, 1, tzinfo=timezone.utc)
class Casing(enum.Enum):
"""Casing constants for serialization."""
@ -128,9 +132,11 @@ class FieldMetadata:
# Protobuf type name
proto_type: str
# Map information if the proto_type is a map
map_types: Optional[Tuple[str, str]]
map_types: Optional[Tuple[str, str]] = None
# Groups several "one-of" fields together
group: Optional[str]
group: Optional[str] = None
# Describes the wrapped type (e.g. when using google.protobuf.BoolValue)
wraps: Optional[str] = None
@staticmethod
def get(field: dataclasses.Field) -> "FieldMetadata":
@ -144,11 +150,14 @@ def dataclass_field(
*,
map_types: Optional[Tuple[str, str]] = None,
group: Optional[str] = None,
wraps: Optional[str] = None,
) -> dataclasses.Field:
"""Creates a dataclass field with attached protobuf metadata."""
return dataclasses.field(
default=PLACEHOLDER,
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, group)},
metadata={
"betterproto": FieldMetadata(number, proto_type, map_types, group, wraps)
},
)
@ -221,8 +230,10 @@ def bytes_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_BYTES, group=group)
def message_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_MESSAGE, group=group)
def message_field(
number: int, group: Optional[str] = None, wraps: Optional[str] = None
) -> Any:
return dataclass_field(number, TYPE_MESSAGE, group=group, wraps=wraps)
def map_field(
@ -273,7 +284,7 @@ def encode_varint(value: int) -> bytes:
return bytes(b + [bits])
def _preprocess_single(proto_type: str, value: Any) -> bytes:
def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
"""Adjusts values before serialization."""
if proto_type in [
TYPE_ENUM,
@ -307,6 +318,10 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
seconds = int(total_ms / 1e6)
nanos = int((total_ms % 1e6) * 1e3)
value = _Duration(seconds=seconds, nanos=nanos)
elif wraps:
if value is None:
return b""
value = _get_wrapper(wraps)(value=value)
return bytes(value)
@ -314,10 +329,15 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
def _serialize_single(
field_number: int, proto_type: str, value: Any, *, serialize_empty: bool = False
field_number: int,
proto_type: str,
value: Any,
*,
serialize_empty: bool = False,
wraps: str = "",
) -> bytes:
"""Serializes a single field and value."""
value = _preprocess_single(proto_type, value)
value = _preprocess_single(proto_type, wraps, value)
output = b""
if proto_type in WIRE_VARINT_TYPES:
@ -330,7 +350,7 @@ def _serialize_single(
key = encode_varint((field_number << 3) | 1)
output += key + value
elif proto_type in WIRE_LEN_DELIM_TYPES:
if len(value) or serialize_empty:
if len(value) or serialize_empty or wraps:
key = encode_varint((field_number << 3) | 2)
output += key + encode_varint(len(value)) + value
else:
@ -370,7 +390,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
while i < len(value):
start = i
num_wire, i = decode_varint(value, i)
# print(num_wire, i)
number = num_wire >> 3
wire_type = num_wire & 0x7
@ -386,8 +405,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
elif wire_type == 5:
decoded, i = value[i : i + 4], i + 4
# print(ParsedField(number=number, wire_type=wire_type, value=decoded))
yield ParsedField(
number=number, wire_type=wire_type, value=decoded, raw=value[start:i]
)
@ -462,6 +479,11 @@ class Message(ABC):
meta = FieldMetadata.get(field)
value = getattr(self, field.name)
if value is None:
# Optional items should be skipped. This is used for the Google
# wrapper types.
continue
# Being selected in a a group means this field is the one that is
# currently set in a `oneof` group, so it must be serialized even
# if the value is the default zero value.
@ -491,11 +513,13 @@ class Message(ABC):
# treat it like a field of raw bytes.
buf = b""
for item in value:
buf += _preprocess_single(meta.proto_type, item)
buf += _preprocess_single(meta.proto_type, "", item)
output += _serialize_single(meta.number, TYPE_BYTES, buf)
else:
for item in value:
output += _serialize_single(meta.number, meta.proto_type, item)
output += _serialize_single(
meta.number, meta.proto_type, item, wraps=meta.wraps
)
elif isinstance(value, dict):
for k, v in value.items():
assert meta.map_types
@ -504,7 +528,11 @@ class Message(ABC):
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
else:
output += _serialize_single(
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
meta.number,
meta.proto_type,
value,
serialize_empty=serialize_empty,
wraps=meta.wraps,
)
return output + self._unknown_fields
@ -546,7 +574,7 @@ class Message(ABC):
value = 0
elif t == datetime:
# Offsets are relative to 1970-01-01T00:00:00Z
value = datetime(1970, 1, 1, tzinfo=timezone.utc)
value = DATETIME_ZERO
else:
# This is either a primitive scalar or another message type. Calling
# it should result in its zero value.
@ -580,6 +608,10 @@ class Message(ABC):
value = _Timestamp().parse(value).to_datetime()
elif cls == timedelta:
value = _Duration().parse(value).to_timedelta()
elif meta.wraps:
# This is a Google wrapper value message around a single
# scalar type.
value = _get_wrapper(meta.wraps)().parse(value).value
else:
value = cls().parse(value)
value._serialized_on_wire = True
@ -670,9 +702,14 @@ class Message(ABC):
cased_name = casing(field.name).rstrip("_")
if meta.proto_type == "message":
if isinstance(v, datetime):
output[cased_name] = _Timestamp.to_json(v)
if v != DATETIME_ZERO:
output[cased_name] = _Timestamp.to_json(v)
elif isinstance(v, timedelta):
output[cased_name] = _Duration.to_json(v)
if v != timedelta(0):
output[cased_name] = _Duration.to_json(v)
elif meta.wraps:
if v is not None:
output[cased_name] = v
elif isinstance(v, list):
# Convert each item.
v = [i.to_dict() for i in v]
@ -723,17 +760,20 @@ class Message(ABC):
if value[key] is not None:
if meta.proto_type == "message":
v = getattr(self, field.name)
# print(v, value[key])
if isinstance(v, list):
cls = self._cls_for(field)
for i in range(len(value[key])):
v.append(cls().from_dict(value[key][i]))
elif isinstance(v, datetime):
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
v = datetime.fromisoformat(
value[key].replace("Z", "+00:00")
)
setattr(self, field.name, v)
elif isinstance(v, timedelta):
v = timedelta(seconds=float(value[key][:-1]))
setattr(self, field.name, v)
elif meta.wraps:
setattr(self, field.name, value[key])
else:
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
@ -830,7 +870,6 @@ class _Timestamp(Message):
def to_datetime(self) -> datetime:
ts = self.seconds + (self.nanos / 1e9)
print('to-datetime', ts, datetime.fromtimestamp(ts, tz=timezone.utc))
return datetime.fromtimestamp(ts, tz=timezone.utc)
@staticmethod
@ -839,17 +878,90 @@ class _Timestamp(Message):
copy = dt.replace(microsecond=0, tzinfo=None)
result = copy.isoformat()
if (nanos % 1e9) == 0:
# If there are 0 fractional digits, the fractional
# point '.' should be omitted when serializing.
return result + 'Z'
# If there are 0 fractional digits, the fractional
# point '.' should be omitted when serializing.
return result + "Z"
if (nanos % 1e6) == 0:
# Serialize 3 fractional digits.
return result + '.%03dZ' % (nanos / 1e6)
# Serialize 3 fractional digits.
return result + ".%03dZ" % (nanos / 1e6)
if (nanos % 1e3) == 0:
# Serialize 6 fractional digits.
return result + '.%06dZ' % (nanos / 1e3)
# Serialize 6 fractional digits.
return result + ".%06dZ" % (nanos / 1e3)
# Serialize 9 fractional digits.
return result + '.%09dZ' % nanos
return result + ".%09dZ" % nanos
class _WrappedMessage(Message):
"""
Google protobuf wrapper types base class. JSON representation is just the
value itself.
"""
def to_dict(self) -> Any:
return self.value
def from_dict(self, value: Any) -> None:
if value is not None:
self.value = value
@dataclasses.dataclass
class _BoolValue(_WrappedMessage):
value: bool = bool_field(1)
@dataclasses.dataclass
class _Int32Value(_WrappedMessage):
value: int = int32_field(1)
@dataclasses.dataclass
class _UInt32Value(_WrappedMessage):
value: int = uint32_field(1)
@dataclasses.dataclass
class _Int64Value(_WrappedMessage):
value: int = int64_field(1)
@dataclasses.dataclass
class _UInt64Value(_WrappedMessage):
value: int = uint64_field(1)
@dataclasses.dataclass
class _FloatValue(_WrappedMessage):
value: float = float_field(1)
@dataclasses.dataclass
class _DoubleValue(_WrappedMessage):
value: float = double_field(1)
@dataclasses.dataclass
class _StringValue(_WrappedMessage):
value: str = string_field(1)
@dataclasses.dataclass
class _BytesValue(_WrappedMessage):
value: bytes = bytes_field(1)
def _get_wrapper(proto_type: str) -> _WrappedMessage:
"""Get the wrapper message class for a wrapped type."""
return {
TYPE_BOOL: _BoolValue,
TYPE_INT32: _Int32Value,
TYPE_UINT32: _UInt32Value,
TYPE_INT64: _Int64Value,
TYPE_UINT64: _UInt64Value,
TYPE_FLOAT: _FloatValue,
TYPE_DOUBLE: _DoubleValue,
TYPE_STRING: _StringValue,
TYPE_BYTES: _BytesValue,
}[proto_type]
class ServiceStub(ABC):

View File

@ -236,6 +236,14 @@ def generate_code(request, response):
packed = False
field_type = f.Type.Name(f.type).lower()[5:]
field_wraps = ""
if f.type_name.startswith(
".google.protobuf"
) and f.type_name.endswith("Value"):
w = f.type_name.split(".").pop()[:-5].upper()
field_wraps = f"betterproto.TYPE_{w}"
map_types = None
if f.type == 11:
# This might be a map...
@ -301,6 +309,7 @@ def generate_code(request, response):
"comment": get_comment(proto_file, path + [2, i]),
"proto_type": int(f.type),
"field_type": field_type,
"field_wraps": field_wraps,
"map_types": map_types,
"type": t,
"zero": zero,

View File

@ -48,7 +48,7 @@ class {{ message.py_name }}(betterproto.Message):
{% if field.comment %}
{{ field.comment }}
{% endif %}
{{ field.py_name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %})
{{ field.py_name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}{% if field.field_wraps %}, wraps={{ field.field_wraps }}{% endif %})
{% endfor %}
{% if not message.properties %}
pass
@ -63,7 +63,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% endif %}
{% for method in service.methods %}
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
{% if method.comment %}
{{ method.comment }}

View File

@ -0,0 +1,5 @@
{
"maybe": false,
"ts": "1972-01-01T10:00:20.021Z",
"duration": "1.200s"
}

View File

@ -0,0 +1,12 @@
syntax = "proto3";
import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto";
message Test {
google.protobuf.BoolValue maybe = 1;
google.protobuf.Timestamp ts = 2;
google.protobuf.Duration duration = 3;
google.protobuf.Int32Value important = 4;
}

View File

@ -1,5 +1,6 @@
import betterproto
from dataclasses import dataclass
from typing import Optional
def test_has_field():
@ -146,3 +147,18 @@ def test_json_casing():
"snake_case": 3,
"kabob_case": 4,
}
def test_optional_flag():
@dataclass
class Request(betterproto.Message):
flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL)
# Serialization of not passed vs. set vs. zero-value.
assert bytes(Request()) == b""
assert bytes(Request(flag=True)) == b"\n\x02\x08\x01"
assert bytes(Request(flag=False)) == b"\n\x00"
# Differentiate between not passed and the zero-value.
assert Request().parse(b"").flag == None
assert Request().parse(b"\n\x00").flag == False