Support wrapper types

This commit is contained in:
Daniel G. Taylor 2019-10-27 14:55:25 -07:00
parent c79535b614
commit 035793aec3
No known key found for this signature in database
GPG Key ID: 7BD6DC99C9A87E22
7 changed files with 233 additions and 32 deletions

View File

@ -238,6 +238,53 @@ Again this is a little different than the official Google code generator:
["foo", "foo's value"] ["foo", "foo's value"]
``` ```
### Well-Known Google Types
Google provides several well-known message types like a timestamp, duration, and several wrappers used to provide optional zero value support. Each of these has a special JSON representation and is handled a little differently from normal messages. The Python mapping for these is as follows:
| Google Message | Python Type | Default |
| --------------------------- | ---------------------------------------- | ---------------------- |
| `google.protobuf.duration` | [`datetime.timedelta`][td] | `0` |
| `google.protobuf.timestamp` | Timezone-aware [`datetime.datetime`][dt] | `1970-01-01T00:00:00Z` |
| `google.protobuf.*Value` | `Optional[...]` | `None` |
[td]: https://docs.python.org/3/library/datetime.html#timedelta-objects
[dt]: https://docs.python.org/3/library/datetime.html#datetime.datetime
For the wrapper types, the Python type corresponds to the wrapped type, e.g. `google.protobuf.BoolValue` becomes `Optional[bool]` while `google.protobuf.Int32Value` becomes `Optional[int]`. All of the optional values default to `None`, so don't forget to check for that possible state. Given:
```protobuf
syntax = "proto3";
import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto";
message Test {
google.protobuf.BoolValue maybe = 1;
google.protobuf.Timestamp ts = 2;
google.protobuf.Duration duration = 3;
}
```
You can do stuff like:
```py
>>> t = Test().from_dict({"maybe": True, "ts": "2019-01-01T12:00:00Z", "duration": "1.200s"})
>>> t
st(maybe=True, ts=datetime.datetime(2019, 1, 1, 12, 0, tzinfo=datetime.timezone.utc), duration=datetime.timedelta(seconds=1, microseconds=200000))
>>> t.ts - t.duration
datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
>>> t.ts.isoformat()
'2019-01-01T12:00:00+00:00'
>>> t.maybe = None
>>> t.to_dict()
{'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'}
```
## Development ## Development
First, make sure you have Python 3.7+ and `pipenv` installed, along with the official [Protobuf Compiler](https://github.com/protocolbuffers/protobuf/releases) for your platform. Then: First, make sure you have Python 3.7+ and `pipenv` installed, along with the official [Protobuf Compiler](https://github.com/protocolbuffers/protobuf/releases) for your platform. Then:
@ -295,7 +342,7 @@ $ pipenv run tests
- [x] Bytes as base64 - [x] Bytes as base64
- [ ] Any support - [ ] Any support
- [x] Enum strings - [x] Enum strings
- [ ] Well known types support (timestamp, duration, wrappers) - [x] Well known types support (timestamp, duration, wrappers)
- [x] Support different casing (orig vs. camel vs. others?) - [x] Support different casing (orig vs. camel vs. others?)
- [ ] Async service stubs - [ ] Async service stubs
- [x] Unary-unary - [x] Unary-unary

View File

@ -105,6 +105,10 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
DATETIME_ZERO = datetime(1970, 1, 1, tzinfo=timezone.utc)
class Casing(enum.Enum): class Casing(enum.Enum):
"""Casing constants for serialization.""" """Casing constants for serialization."""
@ -128,9 +132,11 @@ class FieldMetadata:
# Protobuf type name # Protobuf type name
proto_type: str proto_type: str
# Map information if the proto_type is a map # Map information if the proto_type is a map
map_types: Optional[Tuple[str, str]] map_types: Optional[Tuple[str, str]] = None
# Groups several "one-of" fields together # Groups several "one-of" fields together
group: Optional[str] group: Optional[str] = None
# Describes the wrapped type (e.g. when using google.protobuf.BoolValue)
wraps: Optional[str] = None
@staticmethod @staticmethod
def get(field: dataclasses.Field) -> "FieldMetadata": def get(field: dataclasses.Field) -> "FieldMetadata":
@ -144,11 +150,14 @@ def dataclass_field(
*, *,
map_types: Optional[Tuple[str, str]] = None, map_types: Optional[Tuple[str, str]] = None,
group: Optional[str] = None, group: Optional[str] = None,
wraps: Optional[str] = None,
) -> dataclasses.Field: ) -> dataclasses.Field:
"""Creates a dataclass field with attached protobuf metadata.""" """Creates a dataclass field with attached protobuf metadata."""
return dataclasses.field( return dataclasses.field(
default=PLACEHOLDER, default=PLACEHOLDER,
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, group)}, metadata={
"betterproto": FieldMetadata(number, proto_type, map_types, group, wraps)
},
) )
@ -221,8 +230,10 @@ def bytes_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_BYTES, group=group) return dataclass_field(number, TYPE_BYTES, group=group)
def message_field(number: int, group: Optional[str] = None) -> Any: def message_field(
return dataclass_field(number, TYPE_MESSAGE, group=group) number: int, group: Optional[str] = None, wraps: Optional[str] = None
) -> Any:
return dataclass_field(number, TYPE_MESSAGE, group=group, wraps=wraps)
def map_field( def map_field(
@ -273,7 +284,7 @@ def encode_varint(value: int) -> bytes:
return bytes(b + [bits]) return bytes(b + [bits])
def _preprocess_single(proto_type: str, value: Any) -> bytes: def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
"""Adjusts values before serialization.""" """Adjusts values before serialization."""
if proto_type in [ if proto_type in [
TYPE_ENUM, TYPE_ENUM,
@ -307,6 +318,10 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
seconds = int(total_ms / 1e6) seconds = int(total_ms / 1e6)
nanos = int((total_ms % 1e6) * 1e3) nanos = int((total_ms % 1e6) * 1e3)
value = _Duration(seconds=seconds, nanos=nanos) value = _Duration(seconds=seconds, nanos=nanos)
elif wraps:
if value is None:
return b""
value = _get_wrapper(wraps)(value=value)
return bytes(value) return bytes(value)
@ -314,10 +329,15 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
def _serialize_single( def _serialize_single(
field_number: int, proto_type: str, value: Any, *, serialize_empty: bool = False field_number: int,
proto_type: str,
value: Any,
*,
serialize_empty: bool = False,
wraps: str = "",
) -> bytes: ) -> bytes:
"""Serializes a single field and value.""" """Serializes a single field and value."""
value = _preprocess_single(proto_type, value) value = _preprocess_single(proto_type, wraps, value)
output = b"" output = b""
if proto_type in WIRE_VARINT_TYPES: if proto_type in WIRE_VARINT_TYPES:
@ -330,7 +350,7 @@ def _serialize_single(
key = encode_varint((field_number << 3) | 1) key = encode_varint((field_number << 3) | 1)
output += key + value output += key + value
elif proto_type in WIRE_LEN_DELIM_TYPES: elif proto_type in WIRE_LEN_DELIM_TYPES:
if len(value) or serialize_empty: if len(value) or serialize_empty or wraps:
key = encode_varint((field_number << 3) | 2) key = encode_varint((field_number << 3) | 2)
output += key + encode_varint(len(value)) + value output += key + encode_varint(len(value)) + value
else: else:
@ -370,7 +390,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
while i < len(value): while i < len(value):
start = i start = i
num_wire, i = decode_varint(value, i) num_wire, i = decode_varint(value, i)
# print(num_wire, i)
number = num_wire >> 3 number = num_wire >> 3
wire_type = num_wire & 0x7 wire_type = num_wire & 0x7
@ -386,8 +405,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
elif wire_type == 5: elif wire_type == 5:
decoded, i = value[i : i + 4], i + 4 decoded, i = value[i : i + 4], i + 4
# print(ParsedField(number=number, wire_type=wire_type, value=decoded))
yield ParsedField( yield ParsedField(
number=number, wire_type=wire_type, value=decoded, raw=value[start:i] number=number, wire_type=wire_type, value=decoded, raw=value[start:i]
) )
@ -462,6 +479,11 @@ class Message(ABC):
meta = FieldMetadata.get(field) meta = FieldMetadata.get(field)
value = getattr(self, field.name) value = getattr(self, field.name)
if value is None:
# Optional items should be skipped. This is used for the Google
# wrapper types.
continue
# Being selected in a a group means this field is the one that is # Being selected in a a group means this field is the one that is
# currently set in a `oneof` group, so it must be serialized even # currently set in a `oneof` group, so it must be serialized even
# if the value is the default zero value. # if the value is the default zero value.
@ -491,11 +513,13 @@ class Message(ABC):
# treat it like a field of raw bytes. # treat it like a field of raw bytes.
buf = b"" buf = b""
for item in value: for item in value:
buf += _preprocess_single(meta.proto_type, item) buf += _preprocess_single(meta.proto_type, "", item)
output += _serialize_single(meta.number, TYPE_BYTES, buf) output += _serialize_single(meta.number, TYPE_BYTES, buf)
else: else:
for item in value: for item in value:
output += _serialize_single(meta.number, meta.proto_type, item) output += _serialize_single(
meta.number, meta.proto_type, item, wraps=meta.wraps
)
elif isinstance(value, dict): elif isinstance(value, dict):
for k, v in value.items(): for k, v in value.items():
assert meta.map_types assert meta.map_types
@ -504,7 +528,11 @@ class Message(ABC):
output += _serialize_single(meta.number, meta.proto_type, sk + sv) output += _serialize_single(meta.number, meta.proto_type, sk + sv)
else: else:
output += _serialize_single( output += _serialize_single(
meta.number, meta.proto_type, value, serialize_empty=serialize_empty meta.number,
meta.proto_type,
value,
serialize_empty=serialize_empty,
wraps=meta.wraps,
) )
return output + self._unknown_fields return output + self._unknown_fields
@ -546,7 +574,7 @@ class Message(ABC):
value = 0 value = 0
elif t == datetime: elif t == datetime:
# Offsets are relative to 1970-01-01T00:00:00Z # Offsets are relative to 1970-01-01T00:00:00Z
value = datetime(1970, 1, 1, tzinfo=timezone.utc) value = DATETIME_ZERO
else: else:
# This is either a primitive scalar or another message type. Calling # This is either a primitive scalar or another message type. Calling
# it should result in its zero value. # it should result in its zero value.
@ -580,6 +608,10 @@ class Message(ABC):
value = _Timestamp().parse(value).to_datetime() value = _Timestamp().parse(value).to_datetime()
elif cls == timedelta: elif cls == timedelta:
value = _Duration().parse(value).to_timedelta() value = _Duration().parse(value).to_timedelta()
elif meta.wraps:
# This is a Google wrapper value message around a single
# scalar type.
value = _get_wrapper(meta.wraps)().parse(value).value
else: else:
value = cls().parse(value) value = cls().parse(value)
value._serialized_on_wire = True value._serialized_on_wire = True
@ -670,9 +702,14 @@ class Message(ABC):
cased_name = casing(field.name).rstrip("_") cased_name = casing(field.name).rstrip("_")
if meta.proto_type == "message": if meta.proto_type == "message":
if isinstance(v, datetime): if isinstance(v, datetime):
output[cased_name] = _Timestamp.to_json(v) if v != DATETIME_ZERO:
output[cased_name] = _Timestamp.to_json(v)
elif isinstance(v, timedelta): elif isinstance(v, timedelta):
output[cased_name] = _Duration.to_json(v) if v != timedelta(0):
output[cased_name] = _Duration.to_json(v)
elif meta.wraps:
if v is not None:
output[cased_name] = v
elif isinstance(v, list): elif isinstance(v, list):
# Convert each item. # Convert each item.
v = [i.to_dict() for i in v] v = [i.to_dict() for i in v]
@ -723,17 +760,20 @@ class Message(ABC):
if value[key] is not None: if value[key] is not None:
if meta.proto_type == "message": if meta.proto_type == "message":
v = getattr(self, field.name) v = getattr(self, field.name)
# print(v, value[key])
if isinstance(v, list): if isinstance(v, list):
cls = self._cls_for(field) cls = self._cls_for(field)
for i in range(len(value[key])): for i in range(len(value[key])):
v.append(cls().from_dict(value[key][i])) v.append(cls().from_dict(value[key][i]))
elif isinstance(v, datetime): elif isinstance(v, datetime):
v = datetime.fromisoformat(value[key].replace("Z", "+00:00")) v = datetime.fromisoformat(
value[key].replace("Z", "+00:00")
)
setattr(self, field.name, v) setattr(self, field.name, v)
elif isinstance(v, timedelta): elif isinstance(v, timedelta):
v = timedelta(seconds=float(value[key][:-1])) v = timedelta(seconds=float(value[key][:-1]))
setattr(self, field.name, v) setattr(self, field.name, v)
elif meta.wraps:
setattr(self, field.name, value[key])
else: else:
v.from_dict(value[key]) v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
@ -830,7 +870,6 @@ class _Timestamp(Message):
def to_datetime(self) -> datetime: def to_datetime(self) -> datetime:
ts = self.seconds + (self.nanos / 1e9) ts = self.seconds + (self.nanos / 1e9)
print('to-datetime', ts, datetime.fromtimestamp(ts, tz=timezone.utc))
return datetime.fromtimestamp(ts, tz=timezone.utc) return datetime.fromtimestamp(ts, tz=timezone.utc)
@staticmethod @staticmethod
@ -839,17 +878,90 @@ class _Timestamp(Message):
copy = dt.replace(microsecond=0, tzinfo=None) copy = dt.replace(microsecond=0, tzinfo=None)
result = copy.isoformat() result = copy.isoformat()
if (nanos % 1e9) == 0: if (nanos % 1e9) == 0:
# If there are 0 fractional digits, the fractional # If there are 0 fractional digits, the fractional
# point '.' should be omitted when serializing. # point '.' should be omitted when serializing.
return result + 'Z' return result + "Z"
if (nanos % 1e6) == 0: if (nanos % 1e6) == 0:
# Serialize 3 fractional digits. # Serialize 3 fractional digits.
return result + '.%03dZ' % (nanos / 1e6) return result + ".%03dZ" % (nanos / 1e6)
if (nanos % 1e3) == 0: if (nanos % 1e3) == 0:
# Serialize 6 fractional digits. # Serialize 6 fractional digits.
return result + '.%06dZ' % (nanos / 1e3) return result + ".%06dZ" % (nanos / 1e3)
# Serialize 9 fractional digits. # Serialize 9 fractional digits.
return result + '.%09dZ' % nanos return result + ".%09dZ" % nanos
class _WrappedMessage(Message):
"""
Google protobuf wrapper types base class. JSON representation is just the
value itself.
"""
def to_dict(self) -> Any:
return self.value
def from_dict(self, value: Any) -> None:
if value is not None:
self.value = value
@dataclasses.dataclass
class _BoolValue(_WrappedMessage):
value: bool = bool_field(1)
@dataclasses.dataclass
class _Int32Value(_WrappedMessage):
value: int = int32_field(1)
@dataclasses.dataclass
class _UInt32Value(_WrappedMessage):
value: int = uint32_field(1)
@dataclasses.dataclass
class _Int64Value(_WrappedMessage):
value: int = int64_field(1)
@dataclasses.dataclass
class _UInt64Value(_WrappedMessage):
value: int = uint64_field(1)
@dataclasses.dataclass
class _FloatValue(_WrappedMessage):
value: float = float_field(1)
@dataclasses.dataclass
class _DoubleValue(_WrappedMessage):
value: float = double_field(1)
@dataclasses.dataclass
class _StringValue(_WrappedMessage):
value: str = string_field(1)
@dataclasses.dataclass
class _BytesValue(_WrappedMessage):
value: bytes = bytes_field(1)
def _get_wrapper(proto_type: str) -> _WrappedMessage:
"""Get the wrapper message class for a wrapped type."""
return {
TYPE_BOOL: _BoolValue,
TYPE_INT32: _Int32Value,
TYPE_UINT32: _UInt32Value,
TYPE_INT64: _Int64Value,
TYPE_UINT64: _UInt64Value,
TYPE_FLOAT: _FloatValue,
TYPE_DOUBLE: _DoubleValue,
TYPE_STRING: _StringValue,
TYPE_BYTES: _BytesValue,
}[proto_type]
class ServiceStub(ABC): class ServiceStub(ABC):

View File

@ -236,6 +236,14 @@ def generate_code(request, response):
packed = False packed = False
field_type = f.Type.Name(f.type).lower()[5:] field_type = f.Type.Name(f.type).lower()[5:]
field_wraps = ""
if f.type_name.startswith(
".google.protobuf"
) and f.type_name.endswith("Value"):
w = f.type_name.split(".").pop()[:-5].upper()
field_wraps = f"betterproto.TYPE_{w}"
map_types = None map_types = None
if f.type == 11: if f.type == 11:
# This might be a map... # This might be a map...
@ -301,6 +309,7 @@ def generate_code(request, response):
"comment": get_comment(proto_file, path + [2, i]), "comment": get_comment(proto_file, path + [2, i]),
"proto_type": int(f.type), "proto_type": int(f.type),
"field_type": field_type, "field_type": field_type,
"field_wraps": field_wraps,
"map_types": map_types, "map_types": map_types,
"type": t, "type": t,
"zero": zero, "zero": zero,

View File

@ -48,7 +48,7 @@ class {{ message.py_name }}(betterproto.Message):
{% if field.comment %} {% if field.comment %}
{{ field.comment }} {{ field.comment }}
{% endif %} {% endif %}
{{ field.py_name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}) {{ field.py_name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}{% if field.field_wraps %}, wraps={{ field.field_wraps }}{% endif %})
{% endfor %} {% endfor %}
{% if not message.properties %} {% if not message.properties %}
pass pass
@ -63,7 +63,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% endif %} {% endif %}
{% for method in service.methods %} {% for method in service.methods %}
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}: async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
{% if method.comment %} {% if method.comment %}
{{ method.comment }} {{ method.comment }}

View File

@ -0,0 +1,5 @@
{
"maybe": false,
"ts": "1972-01-01T10:00:20.021Z",
"duration": "1.200s"
}

View File

@ -0,0 +1,12 @@
syntax = "proto3";
import "google/protobuf/duration.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto";
message Test {
google.protobuf.BoolValue maybe = 1;
google.protobuf.Timestamp ts = 2;
google.protobuf.Duration duration = 3;
google.protobuf.Int32Value important = 4;
}

View File

@ -1,5 +1,6 @@
import betterproto import betterproto
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
def test_has_field(): def test_has_field():
@ -146,3 +147,18 @@ def test_json_casing():
"snake_case": 3, "snake_case": 3,
"kabob_case": 4, "kabob_case": 4,
} }
def test_optional_flag():
@dataclass
class Request(betterproto.Message):
flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL)
# Serialization of not passed vs. set vs. zero-value.
assert bytes(Request()) == b""
assert bytes(Request(flag=True)) == b"\n\x02\x08\x01"
assert bytes(Request(flag=False)) == b"\n\x00"
# Differentiate between not passed and the zero-value.
assert Request().parse(b"").flag == None
assert Request().parse(b"\n\x00").flag == False