Fix parameters missing from services (#381)
This commit is contained in:
		
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							bc13e7070d
						
					
				
				
					commit
					3fd5a0d662
				
			| @@ -379,15 +379,10 @@ def _preprocess_single(proto_type: str, wraps: str, value: Any) -> bytes: | |||||||
|     elif proto_type == TYPE_MESSAGE: |     elif proto_type == TYPE_MESSAGE: | ||||||
|         if isinstance(value, datetime): |         if isinstance(value, datetime): | ||||||
|             # Convert the `datetime` to a timestamp message. |             # Convert the `datetime` to a timestamp message. | ||||||
|             seconds = int(value.timestamp()) |             value = _Timestamp.from_datetime(value) | ||||||
|             nanos = int(value.microsecond * 1e3) |  | ||||||
|             value = _Timestamp(seconds=seconds, nanos=nanos) |  | ||||||
|         elif isinstance(value, timedelta): |         elif isinstance(value, timedelta): | ||||||
|             # Convert the `timedelta` to a duration message. |             # Convert the `timedelta` to a duration message. | ||||||
|             total_ms = value // timedelta(microseconds=1) |             value = _Duration.from_timedelta(value) | ||||||
|             seconds = int(total_ms / 1e6) |  | ||||||
|             nanos = int((total_ms % 1e6) * 1e3) |  | ||||||
|             value = _Duration(seconds=seconds, nanos=nanos) |  | ||||||
|         elif wraps: |         elif wraps: | ||||||
|             if value is None: |             if value is None: | ||||||
|                 return b"" |                 return b"" | ||||||
| @@ -1505,6 +1500,15 @@ from .lib.google.protobuf import (  # noqa | |||||||
|  |  | ||||||
|  |  | ||||||
| class _Duration(Duration): | class _Duration(Duration): | ||||||
|  |     @classmethod | ||||||
|  |     def from_timedelta( | ||||||
|  |         cls, delta: timedelta, *, _1_microsecond: timedelta = timedelta(microseconds=1) | ||||||
|  |     ) -> "_Duration": | ||||||
|  |         total_ms = delta // _1_microsecond | ||||||
|  |         seconds = int(total_ms / 1e6) | ||||||
|  |         nanos = int((total_ms % 1e6) * 1e3) | ||||||
|  |         return cls(seconds, nanos) | ||||||
|  |  | ||||||
|     def to_timedelta(self) -> timedelta: |     def to_timedelta(self) -> timedelta: | ||||||
|         return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3) |         return timedelta(seconds=self.seconds, microseconds=self.nanos / 1e3) | ||||||
|  |  | ||||||
| @@ -1518,6 +1522,12 @@ class _Duration(Duration): | |||||||
|  |  | ||||||
|  |  | ||||||
| class _Timestamp(Timestamp): | class _Timestamp(Timestamp): | ||||||
|  |     @classmethod | ||||||
|  |     def from_datetime(cls, dt: datetime) -> "_Timestamp": | ||||||
|  |         seconds = int(dt.timestamp()) | ||||||
|  |         nanos = int(dt.microsecond * 1e3) | ||||||
|  |         return cls(seconds, nanos) | ||||||
|  |  | ||||||
|     def to_datetime(self) -> datetime: |     def to_datetime(self) -> datetime: | ||||||
|         ts = self.seconds + (self.nanos / 1e9) |         ts = self.seconds + (self.nanos / 1e9) | ||||||
|         return datetime.fromtimestamp(ts, tz=timezone.utc) |         return datetime.fromtimestamp(ts, tz=timezone.utc) | ||||||
|   | |||||||
| @@ -43,7 +43,7 @@ def parse_source_type_name(field_type_name: str) -> Tuple[str, str]: | |||||||
|  |  | ||||||
|  |  | ||||||
| def get_type_reference( | def get_type_reference( | ||||||
|     package: str, imports: set, source_type: str, unwrap: bool = True |     *, package: str, imports: set, source_type: str, unwrap: bool = True | ||||||
| ) -> str: | ) -> str: | ||||||
|     """ |     """ | ||||||
|     Return a Python type name for a proto type reference. Adds the import if |     Return a Python type name for a proto type reference. Adds the import if | ||||||
|   | |||||||
| @@ -15,21 +15,22 @@ from typing import ( | |||||||
|  |  | ||||||
| import grpclib.const | import grpclib.const | ||||||
|  |  | ||||||
| from .._types import ( |  | ||||||
|     ST, |  | ||||||
|     T, |  | ||||||
| ) |  | ||||||
|  |  | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|     from grpclib.client import Channel |     from grpclib.client import Channel | ||||||
|     from grpclib.metadata import Deadline |     from grpclib.metadata import Deadline | ||||||
|  |  | ||||||
|  |     from .._types import ( | ||||||
|  |         ST, | ||||||
|  |         IProtoMessage, | ||||||
|  |         Message, | ||||||
|  |         T, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |  | ||||||
| Value = Union[str, bytes] | Value = Union[str, bytes] | ||||||
| MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]] | MetadataLike = Union[Mapping[str, Value], Collection[Tuple[str, Value]]] | ||||||
| MessageLike = Union[T, ST] | MessageSource = Union[Iterable["IProtoMessage"], AsyncIterable["IProtoMessage"]] | ||||||
| MessageSource = Union[Iterable[ST], AsyncIterable[ST]] |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class ServiceStub(ABC): | class ServiceStub(ABC): | ||||||
| @@ -65,13 +66,13 @@ class ServiceStub(ABC): | |||||||
|     async def _unary_unary( |     async def _unary_unary( | ||||||
|         self, |         self, | ||||||
|         route: str, |         route: str, | ||||||
|         request: MessageLike, |         request: "IProtoMessage", | ||||||
|         response_type: Type[T], |         response_type: Type["T"], | ||||||
|         *, |         *, | ||||||
|         timeout: Optional[float] = None, |         timeout: Optional[float] = None, | ||||||
|         deadline: Optional["Deadline"] = None, |         deadline: Optional["Deadline"] = None, | ||||||
|         metadata: Optional[MetadataLike] = None, |         metadata: Optional[MetadataLike] = None, | ||||||
|     ) -> T: |     ) -> "T": | ||||||
|         """Make a unary request and return the response.""" |         """Make a unary request and return the response.""" | ||||||
|         async with self.channel.request( |         async with self.channel.request( | ||||||
|             route, |             route, | ||||||
| @@ -88,13 +89,13 @@ class ServiceStub(ABC): | |||||||
|     async def _unary_stream( |     async def _unary_stream( | ||||||
|         self, |         self, | ||||||
|         route: str, |         route: str, | ||||||
|         request: MessageLike, |         request: "IProtoMessage", | ||||||
|         response_type: Type[T], |         response_type: Type["T"], | ||||||
|         *, |         *, | ||||||
|         timeout: Optional[float] = None, |         timeout: Optional[float] = None, | ||||||
|         deadline: Optional["Deadline"] = None, |         deadline: Optional["Deadline"] = None, | ||||||
|         metadata: Optional[MetadataLike] = None, |         metadata: Optional[MetadataLike] = None, | ||||||
|     ) -> AsyncIterator[T]: |     ) -> AsyncIterator["T"]: | ||||||
|         """Make a unary request and return the stream response iterator.""" |         """Make a unary request and return the stream response iterator.""" | ||||||
|         async with self.channel.request( |         async with self.channel.request( | ||||||
|             route, |             route, | ||||||
| @@ -111,13 +112,13 @@ class ServiceStub(ABC): | |||||||
|         self, |         self, | ||||||
|         route: str, |         route: str, | ||||||
|         request_iterator: MessageSource, |         request_iterator: MessageSource, | ||||||
|         request_type: Type[ST], |         request_type: Type["IProtoMessage"], | ||||||
|         response_type: Type[T], |         response_type: Type["T"], | ||||||
|         *, |         *, | ||||||
|         timeout: Optional[float] = None, |         timeout: Optional[float] = None, | ||||||
|         deadline: Optional["Deadline"] = None, |         deadline: Optional["Deadline"] = None, | ||||||
|         metadata: Optional[MetadataLike] = None, |         metadata: Optional[MetadataLike] = None, | ||||||
|     ) -> T: |     ) -> "T": | ||||||
|         """Make a stream request and return the response.""" |         """Make a stream request and return the response.""" | ||||||
|         async with self.channel.request( |         async with self.channel.request( | ||||||
|             route, |             route, | ||||||
| @@ -135,13 +136,13 @@ class ServiceStub(ABC): | |||||||
|         self, |         self, | ||||||
|         route: str, |         route: str, | ||||||
|         request_iterator: MessageSource, |         request_iterator: MessageSource, | ||||||
|         request_type: Type[ST], |         request_type: Type["IProtoMessage"], | ||||||
|         response_type: Type[T], |         response_type: Type["T"], | ||||||
|         *, |         *, | ||||||
|         timeout: Optional[float] = None, |         timeout: Optional[float] = None, | ||||||
|         deadline: Optional["Deadline"] = None, |         deadline: Optional["Deadline"] = None, | ||||||
|         metadata: Optional[MetadataLike] = None, |         metadata: Optional[MetadataLike] = None, | ||||||
|     ) -> AsyncIterator[T]: |     ) -> AsyncIterator["T"]: | ||||||
|         """ |         """ | ||||||
|         Make a stream request and return an AsyncIterator to iterate over response |         Make a stream request and return an AsyncIterator to iterate over response | ||||||
|         messages. |         messages. | ||||||
|   | |||||||
| @@ -252,6 +252,7 @@ class OutputTemplate: | |||||||
|     enums: List["EnumDefinitionCompiler"] = field(default_factory=list) |     enums: List["EnumDefinitionCompiler"] = field(default_factory=list) | ||||||
|     services: List["ServiceCompiler"] = field(default_factory=list) |     services: List["ServiceCompiler"] = field(default_factory=list) | ||||||
|     imports_type_checking_only: Set[str] = field(default_factory=set) |     imports_type_checking_only: Set[str] = field(default_factory=set) | ||||||
|  |     output: bool = True | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def package(self) -> str: |     def package(self) -> str: | ||||||
| @@ -704,6 +705,7 @@ class ServiceMethodCompiler(ProtoContentBase): | |||||||
|  |  | ||||||
|         # add imports required for request arguments timeout, deadline and metadata |         # add imports required for request arguments timeout, deadline and metadata | ||||||
|         self.output_file.typing_imports.add("Optional") |         self.output_file.typing_imports.add("Optional") | ||||||
|  |         self.output_file.imports_type_checking_only.add("import grpclib.server") | ||||||
|         self.output_file.imports_type_checking_only.add( |         self.output_file.imports_type_checking_only.add( | ||||||
|             "from betterproto.grpc.grpclib_client import MetadataLike" |             "from betterproto.grpc.grpclib_client import MetadataLike" | ||||||
|         ) |         ) | ||||||
| @@ -768,6 +770,7 @@ class ServiceMethodCompiler(ProtoContentBase): | |||||||
|             package=self.output_file.package, |             package=self.output_file.package, | ||||||
|             imports=self.output_file.imports, |             imports=self.output_file.imports, | ||||||
|             source_type=self.proto_obj.input_type, |             source_type=self.proto_obj.input_type, | ||||||
|  |             unwrap=False, | ||||||
|         ).strip('"') |         ).strip('"') | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|   | |||||||
| @@ -74,14 +74,6 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: | |||||||
|     request_data = PluginRequestCompiler(plugin_request_obj=request) |     request_data = PluginRequestCompiler(plugin_request_obj=request) | ||||||
|     # Gather output packages |     # Gather output packages | ||||||
|     for proto_file in request.proto_file: |     for proto_file in request.proto_file: | ||||||
|         if ( |  | ||||||
|             proto_file.package == "google.protobuf" |  | ||||||
|             and "INCLUDE_GOOGLE" not in plugin_options |  | ||||||
|         ): |  | ||||||
|             # If not INCLUDE_GOOGLE, |  | ||||||
|             # skip re-compiling Google's well-known types |  | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         output_package_name = proto_file.package |         output_package_name = proto_file.package | ||||||
|         if output_package_name not in request_data.output_packages: |         if output_package_name not in request_data.output_packages: | ||||||
|             # Create a new output if there is no output for this package |             # Create a new output if there is no output for this package | ||||||
| @@ -91,6 +83,14 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: | |||||||
|         # Add this input file to the output corresponding to this package |         # Add this input file to the output corresponding to this package | ||||||
|         request_data.output_packages[output_package_name].input_files.append(proto_file) |         request_data.output_packages[output_package_name].input_files.append(proto_file) | ||||||
|  |  | ||||||
|  |         if ( | ||||||
|  |             proto_file.package == "google.protobuf" | ||||||
|  |             and "INCLUDE_GOOGLE" not in plugin_options | ||||||
|  |         ): | ||||||
|  |             # If not INCLUDE_GOOGLE, | ||||||
|  |             # skip outputting Google's well-known types | ||||||
|  |             request_data.output_packages[output_package_name].output = False | ||||||
|  |  | ||||||
|     # Read Messages and Enums |     # Read Messages and Enums | ||||||
|     # We need to read Messages before Services in so that we can |     # We need to read Messages before Services in so that we can | ||||||
|     # get the references to input/output messages for each service |     # get the references to input/output messages for each service | ||||||
| @@ -113,6 +113,8 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: | |||||||
|     # Generate output files |     # Generate output files | ||||||
|     output_paths: Set[pathlib.Path] = set() |     output_paths: Set[pathlib.Path] = set() | ||||||
|     for output_package_name, output_package in request_data.output_packages.items(): |     for output_package_name, output_package in request_data.output_packages.items(): | ||||||
|  |         if not output_package.output: | ||||||
|  |             continue | ||||||
|  |  | ||||||
|         # Add files to the response object |         # Add files to the response object | ||||||
|         output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") |         output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") | ||||||
|   | |||||||
| @@ -15,13 +15,14 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no | |||||||
| {% endif %} | {% endif %} | ||||||
|  |  | ||||||
| import betterproto | import betterproto | ||||||
|  | {% if output_file.services %} | ||||||
| from betterproto.grpc.grpclib_server import ServiceBase | from betterproto.grpc.grpclib_server import ServiceBase | ||||||
|  | import grpclib | ||||||
|  | {% endif %} | ||||||
|  |  | ||||||
| {% for i in output_file.imports|sort %} | {% for i in output_file.imports|sort %} | ||||||
| {{ i }} | {{ i }} | ||||||
| {% endfor %} | {% endfor %} | ||||||
| {% if output_file.services %} |  | ||||||
| import grpclib |  | ||||||
| {% endif %} |  | ||||||
|  |  | ||||||
| {% if output_file.imports_type_checking_only %} | {% if output_file.imports_type_checking_only %} | ||||||
| from typing import TYPE_CHECKING | from typing import TYPE_CHECKING | ||||||
| @@ -96,9 +97,11 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): | |||||||
|             {# Client streaming: need a request iterator instead #} |             {# Client streaming: need a request iterator instead #} | ||||||
|             , {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]] |             , {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]] | ||||||
|         {%- endif -%} |         {%- endif -%} | ||||||
|  |             , | ||||||
|  |             * | ||||||
|             , timeout: Optional[float] = None |             , timeout: Optional[float] = None | ||||||
|             , deadline: Optional["Deadline"] = None |             , deadline: Optional["Deadline"] = None | ||||||
|             , metadata: Optional["_MetadataLike"] = None |             , metadata: Optional["MetadataLike"] = None | ||||||
|             ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: |             ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: | ||||||
|         {% if method.comment %} |         {% if method.comment %} | ||||||
| {{ method.comment }} | {{ method.comment }} | ||||||
| @@ -179,7 +182,7 @@ class {{ service.py_name }}Base(ServiceBase): | |||||||
|     {% endfor %} |     {% endfor %} | ||||||
|  |  | ||||||
|     {% for method in service.methods %} |     {% for method in service.methods %} | ||||||
|     async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None: |     async def __rpc_{{ method.py_name }}(self, stream: "grpclib.server.Stream[{{ method.py_input_message_type }}, {{ method.py_output_message_type }}]") -> None: | ||||||
|         {% if not method.client_streaming %} |         {% if not method.client_streaming %} | ||||||
|         request = await stream.recv_message() |         request = await stream.recv_message() | ||||||
|         {% else %} |         {% else %} | ||||||
|   | |||||||
| @@ -9,6 +9,7 @@ xfail = { | |||||||
| } | } | ||||||
|  |  | ||||||
| services = { | services = { | ||||||
|  |     "googletypes_request", | ||||||
|     "googletypes_response", |     "googletypes_response", | ||||||
|     "googletypes_response_embedded", |     "googletypes_response_embedded", | ||||||
|     "service", |     "service", | ||||||
|   | |||||||
							
								
								
									
										29
									
								
								tests/inputs/googletypes_request/googletypes_request.proto
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										29
									
								
								tests/inputs/googletypes_request/googletypes_request.proto
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,29 @@ | |||||||
|  | syntax = "proto3"; | ||||||
|  |  | ||||||
|  | package googletypes_request; | ||||||
|  |  | ||||||
|  | import "google/protobuf/duration.proto"; | ||||||
|  | import "google/protobuf/empty.proto"; | ||||||
|  | import "google/protobuf/timestamp.proto"; | ||||||
|  | import "google/protobuf/wrappers.proto"; | ||||||
|  |  | ||||||
|  | // Tests that google types can be used as params | ||||||
|  |  | ||||||
|  | service Test { | ||||||
|  |     rpc SendDouble (google.protobuf.DoubleValue) returns (Input); | ||||||
|  |     rpc SendFloat (google.protobuf.FloatValue) returns (Input); | ||||||
|  |     rpc SendInt64 (google.protobuf.Int64Value) returns (Input); | ||||||
|  |     rpc SendUInt64 (google.protobuf.UInt64Value) returns (Input); | ||||||
|  |     rpc SendInt32 (google.protobuf.Int32Value) returns (Input); | ||||||
|  |     rpc SendUInt32 (google.protobuf.UInt32Value) returns (Input); | ||||||
|  |     rpc SendBool (google.protobuf.BoolValue) returns (Input); | ||||||
|  |     rpc SendString (google.protobuf.StringValue) returns (Input); | ||||||
|  |     rpc SendBytes (google.protobuf.BytesValue) returns (Input); | ||||||
|  |     rpc SendDatetime (google.protobuf.Timestamp) returns (Input); | ||||||
|  |     rpc SendTimedelta (google.protobuf.Duration) returns (Input); | ||||||
|  |     rpc SendEmpty (google.protobuf.Empty) returns (Input); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | message Input { | ||||||
|  |  | ||||||
|  | } | ||||||
							
								
								
									
										47
									
								
								tests/inputs/googletypes_request/test_googletypes_request.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								tests/inputs/googletypes_request/test_googletypes_request.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | |||||||
|  | from datetime import ( | ||||||
|  |     datetime, | ||||||
|  |     timedelta, | ||||||
|  | ) | ||||||
|  | from typing import ( | ||||||
|  |     Any, | ||||||
|  |     Callable, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | import pytest | ||||||
|  |  | ||||||
|  | import betterproto.lib.google.protobuf as protobuf | ||||||
|  | from tests.mocks import MockChannel | ||||||
|  | from tests.output_betterproto.googletypes_request import ( | ||||||
|  |     Input, | ||||||
|  |     TestStub, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | test_cases = [ | ||||||
|  |     (TestStub.send_double, protobuf.DoubleValue, 2.5), | ||||||
|  |     (TestStub.send_float, protobuf.FloatValue, 2.5), | ||||||
|  |     (TestStub.send_int64, protobuf.Int64Value, -64), | ||||||
|  |     (TestStub.send_u_int64, protobuf.UInt64Value, 64), | ||||||
|  |     (TestStub.send_int32, protobuf.Int32Value, -32), | ||||||
|  |     (TestStub.send_u_int32, protobuf.UInt32Value, 32), | ||||||
|  |     (TestStub.send_bool, protobuf.BoolValue, True), | ||||||
|  |     (TestStub.send_string, protobuf.StringValue, "string"), | ||||||
|  |     (TestStub.send_bytes, protobuf.BytesValue, bytes(0xFF)[0:4]), | ||||||
|  |     (TestStub.send_datetime, protobuf.Timestamp, datetime(2038, 1, 19, 3, 14, 8)), | ||||||
|  |     (TestStub.send_timedelta, protobuf.Duration, timedelta(seconds=123456)), | ||||||
|  | ] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.asyncio | ||||||
|  | @pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases) | ||||||
|  | async def test_channel_receives_wrapped_type( | ||||||
|  |     service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value | ||||||
|  | ): | ||||||
|  |     wrapped_value = wrapper_class() | ||||||
|  |     wrapped_value.value = value | ||||||
|  |     channel = MockChannel(responses=[Input()]) | ||||||
|  |     service = TestStub(channel) | ||||||
|  |  | ||||||
|  |     await service_method(service, wrapped_value) | ||||||
|  |  | ||||||
|  |     assert channel.requests[0]["request"] == type(wrapped_value) | ||||||
		Reference in New Issue
	
	Block a user