QOL fixes (#141)
- Add missing type annotations - Various style improvements - Use constants more consistently - enforce black on benchmark code
This commit is contained in:
@@ -26,7 +26,7 @@ from ._types import T
|
||||
from .casing import camel_case, safe_snake_case, snake_case
|
||||
from .grpc.grpclib_client import ServiceStub
|
||||
|
||||
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
|
||||
if sys.version_info[:2] < (3, 7):
|
||||
# Apply backport of datetime.fromisoformat from 3.7
|
||||
from backports.datetime_fromisoformat import MonkeyPatch
|
||||
|
||||
@@ -110,7 +110,7 @@ WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
||||
|
||||
|
||||
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
|
||||
def datetime_default_gen():
|
||||
def datetime_default_gen() -> datetime:
|
||||
return datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
@@ -256,8 +256,7 @@ class Enum(enum.IntEnum):
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, name: str) -> "Enum":
|
||||
"""
|
||||
Return the value which corresponds to the string name.
|
||||
"""Return the value which corresponds to the string name.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
@@ -316,11 +315,7 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
|
||||
return encode_varint(value)
|
||||
elif proto_type in [TYPE_SINT32, TYPE_SINT64]:
|
||||
# Handle zig-zag encoding.
|
||||
if value >= 0:
|
||||
value = value << 1
|
||||
else:
|
||||
value = (value << 1) ^ (~0)
|
||||
return encode_varint(value)
|
||||
return encode_varint(value << 1 if value >= 0 else (value << 1) ^ (~0))
|
||||
elif proto_type in FIXED_TYPES:
|
||||
return struct.pack(_pack_fmt(proto_type), value)
|
||||
elif proto_type == TYPE_STRING:
|
||||
@@ -413,15 +408,15 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
||||
wire_type = num_wire & 0x7
|
||||
|
||||
decoded: Any = None
|
||||
if wire_type == 0:
|
||||
if wire_type == WIRE_VARINT:
|
||||
decoded, i = decode_varint(value, i)
|
||||
elif wire_type == 1:
|
||||
elif wire_type == WIRE_FIXED_64:
|
||||
decoded, i = value[i : i + 8], i + 8
|
||||
elif wire_type == 2:
|
||||
elif wire_type == WIRE_LEN_DELIM:
|
||||
length, i = decode_varint(value, i)
|
||||
decoded = value[i : i + length]
|
||||
i += length
|
||||
elif wire_type == 5:
|
||||
elif wire_type == WIRE_FIXED_32:
|
||||
decoded, i = value[i : i + 4], i + 4
|
||||
|
||||
yield ParsedField(
|
||||
@@ -430,12 +425,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
||||
|
||||
|
||||
class ProtoClassMetadata:
|
||||
oneof_group_by_field: Dict[str, str]
|
||||
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
|
||||
default_gen: Dict[str, Callable]
|
||||
cls_by_field: Dict[str, Type]
|
||||
field_name_by_number: Dict[int, str]
|
||||
meta_by_field_name: Dict[str, FieldMetadata]
|
||||
__slots__ = (
|
||||
"oneof_group_by_field",
|
||||
"oneof_field_by_group",
|
||||
@@ -446,6 +435,14 @@ class ProtoClassMetadata:
|
||||
"sorted_field_names",
|
||||
)
|
||||
|
||||
oneof_group_by_field: Dict[str, str]
|
||||
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
|
||||
field_name_by_number: Dict[int, str]
|
||||
meta_by_field_name: Dict[str, FieldMetadata]
|
||||
sorted_field_names: Tuple[str, ...]
|
||||
default_gen: Dict[str, Callable[[], Any]]
|
||||
cls_by_field: Dict[str, Type]
|
||||
|
||||
def __init__(self, cls: Type["Message"]):
|
||||
by_field = {}
|
||||
by_group: Dict[str, Set] = {}
|
||||
@@ -470,23 +467,21 @@ class ProtoClassMetadata:
|
||||
self.field_name_by_number = by_field_number
|
||||
self.meta_by_field_name = by_field_name
|
||||
self.sorted_field_names = tuple(
|
||||
by_field_number[number] for number in sorted(by_field_number.keys())
|
||||
by_field_number[number] for number in sorted(by_field_number)
|
||||
)
|
||||
|
||||
self.default_gen = self._get_default_gen(cls, fields)
|
||||
self.cls_by_field = self._get_cls_by_field(cls, fields)
|
||||
|
||||
@staticmethod
|
||||
def _get_default_gen(cls, fields):
|
||||
default_gen = {}
|
||||
|
||||
for field in fields:
|
||||
default_gen[field.name] = cls._get_field_default_gen(field)
|
||||
|
||||
return default_gen
|
||||
def _get_default_gen(
|
||||
cls: Type["Message"], fields: List[dataclasses.Field]
|
||||
) -> Dict[str, Callable[[], Any]]:
|
||||
return {field.name: cls._get_field_default_gen(field) for field in fields}
|
||||
|
||||
@staticmethod
|
||||
def _get_cls_by_field(cls, fields):
|
||||
def _get_cls_by_field(
|
||||
cls: Type["Message"], fields: List[dataclasses.Field]
|
||||
) -> Dict[str, Type]:
|
||||
field_cls = {}
|
||||
|
||||
for field in fields:
|
||||
@@ -503,7 +498,7 @@ class ProtoClassMetadata:
|
||||
],
|
||||
bases=(Message,),
|
||||
)
|
||||
field_cls[field.name + ".value"] = vt
|
||||
field_cls[f"{field.name}.value"] = vt
|
||||
else:
|
||||
field_cls[field.name] = cls._cls_for(field)
|
||||
|
||||
@@ -612,7 +607,7 @@ class Message(ABC):
|
||||
super().__setattr__(attr, value)
|
||||
|
||||
@property
|
||||
def _betterproto(self):
|
||||
def _betterproto(self) -> ProtoClassMetadata:
|
||||
"""
|
||||
Lazy initialize metadata for each protobuf class.
|
||||
It may be initialized multiple times in a multi-threaded environment,
|
||||
@@ -726,9 +721,8 @@ class Message(ABC):
|
||||
|
||||
@classmethod
|
||||
def _type_hints(cls) -> Dict[str, Type]:
|
||||
module = inspect.getmodule(cls)
|
||||
type_hints = get_type_hints(cls, vars(module))
|
||||
return type_hints
|
||||
module = sys.modules[cls.__module__]
|
||||
return get_type_hints(cls, vars(module))
|
||||
|
||||
@classmethod
|
||||
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
|
||||
@@ -739,7 +733,7 @@ class Message(ABC):
|
||||
field_cls = field_cls.__args__[index]
|
||||
return field_cls
|
||||
|
||||
def _get_field_default(self, field_name):
|
||||
def _get_field_default(self, field_name: str) -> Any:
|
||||
return self._betterproto.default_gen[field_name]()
|
||||
|
||||
@classmethod
|
||||
@@ -762,7 +756,7 @@ class Message(ABC):
|
||||
elif issubclass(t, Enum):
|
||||
# Enums always default to zero.
|
||||
return int
|
||||
elif t == datetime:
|
||||
elif t is datetime:
|
||||
# Offsets are relative to 1970-01-01T00:00:00Z
|
||||
return datetime_default_gen
|
||||
else:
|
||||
@@ -966,7 +960,7 @@ class Message(ABC):
|
||||
)
|
||||
):
|
||||
output[cased_name] = value.to_dict(casing, include_default_values)
|
||||
elif meta.proto_type == "map":
|
||||
elif meta.proto_type == TYPE_MAP:
|
||||
for k in value:
|
||||
if hasattr(value[k], "to_dict"):
|
||||
value[k] = value[k].to_dict(casing, include_default_values)
|
||||
@@ -1032,12 +1026,12 @@ class Message(ABC):
|
||||
continue
|
||||
|
||||
if value[key] is not None:
|
||||
if meta.proto_type == "message":
|
||||
if meta.proto_type == TYPE_MESSAGE:
|
||||
v = getattr(self, field_name)
|
||||
if isinstance(v, list):
|
||||
cls = self._betterproto.cls_by_field[field_name]
|
||||
for i in range(len(value[key])):
|
||||
v.append(cls().from_dict(value[key][i]))
|
||||
for item in value[key]:
|
||||
v.append(cls().from_dict(item))
|
||||
elif isinstance(v, datetime):
|
||||
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
|
||||
setattr(self, field_name, v)
|
||||
@@ -1052,7 +1046,7 @@ class Message(ABC):
|
||||
v.from_dict(value[key])
|
||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||
v = getattr(self, field_name)
|
||||
cls = self._betterproto.cls_by_field[field_name + ".value"]
|
||||
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
|
||||
for k in value[key]:
|
||||
v[k] = cls().from_dict(value[key][k])
|
||||
else:
|
||||
@@ -1134,7 +1128,7 @@ def serialized_on_wire(message: Message) -> bool:
|
||||
return message._serialized_on_wire
|
||||
|
||||
|
||||
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
||||
def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]:
|
||||
"""
|
||||
Return the name and value of a message's one-of field group.
|
||||
|
||||
@@ -1145,21 +1139,21 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
||||
"""
|
||||
field_name = message._group_current.get(group_name)
|
||||
if not field_name:
|
||||
return ("", None)
|
||||
return (field_name, getattr(message, field_name))
|
||||
return "", None
|
||||
return field_name, getattr(message, field_name)
|
||||
|
||||
|
||||
# Circular import workaround: google.protobuf depends on base classes defined above.
|
||||
from .lib.google.protobuf import ( # noqa
|
||||
Duration,
|
||||
Timestamp,
|
||||
BoolValue,
|
||||
BytesValue,
|
||||
DoubleValue,
|
||||
Duration,
|
||||
FloatValue,
|
||||
Int32Value,
|
||||
Int64Value,
|
||||
StringValue,
|
||||
Timestamp,
|
||||
UInt32Value,
|
||||
UInt64Value,
|
||||
)
|
||||
@@ -1174,8 +1168,8 @@ class _Duration(Duration):
|
||||
parts = str(delta.total_seconds()).split(".")
|
||||
if len(parts) > 1:
|
||||
while len(parts[1]) not in [3, 6, 9]:
|
||||
parts[1] = parts[1] + "0"
|
||||
return ".".join(parts) + "s"
|
||||
parts[1] = f"{parts[1]}0"
|
||||
return f"{'.'.join(parts)}s"
|
||||
|
||||
|
||||
class _Timestamp(Timestamp):
|
||||
@@ -1191,15 +1185,15 @@ class _Timestamp(Timestamp):
|
||||
if (nanos % 1e9) == 0:
|
||||
# If there are 0 fractional digits, the fractional
|
||||
# point '.' should be omitted when serializing.
|
||||
return result + "Z"
|
||||
return f"{result}Z"
|
||||
if (nanos % 1e6) == 0:
|
||||
# Serialize 3 fractional digits.
|
||||
return result + ".%03dZ" % (nanos / 1e6)
|
||||
return f"{result}.{int(nanos // 1e6) :03d}Z"
|
||||
if (nanos % 1e3) == 0:
|
||||
# Serialize 6 fractional digits.
|
||||
return result + ".%06dZ" % (nanos / 1e3)
|
||||
return f"{result}.{int(nanos // 1e3) :06d}Z"
|
||||
# Serialize 9 fractional digits.
|
||||
return result + ".%09dZ" % nanos
|
||||
return f"{result}.{nanos:09d}"
|
||||
|
||||
|
||||
class _WrappedMessage(Message):
|
||||
|
||||
Reference in New Issue
Block a user