Blacken
This commit is contained in:
parent
f6af077ffe
commit
5e2d9febea
@ -122,7 +122,7 @@ def get_py_zero(type_num: int) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def traverse(proto_file):
|
def traverse(proto_file):
|
||||||
def _traverse(path, items, prefix = ''):
|
def _traverse(path, items, prefix=""):
|
||||||
for i, item in enumerate(items):
|
for i, item in enumerate(items):
|
||||||
# Adjust the name since we flatten the heirarchy.
|
# Adjust the name since we flatten the heirarchy.
|
||||||
item.name = next_prefix = prefix + item.name
|
item.name = next_prefix = prefix + item.name
|
||||||
|
@ -3,8 +3,14 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
from typing import Set
|
from typing import Set
|
||||||
|
|
||||||
from betterproto.tests.util import get_directories, inputs_path, output_path_betterproto, output_path_reference, \
|
from betterproto.tests.util import (
|
||||||
protoc_plugin, protoc_reference
|
get_directories,
|
||||||
|
inputs_path,
|
||||||
|
output_path_betterproto,
|
||||||
|
output_path_reference,
|
||||||
|
protoc_plugin,
|
||||||
|
protoc_reference,
|
||||||
|
)
|
||||||
|
|
||||||
# Force pure-python implementation instead of C++, otherwise imports
|
# Force pure-python implementation instead of C++, otherwise imports
|
||||||
# break things because we can't properly reset the symbol database.
|
# break things because we can't properly reset the symbol database.
|
||||||
@ -20,13 +26,19 @@ def generate(whitelist: Set[str]):
|
|||||||
for test_case_name in sorted(test_case_names):
|
for test_case_name in sorted(test_case_names):
|
||||||
test_case_path = os.path.realpath(os.path.join(inputs_path, test_case_name))
|
test_case_path = os.path.realpath(os.path.join(inputs_path, test_case_name))
|
||||||
|
|
||||||
if whitelist and test_case_path not in path_whitelist and test_case_name not in name_whitelist:
|
if (
|
||||||
|
whitelist
|
||||||
|
and test_case_path not in path_whitelist
|
||||||
|
and test_case_name not in name_whitelist
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
case_output_dir_reference = os.path.join(output_path_reference, test_case_name)
|
case_output_dir_reference = os.path.join(output_path_reference, test_case_name)
|
||||||
case_output_dir_betterproto = os.path.join(output_path_betterproto, test_case_name)
|
case_output_dir_betterproto = os.path.join(
|
||||||
|
output_path_betterproto, test_case_name
|
||||||
|
)
|
||||||
|
|
||||||
print(f'Generating output for {test_case_name}')
|
print(f"Generating output for {test_case_name}")
|
||||||
os.makedirs(case_output_dir_reference, exist_ok=True)
|
os.makedirs(case_output_dir_reference, exist_ok=True)
|
||||||
os.makedirs(case_output_dir_betterproto, exist_ok=True)
|
os.makedirs(case_output_dir_betterproto, exist_ok=True)
|
||||||
|
|
||||||
@ -34,21 +46,23 @@ def generate(whitelist: Set[str]):
|
|||||||
protoc_plugin(test_case_path, case_output_dir_betterproto)
|
protoc_plugin(test_case_path, case_output_dir_betterproto)
|
||||||
|
|
||||||
|
|
||||||
HELP = "\n".join([
|
HELP = "\n".join(
|
||||||
'Usage: python generate.py',
|
[
|
||||||
' python generate.py [DIRECTORIES or NAMES]',
|
"Usage: python generate.py",
|
||||||
'Generate python classes for standard tests.',
|
" python generate.py [DIRECTORIES or NAMES]",
|
||||||
'',
|
"Generate python classes for standard tests.",
|
||||||
'DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.',
|
"",
|
||||||
' python generate.py inputs/bool inputs/double inputs/enum',
|
"DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.",
|
||||||
'',
|
" python generate.py inputs/bool inputs/double inputs/enum",
|
||||||
'NAMES One or more test-case names to generate classes for.',
|
"",
|
||||||
' python generate.py bool double enums'
|
"NAMES One or more test-case names to generate classes for.",
|
||||||
])
|
" python generate.py bool double enums",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
if set(sys.argv).intersection({'-h', '--help'}):
|
if set(sys.argv).intersection({"-h", "--help"}):
|
||||||
print(HELP)
|
print(HELP)
|
||||||
return
|
return
|
||||||
whitelist = set(sys.argv[1:])
|
whitelist = set(sys.argv[1:])
|
||||||
|
@ -4,13 +4,19 @@ from betterproto.tests.output_betterproto.casing.casing import Test
|
|||||||
|
|
||||||
def test_message_attributes():
|
def test_message_attributes():
|
||||||
message = Test()
|
message = Test()
|
||||||
assert hasattr(message, 'snake_case_message'), 'snake_case field name is same in python'
|
assert hasattr(
|
||||||
assert hasattr(message, 'camel_case'), 'CamelCase field is snake_case in python'
|
message, "snake_case_message"
|
||||||
|
), "snake_case field name is same in python"
|
||||||
|
assert hasattr(message, "camel_case"), "CamelCase field is snake_case in python"
|
||||||
|
|
||||||
|
|
||||||
def test_message_casing():
|
def test_message_casing():
|
||||||
assert hasattr(casing, 'SnakeCaseMessage'), 'snake_case Message name is converted to CamelCase in python'
|
assert hasattr(
|
||||||
|
casing, "SnakeCaseMessage"
|
||||||
|
), "snake_case Message name is converted to CamelCase in python"
|
||||||
|
|
||||||
|
|
||||||
def test_enum_casing():
|
def test_enum_casing():
|
||||||
assert hasattr(casing, 'MyEnum'), 'snake_case Enum name is converted to CamelCase in python'
|
assert hasattr(
|
||||||
|
casing, "MyEnum"
|
||||||
|
), "snake_case Enum name is converted to CamelCase in python"
|
||||||
|
@ -2,7 +2,9 @@ from typing import Optional
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import TestStub
|
from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import (
|
||||||
|
TestStub
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestStubChild(TestStub):
|
class TestStubChild(TestStub):
|
||||||
@ -12,7 +14,7 @@ class TestStubChild(TestStub):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test():
|
async def test():
|
||||||
pytest.skip('todo')
|
pytest.skip("todo")
|
||||||
stub = TestStubChild(None)
|
stub = TestStubChild(None)
|
||||||
await stub.get_int64()
|
await stub.get_int64()
|
||||||
assert stub.response_type != Optional[int]
|
assert stub.response_type != Optional[int]
|
||||||
|
@ -256,7 +256,7 @@ def test_to_dict_default_values():
|
|||||||
some_double: float = betterproto.double_field(2)
|
some_double: float = betterproto.double_field(2)
|
||||||
some_message: TestChildMessage = betterproto.message_field(3)
|
some_message: TestChildMessage = betterproto.message_field(3)
|
||||||
|
|
||||||
test = TestParentMessage().from_dict({"someInt": 0, "someDouble": 1.2,})
|
test = TestParentMessage().from_dict({"someInt": 0, "someDouble": 1.2})
|
||||||
|
|
||||||
assert test.to_dict(include_default_values=True) == {
|
assert test.to_dict(include_default_values=True) == {
|
||||||
"someInt": 0,
|
"someInt": 0,
|
||||||
|
@ -15,27 +15,33 @@ from google.protobuf.descriptor_pool import DescriptorPool
|
|||||||
from google.protobuf.json_format import Parse
|
from google.protobuf.json_format import Parse
|
||||||
|
|
||||||
|
|
||||||
excluded_test_cases = {'googletypes_response', 'service'}
|
excluded_test_cases = {"googletypes_response", "service"}
|
||||||
test_case_names = {*get_directories(inputs_path)} - excluded_test_cases
|
test_case_names = {*get_directories(inputs_path)} - excluded_test_cases
|
||||||
|
|
||||||
plugin_output_package = 'betterproto.tests.output_betterproto'
|
plugin_output_package = "betterproto.tests.output_betterproto"
|
||||||
reference_output_package = 'betterproto.tests.output_reference'
|
reference_output_package = "betterproto.tests.output_reference"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("test_case_name", test_case_names)
|
@pytest.mark.parametrize("test_case_name", test_case_names)
|
||||||
def test_message_can_be_imported(test_case_name: str) -> None:
|
def test_message_can_be_imported(test_case_name: str) -> None:
|
||||||
importlib.import_module(f"{plugin_output_package}.{test_case_name}.{test_case_name}")
|
importlib.import_module(
|
||||||
|
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("test_case_name", test_case_names)
|
@pytest.mark.parametrize("test_case_name", test_case_names)
|
||||||
def test_message_can_instantiated(test_case_name: str) -> None:
|
def test_message_can_instantiated(test_case_name: str) -> None:
|
||||||
plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}.{test_case_name}")
|
plugin_module = importlib.import_module(
|
||||||
|
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
|
||||||
|
)
|
||||||
plugin_module.Test()
|
plugin_module.Test()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("test_case_name", test_case_names)
|
@pytest.mark.parametrize("test_case_name", test_case_names)
|
||||||
def test_message_equality(test_case_name: str) -> None:
|
def test_message_equality(test_case_name: str) -> None:
|
||||||
plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}.{test_case_name}")
|
plugin_module = importlib.import_module(
|
||||||
|
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
|
||||||
|
)
|
||||||
message1 = plugin_module.Test()
|
message1 = plugin_module.Test()
|
||||||
message2 = plugin_module.Test()
|
message2 = plugin_module.Test()
|
||||||
assert message1 == message2
|
assert message1 == message2
|
||||||
@ -43,7 +49,9 @@ def test_message_equality(test_case_name: str) -> None:
|
|||||||
|
|
||||||
@pytest.mark.parametrize("test_case_name", test_case_names)
|
@pytest.mark.parametrize("test_case_name", test_case_names)
|
||||||
def test_message_json(test_case_name: str) -> None:
|
def test_message_json(test_case_name: str) -> None:
|
||||||
plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}.{test_case_name}")
|
plugin_module = importlib.import_module(
|
||||||
|
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
|
||||||
|
)
|
||||||
message: betterproto.Message = plugin_module.Test()
|
message: betterproto.Message = plugin_module.Test()
|
||||||
reference_json_data = get_test_case_json_data(test_case_name)
|
reference_json_data = get_test_case_json_data(test_case_name)
|
||||||
|
|
||||||
@ -60,20 +68,28 @@ def test_binary_compatibility(test_case_name: str) -> None:
|
|||||||
sym = symbol_database.Default()
|
sym = symbol_database.Default()
|
||||||
sym.pool = DescriptorPool()
|
sym.pool = DescriptorPool()
|
||||||
|
|
||||||
reference_module_root = os.path.join(*reference_output_package.split('.'), test_case_name)
|
reference_module_root = os.path.join(
|
||||||
|
*reference_output_package.split("."), test_case_name
|
||||||
|
)
|
||||||
|
|
||||||
sys.path.append(reference_module_root)
|
sys.path.append(reference_module_root)
|
||||||
|
|
||||||
# import reference message
|
# import reference message
|
||||||
reference_module = importlib.import_module(f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2")
|
reference_module = importlib.import_module(
|
||||||
plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}.{test_case_name}")
|
f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2"
|
||||||
|
)
|
||||||
|
plugin_module = importlib.import_module(
|
||||||
|
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
|
||||||
|
)
|
||||||
|
|
||||||
test_data = get_test_case_json_data(test_case_name)
|
test_data = get_test_case_json_data(test_case_name)
|
||||||
|
|
||||||
reference_instance = Parse(test_data, reference_module.Test())
|
reference_instance = Parse(test_data, reference_module.Test())
|
||||||
reference_binary_output = reference_instance.SerializeToString()
|
reference_binary_output = reference_instance.SerializeToString()
|
||||||
|
|
||||||
plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json(test_data)
|
plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json(
|
||||||
|
test_data
|
||||||
|
)
|
||||||
plugin_instance_from_binary = plugin_module.Test.FromString(reference_binary_output)
|
plugin_instance_from_binary = plugin_module.Test.FromString(reference_binary_output)
|
||||||
|
|
||||||
# # Generally this can't be relied on, but here we are aiming to match the
|
# # Generally this can't be relied on, but here we are aiming to match the
|
||||||
@ -85,13 +101,13 @@ def test_binary_compatibility(test_case_name: str) -> None:
|
|||||||
sys.path.remove(reference_module_root)
|
sys.path.remove(reference_module_root)
|
||||||
|
|
||||||
|
|
||||||
'''
|
"""
|
||||||
helper methods
|
helper methods
|
||||||
'''
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_test_case_json_data(test_case_name):
|
def get_test_case_json_data(test_case_name):
|
||||||
test_data_path = os.path.join(inputs_path, test_case_name, f'{test_case_name}.json')
|
test_data_path = os.path.join(inputs_path, test_case_name, f"{test_case_name}.json")
|
||||||
if not os.path.exists(test_data_path):
|
if not os.path.exists(test_data_path):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@ -4,7 +4,12 @@ from grpclib.testing import ChannelFor
|
|||||||
import pytest
|
import pytest
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
from betterproto.tests.output_betterproto.service.service import DoThingResponse, DoThingRequest, ExampleServiceStub
|
from betterproto.tests.output_betterproto.service.service import (
|
||||||
|
DoThingResponse,
|
||||||
|
DoThingRequest,
|
||||||
|
ExampleServiceStub,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExampleService:
|
class ExampleService:
|
||||||
def __init__(self, test_hook=None):
|
def __init__(self, test_hook=None):
|
||||||
@ -29,7 +34,7 @@ class ExampleService:
|
|||||||
grpclib.const.Cardinality.UNARY_UNARY,
|
grpclib.const.Cardinality.UNARY_UNARY,
|
||||||
DoThingRequest,
|
DoThingRequest,
|
||||||
DoThingResponse,
|
DoThingResponse,
|
||||||
),
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -94,7 +99,9 @@ async def test_service_call_lower_level_with_overrides():
|
|||||||
) as channel:
|
) as channel:
|
||||||
stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
|
stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
|
||||||
response = await stub._unary_unary(
|
response = await stub._unary_unary(
|
||||||
"/service.ExampleService/DoThing", DoThingRequest(ITERATIONS), DoThingResponse,
|
"/service.ExampleService/DoThing",
|
||||||
|
DoThingRequest(ITERATIONS),
|
||||||
|
DoThingResponse,
|
||||||
deadline=kwarg_deadline,
|
deadline=kwarg_deadline,
|
||||||
metadata=kwarg_metadata,
|
metadata=kwarg_metadata,
|
||||||
)
|
)
|
||||||
@ -116,7 +123,9 @@ async def test_service_call_lower_level_with_overrides():
|
|||||||
) as channel:
|
) as channel:
|
||||||
stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
|
stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
|
||||||
response = await stub._unary_unary(
|
response = await stub._unary_unary(
|
||||||
"/service.ExampleService/DoThing", DoThingRequest(ITERATIONS), DoThingResponse,
|
"/service.ExampleService/DoThing",
|
||||||
|
DoThingRequest(ITERATIONS),
|
||||||
|
DoThingResponse,
|
||||||
timeout=kwarg_timeout,
|
timeout=kwarg_timeout,
|
||||||
metadata=kwarg_metadata,
|
metadata=kwarg_metadata,
|
||||||
)
|
)
|
||||||
|
@ -5,14 +5,14 @@ from typing import Generator
|
|||||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||||
|
|
||||||
root_path = os.path.dirname(os.path.realpath(__file__))
|
root_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
inputs_path = os.path.join(root_path, 'inputs')
|
inputs_path = os.path.join(root_path, "inputs")
|
||||||
output_path_reference = os.path.join(root_path, 'output_reference')
|
output_path_reference = os.path.join(root_path, "output_reference")
|
||||||
output_path_betterproto = os.path.join(root_path, 'output_betterproto')
|
output_path_betterproto = os.path.join(root_path, "output_betterproto")
|
||||||
|
|
||||||
if os.name == 'nt':
|
if os.name == "nt":
|
||||||
plugin_path = os.path.join(root_path, '..', 'plugin.bat')
|
plugin_path = os.path.join(root_path, "..", "plugin.bat")
|
||||||
else:
|
else:
|
||||||
plugin_path = os.path.join(root_path, '..', 'plugin.py')
|
plugin_path = os.path.join(root_path, "..", "plugin.py")
|
||||||
|
|
||||||
|
|
||||||
def get_files(path, end: str) -> Generator[str, None, None]:
|
def get_files(path, end: str) -> Generator[str, None, None]:
|
||||||
@ -44,5 +44,7 @@ def protoc_plugin(path: str, output_dir: str):
|
|||||||
|
|
||||||
|
|
||||||
def protoc_reference(path: str, output_dir: str):
|
def protoc_reference(path: str, output_dir: str):
|
||||||
subprocess.run(f"protoc --python_out={output_dir} --proto_path={path} {path}/*.proto", shell=True)
|
subprocess.run(
|
||||||
|
f"protoc --python_out={output_dir} --proto_path={path} {path}/*.proto",
|
||||||
|
shell=True,
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user