diff --git a/betterproto/__init__.py b/betterproto/__init__.py index ee36708..a9bc4b6 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -420,11 +420,14 @@ class Message(ABC): register the message fields which get used by the serializers and parsers to go between Python, binary and JSON protobuf message representations. """ + _serialized_on_wire: bool + _unknown_fields: bytes + _group_map: Dict[str, dict] def __post_init__(self) -> None: # Set a default value for each field in the class after `__init__` has # already been run. - group_map = {"fields": {}, "groups": {}} + group_map: Dict[str, dict] = {"fields": {}, "groups": {}} for field in dataclasses.fields(self): meta = FieldMetadata.get(field) @@ -518,7 +521,7 @@ class Message(ABC): else: for item in value: output += _serialize_single( - meta.number, meta.proto_type, item, wraps=meta.wraps + meta.number, meta.proto_type, item, wraps=meta.wraps or "" ) elif isinstance(value, dict): for k, v in value.items(): @@ -532,7 +535,7 @@ class Message(ABC): meta.proto_type, value, serialize_empty=serialize_empty, - wraps=meta.wraps, + wraps=meta.wraps or "", ) return output + self._unknown_fields @@ -702,14 +705,14 @@ class Message(ABC): for field in dataclasses.fields(self): meta = FieldMetadata.get(field) v = getattr(self, field.name) - cased_name = casing(field.name).rstrip("_") + cased_name = casing(field.name).rstrip("_") # type: ignore if meta.proto_type == "message": if isinstance(v, datetime): if v != DATETIME_ZERO: - output[cased_name] = _Timestamp.to_json(v) + output[cased_name] = _Timestamp.timestamp_to_json(v) elif isinstance(v, timedelta): if v != timedelta(0): - output[cased_name] = _Duration.to_json(v) + output[cased_name] = _Duration.delta_to_json(v) elif meta.wraps: if v is not None: output[cased_name] = v @@ -738,7 +741,7 @@ class Message(ABC): else: output[cased_name] = b64encode(v).decode("utf8") elif meta.proto_type == TYPE_ENUM: - enum_values = list(self._cls_for(field)) + enum_values = list(self._cls_for(field)) # type: ignore if isinstance(v, list): output[cased_name] = [enum_values[e].name for e in v] else: @@ -853,7 +856,7 @@ class _Duration(Message): return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3) @staticmethod - def to_json(delta: timedelta) -> str: + def delta_to_json(delta: timedelta) -> str: parts = str(delta.total_seconds()).split(".") if len(parts) > 1: while len(parts[1]) not in [3, 6, 9]: @@ -876,7 +879,7 @@ class _Timestamp(Message): return datetime.fromtimestamp(ts, tz=timezone.utc) @staticmethod - def to_json(dt: datetime) -> str: + def timestamp_to_json(dt: datetime) -> str: nanos = dt.microsecond * 1e3 copy = dt.replace(microsecond=0, tzinfo=None) result = copy.isoformat() @@ -899,12 +902,15 @@ class _WrappedMessage(Message): Google protobuf wrapper types base class. JSON representation is just the value itself. """ - def to_dict(self) -> Any: + value: Any + + def to_dict(self, casing: Casing = Casing.CAMEL) -> Any: return self.value - def from_dict(self, value: Any) -> None: + def from_dict(self: T, value: Any) -> T: if value is not None: self.value = value + return self @dataclasses.dataclass @@ -952,7 +958,7 @@ class _BytesValue(_WrappedMessage): value: bytes = bytes_field(1) -def _get_wrapper(proto_type: str) -> _WrappedMessage: +def _get_wrapper(proto_type: str) -> Type: """Get the wrapper message class for a wrapped type.""" return { TYPE_BOOL: _BoolValue,