Typing fixes

This commit is contained in:
Daniel G. Taylor 2019-10-27 15:13:51 -07:00
parent eb5020db2a
commit 16687211a2
No known key found for this signature in database
GPG Key ID: 7BD6DC99C9A87E22

View File

@ -420,11 +420,14 @@ class Message(ABC):
register the message fields which get used by the serializers and parsers
to go between Python, binary and JSON protobuf message representations.
"""
_serialized_on_wire: bool
_unknown_fields: bytes
_group_map: Dict[str, dict]
def __post_init__(self) -> None:
# Set a default value for each field in the class after `__init__` has
# already been run.
group_map = {"fields": {}, "groups": {}}
group_map: Dict[str, dict] = {"fields": {}, "groups": {}}
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
@ -518,7 +521,7 @@ class Message(ABC):
else:
for item in value:
output += _serialize_single(
meta.number, meta.proto_type, item, wraps=meta.wraps
meta.number, meta.proto_type, item, wraps=meta.wraps or ""
)
elif isinstance(value, dict):
for k, v in value.items():
@ -532,7 +535,7 @@ class Message(ABC):
meta.proto_type,
value,
serialize_empty=serialize_empty,
wraps=meta.wraps,
wraps=meta.wraps or "",
)
return output + self._unknown_fields
@ -702,14 +705,14 @@ class Message(ABC):
for field in dataclasses.fields(self):
meta = FieldMetadata.get(field)
v = getattr(self, field.name)
cased_name = casing(field.name).rstrip("_")
cased_name = casing(field.name).rstrip("_") # type: ignore
if meta.proto_type == "message":
if isinstance(v, datetime):
if v != DATETIME_ZERO:
output[cased_name] = _Timestamp.to_json(v)
output[cased_name] = _Timestamp.timestamp_to_json(v)
elif isinstance(v, timedelta):
if v != timedelta(0):
output[cased_name] = _Duration.to_json(v)
output[cased_name] = _Duration.delta_to_json(v)
elif meta.wraps:
if v is not None:
output[cased_name] = v
@ -738,7 +741,7 @@ class Message(ABC):
else:
output[cased_name] = b64encode(v).decode("utf8")
elif meta.proto_type == TYPE_ENUM:
enum_values = list(self._cls_for(field))
enum_values = list(self._cls_for(field)) # type: ignore
if isinstance(v, list):
output[cased_name] = [enum_values[e].name for e in v]
else:
@ -853,7 +856,7 @@ class _Duration(Message):
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
@staticmethod
def to_json(delta: timedelta) -> str:
def delta_to_json(delta: timedelta) -> str:
parts = str(delta.total_seconds()).split(".")
if len(parts) > 1:
while len(parts[1]) not in [3, 6, 9]:
@ -876,7 +879,7 @@ class _Timestamp(Message):
return datetime.fromtimestamp(ts, tz=timezone.utc)
@staticmethod
def to_json(dt: datetime) -> str:
def timestamp_to_json(dt: datetime) -> str:
nanos = dt.microsecond * 1e3
copy = dt.replace(microsecond=0, tzinfo=None)
result = copy.isoformat()
@ -899,12 +902,15 @@ class _WrappedMessage(Message):
Google protobuf wrapper types base class. JSON representation is just the
value itself.
"""
def to_dict(self) -> Any:
value: Any
def to_dict(self, casing: Casing = Casing.CAMEL) -> Any:
return self.value
def from_dict(self, value: Any) -> None:
def from_dict(self: T, value: Any) -> T:
if value is not None:
self.value = value
return self
@dataclasses.dataclass
@ -952,7 +958,7 @@ class _BytesValue(_WrappedMessage):
value: bytes = bytes_field(1)
def _get_wrapper(proto_type: str) -> _WrappedMessage:
def _get_wrapper(proto_type: str) -> Type:
"""Get the wrapper message class for a wrapped type."""
return {
TYPE_BOOL: _BoolValue,