Support wrapper types

This commit is contained in:
Daniel G. Taylor
2019-10-27 14:55:25 -07:00
parent c79535b614
commit 035793aec3
7 changed files with 233 additions and 32 deletions

View File

@@ -105,6 +105,10 @@ WIRE_FIXED_64_TYPES = [TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64]
WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
DATETIME_ZERO = datetime(1970, 1, 1, tzinfo=timezone.utc)
class Casing(enum.Enum):
"""Casing constants for serialization."""
@@ -128,9 +132,11 @@ class FieldMetadata:
# Protobuf type name
proto_type: str
# Map information if the proto_type is a map
map_types: Optional[Tuple[str, str]]
map_types: Optional[Tuple[str, str]] = None
# Groups several "one-of" fields together
group: Optional[str]
group: Optional[str] = None
# Describes the wrapped type (e.g. when using google.protobuf.BoolValue)
wraps: Optional[str] = None
@staticmethod
def get(field: dataclasses.Field) -> "FieldMetadata":
@@ -144,11 +150,14 @@ def dataclass_field(
*,
map_types: Optional[Tuple[str, str]] = None,
group: Optional[str] = None,
wraps: Optional[str] = None,
) -> dataclasses.Field:
"""Creates a dataclass field with attached protobuf metadata."""
return dataclasses.field(
default=PLACEHOLDER,
metadata={"betterproto": FieldMetadata(number, proto_type, map_types, group)},
metadata={
"betterproto": FieldMetadata(number, proto_type, map_types, group, wraps)
},
)
@@ -221,8 +230,10 @@ def bytes_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_BYTES, group=group)
def message_field(number: int, group: Optional[str] = None) -> Any:
return dataclass_field(number, TYPE_MESSAGE, group=group)
def message_field(
number: int, group: Optional[str] = None, wraps: Optional[str] = None
) -> Any:
return dataclass_field(number, TYPE_MESSAGE, group=group, wraps=wraps)
def map_field(
@@ -273,7 +284,7 @@ def encode_varint(value: int) -> bytes:
return bytes(b + [bits])
def _preprocess_single(proto_type: str, value: Any) -> bytes:
def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
"""Adjusts values before serialization."""
if proto_type in [
TYPE_ENUM,
@@ -307,6 +318,10 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
seconds = int(total_ms / 1e6)
nanos = int((total_ms % 1e6) * 1e3)
value = _Duration(seconds=seconds, nanos=nanos)
elif wraps:
if value is None:
return b""
value = _get_wrapper(wraps)(value=value)
return bytes(value)
@@ -314,10 +329,15 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
def _serialize_single(
field_number: int, proto_type: str, value: Any, *, serialize_empty: bool = False
field_number: int,
proto_type: str,
value: Any,
*,
serialize_empty: bool = False,
wraps: str = "",
) -> bytes:
"""Serializes a single field and value."""
value = _preprocess_single(proto_type, value)
value = _preprocess_single(proto_type, wraps, value)
output = b""
if proto_type in WIRE_VARINT_TYPES:
@@ -330,7 +350,7 @@ def _serialize_single(
key = encode_varint((field_number << 3) | 1)
output += key + value
elif proto_type in WIRE_LEN_DELIM_TYPES:
if len(value) or serialize_empty:
if len(value) or serialize_empty or wraps:
key = encode_varint((field_number << 3) | 2)
output += key + encode_varint(len(value)) + value
else:
@@ -370,7 +390,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
while i < len(value):
start = i
num_wire, i = decode_varint(value, i)
# print(num_wire, i)
number = num_wire >> 3
wire_type = num_wire & 0x7
@@ -386,8 +405,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
elif wire_type == 5:
decoded, i = value[i : i + 4], i + 4
# print(ParsedField(number=number, wire_type=wire_type, value=decoded))
yield ParsedField(
number=number, wire_type=wire_type, value=decoded, raw=value[start:i]
)
@@ -462,6 +479,11 @@ class Message(ABC):
meta = FieldMetadata.get(field)
value = getattr(self, field.name)
if value is None:
# Optional items should be skipped. This is used for the Google
# wrapper types.
continue
# Being selected in a a group means this field is the one that is
# currently set in a `oneof` group, so it must be serialized even
# if the value is the default zero value.
@@ -491,11 +513,13 @@ class Message(ABC):
# treat it like a field of raw bytes.
buf = b""
for item in value:
buf += _preprocess_single(meta.proto_type, item)
buf += _preprocess_single(meta.proto_type, "", item)
output += _serialize_single(meta.number, TYPE_BYTES, buf)
else:
for item in value:
output += _serialize_single(meta.number, meta.proto_type, item)
output += _serialize_single(
meta.number, meta.proto_type, item, wraps=meta.wraps
)
elif isinstance(value, dict):
for k, v in value.items():
assert meta.map_types
@@ -504,7 +528,11 @@ class Message(ABC):
output += _serialize_single(meta.number, meta.proto_type, sk + sv)
else:
output += _serialize_single(
meta.number, meta.proto_type, value, serialize_empty=serialize_empty
meta.number,
meta.proto_type,
value,
serialize_empty=serialize_empty,
wraps=meta.wraps,
)
return output + self._unknown_fields
@@ -546,7 +574,7 @@ class Message(ABC):
value = 0
elif t == datetime:
# Offsets are relative to 1970-01-01T00:00:00Z
value = datetime(1970, 1, 1, tzinfo=timezone.utc)
value = DATETIME_ZERO
else:
# This is either a primitive scalar or another message type. Calling
# it should result in its zero value.
@@ -580,6 +608,10 @@ class Message(ABC):
value = _Timestamp().parse(value).to_datetime()
elif cls == timedelta:
value = _Duration().parse(value).to_timedelta()
elif meta.wraps:
# This is a Google wrapper value message around a single
# scalar type.
value = _get_wrapper(meta.wraps)().parse(value).value
else:
value = cls().parse(value)
value._serialized_on_wire = True
@@ -670,9 +702,14 @@ class Message(ABC):
cased_name = casing(field.name).rstrip("_")
if meta.proto_type == "message":
if isinstance(v, datetime):
output[cased_name] = _Timestamp.to_json(v)
if v != DATETIME_ZERO:
output[cased_name] = _Timestamp.to_json(v)
elif isinstance(v, timedelta):
output[cased_name] = _Duration.to_json(v)
if v != timedelta(0):
output[cased_name] = _Duration.to_json(v)
elif meta.wraps:
if v is not None:
output[cased_name] = v
elif isinstance(v, list):
# Convert each item.
v = [i.to_dict() for i in v]
@@ -723,17 +760,20 @@ class Message(ABC):
if value[key] is not None:
if meta.proto_type == "message":
v = getattr(self, field.name)
# print(v, value[key])
if isinstance(v, list):
cls = self._cls_for(field)
for i in range(len(value[key])):
v.append(cls().from_dict(value[key][i]))
elif isinstance(v, datetime):
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
v = datetime.fromisoformat(
value[key].replace("Z", "+00:00")
)
setattr(self, field.name, v)
elif isinstance(v, timedelta):
v = timedelta(seconds=float(value[key][:-1]))
setattr(self, field.name, v)
elif meta.wraps:
setattr(self, field.name, value[key])
else:
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
@@ -830,7 +870,6 @@ class _Timestamp(Message):
def to_datetime(self) -> datetime:
ts = self.seconds + (self.nanos / 1e9)
print('to-datetime', ts, datetime.fromtimestamp(ts, tz=timezone.utc))
return datetime.fromtimestamp(ts, tz=timezone.utc)
@staticmethod
@@ -839,17 +878,90 @@ class _Timestamp(Message):
copy = dt.replace(microsecond=0, tzinfo=None)
result = copy.isoformat()
if (nanos % 1e9) == 0:
# If there are 0 fractional digits, the fractional
# point '.' should be omitted when serializing.
return result + 'Z'
# If there are 0 fractional digits, the fractional
# point '.' should be omitted when serializing.
return result + "Z"
if (nanos % 1e6) == 0:
# Serialize 3 fractional digits.
return result + '.%03dZ' % (nanos / 1e6)
# Serialize 3 fractional digits.
return result + ".%03dZ" % (nanos / 1e6)
if (nanos % 1e3) == 0:
# Serialize 6 fractional digits.
return result + '.%06dZ' % (nanos / 1e3)
# Serialize 6 fractional digits.
return result + ".%06dZ" % (nanos / 1e3)
# Serialize 9 fractional digits.
return result + '.%09dZ' % nanos
return result + ".%09dZ" % nanos
class _WrappedMessage(Message):
"""
Google protobuf wrapper types base class. JSON representation is just the
value itself.
"""
def to_dict(self) -> Any:
return self.value
def from_dict(self, value: Any) -> None:
if value is not None:
self.value = value
@dataclasses.dataclass
class _BoolValue(_WrappedMessage):
value: bool = bool_field(1)
@dataclasses.dataclass
class _Int32Value(_WrappedMessage):
value: int = int32_field(1)
@dataclasses.dataclass
class _UInt32Value(_WrappedMessage):
value: int = uint32_field(1)
@dataclasses.dataclass
class _Int64Value(_WrappedMessage):
value: int = int64_field(1)
@dataclasses.dataclass
class _UInt64Value(_WrappedMessage):
value: int = uint64_field(1)
@dataclasses.dataclass
class _FloatValue(_WrappedMessage):
value: float = float_field(1)
@dataclasses.dataclass
class _DoubleValue(_WrappedMessage):
value: float = double_field(1)
@dataclasses.dataclass
class _StringValue(_WrappedMessage):
value: str = string_field(1)
@dataclasses.dataclass
class _BytesValue(_WrappedMessage):
value: bytes = bytes_field(1)
def _get_wrapper(proto_type: str) -> _WrappedMessage:
"""Get the wrapper message class for a wrapped type."""
return {
TYPE_BOOL: _BoolValue,
TYPE_INT32: _Int32Value,
TYPE_UINT32: _UInt32Value,
TYPE_INT64: _Int64Value,
TYPE_UINT64: _UInt64Value,
TYPE_FLOAT: _FloatValue,
TYPE_DOUBLE: _DoubleValue,
TYPE_STRING: _StringValue,
TYPE_BYTES: _BytesValue,
}[proto_type]
class ServiceStub(ABC):