Support Duration/Timestamp Google well-known types
This commit is contained in:
@@ -5,6 +5,7 @@ import json
|
||||
import struct
|
||||
from abc import ABC
|
||||
from base64 import b64encode, b64decode
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
@@ -295,6 +296,18 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
|
||||
elif proto_type == TYPE_STRING:
|
||||
return value.encode("utf-8")
|
||||
elif proto_type == TYPE_MESSAGE:
|
||||
if isinstance(value, datetime):
|
||||
# Convert the `datetime` to a timestamp message.
|
||||
seconds = int(value.timestamp())
|
||||
nanos = int(value.microsecond * 1e3)
|
||||
value = _Timestamp(seconds=seconds, nanos=nanos)
|
||||
elif isinstance(value, timedelta):
|
||||
# Convert the `timedelta` to a duration message.
|
||||
total_ms = value // timedelta(microseconds=1)
|
||||
seconds = int(total_ms / 1e6)
|
||||
nanos = int((total_ms % 1e6) * 1e3)
|
||||
value = _Duration(seconds=seconds, nanos=nanos)
|
||||
|
||||
return bytes(value)
|
||||
|
||||
return value
|
||||
@@ -399,6 +412,7 @@ class Message(ABC):
|
||||
meta = FieldMetadata.get(field)
|
||||
|
||||
if meta.group:
|
||||
# This is part of a one-of group.
|
||||
group_map["fields"][field.name] = meta.group
|
||||
|
||||
if meta.group not in group_map["groups"]:
|
||||
@@ -530,6 +544,9 @@ class Message(ABC):
|
||||
elif issubclass(t, Enum):
|
||||
# Enums always default to zero.
|
||||
value = 0
|
||||
elif t == datetime:
|
||||
# Offsets are relative to 1970-01-01T00:00:00Z
|
||||
value = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
else:
|
||||
# This is either a primitive scalar or another message type. Calling
|
||||
# it should result in its zero value.
|
||||
@@ -558,8 +575,14 @@ class Message(ABC):
|
||||
value = value.decode("utf-8")
|
||||
elif meta.proto_type == TYPE_MESSAGE:
|
||||
cls = self._cls_for(field)
|
||||
value = cls().parse(value)
|
||||
value._serialized_on_wire = True
|
||||
|
||||
if cls == datetime:
|
||||
value = _Timestamp().parse(value).to_datetime()
|
||||
elif cls == timedelta:
|
||||
value = _Duration().parse(value).to_timedelta()
|
||||
else:
|
||||
value = cls().parse(value)
|
||||
value._serialized_on_wire = True
|
||||
elif meta.proto_type == TYPE_MAP:
|
||||
# TODO: This is slow, use a cache to make it faster since each
|
||||
# key/value pair will recreate the class.
|
||||
@@ -646,7 +669,11 @@ class Message(ABC):
|
||||
v = getattr(self, field.name)
|
||||
cased_name = casing(field.name).rstrip("_")
|
||||
if meta.proto_type == "message":
|
||||
if isinstance(v, list):
|
||||
if isinstance(v, datetime):
|
||||
output[cased_name] = _Timestamp.to_json(v)
|
||||
elif isinstance(v, timedelta):
|
||||
output[cased_name] = _Duration.to_json(v)
|
||||
elif isinstance(v, list):
|
||||
# Convert each item.
|
||||
v = [i.to_dict() for i in v]
|
||||
output[cased_name] = v
|
||||
@@ -701,6 +728,12 @@ class Message(ABC):
|
||||
cls = self._cls_for(field)
|
||||
for i in range(len(value[key])):
|
||||
v.append(cls().from_dict(value[key][i]))
|
||||
elif isinstance(v, datetime):
|
||||
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
|
||||
setattr(self, field.name, v)
|
||||
elif isinstance(v, timedelta):
|
||||
v = timedelta(seconds=float(value[key][:-1]))
|
||||
setattr(self, field.name, v)
|
||||
else:
|
||||
v.from_dict(value[key])
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
@@ -760,6 +793,65 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
||||
return (field.name, getattr(message, field.name))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Duration(Message):
|
||||
# Signed seconds of the span of time. Must be from -315,576,000,000 to
|
||||
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60
|
||||
# sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years
|
||||
seconds: int = int64_field(1)
|
||||
# Signed fractions of a second at nanosecond resolution of the span of time.
|
||||
# Durations less than one second are represented with a 0 `seconds` field and
|
||||
# a positive or negative `nanos` field. For durations of one second or more,
|
||||
# a non-zero value for the `nanos` field must be of the same sign as the
|
||||
# `seconds` field. Must be from -999,999,999 to +999,999,999 inclusive.
|
||||
nanos: int = int32_field(2)
|
||||
|
||||
def to_timedelta(self) -> timedelta:
|
||||
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
|
||||
|
||||
@staticmethod
|
||||
def to_json(delta: timedelta) -> str:
|
||||
parts = str(delta.total_seconds()).split(".")
|
||||
if len(parts) > 1:
|
||||
while len(parts[1]) not in [3, 6, 9]:
|
||||
parts[1] = parts[1] + "0"
|
||||
return ".".join(parts) + "s"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Timestamp(Message):
|
||||
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
|
||||
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
|
||||
seconds: int = int64_field(1)
|
||||
# Non-negative fractions of a second at nanosecond resolution. Negative
|
||||
# second values with fractions must still have non-negative nanos values that
|
||||
# count forward in time. Must be from 0 to 999,999,999 inclusive.
|
||||
nanos: int = int32_field(2)
|
||||
|
||||
def to_datetime(self) -> datetime:
|
||||
ts = self.seconds + (self.nanos / 1e9)
|
||||
print('to-datetime', ts, datetime.fromtimestamp(ts, tz=timezone.utc))
|
||||
return datetime.fromtimestamp(ts, tz=timezone.utc)
|
||||
|
||||
@staticmethod
|
||||
def to_json(dt: datetime) -> str:
|
||||
nanos = dt.microsecond * 1e3
|
||||
copy = dt.replace(microsecond=0, tzinfo=None)
|
||||
result = copy.isoformat()
|
||||
if (nanos % 1e9) == 0:
|
||||
# If there are 0 fractional digits, the fractional
|
||||
# point '.' should be omitted when serializing.
|
||||
return result + 'Z'
|
||||
if (nanos % 1e6) == 0:
|
||||
# Serialize 3 fractional digits.
|
||||
return result + '.%03dZ' % (nanos / 1e6)
|
||||
if (nanos % 1e3) == 0:
|
||||
# Serialize 6 fractional digits.
|
||||
return result + '.%06dZ' % (nanos / 1e3)
|
||||
# Serialize 9 fractional digits.
|
||||
return result + '.%09dZ' % nanos
|
||||
|
||||
|
||||
class ServiceStub(ABC):
|
||||
"""
|
||||
Base class for async gRPC service stubs.
|
||||
|
||||
Reference in New Issue
Block a user