QOL fixes (#141)
- Add missing type annotations - Various style improvements - Use constants more consistently - enforce black on benchmark code
This commit is contained in:
parent
bf9412e083
commit
8f7af272cc
2
.github/workflows/code-quality.yml
vendored
2
.github/workflows/code-quality.yml
vendored
@ -17,7 +17,7 @@ jobs:
|
|||||||
- name: Run Black
|
- name: Run Black
|
||||||
uses: lgeiger/black-action@master
|
uses: lgeiger/black-action@master
|
||||||
with:
|
with:
|
||||||
args: --check src/ tests/
|
args: --check src/ tests/ benchmarks/
|
||||||
|
|
||||||
- name: Install rST dependcies
|
- name: Install rST dependcies
|
||||||
run: python -m pip install doc8
|
run: python -m pip install doc8
|
||||||
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -15,4 +15,5 @@ output
|
|||||||
.DS_Store
|
.DS_Store
|
||||||
.tox
|
.tox
|
||||||
.venv
|
.venv
|
||||||
.asv
|
.asv
|
||||||
|
venv
|
||||||
|
@ -8,9 +8,9 @@ class TestMessage(betterproto.Message):
|
|||||||
bar: str = betterproto.string_field(1)
|
bar: str = betterproto.string_field(1)
|
||||||
baz: float = betterproto.float_field(2)
|
baz: float = betterproto.float_field(2)
|
||||||
|
|
||||||
|
|
||||||
class BenchMessage:
|
class BenchMessage:
|
||||||
"""Test creation and usage a proto message.
|
"""Test creation and usage a proto message."""
|
||||||
"""
|
|
||||||
|
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.cls = TestMessage
|
self.cls = TestMessage
|
||||||
@ -18,8 +18,8 @@ class BenchMessage:
|
|||||||
self.instance_filled = TestMessage(0, "test", 0.0)
|
self.instance_filled = TestMessage(0, "test", 0.0)
|
||||||
|
|
||||||
def time_overhead(self):
|
def time_overhead(self):
|
||||||
"""Overhead in class definition.
|
"""Overhead in class definition."""
|
||||||
"""
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Message(betterproto.Message):
|
class Message(betterproto.Message):
|
||||||
foo: int = betterproto.uint32_field(0)
|
foo: int = betterproto.uint32_field(0)
|
||||||
@ -27,29 +27,25 @@ class BenchMessage:
|
|||||||
baz: float = betterproto.float_field(2)
|
baz: float = betterproto.float_field(2)
|
||||||
|
|
||||||
def time_instantiation(self):
|
def time_instantiation(self):
|
||||||
"""Time instantiation
|
"""Time instantiation"""
|
||||||
"""
|
|
||||||
self.cls()
|
self.cls()
|
||||||
|
|
||||||
def time_attribute_access(self):
|
def time_attribute_access(self):
|
||||||
"""Time to access an attribute
|
"""Time to access an attribute"""
|
||||||
"""
|
|
||||||
self.instance.foo
|
self.instance.foo
|
||||||
self.instance.bar
|
self.instance.bar
|
||||||
self.instance.baz
|
self.instance.baz
|
||||||
|
|
||||||
def time_init_with_values(self):
|
def time_init_with_values(self):
|
||||||
"""Time to set an attribute
|
"""Time to set an attribute"""
|
||||||
"""
|
|
||||||
self.cls(0, "test", 0.0)
|
self.cls(0, "test", 0.0)
|
||||||
|
|
||||||
def time_attribute_setting(self):
|
def time_attribute_setting(self):
|
||||||
"""Time to set attributes
|
"""Time to set attributes"""
|
||||||
"""
|
|
||||||
self.instance.foo = 0
|
self.instance.foo = 0
|
||||||
self.instance.bar = "test"
|
self.instance.bar = "test"
|
||||||
self.instance.baz = 0.0
|
self.instance.baz = 0.0
|
||||||
|
|
||||||
def time_serialize(self):
|
def time_serialize(self):
|
||||||
"""Time serializing a message to wire."""
|
"""Time serializing a message to wire."""
|
||||||
bytes(self.instance_filled)
|
bytes(self.instance_filled)
|
||||||
@ -58,6 +54,6 @@ class BenchMessage:
|
|||||||
class MemSuite:
|
class MemSuite:
|
||||||
def setup(self):
|
def setup(self):
|
||||||
self.cls = TestMessage
|
self.cls = TestMessage
|
||||||
|
|
||||||
def mem_instance(self):
|
def mem_instance(self):
|
||||||
return self.cls()
|
return self.cls()
|
||||||
|
@ -26,7 +26,7 @@ from ._types import T
|
|||||||
from .casing import camel_case, safe_snake_case, snake_case
|
from .casing import camel_case, safe_snake_case, snake_case
|
||||||
from .grpc.grpclib_client import ServiceStub
|
from .grpc.grpclib_client import ServiceStub
|
||||||
|
|
||||||
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
|
if sys.version_info[:2] < (3, 7):
|
||||||
# Apply backport of datetime.fromisoformat from 3.7
|
# Apply backport of datetime.fromisoformat from 3.7
|
||||||
from backports.datetime_fromisoformat import MonkeyPatch
|
from backports.datetime_fromisoformat import MonkeyPatch
|
||||||
|
|
||||||
@ -110,7 +110,7 @@ WIRE_LEN_DELIM_TYPES = [TYPE_STRING, TYPE_BYTES, TYPE_MESSAGE, TYPE_MAP]
|
|||||||
|
|
||||||
|
|
||||||
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
|
# Protobuf datetimes start at the Unix Epoch in 1970 in UTC.
|
||||||
def datetime_default_gen():
|
def datetime_default_gen() -> datetime:
|
||||||
return datetime(1970, 1, 1, tzinfo=timezone.utc)
|
return datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
@ -256,8 +256,7 @@ class Enum(enum.IntEnum):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_string(cls, name: str) -> "Enum":
|
def from_string(cls, name: str) -> "Enum":
|
||||||
"""
|
"""Return the value which corresponds to the string name.
|
||||||
Return the value which corresponds to the string name.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
-----------
|
-----------
|
||||||
@ -316,11 +315,7 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
|
|||||||
return encode_varint(value)
|
return encode_varint(value)
|
||||||
elif proto_type in [TYPE_SINT32, TYPE_SINT64]:
|
elif proto_type in [TYPE_SINT32, TYPE_SINT64]:
|
||||||
# Handle zig-zag encoding.
|
# Handle zig-zag encoding.
|
||||||
if value >= 0:
|
return encode_varint(value << 1 if value >= 0 else (value << 1) ^ (~0))
|
||||||
value = value << 1
|
|
||||||
else:
|
|
||||||
value = (value << 1) ^ (~0)
|
|
||||||
return encode_varint(value)
|
|
||||||
elif proto_type in FIXED_TYPES:
|
elif proto_type in FIXED_TYPES:
|
||||||
return struct.pack(_pack_fmt(proto_type), value)
|
return struct.pack(_pack_fmt(proto_type), value)
|
||||||
elif proto_type == TYPE_STRING:
|
elif proto_type == TYPE_STRING:
|
||||||
@ -413,15 +408,15 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
|||||||
wire_type = num_wire & 0x7
|
wire_type = num_wire & 0x7
|
||||||
|
|
||||||
decoded: Any = None
|
decoded: Any = None
|
||||||
if wire_type == 0:
|
if wire_type == WIRE_VARINT:
|
||||||
decoded, i = decode_varint(value, i)
|
decoded, i = decode_varint(value, i)
|
||||||
elif wire_type == 1:
|
elif wire_type == WIRE_FIXED_64:
|
||||||
decoded, i = value[i : i + 8], i + 8
|
decoded, i = value[i : i + 8], i + 8
|
||||||
elif wire_type == 2:
|
elif wire_type == WIRE_LEN_DELIM:
|
||||||
length, i = decode_varint(value, i)
|
length, i = decode_varint(value, i)
|
||||||
decoded = value[i : i + length]
|
decoded = value[i : i + length]
|
||||||
i += length
|
i += length
|
||||||
elif wire_type == 5:
|
elif wire_type == WIRE_FIXED_32:
|
||||||
decoded, i = value[i : i + 4], i + 4
|
decoded, i = value[i : i + 4], i + 4
|
||||||
|
|
||||||
yield ParsedField(
|
yield ParsedField(
|
||||||
@ -430,12 +425,6 @@ def parse_fields(value: bytes) -> Generator[ParsedField, None, None]:
|
|||||||
|
|
||||||
|
|
||||||
class ProtoClassMetadata:
|
class ProtoClassMetadata:
|
||||||
oneof_group_by_field: Dict[str, str]
|
|
||||||
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
|
|
||||||
default_gen: Dict[str, Callable]
|
|
||||||
cls_by_field: Dict[str, Type]
|
|
||||||
field_name_by_number: Dict[int, str]
|
|
||||||
meta_by_field_name: Dict[str, FieldMetadata]
|
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
"oneof_group_by_field",
|
"oneof_group_by_field",
|
||||||
"oneof_field_by_group",
|
"oneof_field_by_group",
|
||||||
@ -446,6 +435,14 @@ class ProtoClassMetadata:
|
|||||||
"sorted_field_names",
|
"sorted_field_names",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
oneof_group_by_field: Dict[str, str]
|
||||||
|
oneof_field_by_group: Dict[str, Set[dataclasses.Field]]
|
||||||
|
field_name_by_number: Dict[int, str]
|
||||||
|
meta_by_field_name: Dict[str, FieldMetadata]
|
||||||
|
sorted_field_names: Tuple[str, ...]
|
||||||
|
default_gen: Dict[str, Callable[[], Any]]
|
||||||
|
cls_by_field: Dict[str, Type]
|
||||||
|
|
||||||
def __init__(self, cls: Type["Message"]):
|
def __init__(self, cls: Type["Message"]):
|
||||||
by_field = {}
|
by_field = {}
|
||||||
by_group: Dict[str, Set] = {}
|
by_group: Dict[str, Set] = {}
|
||||||
@ -470,23 +467,21 @@ class ProtoClassMetadata:
|
|||||||
self.field_name_by_number = by_field_number
|
self.field_name_by_number = by_field_number
|
||||||
self.meta_by_field_name = by_field_name
|
self.meta_by_field_name = by_field_name
|
||||||
self.sorted_field_names = tuple(
|
self.sorted_field_names = tuple(
|
||||||
by_field_number[number] for number in sorted(by_field_number.keys())
|
by_field_number[number] for number in sorted(by_field_number)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.default_gen = self._get_default_gen(cls, fields)
|
self.default_gen = self._get_default_gen(cls, fields)
|
||||||
self.cls_by_field = self._get_cls_by_field(cls, fields)
|
self.cls_by_field = self._get_cls_by_field(cls, fields)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_default_gen(cls, fields):
|
def _get_default_gen(
|
||||||
default_gen = {}
|
cls: Type["Message"], fields: List[dataclasses.Field]
|
||||||
|
) -> Dict[str, Callable[[], Any]]:
|
||||||
for field in fields:
|
return {field.name: cls._get_field_default_gen(field) for field in fields}
|
||||||
default_gen[field.name] = cls._get_field_default_gen(field)
|
|
||||||
|
|
||||||
return default_gen
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_cls_by_field(cls, fields):
|
def _get_cls_by_field(
|
||||||
|
cls: Type["Message"], fields: List[dataclasses.Field]
|
||||||
|
) -> Dict[str, Type]:
|
||||||
field_cls = {}
|
field_cls = {}
|
||||||
|
|
||||||
for field in fields:
|
for field in fields:
|
||||||
@ -503,7 +498,7 @@ class ProtoClassMetadata:
|
|||||||
],
|
],
|
||||||
bases=(Message,),
|
bases=(Message,),
|
||||||
)
|
)
|
||||||
field_cls[field.name + ".value"] = vt
|
field_cls[f"{field.name}.value"] = vt
|
||||||
else:
|
else:
|
||||||
field_cls[field.name] = cls._cls_for(field)
|
field_cls[field.name] = cls._cls_for(field)
|
||||||
|
|
||||||
@ -612,7 +607,7 @@ class Message(ABC):
|
|||||||
super().__setattr__(attr, value)
|
super().__setattr__(attr, value)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _betterproto(self):
|
def _betterproto(self) -> ProtoClassMetadata:
|
||||||
"""
|
"""
|
||||||
Lazy initialize metadata for each protobuf class.
|
Lazy initialize metadata for each protobuf class.
|
||||||
It may be initialized multiple times in a multi-threaded environment,
|
It may be initialized multiple times in a multi-threaded environment,
|
||||||
@ -726,9 +721,8 @@ class Message(ABC):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _type_hints(cls) -> Dict[str, Type]:
|
def _type_hints(cls) -> Dict[str, Type]:
|
||||||
module = inspect.getmodule(cls)
|
module = sys.modules[cls.__module__]
|
||||||
type_hints = get_type_hints(cls, vars(module))
|
return get_type_hints(cls, vars(module))
|
||||||
return type_hints
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
|
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
|
||||||
@ -739,7 +733,7 @@ class Message(ABC):
|
|||||||
field_cls = field_cls.__args__[index]
|
field_cls = field_cls.__args__[index]
|
||||||
return field_cls
|
return field_cls
|
||||||
|
|
||||||
def _get_field_default(self, field_name):
|
def _get_field_default(self, field_name: str) -> Any:
|
||||||
return self._betterproto.default_gen[field_name]()
|
return self._betterproto.default_gen[field_name]()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -762,7 +756,7 @@ class Message(ABC):
|
|||||||
elif issubclass(t, Enum):
|
elif issubclass(t, Enum):
|
||||||
# Enums always default to zero.
|
# Enums always default to zero.
|
||||||
return int
|
return int
|
||||||
elif t == datetime:
|
elif t is datetime:
|
||||||
# Offsets are relative to 1970-01-01T00:00:00Z
|
# Offsets are relative to 1970-01-01T00:00:00Z
|
||||||
return datetime_default_gen
|
return datetime_default_gen
|
||||||
else:
|
else:
|
||||||
@ -966,7 +960,7 @@ class Message(ABC):
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
output[cased_name] = value.to_dict(casing, include_default_values)
|
output[cased_name] = value.to_dict(casing, include_default_values)
|
||||||
elif meta.proto_type == "map":
|
elif meta.proto_type == TYPE_MAP:
|
||||||
for k in value:
|
for k in value:
|
||||||
if hasattr(value[k], "to_dict"):
|
if hasattr(value[k], "to_dict"):
|
||||||
value[k] = value[k].to_dict(casing, include_default_values)
|
value[k] = value[k].to_dict(casing, include_default_values)
|
||||||
@ -1032,12 +1026,12 @@ class Message(ABC):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if value[key] is not None:
|
if value[key] is not None:
|
||||||
if meta.proto_type == "message":
|
if meta.proto_type == TYPE_MESSAGE:
|
||||||
v = getattr(self, field_name)
|
v = getattr(self, field_name)
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
cls = self._betterproto.cls_by_field[field_name]
|
cls = self._betterproto.cls_by_field[field_name]
|
||||||
for i in range(len(value[key])):
|
for item in value[key]:
|
||||||
v.append(cls().from_dict(value[key][i]))
|
v.append(cls().from_dict(item))
|
||||||
elif isinstance(v, datetime):
|
elif isinstance(v, datetime):
|
||||||
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
|
v = datetime.fromisoformat(value[key].replace("Z", "+00:00"))
|
||||||
setattr(self, field_name, v)
|
setattr(self, field_name, v)
|
||||||
@ -1052,7 +1046,7 @@ class Message(ABC):
|
|||||||
v.from_dict(value[key])
|
v.from_dict(value[key])
|
||||||
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
elif meta.map_types and meta.map_types[1] == TYPE_MESSAGE:
|
||||||
v = getattr(self, field_name)
|
v = getattr(self, field_name)
|
||||||
cls = self._betterproto.cls_by_field[field_name + ".value"]
|
cls = self._betterproto.cls_by_field[f"{field_name}.value"]
|
||||||
for k in value[key]:
|
for k in value[key]:
|
||||||
v[k] = cls().from_dict(value[key][k])
|
v[k] = cls().from_dict(value[key][k])
|
||||||
else:
|
else:
|
||||||
@ -1134,7 +1128,7 @@ def serialized_on_wire(message: Message) -> bool:
|
|||||||
return message._serialized_on_wire
|
return message._serialized_on_wire
|
||||||
|
|
||||||
|
|
||||||
def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]:
|
||||||
"""
|
"""
|
||||||
Return the name and value of a message's one-of field group.
|
Return the name and value of a message's one-of field group.
|
||||||
|
|
||||||
@ -1145,21 +1139,21 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Any]:
|
|||||||
"""
|
"""
|
||||||
field_name = message._group_current.get(group_name)
|
field_name = message._group_current.get(group_name)
|
||||||
if not field_name:
|
if not field_name:
|
||||||
return ("", None)
|
return "", None
|
||||||
return (field_name, getattr(message, field_name))
|
return field_name, getattr(message, field_name)
|
||||||
|
|
||||||
|
|
||||||
# Circular import workaround: google.protobuf depends on base classes defined above.
|
# Circular import workaround: google.protobuf depends on base classes defined above.
|
||||||
from .lib.google.protobuf import ( # noqa
|
from .lib.google.protobuf import ( # noqa
|
||||||
Duration,
|
|
||||||
Timestamp,
|
|
||||||
BoolValue,
|
BoolValue,
|
||||||
BytesValue,
|
BytesValue,
|
||||||
DoubleValue,
|
DoubleValue,
|
||||||
|
Duration,
|
||||||
FloatValue,
|
FloatValue,
|
||||||
Int32Value,
|
Int32Value,
|
||||||
Int64Value,
|
Int64Value,
|
||||||
StringValue,
|
StringValue,
|
||||||
|
Timestamp,
|
||||||
UInt32Value,
|
UInt32Value,
|
||||||
UInt64Value,
|
UInt64Value,
|
||||||
)
|
)
|
||||||
@ -1174,8 +1168,8 @@ class _Duration(Duration):
|
|||||||
parts = str(delta.total_seconds()).split(".")
|
parts = str(delta.total_seconds()).split(".")
|
||||||
if len(parts) > 1:
|
if len(parts) > 1:
|
||||||
while len(parts[1]) not in [3, 6, 9]:
|
while len(parts[1]) not in [3, 6, 9]:
|
||||||
parts[1] = parts[1] + "0"
|
parts[1] = f"{parts[1]}0"
|
||||||
return ".".join(parts) + "s"
|
return f"{'.'.join(parts)}s"
|
||||||
|
|
||||||
|
|
||||||
class _Timestamp(Timestamp):
|
class _Timestamp(Timestamp):
|
||||||
@ -1191,15 +1185,15 @@ class _Timestamp(Timestamp):
|
|||||||
if (nanos % 1e9) == 0:
|
if (nanos % 1e9) == 0:
|
||||||
# If there are 0 fractional digits, the fractional
|
# If there are 0 fractional digits, the fractional
|
||||||
# point '.' should be omitted when serializing.
|
# point '.' should be omitted when serializing.
|
||||||
return result + "Z"
|
return f"{result}Z"
|
||||||
if (nanos % 1e6) == 0:
|
if (nanos % 1e6) == 0:
|
||||||
# Serialize 3 fractional digits.
|
# Serialize 3 fractional digits.
|
||||||
return result + ".%03dZ" % (nanos / 1e6)
|
return f"{result}.{int(nanos // 1e6) :03d}Z"
|
||||||
if (nanos % 1e3) == 0:
|
if (nanos % 1e3) == 0:
|
||||||
# Serialize 6 fractional digits.
|
# Serialize 6 fractional digits.
|
||||||
return result + ".%06dZ" % (nanos / 1e3)
|
return f"{result}.{int(nanos // 1e3) :06d}Z"
|
||||||
# Serialize 9 fractional digits.
|
# Serialize 9 fractional digits.
|
||||||
return result + ".%09dZ" % nanos
|
return f"{result}.{nanos:09d}"
|
||||||
|
|
||||||
|
|
||||||
class _WrappedMessage(Message):
|
class _WrappedMessage(Message):
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
from typing import TYPE_CHECKING, TypeVar
|
from typing import TYPE_CHECKING, TypeVar
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from . import Message
|
|
||||||
from grpclib._typing import IProtoMessage
|
from grpclib._typing import IProtoMessage
|
||||||
|
from . import Message
|
||||||
|
|
||||||
# Bound type variable to allow methods to return `self` of subclasses
|
# Bound type variable to allow methods to return `self` of subclasses
|
||||||
T = TypeVar("T", bound="Message")
|
T = TypeVar("T", bound="Message")
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List, Set, Type
|
from typing import Dict, List, Set, Tuple, Type
|
||||||
|
|
||||||
from betterproto import safe_snake_case
|
from ..casing import safe_snake_case
|
||||||
from betterproto.compile.naming import pythonize_class_name
|
from ..lib.google import protobuf as google_protobuf
|
||||||
from betterproto.lib.google import protobuf as google_protobuf
|
from .naming import pythonize_class_name
|
||||||
|
|
||||||
WRAPPER_TYPES: Dict[str, Type] = {
|
WRAPPER_TYPES: Dict[str, Type] = {
|
||||||
".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
|
".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
|
||||||
@ -19,7 +19,7 @@ WRAPPER_TYPES: Dict[str, Type] = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def parse_source_type_name(field_type_name):
|
def parse_source_type_name(field_type_name: str) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Split full source type name into package and type name.
|
Split full source type name into package and type name.
|
||||||
E.g. 'root.package.Message' -> ('root.package', 'Message')
|
E.g. 'root.package.Message' -> ('root.package', 'Message')
|
||||||
@ -50,7 +50,7 @@ def get_type_reference(
|
|||||||
if source_type == ".google.protobuf.Duration":
|
if source_type == ".google.protobuf.Duration":
|
||||||
return "timedelta"
|
return "timedelta"
|
||||||
|
|
||||||
if source_type == ".google.protobuf.Timestamp":
|
elif source_type == ".google.protobuf.Timestamp":
|
||||||
return "datetime"
|
return "datetime"
|
||||||
|
|
||||||
source_package, source_type = parse_source_type_name(source_type)
|
source_package, source_type = parse_source_type_name(source_type)
|
||||||
@ -79,7 +79,7 @@ def get_type_reference(
|
|||||||
return reference_cousin(current_package, imports, py_package, py_type)
|
return reference_cousin(current_package, imports, py_package, py_type)
|
||||||
|
|
||||||
|
|
||||||
def reference_absolute(imports, py_package, py_type):
|
def reference_absolute(imports: Set[str], py_package: List[str], py_type: str) -> str:
|
||||||
"""
|
"""
|
||||||
Returns a reference to a python type located in the root, i.e. sys.path.
|
Returns a reference to a python type located in the root, i.e. sys.path.
|
||||||
"""
|
"""
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
from betterproto import casing
|
from betterproto import casing
|
||||||
|
|
||||||
|
|
||||||
def pythonize_class_name(name):
|
def pythonize_class_name(name: str) -> str:
|
||||||
return casing.pascal_case(name)
|
return casing.pascal_case(name)
|
||||||
|
|
||||||
|
|
||||||
def pythonize_field_name(name: str):
|
def pythonize_field_name(name: str) -> str:
|
||||||
return casing.safe_snake_case(name)
|
return casing.safe_snake_case(name)
|
||||||
|
|
||||||
|
|
||||||
def pythonize_method_name(name: str):
|
def pythonize_method_name(name: str) -> str:
|
||||||
return casing.safe_snake_case(name)
|
return casing.safe_snake_case(name)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from abc import ABC
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import grpclib.const
|
from abc import ABC
|
||||||
from typing import (
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
AsyncIterable,
|
AsyncIterable,
|
||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Collection,
|
Collection,
|
||||||
@ -9,11 +9,13 @@ from typing import (
|
|||||||
Mapping,
|
Mapping,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
Tuple,
|
||||||
TYPE_CHECKING,
|
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
from betterproto._types import ST, T
|
|
||||||
|
import grpclib.const
|
||||||
|
|
||||||
|
from .._types import ST, T
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from grpclib.client import Channel
|
from grpclib.client import Channel
|
||||||
|
@ -1,12 +1,5 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from typing import (
|
from typing import AsyncIterable, AsyncIterator, Iterable, Optional, TypeVar, Union
|
||||||
AsyncIterable,
|
|
||||||
AsyncIterator,
|
|
||||||
Iterable,
|
|
||||||
Optional,
|
|
||||||
TypeVar,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@ -16,8 +9,6 @@ class ChannelClosed(Exception):
|
|||||||
An exception raised on an attempt to send through a closed channel
|
An exception raised on an attempt to send through a closed channel
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ChannelDone(Exception):
|
class ChannelDone(Exception):
|
||||||
"""
|
"""
|
||||||
@ -25,8 +16,6 @@ class ChannelDone(Exception):
|
|||||||
and empty.
|
and empty.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncChannel(AsyncIterable[T]):
|
class AsyncChannel(AsyncIterable[T]):
|
||||||
"""
|
"""
|
||||||
|
@ -5,10 +5,9 @@ try:
|
|||||||
import black
|
import black
|
||||||
import jinja2
|
import jinja2
|
||||||
except ImportError as err:
|
except ImportError as err:
|
||||||
missing_import = err.args[0][17:-1]
|
|
||||||
print(
|
print(
|
||||||
"\033[31m"
|
"\033[31m"
|
||||||
f"Unable to import `{missing_import}` from betterproto plugin! "
|
f"Unable to import `{err.name}` from betterproto plugin! "
|
||||||
"Please ensure that you've installed betterproto as "
|
"Please ensure that you've installed betterproto as "
|
||||||
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
||||||
"are included."
|
"are included."
|
||||||
@ -16,7 +15,7 @@ except ImportError as err:
|
|||||||
)
|
)
|
||||||
raise SystemExit(1)
|
raise SystemExit(1)
|
||||||
|
|
||||||
from betterproto.plugin.models import OutputTemplate
|
from .models import OutputTemplate
|
||||||
|
|
||||||
|
|
||||||
def outputfile_compiler(output_file: OutputTemplate) -> str:
|
def outputfile_compiler(output_file: OutputTemplate) -> str:
|
||||||
@ -32,9 +31,7 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
|
|||||||
)
|
)
|
||||||
template = env.get_template("template.py.j2")
|
template = env.get_template("template.py.j2")
|
||||||
|
|
||||||
res = black.format_str(
|
return black.format_str(
|
||||||
template.render(output_file=output_file),
|
template.render(output_file=output_file),
|
||||||
mode=black.FileMode(target_versions={black.TargetVersion.PY37}),
|
mode=black.FileMode(target_versions={black.TargetVersion.PY37}),
|
||||||
)
|
)
|
||||||
|
|
||||||
return res
|
|
||||||
|
@ -1,13 +1,14 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||||
|
|
||||||
from betterproto.plugin.parser import generate_code
|
from betterproto.plugin.parser import generate_code
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
"""The plugin's main entry point."""
|
"""The plugin's main entry point."""
|
||||||
# Read request message from stdin
|
# Read request message from stdin
|
||||||
data = sys.stdin.buffer.read()
|
data = sys.stdin.buffer.read()
|
||||||
@ -33,7 +34,7 @@ def main():
|
|||||||
sys.stdout.buffer.write(output)
|
sys.stdout.buffer.write(output)
|
||||||
|
|
||||||
|
|
||||||
def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest):
|
def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest) -> None:
|
||||||
"""
|
"""
|
||||||
For developers: Supports running plugin.py standalone so its possible to debug it.
|
For developers: Supports running plugin.py standalone so its possible to debug it.
|
||||||
Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file.
|
Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file.
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
"""Plugin model dataclasses.
|
"""Plugin model dataclasses.
|
||||||
|
|
||||||
These classes are meant to be an intermediate representation
|
These classes are meant to be an intermediate representation
|
||||||
of protbuf objects. They are used to organize the data collected during parsing.
|
of protobuf objects. They are used to organize the data collected during parsing.
|
||||||
|
|
||||||
The general intention is to create a doubly-linked tree-like structure
|
The general intention is to create a doubly-linked tree-like structure
|
||||||
with the following types of references:
|
with the following types of references:
|
||||||
- Downwards references: from message -> fields, from output package -> messages
|
- Downwards references: from message -> fields, from output package -> messages
|
||||||
or from service -> service methods
|
or from service -> service methods
|
||||||
- Upwards references: from field -> message, message -> package.
|
- Upwards references: from field -> message, message -> package.
|
||||||
- Input/ouput message references: from a service method to it's corresponding
|
- Input/output message references: from a service method to it's corresponding
|
||||||
input/output messages, which may even be in another package.
|
input/output messages, which may even be in another package.
|
||||||
|
|
||||||
There are convenience methods to allow climbing up and down this tree, for
|
There are convenience methods to allow climbing up and down this tree, for
|
||||||
@ -26,36 +26,24 @@ such as a pythonized name, that will be calculated from proto_obj.
|
|||||||
The instantiation should also attach a reference to the new object
|
The instantiation should also attach a reference to the new object
|
||||||
into the corresponding place within it's parent object. For example,
|
into the corresponding place within it's parent object. For example,
|
||||||
instantiating field `A` with parent message `B` should add a
|
instantiating field `A` with parent message `B` should add a
|
||||||
reference to `A` to `B`'s `fields` attirbute.
|
reference to `A` to `B`'s `fields` attribute.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
|
||||||
from dataclasses import field
|
|
||||||
from typing import (
|
|
||||||
Iterator,
|
|
||||||
Union,
|
|
||||||
Type,
|
|
||||||
List,
|
|
||||||
Dict,
|
|
||||||
Set,
|
|
||||||
Text,
|
|
||||||
)
|
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, Iterator, List, Optional, Set, Text, Type, Union
|
||||||
|
|
||||||
import betterproto
|
import betterproto
|
||||||
from betterproto.compile.importing import (
|
|
||||||
get_type_reference,
|
from ..casing import sanitize_name
|
||||||
parse_source_type_name,
|
from ..compile.importing import get_type_reference, parse_source_type_name
|
||||||
)
|
from ..compile.naming import (
|
||||||
from betterproto.compile.naming import (
|
|
||||||
pythonize_class_name,
|
pythonize_class_name,
|
||||||
pythonize_field_name,
|
pythonize_field_name,
|
||||||
pythonize_method_name,
|
pythonize_method_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..casing import sanitize_name
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# betterproto[compiler] specific dependencies
|
# betterproto[compiler] specific dependencies
|
||||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||||
@ -67,10 +55,9 @@ try:
|
|||||||
MethodDescriptorProto,
|
MethodDescriptorProto,
|
||||||
)
|
)
|
||||||
except ImportError as err:
|
except ImportError as err:
|
||||||
missing_import = re.match(r".*(cannot import name .*$)", err.args[0]).group(1)
|
|
||||||
print(
|
print(
|
||||||
"\033[31m"
|
"\033[31m"
|
||||||
f"Unable to import `{missing_import}` from betterproto plugin! "
|
f"Unable to import `{err.name}` from betterproto plugin! "
|
||||||
"Please ensure that you've installed betterproto as "
|
"Please ensure that you've installed betterproto as "
|
||||||
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
||||||
"are included."
|
"are included."
|
||||||
@ -124,10 +111,11 @@ PROTO_PACKED_TYPES = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_comment(proto_file, path: List[int], indent: int = 4) -> str:
|
def get_comment(
|
||||||
|
proto_file: "FileDescriptorProto", path: List[int], indent: int = 4
|
||||||
|
) -> str:
|
||||||
pad = " " * indent
|
pad = " " * indent
|
||||||
for sci in proto_file.source_code_info.location:
|
for sci in proto_file.source_code_info.location:
|
||||||
# print(list(sci.path), path, file=sys.stderr)
|
|
||||||
if list(sci.path) == path and sci.leading_comments:
|
if list(sci.path) == path and sci.leading_comments:
|
||||||
lines = textwrap.wrap(
|
lines = textwrap.wrap(
|
||||||
sci.leading_comments.strip().replace("\n", ""), width=79 - indent
|
sci.leading_comments.strip().replace("\n", ""), width=79 - indent
|
||||||
@ -153,9 +141,9 @@ class ProtoContentBase:
|
|||||||
|
|
||||||
path: List[int]
|
path: List[int]
|
||||||
comment_indent: int = 4
|
comment_indent: int = 4
|
||||||
parent: Union["Messsage", "OutputTemplate"]
|
parent: Union["betterproto.Message", "OutputTemplate"]
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
"""Checks that no fake default fields were left as placeholders."""
|
"""Checks that no fake default fields were left as placeholders."""
|
||||||
for field_name, field_val in self.__dataclass_fields__.items():
|
for field_name, field_val in self.__dataclass_fields__.items():
|
||||||
if field_val is PLACEHOLDER:
|
if field_val is PLACEHOLDER:
|
||||||
@ -273,7 +261,7 @@ class MessageCompiler(ProtoContentBase):
|
|||||||
)
|
)
|
||||||
deprecated: bool = field(default=False, init=False)
|
deprecated: bool = field(default=False, init=False)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
# Add message to output file
|
# Add message to output file
|
||||||
if isinstance(self.parent, OutputTemplate):
|
if isinstance(self.parent, OutputTemplate):
|
||||||
if isinstance(self, EnumDefinitionCompiler):
|
if isinstance(self, EnumDefinitionCompiler):
|
||||||
@ -314,17 +302,17 @@ def is_map(
|
|||||||
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
|
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
|
||||||
if message_type == map_entry:
|
if message_type == map_entry:
|
||||||
for nested in parent_message.nested_type: # parent message
|
for nested in parent_message.nested_type: # parent message
|
||||||
if nested.name.replace("_", "").lower() == map_entry:
|
if (
|
||||||
if nested.options.map_entry:
|
nested.name.replace("_", "").lower() == map_entry
|
||||||
return True
|
and nested.options.map_entry
|
||||||
|
):
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
|
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
|
||||||
"""True if proto_field_obj is a OneOf, otherwise False."""
|
"""True if proto_field_obj is a OneOf, otherwise False."""
|
||||||
if proto_field_obj.HasField("oneof_index"):
|
return proto_field_obj.HasField("oneof_index")
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -332,7 +320,7 @@ class FieldCompiler(MessageCompiler):
|
|||||||
parent: MessageCompiler = PLACEHOLDER
|
parent: MessageCompiler = PLACEHOLDER
|
||||||
proto_obj: FieldDescriptorProto = PLACEHOLDER
|
proto_obj: FieldDescriptorProto = PLACEHOLDER
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
# Add field to message
|
# Add field to message
|
||||||
self.parent.fields.append(self)
|
self.parent.fields.append(self)
|
||||||
# Check for new imports
|
# Check for new imports
|
||||||
@ -357,11 +345,9 @@ class FieldCompiler(MessageCompiler):
|
|||||||
([""] + self.betterproto_field_args) if self.betterproto_field_args else []
|
([""] + self.betterproto_field_args) if self.betterproto_field_args else []
|
||||||
)
|
)
|
||||||
betterproto_field_type = (
|
betterproto_field_type = (
|
||||||
f"betterproto.{self.field_type}_field({self.proto_obj.number}"
|
f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})"
|
||||||
+ field_args
|
|
||||||
+ ")"
|
|
||||||
)
|
)
|
||||||
return name + annotations + " = " + betterproto_field_type
|
return f"{name}{annotations} = {betterproto_field_type}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def betterproto_field_args(self) -> List[str]:
|
def betterproto_field_args(self) -> List[str]:
|
||||||
@ -371,7 +357,7 @@ class FieldCompiler(MessageCompiler):
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def field_wraps(self) -> Union[str, None]:
|
def field_wraps(self) -> Optional[str]:
|
||||||
"""Returns betterproto wrapped field type or None."""
|
"""Returns betterproto wrapped field type or None."""
|
||||||
match_wrapper = re.match(
|
match_wrapper = re.match(
|
||||||
r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name
|
r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name
|
||||||
@ -384,17 +370,15 @@ class FieldCompiler(MessageCompiler):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def repeated(self) -> bool:
|
def repeated(self) -> bool:
|
||||||
if self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED and not is_map(
|
return (
|
||||||
self.proto_obj, self.parent
|
self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED
|
||||||
):
|
and not is_map(self.proto_obj, self.parent)
|
||||||
return True
|
)
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def mutable(self) -> bool:
|
def mutable(self) -> bool:
|
||||||
"""True if the field is a mutable type, otherwise False."""
|
"""True if the field is a mutable type, otherwise False."""
|
||||||
annotation = self.annotation
|
return self.annotation.startswith(("List[", "Dict["))
|
||||||
return annotation.startswith("List[") or annotation.startswith("Dict[")
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def field_type(self) -> str:
|
def field_type(self) -> str:
|
||||||
@ -425,9 +409,7 @@ class FieldCompiler(MessageCompiler):
|
|||||||
@property
|
@property
|
||||||
def packed(self) -> bool:
|
def packed(self) -> bool:
|
||||||
"""True if the wire representation is a packed format."""
|
"""True if the wire representation is a packed format."""
|
||||||
if self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES:
|
return self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def py_name(self) -> str:
|
def py_name(self) -> str:
|
||||||
@ -486,22 +468,24 @@ class MapEntryCompiler(FieldCompiler):
|
|||||||
proto_k_type: str = PLACEHOLDER
|
proto_k_type: str = PLACEHOLDER
|
||||||
proto_v_type: str = PLACEHOLDER
|
proto_v_type: str = PLACEHOLDER
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
"""Explore nested types and set k_type and v_type if unset."""
|
"""Explore nested types and set k_type and v_type if unset."""
|
||||||
map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry"
|
map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry"
|
||||||
for nested in self.parent.proto_obj.nested_type:
|
for nested in self.parent.proto_obj.nested_type:
|
||||||
if nested.name.replace("_", "").lower() == map_entry:
|
if (
|
||||||
if nested.options.map_entry:
|
nested.name.replace("_", "").lower() == map_entry
|
||||||
# Get Python types
|
and nested.options.map_entry
|
||||||
self.py_k_type = FieldCompiler(
|
):
|
||||||
parent=self, proto_obj=nested.field[0] # key
|
# Get Python types
|
||||||
).py_type
|
self.py_k_type = FieldCompiler(
|
||||||
self.py_v_type = FieldCompiler(
|
parent=self, proto_obj=nested.field[0] # key
|
||||||
parent=self, proto_obj=nested.field[1] # value
|
).py_type
|
||||||
).py_type
|
self.py_v_type = FieldCompiler(
|
||||||
# Get proto types
|
parent=self, proto_obj=nested.field[1] # value
|
||||||
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
|
).py_type
|
||||||
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
|
# Get proto types
|
||||||
|
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
|
||||||
|
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
|
||||||
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
|
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -513,11 +497,11 @@ class MapEntryCompiler(FieldCompiler):
|
|||||||
return "map"
|
return "map"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def annotation(self):
|
def annotation(self) -> str:
|
||||||
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
|
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def repeated(self):
|
def repeated(self) -> bool:
|
||||||
return False # maps cannot be repeated
|
return False # maps cannot be repeated
|
||||||
|
|
||||||
|
|
||||||
@ -536,7 +520,7 @@ class EnumDefinitionCompiler(MessageCompiler):
|
|||||||
value: int
|
value: int
|
||||||
comment: str
|
comment: str
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
# Get entries/allowed values for this Enum
|
# Get entries/allowed values for this Enum
|
||||||
self.entries = [
|
self.entries = [
|
||||||
self.EnumEntry(
|
self.EnumEntry(
|
||||||
@ -551,7 +535,7 @@ class EnumDefinitionCompiler(MessageCompiler):
|
|||||||
super().__post_init__() # call MessageCompiler __post_init__
|
super().__post_init__() # call MessageCompiler __post_init__
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_value_string(self) -> int:
|
def default_value_string(self) -> str:
|
||||||
"""Python representation of the default value for Enums.
|
"""Python representation of the default value for Enums.
|
||||||
|
|
||||||
As per the spec, this is the first value of the Enum.
|
As per the spec, this is the first value of the Enum.
|
||||||
@ -572,11 +556,11 @@ class ServiceCompiler(ProtoContentBase):
|
|||||||
super().__post_init__() # check for unset fields
|
super().__post_init__() # check for unset fields
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def proto_name(self):
|
def proto_name(self) -> str:
|
||||||
return self.proto_obj.name
|
return self.proto_obj.name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def py_name(self):
|
def py_name(self) -> str:
|
||||||
return pythonize_class_name(self.proto_name)
|
return pythonize_class_name(self.proto_name)
|
||||||
|
|
||||||
|
|
||||||
@ -628,7 +612,7 @@ class ServiceMethodCompiler(ProtoContentBase):
|
|||||||
Name and actual default value (as a string)
|
Name and actual default value (as a string)
|
||||||
for each argument with mutable default values.
|
for each argument with mutable default values.
|
||||||
"""
|
"""
|
||||||
mutable_default_args = dict()
|
mutable_default_args = {}
|
||||||
|
|
||||||
if self.py_input_message:
|
if self.py_input_message:
|
||||||
for f in self.py_input_message.fields:
|
for f in self.py_input_message.fields:
|
||||||
@ -654,18 +638,15 @@ class ServiceMethodCompiler(ProtoContentBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def route(self) -> str:
|
def route(self) -> str:
|
||||||
return (
|
return f"/{self.output_file.package}.{self.parent.proto_name}/{self.proto_name}"
|
||||||
f"/{self.output_file.package}."
|
|
||||||
f"{self.parent.proto_name}/{self.proto_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def py_input_message(self) -> Union[None, MessageCompiler]:
|
def py_input_message(self) -> Optional[MessageCompiler]:
|
||||||
"""Find the input message object.
|
"""Find the input message object.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Union[None, MessageCompiler]
|
Optional[MessageCompiler]
|
||||||
Method instance representing the input message.
|
Method instance representing the input message.
|
||||||
If not input message could be found or there are no
|
If not input message could be found or there are no
|
||||||
input messages, None is returned.
|
input messages, None is returned.
|
||||||
@ -685,14 +666,13 @@ class ServiceMethodCompiler(ProtoContentBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def py_input_message_type(self) -> str:
|
def py_input_message_type(self) -> str:
|
||||||
"""String representation of the Python type correspoding to the
|
"""String representation of the Python type corresponding to the
|
||||||
input message.
|
input message.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
str
|
str
|
||||||
String representation of the Python type correspoding to the
|
String representation of the Python type corresponding to the input message.
|
||||||
input message.
|
|
||||||
"""
|
"""
|
||||||
return get_type_reference(
|
return get_type_reference(
|
||||||
package=self.output_file.package,
|
package=self.output_file.package,
|
||||||
@ -702,14 +682,13 @@ class ServiceMethodCompiler(ProtoContentBase):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def py_output_message_type(self) -> str:
|
def py_output_message_type(self) -> str:
|
||||||
"""String representation of the Python type correspoding to the
|
"""String representation of the Python type corresponding to the
|
||||||
output message.
|
output message.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
str
|
str
|
||||||
String representation of the Python type correspoding to the
|
String representation of the Python type corresponding to the output message.
|
||||||
output message.
|
|
||||||
"""
|
"""
|
||||||
return get_type_reference(
|
return get_type_reference(
|
||||||
package=self.output_file.package,
|
package=self.output_file.package,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import itertools
|
import itertools
|
||||||
import pathlib
|
import pathlib
|
||||||
import sys
|
import sys
|
||||||
from typing import List, Iterator
|
from typing import TYPE_CHECKING, Iterator, List, Tuple, Union, Set
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# betterproto[compiler] specific dependencies
|
# betterproto[compiler] specific dependencies
|
||||||
@ -13,10 +13,9 @@ try:
|
|||||||
ServiceDescriptorProto,
|
ServiceDescriptorProto,
|
||||||
)
|
)
|
||||||
except ImportError as err:
|
except ImportError as err:
|
||||||
missing_import = err.args[0][17:-1]
|
|
||||||
print(
|
print(
|
||||||
"\033[31m"
|
"\033[31m"
|
||||||
f"Unable to import `{missing_import}` from betterproto plugin! "
|
f"Unable to import `{err.name}` from betterproto plugin! "
|
||||||
"Please ensure that you've installed betterproto as "
|
"Please ensure that you've installed betterproto as "
|
||||||
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
||||||
"are included."
|
"are included."
|
||||||
@ -24,26 +23,32 @@ except ImportError as err:
|
|||||||
)
|
)
|
||||||
raise SystemExit(1)
|
raise SystemExit(1)
|
||||||
|
|
||||||
from betterproto.plugin.models import (
|
from .compiler import outputfile_compiler
|
||||||
PluginRequestCompiler,
|
from .models import (
|
||||||
OutputTemplate,
|
|
||||||
MessageCompiler,
|
|
||||||
FieldCompiler,
|
|
||||||
OneOfFieldCompiler,
|
|
||||||
MapEntryCompiler,
|
|
||||||
EnumDefinitionCompiler,
|
EnumDefinitionCompiler,
|
||||||
|
FieldCompiler,
|
||||||
|
MapEntryCompiler,
|
||||||
|
MessageCompiler,
|
||||||
|
OneOfFieldCompiler,
|
||||||
|
OutputTemplate,
|
||||||
|
PluginRequestCompiler,
|
||||||
ServiceCompiler,
|
ServiceCompiler,
|
||||||
ServiceMethodCompiler,
|
ServiceMethodCompiler,
|
||||||
is_map,
|
is_map,
|
||||||
is_oneof,
|
is_oneof,
|
||||||
)
|
)
|
||||||
|
|
||||||
from betterproto.plugin.compiler import outputfile_compiler
|
if TYPE_CHECKING:
|
||||||
|
from google.protobuf.descriptor import Descriptor
|
||||||
|
|
||||||
|
|
||||||
def traverse(proto_file: FieldDescriptorProto) -> Iterator:
|
def traverse(
|
||||||
|
proto_file: FieldDescriptorProto,
|
||||||
|
) -> "itertools.chain[Tuple[Union[str, EnumDescriptorProto], List[int]]]":
|
||||||
# Todo: Keep information about nested hierarchy
|
# Todo: Keep information about nested hierarchy
|
||||||
def _traverse(path, items, prefix=""):
|
def _traverse(
|
||||||
|
path: List[int], items: List["Descriptor"], prefix=""
|
||||||
|
) -> Iterator[Tuple[Union[str, EnumDescriptorProto], List[int]]]:
|
||||||
for i, item in enumerate(items):
|
for i, item in enumerate(items):
|
||||||
# Adjust the name since we flatten the hierarchy.
|
# Adjust the name since we flatten the hierarchy.
|
||||||
# Todo: don't change the name, but include full name in returned tuple
|
# Todo: don't change the name, but include full name in returned tuple
|
||||||
@ -104,7 +109,7 @@ def generate_code(
|
|||||||
read_protobuf_service(service, index, output_package)
|
read_protobuf_service(service, index, output_package)
|
||||||
|
|
||||||
# Generate output files
|
# Generate output files
|
||||||
output_paths: pathlib.Path = set()
|
output_paths: Set[pathlib.Path] = set()
|
||||||
for output_package_name, output_package in request_data.output_packages.items():
|
for output_package_name, output_package in request_data.output_packages.items():
|
||||||
|
|
||||||
# Add files to the response object
|
# Add files to the response object
|
||||||
@ -112,20 +117,17 @@ def generate_code(
|
|||||||
output_paths.add(output_path)
|
output_paths.add(output_path)
|
||||||
|
|
||||||
f: response.File = response.file.add()
|
f: response.File = response.file.add()
|
||||||
f.name: str = str(output_path)
|
f.name = str(output_path)
|
||||||
|
|
||||||
# Render and then format the output file
|
# Render and then format the output file
|
||||||
f.content: str = outputfile_compiler(output_file=output_package)
|
f.content = outputfile_compiler(output_file=output_package)
|
||||||
|
|
||||||
# Make each output directory a package with __init__ file
|
# Make each output directory a package with __init__ file
|
||||||
init_files = (
|
init_files = {
|
||||||
set(
|
directory.joinpath("__init__.py")
|
||||||
directory.joinpath("__init__.py")
|
for path in output_paths
|
||||||
for path in output_paths
|
for directory in path.parents
|
||||||
for directory in path.parents
|
} - output_paths
|
||||||
)
|
|
||||||
- output_paths
|
|
||||||
)
|
|
||||||
|
|
||||||
for init_file in init_files:
|
for init_file in init_files:
|
||||||
init = response.file.add()
|
init = response.file.add()
|
||||||
|
@ -27,10 +27,7 @@ class ClientStub:
|
|||||||
|
|
||||||
|
|
||||||
async def to_list(generator: AsyncIterator):
|
async def to_list(generator: AsyncIterator):
|
||||||
result = []
|
return [value async for value in generator]
|
||||||
async for value in generator:
|
|
||||||
result.append(value)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -6,7 +6,7 @@ from grpclib.client import Channel
|
|||||||
class MockChannel(Channel):
|
class MockChannel(Channel):
|
||||||
# noinspection PyMissingConstructor
|
# noinspection PyMissingConstructor
|
||||||
def __init__(self, responses=None) -> None:
|
def __init__(self, responses=None) -> None:
|
||||||
self.responses = responses if responses else []
|
self.responses = responses or []
|
||||||
self.requests = []
|
self.requests = []
|
||||||
self._loop = None
|
self._loop = None
|
||||||
|
|
||||||
|
@ -23,8 +23,7 @@ def get_files(path, suffix: str) -> Generator[str, None, None]:
|
|||||||
|
|
||||||
def get_directories(path):
|
def get_directories(path):
|
||||||
for root, directories, files in os.walk(path):
|
for root, directories, files in os.walk(path):
|
||||||
for directory in directories:
|
yield from directories
|
||||||
yield directory
|
|
||||||
|
|
||||||
|
|
||||||
async def protoc(
|
async def protoc(
|
||||||
@ -49,7 +48,7 @@ async def protoc(
|
|||||||
|
|
||||||
|
|
||||||
def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = None):
|
def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = None):
|
||||||
test_data_file_name = json_file_name if json_file_name else f"{test_case_name}.json"
|
test_data_file_name = json_file_name or f"{test_case_name}.json"
|
||||||
test_data_file_path = inputs_path.joinpath(test_case_name, test_data_file_name)
|
test_data_file_path = inputs_path.joinpath(test_case_name, test_data_file_name)
|
||||||
|
|
||||||
if not test_data_file_path.exists():
|
if not test_data_file_path.exists():
|
||||||
@ -77,7 +76,7 @@ def find_module(
|
|||||||
|
|
||||||
module_path = pathlib.Path(*module.__path__)
|
module_path = pathlib.Path(*module.__path__)
|
||||||
|
|
||||||
for sub in list(sub.parent for sub in module_path.glob("**/__init__.py")):
|
for sub in [sub.parent for sub in module_path.glob("**/__init__.py")]:
|
||||||
if sub == module_path:
|
if sub == module_path:
|
||||||
continue
|
continue
|
||||||
sub_module_path = sub.relative_to(module_path)
|
sub_module_path = sub.relative_to(module_path)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user