diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 597bf1a..5bce411 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -122,7 +122,7 @@ def get_py_zero(type_num: int) -> str: def traverse(proto_file): - def _traverse(path, items, prefix = ''): + def _traverse(path, items, prefix=""): for i, item in enumerate(items): # Adjust the name since we flatten the heirarchy. item.name = next_prefix = prefix + item.name diff --git a/betterproto/tests/generate.py b/betterproto/tests/generate.py index fdcb220..fc3c4cd 100644 --- a/betterproto/tests/generate.py +++ b/betterproto/tests/generate.py @@ -3,8 +3,14 @@ import os import sys from typing import Set -from betterproto.tests.util import get_directories, inputs_path, output_path_betterproto, output_path_reference, \ - protoc_plugin, protoc_reference +from betterproto.tests.util import ( + get_directories, + inputs_path, + output_path_betterproto, + output_path_reference, + protoc_plugin, + protoc_reference, +) # Force pure-python implementation instead of C++, otherwise imports # break things because we can't properly reset the symbol database. @@ -20,13 +26,19 @@ def generate(whitelist: Set[str]): for test_case_name in sorted(test_case_names): test_case_path = os.path.realpath(os.path.join(inputs_path, test_case_name)) - if whitelist and test_case_path not in path_whitelist and test_case_name not in name_whitelist: + if ( + whitelist + and test_case_path not in path_whitelist + and test_case_name not in name_whitelist + ): continue case_output_dir_reference = os.path.join(output_path_reference, test_case_name) - case_output_dir_betterproto = os.path.join(output_path_betterproto, test_case_name) + case_output_dir_betterproto = os.path.join( + output_path_betterproto, test_case_name + ) - print(f'Generating output for {test_case_name}') + print(f"Generating output for {test_case_name}") os.makedirs(case_output_dir_reference, exist_ok=True) os.makedirs(case_output_dir_betterproto, exist_ok=True) @@ -34,21 +46,23 @@ def generate(whitelist: Set[str]): protoc_plugin(test_case_path, case_output_dir_betterproto) -HELP = "\n".join([ - 'Usage: python generate.py', - ' python generate.py [DIRECTORIES or NAMES]', - 'Generate python classes for standard tests.', - '', - 'DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.', - ' python generate.py inputs/bool inputs/double inputs/enum', - '', - 'NAMES One or more test-case names to generate classes for.', - ' python generate.py bool double enums' -]) +HELP = "\n".join( + [ + "Usage: python generate.py", + " python generate.py [DIRECTORIES or NAMES]", + "Generate python classes for standard tests.", + "", + "DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.", + " python generate.py inputs/bool inputs/double inputs/enum", + "", + "NAMES One or more test-case names to generate classes for.", + " python generate.py bool double enums", + ] +) def main(): - if set(sys.argv).intersection({'-h', '--help'}): + if set(sys.argv).intersection({"-h", "--help"}): print(HELP) return whitelist = set(sys.argv[1:]) diff --git a/betterproto/tests/inputs/casing/test_casing.py b/betterproto/tests/inputs/casing/test_casing.py index 3d1ac3d..3255c4e 100644 --- a/betterproto/tests/inputs/casing/test_casing.py +++ b/betterproto/tests/inputs/casing/test_casing.py @@ -4,13 +4,19 @@ from betterproto.tests.output_betterproto.casing.casing import Test def test_message_attributes(): message = Test() - assert hasattr(message, 'snake_case_message'), 'snake_case field name is same in python' - assert hasattr(message, 'camel_case'), 'CamelCase field is snake_case in python' + assert hasattr( + message, "snake_case_message" + ), "snake_case field name is same in python" + assert hasattr(message, "camel_case"), "CamelCase field is snake_case in python" def test_message_casing(): - assert hasattr(casing, 'SnakeCaseMessage'), 'snake_case Message name is converted to CamelCase in python' + assert hasattr( + casing, "SnakeCaseMessage" + ), "snake_case Message name is converted to CamelCase in python" def test_enum_casing(): - assert hasattr(casing, 'MyEnum'), 'snake_case Enum name is converted to CamelCase in python' + assert hasattr( + casing, "MyEnum" + ), "snake_case Enum name is converted to CamelCase in python" diff --git a/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py b/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py index 9e8e454..fba2070 100644 --- a/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py +++ b/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py @@ -2,7 +2,9 @@ from typing import Optional import pytest -from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import TestStub +from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import ( + TestStub +) class TestStubChild(TestStub): @@ -12,7 +14,7 @@ class TestStubChild(TestStub): @pytest.mark.asyncio async def test(): - pytest.skip('todo') + pytest.skip("todo") stub = TestStubChild(None) await stub.get_int64() assert stub.response_type != Optional[int] diff --git a/betterproto/tests/test_features.py b/betterproto/tests/test_features.py index c0b40c1..47019e1 100644 --- a/betterproto/tests/test_features.py +++ b/betterproto/tests/test_features.py @@ -256,7 +256,7 @@ def test_to_dict_default_values(): some_double: float = betterproto.double_field(2) some_message: TestChildMessage = betterproto.message_field(3) - test = TestParentMessage().from_dict({"someInt": 0, "someDouble": 1.2,}) + test = TestParentMessage().from_dict({"someInt": 0, "someDouble": 1.2}) assert test.to_dict(include_default_values=True) == { "someInt": 0, diff --git a/betterproto/tests/test_inputs.py b/betterproto/tests/test_inputs.py index d819b2e..c8fb7d3 100644 --- a/betterproto/tests/test_inputs.py +++ b/betterproto/tests/test_inputs.py @@ -15,27 +15,33 @@ from google.protobuf.descriptor_pool import DescriptorPool from google.protobuf.json_format import Parse -excluded_test_cases = {'googletypes_response', 'service'} +excluded_test_cases = {"googletypes_response", "service"} test_case_names = {*get_directories(inputs_path)} - excluded_test_cases -plugin_output_package = 'betterproto.tests.output_betterproto' -reference_output_package = 'betterproto.tests.output_reference' +plugin_output_package = "betterproto.tests.output_betterproto" +reference_output_package = "betterproto.tests.output_reference" @pytest.mark.parametrize("test_case_name", test_case_names) def test_message_can_be_imported(test_case_name: str) -> None: - importlib.import_module(f"{plugin_output_package}.{test_case_name}.{test_case_name}") + importlib.import_module( + f"{plugin_output_package}.{test_case_name}.{test_case_name}" + ) @pytest.mark.parametrize("test_case_name", test_case_names) def test_message_can_instantiated(test_case_name: str) -> None: - plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}.{test_case_name}") + plugin_module = importlib.import_module( + f"{plugin_output_package}.{test_case_name}.{test_case_name}" + ) plugin_module.Test() @pytest.mark.parametrize("test_case_name", test_case_names) def test_message_equality(test_case_name: str) -> None: - plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}.{test_case_name}") + plugin_module = importlib.import_module( + f"{plugin_output_package}.{test_case_name}.{test_case_name}" + ) message1 = plugin_module.Test() message2 = plugin_module.Test() assert message1 == message2 @@ -43,7 +49,9 @@ def test_message_equality(test_case_name: str) -> None: @pytest.mark.parametrize("test_case_name", test_case_names) def test_message_json(test_case_name: str) -> None: - plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}.{test_case_name}") + plugin_module = importlib.import_module( + f"{plugin_output_package}.{test_case_name}.{test_case_name}" + ) message: betterproto.Message = plugin_module.Test() reference_json_data = get_test_case_json_data(test_case_name) @@ -60,20 +68,28 @@ def test_binary_compatibility(test_case_name: str) -> None: sym = symbol_database.Default() sym.pool = DescriptorPool() - reference_module_root = os.path.join(*reference_output_package.split('.'), test_case_name) + reference_module_root = os.path.join( + *reference_output_package.split("."), test_case_name + ) sys.path.append(reference_module_root) # import reference message - reference_module = importlib.import_module(f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2") - plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}.{test_case_name}") + reference_module = importlib.import_module( + f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2" + ) + plugin_module = importlib.import_module( + f"{plugin_output_package}.{test_case_name}.{test_case_name}" + ) test_data = get_test_case_json_data(test_case_name) reference_instance = Parse(test_data, reference_module.Test()) reference_binary_output = reference_instance.SerializeToString() - plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json(test_data) + plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json( + test_data + ) plugin_instance_from_binary = plugin_module.Test.FromString(reference_binary_output) # # Generally this can't be relied on, but here we are aiming to match the @@ -85,13 +101,13 @@ def test_binary_compatibility(test_case_name: str) -> None: sys.path.remove(reference_module_root) -''' +""" helper methods -''' +""" def get_test_case_json_data(test_case_name): - test_data_path = os.path.join(inputs_path, test_case_name, f'{test_case_name}.json') + test_data_path = os.path.join(inputs_path, test_case_name, f"{test_case_name}.json") if not os.path.exists(test_data_path): return None diff --git a/betterproto/tests/test_service_stub.py b/betterproto/tests/test_service_stub.py index 86377e2..b614e82 100644 --- a/betterproto/tests/test_service_stub.py +++ b/betterproto/tests/test_service_stub.py @@ -4,7 +4,12 @@ from grpclib.testing import ChannelFor import pytest from typing import Dict -from betterproto.tests.output_betterproto.service.service import DoThingResponse, DoThingRequest, ExampleServiceStub +from betterproto.tests.output_betterproto.service.service import ( + DoThingResponse, + DoThingRequest, + ExampleServiceStub, +) + class ExampleService: def __init__(self, test_hook=None): @@ -29,7 +34,7 @@ class ExampleService: grpclib.const.Cardinality.UNARY_UNARY, DoThingRequest, DoThingResponse, - ), + ) } @@ -94,7 +99,9 @@ async def test_service_call_lower_level_with_overrides(): ) as channel: stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata) response = await stub._unary_unary( - "/service.ExampleService/DoThing", DoThingRequest(ITERATIONS), DoThingResponse, + "/service.ExampleService/DoThing", + DoThingRequest(ITERATIONS), + DoThingResponse, deadline=kwarg_deadline, metadata=kwarg_metadata, ) @@ -116,7 +123,9 @@ async def test_service_call_lower_level_with_overrides(): ) as channel: stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata) response = await stub._unary_unary( - "/service.ExampleService/DoThing", DoThingRequest(ITERATIONS), DoThingResponse, + "/service.ExampleService/DoThing", + DoThingRequest(ITERATIONS), + DoThingResponse, timeout=kwarg_timeout, metadata=kwarg_metadata, ) diff --git a/betterproto/tests/util.py b/betterproto/tests/util.py index b627a23..83cfd98 100644 --- a/betterproto/tests/util.py +++ b/betterproto/tests/util.py @@ -5,14 +5,14 @@ from typing import Generator os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" root_path = os.path.dirname(os.path.realpath(__file__)) -inputs_path = os.path.join(root_path, 'inputs') -output_path_reference = os.path.join(root_path, 'output_reference') -output_path_betterproto = os.path.join(root_path, 'output_betterproto') +inputs_path = os.path.join(root_path, "inputs") +output_path_reference = os.path.join(root_path, "output_reference") +output_path_betterproto = os.path.join(root_path, "output_betterproto") -if os.name == 'nt': - plugin_path = os.path.join(root_path, '..', 'plugin.bat') +if os.name == "nt": + plugin_path = os.path.join(root_path, "..", "plugin.bat") else: - plugin_path = os.path.join(root_path, '..', 'plugin.py') + plugin_path = os.path.join(root_path, "..", "plugin.py") def get_files(path, end: str) -> Generator[str, None, None]: @@ -44,5 +44,7 @@ def protoc_plugin(path: str, output_dir: str): def protoc_reference(path: str, output_dir: str): - subprocess.run(f"protoc --python_out={output_dir} --proto_path={path} {path}/*.proto", shell=True) - + subprocess.run( + f"protoc --python_out={output_dir} --proto_path={path} {path}/*.proto", + shell=True, + )