Fix parameters missing from services (#381)
This commit is contained in:
parent
bc13e7070d
commit
3fd5a0d662
@ -379,15 +379,10 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
|
||||
elif proto_type == TYPE_MESSAGE:
|
||||
if isinstance(value, datetime):
|
||||
# Convert the `datetime` to a timestamp message.
|
||||
seconds = int(value.timestamp())
|
||||
nanos = int(value.microsecond * 1e3)
|
||||
value = _Timestamp(seconds=seconds, nanos=nanos)
|
||||
value = _Timestamp.from_datetime(value)
|
||||
elif isinstance(value, timedelta):
|
||||
# Convert the `timedelta` to a duration message.
|
||||
total_ms = value // timedelta(microseconds=1)
|
||||
seconds = int(total_ms / 1e6)
|
||||
nanos = int((total_ms % 1e6) * 1e3)
|
||||
value = _Duration(seconds=seconds, nanos=nanos)
|
||||
value = _Duration.from_timedelta(value)
|
||||
elif wraps:
|
||||
if value is None:
|
||||
return b""
|
||||
@ -1505,6 +1500,15 @@ from .lib.google.protobuf import ( # noqa
|
||||
|
||||
|
||||
class _Duration(Duration):
|
||||
@classmethod
|
||||
def from_timedelta(
|
||||
cls, delta: timedelta, *, _1_microsecond: timedelta = timedelta(microseconds=1)
|
||||
) -> "_Duration":
|
||||
total_ms = delta // _1_microsecond
|
||||
seconds = int(total_ms / 1e6)
|
||||
nanos = int((total_ms % 1e6) * 1e3)
|
||||
return cls(seconds, nanos)
|
||||
|
||||
def to_timedelta(self) -> timedelta:
|
||||
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
|
||||
|
||||
@ -1518,6 +1522,12 @@ class _Duration(Duration):
|
||||
|
||||
|
||||
class _Timestamp(Timestamp):
|
||||
@classmethod
|
||||
def from_datetime(cls, dt: datetime) -> "_Timestamp":
|
||||
seconds = int(dt.timestamp())
|
||||
nanos = int(dt.microsecond * 1e3)
|
||||
return cls(seconds, nanos)
|
||||
|
||||
def to_datetime(self) -> datetime:
|
||||
ts = self.seconds + (self.nanos / 1e9)
|
||||
return datetime.fromtimestamp(ts, tz=timezone.utc)
|
||||
|
@ -43,7 +43,7 @@ def parse_source_type_name(field_type_name: str) -> Tuple[str, str]:
|
||||
|
||||
|
||||
def get_type_reference(
|
||||
package: str, imports: set, source_type: str, unwrap: bool = True
|
||||
*, package: str, imports: set, source_type: str, unwrap: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Return a Python type name for a proto type reference. Adds the import if
|
||||
|
@ -15,21 +15,22 @@ from typing import (
|
||||
|
||||
import grpclib.const
|
||||
|
||||
from .._types import (
|
||||
ST,
|
||||
T,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from grpclib.client import Channel
|
||||
from grpclib.metadata import Deadline
|
||||
|
||||
from .._types import (
|
||||
ST,
|
||||
IProtoMessage,
|
||||
Message,
|
||||
T,
|
||||
)
|
||||
|
||||
|
||||
Value = Union[str, bytes]
|
||||
MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]]
|
||||
MessageLike = Union[T, ST]
|
||||
MessageSource = Union[Iterable[ST], AsyncIterable[ST]]
|
||||
MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
|
||||
|
||||
|
||||
class ServiceStub(ABC):
|
||||
@ -65,13 +66,13 @@ class ServiceStub(ABC):
|
||||
async def _unary_unary(
|
||||
self,
|
||||
route: str,
|
||||
request: MessageLike,
|
||||
response_type: Type[T],
|
||||
request: "IProtoMessage",
|
||||
response_type: Type["T"],
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
deadline: Optional["Deadline"] = None,
|
||||
metadata: Optional[MetadataLike] = None,
|
||||
) -> T:
|
||||
) -> "T":
|
||||
"""Make a unary request and return the response."""
|
||||
async with self.channel.request(
|
||||
route,
|
||||
@ -88,13 +89,13 @@ class ServiceStub(ABC):
|
||||
async def _unary_stream(
|
||||
self,
|
||||
route: str,
|
||||
request: MessageLike,
|
||||
response_type: Type[T],
|
||||
request: "IProtoMessage",
|
||||
response_type: Type["T"],
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
deadline: Optional["Deadline"] = None,
|
||||
metadata: Optional[MetadataLike] = None,
|
||||
) -> AsyncIterator[T]:
|
||||
) -> AsyncIterator["T"]:
|
||||
"""Make a unary request and return the stream response iterator."""
|
||||
async with self.channel.request(
|
||||
route,
|
||||
@ -111,13 +112,13 @@ class ServiceStub(ABC):
|
||||
self,
|
||||
route: str,
|
||||
request_iterator: MessageSource,
|
||||
request_type: Type[ST],
|
||||
response_type: Type[T],
|
||||
request_type: Type["IProtoMessage"],
|
||||
response_type: Type["T"],
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
deadline: Optional["Deadline"] = None,
|
||||
metadata: Optional[MetadataLike] = None,
|
||||
) -> T:
|
||||
) -> "T":
|
||||
"""Make a stream request and return the response."""
|
||||
async with self.channel.request(
|
||||
route,
|
||||
@ -135,13 +136,13 @@ class ServiceStub(ABC):
|
||||
self,
|
||||
route: str,
|
||||
request_iterator: MessageSource,
|
||||
request_type: Type[ST],
|
||||
response_type: Type[T],
|
||||
request_type: Type["IProtoMessage"],
|
||||
response_type: Type["T"],
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
deadline: Optional["Deadline"] = None,
|
||||
metadata: Optional[MetadataLike] = None,
|
||||
) -> AsyncIterator[T]:
|
||||
) -> AsyncIterator["T"]:
|
||||
"""
|
||||
Make a stream request and return an AsyncIterator to iterate over response
|
||||
messages.
|
||||
|
@ -252,6 +252,7 @@ class OutputTemplate:
|
||||
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
|
||||
services: List["ServiceCompiler"] = field(default_factory=list)
|
||||
imports_type_checking_only: Set[str] = field(default_factory=set)
|
||||
output: bool = True
|
||||
|
||||
@property
|
||||
def package(self) -> str:
|
||||
@ -704,6 +705,7 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
|
||||
# add imports required for request arguments timeout, deadline and metadata
|
||||
self.output_file.typing_imports.add("Optional")
|
||||
self.output_file.imports_type_checking_only.add("import grpclib.server")
|
||||
self.output_file.imports_type_checking_only.add(
|
||||
"from betterproto.grpc.grpclib_client import MetadataLike"
|
||||
)
|
||||
@ -768,6 +770,7 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
package=self.output_file.package,
|
||||
imports=self.output_file.imports,
|
||||
source_type=self.proto_obj.input_type,
|
||||
unwrap=False,
|
||||
).strip('"')
|
||||
|
||||
@property
|
||||
|
@ -74,14 +74,6 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
||||
request_data = PluginRequestCompiler(plugin_request_obj=request)
|
||||
# Gather output packages
|
||||
for proto_file in request.proto_file:
|
||||
if (
|
||||
proto_file.package == "google.protobuf"
|
||||
and "INCLUDE_GOOGLE" not in plugin_options
|
||||
):
|
||||
# If not INCLUDE_GOOGLE,
|
||||
# skip re-compiling Google's well-known types
|
||||
continue
|
||||
|
||||
output_package_name = proto_file.package
|
||||
if output_package_name not in request_data.output_packages:
|
||||
# Create a new output if there is no output for this package
|
||||
@ -91,6 +83,14 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
||||
# Add this input file to the output corresponding to this package
|
||||
request_data.output_packages[output_package_name].input_files.append(proto_file)
|
||||
|
||||
if (
|
||||
proto_file.package == "google.protobuf"
|
||||
and "INCLUDE_GOOGLE" not in plugin_options
|
||||
):
|
||||
# If not INCLUDE_GOOGLE,
|
||||
# skip outputting Google's well-known types
|
||||
request_data.output_packages[output_package_name].output = False
|
||||
|
||||
# Read Messages and Enums
|
||||
# We need to read Messages before Services in so that we can
|
||||
# get the references to input/output messages for each service
|
||||
@ -113,6 +113,8 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
||||
# Generate output files
|
||||
output_paths: Set[pathlib.Path] = set()
|
||||
for output_package_name, output_package in request_data.output_packages.items():
|
||||
if not output_package.output:
|
||||
continue
|
||||
|
||||
# Add files to the response object
|
||||
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
|
||||
|
@ -15,13 +15,14 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no
|
||||
{% endif %}
|
||||
|
||||
import betterproto
|
||||
{% if output_file.services %}
|
||||
from betterproto.grpc.grpclib_server import ServiceBase
|
||||
import grpclib
|
||||
{% endif %}
|
||||
|
||||
{% for i in output_file.imports|sort %}
|
||||
{{ i }}
|
||||
{% endfor %}
|
||||
{% if output_file.services %}
|
||||
import grpclib
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.imports_type_checking_only %}
|
||||
from typing import TYPE_CHECKING
|
||||
@ -96,9 +97,11 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
{# Client streaming: need a request iterator instead #}
|
||||
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
|
||||
{%- endif -%}
|
||||
,
|
||||
*
|
||||
, timeout: Optional[float] = None
|
||||
, deadline: Optional["Deadline"] = None
|
||||
, metadata: Optional["_MetadataLike"] = None
|
||||
, metadata: Optional["MetadataLike"] = None
|
||||
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||
{% if method.comment %}
|
||||
{{ method.comment }}
|
||||
@ -179,7 +182,7 @@ class {{ service.py_name }}Base(ServiceBase):
|
||||
{% endfor %}
|
||||
|
||||
{% for method in service.methods %}
|
||||
async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None:
|
||||
async def __rpc_{{ method.py_name }}(self, stream: "grpclib.server.Stream[{{ method.py_input_message_type }}, {{ method.py_output_message_type }}]") -> None:
|
||||
{% if not method.client_streaming %}
|
||||
request = await stream.recv_message()
|
||||
{% else %}
|
||||
|
@ -9,6 +9,7 @@ xfail = {
|
||||
}
|
||||
|
||||
services = {
|
||||
"googletypes_request",
|
||||
"googletypes_response",
|
||||
"googletypes_response_embedded",
|
||||
"service",
|
||||
|
29
tests/inputs/googletypes_request/googletypes_request.proto
Normal file
29
tests/inputs/googletypes_request/googletypes_request.proto
Normal file
@ -0,0 +1,29 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package googletypes_request;
|
||||
|
||||
import "google/protobuf/duration.proto";
|
||||
import "google/protobuf/empty.proto";
|
||||
import "google/protobuf/timestamp.proto";
|
||||
import "google/protobuf/wrappers.proto";
|
||||
|
||||
// Tests that google types can be used as params
|
||||
|
||||
service Test {
|
||||
rpc SendDouble (google.protobuf.DoubleValue) returns (Input);
|
||||
rpc SendFloat (google.protobuf.FloatValue) returns (Input);
|
||||
rpc SendInt64 (google.protobuf.Int64Value) returns (Input);
|
||||
rpc SendUInt64 (google.protobuf.UInt64Value) returns (Input);
|
||||
rpc SendInt32 (google.protobuf.Int32Value) returns (Input);
|
||||
rpc SendUInt32 (google.protobuf.UInt32Value) returns (Input);
|
||||
rpc SendBool (google.protobuf.BoolValue) returns (Input);
|
||||
rpc SendString (google.protobuf.StringValue) returns (Input);
|
||||
rpc SendBytes (google.protobuf.BytesValue) returns (Input);
|
||||
rpc SendDatetime (google.protobuf.Timestamp) returns (Input);
|
||||
rpc SendTimedelta (google.protobuf.Duration) returns (Input);
|
||||
rpc SendEmpty (google.protobuf.Empty) returns (Input);
|
||||
}
|
||||
|
||||
message Input {
|
||||
|
||||
}
|
47
tests/inputs/googletypes_request/test_googletypes_request.py
Normal file
47
tests/inputs/googletypes_request/test_googletypes_request.py
Normal file
@ -0,0 +1,47 @@
|
||||
from datetime import (
|
||||
datetime,
|
||||
timedelta,
|
||||
)
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
import betterproto.lib.google.protobuf as protobuf
|
||||
from tests.mocks import MockChannel
|
||||
from tests.output_betterproto.googletypes_request import (
|
||||
Input,
|
||||
TestStub,
|
||||
)
|
||||
|
||||
|
||||
test_cases = [
|
||||
(TestStub.send_double, protobuf.DoubleValue, 2.5),
|
||||
(TestStub.send_float, protobuf.FloatValue, 2.5),
|
||||
(TestStub.send_int64, protobuf.Int64Value, -64),
|
||||
(TestStub.send_u_int64, protobuf.UInt64Value, 64),
|
||||
(TestStub.send_int32, protobuf.Int32Value, -32),
|
||||
(TestStub.send_u_int32, protobuf.UInt32Value, 32),
|
||||
(TestStub.send_bool, protobuf.BoolValue, True),
|
||||
(TestStub.send_string, protobuf.StringValue, "string"),
|
||||
(TestStub.send_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]),
|
||||
(TestStub.send_datetime, protobuf.Timestamp, datetime(2038, 1, 19, 3, 14, 8)),
|
||||
(TestStub.send_timedelta, protobuf.Duration, timedelta(seconds=123456)),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||
async def test_channel_receives_wrapped_type(
|
||||
service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
|
||||
):
|
||||
wrapped_value = wrapper_class()
|
||||
wrapped_value.value = value
|
||||
channel = MockChannel(responses=[Input()])
|
||||
service = TestStub(channel)
|
||||
|
||||
await service_method(service, wrapped_value)
|
||||
|
||||
assert channel.requests[0]["request"] == type(wrapped_value)
|
Loading…
x
Reference in New Issue
Block a user