Support wrapper types
This commit is contained in:
parent
c79535b614
commit
035793aec3
49
README.md
49
README.md
@ -238,6 +238,53 @@ Again this is a little different than the official Google code generator:
|
||||
["foo", "foo's value"]
|
||||
```
|
||||
|
||||
### Well-Known Google Types
|
||||
|
||||
Google provides several well-known message types like a timestamp, duration, and several wrappers used to provide optional zero value support. Each of these has a special JSON representation and is handled a little differently from normal messages. The Python mapping for these is as follows:
|
||||
|
||||
| Google Message | Python Type | Default |
|
||||
| --------------------------- | ---------------------------------------- | ---------------------- |
|
||||
| `google.protobuf.duration` | [`datetime.timedelta`][td] | `0` |
|
||||
| `google.protobuf.timestamp` | Timezone-aware [`datetime.datetime`][dt] | `1970-01-01T00:00:00Z` |
|
||||
| `google.protobuf.*Value` | `Optional[...]` | `None` |
|
||||
|
||||
[td]: https://docs.python.org/3/library/datetime.html#timedelta-objects
|
||||
[dt]: https://docs.python.org/3/library/datetime.html#datetime.datetime
|
||||
|
||||
For the wrapper types, the Python type corresponds to the wrapped type, e.g. `google.protobuf.BoolValue` becomes `Optional[bool]` while `google.protobuf.Int32Value` becomes `Optional[int]`. All of the optional values default to `None`, so don't forget to check for that possible state. Given:
|
||||
|
||||
```protobuf
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/duration.proto";
|
||||
import "google/protobuf/timestamp.proto";
|
||||
import "google/protobuf/wrappers.proto";
|
||||
|
||||
message Test {
|
||||
google.protobuf.BoolValue maybe = 1;
|
||||
google.protobuf.Timestamp ts = 2;
|
||||
google.protobuf.Duration duration = 3;
|
||||
}
|
||||
```
|
||||
|
||||
You can do stuff like:
|
||||
|
||||
```py
|
||||
>>> t = Test().from_dict({"maybe": True, "ts": "2019-01-01T12:00:00Z", "duration": "1.200s"})
|
||||
>>> t
|
||||
st(maybe=True, ts=datetime.datetime(2019, 1, 1, 12, 0, tzinfo=datetime.timezone.utc), duration=datetime.timedelta(seconds=1, microseconds=200000))
|
||||
|
||||
>>> t.ts - t.duration
|
||||
datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
|
||||
|
||||
>>> t.ts.isoformat()
|
||||
'2019-01-01T12:00:00+00:00'
|
||||
|
||||
>>> t.maybe = None
|
||||
>>> t.to_dict()
|
||||
{'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'}
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
First, make sure you have Python 3.7+ and `pipenv` installed, along with the official [Protobuf Compiler](https://github.com/protocolbuffers/protobuf/releases) for your platform. Then:
|
||||
@ -295,7 +342,7 @@ $ pipenv run tests
|
||||
- [x] Bytes as base64
|
||||
- [ ] Any support
|
||||
- [x] Enum strings
|
||||
- [ ] Well known types support (timestamp, duration, wrappers)
|
||||
- [x] Well known types support (timestamp, duration, wrappers)
|
||||
- [x] Support different casing (orig vs. camel vs. others?)
|
||||
- [ ] Async service stubs
|
||||
- [x] Unary-unary
|
||||
|
@ -105,6 +105,10 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
|
||||
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
||||
|
||||
|
||||
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
|
||||
DATETIME_ZERO = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
class Casing(enum.Enum):
|
||||
"""Casing constants for serialization."""
|
||||
|
||||
@ -128,9 +132,11 @@ class FieldMetadata:
|
||||
# Protobuf type name
|
||||
proto_type: str
|
||||
# Map information if the proto_type is a map
|
||||
map_types: Optional[Tuple[str, str]]
|
||||
map_types: Optional[Tuple[str, str]] = None
|
||||
# Groups several "one-of" fields together
|
||||
group: Optional[str]
|
||||
group: Optional[str] = None
|
||||
# Describes the wrapped type (e.g. when using google.protobuf.BoolValue)
|
||||
wraps: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def get(field: dataclasses.Field) -> "FieldMetadata":
|
||||
@ -144,11 +150,14 @@ def dataclass_field(
|
||||
*,
|
||||
map_types: Optional[Tuple[str, str]] = None,
|
||||
group: Optional[str] = None,
|
||||
wraps: Optional[str] = None,
|
||||
) -> dataclasses.Field:
|
||||
"""Creates a dataclass field with attached protobuf metadata."""
|
||||
return dataclasses.field(
|
||||
default=PLACEHOLDER,
|
||||
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, group)},
|
||||
metadata={
|
||||
"betterproto": FieldMetadata(number, proto_type, map_types, group, wraps)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@ -221,8 +230,10 @@ def bytes_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_BYTES, group=group)
|
||||
|
||||
|
||||
def message_field(number: int, group: Optional[str] = None) -> Any:
|
||||
return dataclass_field(number, TYPE_MESSAGE, group=group)
|
||||
def message_field(
|
||||
number: int, group: Optional[str] = None, wraps: Optional[str] = None
|
||||
) -> Any:
|
||||
return dataclass_field(number, TYPE_MESSAGE, group=group, wraps=wraps)
|
||||
|
||||
|
||||
def map_field(
|
||||
@ -273,7 +284,7 @@ def encode_varint(value: int) -> bytes:
|
||||
return bytes(b + [bits])
|
||||
|
||||
|
||||
def _preprocess_single(proto_type: str, value: Any) -> bytes:
|
||||
def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
|
||||
"""Adjusts values before serialization."""
|
||||
if proto_type in [
|
||||
TYPE_ENUM,
|
||||
@ -307,6 +318,10 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
|
||||
seconds = int(total_ms / 1e6)
|
||||
nanos = int((total_ms % 1e6) * 1e3)
|
||||
value = _Duration(seconds=seconds, nanos=nanos)
|
||||
elif wraps:
|
||||
if value is None:
|
||||
return b""
|
||||
value = _get_wrapper(wraps)(value=value)
|
||||
|
||||
return bytes(value)
|
||||
|
||||
@ -314,10 +329,15 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
|
||||
|
||||
|
||||
def _serialize_single(
|
||||
field_number: int, proto_type: str, value: Any, *, serialize_empty: bool = False
|
||||
field_number: int,
|
||||
proto_type: str,
|
||||
value: Any,
|
||||
*,
|
||||
serialize_empty: bool = False,
|
||||
wraps: str = "",
|
||||
) -> bytes:
|
||||
"""Serializes a single field and value."""
|
||||
value = _preprocess_single(proto_type, value)
|
||||
value = _preprocess_single(proto_type, wraps, value)
|
||||
|
||||
output = b""
|
||||
if proto_type in WIRE_VARINT_TYPES:
|
||||
@ -330,7 +350,7 @@ def _serialize_single(
|
||||
key = encode_varint((field_number << 3) | 1)
|
||||
output += key + value
|
||||
elif proto_type in WIRE_LEN_DELIM_TYPES:
|
||||
if len(value) or serialize_empty:
|
||||
if len(value) or serialize_empty or wraps:
|
||||
key = encode_varint((field_number << 3) | 2)
|
||||
output += key + encode_varint(len(value)) + value
|
||||
else:
|
||||
@ -370,7 +390,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
||||
while i < len(value):
|
||||
start = i
|
||||
num_wire, i = decode_varint(value, i)
|
||||
# print(num_wire, i)
|
||||
number = num_wire >> 3
|
||||
wire_type = num_wire & 0x7
|
||||
|
||||
@ -386,8 +405,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
||||
elif wire_type == 5:
|
||||
decoded, i = value[i : i + 4], i + 4
|
||||
|
||||
# print(ParsedField(number=number, wire_type=wire_type, value=decoded))
|
||||
|
||||
yield ParsedField(
|
||||
number=number, wire_type=wire_type, value=decoded, raw=value[start:i]
|
||||
)
|
||||
@ -462,6 +479,11 @@ class Message(ABC):
|
||||
meta = FieldMetadata.get(field)
|
||||
value = getattr(self, field.name)
|
||||
|
||||
if value is None:
|
||||
# Optional items should be skipped. This is used for the Google
|
||||
# wrapper types.
|
||||
continue
|
||||
|
||||
# Being selected in a a group means this field is the one that is
|
||||
# currently set in a `oneof` group, so it must be serialized even
|
||||
# if the value is the default zero value.
|
||||
@ -491,11 +513,13 @@ class Message(ABC):
|
||||
# treat it like a field of raw bytes.
|
||||
buf = b""
|
||||
for item in value:
|
||||
buf += _preprocess_single(meta.proto_type, item)
|
||||
buf += _preprocess_single(meta.proto_type, "", item)
|
||||
output += _serialize_single(meta.number, TYPE_BYTES, buf)
|
||||
else:
|
||||
for item in value:
|
||||
output += _serialize_single(meta.number, meta.proto_type, item)
|
||||
output += _serialize_single(
|
||||
meta.number, meta.proto_type, item, wraps=meta.wraps
|
||||
)
|
||||
elif isinstance(value, dict):
|
||||
for k, v in value.items():
|
||||
assert meta.map_types
|
||||
@ -504,7 +528,11 @@ class Message(ABC):
|
||||
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
|
||||
else:
|
||||
output += _serialize_single(
|
||||
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
|
||||
meta.number,
|
||||
meta.proto_type,
|
||||
value,
|
||||
serialize_empty=serialize_empty,
|
||||
wraps=meta.wraps,
|
||||
)
|
||||
|
||||
return output + self._unknown_fields
|
||||
@ -546,7 +574,7 @@ class Message(ABC):
|
||||
value = 0
|
||||
elif t == datetime:
|
||||
# Offsets are relative to 1970-01-01T00:00:00Z
|
||||
value = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
value = DATETIME_ZERO
|
||||
else:
|
||||
# This is either a primitive scalar or another message type. Calling
|
||||
# it should result in its zero value.
|
||||
@ -580,6 +608,10 @@ class Message(ABC):
|
||||
value = _Timestamp().parse(value).to_datetime()
|
||||
elif cls == timedelta:
|
||||
value = _Duration().parse(value).to_timedelta()
|
||||
elif meta.wraps:
|
||||
# This is a Google wrapper value message around a single
|
||||
# scalar type.
|
||||
value = _get_wrapper(meta.wraps)().parse(value).value
|
||||
else:
|
||||
value = cls().parse(value)
|
||||
value._serialized_on_wire = True
|
||||
@ -670,9 +702,14 @@ class Message(ABC):
|
||||
cased_name = casing(field.name).rstrip("_")
|
||||
if meta.proto_type == "message":
|
||||
if isinstance(v, datetime):
|
||||
output[cased_name] = _Timestamp.to_json(v)
|
||||
if v != DATETIME_ZERO:
|
||||
output[cased_name] = _Timestamp.to_json(v)
|
||||
elif isinstance(v, timedelta):
|
||||
output[cased_name] = _Duration.to_json(v)
|
||||
if v != timedelta(0):
|
||||
output[cased_name] = _Duration.to_json(v)
|
||||
elif meta.wraps:
|
||||
if v is not None:
|
||||
output[cased_name] = v
|
||||
elif isinstance(v, list):
|
||||
# Convert each item.
|
||||
v = [i.to_dict() for i in v]
|
||||
@ -723,17 +760,20 @@ class Message(ABC):
|
||||
if value[key] is not None:
|
||||
if meta.proto_type == "message":
|
||||
v = getattr(self, field.name)
|
||||
# print(v, value[key])
|
||||
if isinstance(v, list):
|
||||
cls = self._cls_for(field)
|
||||
for i in range(len(value[key])):
|
||||
v.append(cls().from_dict(value[key][i]))
|
||||
elif isinstance(v, datetime):
|
||||
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
|
||||
v = datetime.fromisoformat(
|
||||
value[key].replace("Z", "+00:00")
|
||||
)
|
||||
setattr(self, field.name, v)
|
||||
elif isinstance(v, timedelta):
|
||||
v = timedelta(seconds=float(value[key][:-1]))
|
||||
setattr(self, field.name, v)
|
||||
elif meta.wraps:
|
||||
setattr(self, field.name, value[key])
|
||||
else:
|
||||
v.from_dict(value[key])
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
@ -830,7 +870,6 @@ class _Timestamp(Message):
|
||||
|
||||
def to_datetime(self) -> datetime:
|
||||
ts = self.seconds + (self.nanos / 1e9)
|
||||
print('to-datetime', ts, datetime.fromtimestamp(ts, tz=timezone.utc))
|
||||
return datetime.fromtimestamp(ts, tz=timezone.utc)
|
||||
|
||||
@staticmethod
|
||||
@ -839,17 +878,90 @@ class _Timestamp(Message):
|
||||
copy = dt.replace(microsecond=0, tzinfo=None)
|
||||
result = copy.isoformat()
|
||||
if (nanos % 1e9) == 0:
|
||||
# If there are 0 fractional digits, the fractional
|
||||
# point '.' should be omitted when serializing.
|
||||
return result + 'Z'
|
||||
# If there are 0 fractional digits, the fractional
|
||||
# point '.' should be omitted when serializing.
|
||||
return result + "Z"
|
||||
if (nanos % 1e6) == 0:
|
||||
# Serialize 3 fractional digits.
|
||||
return result + '.%03dZ' % (nanos / 1e6)
|
||||
# Serialize 3 fractional digits.
|
||||
return result + ".%03dZ" % (nanos / 1e6)
|
||||
if (nanos % 1e3) == 0:
|
||||
# Serialize 6 fractional digits.
|
||||
return result + '.%06dZ' % (nanos / 1e3)
|
||||
# Serialize 6 fractional digits.
|
||||
return result + ".%06dZ" % (nanos / 1e3)
|
||||
# Serialize 9 fractional digits.
|
||||
return result + '.%09dZ' % nanos
|
||||
return result + ".%09dZ" % nanos
|
||||
|
||||
|
||||
class _WrappedMessage(Message):
|
||||
"""
|
||||
Google protobuf wrapper types base class. JSON representation is just the
|
||||
value itself.
|
||||
"""
|
||||
def to_dict(self) -> Any:
|
||||
return self.value
|
||||
|
||||
def from_dict(self, value: Any) -> None:
|
||||
if value is not None:
|
||||
self.value = value
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _BoolValue(_WrappedMessage):
|
||||
value: bool = bool_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Int32Value(_WrappedMessage):
|
||||
value: int = int32_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _UInt32Value(_WrappedMessage):
|
||||
value: int = uint32_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Int64Value(_WrappedMessage):
|
||||
value: int = int64_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _UInt64Value(_WrappedMessage):
|
||||
value: int = uint64_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _FloatValue(_WrappedMessage):
|
||||
value: float = float_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DoubleValue(_WrappedMessage):
|
||||
value: float = double_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _StringValue(_WrappedMessage):
|
||||
value: str = string_field(1)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _BytesValue(_WrappedMessage):
|
||||
value: bytes = bytes_field(1)
|
||||
|
||||
|
||||
def _get_wrapper(proto_type: str) -> _WrappedMessage:
|
||||
"""Get the wrapper message class for a wrapped type."""
|
||||
return {
|
||||
TYPE_BOOL: _BoolValue,
|
||||
TYPE_INT32: _Int32Value,
|
||||
TYPE_UINT32: _UInt32Value,
|
||||
TYPE_INT64: _Int64Value,
|
||||
TYPE_UINT64: _UInt64Value,
|
||||
TYPE_FLOAT: _FloatValue,
|
||||
TYPE_DOUBLE: _DoubleValue,
|
||||
TYPE_STRING: _StringValue,
|
||||
TYPE_BYTES: _BytesValue,
|
||||
}[proto_type]
|
||||
|
||||
|
||||
class ServiceStub(ABC):
|
||||
|
@ -236,6 +236,14 @@ def generate_code(request, response):
|
||||
packed = False
|
||||
|
||||
field_type = f.Type.Name(f.type).lower()[5:]
|
||||
|
||||
field_wraps = ""
|
||||
if f.type_name.startswith(
|
||||
".google.protobuf"
|
||||
) and f.type_name.endswith("Value"):
|
||||
w = f.type_name.split(".").pop()[:-5].upper()
|
||||
field_wraps = f"betterproto.TYPE_{w}"
|
||||
|
||||
map_types = None
|
||||
if f.type == 11:
|
||||
# This might be a map...
|
||||
@ -301,6 +309,7 @@ def generate_code(request, response):
|
||||
"comment": get_comment(proto_file, path + [2, i]),
|
||||
"proto_type": int(f.type),
|
||||
"field_type": field_type,
|
||||
"field_wraps": field_wraps,
|
||||
"map_types": map_types,
|
||||
"type": t,
|
||||
"zero": zero,
|
||||
|
@ -48,7 +48,7 @@ class {{ message.py_name }}(betterproto.Message):
|
||||
{% if field.comment %}
|
||||
{{ field.comment }}
|
||||
{% endif %}
|
||||
{{ field.py_name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %})
|
||||
{{ field.py_name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}{% if field.field_wraps %}, wraps={{ field.field_wraps }}{% endif %})
|
||||
{% endfor %}
|
||||
{% if not message.properties %}
|
||||
pass
|
||||
@ -63,7 +63,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
|
||||
{% endif %}
|
||||
{% for method in service.methods %}
|
||||
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
|
||||
async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}:
|
||||
{% if method.comment %}
|
||||
{{ method.comment }}
|
||||
|
||||
|
5
betterproto/tests/googletypes.json
Normal file
5
betterproto/tests/googletypes.json
Normal file
@ -0,0 +1,5 @@
|
||||
{
|
||||
"maybe": false,
|
||||
"ts": "1972-01-01T10:00:20.021Z",
|
||||
"duration": "1.200s"
|
||||
}
|
12
betterproto/tests/googletypes.proto
Normal file
12
betterproto/tests/googletypes.proto
Normal file
@ -0,0 +1,12 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "google/protobuf/duration.proto";
|
||||
import "google/protobuf/timestamp.proto";
|
||||
import "google/protobuf/wrappers.proto";
|
||||
|
||||
message Test {
|
||||
google.protobuf.BoolValue maybe = 1;
|
||||
google.protobuf.Timestamp ts = 2;
|
||||
google.protobuf.Duration duration = 3;
|
||||
google.protobuf.Int32Value important = 4;
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
import betterproto
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def test_has_field():
|
||||
@ -146,3 +147,18 @@ def test_json_casing():
|
||||
"snake_case": 3,
|
||||
"kabob_case": 4,
|
||||
}
|
||||
|
||||
|
||||
def test_optional_flag():
|
||||
@dataclass
|
||||
class Request(betterproto.Message):
|
||||
flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL)
|
||||
|
||||
# Serialization of not passed vs. set vs. zero-value.
|
||||
assert bytes(Request()) == b""
|
||||
assert bytes(Request(flag=True)) == b"\n\x02\x08\x01"
|
||||
assert bytes(Request(flag=False)) == b"\n\x00"
|
||||
|
||||
# Differentiate between not passed and the zero-value.
|
||||
assert Request().parse(b"").flag == None
|
||||
assert Request().parse(b"\n\x00").flag == False
|
||||
|
Loading…
x
Reference in New Issue
Block a user