Merge remote-tracking branch 'daniel/master' into fix/imports
# Conflicts: # Pipfile # README.md # betterproto/__init__.py # betterproto/plugin.py # betterproto/tests/util.py
This commit is contained in:
114
betterproto/tests/generate.py
Normal file → Executable file
114
betterproto/tests/generate.py
Normal file → Executable file
@@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
import glob
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
@@ -20,58 +21,63 @@ from betterproto.tests.util import (
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
|
||||
|
||||
def clear_directory(path: str):
|
||||
for file_or_directory in glob.glob(os.path.join(path, "*")):
|
||||
if os.path.isdir(file_or_directory):
|
||||
def clear_directory(dir_path: Path):
|
||||
for file_or_directory in dir_path.glob("*"):
|
||||
if file_or_directory.is_dir():
|
||||
shutil.rmtree(file_or_directory)
|
||||
else:
|
||||
os.remove(file_or_directory)
|
||||
file_or_directory.unlink()
|
||||
|
||||
|
||||
def generate(whitelist: Set[str]):
|
||||
path_whitelist = {os.path.realpath(e) for e in whitelist if os.path.exists(e)}
|
||||
name_whitelist = {e for e in whitelist if not os.path.exists(e)}
|
||||
async def generate(whitelist: Set[str], verbose: bool):
|
||||
test_case_names = set(get_directories(inputs_path)) - {"__pycache__"}
|
||||
|
||||
test_case_names = set(get_directories(inputs_path))
|
||||
|
||||
failed_test_cases = []
|
||||
path_whitelist = set()
|
||||
name_whitelist = set()
|
||||
for item in whitelist:
|
||||
if item in test_case_names:
|
||||
name_whitelist.add(item)
|
||||
continue
|
||||
path_whitelist.add(item)
|
||||
|
||||
generation_tasks = []
|
||||
for test_case_name in sorted(test_case_names):
|
||||
test_case_input_path = os.path.realpath(
|
||||
os.path.join(inputs_path, test_case_name)
|
||||
)
|
||||
|
||||
test_case_input_path = inputs_path.joinpath(test_case_name).resolve()
|
||||
if (
|
||||
whitelist
|
||||
and test_case_input_path not in path_whitelist
|
||||
and str(test_case_input_path) not in path_whitelist
|
||||
and test_case_name not in name_whitelist
|
||||
):
|
||||
continue
|
||||
generation_tasks.append(
|
||||
generate_test_case_output(test_case_input_path, test_case_name, verbose)
|
||||
)
|
||||
|
||||
print(f"Generating output for {test_case_name}")
|
||||
try:
|
||||
generate_test_case_output(test_case_name, test_case_input_path)
|
||||
except subprocess.CalledProcessError as e:
|
||||
failed_test_cases = []
|
||||
# Wait for all subprocs and match any failures to names to report
|
||||
for test_case_name, result in zip(
|
||||
sorted(test_case_names), await asyncio.gather(*generation_tasks)
|
||||
):
|
||||
if result != 0:
|
||||
failed_test_cases.append(test_case_name)
|
||||
|
||||
if failed_test_cases:
|
||||
sys.stderr.write("\nFailed to generate the following test cases:\n")
|
||||
sys.stderr.write(
|
||||
"\n\033[31;1;4mFailed to generate the following test cases:\033[0m\n"
|
||||
)
|
||||
for failed_test_case in failed_test_cases:
|
||||
sys.stderr.write(f"- {failed_test_case}\n")
|
||||
|
||||
|
||||
def generate_test_case_output(test_case_name, test_case_input_path=None):
|
||||
if not test_case_input_path:
|
||||
test_case_input_path = os.path.realpath(
|
||||
os.path.join(inputs_path, test_case_name)
|
||||
)
|
||||
async def generate_test_case_output(
|
||||
test_case_input_path: Path, test_case_name: str, verbose: bool
|
||||
) -> int:
|
||||
"""
|
||||
Returns the max of the subprocess return values
|
||||
"""
|
||||
|
||||
test_case_output_path_reference = os.path.join(
|
||||
output_path_reference, test_case_name
|
||||
)
|
||||
test_case_output_path_betterproto = os.path.join(
|
||||
output_path_betterproto, test_case_name
|
||||
)
|
||||
test_case_output_path_reference = output_path_reference.joinpath(test_case_name)
|
||||
test_case_output_path_betterproto = output_path_betterproto.joinpath(test_case_name)
|
||||
|
||||
os.makedirs(test_case_output_path_reference, exist_ok=True)
|
||||
os.makedirs(test_case_output_path_betterproto, exist_ok=True)
|
||||
@@ -79,14 +85,36 @@ def generate_test_case_output(test_case_name, test_case_input_path=None):
|
||||
clear_directory(test_case_output_path_reference)
|
||||
clear_directory(test_case_output_path_betterproto)
|
||||
|
||||
protoc_reference(test_case_input_path, test_case_output_path_reference)
|
||||
protoc_plugin(test_case_input_path, test_case_output_path_betterproto)
|
||||
(
|
||||
(ref_out, ref_err, ref_code),
|
||||
(plg_out, plg_err, plg_code),
|
||||
) = await asyncio.gather(
|
||||
protoc_reference(test_case_input_path, test_case_output_path_reference),
|
||||
protoc_plugin(test_case_input_path, test_case_output_path_betterproto),
|
||||
)
|
||||
|
||||
message = f"Generated output for {test_case_name!r}"
|
||||
if verbose:
|
||||
print(f"\033[31;1;4m{message}\033[0m")
|
||||
if ref_out:
|
||||
sys.stdout.buffer.write(ref_out)
|
||||
if ref_err:
|
||||
sys.stderr.buffer.write(ref_err)
|
||||
if plg_out:
|
||||
sys.stdout.buffer.write(plg_out)
|
||||
if plg_err:
|
||||
sys.stderr.buffer.write(plg_err)
|
||||
sys.stdout.buffer.flush()
|
||||
sys.stderr.buffer.flush()
|
||||
else:
|
||||
print(message)
|
||||
|
||||
return max(ref_code, plg_code)
|
||||
|
||||
|
||||
HELP = "\n".join(
|
||||
[
|
||||
"Usage: python generate.py",
|
||||
" python generate.py [DIRECTORIES or NAMES]",
|
||||
(
|
||||
"Usage: python generate.py [-h] [-v] [DIRECTORIES or NAMES]",
|
||||
"Generate python classes for standard tests.",
|
||||
"",
|
||||
"DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.",
|
||||
@@ -94,7 +122,7 @@ HELP = "\n".join(
|
||||
"",
|
||||
"NAMES One or more test-case names to generate classes for.",
|
||||
" python generate.py bool double enums",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -102,9 +130,13 @@ def main():
|
||||
if set(sys.argv).intersection({"-h", "--help"}):
|
||||
print(HELP)
|
||||
return
|
||||
whitelist = set(sys.argv[1:])
|
||||
|
||||
generate(whitelist)
|
||||
if sys.argv[1:2] == ["-v"]:
|
||||
verbose = True
|
||||
whitelist = set(sys.argv[2:])
|
||||
else:
|
||||
verbose = False
|
||||
whitelist = set(sys.argv[1:])
|
||||
asyncio.get_event_loop().run_until_complete(generate(whitelist, verbose))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
0
betterproto/tests/grpc/__init__.py
Normal file
0
betterproto/tests/grpc/__init__.py
Normal file
154
betterproto/tests/grpc/test_grpclib_client.py
Normal file
154
betterproto/tests/grpc/test_grpclib_client.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import asyncio
|
||||
from betterproto.tests.output_betterproto.service.service import (
|
||||
DoThingResponse,
|
||||
DoThingRequest,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
TestStub as ThingServiceClient,
|
||||
)
|
||||
import grpclib
|
||||
from grpclib.testing import ChannelFor
|
||||
import pytest
|
||||
from betterproto.grpc.util.async_channel import AsyncChannel
|
||||
from .thing_service import ThingService
|
||||
|
||||
|
||||
async def _test_client(client, name="clean room", **kwargs):
|
||||
response = await client.do_thing(name=name)
|
||||
assert response.names == [name]
|
||||
|
||||
|
||||
def _assert_request_meta_recieved(deadline, metadata):
|
||||
def server_side_test(stream):
|
||||
assert stream.deadline._timestamp == pytest.approx(
|
||||
deadline._timestamp, 1
|
||||
), "The provided deadline should be recieved serverside"
|
||||
assert (
|
||||
stream.metadata["authorization"] == metadata["authorization"]
|
||||
), "The provided authorization metadata should be recieved serverside"
|
||||
|
||||
return server_side_test
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_service_call():
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
await _test_client(ThingServiceClient(channel))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_call_with_upfront_request_params():
|
||||
# Setting deadline
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||
metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
||||
) as channel:
|
||||
await _test_client(
|
||||
ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||
)
|
||||
|
||||
# Setting timeout
|
||||
timeout = 99
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||
metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
||||
) as channel:
|
||||
await _test_client(
|
||||
ThingServiceClient(channel, timeout=timeout, metadata=metadata)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_call_lower_level_with_overrides():
|
||||
THING_TO_DO = "get milk"
|
||||
|
||||
# Setting deadline
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||
metadata = {"authorization": "12345"}
|
||||
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
|
||||
kwarg_metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ThingService(test_hook=_assert_request_meta_recieved(deadline, metadata),)]
|
||||
) as channel:
|
||||
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||
response = await client._unary_unary(
|
||||
"/service.Test/DoThing",
|
||||
DoThingRequest(THING_TO_DO),
|
||||
DoThingResponse,
|
||||
deadline=kwarg_deadline,
|
||||
metadata=kwarg_metadata,
|
||||
)
|
||||
assert response.names == [THING_TO_DO]
|
||||
|
||||
# Setting timeout
|
||||
timeout = 99
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||
metadata = {"authorization": "12345"}
|
||||
kwarg_timeout = 9000
|
||||
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
|
||||
kwarg_metadata = {"authorization": "09876"}
|
||||
async with ChannelFor(
|
||||
[
|
||||
ThingService(
|
||||
test_hook=_assert_request_meta_recieved(kwarg_deadline, kwarg_metadata),
|
||||
)
|
||||
]
|
||||
) as channel:
|
||||
client = ThingServiceClient(channel, deadline=deadline, metadata=metadata)
|
||||
response = await client._unary_unary(
|
||||
"/service.Test/DoThing",
|
||||
DoThingRequest(THING_TO_DO),
|
||||
DoThingResponse,
|
||||
timeout=kwarg_timeout,
|
||||
metadata=kwarg_metadata,
|
||||
)
|
||||
assert response.names == [THING_TO_DO]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_gen_for_unary_stream_request():
|
||||
thing_name = "my milkshakes"
|
||||
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
client = ThingServiceClient(channel)
|
||||
expected_versions = [5, 4, 3, 2, 1]
|
||||
async for response in client.get_thing_versions(name=thing_name):
|
||||
assert response.name == thing_name
|
||||
assert response.version == expected_versions.pop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_gen_for_stream_stream_request():
|
||||
some_things = ["cake", "cricket", "coral reef"]
|
||||
more_things = ["ball", "that", "56kmodem", "liberal humanism", "cheesesticks"]
|
||||
expected_things = (*some_things, *more_things)
|
||||
|
||||
async with ChannelFor([ThingService()]) as channel:
|
||||
client = ThingServiceClient(channel)
|
||||
# Use an AsyncChannel to decouple sending and recieving, it'll send some_things
|
||||
# immediately and we'll use it to send more_things later, after recieving some
|
||||
# results
|
||||
request_chan = AsyncChannel()
|
||||
send_initial_requests = asyncio.ensure_future(
|
||||
request_chan.send_from(GetThingRequest(name) for name in some_things)
|
||||
)
|
||||
response_index = 0
|
||||
async for response in client.get_different_things(request_chan):
|
||||
assert response.name == expected_things[response_index]
|
||||
assert response.version == response_index + 1
|
||||
response_index += 1
|
||||
if more_things:
|
||||
# Send some more requests as we recieve reponses to be sure coordination of
|
||||
# send/recieve events doesn't matter
|
||||
await request_chan.send(GetThingRequest(more_things.pop(0)))
|
||||
elif not send_initial_requests.done():
|
||||
# Make sure the sending task it completed
|
||||
await send_initial_requests
|
||||
else:
|
||||
# No more things to send make sure channel is closed
|
||||
request_chan.close()
|
||||
assert response_index == len(
|
||||
expected_things
|
||||
), "Didn't recieve all exptected responses"
|
||||
100
betterproto/tests/grpc/test_stream_stream.py
Normal file
100
betterproto/tests/grpc/test_stream_stream.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import asyncio
|
||||
import betterproto
|
||||
from betterproto.grpc.util.async_channel import AsyncChannel
|
||||
from dataclasses import dataclass
|
||||
import pytest
|
||||
from typing import AsyncIterator
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message(betterproto.Message):
|
||||
body: str = betterproto.string_field(1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def expected_responses():
|
||||
return [Message("Hello world 1"), Message("Hello world 2"), Message("Done")]
|
||||
|
||||
|
||||
class ClientStub:
|
||||
async def connect(self, requests: AsyncIterator):
|
||||
await asyncio.sleep(0.1)
|
||||
async for request in requests:
|
||||
await asyncio.sleep(0.1)
|
||||
yield request
|
||||
await asyncio.sleep(0.1)
|
||||
yield Message("Done")
|
||||
|
||||
|
||||
async def to_list(generator: AsyncIterator):
|
||||
result = []
|
||||
async for value in generator:
|
||||
result.append(value)
|
||||
return result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
# channel = Channel(host='127.0.0.1', port=50051)
|
||||
# return ClientStub(channel)
|
||||
return ClientStub()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_from_before_connect_and_close_automatically(
|
||||
client, expected_responses
|
||||
):
|
||||
requests = AsyncChannel()
|
||||
await requests.send_from(
|
||||
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
|
||||
)
|
||||
responses = client.connect(requests)
|
||||
|
||||
assert await to_list(responses) == expected_responses
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_from_after_connect_and_close_automatically(
|
||||
client, expected_responses
|
||||
):
|
||||
requests = AsyncChannel()
|
||||
responses = client.connect(requests)
|
||||
await requests.send_from(
|
||||
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=True
|
||||
)
|
||||
|
||||
assert await to_list(responses) == expected_responses
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_from_close_manually_immediately(client, expected_responses):
|
||||
requests = AsyncChannel()
|
||||
responses = client.connect(requests)
|
||||
await requests.send_from(
|
||||
[Message(body="Hello world 1"), Message(body="Hello world 2")], close=False
|
||||
)
|
||||
requests.close()
|
||||
|
||||
assert await to_list(responses) == expected_responses
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_individually_and_close_before_connect(client, expected_responses):
|
||||
requests = AsyncChannel()
|
||||
await requests.send(Message(body="Hello world 1"))
|
||||
await requests.send(Message(body="Hello world 2"))
|
||||
requests.close()
|
||||
responses = client.connect(requests)
|
||||
|
||||
assert await to_list(responses) == expected_responses
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_individually_and_close_after_connect(client, expected_responses):
|
||||
requests = AsyncChannel()
|
||||
await requests.send(Message(body="Hello world 1"))
|
||||
await requests.send(Message(body="Hello world 2"))
|
||||
responses = client.connect(requests)
|
||||
requests.close()
|
||||
|
||||
assert await to_list(responses) == expected_responses
|
||||
83
betterproto/tests/grpc/thing_service.py
Normal file
83
betterproto/tests/grpc/thing_service.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from betterproto.tests.output_betterproto.service.service import (
|
||||
DoThingResponse,
|
||||
DoThingRequest,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
TestStub as ThingServiceClient,
|
||||
)
|
||||
import grpclib
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class ThingService:
|
||||
def __init__(self, test_hook=None):
|
||||
# This lets us pass assertions to the servicer ;)
|
||||
self.test_hook = test_hook
|
||||
|
||||
async def do_thing(
|
||||
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||
):
|
||||
request = await stream.recv_message()
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
await stream.send_message(DoThingResponse([request.name]))
|
||||
|
||||
async def do_many_things(
|
||||
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||
):
|
||||
thing_names = [request.name for request in stream]
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
await stream.send_message(DoThingResponse(thing_names))
|
||||
|
||||
async def get_thing_versions(
|
||||
self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
|
||||
):
|
||||
request = await stream.recv_message()
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
for version_num in range(1, 6):
|
||||
await stream.send_message(
|
||||
GetThingResponse(name=request.name, version=version_num)
|
||||
)
|
||||
|
||||
async def get_different_things(
|
||||
self, stream: "grpclib.server.Stream[GetThingRequest, GetThingResponse]"
|
||||
):
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
# Respond to each input item immediately
|
||||
response_num = 0
|
||||
async for request in stream:
|
||||
response_num += 1
|
||||
await stream.send_message(
|
||||
GetThingResponse(name=request.name, version=response_num)
|
||||
)
|
||||
|
||||
def __mapping__(self) -> Dict[str, "grpclib.const.Handler"]:
|
||||
return {
|
||||
"/service.Test/DoThing": grpclib.const.Handler(
|
||||
self.do_thing,
|
||||
grpclib.const.Cardinality.UNARY_UNARY,
|
||||
DoThingRequest,
|
||||
DoThingResponse,
|
||||
),
|
||||
"/service.Test/DoManyThings": grpclib.const.Handler(
|
||||
self.do_many_things,
|
||||
grpclib.const.Cardinality.STREAM_UNARY,
|
||||
DoThingRequest,
|
||||
DoThingResponse,
|
||||
),
|
||||
"/service.Test/GetThingVersions": grpclib.const.Handler(
|
||||
self.get_thing_versions,
|
||||
grpclib.const.Cardinality.UNARY_STREAM,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
),
|
||||
"/service.Test/GetDifferentThings": grpclib.const.Handler(
|
||||
self.get_different_things,
|
||||
grpclib.const.Cardinality.STREAM_STREAM,
|
||||
GetThingRequest,
|
||||
GetThingResponse,
|
||||
),
|
||||
}
|
||||
6
betterproto/tests/inputs/fixed/fixed.json
Normal file
6
betterproto/tests/inputs/fixed/fixed.json
Normal file
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"foo": 4294967295,
|
||||
"bar": -2147483648,
|
||||
"baz": "18446744073709551615",
|
||||
"qux": "-9223372036854775808"
|
||||
}
|
||||
8
betterproto/tests/inputs/fixed/fixed.proto
Normal file
8
betterproto/tests/inputs/fixed/fixed.proto
Normal file
@@ -0,0 +1,8 @@
|
||||
syntax = "proto3";
|
||||
|
||||
message Test {
|
||||
fixed32 foo = 1;
|
||||
sfixed32 bar = 2;
|
||||
fixed64 baz = 3;
|
||||
sfixed64 qux = 4;
|
||||
}
|
||||
@@ -21,7 +21,7 @@ test_cases = [
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(["service_method", "wrapper_class", "value"], test_cases)
|
||||
async def test_channel_receives_wrapped_type(
|
||||
async def test_channel_recieves_wrapped_type(
|
||||
service_method: Callable[[TestStub], Any], wrapper_class: Callable, value
|
||||
):
|
||||
wrapped_value = wrapper_class()
|
||||
|
||||
@@ -3,13 +3,25 @@ syntax = "proto3";
|
||||
package service;
|
||||
|
||||
message DoThingRequest {
|
||||
int32 iterations = 1;
|
||||
string name = 1;
|
||||
}
|
||||
|
||||
message DoThingResponse {
|
||||
int32 successfulIterations = 1;
|
||||
repeated string names = 1;
|
||||
}
|
||||
|
||||
message GetThingRequest {
|
||||
string name = 1;
|
||||
}
|
||||
|
||||
message GetThingResponse {
|
||||
string name = 1;
|
||||
int32 version = 2;
|
||||
}
|
||||
|
||||
service Test {
|
||||
rpc DoThing (DoThingRequest) returns (DoThingResponse);
|
||||
rpc DoManyThings (stream DoThingRequest) returns (DoThingResponse);
|
||||
rpc GetThingVersions (GetThingRequest) returns (stream GetThingResponse);
|
||||
rpc GetDifferentThings (stream GetThingRequest) returns (stream GetThingResponse);
|
||||
}
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
import betterproto
|
||||
import grpclib
|
||||
from grpclib.testing import ChannelFor
|
||||
import pytest
|
||||
from typing import Dict
|
||||
|
||||
from betterproto.tests.output_betterproto.service.service import (
|
||||
DoThingResponse,
|
||||
DoThingRequest,
|
||||
TestStub as ExampleServiceStub,
|
||||
)
|
||||
|
||||
|
||||
class ExampleService:
|
||||
def __init__(self, test_hook=None):
|
||||
# This lets us pass assertions to the servicer ;)
|
||||
self.test_hook = test_hook
|
||||
|
||||
async def DoThing(
|
||||
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
|
||||
):
|
||||
request = await stream.recv_message()
|
||||
print("self.test_hook", self.test_hook)
|
||||
if self.test_hook is not None:
|
||||
self.test_hook(stream)
|
||||
for iteration in range(request.iterations):
|
||||
pass
|
||||
await stream.send_message(DoThingResponse(request.iterations))
|
||||
|
||||
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
|
||||
return {
|
||||
"/service.Test/DoThing": grpclib.const.Handler(
|
||||
self.DoThing,
|
||||
grpclib.const.Cardinality.UNARY_UNARY,
|
||||
DoThingRequest,
|
||||
DoThingResponse,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
async def _test_stub(stub, iterations=42, **kwargs):
|
||||
response = await stub.do_thing(iterations=iterations)
|
||||
assert response.successful_iterations == iterations
|
||||
|
||||
|
||||
def _get_server_side_test(deadline, metadata):
|
||||
def server_side_test(stream):
|
||||
assert stream.deadline._timestamp == pytest.approx(
|
||||
deadline._timestamp, 1
|
||||
), "The provided deadline should be recieved serverside"
|
||||
assert (
|
||||
stream.metadata["authorization"] == metadata["authorization"]
|
||||
), "The provided authorization metadata should be recieved serverside"
|
||||
|
||||
return server_side_test
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_service_call():
|
||||
async with ChannelFor([ExampleService()]) as channel:
|
||||
await _test_stub(ExampleServiceStub(channel))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_call_with_upfront_request_params():
|
||||
# Setting deadline
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||
metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ExampleService(test_hook=_get_server_side_test(deadline, metadata))]
|
||||
) as channel:
|
||||
await _test_stub(
|
||||
ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
|
||||
)
|
||||
|
||||
# Setting timeout
|
||||
timeout = 99
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||
metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ExampleService(test_hook=_get_server_side_test(deadline, metadata))]
|
||||
) as channel:
|
||||
await _test_stub(
|
||||
ExampleServiceStub(channel, timeout=timeout, metadata=metadata)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_call_lower_level_with_overrides():
|
||||
ITERATIONS = 99
|
||||
|
||||
# Setting deadline
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(22)
|
||||
metadata = {"authorization": "12345"}
|
||||
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
|
||||
kwarg_metadata = {"authorization": "12345"}
|
||||
async with ChannelFor(
|
||||
[ExampleService(test_hook=_get_server_side_test(deadline, metadata))]
|
||||
) as channel:
|
||||
stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
|
||||
response = await stub._unary_unary(
|
||||
"/service.Test/DoThing",
|
||||
DoThingRequest(ITERATIONS),
|
||||
DoThingResponse,
|
||||
deadline=kwarg_deadline,
|
||||
metadata=kwarg_metadata,
|
||||
)
|
||||
assert response.successful_iterations == ITERATIONS
|
||||
|
||||
# Setting timeout
|
||||
timeout = 99
|
||||
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
|
||||
metadata = {"authorization": "12345"}
|
||||
kwarg_timeout = 9000
|
||||
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
|
||||
kwarg_metadata = {"authorization": "09876"}
|
||||
async with ChannelFor(
|
||||
[
|
||||
ExampleService(
|
||||
test_hook=_get_server_side_test(kwarg_deadline, kwarg_metadata)
|
||||
)
|
||||
]
|
||||
) as channel:
|
||||
stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
|
||||
response = await stub._unary_unary(
|
||||
"/service.Test/DoThing",
|
||||
DoThingRequest(ITERATIONS),
|
||||
DoThingResponse,
|
||||
timeout=kwarg_timeout,
|
||||
metadata=kwarg_metadata,
|
||||
)
|
||||
assert response.successful_iterations == ITERATIONS
|
||||
@@ -29,9 +29,9 @@ from google.protobuf.json_format import Parse
|
||||
|
||||
class TestCases:
|
||||
def __init__(self, path, services: Set[str], xfail: Set[str]):
|
||||
_all = set(get_directories(path))
|
||||
_all = set(get_directories(path)) - {"__pycache__"}
|
||||
_services = services
|
||||
_messages = _all - services
|
||||
_messages = (_all - services) - {"__pycache__"}
|
||||
_messages_with_json = {
|
||||
test for test in _messages if get_test_case_json_data(test)
|
||||
}
|
||||
|
||||
@@ -1,26 +1,27 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Callable, Generator, Optional
|
||||
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
|
||||
root_path = os.path.dirname(os.path.realpath(__file__))
|
||||
inputs_path = os.path.join(root_path, "inputs")
|
||||
output_path_reference = os.path.join(root_path, "output_reference")
|
||||
output_path_betterproto = os.path.join(root_path, "output_betterproto")
|
||||
root_path = Path(__file__).resolve().parent
|
||||
inputs_path = root_path.joinpath("inputs")
|
||||
output_path_reference = root_path.joinpath("output_reference")
|
||||
output_path_betterproto = root_path.joinpath("output_betterproto")
|
||||
|
||||
if os.name == "nt":
|
||||
plugin_path = os.path.join(root_path, "..", "plugin.bat")
|
||||
plugin_path = root_path.joinpath("..", "plugin.bat").resolve()
|
||||
else:
|
||||
plugin_path = os.path.join(root_path, "..", "plugin.py")
|
||||
plugin_path = root_path.joinpath("..", "plugin.py").resolve()
|
||||
|
||||
|
||||
def get_files(path, end: str) -> Generator[str, None, None]:
|
||||
def get_files(path, suffix: str) -> Generator[str, None, None]:
|
||||
for r, dirs, files in os.walk(path):
|
||||
for filename in [f for f in files if f.endswith(end)]:
|
||||
for filename in [f for f in files if f.endswith(suffix)]:
|
||||
yield os.path.join(r, filename)
|
||||
|
||||
|
||||
@@ -30,38 +31,32 @@ def get_directories(path):
|
||||
yield directory
|
||||
|
||||
|
||||
def relative(file: str, path: str):
|
||||
return os.path.join(os.path.dirname(file), path)
|
||||
|
||||
|
||||
def read_relative(file: str, path: str):
|
||||
with open(relative(file, path)) as fh:
|
||||
return fh.read()
|
||||
|
||||
|
||||
def protoc_plugin(path: str, output_dir: str) -> subprocess.CompletedProcess:
|
||||
return subprocess.run(
|
||||
async def protoc_plugin(path: str, output_dir: str):
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
f"protoc --plugin=protoc-gen-custom={plugin_path} --custom_out={output_dir} --proto_path={path} {path}/*.proto",
|
||||
shell=True,
|
||||
check=True,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
return (*(await proc.communicate()), proc.returncode)
|
||||
|
||||
|
||||
def protoc_reference(path: str, output_dir: str):
|
||||
subprocess.run(
|
||||
async def protoc_reference(path: str, output_dir: str):
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
f"protoc --python_out={output_dir} --proto_path={path} {path}/*.proto",
|
||||
shell=True,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
return (*(await proc.communicate()), proc.returncode)
|
||||
|
||||
|
||||
def get_test_case_json_data(test_case_name, json_file_name=None):
|
||||
def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = None):
|
||||
test_data_file_name = json_file_name if json_file_name else f"{test_case_name}.json"
|
||||
test_data_file_path = os.path.join(inputs_path, test_case_name, test_data_file_name)
|
||||
test_data_file_path = inputs_path.joinpath(test_case_name, test_data_file_name)
|
||||
|
||||
if not os.path.exists(test_data_file_path):
|
||||
if not test_data_file_path.exists():
|
||||
return None
|
||||
|
||||
with open(test_data_file_path) as fh:
|
||||
with test_data_file_path.open("r") as fh:
|
||||
return fh.read()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user