This commit is contained in:
Nat Noordanus 2020-05-23 23:37:22 +02:00
parent f6af077ffe
commit 5e2d9febea
8 changed files with 100 additions and 51 deletions

View File

@ -122,7 +122,7 @@ def get_py_zero(type_num: int) -> str:
def traverse(proto_file): def traverse(proto_file):
def _traverse(path, items, prefix = ''): def _traverse(path, items, prefix=""):
for i, item in enumerate(items): for i, item in enumerate(items):
# Adjust the name since we flatten the heirarchy. # Adjust the name since we flatten the heirarchy.
item.name = next_prefix = prefix + item.name item.name = next_prefix = prefix + item.name

View File

@ -3,8 +3,14 @@ import os
import sys import sys
from typing import Set from typing import Set
from betterproto.tests.util import get_directories, inputs_path, output_path_betterproto, output_path_reference, \ from betterproto.tests.util import (
protoc_plugin, protoc_reference get_directories,
inputs_path,
output_path_betterproto,
output_path_reference,
protoc_plugin,
protoc_reference,
)
# Force pure-python implementation instead of C++, otherwise imports # Force pure-python implementation instead of C++, otherwise imports
# break things because we can't properly reset the symbol database. # break things because we can't properly reset the symbol database.
@ -20,13 +26,19 @@ def generate(whitelist: Set[str]):
for test_case_name in sorted(test_case_names): for test_case_name in sorted(test_case_names):
test_case_path = os.path.realpath(os.path.join(inputs_path, test_case_name)) test_case_path = os.path.realpath(os.path.join(inputs_path, test_case_name))
if whitelist and test_case_path not in path_whitelist and test_case_name not in name_whitelist: if (
whitelist
and test_case_path not in path_whitelist
and test_case_name not in name_whitelist
):
continue continue
case_output_dir_reference = os.path.join(output_path_reference, test_case_name) case_output_dir_reference = os.path.join(output_path_reference, test_case_name)
case_output_dir_betterproto = os.path.join(output_path_betterproto, test_case_name) case_output_dir_betterproto = os.path.join(
output_path_betterproto, test_case_name
)
print(f'Generating output for {test_case_name}') print(f"Generating output for {test_case_name}")
os.makedirs(case_output_dir_reference, exist_ok=True) os.makedirs(case_output_dir_reference, exist_ok=True)
os.makedirs(case_output_dir_betterproto, exist_ok=True) os.makedirs(case_output_dir_betterproto, exist_ok=True)
@ -34,21 +46,23 @@ def generate(whitelist: Set[str]):
protoc_plugin(test_case_path, case_output_dir_betterproto) protoc_plugin(test_case_path, case_output_dir_betterproto)
HELP = "\n".join([ HELP = "\n".join(
'Usage: python generate.py', [
' python generate.py [DIRECTORIES or NAMES]', "Usage: python generate.py",
'Generate python classes for standard tests.', " python generate.py [DIRECTORIES or NAMES]",
'', "Generate python classes for standard tests.",
'DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.', "",
' python generate.py inputs/bool inputs/double inputs/enum', "DIRECTORIES One or more relative or absolute directories of test-cases to generate classes for.",
'', " python generate.py inputs/bool inputs/double inputs/enum",
'NAMES One or more test-case names to generate classes for.', "",
' python generate.py bool double enums' "NAMES One or more test-case names to generate classes for.",
]) " python generate.py bool double enums",
]
)
def main(): def main():
if set(sys.argv).intersection({'-h', '--help'}): if set(sys.argv).intersection({"-h", "--help"}):
print(HELP) print(HELP)
return return
whitelist = set(sys.argv[1:]) whitelist = set(sys.argv[1:])

View File

@ -4,13 +4,19 @@ from betterproto.tests.output_betterproto.casing.casing import Test
def test_message_attributes(): def test_message_attributes():
message = Test() message = Test()
assert hasattr(message, 'snake_case_message'), 'snake_case field name is same in python' assert hasattr(
assert hasattr(message, 'camel_case'), 'CamelCase field is snake_case in python' message, "snake_case_message"
), "snake_case field name is same in python"
assert hasattr(message, "camel_case"), "CamelCase field is snake_case in python"
def test_message_casing(): def test_message_casing():
assert hasattr(casing, 'SnakeCaseMessage'), 'snake_case Message name is converted to CamelCase in python' assert hasattr(
casing, "SnakeCaseMessage"
), "snake_case Message name is converted to CamelCase in python"
def test_enum_casing(): def test_enum_casing():
assert hasattr(casing, 'MyEnum'), 'snake_case Enum name is converted to CamelCase in python' assert hasattr(
casing, "MyEnum"
), "snake_case Enum name is converted to CamelCase in python"

View File

@ -2,7 +2,9 @@ from typing import Optional
import pytest import pytest
from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import TestStub from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import (
TestStub
)
class TestStubChild(TestStub): class TestStubChild(TestStub):
@ -12,7 +14,7 @@ class TestStubChild(TestStub):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test(): async def test():
pytest.skip('todo') pytest.skip("todo")
stub = TestStubChild(None) stub = TestStubChild(None)
await stub.get_int64() await stub.get_int64()
assert stub.response_type != Optional[int] assert stub.response_type != Optional[int]

View File

@ -256,7 +256,7 @@ def test_to_dict_default_values():
some_double: float = betterproto.double_field(2) some_double: float = betterproto.double_field(2)
some_message: TestChildMessage = betterproto.message_field(3) some_message: TestChildMessage = betterproto.message_field(3)
test = TestParentMessage().from_dict({"someInt": 0, "someDouble": 1.2,}) test = TestParentMessage().from_dict({"someInt": 0, "someDouble": 1.2})
assert test.to_dict(include_default_values=True) == { assert test.to_dict(include_default_values=True) == {
"someInt": 0, "someInt": 0,

View File

@ -15,27 +15,33 @@ from google.protobuf.descriptor_pool import DescriptorPool
from google.protobuf.json_format import Parse from google.protobuf.json_format import Parse
excluded_test_cases = {'googletypes_response', 'service'} excluded_test_cases = {"googletypes_response", "service"}
test_case_names = {*get_directories(inputs_path)} - excluded_test_cases test_case_names = {*get_directories(inputs_path)} - excluded_test_cases
plugin_output_package = 'betterproto.tests.output_betterproto' plugin_output_package = "betterproto.tests.output_betterproto"
reference_output_package = 'betterproto.tests.output_reference' reference_output_package = "betterproto.tests.output_reference"
@pytest.mark.parametrize("test_case_name", test_case_names) @pytest.mark.parametrize("test_case_name", test_case_names)
def test_message_can_be_imported(test_case_name: str) -> None: def test_message_can_be_imported(test_case_name: str) -> None:
importlib.import_module(f"{plugin_output_package}.{test_case_name}.{test_case_name}") importlib.import_module(
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
)
@pytest.mark.parametrize("test_case_name", test_case_names) @pytest.mark.parametrize("test_case_name", test_case_names)
def test_message_can_instantiated(test_case_name: str) -> None: def test_message_can_instantiated(test_case_name: str) -> None:
plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}.{test_case_name}") plugin_module = importlib.import_module(
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
)
plugin_module.Test() plugin_module.Test()
@pytest.mark.parametrize("test_case_name", test_case_names) @pytest.mark.parametrize("test_case_name", test_case_names)
def test_message_equality(test_case_name: str) -> None: def test_message_equality(test_case_name: str) -> None:
plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}.{test_case_name}") plugin_module = importlib.import_module(
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
)
message1 = plugin_module.Test() message1 = plugin_module.Test()
message2 = plugin_module.Test() message2 = plugin_module.Test()
assert message1 == message2 assert message1 == message2
@ -43,7 +49,9 @@ def test_message_equality(test_case_name: str) -> None:
@pytest.mark.parametrize("test_case_name", test_case_names) @pytest.mark.parametrize("test_case_name", test_case_names)
def test_message_json(test_case_name: str) -> None: def test_message_json(test_case_name: str) -> None:
plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}.{test_case_name}") plugin_module = importlib.import_module(
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
)
message: betterproto.Message = plugin_module.Test() message: betterproto.Message = plugin_module.Test()
reference_json_data = get_test_case_json_data(test_case_name) reference_json_data = get_test_case_json_data(test_case_name)
@ -60,20 +68,28 @@ def test_binary_compatibility(test_case_name: str) -> None:
sym = symbol_database.Default() sym = symbol_database.Default()
sym.pool = DescriptorPool() sym.pool = DescriptorPool()
reference_module_root = os.path.join(*reference_output_package.split('.'), test_case_name) reference_module_root = os.path.join(
*reference_output_package.split("."), test_case_name
)
sys.path.append(reference_module_root) sys.path.append(reference_module_root)
# import reference message # import reference message
reference_module = importlib.import_module(f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2") reference_module = importlib.import_module(
plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}.{test_case_name}") f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2"
)
plugin_module = importlib.import_module(
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
)
test_data = get_test_case_json_data(test_case_name) test_data = get_test_case_json_data(test_case_name)
reference_instance = Parse(test_data, reference_module.Test()) reference_instance = Parse(test_data, reference_module.Test())
reference_binary_output = reference_instance.SerializeToString() reference_binary_output = reference_instance.SerializeToString()
plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json(test_data) plugin_instance_from_json: betterproto.Message = plugin_module.Test().from_json(
test_data
)
plugin_instance_from_binary = plugin_module.Test.FromString(reference_binary_output) plugin_instance_from_binary = plugin_module.Test.FromString(reference_binary_output)
# # Generally this can't be relied on, but here we are aiming to match the # # Generally this can't be relied on, but here we are aiming to match the
@ -85,13 +101,13 @@ def test_binary_compatibility(test_case_name: str) -> None:
sys.path.remove(reference_module_root) sys.path.remove(reference_module_root)
''' """
helper methods helper methods
''' """
def get_test_case_json_data(test_case_name): def get_test_case_json_data(test_case_name):
test_data_path = os.path.join(inputs_path, test_case_name, f'{test_case_name}.json') test_data_path = os.path.join(inputs_path, test_case_name, f"{test_case_name}.json")
if not os.path.exists(test_data_path): if not os.path.exists(test_data_path):
return None return None

View File

@ -4,7 +4,12 @@ from grpclib.testing import ChannelFor
import pytest import pytest
from typing import Dict from typing import Dict
from betterproto.tests.output_betterproto.service.service import DoThingResponse, DoThingRequest, ExampleServiceStub from betterproto.tests.output_betterproto.service.service import (
DoThingResponse,
DoThingRequest,
ExampleServiceStub,
)
class ExampleService: class ExampleService:
def __init__(self, test_hook=None): def __init__(self, test_hook=None):
@ -29,7 +34,7 @@ class ExampleService:
grpclib.const.Cardinality.UNARY_UNARY, grpclib.const.Cardinality.UNARY_UNARY,
DoThingRequest, DoThingRequest,
DoThingResponse, DoThingResponse,
), )
} }
@ -94,7 +99,9 @@ async def test_service_call_lower_level_with_overrides():
) as channel: ) as channel:
stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata) stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
response = await stub._unary_unary( response = await stub._unary_unary(
"/service.ExampleService/DoThing", DoThingRequest(ITERATIONS), DoThingResponse, "/service.ExampleService/DoThing",
DoThingRequest(ITERATIONS),
DoThingResponse,
deadline=kwarg_deadline, deadline=kwarg_deadline,
metadata=kwarg_metadata, metadata=kwarg_metadata,
) )
@ -116,7 +123,9 @@ async def test_service_call_lower_level_with_overrides():
) as channel: ) as channel:
stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata) stub = ExampleServiceStub(channel, deadline=deadline, metadata=metadata)
response = await stub._unary_unary( response = await stub._unary_unary(
"/service.ExampleService/DoThing", DoThingRequest(ITERATIONS), DoThingResponse, "/service.ExampleService/DoThing",
DoThingRequest(ITERATIONS),
DoThingResponse,
timeout=kwarg_timeout, timeout=kwarg_timeout,
metadata=kwarg_metadata, metadata=kwarg_metadata,
) )

View File

@ -5,14 +5,14 @@ from typing import Generator
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
root_path = os.path.dirname(os.path.realpath(__file__)) root_path = os.path.dirname(os.path.realpath(__file__))
inputs_path = os.path.join(root_path, 'inputs') inputs_path = os.path.join(root_path, "inputs")
output_path_reference = os.path.join(root_path, 'output_reference') output_path_reference = os.path.join(root_path, "output_reference")
output_path_betterproto = os.path.join(root_path, 'output_betterproto') output_path_betterproto = os.path.join(root_path, "output_betterproto")
if os.name == 'nt': if os.name == "nt":
plugin_path = os.path.join(root_path, '..', 'plugin.bat') plugin_path = os.path.join(root_path, "..", "plugin.bat")
else: else:
plugin_path = os.path.join(root_path, '..', 'plugin.py') plugin_path = os.path.join(root_path, "..", "plugin.py")
def get_files(path, end: str) -> Generator[str, None, None]: def get_files(path, end: str) -> Generator[str, None, None]:
@ -44,5 +44,7 @@ def protoc_plugin(path: str, output_dir: str):
def protoc_reference(path: str, output_dir: str): def protoc_reference(path: str, output_dir: str):
subprocess.run(f"protoc --python_out={output_dir} --proto_path={path} {path}/*.proto", shell=True) subprocess.run(
f"protoc --python_out={output_dir} --proto_path={path} {path}/*.proto",
shell=True,
)