Support Duration/Timestamp Google well-known types
This commit is contained in:
parent
5daf61f64c
commit
c79535b614
@ -5,6 +5,7 @@ import json
|
|||||||
import struct
|
import struct
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from base64 import b64encode, b64decode
|
from base64 import b64encode, b64decode
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
AsyncGenerator,
|
AsyncGenerator,
|
||||||
@ -295,6 +296,18 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes:
|
|||||||
elif proto_type == TYPE_STRING:
|
elif proto_type == TYPE_STRING:
|
||||||
return value.encode("utf-8")
|
return value.encode("utf-8")
|
||||||
elif proto_type == TYPE_MESSAGE:
|
elif proto_type == TYPE_MESSAGE:
|
||||||
|
if isinstance(value, datetime):
|
||||||
|
# Convert the `datetime` to a timestamp message.
|
||||||
|
seconds = int(value.timestamp())
|
||||||
|
nanos = int(value.microsecond * 1e3)
|
||||||
|
value = _Timestamp(seconds=seconds, nanos=nanos)
|
||||||
|
elif isinstance(value, timedelta):
|
||||||
|
# Convert the `timedelta` to a duration message.
|
||||||
|
total_ms = value // timedelta(microseconds=1)
|
||||||
|
seconds = int(total_ms / 1e6)
|
||||||
|
nanos = int((total_ms % 1e6) * 1e3)
|
||||||
|
value = _Duration(seconds=seconds, nanos=nanos)
|
||||||
|
|
||||||
return bytes(value)
|
return bytes(value)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
@ -399,6 +412,7 @@ class Message(ABC):
|
|||||||
meta = FieldMetadata.get(field)
|
meta = FieldMetadata.get(field)
|
||||||
|
|
||||||
if meta.group:
|
if meta.group:
|
||||||
|
# This is part of a one-of group.
|
||||||
group_map["fields"][field.name] = meta.group
|
group_map["fields"][field.name] = meta.group
|
||||||
|
|
||||||
if meta.group not in group_map["groups"]:
|
if meta.group not in group_map["groups"]:
|
||||||
@ -530,6 +544,9 @@ class Message(ABC):
|
|||||||
elif issubclass(t, Enum):
|
elif issubclass(t, Enum):
|
||||||
# Enums always default to zero.
|
# Enums always default to zero.
|
||||||
value = 0
|
value = 0
|
||||||
|
elif t == datetime:
|
||||||
|
# Offsets are relative to 1970-01-01T00:00:00Z
|
||||||
|
value = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||||
else:
|
else:
|
||||||
# This is either a primitive scalar or another message type. Calling
|
# This is either a primitive scalar or another message type. Calling
|
||||||
# it should result in its zero value.
|
# it should result in its zero value.
|
||||||
@ -558,8 +575,14 @@ class Message(ABC):
|
|||||||
value = value.decode("utf-8")
|
value = value.decode("utf-8")
|
||||||
elif meta.proto_type == TYPE_MESSAGE:
|
elif meta.proto_type == TYPE_MESSAGE:
|
||||||
cls = self._cls_for(field)
|
cls = self._cls_for(field)
|
||||||
value = cls().parse(value)
|
|
||||||
value._serialized_on_wire = True
|
if cls == datetime:
|
||||||
|
value = _Timestamp().parse(value).to_datetime()
|
||||||
|
elif cls == timedelta:
|
||||||
|
value = _Duration().parse(value).to_timedelta()
|
||||||
|
else:
|
||||||
|
value = cls().parse(value)
|
||||||
|
value._serialized_on_wire = True
|
||||||
elif meta.proto_type == TYPE_MAP:
|
elif meta.proto_type == TYPE_MAP:
|
||||||
# TODO: This is slow, use a cache to make it faster since each
|
# TODO: This is slow, use a cache to make it faster since each
|
||||||
# key/value pair will recreate the class.
|
# key/value pair will recreate the class.
|
||||||
@ -646,7 +669,11 @@ class Message(ABC):
|
|||||||
v = getattr(self, field.name)
|
v = getattr(self, field.name)
|
||||||
cased_name = casing(field.name).rstrip("_")
|
cased_name = casing(field.name).rstrip("_")
|
||||||
if meta.proto_type == "message":
|
if meta.proto_type == "message":
|
||||||
if isinstance(v, list):
|
if isinstance(v, datetime):
|
||||||
|
output[cased_name] = _Timestamp.to_json(v)
|
||||||
|
elif isinstance(v, timedelta):
|
||||||
|
output[cased_name] = _Duration.to_json(v)
|
||||||
|
elif isinstance(v, list):
|
||||||
# Convert each item.
|
# Convert each item.
|
||||||
v = [i.to_dict() for i in v]
|
v = [i.to_dict() for i in v]
|
||||||
output[cased_name] = v
|
output[cased_name] = v
|
||||||
@ -701,6 +728,12 @@ class Message(ABC):
|
|||||||
cls = self._cls_for(field)
|
cls = self._cls_for(field)
|
||||||
for i in range(len(value[key])):
|
for i in range(len(value[key])):
|
||||||
v.append(cls().from_dict(value[key][i]))
|
v.append(cls().from_dict(value[key][i]))
|
||||||
|
elif isinstance(v, datetime):
|
||||||
|
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
|
||||||
|
setattr(self, field.name, v)
|
||||||
|
elif isinstance(v, timedelta):
|
||||||
|
v = timedelta(seconds=float(value[key][:-1]))
|
||||||
|
setattr(self, field.name, v)
|
||||||
else:
|
else:
|
||||||
v.from_dict(value[key])
|
v.from_dict(value[key])
|
||||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||||
@ -760,6 +793,65 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
|||||||
return (field.name, getattr(message, field.name))
|
return (field.name, getattr(message, field.name))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class _Duration(Message):
|
||||||
|
# Signed seconds of the span of time. Must be from -315,576,000,000 to
|
||||||
|
# +315,576,000,000 inclusive. Note: these bounds are computed from: 60
|
||||||
|
# sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years
|
||||||
|
seconds: int = int64_field(1)
|
||||||
|
# Signed fractions of a second at nanosecond resolution of the span of time.
|
||||||
|
# Durations less than one second are represented with a 0 `seconds` field and
|
||||||
|
# a positive or negative `nanos` field. For durations of one second or more,
|
||||||
|
# a non-zero value for the `nanos` field must be of the same sign as the
|
||||||
|
# `seconds` field. Must be from -999,999,999 to +999,999,999 inclusive.
|
||||||
|
nanos: int = int32_field(2)
|
||||||
|
|
||||||
|
def to_timedelta(self) -> timedelta:
|
||||||
|
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_json(delta: timedelta) -> str:
|
||||||
|
parts = str(delta.total_seconds()).split(".")
|
||||||
|
if len(parts) > 1:
|
||||||
|
while len(parts[1]) not in [3, 6, 9]:
|
||||||
|
parts[1] = parts[1] + "0"
|
||||||
|
return ".".join(parts) + "s"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class _Timestamp(Message):
|
||||||
|
# Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must
|
||||||
|
# be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive.
|
||||||
|
seconds: int = int64_field(1)
|
||||||
|
# Non-negative fractions of a second at nanosecond resolution. Negative
|
||||||
|
# second values with fractions must still have non-negative nanos values that
|
||||||
|
# count forward in time. Must be from 0 to 999,999,999 inclusive.
|
||||||
|
nanos: int = int32_field(2)
|
||||||
|
|
||||||
|
def to_datetime(self) -> datetime:
|
||||||
|
ts = self.seconds + (self.nanos / 1e9)
|
||||||
|
print('to-datetime', ts, datetime.fromtimestamp(ts, tz=timezone.utc))
|
||||||
|
return datetime.fromtimestamp(ts, tz=timezone.utc)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_json(dt: datetime) -> str:
|
||||||
|
nanos = dt.microsecond * 1e3
|
||||||
|
copy = dt.replace(microsecond=0, tzinfo=None)
|
||||||
|
result = copy.isoformat()
|
||||||
|
if (nanos % 1e9) == 0:
|
||||||
|
# If there are 0 fractional digits, the fractional
|
||||||
|
# point '.' should be omitted when serializing.
|
||||||
|
return result + 'Z'
|
||||||
|
if (nanos % 1e6) == 0:
|
||||||
|
# Serialize 3 fractional digits.
|
||||||
|
return result + '.%03dZ' % (nanos / 1e6)
|
||||||
|
if (nanos % 1e3) == 0:
|
||||||
|
# Serialize 6 fractional digits.
|
||||||
|
return result + '.%06dZ' % (nanos / 1e3)
|
||||||
|
# Serialize 9 fractional digits.
|
||||||
|
return result + '.%09dZ' % nanos
|
||||||
|
|
||||||
|
|
||||||
class ServiceStub(ABC):
|
class ServiceStub(ABC):
|
||||||
"""
|
"""
|
||||||
Base class for async gRPC service stubs.
|
Base class for async gRPC service stubs.
|
||||||
|
@ -30,6 +30,19 @@ from google.protobuf.descriptor_pb2 import (
|
|||||||
from betterproto.casing import safe_snake_case
|
from betterproto.casing import safe_snake_case
|
||||||
|
|
||||||
|
|
||||||
|
WRAPPER_TYPES = {
|
||||||
|
"google.protobuf.DoubleValue": "float",
|
||||||
|
"google.protobuf.FloatValue": "float",
|
||||||
|
"google.protobuf.Int64Value": "int",
|
||||||
|
"google.protobuf.UInt64Value": "int",
|
||||||
|
"google.protobuf.Int32Value": "int",
|
||||||
|
"google.protobuf.UInt32Value": "int",
|
||||||
|
"google.protobuf.BoolValue": "bool",
|
||||||
|
"google.protobuf.StringValue": "str",
|
||||||
|
"google.protobuf.BytesValue": "bytes",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
||||||
"""
|
"""
|
||||||
Return a Python type name for a proto type reference. Adds the import if
|
Return a Python type name for a proto type reference. Adds the import if
|
||||||
@ -39,6 +52,16 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str:
|
|||||||
# because by convention packages are lowercase and message/enum types are
|
# because by convention packages are lowercase and message/enum types are
|
||||||
# pascal-cased. May require refactoring in the future.
|
# pascal-cased. May require refactoring in the future.
|
||||||
type_name = type_name.lstrip(".")
|
type_name = type_name.lstrip(".")
|
||||||
|
|
||||||
|
if type_name in WRAPPER_TYPES:
|
||||||
|
return f"Optional[{WRAPPER_TYPES[type_name]}]"
|
||||||
|
|
||||||
|
if type_name == "google.protobuf.Duration":
|
||||||
|
return "timedelta"
|
||||||
|
|
||||||
|
if type_name == "google.protobuf.Timestamp":
|
||||||
|
return "datetime"
|
||||||
|
|
||||||
if type_name.startswith(package):
|
if type_name.startswith(package):
|
||||||
parts = type_name.lstrip(package).lstrip(".").split(".")
|
parts = type_name.lstrip(package).lstrip(".").split(".")
|
||||||
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
|
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
|
||||||
@ -152,6 +175,9 @@ def generate_code(request, response):
|
|||||||
output_map = {}
|
output_map = {}
|
||||||
for proto_file in request.proto_file:
|
for proto_file in request.proto_file:
|
||||||
out = proto_file.package
|
out = proto_file.package
|
||||||
|
if out == "google.protobuf":
|
||||||
|
continue
|
||||||
|
|
||||||
if not out:
|
if not out:
|
||||||
out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".")
|
out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".")
|
||||||
|
|
||||||
@ -169,6 +195,7 @@ def generate_code(request, response):
|
|||||||
"package": package,
|
"package": package,
|
||||||
"files": [f.name for f in options["files"]],
|
"files": [f.name for f in options["files"]],
|
||||||
"imports": set(),
|
"imports": set(),
|
||||||
|
"datetime_imports": set(),
|
||||||
"typing_imports": set(),
|
"typing_imports": set(),
|
||||||
"messages": [],
|
"messages": [],
|
||||||
"enums": [],
|
"enums": [],
|
||||||
@ -258,6 +285,14 @@ def generate_code(request, response):
|
|||||||
if f.HasField("oneof_index"):
|
if f.HasField("oneof_index"):
|
||||||
one_of = item.oneof_decl[f.oneof_index].name
|
one_of = item.oneof_decl[f.oneof_index].name
|
||||||
|
|
||||||
|
if "Optional[" in t:
|
||||||
|
output["typing_imports"].add("Optional")
|
||||||
|
|
||||||
|
if "timedelta" in t:
|
||||||
|
output["datetime_imports"].add("timedelta")
|
||||||
|
elif "datetime" in t:
|
||||||
|
output["datetime_imports"].add("datetime")
|
||||||
|
|
||||||
data["properties"].append(
|
data["properties"].append(
|
||||||
{
|
{
|
||||||
"name": f.name,
|
"name": f.name,
|
||||||
@ -346,6 +381,7 @@ def generate_code(request, response):
|
|||||||
output["services"].append(data)
|
output["services"].append(data)
|
||||||
|
|
||||||
output["imports"] = sorted(output["imports"])
|
output["imports"] = sorted(output["imports"])
|
||||||
|
output["datetime_imports"] = sorted(output["datetime_imports"])
|
||||||
output["typing_imports"] = sorted(output["typing_imports"])
|
output["typing_imports"] = sorted(output["typing_imports"])
|
||||||
|
|
||||||
# Fill response
|
# Fill response
|
||||||
|
@ -2,6 +2,10 @@
|
|||||||
# sources: {{ ', '.join(description.files) }}
|
# sources: {{ ', '.join(description.files) }}
|
||||||
# plugin: python-betterproto
|
# plugin: python-betterproto
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
{% if description.datetime_imports %}
|
||||||
|
from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||||
|
|
||||||
|
{% endif%}
|
||||||
{% if description.typing_imports %}
|
{% if description.typing_imports %}
|
||||||
from typing import {% for i in description.typing_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
from typing import {% for i in description.typing_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user