QOL fixes (#141)

- Add missing type annotations
- Various style improvements
- Use constants more consistently
- enforce black on benchmark code
This commit is contained in:
James
2020-10-17 18:27:11 +01:00
committed by GitHub
parent bf9412e083
commit 8f7af272cc
16 changed files with 177 additions and 220 deletions

View File

@@ -26,7 +26,7 @@ from ._types import T
from .casing import camel_case, safe_snake_case, snake_case
from .grpc.grpclib_client import ServiceStub
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
if sys.version_info[:2] < (3, 7):
# Apply backport of datetime.fromisoformat from 3.7
from backports.datetime_fromisoformat import MonkeyPatch
@@ -110,7 +110,7 @@ WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
def datetime_default_gen():
def datetime_default_gen() -> datetime:
return datetime(1970, 1, 1, tzinfo=timezone.utc)
@@ -256,8 +256,7 @@ class Enum(enum.IntEnum):
@classmethod
def from_string(cls, name: str) -> "Enum":
"""
Return the value which corresponds to the string name.
"""Return the value which corresponds to the string name.
Parameters
-----------
@@ -316,11 +315,7 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
return encode_varint(value)
elif proto_type in [TYPE_SINT32, TYPE_SINT64]:
# Handle zig-zag encoding.
if value >= 0:
value = value << 1
else:
value = (value << 1) ^ (~0)
return encode_varint(value)
return encode_varint(value << 1 if value >= 0 else (value << 1) ^ (~0))
elif proto_type in FIXED_TYPES:
return struct.pack(_pack_fmt(proto_type), value)
elif proto_type == TYPE_STRING:
@@ -413,15 +408,15 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
wire_type = num_wire & 0x7
decoded: Any = None
if wire_type == 0:
if wire_type == WIRE_VARINT:
decoded, i = decode_varint(value, i)
elif wire_type == 1:
elif wire_type == WIRE_FIXED_64:
decoded, i = value[i : i + 8], i + 8
elif wire_type == 2:
elif wire_type == WIRE_LEN_DELIM:
length, i = decode_varint(value, i)
decoded = value[i : i + length]
i += length
elif wire_type == 5:
elif wire_type == WIRE_FIXED_32:
decoded, i = value[i : i + 4], i + 4
yield ParsedField(
@@ -430,12 +425,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
class ProtoClassMetadata:
oneof_group_by_field: Dict[str, str]
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
default_gen: Dict[str, Callable]
cls_by_field: Dict[str, Type]
field_name_by_number: Dict[int, str]
meta_by_field_name: Dict[str, FieldMetadata]
__slots__ = (
"oneof_group_by_field",
"oneof_field_by_group",
@@ -446,6 +435,14 @@ class ProtoClassMetadata:
"sorted_field_names",
)
oneof_group_by_field: Dict[str, str]
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
field_name_by_number: Dict[int, str]
meta_by_field_name: Dict[str, FieldMetadata]
sorted_field_names: Tuple[str, ...]
default_gen: Dict[str, Callable[[], Any]]
cls_by_field: Dict[str, Type]
def __init__(self, cls: Type["Message"]):
by_field = {}
by_group: Dict[str, Set] = {}
@@ -470,23 +467,21 @@ class ProtoClassMetadata:
self.field_name_by_number = by_field_number
self.meta_by_field_name = by_field_name
self.sorted_field_names = tuple(
by_field_number[number] for number in sorted(by_field_number.keys())
by_field_number[number] for number in sorted(by_field_number)
)
self.default_gen = self._get_default_gen(cls, fields)
self.cls_by_field = self._get_cls_by_field(cls, fields)
@staticmethod
def _get_default_gen(cls, fields):
default_gen = {}
for field in fields:
default_gen[field.name] = cls._get_field_default_gen(field)
return default_gen
def _get_default_gen(
cls: Type["Message"], fields: List[dataclasses.Field]
) -> Dict[str, Callable[[], Any]]:
return {field.name: cls._get_field_default_gen(field) for field in fields}
@staticmethod
def _get_cls_by_field(cls, fields):
def _get_cls_by_field(
cls: Type["Message"], fields: List[dataclasses.Field]
) -> Dict[str, Type]:
field_cls = {}
for field in fields:
@@ -503,7 +498,7 @@ class ProtoClassMetadata:
],
bases=(Message,),
)
field_cls[field.name + ".value"] = vt
field_cls[f"{field.name}.value"] = vt
else:
field_cls[field.name] = cls._cls_for(field)
@@ -612,7 +607,7 @@ class Message(ABC):
super().__setattr__(attr, value)
@property
def _betterproto(self):
def _betterproto(self) -> ProtoClassMetadata:
"""
Lazy initialize metadata for each protobuf class.
It may be initialized multiple times in a multi-threaded environment,
@@ -726,9 +721,8 @@ class Message(ABC):
@classmethod
def _type_hints(cls) -> Dict[str, Type]:
module = inspect.getmodule(cls)
type_hints = get_type_hints(cls, vars(module))
return type_hints
module = sys.modules[cls.__module__]
return get_type_hints(cls, vars(module))
@classmethod
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
@@ -739,7 +733,7 @@ class Message(ABC):
field_cls = field_cls.__args__[index]
return field_cls
def _get_field_default(self, field_name):
def _get_field_default(self, field_name: str) -> Any:
return self._betterproto.default_gen[field_name]()
@classmethod
@@ -762,7 +756,7 @@ class Message(ABC):
elif issubclass(t, Enum):
# Enums always default to zero.
return int
elif t == datetime:
elif t is datetime:
# Offsets are relative to 1970-01-01T00:00:00Z
return datetime_default_gen
else:
@@ -966,7 +960,7 @@ class Message(ABC):
)
):
output[cased_name] = value.to_dict(casing, include_default_values)
elif meta.proto_type == "map":
elif meta.proto_type == TYPE_MAP:
for k in value:
if hasattr(value[k], "to_dict"):
value[k] = value[k].to_dict(casing, include_default_values)
@@ -1032,12 +1026,12 @@ class Message(ABC):
continue
if value[key] is not None:
if meta.proto_type == "message":
if meta.proto_type == TYPE_MESSAGE:
v = getattr(self, field_name)
if isinstance(v, list):
cls = self._betterproto.cls_by_field[field_name]
for i in range(len(value[key])):
v.append(cls().from_dict(value[key][i]))
for item in value[key]:
v.append(cls().from_dict(item))
elif isinstance(v, datetime):
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
setattr(self, field_name, v)
@@ -1052,7 +1046,7 @@ class Message(ABC):
v.from_dict(value[key])
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
v = getattr(self, field_name)
cls = self._betterproto.cls_by_field[field_name + ".value"]
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
for k in value[key]:
v[k] = cls().from_dict(value[key][k])
else:
@@ -1134,7 +1128,7 @@ def serialized_on_wire(message: Message) -> bool:
return message._serialized_on_wire
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]:
"""
Return the name and value of a message's one-of field group.
@@ -1145,21 +1139,21 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
"""
field_name = message._group_current.get(group_name)
if not field_name:
return ("", None)
return (field_name, getattr(message, field_name))
return "", None
return field_name, getattr(message, field_name)
# Circular import workaround: google.protobuf depends on base classes defined above.
from .lib.google.protobuf import ( # noqa
Duration,
Timestamp,
BoolValue,
BytesValue,
DoubleValue,
Duration,
FloatValue,
Int32Value,
Int64Value,
StringValue,
Timestamp,
UInt32Value,
UInt64Value,
)
@@ -1174,8 +1168,8 @@ class _Duration(Duration):
parts = str(delta.total_seconds()).split(".")
if len(parts) > 1:
while len(parts[1]) not in [3, 6, 9]:
parts[1] = parts[1] + "0"
return ".".join(parts) + "s"
parts[1] = f"{parts[1]}0"
return f"{'.'.join(parts)}s"
class _Timestamp(Timestamp):
@@ -1191,15 +1185,15 @@ class _Timestamp(Timestamp):
if (nanos % 1e9) == 0:
# If there are 0 fractional digits, the fractional
# point '.' should be omitted when serializing.
return result + "Z"
return f"{result}Z"
if (nanos % 1e6) == 0:
# Serialize 3 fractional digits.
return result + ".%03dZ" % (nanos / 1e6)
return f"{result}.{int(nanos // 1e6) :03d}Z"
if (nanos % 1e3) == 0:
# Serialize 6 fractional digits.
return result + ".%06dZ" % (nanos / 1e3)
return f"{result}.{int(nanos // 1e3) :06d}Z"
# Serialize 9 fractional digits.
return result + ".%09dZ" % nanos
return f"{result}.{nanos:09d}"
class _WrappedMessage(Message):