Typing fixes

This commit is contained in:
Daniel G. Taylor 2019-10-27 15:13:51 -07:00
parent eb5020db2a
commit 16687211a2
No known key found for this signature in database
GPG Key ID: 7BD6DC99C9A87E22

View File

@ -420,11 +420,14 @@ class Message(ABC):
register the message fields which get used by the serializers and parsers register the message fields which get used by the serializers and parsers
to go between Python, binary and JSON protobuf message representations. to go between Python, binary and JSON protobuf message representations.
""" """
_serialized_on_wire: bool
_unknown_fields: bytes
_group_map: Dict[str, dict]
def __post_init__(self) -> None: def __post_init__(self) -> None:
# Set a default value for each field in the class after `__init__` has # Set a default value for each field in the class after `__init__` has
# already been run. # already been run.
group_map = {"fields": {}, "groups": {}} group_map: Dict[str, dict] = {"fields": {}, "groups": {}}
for field in dataclasses.fields(self): for field in dataclasses.fields(self):
meta = FieldMetadata.get(field) meta = FieldMetadata.get(field)
@ -518,7 +521,7 @@ class Message(ABC):
else: else:
for item in value: for item in value:
output += _serialize_single( output += _serialize_single(
meta.number, meta.proto_type, item, wraps=meta.wraps meta.number, meta.proto_type, item, wraps=meta.wraps or ""
) )
elif isinstance(value, dict): elif isinstance(value, dict):
for k, v in value.items(): for k, v in value.items():
@ -532,7 +535,7 @@ class Message(ABC):
meta.proto_type, meta.proto_type,
value, value,
serialize_empty=serialize_empty, serialize_empty=serialize_empty,
wraps=meta.wraps, wraps=meta.wraps or "",
) )
return output + self._unknown_fields return output + self._unknown_fields
@ -702,14 +705,14 @@ class Message(ABC):
for field in dataclasses.fields(self): for field in dataclasses.fields(self):
meta = FieldMetadata.get(field) meta = FieldMetadata.get(field)
v = getattr(self, field.name) v = getattr(self, field.name)
cased_name = casing(field.name).rstrip("_") cased_name = casing(field.name).rstrip("_") # type: ignore
if meta.proto_type == "message": if meta.proto_type == "message":
if isinstance(v, datetime): if isinstance(v, datetime):
if v != DATETIME_ZERO: if v != DATETIME_ZERO:
output[cased_name] = _Timestamp.to_json(v) output[cased_name] = _Timestamp.timestamp_to_json(v)
elif isinstance(v, timedelta): elif isinstance(v, timedelta):
if v != timedelta(0): if v != timedelta(0):
output[cased_name] = _Duration.to_json(v) output[cased_name] = _Duration.delta_to_json(v)
elif meta.wraps: elif meta.wraps:
if v is not None: if v is not None:
output[cased_name] = v output[cased_name] = v
@ -738,7 +741,7 @@ class Message(ABC):
else: else:
output[cased_name] = b64encode(v).decode("utf8") output[cased_name] = b64encode(v).decode("utf8")
elif meta.proto_type == TYPE_ENUM: elif meta.proto_type == TYPE_ENUM:
enum_values = list(self._cls_for(field)) enum_values = list(self._cls_for(field)) # type: ignore
if isinstance(v, list): if isinstance(v, list):
output[cased_name] = [enum_values[e].name for e in v] output[cased_name] = [enum_values[e].name for e in v]
else: else:
@ -853,7 +856,7 @@ class _Duration(Message):
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3) return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
@staticmethod @staticmethod
def to_json(delta: timedelta) -> str: def delta_to_json(delta: timedelta) -> str:
parts = str(delta.total_seconds()).split(".") parts = str(delta.total_seconds()).split(".")
if len(parts) > 1: if len(parts) > 1:
while len(parts[1]) not in [3, 6, 9]: while len(parts[1]) not in [3, 6, 9]:
@ -876,7 +879,7 @@ class _Timestamp(Message):
return datetime.fromtimestamp(ts, tz=timezone.utc) return datetime.fromtimestamp(ts, tz=timezone.utc)
@staticmethod @staticmethod
def to_json(dt: datetime) -> str: def timestamp_to_json(dt: datetime) -> str:
nanos = dt.microsecond * 1e3 nanos = dt.microsecond * 1e3
copy = dt.replace(microsecond=0, tzinfo=None) copy = dt.replace(microsecond=0, tzinfo=None)
result = copy.isoformat() result = copy.isoformat()
@ -899,12 +902,15 @@ class _WrappedMessage(Message):
Google protobuf wrapper types base class. JSON representation is just the Google protobuf wrapper types base class. JSON representation is just the
value itself. value itself.
""" """
def to_dict(self) -> Any: value: Any
def to_dict(self, casing: Casing = Casing.CAMEL) -> Any:
return self.value return self.value
def from_dict(self, value: Any) -> None: def from_dict(self: T, value: Any) -> T:
if value is not None: if value is not None:
self.value = value self.value = value
return self
@dataclasses.dataclass @dataclasses.dataclass
@ -952,7 +958,7 @@ class _BytesValue(_WrappedMessage):
value: bytes = bytes_field(1) value: bytes = bytes_field(1)
def _get_wrapper(proto_type: str) -> _WrappedMessage: def _get_wrapper(proto_type: str) -> Type:
"""Get the wrapper message class for a wrapped type.""" """Get the wrapper message class for a wrapped type."""
return { return {
TYPE_BOOL: _BoolValue, TYPE_BOOL: _BoolValue,