Support Duration/Timestamp Google well-known types

This commit is contained in:
Daniel G. Taylor 2019-10-26 23:07:30 -07:00
parent 5daf61f64c
commit c79535b614
No known key found for this signature in database
GPG Key ID: 7BD6DC99C9A87E22
3 changed files with 135 additions and 3 deletions

View File

@ -5,6 +5,7 @@ import json
import struct import struct
from abc import ABC from abc import ABC
from base64 import b64encode, b64decode from base64 import b64encode, b64decode
from datetime import datetime, timedelta, timezone
from typing import ( from typing import (
Any, Any,
AsyncGenerator, AsyncGenerator,
@ -295,6 +296,18 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
elif proto_type == TYPE_STRING: elif proto_type == TYPE_STRING:
return value.encode("utf-8") return value.encode("utf-8")
elif proto_type == TYPE_MESSAGE: elif proto_type == TYPE_MESSAGE:
if isinstance(value, datetime):
# Convert the `datetime` to a timestamp message.
seconds = int(value.timestamp())
nanos = int(value.microsecond * 1e3)
value = _Timestamp(seconds=seconds, nanos=nanos)
elif isinstance(value, timedelta):
# Convert the `timedelta` to a duration message.
total_ms = value // timedelta(microseconds=1)
seconds = int(total_ms / 1e6)
nanos = int((total_ms % 1e6) * 1e3)
value = _Duration(seconds=seconds, nanos=nanos)
return bytes(value) return bytes(value)
return value return value
@ -399,6 +412,7 @@ class Message(ABC):
meta = FieldMetadata.get(field) meta = FieldMetadata.get(field)
if meta.group: if meta.group:
# This is part of a one-of group.
group_map["fields"][field.name] = meta.group group_map["fields"][field.name] = meta.group
if meta.group not in group_map["groups"]: if meta.group not in group_map["groups"]:
@ -530,6 +544,9 @@ class Message(ABC):
elif issubclass(t, Enum): elif issubclass(t, Enum):
# Enums always default to zero. # Enums always default to zero.
value = 0 value = 0
elif t == datetime:
# Offsets are relative to 1970-01-01T00:00:00Z
value = datetime(1970, 1, 1, tzinfo=timezone.utc)
else: else:
# This is either a primitive scalar or another message type. Calling # This is either a primitive scalar or another message type. Calling
# it should result in its zero value. # it should result in its zero value.
@ -558,6 +575,12 @@ class Message(ABC):
value = value.decode("utf-8") value = value.decode("utf-8")
elif meta.proto_type == TYPE_MESSAGE: elif meta.proto_type == TYPE_MESSAGE:
cls = self._cls_for(field) cls = self._cls_for(field)
if cls == datetime:
value = _Timestamp().parse(value).to_datetime()
elif cls == timedelta:
value = _Duration().parse(value).to_timedelta()
else:
value = cls().parse(value) value = cls().parse(value)
value._serialized_on_wire = True value._serialized_on_wire = True
elif meta.proto_type == TYPE_MAP: elif meta.proto_type == TYPE_MAP:
@ -646,7 +669,11 @@ class Message(ABC):
v = getattr(self, field.name) v = getattr(self, field.name)
cased_name = casing(field.name).rstrip("_") cased_name = casing(field.name).rstrip("_")
if meta.proto_type == "message": if meta.proto_type == "message":
if isinstance(v, list): if isinstance(v, datetime):
output[cased_name] = _Timestamp.to_json(v)
elif isinstance(v, timedelta):
output[cased_name] = _Duration.to_json(v)
elif isinstance(v, list):
# Convert each item. # Convert each item.
v = [i.to_dict() for i in v] v = [i.to_dict() for i in v]
output[cased_name] = v output[cased_name] = v
@ -701,6 +728,12 @@ class Message(ABC):
cls = self._cls_for(field) cls = self._cls_for(field)
for i in range(len(value[key])): for i in range(len(value[key])):
v.append(cls().from_dict(value[key][i])) v.append(cls().from_dict(value[key][i]))
elif isinstance(v, datetime):
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
setattr(self, field.name, v)
elif isinstance(v, timedelta):
v = timedelta(seconds=float(value[key][:-1]))
setattr(self, field.name, v)
else: else:
v.from_dict(value[key]) v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
@ -760,6 +793,65 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
return (field.name, getattr(message, field.name)) return (field.name, getattr(message, field.name))
@dataclasses.dataclass
class _Duration(Message):
# Signed seconds of the span of time. Must be from -315,576,000,000 to
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60
# sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years
seconds: int = int64_field(1)
# Signed fractions of a second at nanosecond resolution of the span of time.
# Durations less than one second are represented with a 0 `seconds` field and
# a positive or negative `nanos` field. For durations of one second or more,
# a non-zero value for the `nanos` field must be of the same sign as the
# `seconds` field. Must be from -999,999,999 to +999,999,999 inclusive.
nanos: int = int32_field(2)
def to_timedelta(self) -> timedelta:
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
@staticmethod
def to_json(delta: timedelta) -> str:
parts = str(delta.total_seconds()).split(".")
if len(parts) > 1:
while len(parts[1]) not in [3, 6, 9]:
parts[1] = parts[1] + "0"
return ".".join(parts) + "s"
@dataclasses.dataclass
class _Timestamp(Message):
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
seconds: int = int64_field(1)
# Non-negative fractions of a second at nanosecond resolution. Negative
# second values with fractions must still have non-negative nanos values that
# count forward in time. Must be from 0 to 999,999,999 inclusive.
nanos: int = int32_field(2)
def to_datetime(self) -> datetime:
ts = self.seconds + (self.nanos / 1e9)
print('to-datetime', ts, datetime.fromtimestamp(ts, tz=timezone.utc))
return datetime.fromtimestamp(ts, tz=timezone.utc)
@staticmethod
def to_json(dt: datetime) -> str:
nanos = dt.microsecond * 1e3
copy = dt.replace(microsecond=0, tzinfo=None)
result = copy.isoformat()
if (nanos % 1e9) == 0:
# If there are 0 fractional digits, the fractional
# point '.' should be omitted when serializing.
return result + 'Z'
if (nanos % 1e6) == 0:
# Serialize 3 fractional digits.
return result + '.%03dZ' % (nanos / 1e6)
if (nanos % 1e3) == 0:
# Serialize 6 fractional digits.
return result + '.%06dZ' % (nanos / 1e3)
# Serialize 9 fractional digits.
return result + '.%09dZ' % nanos
class ServiceStub(ABC): class ServiceStub(ABC):
""" """
Base class for async gRPC service stubs. Base class for async gRPC service stubs.

View File

@ -30,6 +30,19 @@ from google.protobuf.descriptor_pb2 import (
from betterproto.casing import safe_snake_case from betterproto.casing import safe_snake_case
WRAPPER_TYPES = {
"google.protobuf.DoubleValue": "float",
"google.protobuf.FloatValue": "float",
"google.protobuf.Int64Value": "int",
"google.protobuf.UInt64Value": "int",
"google.protobuf.Int32Value": "int",
"google.protobuf.UInt32Value": "int",
"google.protobuf.BoolValue": "bool",
"google.protobuf.StringValue": "str",
"google.protobuf.BytesValue": "bytes",
}
def get_ref_type(package: str, imports: set, type_name: str) -> str: def get_ref_type(package: str, imports: set, type_name: str) -> str:
""" """
Return a Python type name for a proto type reference. Adds the import if Return a Python type name for a proto type reference. Adds the import if
@ -39,6 +52,16 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str:
# because by convention packages are lowercase and message/enum types are # because by convention packages are lowercase and message/enum types are
# pascal-cased. May require refactoring in the future. # pascal-cased. May require refactoring in the future.
type_name = type_name.lstrip(".") type_name = type_name.lstrip(".")
if type_name in WRAPPER_TYPES:
return f"Optional[{WRAPPER_TYPES[type_name]}]"
if type_name == "google.protobuf.Duration":
return "timedelta"
if type_name == "google.protobuf.Timestamp":
return "datetime"
if type_name.startswith(package): if type_name.startswith(package):
parts = type_name.lstrip(package).lstrip(".").split(".") parts = type_name.lstrip(package).lstrip(".").split(".")
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()): if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
@ -152,6 +175,9 @@ def generate_code(request, response):
output_map = {} output_map = {}
for proto_file in request.proto_file: for proto_file in request.proto_file:
out = proto_file.package out = proto_file.package
if out == "google.protobuf":
continue
if not out: if not out:
out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".") out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".")
@ -169,6 +195,7 @@ def generate_code(request, response):
"package": package, "package": package,
"files": [f.name for f in options["files"]], "files": [f.name for f in options["files"]],
"imports": set(), "imports": set(),
"datetime_imports": set(),
"typing_imports": set(), "typing_imports": set(),
"messages": [], "messages": [],
"enums": [], "enums": [],
@ -258,6 +285,14 @@ def generate_code(request, response):
if f.HasField("oneof_index"): if f.HasField("oneof_index"):
one_of = item.oneof_decl[f.oneof_index].name one_of = item.oneof_decl[f.oneof_index].name
if "Optional[" in t:
output["typing_imports"].add("Optional")
if "timedelta" in t:
output["datetime_imports"].add("timedelta")
elif "datetime" in t:
output["datetime_imports"].add("datetime")
data["properties"].append( data["properties"].append(
{ {
"name": f.name, "name": f.name,
@ -346,6 +381,7 @@ def generate_code(request, response):
output["services"].append(data) output["services"].append(data)
output["imports"] = sorted(output["imports"]) output["imports"] = sorted(output["imports"])
output["datetime_imports"] = sorted(output["datetime_imports"])
output["typing_imports"] = sorted(output["typing_imports"]) output["typing_imports"] = sorted(output["typing_imports"])
# Fill response # Fill response

View File

@ -2,6 +2,10 @@
# sources: {{ ', '.join(description.files) }} # sources: {{ ', '.join(description.files) }}
# plugin: python-betterproto # plugin: python-betterproto
from dataclasses import dataclass from dataclasses import dataclass
{% if description.datetime_imports %}
from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif%}
{% if description.typing_imports %} {% if description.typing_imports %}
from typing import {% for i in description.typing_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} from typing import {% for i in description.typing_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}