Fix parameters missing from services (#381)

This commit is contained in:
James Hilton-Balfe 2022-07-06 19:05:40 +01:00 committed by GitHub
parent bc13e7070d
commit 3fd5a0d662
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 136 additions and 40 deletions

View File

@ -379,15 +379,10 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes:
elif proto_type == TYPE_MESSAGE:
if isinstance(value, datetime):
# Convert the `datetime` to a timestamp message.
seconds = int(value.timestamp())
nanos = int(value.microsecond * 1e3)
value = _Timestamp(seconds=seconds, nanos=nanos)
value = _Timestamp.from_datetime(value)
elif isinstance(value, timedelta):
# Convert the `timedelta` to a duration message.
total_ms = value // timedelta(microseconds=1)
seconds = int(total_ms / 1e6)
nanos = int((total_ms % 1e6) * 1e3)
value = _Duration(seconds=seconds, nanos=nanos)
value = _Duration.from_timedelta(value)
elif wraps:
if value is None:
return b""
@ -1505,6 +1500,15 @@ from .lib.google.protobuf import ( # noqa
class _Duration(Duration):
@classmethod
def from_timedelta(
cls, delta: timedelta, *, _1_microsecond: timedelta = timedelta(microseconds=1)
) -> "_Duration":
total_ms = delta // _1_microsecond
seconds = int(total_ms / 1e6)
nanos = int((total_ms % 1e6) * 1e3)
return cls(seconds, nanos)
def to_timedelta(self) -> timedelta:
return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3)
@ -1518,6 +1522,12 @@ class _Duration(Duration):
class _Timestamp(Timestamp):
@classmethod
def from_datetime(cls, dt: datetime) -> "_Timestamp":
seconds = int(dt.timestamp())
nanos = int(dt.microsecond * 1e3)
return cls(seconds, nanos)
def to_datetime(self) -> datetime:
ts = self.seconds + (self.nanos / 1e9)
return datetime.fromtimestamp(ts, tz=timezone.utc)

View File

@ -43,7 +43,7 @@ def parse_source_type_name(field_type_name: str) -> Tuple[str, str]:
def get_type_reference(
package: str, imports: set, source_type: str, unwrap: bool = True
*, package: str, imports: set, source_type: str, unwrap: bool = True
) -> str:
"""
Return a Python type name for a proto type reference. Adds the import if

View File

@ -15,21 +15,22 @@ from typing import (
import grpclib.const
from .._types import (
ST,
T,
)
if TYPE_CHECKING:
from grpclib.client import Channel
from grpclib.metadata import Deadline
from .._types import (
ST,
IProtoMessage,
Message,
T,
)
Value = Union[str, bytes]
MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]]
MessageLike = Union[T, ST]
MessageSource = Union[Iterable[ST], AsyncIterable[ST]]
MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]]
class ServiceStub(ABC):
@ -65,13 +66,13 @@ class ServiceStub(ABC):
async def _unary_unary(
self,
route: str,
request: MessageLike,
response_type: Type[T],
request: "IProtoMessage",
response_type: Type["T"],
*,
timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None,
metadata: Optional[MetadataLike] = None,
) -> T:
) -> "T":
"""Make a unary request and return the response."""
async with self.channel.request(
route,
@ -88,13 +89,13 @@ class ServiceStub(ABC):
async def _unary_stream(
self,
route: str,
request: MessageLike,
response_type: Type[T],
request: "IProtoMessage",
response_type: Type["T"],
*,
timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None,
metadata: Optional[MetadataLike] = None,
) -> AsyncIterator[T]:
) -> AsyncIterator["T"]:
"""Make a unary request and return the stream response iterator."""
async with self.channel.request(
route,
@ -111,13 +112,13 @@ class ServiceStub(ABC):
self,
route: str,
request_iterator: MessageSource,
request_type: Type[ST],
response_type: Type[T],
request_type: Type["IProtoMessage"],
response_type: Type["T"],
*,
timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None,
metadata: Optional[MetadataLike] = None,
) -> T:
) -> "T":
"""Make a stream request and return the response."""
async with self.channel.request(
route,
@ -135,13 +136,13 @@ class ServiceStub(ABC):
self,
route: str,
request_iterator: MessageSource,
request_type: Type[ST],
response_type: Type[T],
request_type: Type["IProtoMessage"],
response_type: Type["T"],
*,
timeout: Optional[float] = None,
deadline: Optional["Deadline"] = None,
metadata: Optional[MetadataLike] = None,
) -> AsyncIterator[T]:
) -> AsyncIterator["T"]:
"""
Make a stream request and return an AsyncIterator to iterate over response
messages.

View File

@ -252,6 +252,7 @@ class OutputTemplate:
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
services: List["ServiceCompiler"] = field(default_factory=list)
imports_type_checking_only: Set[str] = field(default_factory=set)
output: bool = True
@property
def package(self) -> str:
@ -704,6 +705,7 @@ class ServiceMethodCompiler(ProtoContentBase):
# add imports required for request arguments timeout, deadline and metadata
self.output_file.typing_imports.add("Optional")
self.output_file.imports_type_checking_only.add("import grpclib.server")
self.output_file.imports_type_checking_only.add(
"from betterproto.grpc.grpclib_client import MetadataLike"
)
@ -768,6 +770,7 @@ class ServiceMethodCompiler(ProtoContentBase):
package=self.output_file.package,
imports=self.output_file.imports,
source_type=self.proto_obj.input_type,
unwrap=False,
).strip('"')
@property

View File

@ -74,14 +74,6 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
request_data = PluginRequestCompiler(plugin_request_obj=request)
# Gather output packages
for proto_file in request.proto_file:
if (
proto_file.package == "google.protobuf"
and "INCLUDE_GOOGLE" not in plugin_options
):
# If not INCLUDE_GOOGLE,
# skip re-compiling Google's well-known types
continue
output_package_name = proto_file.package
if output_package_name not in request_data.output_packages:
# Create a new output if there is no output for this package
@ -91,6 +83,14 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
# Add this input file to the output corresponding to this package
request_data.output_packages[output_package_name].input_files.append(proto_file)
if (
proto_file.package == "google.protobuf"
and "INCLUDE_GOOGLE" not in plugin_options
):
# If not INCLUDE_GOOGLE,
# skip outputting Google's well-known types
request_data.output_packages[output_package_name].output = False
# Read Messages and Enums
# We need to read Messages before Services in so that we can
# get the references to input/output messages for each service
@ -113,6 +113,8 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
# Generate output files
output_paths: Set[pathlib.Path] = set()
for output_package_name, output_package in request_data.output_packages.items():
if not output_package.output:
continue
# Add files to the response object
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")

View File

@ -15,13 +15,14 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no
{% endif %}
import betterproto
{% if output_file.services %}
from betterproto.grpc.grpclib_server import ServiceBase
import grpclib
{% endif %}
{% for i in output_file.imports|sort %}
{{ i }}
{% endfor %}
{% if output_file.services %}
import grpclib
{% endif %}
{% if output_file.imports_type_checking_only %}
from typing import TYPE_CHECKING
@ -96,9 +97,11 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
{%- endif -%}
,
*
, timeout: Optional[float] = None
, deadline: Optional["Deadline"] = None
, metadata: Optional["_MetadataLike"] = None
, metadata: Optional["MetadataLike"] = None
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
{{ method.comment }}
@ -179,7 +182,7 @@ class {{ service.py_name }}Base(ServiceBase):
{% endfor %}
{% for method in service.methods %}
async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None:
async def __rpc_{{ method.py_name }}(self, stream: "grpclib.server.Stream[{{ method.py_input_message_type }}, {{ method.py_output_message_type }}]") -> None:
{% if not method.client_streaming %}
request = await stream.recv_message()
{% else %}

View File

@ -9,6 +9,7 @@ xfail = {
}
services = {
"googletypes_request",
"googletypes_response",
"googletypes_response_embedded",
"service",

View File

@ -0,0 +1,29 @@
syntax = "proto3";
package googletypes_request;
import "google/protobuf/duration.proto";
import "google/protobuf/empty.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto";
// Tests that google types can be used as params
service Test {
rpc SendDouble (google.protobuf.DoubleValue) returns (Input);
rpc SendFloat (google.protobuf.FloatValue) returns (Input);
rpc SendInt64 (google.protobuf.Int64Value) returns (Input);
rpc SendUInt64 (google.protobuf.UInt64Value) returns (Input);
rpc SendInt32 (google.protobuf.Int32Value) returns (Input);
rpc SendUInt32 (google.protobuf.UInt32Value) returns (Input);
rpc SendBool (google.protobuf.BoolValue) returns (Input);
rpc SendString (google.protobuf.StringValue) returns (Input);
rpc SendBytes (google.protobuf.BytesValue) returns (Input);
rpc SendDatetime (google.protobuf.Timestamp) returns (Input);
rpc SendTimedelta (google.protobuf.Duration) returns (Input);
rpc SendEmpty (google.protobuf.Empty) returns (Input);
}
message Input {
}

View File

@ -0,0 +1,47 @@
from datetime import (
datetime,
timedelta,
)
from typing import (
Any,
Callable,
)
import pytest
import betterproto.lib.google.protobuf as protobuf
from tests.mocks import MockChannel
from tests.output_betterproto.googletypes_request import (
Input,
TestStub,
)
test_cases = [
(TestStub.send_double, protobuf.DoubleValue, 2.5),
(TestStub.send_float, protobuf.FloatValue, 2.5),
(TestStub.send_int64, protobuf.Int64Value, -64),
(TestStub.send_u_int64, protobuf.UInt64Value, 64),
(TestStub.send_int32, protobuf.Int32Value, -32),
(TestStub.send_u_int32, protobuf.UInt32Value, 32),
(TestStub.send_bool, protobuf.BoolValue, True),
(TestStub.send_string, protobuf.StringValue, "string"),
(TestStub.send_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]),
(TestStub.send_datetime, protobuf.Timestamp, datetime(2038, 1, 19, 3, 14, 8)),
(TestStub.send_timedelta, protobuf.Duration, timedelta(seconds=123456)),
]
@pytest.mark.asyncio
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
async def test_channel_receives_wrapped_type(
service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
):
wrapped_value = wrapper_class()
wrapped_value.value = value
channel = MockChannel(responses=[Input()])
service = TestStub(channel)
await service_method(service, wrapped_value)
assert channel.requests[0]["request"] == type(wrapped_value)