Merge remote-tracking branch 'daniel/master' into pr/wrapper-as-output

This commit is contained in:
boukeversteegh
2020-05-24 10:41:40 +02:00
64 changed files with 1092 additions and 127 deletions

View File

@@ -11,10 +11,12 @@ from typing import (
Any,
AsyncGenerator,
Callable,
Collection,
Dict,
Generator,
Iterable,
List,
Mapping,
Optional,
SupportsBytes,
Tuple,
@@ -1000,20 +1002,57 @@ def _get_wrapper(proto_type: str) -> Type:
}[proto_type]
_Value = Union[str, bytes]
_MetadataLike = Union[Mapping[str, _Value], Collection[Tuple[str, _Value]]]
class ServiceStub(ABC):
"""
Base class for async gRPC service stubs.
"""
def __init__(self, channel: grpclib.client.Channel) -> None:
def __init__(
self,
channel: grpclib.client.Channel,
*,
timeout: Optional[float] = None,
deadline: Optional[grpclib.metadata.Deadline] = None,
metadata: Optional[_MetadataLike] = None,
) -> None:
self.channel = channel
self.timeout = timeout
self.deadline = deadline
self.metadata = metadata
def __resolve_request_kwargs(
self,
timeout: Optional[float],
deadline: Optional[grpclib.metadata.Deadline],
metadata: Optional[_MetadataLike],
):
return {
"timeout": self.timeout if timeout is None else timeout,
"deadline": self.deadline if deadline is None else deadline,
"metadata": self.metadata if metadata is None else metadata,
}
async def _unary_unary(
self, route: str, request: "IProtoMessage", response_type: Type[T]
self,
route: str,
request: "IProtoMessage",
response_type: Type[T],
*,
timeout: Optional[float] = None,
deadline: Optional[grpclib.metadata.Deadline] = None,
metadata: Optional[_MetadataLike] = None,
) -> T:
"""Make a unary request and return the response."""
async with self.channel.request(
route, grpclib.const.Cardinality.UNARY_UNARY, type(request), response_type
route,
grpclib.const.Cardinality.UNARY_UNARY,
type(request),
response_type,
**self.__resolve_request_kwargs(timeout, deadline, metadata),
) as stream:
await stream.send_message(request, end=True)
response = await stream.recv_message()
@@ -1021,11 +1060,22 @@ class ServiceStub(ABC):
return response
async def _unary_stream(
self, route: str, request: "IProtoMessage", response_type: Type[T]
self,
route: str,
request: "IProtoMessage",
response_type: Type[T],
*,
timeout: Optional[float] = None,
deadline: Optional[grpclib.metadata.Deadline] = None,
metadata: Optional[_MetadataLike] = None,
) -> AsyncGenerator[T, None]:
"""Make a unary request and return the stream response iterator."""
async with self.channel.request(
route, grpclib.const.Cardinality.UNARY_STREAM, type(request), response_type
route,
grpclib.const.Cardinality.UNARY_STREAM,
type(request),
response_type,
**self.__resolve_request_kwargs(timeout, deadline, metadata),
) as stream:
await stream.send_message(request, end=True)
async for message in stream:

2
betterproto/plugin.bat Normal file
View File

@@ -0,0 +1,2 @@
@SET plugin_dir=%~dp0
@python %plugin_dir%/plugin.py %*

View File

@@ -138,7 +138,7 @@ def get_py_zero(type_num: int) -> str:
def traverse(proto_file):
def _traverse(path, items, prefix = ''):
def _traverse(path, items, prefix=""):
for i, item in enumerate(items):
# Adjust the name since we flatten the heirarchy.
item.name = next_prefix = prefix + item.name

View File

@@ -0,0 +1,75 @@
# Standard Tests Development Guide
Standard test cases are found in [betterproto/tests/inputs](inputs), where each subdirectory represents a testcase, that is verified in isolation.
```
inputs/
bool/
double/
int32/
...
```
## Test case directory structure
Each testcase has a `<name>.proto` file with a message called `Test`, a matching `.json` file and optionally a custom test file called `test_*.py`.
```bash
bool/
bool.proto
bool.json
test_bool.py # optional
```
### proto
`<name>.proto` &mdash; *The protobuf message to test*
```protobuf
syntax = "proto3";
message Test {
bool value = 1;
}
```
You can add multiple `.proto` files to the test case, as long as one file matches the directory name.
### json
`<name>.json` &mdash; *Test-data to validate the message with*
```json
{
"value": true
}
```
### pytest
`test_<name>.py` &mdash; *Custom test to validate specific aspects of the generated class*
```python
from betterproto.tests.output_betterproto.bool.bool import Test
def test_value():
message = Test()
assert not message.value, "Boolean is False by default"
```
## Standard tests
The following tests are automatically executed for all cases:
- [x] Can the generated python code imported?
- [x] Can the generated message class be instantiated?
- [x] Is the generated code compatible with the Google's `grpc_tools.protoc` implementation?
## Running the tests
- `pipenv run generate`
This generates
- `betterproto/tests/output_betterproto` &mdash; *the plugin generated python classes*
- `betterproto/tests/output_reference` &mdash; *reference implementation classes*
- `pipenv run test`

View File

@@ -1,84 +1,74 @@
#!/usr/bin/env python
import os
import sys
from typing import Set
from betterproto.tests.util import (
get_directories,
inputs_path,
output_path_betterproto,
output_path_reference,
protoc_plugin,
protoc_reference,
)
# Force pure-python implementation instead of C++, otherwise imports
# break things because we can't properly reset the symbol database.
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
import importlib
import json
import subprocess
import sys
from typing import Generator, Tuple
from google.protobuf import symbol_database
from google.protobuf.descriptor_pool import DescriptorPool
from google.protobuf.json_format import MessageToJson, Parse
def generate(whitelist: Set[str]):
path_whitelist = {os.path.realpath(e) for e in whitelist if os.path.exists(e)}
name_whitelist = {e for e in whitelist if not os.path.exists(e)}
test_case_names = set(get_directories(inputs_path))
for test_case_name in sorted(test_case_names):
test_case_path = os.path.realpath(os.path.join(inputs_path, test_case_name))
if (
whitelist
and test_case_path not in path_whitelist
and test_case_name not in name_whitelist
):
continue
case_output_dir_reference = os.path.join(output_path_reference, test_case_name)
case_output_dir_betterproto = os.path.join(
output_path_betterproto, test_case_name
)
print(f"Generating output for {test_case_name}")
os.makedirs(case_output_dir_reference, exist_ok=True)
os.makedirs(case_output_dir_betterproto, exist_ok=True)
protoc_reference(test_case_path, case_output_dir_reference)
protoc_plugin(test_case_path, case_output_dir_betterproto)
root = os.path.dirname(os.path.realpath(__file__))
HELP = "\n".join(
[
"Usage: python generate.py",
" python generate.py [DIRECTORIES or NAMES]",
"Generate python classes for standard tests.",
"",
"DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.",
" python generate.py inputs/bool inputs/double inputs/enum",
"",
"NAMES One or more test-case names to generate classes for.",
" python generate.py bool double enums",
]
)
def get_files(end: str) -> Generator[str, None, None]:
for r, dirs, files in os.walk(root):
for filename in [f for f in files if f.endswith(end)]:
yield os.path.join(r, filename)
def main():
if set(sys.argv).intersection({"-h", "--help"}):
print(HELP)
return
whitelist = set(sys.argv[1:])
def get_base(filename: str) -> str:
return os.path.splitext(os.path.basename(filename))[0]
def ensure_ext(filename: str, ext: str) -> str:
if not filename.endswith(ext):
return filename + ext
return filename
generate(whitelist)
if __name__ == "__main__":
os.chdir(root)
if len(sys.argv) > 1:
proto_files = [ensure_ext(f, ".proto") for f in sys.argv[1:]]
bases = {get_base(f) for f in proto_files}
json_files = [
f for f in get_files(".json") if get_base(f).split("-")[0] in bases
]
else:
proto_files = get_files(".proto")
json_files = get_files(".json")
for filename in proto_files:
print(f"Generating code for {os.path.basename(filename)}")
subprocess.run(
f"protoc --python_out=. {os.path.basename(filename)}", shell=True
)
subprocess.run(
f"protoc --plugin=protoc-gen-custom=../plugin.py --custom_out=. {os.path.basename(filename)}",
shell=True,
)
for filename in json_files:
# Reset the internal symbol database so we can import the `Test` message
# multiple times. Ugh.
sym = symbol_database.Default()
sym.pool = DescriptorPool()
parts = get_base(filename).split("-")
out = filename.replace(".json", ".bin")
print(f"Using {parts[0]}_pb2 to generate {os.path.basename(out)}")
imported = importlib.import_module(f"{parts[0]}_pb2")
input_json = open(filename).read()
parsed = Parse(input_json, imported.Test())
serialized = parsed.SerializeToString()
preserve = "casing" not in filename
serialized_json = MessageToJson(parsed, preserving_proto_field_name=preserve)
s_loaded = json.loads(serialized_json)
in_loaded = json.loads(input_json)
if s_loaded != in_loaded:
raise AssertionError("Expected JSON to be equal:", s_loaded, in_loaded)
open(out, "wb").write(serialized)
main()

View File

@@ -0,0 +1,6 @@
from betterproto.tests.output_betterproto.bool.bool import Test
def test_value():
message = Test()
assert not message.value, "Boolean is False by default"

View File

@@ -9,4 +9,9 @@ enum my_enum {
message Test {
int32 camelCase = 1;
my_enum snake_case = 2;
snake_case_message snake_case_message = 3;
}
message snake_case_message {
}

View File

@@ -0,0 +1,22 @@
import betterproto.tests.output_betterproto.casing.casing as casing
from betterproto.tests.output_betterproto.casing.casing import Test
def test_message_attributes():
message = Test()
assert hasattr(
message, "snake_case_message"
), "snake_case field name is same in python"
assert hasattr(message, "camel_case"), "CamelCase field is snake_case in python"
def test_message_casing():
assert hasattr(
casing, "SnakeCaseMessage"
), "snake_case Message name is converted to CamelCase in python"
def test_enum_casing():
assert hasattr(
casing, "MyEnum"
), "snake_case Enum name is converted to CamelCase in python"

View File

@@ -0,0 +1,18 @@
syntax = "proto3";
import "google/protobuf/wrappers.proto";
service Test {
rpc GetInt32 (Input) returns (google.protobuf.Int32Value);
rpc GetAnotherInt32 (Input) returns (google.protobuf.Int32Value);
rpc GetInt64 (Input) returns (google.protobuf.Int64Value);
rpc GetOutput (Input) returns (Output);
}
message Input {
}
message Output {
google.protobuf.Int64Value int64 = 1;
}

View File

@@ -0,0 +1,20 @@
from typing import Optional
import pytest
from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import (
TestStub
)
class TestStubChild(TestStub):
async def _unary_unary(self, route, request, response_type, **kwargs):
self.response_type = response_type
@pytest.mark.asyncio
async def test():
pytest.skip("todo")
stub = TestStubChild(None)
await stub.get_int64()
assert stub.response_type != Optional[int]

View File

@@ -0,0 +1,4 @@
{
"positive": 150,
"negative": -150
}

View File

@@ -3,5 +3,6 @@ syntax = "proto3";
// Some documentation about the Test message.
message Test {
// Some documentation about the count.
int32 count = 1;
int32 positive = 1;
int32 negative = 2;
}

View File

@@ -0,0 +1,11 @@
syntax = "proto3";
package repeatedmessage;
message Test {
repeated Sub greetings = 1;
}
message Sub {
string greeting = 1;
}

View File

@@ -0,0 +1,15 @@
syntax = "proto3";
package service;
message DoThingRequest {
int32 iterations = 1;
}
message DoThingResponse {
int32 successfulIterations = 1;
}
service ExampleService {
rpc DoThing (DoThingRequest) returns (DoThingResponse);
}

View File

@@ -0,0 +1,6 @@
{
"signed32": 150,
"negative32": -150,
"string64": "150",
"negative64": "-150"
}

View File

@@ -0,0 +1,9 @@
syntax = "proto3";
message Test {
// todo: rename fields after fixing bug where 'signed_32_positive' will map to 'signed_32Positive' as output json
sint32 signed32 = 1; // signed_32_positive
sint32 negative32 = 2; // signed_32_negative
sint64 string64 = 3; // signed_64_positive
sint64 negative64 = 4; // signed_64_negative
}

View File

@@ -1,3 +0,0 @@
{
"count": -150
}

View File

@@ -1,3 +0,0 @@
{
"count": 150
}

View File

@@ -1,4 +0,0 @@
{
"signed_32": -150,
"signed_64": "-150"
}

View File

@@ -1,4 +0,0 @@
{
"signed_32": 150,
"signed_64": "150"
}

View File

@@ -1,6 +0,0 @@
syntax = "proto3";
message Test {
sint32 signed_32 = 1;
sint64 signed_64 = 2;
}

View File

@@ -256,7 +256,7 @@ def test_to_dict_default_values():
some_double: float = betterproto.double_field(2)
some_message: TestChildMessage = betterproto.message_field(3)
test = TestParentMessage().from_dict({"someInt": 0, "someDouble": 1.2,})
test = TestParentMessage().from_dict({"someInt": 0, "someDouble": 1.2})
assert test.to_dict(include_default_values=True) == {
"someInt": 0,

View File

@@ -1,32 +1,115 @@
import importlib
import json
import os
import sys
import pytest
import betterproto
from betterproto.tests.util import get_directories, inputs_path
from .generate import get_base, get_files
# Force pure-python implementation instead of C++, otherwise imports
# break things because we can't properly reset the symbol database.
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
inputs = get_files(".bin")
from google.protobuf import symbol_database
from google.protobuf.descriptor_pool import DescriptorPool
from google.protobuf.json_format import Parse
@pytest.mark.parametrize("filename", inputs)
def test_sample(filename: str) -> None:
module = get_base(filename).split("-")[0]
imported = importlib.import_module(f"betterproto.tests.{module}")
data_binary = open(filename, "rb").read()
data_dict = json.loads(open(filename.replace(".bin", ".json")).read())
t1 = imported.Test().parse(data_binary)
t2 = imported.Test().from_dict(data_dict)
print(t1)
print(t2)
excluded_test_cases = {"googletypes_response", "service"}
test_case_names = {*get_directories(inputs_path)} - excluded_test_cases
# Equality should automagically work for dataclasses!
assert t1 == t2
plugin_output_package = "betterproto.tests.output_betterproto"
reference_output_package = "betterproto.tests.output_reference"
# Generally this can't be relied on, but here we are aiming to match the
# existing Python implementation and aren't doing anything tricky.
# https://developers.google.com/protocol-buffers/docs/encoding#implications
assert bytes(t1) == data_binary
assert bytes(t2) == data_binary
assert t1.to_dict() == data_dict
assert t2.to_dict() == data_dict
@pytest.mark.parametrize("test_case_name", test_case_names)
def test_message_can_be_imported(test_case_name: str) -> None:
importlib.import_module(
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
)
@pytest.mark.parametrize("test_case_name", test_case_names)
def test_message_can_instantiated(test_case_name: str) -> None:
plugin_module = importlib.import_module(
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
)
plugin_module.Test()
@pytest.mark.parametrize("test_case_name", test_case_names)
def test_message_equality(test_case_name: str) -> None:
plugin_module = importlib.import_module(
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
)
message1 = plugin_module.Test()
message2 = plugin_module.Test()
assert message1 == message2
@pytest.mark.parametrize("test_case_name", test_case_names)
def test_message_json(test_case_name: str) -> None:
plugin_module = importlib.import_module(
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
)
message: betterproto.Message = plugin_module.Test()
reference_json_data = get_test_case_json_data(test_case_name)
message.from_json(reference_json_data)
message_json = message.to_json(0)
assert json.loads(reference_json_data) == json.loads(message_json)
@pytest.mark.parametrize("test_case_name", test_case_names)
def test_binary_compatibility(test_case_name: str) -> None:
# Reset the internal symbol database so we can import the `Test` message
# multiple times. Ugh.
sym = symbol_database.Default()
sym.pool = DescriptorPool()
reference_module_root = os.path.join(
*reference_output_package.split("."), test_case_name
)
sys.path.append(reference_module_root)
# import reference message
reference_module = importlib.import_module(
f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2"
)
plugin_module = importlib.import_module(
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
)
test_data = get_test_case_json_data(test_case_name)
reference_instance = Parse(test_data, reference_module.Test())
reference_binary_output = reference_instance.SerializeToString()
plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json(
test_data
)
plugin_instance_from_binary = plugin_module.Test.FromString(reference_binary_output)
# # Generally this can't be relied on, but here we are aiming to match the
# # existing Python implementation and aren't doing anything tricky.
# # https://developers.google.com/protocol-buffers/docs/encoding#implications
assert plugin_instance_from_json == plugin_instance_from_binary
assert plugin_instance_from_json.to_dict() == plugin_instance_from_binary.to_dict()
sys.path.remove(reference_module_root)
"""
helper methods
"""
def get_test_case_json_data(test_case_name):
test_data_path = os.path.join(inputs_path, test_case_name, f"{test_case_name}.json")
if not os.path.exists(test_data_path):
return None
with open(test_data_path) as fh:
return fh.read()

View File

@@ -0,0 +1,132 @@
import betterproto
import grpclib
from grpclib.testing import ChannelFor
import pytest
from typing import Dict
from betterproto.tests.output_betterproto.service.service import (
DoThingResponse,
DoThingRequest,
ExampleServiceStub,
)
class ExampleService:
def __init__(self, test_hook=None):
# This lets us pass assertions to the servicer ;)
self.test_hook = test_hook
async def DoThing(
self, stream: "grpclib.server.Stream[DoThingRequest, DoThingResponse]"
):
request = await stream.recv_message()
print("self.test_hook", self.test_hook)
if self.test_hook is not None:
self.test_hook(stream)
for iteration in range(request.iterations):
pass
await stream.send_message(DoThingResponse(request.iterations))
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
return {
"/service.ExampleService/DoThing": grpclib.const.Handler(
self.DoThing,
grpclib.const.Cardinality.UNARY_UNARY,
DoThingRequest,
DoThingResponse,
)
}
async def _test_stub(stub, iterations=42, **kwargs):
response = await stub.do_thing(iterations=iterations)
assert response.successful_iterations == iterations
def _get_server_side_test(deadline, metadata):
def server_side_test(stream):
assert stream.deadline._timestamp == pytest.approx(
deadline._timestamp, 1
), "The provided deadline should be recieved serverside"
assert (
stream.metadata["authorization"] == metadata["authorization"]
), "The provided authorization metadata should be recieved serverside"
return server_side_test
@pytest.mark.asyncio
async def test_simple_service_call():
async with ChannelFor([ExampleService()]) as channel:
await _test_stub(ExampleServiceStub(channel))
@pytest.mark.asyncio
async def test_service_call_with_upfront_request_params():
# Setting deadline
deadline = grpclib.metadata.Deadline.from_timeout(22)
metadata = {"authorization": "12345"}
async with ChannelFor(
[ExampleService(test_hook=_get_server_side_test(deadline, metadata))]
) as channel:
await _test_stub(
ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
)
# Setting timeout
timeout = 99
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
metadata = {"authorization": "12345"}
async with ChannelFor(
[ExampleService(test_hook=_get_server_side_test(deadline, metadata))]
) as channel:
await _test_stub(
ExampleServiceStub(channel, timeout=timeout, metadata=metadata)
)
@pytest.mark.asyncio
async def test_service_call_lower_level_with_overrides():
ITERATIONS = 99
# Setting deadline
deadline = grpclib.metadata.Deadline.from_timeout(22)
metadata = {"authorization": "12345"}
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(28)
kwarg_metadata = {"authorization": "12345"}
async with ChannelFor(
[ExampleService(test_hook=_get_server_side_test(deadline, metadata))]
) as channel:
stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
response = await stub._unary_unary(
"/service.ExampleService/DoThing",
DoThingRequest(ITERATIONS),
DoThingResponse,
deadline=kwarg_deadline,
metadata=kwarg_metadata,
)
assert response.successful_iterations == ITERATIONS
# Setting timeout
timeout = 99
deadline = grpclib.metadata.Deadline.from_timeout(timeout)
metadata = {"authorization": "12345"}
kwarg_timeout = 9000
kwarg_deadline = grpclib.metadata.Deadline.from_timeout(kwarg_timeout)
kwarg_metadata = {"authorization": "09876"}
async with ChannelFor(
[
ExampleService(
test_hook=_get_server_side_test(kwarg_deadline, kwarg_metadata)
)
]
) as channel:
stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
response = await stub._unary_unary(
"/service.ExampleService/DoThing",
DoThingRequest(ITERATIONS),
DoThingResponse,
timeout=kwarg_timeout,
metadata=kwarg_metadata,
)
assert response.successful_iterations == ITERATIONS

50
betterproto/tests/util.py Normal file
View File

@@ -0,0 +1,50 @@
import os
import subprocess
from typing import Generator
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
root_path = os.path.dirname(os.path.realpath(__file__))
inputs_path = os.path.join(root_path, "inputs")
output_path_reference = os.path.join(root_path, "output_reference")
output_path_betterproto = os.path.join(root_path, "output_betterproto")
if os.name == "nt":
plugin_path = os.path.join(root_path, "..", "plugin.bat")
else:
plugin_path = os.path.join(root_path, "..", "plugin.py")
def get_files(path, end: str) -> Generator[str, None, None]:
for r, dirs, files in os.walk(path):
for filename in [f for f in files if f.endswith(end)]:
yield os.path.join(r, filename)
def get_directories(path):
for root, directories, files in os.walk(path):
for directory in directories:
yield directory
def relative(file: str, path: str):
return os.path.join(os.path.dirname(file), path)
def read_relative(file: str, path: str):
with open(relative(file, path)) as fh:
return fh.read()
def protoc_plugin(path: str, output_dir: str):
subprocess.run(
f"protoc --plugin=protoc-gen-custom={plugin_path} --custom_out={output_dir} --proto_path={path} {path}/*.proto",
shell=True,
)
def protoc_reference(path: str, output_dir: str):
subprocess.run(
f"protoc --python_out={output_dir} --proto_path={path} {path}/*.proto",
shell=True,
)