diff --git a/betterproto/tests/generate.py b/betterproto/tests/generate.py index 03a5c2a..c33f30b 100644 --- a/betterproto/tests/generate.py +++ b/betterproto/tests/generate.py @@ -1,30 +1,62 @@ #!/usr/bin/env python import os +import sys +from typing import Set + +from betterproto.tests.util import get_directories, inputs_path, output_path_betterproto, output_path_reference, \ + protoc_plugin, protoc_reference # Force pure-python implementation instead of C++, otherwise imports # break things because we can't properly reset the symbol database. os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" -from betterproto.tests.util import get_directories, protoc_plugin, protoc_reference, root_path +def generate(whitelist: Set[str]): + path_whitelist = {os.path.realpath(e) for e in whitelist if os.path.exists(e)} + name_whitelist = {e for e in whitelist if not os.path.exists(e)} -def main(): - os.chdir(root_path) - test_cases_directory = os.path.join(root_path, 'inputs') - test_case = get_directories(test_cases_directory) + test_case_names = set(get_directories(inputs_path)) - for test_case_name in test_case: - test_case_path = os.path.join(test_cases_directory, test_case_name) + for test_case_name in test_case_names: + test_case_path = os.path.join(inputs_path, test_case_name) - case_reference_output_dir = os.path.join(root_path, 'output_reference', test_case_name) - case_plugin_output_dir = os.path.join(root_path, 'output_betterproto', test_case_name) + is_path_whitelisted = path_whitelist and os.path.realpath(test_case_path) in path_whitelist + is_name_whitelisted = name_whitelist and test_case_name in name_whitelist + + if whitelist and not is_path_whitelisted and not is_name_whitelisted: + continue + + case_output_dir_reference = os.path.join(output_path_reference, test_case_name) + case_output_dir_betterproto = os.path.join(output_path_betterproto, test_case_name) print(f'Generating output for {test_case_name}') - os.makedirs(case_reference_output_dir, exist_ok=True) - os.makedirs(case_plugin_output_dir, exist_ok=True) + os.makedirs(case_output_dir_reference, exist_ok=True) + os.makedirs(case_output_dir_betterproto, exist_ok=True) - protoc_reference(test_case_path, case_reference_output_dir) - protoc_plugin(test_case_path, case_plugin_output_dir) + protoc_reference(test_case_path, case_output_dir_reference) + protoc_plugin(test_case_path, case_output_dir_betterproto) + + +HELP = "\n".join([ + 'Usage: python generate.py', + ' python generate.py [DIRECTORIES or NAMES]', + 'Generate python classes for standard tests.', + '', + 'DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.', + ' python generate.py inputs/bool inputs/double inputs/enum', + '', + 'NAMES One or more test-case names to generate classes for.', + ' python generate.py bool double enums' +]) + + +def main(): + if sys.argv[1] in ('-h', '--help'): + print(HELP) + return + whitelist = set(sys.argv[1:]) + + generate(whitelist) if __name__ == "__main__": diff --git a/betterproto/tests/util.py b/betterproto/tests/util.py index c65e2a6..b627a23 100644 --- a/betterproto/tests/util.py +++ b/betterproto/tests/util.py @@ -6,6 +6,8 @@ os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" root_path = os.path.dirname(os.path.realpath(__file__)) inputs_path = os.path.join(root_path, 'inputs') +output_path_reference = os.path.join(root_path, 'output_reference') +output_path_betterproto = os.path.join(root_path, 'output_betterproto') if os.name == 'nt': plugin_path = os.path.join(root_path, '..', 'plugin.bat')