diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 8fa5819..9caaa70 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -5,6 +5,7 @@ import json import struct from abc import ABC from base64 import b64encode, b64decode +from datetime import datetime, timedelta, timezone from typing import ( Any, AsyncGenerator, @@ -295,6 +296,18 @@ def _preprocess_single(proto_type: str, value: Any) -> bytes: elif proto_type == TYPE_STRING: return value.encode("utf-8") elif proto_type == TYPE_MESSAGE: + if isinstance(value, datetime): + # Convert the `datetime` to a timestamp message. + seconds = int(value.timestamp()) + nanos = int(value.microsecond * 1e3) + value = _Timestamp(seconds=seconds, nanos=nanos) + elif isinstance(value, timedelta): + # Convert the `timedelta` to a duration message. + total_ms = value // timedelta(microseconds=1) + seconds = int(total_ms / 1e6) + nanos = int((total_ms % 1e6) * 1e3) + value = _Duration(seconds=seconds, nanos=nanos) + return bytes(value) return value @@ -399,6 +412,7 @@ class Message(ABC): meta = FieldMetadata.get(field) if meta.group: + # This is part of a one-of group. group_map["fields"][field.name] = meta.group if meta.group not in group_map["groups"]: @@ -530,6 +544,9 @@ class Message(ABC): elif issubclass(t, Enum): # Enums always default to zero. value = 0 + elif t == datetime: + # Offsets are relative to 1970-01-01T00:00:00Z + value = datetime(1970, 1, 1, tzinfo=timezone.utc) else: # This is either a primitive scalar or another message type. Calling # it should result in its zero value. @@ -558,8 +575,14 @@ class Message(ABC): value = value.decode("utf-8") elif meta.proto_type == TYPE_MESSAGE: cls = self._cls_for(field) - value = cls().parse(value) - value._serialized_on_wire = True + + if cls == datetime: + value = _Timestamp().parse(value).to_datetime() + elif cls == timedelta: + value = _Duration().parse(value).to_timedelta() + else: + value = cls().parse(value) + value._serialized_on_wire = True elif meta.proto_type == TYPE_MAP: # TODO: This is slow, use a cache to make it faster since each # key/value pair will recreate the class. @@ -646,7 +669,11 @@ class Message(ABC): v = getattr(self, field.name) cased_name = casing(field.name).rstrip("_") if meta.proto_type == "message": - if isinstance(v, list): + if isinstance(v, datetime): + output[cased_name] = _Timestamp.to_json(v) + elif isinstance(v, timedelta): + output[cased_name] = _Duration.to_json(v) + elif isinstance(v, list): # Convert each item. v = [i.to_dict() for i in v] output[cased_name] = v @@ -701,6 +728,12 @@ class Message(ABC): cls = self._cls_for(field) for i in range(len(value[key])): v.append(cls().from_dict(value[key][i])) + elif isinstance(v, datetime): + v = datetime.fromisoformat(value[key].replace("Z", "+00:00")) + setattr(self, field.name, v) + elif isinstance(v, timedelta): + v = timedelta(seconds=float(value[key][:-1])) + setattr(self, field.name, v) else: v.from_dict(value[key]) elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: @@ -760,6 +793,65 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]: return (field.name, getattr(message, field.name)) +@dataclasses.dataclass +class _Duration(Message): + # Signed seconds of the span of time. Must be from -315,576,000,000 to + # +315,576,000,000 inclusive. Note: these bounds are computed from: 60 + # sec/min * 60 min/hr * 24 hr/day * 365.25 days/year * 10000 years + seconds: int = int64_field(1) + # Signed fractions of a second at nanosecond resolution of the span of time. + # Durations less than one second are represented with a 0 `seconds` field and + # a positive or negative `nanos` field. For durations of one second or more, + # a non-zero value for the `nanos` field must be of the same sign as the + # `seconds` field. Must be from -999,999,999 to +999,999,999 inclusive. + nanos: int = int32_field(2) + + def to_timedelta(self) -> timedelta: + return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3) + + @staticmethod + def to_json(delta: timedelta) -> str: + parts = str(delta.total_seconds()).split(".") + if len(parts) > 1: + while len(parts[1]) not in [3, 6, 9]: + parts[1] = parts[1] + "0" + return ".".join(parts) + "s" + + +@dataclasses.dataclass +class _Timestamp(Message): + # Represents seconds of UTC time since Unix epoch 1970-01-01T00:00:00Z. Must + # be from 0001-01-01T00:00:00Z to 9999-12-31T23:59:59Z inclusive. + seconds: int = int64_field(1) + # Non-negative fractions of a second at nanosecond resolution. Negative + # second values with fractions must still have non-negative nanos values that + # count forward in time. Must be from 0 to 999,999,999 inclusive. + nanos: int = int32_field(2) + + def to_datetime(self) -> datetime: + ts = self.seconds + (self.nanos / 1e9) + print('to-datetime', ts, datetime.fromtimestamp(ts, tz=timezone.utc)) + return datetime.fromtimestamp(ts, tz=timezone.utc) + + @staticmethod + def to_json(dt: datetime) -> str: + nanos = dt.microsecond * 1e3 + copy = dt.replace(microsecond=0, tzinfo=None) + result = copy.isoformat() + if (nanos % 1e9) == 0: + # If there are 0 fractional digits, the fractional + # point '.' should be omitted when serializing. + return result + 'Z' + if (nanos % 1e6) == 0: + # Serialize 3 fractional digits. + return result + '.%03dZ' % (nanos / 1e6) + if (nanos % 1e3) == 0: + # Serialize 6 fractional digits. + return result + '.%06dZ' % (nanos / 1e3) + # Serialize 9 fractional digits. + return result + '.%09dZ' % nanos + + class ServiceStub(ABC): """ Base class for async gRPC service stubs. diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 38c08d4..bff2986 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -30,6 +30,19 @@ from google.protobuf.descriptor_pb2 import ( from betterproto.casing import safe_snake_case +WRAPPER_TYPES = { + "google.protobuf.DoubleValue": "float", + "google.protobuf.FloatValue": "float", + "google.protobuf.Int64Value": "int", + "google.protobuf.UInt64Value": "int", + "google.protobuf.Int32Value": "int", + "google.protobuf.UInt32Value": "int", + "google.protobuf.BoolValue": "bool", + "google.protobuf.StringValue": "str", + "google.protobuf.BytesValue": "bytes", +} + + def get_ref_type(package: str, imports: set, type_name: str) -> str: """ Return a Python type name for a proto type reference. Adds the import if @@ -39,6 +52,16 @@ def get_ref_type(package: str, imports: set, type_name: str) -> str: # because by convention packages are lowercase and message/enum types are # pascal-cased. May require refactoring in the future. type_name = type_name.lstrip(".") + + if type_name in WRAPPER_TYPES: + return f"Optional[{WRAPPER_TYPES[type_name]}]" + + if type_name == "google.protobuf.Duration": + return "timedelta" + + if type_name == "google.protobuf.Timestamp": + return "datetime" + if type_name.startswith(package): parts = type_name.lstrip(package).lstrip(".").split(".") if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()): @@ -152,6 +175,9 @@ def generate_code(request, response): output_map = {} for proto_file in request.proto_file: out = proto_file.package + if out == "google.protobuf": + continue + if not out: out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".") @@ -169,6 +195,7 @@ def generate_code(request, response): "package": package, "files": [f.name for f in options["files"]], "imports": set(), + "datetime_imports": set(), "typing_imports": set(), "messages": [], "enums": [], @@ -258,6 +285,14 @@ def generate_code(request, response): if f.HasField("oneof_index"): one_of = item.oneof_decl[f.oneof_index].name + if "Optional[" in t: + output["typing_imports"].add("Optional") + + if "timedelta" in t: + output["datetime_imports"].add("timedelta") + elif "datetime" in t: + output["datetime_imports"].add("datetime") + data["properties"].append( { "name": f.name, @@ -346,6 +381,7 @@ def generate_code(request, response): output["services"].append(data) output["imports"] = sorted(output["imports"]) + output["datetime_imports"] = sorted(output["datetime_imports"]) output["typing_imports"] = sorted(output["typing_imports"]) # Fill response diff --git a/betterproto/templates/template.py b/betterproto/templates/template.py index 5ae5857..630afc6 100644 --- a/betterproto/templates/template.py +++ b/betterproto/templates/template.py @@ -2,6 +2,10 @@ # sources: {{ ', '.join(description.files) }} # plugin: python-betterproto from dataclasses import dataclass +{% if description.datetime_imports %} +from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} + +{% endif%} {% if description.typing_imports %} from typing import {% for i in description.typing_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}