QOL fixes (#141)
- Add missing type annotations - Various style improvements - Use constants more consistently - enforce black on benchmark code
This commit is contained in:
		| @@ -26,7 +26,7 @@ from ._types import T | ||||
| from .casing import camel_case, safe_snake_case, snake_case | ||||
| from .grpc.grpclib_client import ServiceStub | ||||
|  | ||||
| if not (sys.version_info.major == 3 and sys.version_info.minor >= 7): | ||||
| if sys.version_info[:2] < (3, 7): | ||||
|     # Apply backport of datetime.fromisoformat from 3.7 | ||||
|     from backports.datetime_fromisoformat import MonkeyPatch | ||||
|  | ||||
| @@ -110,7 +110,7 @@ WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP] | ||||
|  | ||||
|  | ||||
| # Protobuf datetimes start at the Unix Epoch in 1970 in UTC. | ||||
| def datetime_default_gen(): | ||||
| def datetime_default_gen() -> datetime: | ||||
|     return datetime(1970, 1, 1, tzinfo=timezone.utc) | ||||
|  | ||||
|  | ||||
| @@ -256,8 +256,7 @@ class Enum(enum.IntEnum): | ||||
|  | ||||
|     @classmethod | ||||
|     def from_string(cls, name: str) -> "Enum": | ||||
|         """ | ||||
|         Return the value which corresponds to the string name. | ||||
|         """Return the value which corresponds to the string name. | ||||
|  | ||||
|         Parameters | ||||
|         ----------- | ||||
| @@ -316,11 +315,7 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes: | ||||
|         return encode_varint(value) | ||||
|     elif proto_type in [TYPE_SINT32, TYPE_SINT64]: | ||||
|         # Handle zig-zag encoding. | ||||
|         if value >= 0: | ||||
|             value = value << 1 | ||||
|         else: | ||||
|             value = (value << 1) ^ (~0) | ||||
|         return encode_varint(value) | ||||
|         return encode_varint(value << 1 if value >= 0 else (value << 1) ^ (~0)) | ||||
|     elif proto_type in FIXED_TYPES: | ||||
|         return struct.pack(_pack_fmt(proto_type), value) | ||||
|     elif proto_type == TYPE_STRING: | ||||
| @@ -413,15 +408,15 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: | ||||
|         wire_type = num_wire & 0x7 | ||||
|  | ||||
|         decoded: Any = None | ||||
|         if wire_type == 0: | ||||
|         if wire_type == WIRE_VARINT: | ||||
|             decoded, i = decode_varint(value, i) | ||||
|         elif wire_type == 1: | ||||
|         elif wire_type == WIRE_FIXED_64: | ||||
|             decoded, i = value[i : i + 8], i + 8 | ||||
|         elif wire_type == 2: | ||||
|         elif wire_type == WIRE_LEN_DELIM: | ||||
|             length, i = decode_varint(value, i) | ||||
|             decoded = value[i : i + length] | ||||
|             i += length | ||||
|         elif wire_type == 5: | ||||
|         elif wire_type == WIRE_FIXED_32: | ||||
|             decoded, i = value[i : i + 4], i + 4 | ||||
|  | ||||
|         yield ParsedField( | ||||
| @@ -430,12 +425,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]: | ||||
|  | ||||
|  | ||||
| class ProtoClassMetadata: | ||||
|     oneof_group_by_field: Dict[str, str] | ||||
|     oneof_field_by_group: Dict[str, Set[dataclasses.Field]] | ||||
|     default_gen: Dict[str, Callable] | ||||
|     cls_by_field: Dict[str, Type] | ||||
|     field_name_by_number: Dict[int, str] | ||||
|     meta_by_field_name: Dict[str, FieldMetadata] | ||||
|     __slots__ = ( | ||||
|         "oneof_group_by_field", | ||||
|         "oneof_field_by_group", | ||||
| @@ -446,6 +435,14 @@ class ProtoClassMetadata: | ||||
|         "sorted_field_names", | ||||
|     ) | ||||
|  | ||||
|     oneof_group_by_field: Dict[str, str] | ||||
|     oneof_field_by_group: Dict[str, Set[dataclasses.Field]] | ||||
|     field_name_by_number: Dict[int, str] | ||||
|     meta_by_field_name: Dict[str, FieldMetadata] | ||||
|     sorted_field_names: Tuple[str, ...] | ||||
|     default_gen: Dict[str, Callable[[], Any]] | ||||
|     cls_by_field: Dict[str, Type] | ||||
|  | ||||
|     def __init__(self, cls: Type["Message"]): | ||||
|         by_field = {} | ||||
|         by_group: Dict[str, Set] = {} | ||||
| @@ -470,23 +467,21 @@ class ProtoClassMetadata: | ||||
|         self.field_name_by_number = by_field_number | ||||
|         self.meta_by_field_name = by_field_name | ||||
|         self.sorted_field_names = tuple( | ||||
|             by_field_number[number] for number in sorted(by_field_number.keys()) | ||||
|             by_field_number[number] for number in sorted(by_field_number) | ||||
|         ) | ||||
|  | ||||
|         self.default_gen = self._get_default_gen(cls, fields) | ||||
|         self.cls_by_field = self._get_cls_by_field(cls, fields) | ||||
|  | ||||
|     @staticmethod | ||||
|     def _get_default_gen(cls, fields): | ||||
|         default_gen = {} | ||||
|  | ||||
|         for field in fields: | ||||
|             default_gen[field.name] = cls._get_field_default_gen(field) | ||||
|  | ||||
|         return default_gen | ||||
|     def _get_default_gen( | ||||
|         cls: Type["Message"], fields: List[dataclasses.Field] | ||||
|     ) -> Dict[str, Callable[[], Any]]: | ||||
|         return {field.name: cls._get_field_default_gen(field) for field in fields} | ||||
|  | ||||
|     @staticmethod | ||||
|     def _get_cls_by_field(cls, fields): | ||||
|     def _get_cls_by_field( | ||||
|         cls: Type["Message"], fields: List[dataclasses.Field] | ||||
|     ) -> Dict[str, Type]: | ||||
|         field_cls = {} | ||||
|  | ||||
|         for field in fields: | ||||
| @@ -503,7 +498,7 @@ class ProtoClassMetadata: | ||||
|                     ], | ||||
|                     bases=(Message,), | ||||
|                 ) | ||||
|                 field_cls[field.name + ".value"] = vt | ||||
|                 field_cls[f"{field.name}.value"] = vt | ||||
|             else: | ||||
|                 field_cls[field.name] = cls._cls_for(field) | ||||
|  | ||||
| @@ -612,7 +607,7 @@ class Message(ABC): | ||||
|         super().__setattr__(attr, value) | ||||
|  | ||||
|     @property | ||||
|     def _betterproto(self): | ||||
|     def _betterproto(self) -> ProtoClassMetadata: | ||||
|         """ | ||||
|         Lazy initialize metadata for each protobuf class. | ||||
|         It may be initialized multiple times in a multi-threaded environment, | ||||
| @@ -726,9 +721,8 @@ class Message(ABC): | ||||
|  | ||||
|     @classmethod | ||||
|     def _type_hints(cls) -> Dict[str, Type]: | ||||
|         module = inspect.getmodule(cls) | ||||
|         type_hints = get_type_hints(cls, vars(module)) | ||||
|         return type_hints | ||||
|         module = sys.modules[cls.__module__] | ||||
|         return get_type_hints(cls, vars(module)) | ||||
|  | ||||
|     @classmethod | ||||
|     def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type: | ||||
| @@ -739,7 +733,7 @@ class Message(ABC): | ||||
|                 field_cls = field_cls.__args__[index] | ||||
|         return field_cls | ||||
|  | ||||
|     def _get_field_default(self, field_name): | ||||
|     def _get_field_default(self, field_name: str) -> Any: | ||||
|         return self._betterproto.default_gen[field_name]() | ||||
|  | ||||
|     @classmethod | ||||
| @@ -762,7 +756,7 @@ class Message(ABC): | ||||
|         elif issubclass(t, Enum): | ||||
|             # Enums always default to zero. | ||||
|             return int | ||||
|         elif t == datetime: | ||||
|         elif t is datetime: | ||||
|             # Offsets are relative to 1970-01-01T00:00:00Z | ||||
|             return datetime_default_gen | ||||
|         else: | ||||
| @@ -966,7 +960,7 @@ class Message(ABC): | ||||
|                     ) | ||||
|                 ): | ||||
|                     output[cased_name] = value.to_dict(casing, include_default_values) | ||||
|             elif meta.proto_type == "map": | ||||
|             elif meta.proto_type == TYPE_MAP: | ||||
|                 for k in value: | ||||
|                     if hasattr(value[k], "to_dict"): | ||||
|                         value[k] = value[k].to_dict(casing, include_default_values) | ||||
| @@ -1032,12 +1026,12 @@ class Message(ABC): | ||||
|                 continue | ||||
|  | ||||
|             if value[key] is not None: | ||||
|                 if meta.proto_type == "message": | ||||
|                 if meta.proto_type == TYPE_MESSAGE: | ||||
|                     v = getattr(self, field_name) | ||||
|                     if isinstance(v, list): | ||||
|                         cls = self._betterproto.cls_by_field[field_name] | ||||
|                         for i in range(len(value[key])): | ||||
|                             v.append(cls().from_dict(value[key][i])) | ||||
|                         for item in value[key]: | ||||
|                             v.append(cls().from_dict(item)) | ||||
|                     elif isinstance(v, datetime): | ||||
|                         v = datetime.fromisoformat(value[key].replace("Z", "+00:00")) | ||||
|                         setattr(self, field_name, v) | ||||
| @@ -1052,7 +1046,7 @@ class Message(ABC): | ||||
|                         v.from_dict(value[key]) | ||||
|                 elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE: | ||||
|                     v = getattr(self, field_name) | ||||
|                     cls = self._betterproto.cls_by_field[field_name + ".value"] | ||||
|                     cls = self._betterproto.cls_by_field[f"{field_name}.value"] | ||||
|                     for k in value[key]: | ||||
|                         v[k] = cls().from_dict(value[key][k]) | ||||
|                 else: | ||||
| @@ -1134,7 +1128,7 @@ def serialized_on_wire(message: Message) -> bool: | ||||
|     return message._serialized_on_wire | ||||
|  | ||||
|  | ||||
| def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]: | ||||
| def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]: | ||||
|     """ | ||||
|     Return the name and value of a message's one-of field group. | ||||
|  | ||||
| @@ -1145,21 +1139,21 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]: | ||||
|     """ | ||||
|     field_name = message._group_current.get(group_name) | ||||
|     if not field_name: | ||||
|         return ("", None) | ||||
|     return (field_name, getattr(message, field_name)) | ||||
|         return "", None | ||||
|     return field_name, getattr(message, field_name) | ||||
|  | ||||
|  | ||||
| # Circular import workaround: google.protobuf depends on base classes defined above. | ||||
| from .lib.google.protobuf import (  # noqa | ||||
|     Duration, | ||||
|     Timestamp, | ||||
|     BoolValue, | ||||
|     BytesValue, | ||||
|     DoubleValue, | ||||
|     Duration, | ||||
|     FloatValue, | ||||
|     Int32Value, | ||||
|     Int64Value, | ||||
|     StringValue, | ||||
|     Timestamp, | ||||
|     UInt32Value, | ||||
|     UInt64Value, | ||||
| ) | ||||
| @@ -1174,8 +1168,8 @@ class _Duration(Duration): | ||||
|         parts = str(delta.total_seconds()).split(".") | ||||
|         if len(parts) > 1: | ||||
|             while len(parts[1]) not in [3, 6, 9]: | ||||
|                 parts[1] = parts[1] + "0" | ||||
|         return ".".join(parts) + "s" | ||||
|                 parts[1] = f"{parts[1]}0" | ||||
|         return f"{'.'.join(parts)}s" | ||||
|  | ||||
|  | ||||
| class _Timestamp(Timestamp): | ||||
| @@ -1191,15 +1185,15 @@ class _Timestamp(Timestamp): | ||||
|         if (nanos % 1e9) == 0: | ||||
|             # If there are 0 fractional digits, the fractional | ||||
|             # point '.' should be omitted when serializing. | ||||
|             return result + "Z" | ||||
|             return f"{result}Z" | ||||
|         if (nanos % 1e6) == 0: | ||||
|             # Serialize 3 fractional digits. | ||||
|             return result + ".%03dZ" % (nanos / 1e6) | ||||
|             return f"{result}.{int(nanos // 1e6) :03d}Z" | ||||
|         if (nanos % 1e3) == 0: | ||||
|             # Serialize 6 fractional digits. | ||||
|             return result + ".%06dZ" % (nanos / 1e3) | ||||
|             return f"{result}.{int(nanos // 1e3) :06d}Z" | ||||
|         # Serialize 9 fractional digits. | ||||
|         return result + ".%09dZ" % nanos | ||||
|         return f"{result}.{nanos:09d}" | ||||
|  | ||||
|  | ||||
| class _WrappedMessage(Message): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user