Client and Service Stubs take 1 request parameter, not one for each field (#311)
This commit is contained in:
parent
6dd7baa26c
commit
d260f071e0
1
.gitignore
vendored
1
.gitignore
vendored
@ -17,3 +17,4 @@ output
|
|||||||
.venv
|
.venv
|
||||||
.asv
|
.asv
|
||||||
venv
|
venv
|
||||||
|
.devcontainer
|
||||||
|
26
CHANGELOG.md
26
CHANGELOG.md
@ -7,6 +7,32 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
- Versions suffixed with `b*` are in `beta` and can be installed with `pip install --pre betterproto`.
|
- Versions suffixed with `b*` are in `beta` and can be installed with `pip install --pre betterproto`.
|
||||||
|
|
||||||
|
## [Unreleased]
|
||||||
|
|
||||||
|
- **Breaking**: Client and Service Stubs no longer pack and unpack the input message fields as parameters.
|
||||||
|
|
||||||
|
Update your client calls and server handlers as follows:
|
||||||
|
|
||||||
|
Clients before:
|
||||||
|
```py
|
||||||
|
response = await service.echo(value="hello", extra_times=1)
|
||||||
|
```
|
||||||
|
Clients after:
|
||||||
|
```py
|
||||||
|
response = await service.echo(EchoRequest(value="hello", extra_times=1))
|
||||||
|
```
|
||||||
|
Servers before:
|
||||||
|
```py
|
||||||
|
async def echo(self, value: str, extra_times: int) -> EchoResponse:
|
||||||
|
```
|
||||||
|
Servers after:
|
||||||
|
```py
|
||||||
|
async def echo(self, echo_request: EchoRequest) -> EchoResponse:
|
||||||
|
# Use echo_request.value
|
||||||
|
# Use echo_request.extra_times
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## [2.0.0b4] - 2022-01-03
|
## [2.0.0b4] - 2022-01-03
|
||||||
|
|
||||||
- **Breaking**: the minimum Python version has been bumped to `3.6.2`
|
- **Breaking**: the minimum Python version has been bumped to `3.6.2`
|
||||||
|
16
README.md
16
README.md
@ -177,10 +177,10 @@ from grpclib.client import Channel
|
|||||||
async def main():
|
async def main():
|
||||||
channel = Channel(host="127.0.0.1", port=50051)
|
channel = Channel(host="127.0.0.1", port=50051)
|
||||||
service = echo.EchoStub(channel)
|
service = echo.EchoStub(channel)
|
||||||
response = await service.echo(value="hello", extra_times=1)
|
response = await service.echo(echo.EchoRequest(value="hello", extra_times=1))
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
async for response in service.echo_stream(value="hello", extra_times=1):
|
async for response in service.echo_stream(echo.EchoRequest(value="hello", extra_times=1)):
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
# don't forget to close the channel when done!
|
# don't forget to close the channel when done!
|
||||||
@ -206,18 +206,18 @@ service methods:
|
|||||||
|
|
||||||
```python
|
```python
|
||||||
import asyncio
|
import asyncio
|
||||||
from echo import EchoBase, EchoResponse, EchoStreamResponse
|
from echo import EchoBase, EchoRequest, EchoResponse, EchoStreamResponse
|
||||||
from grpclib.server import Server
|
from grpclib.server import Server
|
||||||
from typing import AsyncIterator
|
from typing import AsyncIterator
|
||||||
|
|
||||||
|
|
||||||
class EchoService(EchoBase):
|
class EchoService(EchoBase):
|
||||||
async def echo(self, value: str, extra_times: int) -> "EchoResponse":
|
async def echo(self, echo_request: "EchoRequest") -> "EchoResponse":
|
||||||
return EchoResponse([value for _ in range(extra_times)])
|
return EchoResponse([echo_request.value for _ in range(echo_request.extra_times)])
|
||||||
|
|
||||||
async def echo_stream(self, value: str, extra_times: int) -> AsyncIterator["EchoStreamResponse"]:
|
async def echo_stream(self, echo_request: "EchoRequest") -> AsyncIterator["EchoStreamResponse"]:
|
||||||
for _ in range(extra_times):
|
for _ in range(echo_request.extra_times):
|
||||||
yield EchoStreamResponse(value)
|
yield EchoStreamResponse(echo_request.value)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
@ -111,7 +111,7 @@ omit = ["betterproto/tests/*"]
|
|||||||
legacy_tox_ini = """
|
legacy_tox_ini = """
|
||||||
[tox]
|
[tox]
|
||||||
isolated_build = true
|
isolated_build = true
|
||||||
envlist = py36, py37, py38
|
envlist = py36, py37, py38, py310
|
||||||
|
|
||||||
[testenv]
|
[testenv]
|
||||||
whitelist_externals = poetry
|
whitelist_externals = poetry
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from abc import ABC
|
from abc import ABC
|
||||||
from collections.abc import AsyncIterable
|
from collections.abc import AsyncIterable
|
||||||
from typing import Callable, Any, Dict
|
from typing import Any, Callable, Dict
|
||||||
|
|
||||||
import grpclib
|
import grpclib
|
||||||
import grpclib.server
|
import grpclib.server
|
||||||
@ -15,10 +15,10 @@ class ServiceBase(ABC):
|
|||||||
self,
|
self,
|
||||||
handler: Callable,
|
handler: Callable,
|
||||||
stream: grpclib.server.Stream,
|
stream: grpclib.server.Stream,
|
||||||
request_kwargs: Dict[str, Any],
|
request: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
||||||
response_iter = handler(**request_kwargs)
|
response_iter = handler(request)
|
||||||
# check if response is actually an AsyncIterator
|
# check if response is actually an AsyncIterator
|
||||||
# this might be false if the method just returns without
|
# this might be false if the method just returns without
|
||||||
# yielding at least once
|
# yielding at least once
|
||||||
|
@ -31,13 +31,15 @@ reference to `A` to `B`'s `fields` attribute.
|
|||||||
|
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
|
import re
|
||||||
|
import textwrap
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union
|
||||||
|
|
||||||
import betterproto
|
import betterproto
|
||||||
from betterproto import which_one_of
|
from betterproto import which_one_of
|
||||||
from betterproto.casing import sanitize_name
|
from betterproto.casing import sanitize_name
|
||||||
from betterproto.compile.importing import (
|
from betterproto.compile.importing import get_type_reference, parse_source_type_name
|
||||||
get_type_reference,
|
|
||||||
parse_source_type_name,
|
|
||||||
)
|
|
||||||
from betterproto.compile.naming import (
|
from betterproto.compile.naming import (
|
||||||
pythonize_class_name,
|
pythonize_class_name,
|
||||||
pythonize_field_name,
|
pythonize_field_name,
|
||||||
@ -46,21 +48,15 @@ from betterproto.compile.naming import (
|
|||||||
from betterproto.lib.google.protobuf import (
|
from betterproto.lib.google.protobuf import (
|
||||||
DescriptorProto,
|
DescriptorProto,
|
||||||
EnumDescriptorProto,
|
EnumDescriptorProto,
|
||||||
FileDescriptorProto,
|
|
||||||
MethodDescriptorProto,
|
|
||||||
Field,
|
Field,
|
||||||
FieldDescriptorProto,
|
FieldDescriptorProto,
|
||||||
FieldDescriptorProtoType,
|
|
||||||
FieldDescriptorProtoLabel,
|
FieldDescriptorProtoLabel,
|
||||||
|
FieldDescriptorProtoType,
|
||||||
|
FileDescriptorProto,
|
||||||
|
MethodDescriptorProto,
|
||||||
)
|
)
|
||||||
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
|
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
|
||||||
|
|
||||||
|
|
||||||
import re
|
|
||||||
import textwrap
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union
|
|
||||||
|
|
||||||
from ..casing import sanitize_name
|
from ..casing import sanitize_name
|
||||||
from ..compile.importing import get_type_reference, parse_source_type_name
|
from ..compile.importing import get_type_reference, parse_source_type_name
|
||||||
from ..compile.naming import (
|
from ..compile.naming import (
|
||||||
@ -69,7 +65,6 @@ from ..compile.naming import (
|
|||||||
pythonize_method_name,
|
pythonize_method_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Create a unique placeholder to deal with
|
# Create a unique placeholder to deal with
|
||||||
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
|
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
|
||||||
PLACEHOLDER = object()
|
PLACEHOLDER = object()
|
||||||
@ -675,12 +670,8 @@ class ServiceMethodCompiler(ProtoContentBase):
|
|||||||
self.parent.methods.append(self)
|
self.parent.methods.append(self)
|
||||||
|
|
||||||
# Check for imports
|
# Check for imports
|
||||||
if self.py_input_message:
|
|
||||||
for f in self.py_input_message.fields:
|
|
||||||
f.add_imports_to(self.output_file)
|
|
||||||
if "Optional" in self.py_output_message_type:
|
if "Optional" in self.py_output_message_type:
|
||||||
self.output_file.typing_imports.add("Optional")
|
self.output_file.typing_imports.add("Optional")
|
||||||
self.mutable_default_args # ensure this is called before rendering
|
|
||||||
|
|
||||||
# Check for Async imports
|
# Check for Async imports
|
||||||
if self.client_streaming:
|
if self.client_streaming:
|
||||||
@ -694,37 +685,6 @@ class ServiceMethodCompiler(ProtoContentBase):
|
|||||||
|
|
||||||
super().__post_init__() # check for unset fields
|
super().__post_init__() # check for unset fields
|
||||||
|
|
||||||
@property
|
|
||||||
def mutable_default_args(self) -> Dict[str, str]:
|
|
||||||
"""Handle mutable default arguments.
|
|
||||||
|
|
||||||
Returns a list of tuples containing the name and default value
|
|
||||||
for arguments to this message who's default value is mutable.
|
|
||||||
The defaults are swapped out for None and replaced back inside
|
|
||||||
the method's body.
|
|
||||||
Reference:
|
|
||||||
https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Dict[str, str]
|
|
||||||
Name and actual default value (as a string)
|
|
||||||
for each argument with mutable default values.
|
|
||||||
"""
|
|
||||||
mutable_default_args = {}
|
|
||||||
|
|
||||||
if self.py_input_message:
|
|
||||||
for f in self.py_input_message.fields:
|
|
||||||
if (
|
|
||||||
not self.client_streaming
|
|
||||||
and f.default_value_string != "None"
|
|
||||||
and f.mutable
|
|
||||||
):
|
|
||||||
mutable_default_args[f.py_name] = f.default_value_string
|
|
||||||
self.output_file.typing_imports.add("Optional")
|
|
||||||
|
|
||||||
return mutable_default_args
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def py_name(self) -> str:
|
def py_name(self) -> str:
|
||||||
"""Pythonized method name."""
|
"""Pythonized method name."""
|
||||||
@ -782,6 +742,17 @@ class ServiceMethodCompiler(ProtoContentBase):
|
|||||||
source_type=self.proto_obj.input_type,
|
source_type=self.proto_obj.input_type,
|
||||||
).strip('"')
|
).strip('"')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def py_input_message_param(self) -> str:
|
||||||
|
"""Param name corresponding to py_input_message_type.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
Param name corresponding to py_input_message_type.
|
||||||
|
"""
|
||||||
|
return pythonize_field_name(self.py_input_message_type)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def py_output_message_type(self) -> str:
|
def py_output_message_type(self) -> str:
|
||||||
"""String representation of the Python type corresponding to the
|
"""String representation of the Python type corresponding to the
|
||||||
|
@ -79,51 +79,21 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
|||||||
{% for method in service.methods %}
|
{% for method in service.methods %}
|
||||||
async def {{ method.py_name }}(self
|
async def {{ method.py_name }}(self
|
||||||
{%- if not method.client_streaming -%}
|
{%- if not method.client_streaming -%}
|
||||||
{%- if method.py_input_message and method.py_input_message.fields -%}, *,
|
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
|
||||||
{%- for field in method.py_input_message.fields -%}
|
|
||||||
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
|
|
||||||
Optional[{{ field.annotation }}]
|
|
||||||
{%- else -%}
|
|
||||||
{{ field.annotation }}
|
|
||||||
{%- endif -%} =
|
|
||||||
{%- if field.py_name not in method.mutable_default_args -%}
|
|
||||||
{{ field.default_value_string }}
|
|
||||||
{%- else -%}
|
|
||||||
None
|
|
||||||
{% endif -%}
|
|
||||||
{%- if not loop.last %}, {% endif -%}
|
|
||||||
{%- endfor -%}
|
|
||||||
{%- endif -%}
|
|
||||||
{%- else -%}
|
{%- else -%}
|
||||||
{# Client streaming: need a request iterator instead #}
|
{# Client streaming: need a request iterator instead #}
|
||||||
, request_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
|
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
|
||||||
{%- endif -%}
|
{%- endif -%}
|
||||||
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||||
{% if method.comment %}
|
{% if method.comment %}
|
||||||
{{ method.comment }}
|
{{ method.comment }}
|
||||||
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{%- for py_name, zero in method.mutable_default_args.items() %}
|
|
||||||
{{ py_name }} = {{ py_name }} or {{ zero }}
|
|
||||||
{% endfor %}
|
|
||||||
|
|
||||||
{% if not method.client_streaming %}
|
|
||||||
request = {{ method.py_input_message_type }}()
|
|
||||||
{% for field in method.py_input_message.fields %}
|
|
||||||
{% if field.field_type == 'message' %}
|
|
||||||
if {{ field.py_name }} is not None:
|
|
||||||
request.{{ field.py_name }} = {{ field.py_name }}
|
|
||||||
{% else %}
|
|
||||||
request.{{ field.py_name }} = {{ field.py_name }}
|
|
||||||
{% endif %}
|
|
||||||
{% endfor %}
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
{% if method.server_streaming %}
|
{% if method.server_streaming %}
|
||||||
{% if method.client_streaming %}
|
{% if method.client_streaming %}
|
||||||
async for response in self._stream_stream(
|
async for response in self._stream_stream(
|
||||||
"{{ method.route }}",
|
"{{ method.route }}",
|
||||||
request_iterator,
|
{{ method.py_input_message_param }}_iterator,
|
||||||
{{ method.py_input_message_type }},
|
{{ method.py_input_message_type }},
|
||||||
{{ method.py_output_message_type.strip('"') }},
|
{{ method.py_output_message_type.strip('"') }},
|
||||||
):
|
):
|
||||||
@ -131,7 +101,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
|||||||
{% else %}{# i.e. not client streaming #}
|
{% else %}{# i.e. not client streaming #}
|
||||||
async for response in self._unary_stream(
|
async for response in self._unary_stream(
|
||||||
"{{ method.route }}",
|
"{{ method.route }}",
|
||||||
request,
|
{{ method.py_input_message_param }},
|
||||||
{{ method.py_output_message_type.strip('"') }},
|
{{ method.py_output_message_type.strip('"') }},
|
||||||
):
|
):
|
||||||
yield response
|
yield response
|
||||||
@ -141,14 +111,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
|||||||
{% if method.client_streaming %}
|
{% if method.client_streaming %}
|
||||||
return await self._stream_unary(
|
return await self._stream_unary(
|
||||||
"{{ method.route }}",
|
"{{ method.route }}",
|
||||||
request_iterator,
|
{{ method.py_input_message_param }}_iterator,
|
||||||
{{ method.py_input_message_type }},
|
{{ method.py_input_message_type }},
|
||||||
{{ method.py_output_message_type.strip('"') }}
|
{{ method.py_output_message_type.strip('"') }}
|
||||||
)
|
)
|
||||||
{% else %}{# i.e. not client streaming #}
|
{% else %}{# i.e. not client streaming #}
|
||||||
return await self._unary_unary(
|
return await self._unary_unary(
|
||||||
"{{ method.route }}",
|
"{{ method.route }}",
|
||||||
request,
|
{{ method.py_input_message_param }},
|
||||||
{{ method.py_output_message_type.strip('"') }}
|
{{ method.py_output_message_type.strip('"') }}
|
||||||
)
|
)
|
||||||
{% endif %}{# client streaming #}
|
{% endif %}{# client streaming #}
|
||||||
@ -167,19 +137,10 @@ class {{ service.py_name }}Base(ServiceBase):
|
|||||||
{% for method in service.methods %}
|
{% for method in service.methods %}
|
||||||
async def {{ method.py_name }}(self
|
async def {{ method.py_name }}(self
|
||||||
{%- if not method.client_streaming -%}
|
{%- if not method.client_streaming -%}
|
||||||
{%- if method.py_input_message and method.py_input_message.fields -%},
|
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
|
||||||
{%- for field in method.py_input_message.fields -%}
|
|
||||||
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
|
|
||||||
Optional[{{ field.annotation }}]
|
|
||||||
{%- else -%}
|
|
||||||
{{ field.annotation }}
|
|
||||||
{%- endif -%}
|
|
||||||
{%- if not loop.last %}, {% endif -%}
|
|
||||||
{%- endfor -%}
|
|
||||||
{%- endif -%}
|
|
||||||
{%- else -%}
|
{%- else -%}
|
||||||
{# Client streaming: need a request iterator instead #}
|
{# Client streaming: need a request iterator instead #}
|
||||||
, request_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
|
, {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
|
||||||
{%- endif -%}
|
{%- endif -%}
|
||||||
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||||
{% if method.comment %}
|
{% if method.comment %}
|
||||||
@ -194,25 +155,17 @@ class {{ service.py_name }}Base(ServiceBase):
|
|||||||
async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None:
|
async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None:
|
||||||
{% if not method.client_streaming %}
|
{% if not method.client_streaming %}
|
||||||
request = await stream.recv_message()
|
request = await stream.recv_message()
|
||||||
|
|
||||||
request_kwargs = {
|
|
||||||
{% for field in method.py_input_message.fields %}
|
|
||||||
"{{ field.py_name }}": request.{{ field.py_name }},
|
|
||||||
{% endfor %}
|
|
||||||
}
|
|
||||||
|
|
||||||
{% else %}
|
{% else %}
|
||||||
request_kwargs = {"request_iterator": stream.__aiter__()}
|
request = stream.__aiter__()
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
{% if not method.server_streaming %}
|
{% if not method.server_streaming %}
|
||||||
response = await self.{{ method.py_name }}(**request_kwargs)
|
response = await self.{{ method.py_name }}(request)
|
||||||
await stream.send_message(response)
|
await stream.send_message(response)
|
||||||
{% else %}
|
{% else %}
|
||||||
await self._call_rpc_handler_server_stream(
|
await self._call_rpc_handler_server_stream(
|
||||||
self.{{ method.py_name }},
|
self.{{ method.py_name }},
|
||||||
stream,
|
stream,
|
||||||
request_kwargs,
|
request,
|
||||||
)
|
)
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
|
@ -1,23 +1,24 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import grpclib
|
||||||
|
import grpclib.metadata
|
||||||
|
import grpclib.server
|
||||||
|
import pytest
|
||||||
|
from betterproto.grpc.util.async_channel import AsyncChannel
|
||||||
|
from grpclib.testing import ChannelFor
|
||||||
from tests.output_betterproto.service.service import (
|
from tests.output_betterproto.service.service import (
|
||||||
DoThingRequest,
|
DoThingRequest,
|
||||||
DoThingResponse,
|
DoThingResponse,
|
||||||
GetThingRequest,
|
GetThingRequest,
|
||||||
TestStub as ThingServiceClient,
|
|
||||||
)
|
)
|
||||||
import grpclib
|
from tests.output_betterproto.service.service import TestStub as ThingServiceClient
|
||||||
import grpclib.metadata
|
|
||||||
import grpclib.server
|
|
||||||
from grpclib.testing import ChannelFor
|
|
||||||
import pytest
|
|
||||||
from betterproto.grpc.util.async_channel import AsyncChannel
|
|
||||||
from .thing_service import ThingService
|
from .thing_service import ThingService
|
||||||
|
|
||||||
|
|
||||||
async def _test_client(client, name="clean room", **kwargs):
|
async def _test_client(client: ThingServiceClient, name="clean room", **kwargs):
|
||||||
response = await client.do_thing(name=name)
|
response = await client.do_thing(DoThingRequest(name=name))
|
||||||
assert response.names == [name]
|
assert response.names == [name]
|
||||||
|
|
||||||
|
|
||||||
@ -62,7 +63,7 @@ async def test_trailer_only_error_unary_unary(
|
|||||||
)
|
)
|
||||||
async with ChannelFor([service]) as channel:
|
async with ChannelFor([service]) as channel:
|
||||||
with pytest.raises(grpclib.exceptions.GRPCError) as e:
|
with pytest.raises(grpclib.exceptions.GRPCError) as e:
|
||||||
await ThingServiceClient(channel).do_thing(name="something")
|
await ThingServiceClient(channel).do_thing(DoThingRequest(name="something"))
|
||||||
assert e.value.status == grpclib.Status.UNAUTHENTICATED
|
assert e.value.status == grpclib.Status.UNAUTHENTICATED
|
||||||
|
|
||||||
|
|
||||||
@ -80,7 +81,7 @@ async def test_trailer_only_error_stream_unary(
|
|||||||
async with ChannelFor([service]) as channel:
|
async with ChannelFor([service]) as channel:
|
||||||
with pytest.raises(grpclib.exceptions.GRPCError) as e:
|
with pytest.raises(grpclib.exceptions.GRPCError) as e:
|
||||||
await ThingServiceClient(channel).do_many_things(
|
await ThingServiceClient(channel).do_many_things(
|
||||||
request_iterator=[DoThingRequest(name="something")]
|
do_thing_request_iterator=[DoThingRequest(name="something")]
|
||||||
)
|
)
|
||||||
await _test_client(ThingServiceClient(channel))
|
await _test_client(ThingServiceClient(channel))
|
||||||
assert e.value.status == grpclib.Status.UNAUTHENTICATED
|
assert e.value.status == grpclib.Status.UNAUTHENTICATED
|
||||||
@ -178,7 +179,9 @@ async def test_async_gen_for_unary_stream_request():
|
|||||||
async with ChannelFor([ThingService()]) as channel:
|
async with ChannelFor([ThingService()]) as channel:
|
||||||
client = ThingServiceClient(channel)
|
client = ThingServiceClient(channel)
|
||||||
expected_versions = [5, 4, 3, 2, 1]
|
expected_versions = [5, 4, 3, 2, 1]
|
||||||
async for response in client.get_thing_versions(name=thing_name):
|
async for response in client.get_thing_versions(
|
||||||
|
GetThingRequest(name=thing_name)
|
||||||
|
):
|
||||||
assert response.name == thing_name
|
assert response.name == thing_name
|
||||||
assert response.version == expected_versions.pop()
|
assert response.version == expected_versions.pop()
|
||||||
|
|
||||||
|
@ -1,49 +1,48 @@
|
|||||||
from typing import AsyncIterator, AsyncIterable
|
from typing import AsyncIterable, AsyncIterator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from grpclib.testing import ChannelFor
|
from grpclib.testing import ChannelFor
|
||||||
|
|
||||||
from tests.output_betterproto.example_service.example_service import (
|
from tests.output_betterproto.example_service.example_service import (
|
||||||
TestBase,
|
|
||||||
TestStub,
|
|
||||||
ExampleRequest,
|
ExampleRequest,
|
||||||
ExampleResponse,
|
ExampleResponse,
|
||||||
|
TestBase,
|
||||||
|
TestStub,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExampleService(TestBase):
|
class ExampleService(TestBase):
|
||||||
async def example_unary_unary(
|
async def example_unary_unary(
|
||||||
self, example_string: str, example_integer: int
|
self, example_request: ExampleRequest
|
||||||
) -> "ExampleResponse":
|
) -> "ExampleResponse":
|
||||||
return ExampleResponse(
|
return ExampleResponse(
|
||||||
example_string=example_string,
|
example_string=example_request.example_string,
|
||||||
example_integer=example_integer,
|
example_integer=example_request.example_integer,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def example_unary_stream(
|
async def example_unary_stream(
|
||||||
self, example_string: str, example_integer: int
|
self, example_request: ExampleRequest
|
||||||
) -> AsyncIterator["ExampleResponse"]:
|
) -> AsyncIterator["ExampleResponse"]:
|
||||||
response = ExampleResponse(
|
response = ExampleResponse(
|
||||||
example_string=example_string,
|
example_string=example_request.example_string,
|
||||||
example_integer=example_integer,
|
example_integer=example_request.example_integer,
|
||||||
)
|
)
|
||||||
yield response
|
yield response
|
||||||
yield response
|
yield response
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
async def example_stream_unary(
|
async def example_stream_unary(
|
||||||
self, request_iterator: AsyncIterator["ExampleRequest"]
|
self, example_request_iterator: AsyncIterator["ExampleRequest"]
|
||||||
) -> "ExampleResponse":
|
) -> "ExampleResponse":
|
||||||
async for example_request in request_iterator:
|
async for example_request in example_request_iterator:
|
||||||
return ExampleResponse(
|
return ExampleResponse(
|
||||||
example_string=example_request.example_string,
|
example_string=example_request.example_string,
|
||||||
example_integer=example_request.example_integer,
|
example_integer=example_request.example_integer,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def example_stream_stream(
|
async def example_stream_stream(
|
||||||
self, request_iterator: AsyncIterator["ExampleRequest"]
|
self, example_request_iterator: AsyncIterator["ExampleRequest"]
|
||||||
) -> AsyncIterator["ExampleResponse"]:
|
) -> AsyncIterator["ExampleResponse"]:
|
||||||
async for example_request in request_iterator:
|
async for example_request in example_request_iterator:
|
||||||
yield ExampleResponse(
|
yield ExampleResponse(
|
||||||
example_string=example_request.example_string,
|
example_string=example_request.example_string,
|
||||||
example_integer=example_request.example_integer,
|
example_integer=example_request.example_integer,
|
||||||
@ -52,44 +51,32 @@ class ExampleService(TestBase):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_calls_with_different_cardinalities():
|
async def test_calls_with_different_cardinalities():
|
||||||
test_string = "test string"
|
example_request = ExampleRequest("test string", 42)
|
||||||
test_int = 42
|
|
||||||
|
|
||||||
async with ChannelFor([ExampleService()]) as channel:
|
async with ChannelFor([ExampleService()]) as channel:
|
||||||
stub = TestStub(channel)
|
stub = TestStub(channel)
|
||||||
|
|
||||||
# unary unary
|
# unary unary
|
||||||
response = await stub.example_unary_unary(
|
response = await stub.example_unary_unary(example_request)
|
||||||
example_string="test string",
|
assert response.example_string == example_request.example_string
|
||||||
example_integer=42,
|
assert response.example_integer == example_request.example_integer
|
||||||
)
|
|
||||||
assert response.example_string == test_string
|
|
||||||
assert response.example_integer == test_int
|
|
||||||
|
|
||||||
# unary stream
|
# unary stream
|
||||||
async for response in stub.example_unary_stream(
|
async for response in stub.example_unary_stream(example_request):
|
||||||
example_string="test string",
|
assert response.example_string == example_request.example_string
|
||||||
example_integer=42,
|
assert response.example_integer == example_request.example_integer
|
||||||
):
|
|
||||||
assert response.example_string == test_string
|
|
||||||
assert response.example_integer == test_int
|
|
||||||
|
|
||||||
# stream unary
|
# stream unary
|
||||||
request = ExampleRequest(
|
|
||||||
example_string=test_string,
|
|
||||||
example_integer=42,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def request_iterator():
|
async def request_iterator():
|
||||||
yield request
|
yield example_request
|
||||||
yield request
|
yield example_request
|
||||||
yield request
|
yield example_request
|
||||||
|
|
||||||
response = await stub.example_stream_unary(request_iterator())
|
response = await stub.example_stream_unary(request_iterator())
|
||||||
assert response.example_string == test_string
|
assert response.example_string == example_request.example_string
|
||||||
assert response.example_integer == test_int
|
assert response.example_integer == example_request.example_integer
|
||||||
|
|
||||||
# stream stream
|
# stream stream
|
||||||
async for response in stub.example_stream_stream(request_iterator()):
|
async for response in stub.example_stream_stream(request_iterator()):
|
||||||
assert response.example_string == test_string
|
assert response.example_string == example_request.example_string
|
||||||
assert response.example_integer == test_int
|
assert response.example_integer == example_request.example_integer
|
||||||
|
@ -2,9 +2,8 @@ from typing import Any, Callable, Optional
|
|||||||
|
|
||||||
import betterproto.lib.google.protobuf as protobuf
|
import betterproto.lib.google.protobuf as protobuf
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.mocks import MockChannel
|
from tests.mocks import MockChannel
|
||||||
from tests.output_betterproto.googletypes_response import TestStub
|
from tests.output_betterproto.googletypes_response import Input, TestStub
|
||||||
|
|
||||||
test_cases = [
|
test_cases = [
|
||||||
(TestStub.get_double, protobuf.DoubleValue, 2.5),
|
(TestStub.get_double, protobuf.DoubleValue, 2.5),
|
||||||
@ -22,14 +21,15 @@ test_cases = [
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||||
async def test_channel_receives_wrapped_type(
|
async def test_channel_receives_wrapped_type(
|
||||||
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
|
||||||
):
|
):
|
||||||
wrapped_value = wrapper_class()
|
wrapped_value = wrapper_class()
|
||||||
wrapped_value.value = value
|
wrapped_value.value = value
|
||||||
channel = MockChannel(responses=[wrapped_value])
|
channel = MockChannel(responses=[wrapped_value])
|
||||||
service = TestStub(channel)
|
service = TestStub(channel)
|
||||||
|
method_param = Input()
|
||||||
|
|
||||||
await service_method(service)
|
await service_method(service, method_param)
|
||||||
|
|
||||||
assert channel.requests[0]["response_type"] != Optional[type(value)]
|
assert channel.requests[0]["response_type"] != Optional[type(value)]
|
||||||
assert channel.requests[0]["response_type"] == type(wrapped_value)
|
assert channel.requests[0]["response_type"] == type(wrapped_value)
|
||||||
@ -39,7 +39,7 @@ async def test_channel_receives_wrapped_type(
|
|||||||
@pytest.mark.xfail
|
@pytest.mark.xfail
|
||||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||||
async def test_service_unwraps_response(
|
async def test_service_unwraps_response(
|
||||||
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
grpclib does not unwrap wrapper values returned by services
|
grpclib does not unwrap wrapper values returned by services
|
||||||
@ -47,8 +47,9 @@ async def test_service_unwraps_response(
|
|||||||
wrapped_value = wrapper_class()
|
wrapped_value = wrapper_class()
|
||||||
wrapped_value.value = value
|
wrapped_value.value = value
|
||||||
service = TestStub(MockChannel(responses=[wrapped_value]))
|
service = TestStub(MockChannel(responses=[wrapped_value]))
|
||||||
|
method_param = Input()
|
||||||
|
|
||||||
response_value = await service_method(service)
|
response_value = await service_method(service, method_param)
|
||||||
|
|
||||||
assert response_value == value
|
assert response_value == value
|
||||||
assert type(response_value) == type(value)
|
assert type(response_value) == type(value)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.mocks import MockChannel
|
from tests.mocks import MockChannel
|
||||||
from tests.output_betterproto.googletypes_response_embedded import (
|
from tests.output_betterproto.googletypes_response_embedded import (
|
||||||
|
Input,
|
||||||
Output,
|
Output,
|
||||||
TestStub,
|
TestStub,
|
||||||
)
|
)
|
||||||
@ -26,7 +26,7 @@ async def test_service_passes_through_unwrapped_values_embedded_in_response():
|
|||||||
)
|
)
|
||||||
|
|
||||||
service = TestStub(MockChannel(responses=[output]))
|
service = TestStub(MockChannel(responses=[output]))
|
||||||
response = await service.get_output()
|
response = await service.get_output(Input())
|
||||||
|
|
||||||
assert response.double_value == 10.0
|
assert response.double_value == 10.0
|
||||||
assert response.float_value == 12.0
|
assert response.float_value == 12.0
|
||||||
|
@ -1,17 +1,21 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from tests.mocks import MockChannel
|
from tests.mocks import MockChannel
|
||||||
from tests.output_betterproto.import_service_input_message import (
|
from tests.output_betterproto.import_service_input_message import (
|
||||||
|
NestedRequestMessage,
|
||||||
|
RequestMessage,
|
||||||
RequestResponse,
|
RequestResponse,
|
||||||
TestStub,
|
TestStub,
|
||||||
)
|
)
|
||||||
|
from tests.output_betterproto.import_service_input_message.child import (
|
||||||
|
ChildRequestMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_service_correctly_imports_reference_message():
|
async def test_service_correctly_imports_reference_message():
|
||||||
mock_response = RequestResponse(value=10)
|
mock_response = RequestResponse(value=10)
|
||||||
service = TestStub(MockChannel([mock_response]))
|
service = TestStub(MockChannel([mock_response]))
|
||||||
response = await service.do_thing(argument=1)
|
response = await service.do_thing(RequestMessage(1))
|
||||||
assert mock_response == response
|
assert mock_response == response
|
||||||
|
|
||||||
|
|
||||||
@ -19,7 +23,7 @@ async def test_service_correctly_imports_reference_message():
|
|||||||
async def test_service_correctly_imports_reference_message_from_child_package():
|
async def test_service_correctly_imports_reference_message_from_child_package():
|
||||||
mock_response = RequestResponse(value=10)
|
mock_response = RequestResponse(value=10)
|
||||||
service = TestStub(MockChannel([mock_response]))
|
service = TestStub(MockChannel([mock_response]))
|
||||||
response = await service.do_thing2(child_argument=1)
|
response = await service.do_thing2(ChildRequestMessage(1))
|
||||||
assert mock_response == response
|
assert mock_response == response
|
||||||
|
|
||||||
|
|
||||||
@ -27,5 +31,5 @@ async def test_service_correctly_imports_reference_message_from_child_package():
|
|||||||
async def test_service_correctly_imports_nested_reference():
|
async def test_service_correctly_imports_nested_reference():
|
||||||
mock_response = RequestResponse(value=10)
|
mock_response = RequestResponse(value=10)
|
||||||
service = TestStub(MockChannel([mock_response]))
|
service = TestStub(MockChannel([mock_response]))
|
||||||
response = await service.do_thing3(nested_argument=1)
|
response = await service.do_thing3(NestedRequestMessage(1))
|
||||||
assert mock_response == response
|
assert mock_response == response
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import betterproto
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, List, Dict
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from inspect import signature
|
from inspect import Parameter, signature
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
import betterproto
|
||||||
|
|
||||||
|
|
||||||
def test_has_field():
|
def test_has_field():
|
||||||
@ -349,10 +350,8 @@ def test_recursive_message():
|
|||||||
|
|
||||||
|
|
||||||
def test_recursive_message_defaults():
|
def test_recursive_message_defaults():
|
||||||
from tests.output_betterproto.recursivemessage import (
|
from tests.output_betterproto.recursivemessage import Intermediate
|
||||||
Test as RecursiveMessage,
|
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
|
||||||
Intermediate,
|
|
||||||
)
|
|
||||||
|
|
||||||
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
|
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
|
||||||
|
|
||||||
@ -479,8 +478,10 @@ def test_iso_datetime_list():
|
|||||||
assert all([isinstance(item, datetime) for item in msg.timestamps])
|
assert all([isinstance(item, datetime) for item in msg.timestamps])
|
||||||
|
|
||||||
|
|
||||||
def test_enum_service_argument__expected_default_value():
|
def test_service_argument__expected_parameter():
|
||||||
from tests.output_betterproto.service.service import ThingType, TestStub
|
from tests.output_betterproto.service.service import TestStub
|
||||||
|
|
||||||
sig = signature(TestStub.do_thing)
|
sig = signature(TestStub.do_thing)
|
||||||
assert sig.parameters["type"].default == ThingType.UNKNOWN
|
do_thing_request_parameter = sig.parameters["do_thing_request"]
|
||||||
|
assert do_thing_request_parameter.default is Parameter.empty
|
||||||
|
assert do_thing_request_parameter.annotation == "DoThingRequest"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user