Misc cleanup, see commit body (#227)
- Enable oneof_enum test case that passes now (removed the xfail) - Switch from toml to tomlkit as a dev dep for better toml support - upgrade poethepoet to latest stable release - use full table format for poe tasks to avoid long lines in pyproject.toml - remove redundant _WrappedMessage class - fix various Mypy warnings - reformat some comments for consistent line length
This commit is contained in:
@@ -15,6 +15,7 @@ from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
@@ -272,7 +273,7 @@ class Enum(enum.IntEnum):
|
||||
The member was not found in the Enum.
|
||||
"""
|
||||
try:
|
||||
return cls._member_map_[name]
|
||||
return cls._member_map_[name] # type: ignore
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
|
||||
|
||||
@@ -522,13 +523,13 @@ class ProtoClassMetadata:
|
||||
|
||||
@staticmethod
|
||||
def _get_default_gen(
|
||||
cls: Type["Message"], fields: List[dataclasses.Field]
|
||||
cls: Type["Message"], fields: Iterable[dataclasses.Field]
|
||||
) -> Dict[str, Callable[[], Any]]:
|
||||
return {field.name: cls._get_field_default_gen(field) for field in fields}
|
||||
|
||||
@staticmethod
|
||||
def _get_cls_by_field(
|
||||
cls: Type["Message"], fields: List[dataclasses.Field]
|
||||
cls: Type["Message"], fields: Iterable[dataclasses.Field]
|
||||
) -> Dict[str, Type]:
|
||||
field_cls = {}
|
||||
|
||||
@@ -687,7 +688,7 @@ class Message(ABC):
|
||||
meta = getattr(self.__class__, "_betterproto_meta", None)
|
||||
if not meta:
|
||||
meta = ProtoClassMetadata(self.__class__)
|
||||
self.__class__._betterproto_meta = meta
|
||||
self.__class__._betterproto_meta = meta # type: ignore
|
||||
return meta
|
||||
|
||||
def __bytes__(self) -> bytes:
|
||||
@@ -763,7 +764,7 @@ class Message(ABC):
|
||||
meta.number,
|
||||
meta.proto_type,
|
||||
value,
|
||||
serialize_empty=serialize_empty or selected_in_group,
|
||||
serialize_empty=serialize_empty or bool(selected_in_group),
|
||||
wraps=meta.wraps or "",
|
||||
)
|
||||
|
||||
@@ -1067,7 +1068,7 @@ class Message(ABC):
|
||||
output[cased_name] = b64encode(value).decode("utf8")
|
||||
elif meta.proto_type == TYPE_ENUM:
|
||||
if field_is_repeated:
|
||||
enum_class: Type[Enum] = field_types[field_name].__args__[0]
|
||||
enum_class = field_types[field_name].__args__[0]
|
||||
if isinstance(value, typing.Iterable) and not isinstance(
|
||||
value, str
|
||||
):
|
||||
@@ -1076,7 +1077,7 @@ class Message(ABC):
|
||||
# transparently upgrade single value to repeated
|
||||
output[cased_name] = [enum_class(value).name]
|
||||
else:
|
||||
enum_class: Type[Enum] = field_types[field_name] # noqa
|
||||
enum_class = field_types[field_name] # noqa
|
||||
output[cased_name] = enum_class(value).name
|
||||
elif meta.proto_type in (TYPE_FLOAT, TYPE_DOUBLE):
|
||||
if field_is_repeated:
|
||||
@@ -1293,23 +1294,6 @@ class _Timestamp(Timestamp):
|
||||
return f"{result}.{nanos:09d}"
|
||||
|
||||
|
||||
class _WrappedMessage(Message):
|
||||
"""
|
||||
Google protobuf wrapper types base class. JSON representation is just the
|
||||
value itself.
|
||||
"""
|
||||
|
||||
value: Any
|
||||
|
||||
def to_dict(self, casing: Casing = Casing.CAMEL) -> Any:
|
||||
return self.value
|
||||
|
||||
def from_dict(self: T, value: Any) -> T:
|
||||
if value is not None:
|
||||
self.value = value
|
||||
return self
|
||||
|
||||
|
||||
def _get_wrapper(proto_type: str) -> Type:
|
||||
"""Get the wrapper message class for a wrapped type."""
|
||||
|
||||
|
||||
@@ -100,8 +100,9 @@ def reference_descendent(
|
||||
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
|
||||
) -> str:
|
||||
"""
|
||||
Returns a reference to a python type in a package that is a descendent of the current package,
|
||||
and adds the required import that is aliased to avoid name conflicts.
|
||||
Returns a reference to a python type in a package that is a descendent of the
|
||||
current package, and adds the required import that is aliased to avoid name
|
||||
conflicts.
|
||||
"""
|
||||
importing_descendent = py_package[len(current_package) :]
|
||||
string_from = ".".join(importing_descendent[:-1])
|
||||
@@ -119,8 +120,9 @@ def reference_ancestor(
|
||||
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
|
||||
) -> str:
|
||||
"""
|
||||
Returns a reference to a python type in a package which is an ancestor to the current package,
|
||||
and adds the required import that is aliased (if possible) to avoid name conflicts.
|
||||
Returns a reference to a python type in a package which is an ancestor to the
|
||||
current package, and adds the required import that is aliased (if possible) to avoid
|
||||
name conflicts.
|
||||
|
||||
Adds trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34).
|
||||
"""
|
||||
@@ -141,10 +143,10 @@ def reference_cousin(
|
||||
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
|
||||
) -> str:
|
||||
"""
|
||||
Returns a reference to a python type in a package that is not descendent, ancestor or sibling,
|
||||
and adds the required import that is aliased to avoid name conflicts.
|
||||
Returns a reference to a python type in a package that is not descendent, ancestor
|
||||
or sibling, and adds the required import that is aliased to avoid name conflicts.
|
||||
"""
|
||||
shared_ancestry = os.path.commonprefix([current_package, py_package])
|
||||
shared_ancestry = os.path.commonprefix([current_package, py_package]) # type: ignore
|
||||
distance_up = len(current_package) - len(shared_ancestry)
|
||||
string_from = f".{'.' * distance_up}" + ".".join(
|
||||
py_package[len(shared_ancestry) : -1]
|
||||
|
||||
@@ -70,7 +70,7 @@ class AsyncChannel(AsyncIterable[T]):
|
||||
"""
|
||||
|
||||
def __init__(self, *, buffer_limit: int = 0, close: bool = False):
|
||||
self._queue: asyncio.Queue[Union[T, object]] = asyncio.Queue(buffer_limit)
|
||||
self._queue: asyncio.Queue[T] = asyncio.Queue(buffer_limit)
|
||||
self._closed = False
|
||||
self._waiting_receivers: int = 0
|
||||
# Track whether flush has been invoked so it can only happen once
|
||||
|
||||
@@ -170,6 +170,8 @@ class ProtoContentBase:
|
||||
comment_indent: int = 4
|
||||
parent: Union["betterproto.Message", "OutputTemplate"]
|
||||
|
||||
__dataclass_fields__: Dict[str, object]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Checks that no fake default fields were left as placeholders."""
|
||||
for field_name, field_val in self.__dataclass_fields__.items():
|
||||
|
||||
@@ -13,7 +13,7 @@ from betterproto.lib.google.protobuf.compiler import (
|
||||
import itertools
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import Iterator, List, Tuple, TYPE_CHECKING, Union
|
||||
from typing import Iterator, List, Set, Tuple, TYPE_CHECKING, Union
|
||||
from .compiler import outputfile_compiler
|
||||
from .models import (
|
||||
EnumDefinitionCompiler,
|
||||
@@ -38,7 +38,7 @@ def traverse(
|
||||
) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]":
|
||||
# Todo: Keep information about nested hierarchy
|
||||
def _traverse(
|
||||
path: List[int], items: List["Descriptor"], prefix=""
|
||||
path: List[int], items: List["EnumDescriptorProto"], prefix=""
|
||||
) -> Iterator[Tuple[Union[str, EnumDescriptorProto], List[int]]]:
|
||||
for i, item in enumerate(items):
|
||||
# Adjust the name since we flatten the hierarchy.
|
||||
|
||||
Reference in New Issue
Block a user