Compare commits
4 Commits
v2.0.0b4
...
changelog1
Author | SHA1 | Date | |
---|---|---|---|
|
3eaff291c4 | ||
|
9b5594adbe | ||
|
d991040ff6 | ||
|
d260f071e0 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -17,3 +17,4 @@ output
|
||||
.venv
|
||||
.asv
|
||||
venv
|
||||
.devcontainer
|
||||
|
29
CHANGELOG.md
29
CHANGELOG.md
@@ -7,6 +7,35 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
- Versions suffixed with `b*` are in `beta` and can be installed with `pip install --pre betterproto`.
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
- fix: Format field comments also as docstrings (#304)
|
||||
- fix: Fix message text in NotImplementedError (#325)
|
||||
- **Breaking**: Client and Service Stubs take 1 request parameter, not one for each field (#311)
|
||||
Client and Service Stubs no longer pack and unpack the input message fields as parameters.
|
||||
|
||||
Update your client calls and server handlers as follows:
|
||||
|
||||
Clients before:
|
||||
```py
|
||||
response = await service.echo(value="hello", extra_times=1)
|
||||
```
|
||||
Clients after:
|
||||
```py
|
||||
response = await service.echo(EchoRequest(value="hello", extra_times=1))
|
||||
```
|
||||
Servers before:
|
||||
```py
|
||||
async def echo(self, value: str, extra_times: int) -> EchoResponse:
|
||||
```
|
||||
Servers after:
|
||||
```py
|
||||
async def echo(self, echo_request: EchoRequest) -> EchoResponse:
|
||||
# Use echo_request.value
|
||||
# Use echo_request.extra_times
|
||||
```
|
||||
|
||||
|
||||
## [2.0.0b4] - 2022-01-03
|
||||
|
||||
- **Breaking**: the minimum Python version has been bumped to `3.6.2`
|
||||
|
16
README.md
16
README.md
@@ -177,10 +177,10 @@ from grpclib.client import Channel
|
||||
async def main():
|
||||
channel = Channel(host="127.0.0.1", port=50051)
|
||||
service = echo.EchoStub(channel)
|
||||
response = await service.echo(value="hello", extra_times=1)
|
||||
response = await service.echo(echo.EchoRequest(value="hello", extra_times=1))
|
||||
print(response)
|
||||
|
||||
async for response in service.echo_stream(value="hello", extra_times=1):
|
||||
async for response in service.echo_stream(echo.EchoRequest(value="hello", extra_times=1)):
|
||||
print(response)
|
||||
|
||||
# don't forget to close the channel when done!
|
||||
@@ -206,18 +206,18 @@ service methods:
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from echo import EchoBase, EchoResponse, EchoStreamResponse
|
||||
from echo import EchoBase, EchoRequest, EchoResponse, EchoStreamResponse
|
||||
from grpclib.server import Server
|
||||
from typing import AsyncIterator
|
||||
|
||||
|
||||
class EchoService(EchoBase):
|
||||
async def echo(self, value: str, extra_times: int) -> "EchoResponse":
|
||||
return EchoResponse([value for _ in range(extra_times)])
|
||||
async def echo(self, echo_request: "EchoRequest") -> "EchoResponse":
|
||||
return EchoResponse([echo_request.value for _ in range(echo_request.extra_times)])
|
||||
|
||||
async def echo_stream(self, value: str, extra_times: int) -> AsyncIterator["EchoStreamResponse"]:
|
||||
for _ in range(extra_times):
|
||||
yield EchoStreamResponse(value)
|
||||
async def echo_stream(self, echo_request: "EchoRequest") -> AsyncIterator["EchoStreamResponse"]:
|
||||
for _ in range(echo_request.extra_times):
|
||||
yield EchoStreamResponse(echo_request.value)
|
||||
|
||||
|
||||
async def main():
|
||||
|
@@ -111,7 +111,7 @@ omit = ["betterproto/tests/*"]
|
||||
legacy_tox_ini = """
|
||||
[tox]
|
||||
isolated_build = true
|
||||
envlist = py36, py37, py38
|
||||
envlist = py36, py37, py38, py310
|
||||
|
||||
[testenv]
|
||||
whitelist_externals = poetry
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from abc import ABC
|
||||
from collections.abc import AsyncIterable
|
||||
from typing import Callable, Any, Dict
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
import grpclib
|
||||
import grpclib.server
|
||||
@@ -15,10 +15,10 @@ class ServiceBase(ABC):
|
||||
self,
|
||||
handler: Callable,
|
||||
stream: grpclib.server.Stream,
|
||||
request_kwargs: Dict[str, Any],
|
||||
request: Any,
|
||||
) -> None:
|
||||
|
||||
response_iter = handler(**request_kwargs)
|
||||
response_iter = handler(request)
|
||||
# check if response is actually an AsyncIterator
|
||||
# this might be false if the method just returns without
|
||||
# yielding at least once
|
||||
|
@@ -31,13 +31,15 @@ reference to `A` to `B`'s `fields` attribute.
|
||||
|
||||
|
||||
import builtins
|
||||
import re
|
||||
import textwrap
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union
|
||||
|
||||
import betterproto
|
||||
from betterproto import which_one_of
|
||||
from betterproto.casing import sanitize_name
|
||||
from betterproto.compile.importing import (
|
||||
get_type_reference,
|
||||
parse_source_type_name,
|
||||
)
|
||||
from betterproto.compile.importing import get_type_reference, parse_source_type_name
|
||||
from betterproto.compile.naming import (
|
||||
pythonize_class_name,
|
||||
pythonize_field_name,
|
||||
@@ -46,21 +48,15 @@ from betterproto.compile.naming import (
|
||||
from betterproto.lib.google.protobuf import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FileDescriptorProto,
|
||||
MethodDescriptorProto,
|
||||
Field,
|
||||
FieldDescriptorProto,
|
||||
FieldDescriptorProtoType,
|
||||
FieldDescriptorProtoLabel,
|
||||
FieldDescriptorProtoType,
|
||||
FileDescriptorProto,
|
||||
MethodDescriptorProto,
|
||||
)
|
||||
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
|
||||
|
||||
|
||||
import re
|
||||
import textwrap
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Iterable, Iterator, List, Optional, Set, Type, Union
|
||||
|
||||
from ..casing import sanitize_name
|
||||
from ..compile.importing import get_type_reference, parse_source_type_name
|
||||
from ..compile.naming import (
|
||||
@@ -69,7 +65,6 @@ from ..compile.naming import (
|
||||
pythonize_method_name,
|
||||
)
|
||||
|
||||
|
||||
# Create a unique placeholder to deal with
|
||||
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
|
||||
PLACEHOLDER = object()
|
||||
@@ -147,11 +142,7 @@ def get_comment(
|
||||
sci_loc.leading_comments.strip().replace("\n", ""), width=79 - indent
|
||||
)
|
||||
|
||||
if path[-2] == 2 and path[-4] != 6:
|
||||
# This is a field
|
||||
return f"{pad}# " + f"\n{pad}# ".join(lines)
|
||||
else:
|
||||
# This is a message, enum, service, or method
|
||||
# This is a field, message, enum, service, or method
|
||||
if len(lines) == 1 and len(lines[0]) < 79 - indent - 6:
|
||||
lines[0] = lines[0].strip('"')
|
||||
return f'{pad}"""{lines[0]}"""'
|
||||
@@ -529,7 +520,7 @@ class FieldCompiler(MessageCompiler):
|
||||
source_type=self.proto_obj.type_name,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown type {field.type}")
|
||||
raise NotImplementedError(f"Unknown type {self.proto_obj.type}")
|
||||
|
||||
@property
|
||||
def annotation(self) -> str:
|
||||
@@ -675,12 +666,8 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
self.parent.methods.append(self)
|
||||
|
||||
# Check for imports
|
||||
if self.py_input_message:
|
||||
for f in self.py_input_message.fields:
|
||||
f.add_imports_to(self.output_file)
|
||||
if "Optional" in self.py_output_message_type:
|
||||
self.output_file.typing_imports.add("Optional")
|
||||
self.mutable_default_args # ensure this is called before rendering
|
||||
|
||||
# Check for Async imports
|
||||
if self.client_streaming:
|
||||
@@ -694,37 +681,6 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
|
||||
super().__post_init__() # check for unset fields
|
||||
|
||||
@property
|
||||
def mutable_default_args(self) -> Dict[str, str]:
|
||||
"""Handle mutable default arguments.
|
||||
|
||||
Returns a list of tuples containing the name and default value
|
||||
for arguments to this message who's default value is mutable.
|
||||
The defaults are swapped out for None and replaced back inside
|
||||
the method's body.
|
||||
Reference:
|
||||
https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, str]
|
||||
Name and actual default value (as a string)
|
||||
for each argument with mutable default values.
|
||||
"""
|
||||
mutable_default_args = {}
|
||||
|
||||
if self.py_input_message:
|
||||
for f in self.py_input_message.fields:
|
||||
if (
|
||||
not self.client_streaming
|
||||
and f.default_value_string != "None"
|
||||
and f.mutable
|
||||
):
|
||||
mutable_default_args[f.py_name] = f.default_value_string
|
||||
self.output_file.typing_imports.add("Optional")
|
||||
|
||||
return mutable_default_args
|
||||
|
||||
@property
|
||||
def py_name(self) -> str:
|
||||
"""Pythonized method name."""
|
||||
@@ -782,6 +738,17 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
source_type=self.proto_obj.input_type,
|
||||
).strip('"')
|
||||
|
||||
@property
|
||||
def py_input_message_param(self) -> str:
|
||||
"""Param name corresponding to py_input_message_type.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Param name corresponding to py_input_message_type.
|
||||
"""
|
||||
return pythonize_field_name(self.py_input_message_type)
|
||||
|
||||
@property
|
||||
def py_output_message_type(self) -> str:
|
||||
"""String representation of the Python type corresponding to the
|
||||
|
@@ -28,10 +28,11 @@ class {{ enum.py_name }}(betterproto.Enum):
|
||||
|
||||
{% endif %}
|
||||
{% for entry in enum.entries %}
|
||||
{{ entry.name }} = {{ entry.value }}
|
||||
{% if entry.comment %}
|
||||
{{ entry.comment }}
|
||||
|
||||
{% endif %}
|
||||
{{ entry.name }} = {{ entry.value }}
|
||||
{% endfor %}
|
||||
|
||||
|
||||
@@ -45,10 +46,11 @@ class {{ message.py_name }}(betterproto.Message):
|
||||
|
||||
{% endif %}
|
||||
{% for field in message.fields %}
|
||||
{{ field.get_field_string() }}
|
||||
{% if field.comment %}
|
||||
{{ field.comment }}
|
||||
|
||||
{% endif %}
|
||||
{{ field.get_field_string() }}
|
||||
{% endfor %}
|
||||
{% if not message.fields %}
|
||||
pass
|
||||
@@ -79,51 +81,21 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
{% for method in service.methods %}
|
||||
async def {{ method.py_name }}(self
|
||||
{%- if not method.client_streaming -%}
|
||||
{%- if method.py_input_message and method.py_input_message.fields -%}, *,
|
||||
{%- for field in method.py_input_message.fields -%}
|
||||
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
|
||||
Optional[{{ field.annotation }}]
|
||||
{%- else -%}
|
||||
{{ field.annotation }}
|
||||
{%- endif -%} =
|
||||
{%- if field.py_name not in method.mutable_default_args -%}
|
||||
{{ field.default_value_string }}
|
||||
{%- else -%}
|
||||
None
|
||||
{% endif -%}
|
||||
{%- if not loop.last %}, {% endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
|
||||
{%- else -%}
|
||||
{# Client streaming: need a request iterator instead #}
|
||||
, request_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
|
||||
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
|
||||
{%- endif -%}
|
||||
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||
{% if method.comment %}
|
||||
{{ method.comment }}
|
||||
|
||||
{% endif %}
|
||||
{%- for py_name, zero in method.mutable_default_args.items() %}
|
||||
{{ py_name }} = {{ py_name }} or {{ zero }}
|
||||
{% endfor %}
|
||||
|
||||
{% if not method.client_streaming %}
|
||||
request = {{ method.py_input_message_type }}()
|
||||
{% for field in method.py_input_message.fields %}
|
||||
{% if field.field_type == 'message' %}
|
||||
if {{ field.py_name }} is not None:
|
||||
request.{{ field.py_name }} = {{ field.py_name }}
|
||||
{% else %}
|
||||
request.{{ field.py_name }} = {{ field.py_name }}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if method.server_streaming %}
|
||||
{% if method.client_streaming %}
|
||||
async for response in self._stream_stream(
|
||||
"{{ method.route }}",
|
||||
request_iterator,
|
||||
{{ method.py_input_message_param }}_iterator,
|
||||
{{ method.py_input_message_type }},
|
||||
{{ method.py_output_message_type.strip('"') }},
|
||||
):
|
||||
@@ -131,7 +103,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
{% else %}{# i.e. not client streaming #}
|
||||
async for response in self._unary_stream(
|
||||
"{{ method.route }}",
|
||||
request,
|
||||
{{ method.py_input_message_param }},
|
||||
{{ method.py_output_message_type.strip('"') }},
|
||||
):
|
||||
yield response
|
||||
@@ -141,14 +113,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
{% if method.client_streaming %}
|
||||
return await self._stream_unary(
|
||||
"{{ method.route }}",
|
||||
request_iterator,
|
||||
{{ method.py_input_message_param }}_iterator,
|
||||
{{ method.py_input_message_type }},
|
||||
{{ method.py_output_message_type.strip('"') }}
|
||||
)
|
||||
{% else %}{# i.e. not client streaming #}
|
||||
return await self._unary_unary(
|
||||
"{{ method.route }}",
|
||||
request,
|
||||
{{ method.py_input_message_param }},
|
||||
{{ method.py_output_message_type.strip('"') }}
|
||||
)
|
||||
{% endif %}{# client streaming #}
|
||||
@@ -167,19 +139,10 @@ class {{ service.py_name }}Base(ServiceBase):
|
||||
{% for method in service.methods %}
|
||||
async def {{ method.py_name }}(self
|
||||
{%- if not method.client_streaming -%}
|
||||
{%- if method.py_input_message and method.py_input_message.fields -%},
|
||||
{%- for field in method.py_input_message.fields -%}
|
||||
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
|
||||
Optional[{{ field.annotation }}]
|
||||
{%- else -%}
|
||||
{{ field.annotation }}
|
||||
{%- endif -%}
|
||||
{%- if not loop.last %}, {% endif -%}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
|
||||
{%- else -%}
|
||||
{# Client streaming: need a request iterator instead #}
|
||||
, request_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
|
||||
, {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
|
||||
{%- endif -%}
|
||||
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||
{% if method.comment %}
|
||||
@@ -194,25 +157,17 @@ class {{ service.py_name }}Base(ServiceBase):
|
||||
async def __rpc_{{ method.py_name }}(self, stream: grpclib.server.Stream) -> None:
|
||||
{% if not method.client_streaming %}
|
||||
request = await stream.recv_message()
|
||||
|
||||
request_kwargs = {
|
||||
{% for field in method.py_input_message.fields %}
|
||||
"{{ field.py_name }}": request.{{ field.py_name }},
|
||||
{% endfor %}
|
||||
}
|
||||
|
||||
{% else %}
|
||||
request_kwargs = {"request_iterator": stream.__aiter__()}
|
||||
request = stream.__aiter__()
|
||||
{% endif %}
|
||||
|
||||
{% if not method.server_streaming %}
|
||||
response = await self.{{ method.py_name }}(**request_kwargs)
|
||||
response = await self.{{ method.py_name }}(request)
|
||||
await stream.send_message(response)
|
||||
{% else %}
|
||||
await self._call_rpc_handler_server_stream(
|
||||
self.{{ method.py_name }},
|
||||
stream,
|
||||
request_kwargs,
|
||||
request,
|
||||
)
|
||||
{% endif %}
|
||||
|
||||
|
@@ -1,23 +1,24 @@
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import grpclib
|
||||
import grpclib.metadata
|
||||
import grpclib.server
|
||||
import pytest
|
||||
from betterproto.grpc.util.async_channel import AsyncChannel
|
||||
from grpclib.testing import ChannelFor
|
||||
from tests.output_betterproto.service.service import (
|
||||
DoThingRequest,
|
||||
DoThingResponse,
|
||||
GetThingRequest,
|
||||
TestStub as ThingServiceClient,
|
||||
)
|
||||
import grpclib
|
||||
import grpclib.metadata
|
||||
import grpclib.server
|
||||
from grpclib.testing import ChannelFor
|
||||
import pytest
|
||||
from betterproto.grpc.util.async_channel import AsyncChannel
|
||||
from tests.output_betterproto.service.service import TestStub as ThingServiceClient
|
||||
|
||||
from .thing_service import ThingService
|
||||
|
||||
|
||||
async def _test_client(client, name="clean room", **kwargs):
|
||||
response = await client.do_thing(name=name)
|
||||
async def _test_client(client: ThingServiceClient, name="clean room", **kwargs):
|
||||
response = await client.do_thing(DoThingRequest(name=name))
|
||||
assert response.names == [name]
|
||||
|
||||
|
||||
@@ -62,7 +63,7 @@ async def test_trailer_only_error_unary_unary(
|
||||
)
|
||||
async with ChannelFor([service]) as channel:
|
||||
with pytest.raises(grpclib.exceptions.GRPCError) as e:
|
||||
await ThingServiceClient(channel).do_thing(name="something")
|
||||
await ThingServiceClient(channel).do_thing(DoThingRequest(name="something"))
|
||||
assert e.value.status == grpclib.Status.UNAUTHENTICATED
|
||||
|
||||
|
||||
@@ -80,7 +81,7 @@ async def test_trailer_only_error_stream_unary(
|
||||
async with ChannelFor([service]) as channel:
|
||||
with pytest.raises(grpclib.exceptions.GRPCError) as e:
|
||||
await ThingServiceClient(channel).do_many_things(
|
||||
request_iterator=[DoThingRequest(name="something")]
|
||||
do_thing_request_iterator=[DoThingRequest(name="something")]
|
||||
)
|
||||
await _test_client(ThingServiceClient(channel))
|
||||
assert e.value.status == grpclib.Status.UNAUTHENTICATED
|
||||
@@ -178,7 +179,9 @@ async def test_async_gen_for_unary_stream_request():
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
client = ThingServiceClient(channel)
|
||||
expected_versions = [5, 4, 3, 2, 1]
|
||||
async for response in client.get_thing_versions(name=thing_name):
|
||||
async for response in client.get_thing_versions(
|
||||
GetThingRequest(name=thing_name)
|
||||
):
|
||||
assert response.name == thing_name
|
||||
assert response.version == expected_versions.pop()
|
||||
|
||||
|
@@ -1,49 +1,48 @@
|
||||
from typing import AsyncIterator, AsyncIterable
|
||||
from typing import AsyncIterable, AsyncIterator
|
||||
|
||||
import pytest
|
||||
from grpclib.testing import ChannelFor
|
||||
|
||||
from tests.output_betterproto.example_service.example_service import (
|
||||
TestBase,
|
||||
TestStub,
|
||||
ExampleRequest,
|
||||
ExampleResponse,
|
||||
TestBase,
|
||||
TestStub,
|
||||
)
|
||||
|
||||
|
||||
class ExampleService(TestBase):
|
||||
async def example_unary_unary(
|
||||
self, example_string: str, example_integer: int
|
||||
self, example_request: ExampleRequest
|
||||
) -> "ExampleResponse":
|
||||
return ExampleResponse(
|
||||
example_string=example_string,
|
||||
example_integer=example_integer,
|
||||
example_string=example_request.example_string,
|
||||
example_integer=example_request.example_integer,
|
||||
)
|
||||
|
||||
async def example_unary_stream(
|
||||
self, example_string: str, example_integer: int
|
||||
self, example_request: ExampleRequest
|
||||
) -> AsyncIterator["ExampleResponse"]:
|
||||
response = ExampleResponse(
|
||||
example_string=example_string,
|
||||
example_integer=example_integer,
|
||||
example_string=example_request.example_string,
|
||||
example_integer=example_request.example_integer,
|
||||
)
|
||||
yield response
|
||||
yield response
|
||||
yield response
|
||||
|
||||
async def example_stream_unary(
|
||||
self, request_iterator: AsyncIterator["ExampleRequest"]
|
||||
self, example_request_iterator: AsyncIterator["ExampleRequest"]
|
||||
) -> "ExampleResponse":
|
||||
async for example_request in request_iterator:
|
||||
async for example_request in example_request_iterator:
|
||||
return ExampleResponse(
|
||||
example_string=example_request.example_string,
|
||||
example_integer=example_request.example_integer,
|
||||
)
|
||||
|
||||
async def example_stream_stream(
|
||||
self, request_iterator: AsyncIterator["ExampleRequest"]
|
||||
self, example_request_iterator: AsyncIterator["ExampleRequest"]
|
||||
) -> AsyncIterator["ExampleResponse"]:
|
||||
async for example_request in request_iterator:
|
||||
async for example_request in example_request_iterator:
|
||||
yield ExampleResponse(
|
||||
example_string=example_request.example_string,
|
||||
example_integer=example_request.example_integer,
|
||||
@@ -52,44 +51,32 @@ class ExampleService(TestBase):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_with_different_cardinalities():
|
||||
test_string = "test string"
|
||||
test_int = 42
|
||||
example_request = ExampleRequest("test string", 42)
|
||||
|
||||
async with ChannelFor([ExampleService()]) as channel:
|
||||
stub = TestStub(channel)
|
||||
|
||||
# unary unary
|
||||
response = await stub.example_unary_unary(
|
||||
example_string="test string",
|
||||
example_integer=42,
|
||||
)
|
||||
assert response.example_string == test_string
|
||||
assert response.example_integer == test_int
|
||||
response = await stub.example_unary_unary(example_request)
|
||||
assert response.example_string == example_request.example_string
|
||||
assert response.example_integer == example_request.example_integer
|
||||
|
||||
# unary stream
|
||||
async for response in stub.example_unary_stream(
|
||||
example_string="test string",
|
||||
example_integer=42,
|
||||
):
|
||||
assert response.example_string == test_string
|
||||
assert response.example_integer == test_int
|
||||
async for response in stub.example_unary_stream(example_request):
|
||||
assert response.example_string == example_request.example_string
|
||||
assert response.example_integer == example_request.example_integer
|
||||
|
||||
# stream unary
|
||||
request = ExampleRequest(
|
||||
example_string=test_string,
|
||||
example_integer=42,
|
||||
)
|
||||
|
||||
async def request_iterator():
|
||||
yield request
|
||||
yield request
|
||||
yield request
|
||||
yield example_request
|
||||
yield example_request
|
||||
yield example_request
|
||||
|
||||
response = await stub.example_stream_unary(request_iterator())
|
||||
assert response.example_string == test_string
|
||||
assert response.example_integer == test_int
|
||||
assert response.example_string == example_request.example_string
|
||||
assert response.example_integer == example_request.example_integer
|
||||
|
||||
# stream stream
|
||||
async for response in stub.example_stream_stream(request_iterator()):
|
||||
assert response.example_string == test_string
|
||||
assert response.example_integer == test_int
|
||||
assert response.example_string == example_request.example_string
|
||||
assert response.example_integer == example_request.example_integer
|
||||
|
@@ -2,9 +2,8 @@ from typing import Any, Callable, Optional
|
||||
|
||||
import betterproto.lib.google.protobuf as protobuf
|
||||
import pytest
|
||||
|
||||
from tests.mocks import MockChannel
|
||||
from tests.output_betterproto.googletypes_response import TestStub
|
||||
from tests.output_betterproto.googletypes_response import Input, TestStub
|
||||
|
||||
test_cases = [
|
||||
(TestStub.get_double, protobuf.DoubleValue, 2.5),
|
||||
@@ -22,14 +21,15 @@ test_cases = [
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||
async def test_channel_receives_wrapped_type(
|
||||
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
||||
service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
|
||||
):
|
||||
wrapped_value = wrapper_class()
|
||||
wrapped_value.value = value
|
||||
channel = MockChannel(responses=[wrapped_value])
|
||||
service = TestStub(channel)
|
||||
method_param = Input()
|
||||
|
||||
await service_method(service)
|
||||
await service_method(service, method_param)
|
||||
|
||||
assert channel.requests[0]["response_type"] != Optional[type(value)]
|
||||
assert channel.requests[0]["response_type"] == type(wrapped_value)
|
||||
@@ -39,7 +39,7 @@ async def test_channel_receives_wrapped_type(
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||
async def test_service_unwraps_response(
|
||||
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
||||
service_method: Callable[[TestStub, Input], Any], wrapper_class: Callable, value
|
||||
):
|
||||
"""
|
||||
grpclib does not unwrap wrapper values returned by services
|
||||
@@ -47,8 +47,9 @@ async def test_service_unwraps_response(
|
||||
wrapped_value = wrapper_class()
|
||||
wrapped_value.value = value
|
||||
service = TestStub(MockChannel(responses=[wrapped_value]))
|
||||
method_param = Input()
|
||||
|
||||
response_value = await service_method(service)
|
||||
response_value = await service_method(service, method_param)
|
||||
|
||||
assert response_value == value
|
||||
assert type(response_value) == type(value)
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from tests.mocks import MockChannel
|
||||
from tests.output_betterproto.googletypes_response_embedded import (
|
||||
Input,
|
||||
Output,
|
||||
TestStub,
|
||||
)
|
||||
@@ -26,7 +26,7 @@ async def test_service_passes_through_unwrapped_values_embedded_in_response():
|
||||
)
|
||||
|
||||
service = TestStub(MockChannel(responses=[output]))
|
||||
response = await service.get_output()
|
||||
response = await service.get_output(Input())
|
||||
|
||||
assert response.double_value == 10.0
|
||||
assert response.float_value == 12.0
|
||||
|
@@ -1,17 +1,21 @@
|
||||
import pytest
|
||||
|
||||
from tests.mocks import MockChannel
|
||||
from tests.output_betterproto.import_service_input_message import (
|
||||
NestedRequestMessage,
|
||||
RequestMessage,
|
||||
RequestResponse,
|
||||
TestStub,
|
||||
)
|
||||
from tests.output_betterproto.import_service_input_message.child import (
|
||||
ChildRequestMessage,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_correctly_imports_reference_message():
|
||||
mock_response = RequestResponse(value=10)
|
||||
service = TestStub(MockChannel([mock_response]))
|
||||
response = await service.do_thing(argument=1)
|
||||
response = await service.do_thing(RequestMessage(1))
|
||||
assert mock_response == response
|
||||
|
||||
|
||||
@@ -19,7 +23,7 @@ async def test_service_correctly_imports_reference_message():
|
||||
async def test_service_correctly_imports_reference_message_from_child_package():
|
||||
mock_response = RequestResponse(value=10)
|
||||
service = TestStub(MockChannel([mock_response]))
|
||||
response = await service.do_thing2(child_argument=1)
|
||||
response = await service.do_thing2(ChildRequestMessage(1))
|
||||
assert mock_response == response
|
||||
|
||||
|
||||
@@ -27,5 +31,5 @@ async def test_service_correctly_imports_reference_message_from_child_package():
|
||||
async def test_service_correctly_imports_nested_reference():
|
||||
mock_response = RequestResponse(value=10)
|
||||
service = TestStub(MockChannel([mock_response]))
|
||||
response = await service.do_thing3(nested_argument=1)
|
||||
response = await service.do_thing3(NestedRequestMessage(1))
|
||||
assert mock_response == response
|
||||
|
@@ -1,8 +1,9 @@
|
||||
import betterproto
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Dict
|
||||
from datetime import datetime
|
||||
from inspect import signature
|
||||
from inspect import Parameter, signature
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import betterproto
|
||||
|
||||
|
||||
def test_has_field():
|
||||
@@ -349,10 +350,8 @@ def test_recursive_message():
|
||||
|
||||
|
||||
def test_recursive_message_defaults():
|
||||
from tests.output_betterproto.recursivemessage import (
|
||||
Test as RecursiveMessage,
|
||||
Intermediate,
|
||||
)
|
||||
from tests.output_betterproto.recursivemessage import Intermediate
|
||||
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage
|
||||
|
||||
msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))
|
||||
|
||||
@@ -479,8 +478,10 @@ def test_iso_datetime_list():
|
||||
assert all([isinstance(item, datetime) for item in msg.timestamps])
|
||||
|
||||
|
||||
def test_enum_service_argument__expected_default_value():
|
||||
from tests.output_betterproto.service.service import ThingType, TestStub
|
||||
def test_service_argument__expected_parameter():
|
||||
from tests.output_betterproto.service.service import TestStub
|
||||
|
||||
sig = signature(TestStub.do_thing)
|
||||
assert sig.parameters["type"].default == ThingType.UNKNOWN
|
||||
do_thing_request_parameter = sig.parameters["do_thing_request"]
|
||||
assert do_thing_request_parameter.default is Parameter.empty
|
||||
assert do_thing_request_parameter.annotation == "DoThingRequest"
|
||||
|
Reference in New Issue
Block a user