Detect entry-point of tests automatically
This commit is contained in:
@@ -3,6 +3,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
from types import ModuleType
|
||||
from typing import Set
|
||||
|
||||
import pytest
|
||||
@@ -10,7 +11,12 @@ import pytest
|
||||
import betterproto
|
||||
from betterproto.tests.inputs import config as test_input_config
|
||||
from betterproto.tests.mocks import MockChannel
|
||||
from betterproto.tests.util import get_directories, get_test_case_json_data, inputs_path
|
||||
from betterproto.tests.util import (
|
||||
find_module,
|
||||
get_directories,
|
||||
get_test_case_json_data,
|
||||
inputs_path,
|
||||
)
|
||||
|
||||
# Force pure-python implementation instead of C++, otherwise imports
|
||||
# break things because we can't properly reset the symbol database.
|
||||
@@ -50,16 +56,17 @@ class TestCases:
|
||||
test_cases = TestCases(
|
||||
path=inputs_path,
|
||||
services=test_input_config.services,
|
||||
xfail=test_input_config.tests,
|
||||
xfail=test_input_config.xfail,
|
||||
)
|
||||
|
||||
plugin_output_package = "betterproto.tests.output_betterproto"
|
||||
reference_output_package = "betterproto.tests.output_reference"
|
||||
|
||||
TestData = namedtuple("TestData", ["plugin_module", "reference_module", "json_data"])
|
||||
|
||||
TestData = namedtuple(
|
||||
"TestData", ["plugin_module", "reference_module", "json_data", "entry_point"]
|
||||
)
|
||||
|
||||
def module_has_entry_point(module: ModuleType):
|
||||
return any(hasattr(module, attr) for attr in ["Test", "TestStub"])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -77,18 +84,23 @@ def test_data(request):
|
||||
|
||||
sys.path.append(reference_module_root)
|
||||
|
||||
test_package = test_case_name + test_input_config.packages.get(test_case_name, "")
|
||||
plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}")
|
||||
|
||||
plugin_module_entry_point = find_module(plugin_module, module_has_entry_point)
|
||||
|
||||
if not plugin_module_entry_point:
|
||||
raise Exception(
|
||||
f"Test case {repr(test_case_name)} has no entry point. "
|
||||
+ "Please add a proto message or service called Test and recompile."
|
||||
)
|
||||
|
||||
yield (
|
||||
TestData(
|
||||
plugin_module=importlib.import_module(
|
||||
f"{plugin_output_package}.{test_package}"
|
||||
),
|
||||
plugin_module=plugin_module_entry_point,
|
||||
reference_module=lambda: importlib.import_module(
|
||||
f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2"
|
||||
),
|
||||
json_data=get_test_case_json_data(test_case_name),
|
||||
entry_point=test_package,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -111,7 +123,7 @@ def test_message_equality(test_data: TestData) -> None:
|
||||
|
||||
@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
|
||||
def test_message_json(repeat, test_data: TestData) -> None:
|
||||
plugin_module, _, json_data, entry_point = test_data
|
||||
plugin_module, _, json_data = test_data
|
||||
|
||||
for _ in range(repeat):
|
||||
message: betterproto.Message = plugin_module.Test()
|
||||
@@ -124,13 +136,13 @@ def test_message_json(repeat, test_data: TestData) -> None:
|
||||
|
||||
@pytest.mark.parametrize("test_data", test_cases.services, indirect=True)
|
||||
def test_service_can_be_instantiated(test_data: TestData) -> None:
|
||||
plugin_module, _, json_data, entry_point = test_data
|
||||
plugin_module, _, json_data = test_data
|
||||
plugin_module.TestStub(MockChannel())
|
||||
|
||||
|
||||
@pytest.mark.parametrize("test_data", test_cases.messages_with_json, indirect=True)
|
||||
def test_binary_compatibility(repeat, test_data: TestData) -> None:
|
||||
plugin_module, reference_module, json_data, entry_point = test_data
|
||||
plugin_module, reference_module, json_data = test_data
|
||||
|
||||
reference_instance = Parse(json_data, reference_module().Test())
|
||||
reference_binary_output = reference_instance.SerializeToString()
|
||||
|
||||
Reference in New Issue
Block a user