Detect entry-point of tests automatically

This commit is contained in:
boukeversteegh
2020-06-10 22:42:38 +02:00
parent 1a95a7988e
commit fb54917f2c
5 changed files with 61 additions and 36 deletions

View File

@@ -3,6 +3,7 @@ import json
import os
import sys
from collections import namedtuple
from types import ModuleType
from typing import Set
import pytest
@@ -10,7 +11,12 @@ import pytest
import betterproto
from betterproto.tests.inputs import config as test_input_config
from betterproto.tests.mocks import MockChannel
from betterproto.tests.util import get_directories, get_test_case_json_data, inputs_path
from betterproto.tests.util import (
find_module,
get_directories,
get_test_case_json_data,
inputs_path,
)
# Force pure-python implementation instead of C++, otherwise imports
# break things because we can't properly reset the symbol database.
@@ -50,16 +56,17 @@ class TestCases:
test_cases = TestCases(
path=inputs_path,
services=test_input_config.services,
xfail=test_input_config.tests,
xfail=test_input_config.xfail,
)
plugin_output_package = "betterproto.tests.output_betterproto"
reference_output_package = "betterproto.tests.output_reference"
TestData = namedtuple("TestData", ["plugin_module", "reference_module", "json_data"])
TestData = namedtuple(
"TestData", ["plugin_module", "reference_module", "json_data", "entry_point"]
)
def module_has_entry_point(module: ModuleType):
return any(hasattr(module, attr) for attr in ["Test", "TestStub"])
@pytest.fixture
@@ -77,18 +84,23 @@ def test_data(request):
sys.path.append(reference_module_root)
test_package = test_case_name + test_input_config.packages.get(test_case_name, "")
plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}")
plugin_module_entry_point = find_module(plugin_module, module_has_entry_point)
if not plugin_module_entry_point:
raise Exception(
f"Test case {repr(test_case_name)} has no entry point. "
+ "Please add a proto message or service called Test and recompile."
)
yield (
TestData(
plugin_module=importlib.import_module(
f"{plugin_output_package}.{test_package}"
),
plugin_module=plugin_module_entry_point,
reference_module=lambda: importlib.import_module(
f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2"
),
json_data=get_test_case_json_data(test_case_name),
entry_point=test_package,
)
)
@@ -111,7 +123,7 @@ def test_message_equality(test_data: TestData) -> None:
@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
def test_message_json(repeat, test_data: TestData) -> None:
plugin_module, _, json_data, entry_point = test_data
plugin_module, _, json_data = test_data
for _ in range(repeat):
message: betterproto.Message = plugin_module.Test()
@@ -124,13 +136,13 @@ def test_message_json(repeat, test_data: TestData) -> None:
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
def test_service_can_be_instantiated(test_data: TestData) -> None:
plugin_module, _, json_data, entry_point = test_data
plugin_module, _, json_data = test_data
plugin_module.TestStub(MockChannel())
@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
def test_binary_compatibility(repeat, test_data: TestData) -> None:
plugin_module, reference_module, json_data, entry_point = test_data
plugin_module, reference_module, json_data = test_data
reference_instance = Parse(json_data, reference_module().Test())
reference_binary_output = reference_instance.SerializeToString()