Support Duration/Timestamp Google well-known types
This commit is contained in:
parent
5daf61f64c
commit
c79535b614
@ -5,6 +5,7 @@ import json
|
||||
import struct
|
||||
from abc import ABC
|
||||
from base64 import b64encode, b64decode
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
@ -295,6 +296,18 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
|
||||
elif proto_type == TYPE_STRING:
|
||||
return value.encode("utf-8")
|
||||
elif proto_type == TYPE_MESSAGE:
|
||||
if isinstance(value, datetime):
|
||||
# Convert the `datetime` to a timestamp message.
|
||||
seconds = int(value.timestamp())
|
||||
nanos = int(value.microsecond * 1e3)
|
||||
value = _Timestamp(seconds=seconds, nanos=nanos)
|
||||
elif isinstance(value, timedelta):
|
||||
# Convert the `timedelta` to a duration message.
|
||||
total_ms = value // timedelta(microseconds=1)
|
||||
seconds = int(total_ms / 1e6)
|
||||
nanos = int((total_ms % 1e6) * 1e3)
|
||||
value = _Duration(seconds=seconds, nanos=nanos)
|
||||
|
||||
return bytes(value)
|
||||
|
||||
return value
|
||||
@ -399,6 +412,7 @@ class Message(ABC):
|
||||
meta = FieldMetadata.get(field)
|
||||
|
||||
if meta.group:
|
||||
# This is part of a one-of group.
|
||||
group_map["fields"][field.name] = meta.group
|
||||
|
||||
if meta.group not in group_map["groups"]:
|
||||
@ -530,6 +544,9 @@ class Message(ABC):
|
||||
elif issubclass(t, Enum):
|
||||
# Enums always default to zero.
|
||||
value = 0
|
||||
elif t == datetime:
|
||||
# Offsets are relative to 1970-01-01T00:00:00Z
|
||||
value = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
else:
|
||||
# This is either a primitive scalar or another message type. Calling
|
||||
# it should result in its zero value.
|
||||
@ -558,6 +575,12 @@ class Message(ABC):
|
||||
value = value.decode("utf-8")
|
||||
elif meta.proto_type == TYPE_MESSAGE:
|
||||
cls = self._cls_for(field)
|
||||
|
||||
if cls == datetime:
|
||||
value = _Timestamp().parse(value).to_datetime()
|
||||
elif cls == timedelta:
|
||||
value = _Duration().parse(value).to_timedelta()
|
||||
else:
|
||||
value = cls().parse(value)
|
||||
value._serialized_on_wire = True
|
||||
elif meta.proto_type == TYPE_MAP:
|
||||
@ -646,7 +669,11 @@ class Message(ABC):
|
||||
v = getattr(self, field.name)
|
||||
cased_name = casing(field.name).rstrip("_")
|
||||
if meta.proto_type == "message":
|
||||
if isinstance(v, list):
|
||||
if isinstance(v, datetime):
|
||||
output[cased_name] = _Timestamp.to_json(v)
|
||||
elif isinstance(v, timedelta):
|
||||
output[cased_name] = _Duration.to_json(v)
|
||||
elif isinstance(v, list):
|
||||
# Convert each item.
|
||||
v = [i.to_dict() for i in v]
|
||||
output[cased_name] = v
|
||||
@ -701,6 +728,12 @@ class Message(ABC):
|
||||
cls = self._cls_for(field)
|
||||
for i in range(len(value[key])):
|
||||
v.append(cls().from_dict(value[key][i]))
|
||||
elif isinstance(v, datetime):
|
||||
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
|
||||
setattr(self, field.name, v)
|
||||
elif isinstance(v, timedelta):
|
||||
v = timedelta(seconds=float(value[key][:-1]))
|
||||
setattr(self, field.name, v)
|
||||
else:
|
||||
v.from_dict(value[key])
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
@ -760,6 +793,65 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
||||
return (field.name, getattr(message, field.name))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Duration(Message):
|
||||
# Signed seconds of the span of time. Must be from -315,576,000,000 to
|
||||
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60
|
||||
# sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years
|
||||
seconds: int = int64_field(1)
|
||||
# Signed fractions of a second at nanosecond resolution of the span of time.
|
||||
# Durations less than one second are represented with a 0 `seconds` field and
|
||||
# a positive or negative `nanos` field. For durations of one second or more,
|
||||
# a non-zero value for the `nanos` field must be of the same sign as the
|
||||
# `seconds` field. Must be from -999,999,999 to +999,999,999 inclusive.
|
||||
nanos: int = int32_field(2)
|
||||
|
||||
def to_timedelta(self) -> timedelta:
|
||||
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
|
||||
|
||||
@staticmethod
|
||||
def to_json(delta: timedelta) -> str:
|
||||
parts = str(delta.total_seconds()).split(".")
|
||||
if len(parts) > 1:
|
||||
while len(parts[1]) not in [3, 6, 9]:
|
||||
parts[1] = parts[1] + "0"
|
||||
return ".".join(parts) + "s"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _Timestamp(Message):
|
||||
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
|
||||
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
|
||||
seconds: int = int64_field(1)
|
||||
# Non-negative fractions of a second at nanosecond resolution. Negative
|
||||
# second values with fractions must still have non-negative nanos values that
|
||||
# count forward in time. Must be from 0 to 999,999,999 inclusive.
|
||||
nanos: int = int32_field(2)
|
||||
|
||||
def to_datetime(self) -> datetime:
|
||||
ts = self.seconds + (self.nanos / 1e9)
|
||||
print('to-datetime', ts, datetime.fromtimestamp(ts, tz=timezone.utc))
|
||||
return datetime.fromtimestamp(ts, tz=timezone.utc)
|
||||
|
||||
@staticmethod
|
||||
def to_json(dt: datetime) -> str:
|
||||
nanos = dt.microsecond * 1e3
|
||||
copy = dt.replace(microsecond=0, tzinfo=None)
|
||||
result = copy.isoformat()
|
||||
if (nanos % 1e9) == 0:
|
||||
# If there are 0 fractional digits, the fractional
|
||||
# point '.' should be omitted when serializing.
|
||||
return result + 'Z'
|
||||
if (nanos % 1e6) == 0:
|
||||
# Serialize 3 fractional digits.
|
||||
return result + '.%03dZ' % (nanos / 1e6)
|
||||
if (nanos % 1e3) == 0:
|
||||
# Serialize 6 fractional digits.
|
||||
return result + '.%06dZ' % (nanos / 1e3)
|
||||
# Serialize 9 fractional digits.
|
||||
return result + '.%09dZ' % nanos
|
||||
|
||||
|
||||
class ServiceStub(ABC):
|
||||
"""
|
||||
Base class for async gRPC service stubs.
|
||||
|
@ -30,6 +30,19 @@ from google.protobuf.descriptor_pb2 import (
|
||||
from betterproto.casing import safe_snake_case
|
||||
|
||||
|
||||
WRAPPER_TYPES = {
|
||||
"google.protobuf.DoubleValue": "float",
|
||||
"google.protobuf.FloatValue": "float",
|
||||
"google.protobuf.Int64Value": "int",
|
||||
"google.protobuf.UInt64Value": "int",
|
||||
"google.protobuf.Int32Value": "int",
|
||||
"google.protobuf.UInt32Value": "int",
|
||||
"google.protobuf.BoolValue": "bool",
|
||||
"google.protobuf.StringValue": "str",
|
||||
"google.protobuf.BytesValue": "bytes",
|
||||
}
|
||||
|
||||
|
||||
def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
||||
"""
|
||||
Return a Python type name for a proto type reference. Adds the import if
|
||||
@ -39,6 +52,16 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
||||
# because by convention packages are lowercase and message/enum types are
|
||||
# pascal-cased. May require refactoring in the future.
|
||||
type_name = type_name.lstrip(".")
|
||||
|
||||
if type_name in WRAPPER_TYPES:
|
||||
return f"Optional[{WRAPPER_TYPES[type_name]}]"
|
||||
|
||||
if type_name == "google.protobuf.Duration":
|
||||
return "timedelta"
|
||||
|
||||
if type_name == "google.protobuf.Timestamp":
|
||||
return "datetime"
|
||||
|
||||
if type_name.startswith(package):
|
||||
parts = type_name.lstrip(package).lstrip(".").split(".")
|
||||
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
|
||||
@ -152,6 +175,9 @@ def generate_code(request, response):
|
||||
output_map = {}
|
||||
for proto_file in request.proto_file:
|
||||
out = proto_file.package
|
||||
if out == "google.protobuf":
|
||||
continue
|
||||
|
||||
if not out:
|
||||
out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".")
|
||||
|
||||
@ -169,6 +195,7 @@ def generate_code(request, response):
|
||||
"package": package,
|
||||
"files": [f.name for f in options["files"]],
|
||||
"imports": set(),
|
||||
"datetime_imports": set(),
|
||||
"typing_imports": set(),
|
||||
"messages": [],
|
||||
"enums": [],
|
||||
@ -258,6 +285,14 @@ def generate_code(request, response):
|
||||
if f.HasField("oneof_index"):
|
||||
one_of = item.oneof_decl[f.oneof_index].name
|
||||
|
||||
if "Optional[" in t:
|
||||
output["typing_imports"].add("Optional")
|
||||
|
||||
if "timedelta" in t:
|
||||
output["datetime_imports"].add("timedelta")
|
||||
elif "datetime" in t:
|
||||
output["datetime_imports"].add("datetime")
|
||||
|
||||
data["properties"].append(
|
||||
{
|
||||
"name": f.name,
|
||||
@ -346,6 +381,7 @@ def generate_code(request, response):
|
||||
output["services"].append(data)
|
||||
|
||||
output["imports"] = sorted(output["imports"])
|
||||
output["datetime_imports"] = sorted(output["datetime_imports"])
|
||||
output["typing_imports"] = sorted(output["typing_imports"])
|
||||
|
||||
# Fill response
|
||||
|
@ -2,6 +2,10 @@
|
||||
# sources: {{ ', '.join(description.files) }}
|
||||
# plugin: python-betterproto
|
||||
from dataclasses import dataclass
|
||||
{% if description.datetime_imports %}
|
||||
from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif%}
|
||||
{% if description.typing_imports %}
|
||||
from typing import {% for i in description.typing_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user