From 035793aec380fcf2c766212a378de674e1f94c9f Mon Sep 17 00:00:00 2001 From: "Daniel G. Taylor" Date: Sun, 27 Oct 2019 14:55:25 -0700 Subject: [PATCH] Support wrapper types --- README.md | 49 +++++++- betterproto/__init__.py | 170 +++++++++++++++++++++++----- betterproto/plugin.py | 9 ++ betterproto/templates/template.py | 4 +- betterproto/tests/googletypes.json | 5 + betterproto/tests/googletypes.proto | 12 ++ betterproto/tests/test_features.py | 16 +++ 7 files changed, 233 insertions(+), 32 deletions(-) create mode 100644 betterproto/tests/googletypes.json create mode 100644 betterproto/tests/googletypes.proto diff --git a/README.md b/README.md index 89f2edd..68c6700 100644 --- a/README.md +++ b/README.md @@ -238,6 +238,53 @@ Again this is a little different than the official Google code generator: ["foo", "foo's value"] ``` +### Well-Known Google Types + +Google provides several well-known message types like a timestamp, duration, and several wrappers used to provide optional zero value support. Each of these has a special JSON representation and is handled a little differently from normal messages. The Python mapping for these is as follows: + +| Google Message | Python Type | Default | +| --------------------------- | ---------------------------------------- | ---------------------- | +| `google.protobuf.duration` | [`datetime.timedelta`][td] | `0` | +| `google.protobuf.timestamp` | Timezone-aware [`datetime.datetime`][dt] | `1970-01-01T00:00:00Z` | +| `google.protobuf.*Value` | `Optional[...]` | `None` | + +[td]: https://docs.python.org/3/library/datetime.html#timedelta-objects +[dt]: https://docs.python.org/3/library/datetime.html#datetime.datetime + +For the wrapper types, the Python type corresponds to the wrapped type, e.g. `google.protobuf.BoolValue` becomes `Optional[bool]` while `google.protobuf.Int32Value` becomes `Optional[int]`. All of the optional values default to `None`, so don't forget to check for that possible state. Given: + +```protobuf +syntax = "proto3"; + +import "google/protobuf/duration.proto"; +import "google/protobuf/timestamp.proto"; +import "google/protobuf/wrappers.proto"; + +message Test { + google.protobuf.BoolValue maybe = 1; + google.protobuf.Timestamp ts = 2; + google.protobuf.Duration duration = 3; +} +``` + +You can do stuff like: + +```py +>>> t = Test().from_dict({"maybe": True, "ts": "2019-01-01T12:00:00Z", "duration": "1.200s"}) +>>> t +st(maybe=True, ts=datetime.datetime(2019, 1, 1, 12, 0, tzinfo=datetime.timezone.utc), duration=datetime.timedelta(seconds=1, microseconds=200000)) + +>>> t.ts - t.duration +datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc) + +>>> t.ts.isoformat() +'2019-01-01T12:00:00+00:00' + +>>> t.maybe = None +>>> t.to_dict() +{'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'} +``` + ## Development First, make sure you have Python 3.7+ and `pipenv` installed, along with the official [Protobuf Compiler](https://github.com/protocolbuffers/protobuf/releases) for your platform. Then: @@ -295,7 +342,7 @@ $ pipenv run tests - [x] Bytes as base64 - [ ] Any support - [x] Enum strings - - [ ] Well known types support (timestamp, duration, wrappers) + - [x] Well known types support (timestamp, duration, wrappers) - [x] Support different casing (orig vs. camel vs. others?) - [ ] Async service stubs - [x] Unary-unary diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 9caaa70..fc178b5 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -105,6 +105,10 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64] WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] +# Protobuf datetimes start at the Unix Epoch in 1970 in UTC. +DATETIME_ZERO = datetime(1970, 1, 1, tzinfo=timezone.utc) + + class Casing(enum.Enum): """Casing constants for serialization.""" @@ -128,9 +132,11 @@ class FieldMetadata: # Protobuf type name proto_type: str # Map information if the proto_type is a map - map_types: Optional[Tuple[str, str]] + map_types: Optional[Tuple[str, str]] = None # Groups several "one-of" fields together - group: Optional[str] + group: Optional[str] = None + # Describes the wrapped type (e.g. when using google.protobuf.BoolValue) + wraps: Optional[str] = None @staticmethod def get(field: dataclasses.Field) -> "FieldMetadata": @@ -144,11 +150,14 @@ def dataclass_field( *, map_types: Optional[Tuple[str, str]] = None, group: Optional[str] = None, + wraps: Optional[str] = None, ) -> dataclasses.Field: """Creates a dataclass field with attached protobuf metadata.""" return dataclasses.field( default=PLACEHOLDER, - metadata={"betterproto": FieldMetadata(number, proto_type, map_types, group)}, + metadata={ + "betterproto": FieldMetadata(number, proto_type, map_types, group, wraps) + }, ) @@ -221,8 +230,10 @@ def bytes_field(number: int, group: Optional[str] = None) -> Any: return dataclass_field(number, TYPE_BYTES, group=group) -def message_field(number: int, group: Optional[str] = None) -> Any: - return dataclass_field(number, TYPE_MESSAGE, group=group) +def message_field( + number: int, group: Optional[str] = None, wraps: Optional[str] = None +) -> Any: + return dataclass_field(number, TYPE_MESSAGE, group=group, wraps=wraps) def map_field( @@ -273,7 +284,7 @@ def encode_varint(value: int) -> bytes: return bytes(b + [bits]) -def _preprocess_single(proto_type: str, value: Any) -> bytes: +def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes: """Adjusts values before serialization.""" if proto_type in [ TYPE_ENUM, @@ -307,6 +318,10 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes: seconds = int(total_ms / 1e6) nanos = int((total_ms % 1e6) * 1e3) value = _Duration(seconds=seconds, nanos=nanos) + elif wraps: + if value is None: + return b"" + value = _get_wrapper(wraps)(value=value) return bytes(value) @@ -314,10 +329,15 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes: def _serialize_single( - field_number: int, proto_type: str, value: Any, *, serialize_empty: bool = False + field_number: int, + proto_type: str, + value: Any, + *, + serialize_empty: bool = False, + wraps: str = "", ) -> bytes: """Serializes a single field and value.""" - value = _preprocess_single(proto_type, value) + value = _preprocess_single(proto_type, wraps, value) output = b"" if proto_type in WIRE_VARINT_TYPES: @@ -330,7 +350,7 @@ def _serialize_single( key = encode_varint((field_number << 3) | 1) output += key + value elif proto_type in WIRE_LEN_DELIM_TYPES: - if len(value) or serialize_empty: + if len(value) or serialize_empty or wraps: key = encode_varint((field_number << 3) | 2) output += key + encode_varint(len(value)) + value else: @@ -370,7 +390,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: while i < len(value): start = i num_wire, i = decode_varint(value, i) - # print(num_wire, i) number = num_wire >> 3 wire_type = num_wire & 0x7 @@ -386,8 +405,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: elif wire_type == 5: decoded, i = value[i : i + 4], i + 4 - # print(ParsedField(number=number, wire_type=wire_type, value=decoded)) - yield ParsedField( number=number, wire_type=wire_type, value=decoded, raw=value[start:i] ) @@ -462,6 +479,11 @@ class Message(ABC): meta = FieldMetadata.get(field) value = getattr(self, field.name) + if value is None: + # Optional items should be skipped. This is used for the Google + # wrapper types. + continue + # Being selected in a a group means this field is the one that is # currently set in a `oneof` group, so it must be serialized even # if the value is the default zero value. @@ -491,11 +513,13 @@ class Message(ABC): # treat it like a field of raw bytes. buf = b"" for item in value: - buf += _preprocess_single(meta.proto_type, item) + buf += _preprocess_single(meta.proto_type, "", item) output += _serialize_single(meta.number, TYPE_BYTES, buf) else: for item in value: - output += _serialize_single(meta.number, meta.proto_type, item) + output += _serialize_single( + meta.number, meta.proto_type, item, wraps=meta.wraps + ) elif isinstance(value, dict): for k, v in value.items(): assert meta.map_types @@ -504,7 +528,11 @@ class Message(ABC): output += _serialize_single(meta.number, meta.proto_type, sk + sv) else: output += _serialize_single( - meta.number, meta.proto_type, value, serialize_empty=serialize_empty + meta.number, + meta.proto_type, + value, + serialize_empty=serialize_empty, + wraps=meta.wraps, ) return output + self._unknown_fields @@ -546,7 +574,7 @@ class Message(ABC): value = 0 elif t == datetime: # Offsets are relative to 1970-01-01T00:00:00Z - value = datetime(1970, 1, 1, tzinfo=timezone.utc) + value = DATETIME_ZERO else: # This is either a primitive scalar or another message type. Calling # it should result in its zero value. @@ -580,6 +608,10 @@ class Message(ABC): value = _Timestamp().parse(value).to_datetime() elif cls == timedelta: value = _Duration().parse(value).to_timedelta() + elif meta.wraps: + # This is a Google wrapper value message around a single + # scalar type. + value = _get_wrapper(meta.wraps)().parse(value).value else: value = cls().parse(value) value._serialized_on_wire = True @@ -670,9 +702,14 @@ class Message(ABC): cased_name = casing(field.name).rstrip("_") if meta.proto_type == "message": if isinstance(v, datetime): - output[cased_name] = _Timestamp.to_json(v) + if v != DATETIME_ZERO: + output[cased_name] = _Timestamp.to_json(v) elif isinstance(v, timedelta): - output[cased_name] = _Duration.to_json(v) + if v != timedelta(0): + output[cased_name] = _Duration.to_json(v) + elif meta.wraps: + if v is not None: + output[cased_name] = v elif isinstance(v, list): # Convert each item. v = [i.to_dict() for i in v] @@ -723,17 +760,20 @@ class Message(ABC): if value[key] is not None: if meta.proto_type == "message": v = getattr(self, field.name) - # print(v, value[key]) if isinstance(v, list): cls = self._cls_for(field) for i in range(len(value[key])): v.append(cls().from_dict(value[key][i])) elif isinstance(v, datetime): - v = datetime.fromisoformat(value[key].replace("Z", "+00:00")) + v = datetime.fromisoformat( + value[key].replace("Z", "+00:00") + ) setattr(self, field.name, v) elif isinstance(v, timedelta): v = timedelta(seconds=float(value[key][:-1])) setattr(self, field.name, v) + elif meta.wraps: + setattr(self, field.name, value[key]) else: v.from_dict(value[key]) elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: @@ -830,7 +870,6 @@ class _Timestamp(Message): def to_datetime(self) -> datetime: ts = self.seconds + (self.nanos / 1e9) - print('to-datetime', ts, datetime.fromtimestamp(ts, tz=timezone.utc)) return datetime.fromtimestamp(ts, tz=timezone.utc) @staticmethod @@ -839,17 +878,90 @@ class _Timestamp(Message): copy = dt.replace(microsecond=0, tzinfo=None) result = copy.isoformat() if (nanos % 1e9) == 0: - # If there are 0 fractional digits, the fractional - # point '.' should be omitted when serializing. - return result + 'Z' + # If there are 0 fractional digits, the fractional + # point '.' should be omitted when serializing. + return result + "Z" if (nanos % 1e6) == 0: - # Serialize 3 fractional digits. - return result + '.%03dZ' % (nanos / 1e6) + # Serialize 3 fractional digits. + return result + ".%03dZ" % (nanos / 1e6) if (nanos % 1e3) == 0: - # Serialize 6 fractional digits. - return result + '.%06dZ' % (nanos / 1e3) + # Serialize 6 fractional digits. + return result + ".%06dZ" % (nanos / 1e3) # Serialize 9 fractional digits. - return result + '.%09dZ' % nanos + return result + ".%09dZ" % nanos + + +class _WrappedMessage(Message): + """ + Google protobuf wrapper types base class. JSON representation is just the + value itself. + """ + def to_dict(self) -> Any: + return self.value + + def from_dict(self, value: Any) -> None: + if value is not None: + self.value = value + + +@dataclasses.dataclass +class _BoolValue(_WrappedMessage): + value: bool = bool_field(1) + + +@dataclasses.dataclass +class _Int32Value(_WrappedMessage): + value: int = int32_field(1) + + +@dataclasses.dataclass +class _UInt32Value(_WrappedMessage): + value: int = uint32_field(1) + + +@dataclasses.dataclass +class _Int64Value(_WrappedMessage): + value: int = int64_field(1) + + +@dataclasses.dataclass +class _UInt64Value(_WrappedMessage): + value: int = uint64_field(1) + + +@dataclasses.dataclass +class _FloatValue(_WrappedMessage): + value: float = float_field(1) + + +@dataclasses.dataclass +class _DoubleValue(_WrappedMessage): + value: float = double_field(1) + + +@dataclasses.dataclass +class _StringValue(_WrappedMessage): + value: str = string_field(1) + + +@dataclasses.dataclass +class _BytesValue(_WrappedMessage): + value: bytes = bytes_field(1) + + +def _get_wrapper(proto_type: str) -> _WrappedMessage: + """Get the wrapper message class for a wrapped type.""" + return { + TYPE_BOOL: _BoolValue, + TYPE_INT32: _Int32Value, + TYPE_UINT32: _UInt32Value, + TYPE_INT64: _Int64Value, + TYPE_UINT64: _UInt64Value, + TYPE_FLOAT: _FloatValue, + TYPE_DOUBLE: _DoubleValue, + TYPE_STRING: _StringValue, + TYPE_BYTES: _BytesValue, + }[proto_type] class ServiceStub(ABC): diff --git a/betterproto/plugin.py b/betterproto/plugin.py index bff2986..4d48612 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -236,6 +236,14 @@ def generate_code(request, response): packed = False field_type = f.Type.Name(f.type).lower()[5:] + + field_wraps = "" + if f.type_name.startswith( + ".google.protobuf" + ) and f.type_name.endswith("Value"): + w = f.type_name.split(".").pop()[:-5].upper() + field_wraps = f"betterproto.TYPE_{w}" + map_types = None if f.type == 11: # This might be a map... @@ -301,6 +309,7 @@ def generate_code(request, response): "comment": get_comment(proto_file, path + [2, i]), "proto_type": int(f.type), "field_type": field_type, + "field_wraps": field_wraps, "map_types": map_types, "type": t, "zero": zero, diff --git a/betterproto/templates/template.py b/betterproto/templates/template.py index 630afc6..eac4595 100644 --- a/betterproto/templates/template.py +++ b/betterproto/templates/template.py @@ -48,7 +48,7 @@ class {{ message.py_name }}(betterproto.Message): {% if field.comment %} {{ field.comment }} {% endif %} - {{ field.py_name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}) + {{ field.py_name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}{% if field.field_wraps %}, wraps={{ field.field_wraps }}{% endif %}) {% endfor %} {% if not message.properties %} pass @@ -63,7 +63,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% endif %} {% for method in service.methods %} - async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}: + async def {{ method.py_name }}(self{% if method.input_message and method.input_message.properties %}, *, {% for field in method.input_message.properties %}{{ field.name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") %}Optional[{{ field.type }}]{% else %}{{ field.type }}{% endif %} = {{ field.zero }}{% if not loop.last %}, {% endif %}{% endfor %}{% endif %}) -> {% if method.server_streaming %}AsyncGenerator[{{ method.output }}, None]{% else %}{{ method.output }}{% endif %}: {% if method.comment %} {{ method.comment }} diff --git a/betterproto/tests/googletypes.json b/betterproto/tests/googletypes.json new file mode 100644 index 0000000..5d86e1b --- /dev/null +++ b/betterproto/tests/googletypes.json @@ -0,0 +1,5 @@ +{ + "maybe": false, + "ts": "1972-01-01T10:00:20.021Z", + "duration": "1.200s" +} diff --git a/betterproto/tests/googletypes.proto b/betterproto/tests/googletypes.proto new file mode 100644 index 0000000..283b836 --- /dev/null +++ b/betterproto/tests/googletypes.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +import "google/protobuf/duration.proto"; +import "google/protobuf/timestamp.proto"; +import "google/protobuf/wrappers.proto"; + +message Test { + google.protobuf.BoolValue maybe = 1; + google.protobuf.Timestamp ts = 2; + google.protobuf.Duration duration = 3; + google.protobuf.Int32Value important = 4; +} diff --git a/betterproto/tests/test_features.py b/betterproto/tests/test_features.py index 03c3023..8e9aba3 100644 --- a/betterproto/tests/test_features.py +++ b/betterproto/tests/test_features.py @@ -1,5 +1,6 @@ import betterproto from dataclasses import dataclass +from typing import Optional def test_has_field(): @@ -146,3 +147,18 @@ def test_json_casing(): "snake_case": 3, "kabob_case": 4, } + + +def test_optional_flag(): + @dataclass + class Request(betterproto.Message): + flag: Optional[bool] = betterproto.message_field(1, wraps=betterproto.TYPE_BOOL) + + # Serialization of not passed vs. set vs. zero-value. + assert bytes(Request()) == b"" + assert bytes(Request(flag=True)) == b"\n\x02\x08\x01" + assert bytes(Request(flag=False)) == b"\n\x00" + + # Differentiate between not passed and the zero-value. + assert Request().parse(b"").flag == None + assert Request().parse(b"\n\x00").flag == False