Support Duration/Timestamp Google well-known types

This commit is contained in:
Daniel G. Taylor
2019-10-26 23:07:30 -07:00
parent 5daf61f64c
commit c79535b614
3 changed files with 135 additions and 3 deletions

View File

@@ -5,6 +5,7 @@ import json
import struct
from abc import ABC
from base64 import b64encode, b64decode
from datetime import datetime, timedelta, timezone
from typing import (
Any,
AsyncGenerator,
@@ -295,6 +296,18 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
elif proto_type == TYPE_STRING:
return value.encode("utf-8")
elif proto_type == TYPE_MESSAGE:
if isinstance(value, datetime):
# Convert the `datetime` to a timestamp message.
seconds = int(value.timestamp())
nanos = int(value.microsecond * 1e3)
value = _Timestamp(seconds=seconds, nanos=nanos)
elif isinstance(value, timedelta):
# Convert the `timedelta` to a duration message.
total_ms = value // timedelta(microseconds=1)
seconds = int(total_ms / 1e6)
nanos = int((total_ms % 1e6) * 1e3)
value = _Duration(seconds=seconds, nanos=nanos)
return bytes(value)
return value
@@ -399,6 +412,7 @@ class Message(ABC):
meta = FieldMetadata.get(field)
if meta.group:
# This is part of a one-of group.
group_map["fields"][field.name] = meta.group
if meta.group not in group_map["groups"]:
@@ -530,6 +544,9 @@ class Message(ABC):
elif issubclass(t, Enum):
# Enums always default to zero.
value = 0
elif t == datetime:
# Offsets are relative to 1970-01-01T00:00:00Z
value = datetime(1970, 1, 1, tzinfo=timezone.utc)
else:
# This is either a primitive scalar or another message type. Calling
# it should result in its zero value.
@@ -558,8 +575,14 @@ class Message(ABC):
value = value.decode("utf-8")
elif meta.proto_type == TYPE_MESSAGE:
cls = self._cls_for(field)
value = cls().parse(value)
value._serialized_on_wire = True
if cls == datetime:
value = _Timestamp().parse(value).to_datetime()
elif cls == timedelta:
value = _Duration().parse(value).to_timedelta()
else:
value = cls().parse(value)
value._serialized_on_wire = True
elif meta.proto_type == TYPE_MAP:
# TODO: This is slow, use a cache to make it faster since each
# key/value pair will recreate the class.
@@ -646,7 +669,11 @@ class Message(ABC):
v = getattr(self, field.name)
cased_name = casing(field.name).rstrip("_")
if meta.proto_type == "message":
if isinstance(v, list):
if isinstance(v, datetime):
output[cased_name] = _Timestamp.to_json(v)
elif isinstance(v, timedelta):
output[cased_name] = _Duration.to_json(v)
elif isinstance(v, list):
# Convert each item.
v = [i.to_dict() for i in v]
output[cased_name] = v
@@ -701,6 +728,12 @@ class Message(ABC):
cls = self._cls_for(field)
for i in range(len(value[key])):
v.append(cls().from_dict(value[key][i]))
elif isinstance(v, datetime):
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
setattr(self, field.name, v)
elif isinstance(v, timedelta):
v = timedelta(seconds=float(value[key][:-1]))
setattr(self, field.name, v)
else:
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
@@ -760,6 +793,65 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
return (field.name, getattr(message, field.name))
@dataclasses.dataclass
class _Duration(Message):
# Signed seconds of the span of time. Must be from -315,576,000,000 to
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60
# sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years
seconds: int = int64_field(1)
# Signed fractions of a second at nanosecond resolution of the span of time.
# Durations less than one second are represented with a 0 `seconds` field and
# a positive or negative `nanos` field. For durations of one second or more,
# a non-zero value for the `nanos` field must be of the same sign as the
# `seconds` field. Must be from -999,999,999 to +999,999,999 inclusive.
nanos: int = int32_field(2)
def to_timedelta(self) -> timedelta:
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
@staticmethod
def to_json(delta: timedelta) -> str:
parts = str(delta.total_seconds()).split(".")
if len(parts) > 1:
while len(parts[1]) not in [3, 6, 9]:
parts[1] = parts[1] + "0"
return ".".join(parts) + "s"
@dataclasses.dataclass
class _Timestamp(Message):
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
seconds: int = int64_field(1)
# Non-negative fractions of a second at nanosecond resolution. Negative
# second values with fractions must still have non-negative nanos values that
# count forward in time. Must be from 0 to 999,999,999 inclusive.
nanos: int = int32_field(2)
def to_datetime(self) -> datetime:
ts = self.seconds + (self.nanos / 1e9)
print('to-datetime', ts, datetime.fromtimestamp(ts, tz=timezone.utc))
return datetime.fromtimestamp(ts, tz=timezone.utc)
@staticmethod
def to_json(dt: datetime) -> str:
nanos = dt.microsecond * 1e3
copy = dt.replace(microsecond=0, tzinfo=None)
result = copy.isoformat()
if (nanos % 1e9) == 0:
# If there are 0 fractional digits, the fractional
# point '.' should be omitted when serializing.
return result + 'Z'
if (nanos % 1e6) == 0:
# Serialize 3 fractional digits.
return result + '.%03dZ' % (nanos / 1e6)
if (nanos % 1e3) == 0:
# Serialize 6 fractional digits.
return result + '.%06dZ' % (nanos / 1e3)
# Serialize 9 fractional digits.
return result + '.%09dZ' % nanos
class ServiceStub(ABC):
"""
Base class for async gRPC service stubs.