Typing fixes
This commit is contained in:
parent
eb5020db2a
commit
16687211a2
@ -420,11 +420,14 @@ class Message(ABC):
|
|||||||
register the message fields which get used by the serializers and parsers
|
register the message fields which get used by the serializers and parsers
|
||||||
to go between Python, binary and JSON protobuf message representations.
|
to go between Python, binary and JSON protobuf message representations.
|
||||||
"""
|
"""
|
||||||
|
_serialized_on_wire: bool
|
||||||
|
_unknown_fields: bytes
|
||||||
|
_group_map: Dict[str, dict]
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
# Set a default value for each field in the class after `__init__` has
|
# Set a default value for each field in the class after `__init__` has
|
||||||
# already been run.
|
# already been run.
|
||||||
group_map = {"fields": {}, "groups": {}}
|
group_map: Dict[str, dict] = {"fields": {}, "groups": {}}
|
||||||
for field in dataclasses.fields(self):
|
for field in dataclasses.fields(self):
|
||||||
meta = FieldMetadata.get(field)
|
meta = FieldMetadata.get(field)
|
||||||
|
|
||||||
@ -518,7 +521,7 @@ class Message(ABC):
|
|||||||
else:
|
else:
|
||||||
for item in value:
|
for item in value:
|
||||||
output += _serialize_single(
|
output += _serialize_single(
|
||||||
meta.number, meta.proto_type, item, wraps=meta.wraps
|
meta.number, meta.proto_type, item, wraps=meta.wraps or ""
|
||||||
)
|
)
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
for k, v in value.items():
|
for k, v in value.items():
|
||||||
@ -532,7 +535,7 @@ class Message(ABC):
|
|||||||
meta.proto_type,
|
meta.proto_type,
|
||||||
value,
|
value,
|
||||||
serialize_empty=serialize_empty,
|
serialize_empty=serialize_empty,
|
||||||
wraps=meta.wraps,
|
wraps=meta.wraps or "",
|
||||||
)
|
)
|
||||||
|
|
||||||
return output + self._unknown_fields
|
return output + self._unknown_fields
|
||||||
@ -702,14 +705,14 @@ class Message(ABC):
|
|||||||
for field in dataclasses.fields(self):
|
for field in dataclasses.fields(self):
|
||||||
meta = FieldMetadata.get(field)
|
meta = FieldMetadata.get(field)
|
||||||
v = getattr(self, field.name)
|
v = getattr(self, field.name)
|
||||||
cased_name = casing(field.name).rstrip("_")
|
cased_name = casing(field.name).rstrip("_") # type: ignore
|
||||||
if meta.proto_type == "message":
|
if meta.proto_type == "message":
|
||||||
if isinstance(v, datetime):
|
if isinstance(v, datetime):
|
||||||
if v != DATETIME_ZERO:
|
if v != DATETIME_ZERO:
|
||||||
output[cased_name] = _Timestamp.to_json(v)
|
output[cased_name] = _Timestamp.timestamp_to_json(v)
|
||||||
elif isinstance(v, timedelta):
|
elif isinstance(v, timedelta):
|
||||||
if v != timedelta(0):
|
if v != timedelta(0):
|
||||||
output[cased_name] = _Duration.to_json(v)
|
output[cased_name] = _Duration.delta_to_json(v)
|
||||||
elif meta.wraps:
|
elif meta.wraps:
|
||||||
if v is not None:
|
if v is not None:
|
||||||
output[cased_name] = v
|
output[cased_name] = v
|
||||||
@ -738,7 +741,7 @@ class Message(ABC):
|
|||||||
else:
|
else:
|
||||||
output[cased_name] = b64encode(v).decode("utf8")
|
output[cased_name] = b64encode(v).decode("utf8")
|
||||||
elif meta.proto_type == TYPE_ENUM:
|
elif meta.proto_type == TYPE_ENUM:
|
||||||
enum_values = list(self._cls_for(field))
|
enum_values = list(self._cls_for(field)) # type: ignore
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
output[cased_name] = [enum_values[e].name for e in v]
|
output[cased_name] = [enum_values[e].name for e in v]
|
||||||
else:
|
else:
|
||||||
@ -853,7 +856,7 @@ class _Duration(Message):
|
|||||||
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
|
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_json(delta: timedelta) -> str:
|
def delta_to_json(delta: timedelta) -> str:
|
||||||
parts = str(delta.total_seconds()).split(".")
|
parts = str(delta.total_seconds()).split(".")
|
||||||
if len(parts) > 1:
|
if len(parts) > 1:
|
||||||
while len(parts[1]) not in [3, 6, 9]:
|
while len(parts[1]) not in [3, 6, 9]:
|
||||||
@ -876,7 +879,7 @@ class _Timestamp(Message):
|
|||||||
return datetime.fromtimestamp(ts, tz=timezone.utc)
|
return datetime.fromtimestamp(ts, tz=timezone.utc)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_json(dt: datetime) -> str:
|
def timestamp_to_json(dt: datetime) -> str:
|
||||||
nanos = dt.microsecond * 1e3
|
nanos = dt.microsecond * 1e3
|
||||||
copy = dt.replace(microsecond=0, tzinfo=None)
|
copy = dt.replace(microsecond=0, tzinfo=None)
|
||||||
result = copy.isoformat()
|
result = copy.isoformat()
|
||||||
@ -899,12 +902,15 @@ class _WrappedMessage(Message):
|
|||||||
Google protobuf wrapper types base class. JSON representation is just the
|
Google protobuf wrapper types base class. JSON representation is just the
|
||||||
value itself.
|
value itself.
|
||||||
"""
|
"""
|
||||||
def to_dict(self) -> Any:
|
value: Any
|
||||||
|
|
||||||
|
def to_dict(self, casing: Casing = Casing.CAMEL) -> Any:
|
||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
def from_dict(self, value: Any) -> None:
|
def from_dict(self: T, value: Any) -> T:
|
||||||
if value is not None:
|
if value is not None:
|
||||||
self.value = value
|
self.value = value
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@ -952,7 +958,7 @@ class _BytesValue(_WrappedMessage):
|
|||||||
value: bytes = bytes_field(1)
|
value: bytes = bytes_field(1)
|
||||||
|
|
||||||
|
|
||||||
def _get_wrapper(proto_type: str) -> _WrappedMessage:
|
def _get_wrapper(proto_type: str) -> Type:
|
||||||
"""Get the wrapper message class for a wrapped type."""
|
"""Get the wrapper message class for a wrapped type."""
|
||||||
return {
|
return {
|
||||||
TYPE_BOOL: _BoolValue,
|
TYPE_BOOL: _BoolValue,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user