diff --git a/README.md b/README.md index c6001b3..15e85a7 100644 --- a/README.md +++ b/README.md @@ -68,14 +68,15 @@ message Greeting { You can run the following: ```sh -protoc -I . --python_betterproto_out=. example.proto +mkdir lib +protoc -I . --python_betterproto_out=lib example.proto ``` -This will generate `hello.py` which looks like: +This will generate `lib/hello/__init__.py` which looks like: -```py +```python # Generated by the protocol buffer compiler. DO NOT EDIT! -# sources: hello.proto +# sources: example.proto # plugin: python-betterproto from dataclasses import dataclass @@ -83,7 +84,7 @@ import betterproto @dataclass -class Hello(betterproto.Message): +class Greeting(betterproto.Message): """Greeting represents a message you can tell a user.""" message: str = betterproto.string_field(1) @@ -91,23 +92,23 @@ class Hello(betterproto.Message): Now you can use it! -```py ->>> from hello import Hello ->>> test = Hello() +```python +>>> from lib.hello import Greeting +>>> test = Greeting() >>> test -Hello(message='') +Greeting(message='') >>> test.message = "Hey!" >>> test -Hello(message="Hey!") +Greeting(message="Hey!") >>> serialized = bytes(test) >>> serialized b'\n\x04Hey!' ->>> another = Hello().parse(serialized) +>>> another = Greeting().parse(serialized) >>> another -Hello(message="Hey!") +Greeting(message="Hey!") >>> another.to_dict() {"message": "Hey!"} @@ -315,7 +316,7 @@ To benefit from the collection of standard development tasks ensure you have mak This project enforces [black](https://github.com/psf/black) python code formatting. -Before commiting changes run: +Before committing changes run: ```sh make format @@ -336,7 +337,7 @@ Adding a standard test case is easy. - Create a new directory `betterproto/tests/inputs/` - add `.proto` with a message called `Test` - - add `.json` with some test data + - add `.json` with some test data (optional) It will be picked up automatically when you run the tests. diff --git a/betterproto/__init__.py b/betterproto/__init__.py index c1e60ea..6c07feb 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -7,27 +7,22 @@ import sys from abc import ABC from base64 import b64decode, b64encode from datetime import datetime, timedelta, timezone -import stringcase from typing import ( Any, - AsyncGenerator, Callable, - Collection, Dict, Generator, - Iterator, List, - Mapping, Optional, Set, - SupportsBytes, Tuple, Type, Union, get_type_hints, ) -from ._types import ST, T -from .casing import safe_snake_case + +from ._types import T +from .casing import camel_case, safe_snake_case, safe_snake_case, snake_case from .grpc.grpclib_client import ServiceStub if not (sys.version_info.major == 3 and sys.version_info.minor >= 7): @@ -124,8 +119,8 @@ DATETIME_ZERO = datetime_default_gen() class Casing(enum.Enum): """Casing constants for serialization.""" - CAMEL = stringcase.camelcase - SNAKE = stringcase.snakecase + CAMEL = camel_case + SNAKE = snake_case class _PLACEHOLDER: diff --git a/betterproto/casing.py b/betterproto/casing.py index 67ca9a2..543df68 100644 --- a/betterproto/casing.py +++ b/betterproto/casing.py @@ -1,9 +1,21 @@ -import stringcase +import re + +# Word delimiters and symbols that will not be preserved when re-casing. +# language=PythonRegExp +SYMBOLS = "[^a-zA-Z0-9]*" + +# Optionally capitalized word. +# language=PythonRegExp +WORD = "[A-Z]*[a-z]*[0-9]*" + +# Uppercase word, not followed by lowercase letters. +# language=PythonRegExp +WORD_UPPER = "[A-Z]+(?![a-z])[0-9]*" def safe_snake_case(value: str) -> str: """Snake case a value taking into account Python keywords.""" - value = stringcase.snakecase(value) + value = snake_case(value) if value in [ "and", "as", @@ -39,3 +51,70 @@ def safe_snake_case(value: str) -> str: # https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles value += "_" return value + + +def snake_case(value: str, strict: bool = True): + """ + Join words with an underscore into lowercase and remove symbols. + @param value: value to convert + @param strict: force single underscores + """ + + def substitute_word(symbols, word, is_start): + if not word: + return "" + if strict: + delimiter_count = 0 if is_start else 1 # Single underscore if strict. + elif is_start: + delimiter_count = len(symbols) + elif word.isupper() or word.islower(): + delimiter_count = max( + 1, len(symbols) + ) # Preserve all delimiters if not strict. + else: + delimiter_count = len(symbols) + 1 # Extra underscore for leading capital. + + return ("_" * delimiter_count) + word.lower() + + snake = re.sub( + f"(^)?({SYMBOLS})({WORD_UPPER}|{WORD})", + lambda groups: substitute_word(groups[2], groups[3], groups[1] is not None), + value, + ) + return snake + + +def pascal_case(value: str, strict: bool = True): + """ + Capitalize each word and remove symbols. + @param value: value to convert + @param strict: output only alphanumeric characters + """ + + def substitute_word(symbols, word): + if strict: + return word.capitalize() # Remove all delimiters + + if word.islower(): + delimiter_length = len(symbols[:-1]) # Lose one delimiter + else: + delimiter_length = len(symbols) # Preserve all delimiters + + return ("_" * delimiter_length) + word.capitalize() + + return re.sub( + f"({SYMBOLS})({WORD_UPPER}|{WORD})", + lambda groups: substitute_word(groups[1], groups[2]), + value, + ) + + +def camel_case(value: str, strict: bool = True): + """ + Capitalize all words except first and remove symbols. + """ + return lowercase_first(pascal_case(value, strict=strict)) + + +def lowercase_first(value: str): + return value[0:1].lower() + value[1:] diff --git a/betterproto/compile/importing.py b/betterproto/compile/importing.py index 0c53e0b..40441f8 100644 --- a/betterproto/compile/importing.py +++ b/betterproto/compile/importing.py @@ -1,70 +1,160 @@ -from typing import Dict, Type - -import stringcase +import os +import re +from typing import Dict, List, Set, Type from betterproto import safe_snake_case +from betterproto.compile.naming import pythonize_class_name from betterproto.lib.google import protobuf as google_protobuf WRAPPER_TYPES: Dict[str, Type] = { - "google.protobuf.DoubleValue": google_protobuf.DoubleValue, - "google.protobuf.FloatValue": google_protobuf.FloatValue, - "google.protobuf.Int32Value": google_protobuf.Int32Value, - "google.protobuf.Int64Value": google_protobuf.Int64Value, - "google.protobuf.UInt32Value": google_protobuf.UInt32Value, - "google.protobuf.UInt64Value": google_protobuf.UInt64Value, - "google.protobuf.BoolValue": google_protobuf.BoolValue, - "google.protobuf.StringValue": google_protobuf.StringValue, - "google.protobuf.BytesValue": google_protobuf.BytesValue, + ".google.protobuf.DoubleValue": google_protobuf.DoubleValue, + ".google.protobuf.FloatValue": google_protobuf.FloatValue, + ".google.protobuf.Int32Value": google_protobuf.Int32Value, + ".google.protobuf.Int64Value": google_protobuf.Int64Value, + ".google.protobuf.UInt32Value": google_protobuf.UInt32Value, + ".google.protobuf.UInt64Value": google_protobuf.UInt64Value, + ".google.protobuf.BoolValue": google_protobuf.BoolValue, + ".google.protobuf.StringValue": google_protobuf.StringValue, + ".google.protobuf.BytesValue": google_protobuf.BytesValue, } -def get_ref_type( - package: str, imports: set, type_name: str, unwrap: bool = True +def parse_source_type_name(field_type_name): + """ + Split full source type name into package and type name. + E.g. 'root.package.Message' -> ('root.package', 'Message') + 'root.Message.SomeEnum' -> ('root', 'Message.SomeEnum') + """ + package_match = re.match(r"^\.?([^A-Z]+)\.(.+)", field_type_name) + if package_match: + package = package_match.group(1) + name = package_match.group(2) + else: + package = "" + name = field_type_name.lstrip(".") + return package, name + + +def get_type_reference( + package: str, imports: set, source_type: str, unwrap: bool = True, ) -> str: """ Return a Python type name for a proto type reference. Adds the import if necessary. Unwraps well known type if required. """ - # If the package name is a blank string, then this should still work - # because by convention packages are lowercase and message/enum types are - # pascal-cased. May require refactoring in the future. - type_name = type_name.lstrip(".") - - is_wrapper = type_name in WRAPPER_TYPES - if unwrap: - if is_wrapper: - wrapped_type = type(WRAPPER_TYPES[type_name]().value) + if source_type in WRAPPER_TYPES: + wrapped_type = type(WRAPPER_TYPES[source_type]().value) return f"Optional[{wrapped_type.__name__}]" - if type_name == "google.protobuf.Duration": + if source_type == ".google.protobuf.Duration": return "timedelta" - if type_name == "google.protobuf.Timestamp": + if source_type == ".google.protobuf.Timestamp": return "datetime" - if type_name.startswith(package): - parts = type_name.lstrip(package).lstrip(".").split(".") - if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()): - # This is the current package, which has nested types flattened. - # foo.bar_thing => FooBarThing - cased = [stringcase.pascalcase(part) for part in parts] - type_name = f'"{"".join(cased)}"' + source_package, source_type = parse_source_type_name(source_type) - # Use precompiled classes for google.protobuf.* objects - if type_name.startswith("google.protobuf.") and type_name.count(".") == 2: - type_name = type_name.rsplit(".", maxsplit=1)[1] - import_package = "betterproto.lib.google.protobuf" - import_alias = safe_snake_case(import_package) - imports.add(f"import {import_package} as {import_alias}") - return f"{import_alias}.{type_name}" + current_package: List[str] = package.split(".") if package else [] + py_package: List[str] = source_package.split(".") if source_package else [] + py_type: str = pythonize_class_name(source_type) - if "." in type_name: - # This is imported from another package. No need - # to use a forward ref and we need to add the import. - parts = type_name.split(".") - parts[-1] = stringcase.pascalcase(parts[-1]) - imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}") - type_name = f"{parts[-2]}.{parts[-1]}" + compiling_google_protobuf = current_package == ["google", "protobuf"] + importing_google_protobuf = py_package == ["google", "protobuf"] + if importing_google_protobuf and not compiling_google_protobuf: + py_package = ["betterproto", "lib"] + py_package - return type_name + if py_package[:1] == ["betterproto"]: + return reference_absolute(imports, py_package, py_type) + + if py_package == current_package: + return reference_sibling(py_type) + + if py_package[: len(current_package)] == current_package: + return reference_descendent(current_package, imports, py_package, py_type) + + if current_package[: len(py_package)] == py_package: + return reference_ancestor(current_package, imports, py_package, py_type) + + return reference_cousin(current_package, imports, py_package, py_type) + + +def reference_absolute(imports, py_package, py_type): + """ + Returns a reference to a python type located in the root, i.e. sys.path. + """ + string_import = ".".join(py_package) + string_alias = safe_snake_case(string_import) + imports.add(f"import {string_import} as {string_alias}") + return f"{string_alias}.{py_type}" + + +def reference_sibling(py_type: str) -> str: + """ + Returns a reference to a python type within the same package as the current package. + """ + return f'"{py_type}"' + + +def reference_descendent( + current_package: List[str], imports: Set[str], py_package: List[str], py_type: str +) -> str: + """ + Returns a reference to a python type in a package that is a descendent of the current package, + and adds the required import that is aliased to avoid name conflicts. + """ + importing_descendent = py_package[len(current_package) :] + string_from = ".".join(importing_descendent[:-1]) + string_import = importing_descendent[-1] + if string_from: + string_alias = "_".join(importing_descendent) + imports.add(f"from .{string_from} import {string_import} as {string_alias}") + return f"{string_alias}.{py_type}" + else: + imports.add(f"from . import {string_import}") + return f"{string_import}.{py_type}" + + +def reference_ancestor( + current_package: List[str], imports: Set[str], py_package: List[str], py_type: str +) -> str: + """ + Returns a reference to a python type in a package which is an ancestor to the current package, + and adds the required import that is aliased (if possible) to avoid name conflicts. + + Adds trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34). + """ + distance_up = len(current_package) - len(py_package) + if py_package: + string_import = py_package[-1] + string_alias = f"_{'_' * distance_up}{string_import}__" + string_from = f"..{'.' * distance_up}" + imports.add(f"from {string_from} import {string_import} as {string_alias}") + return f"{string_alias}.{py_type}" + else: + string_alias = f"{'_' * distance_up}{py_type}__" + imports.add(f"from .{'.' * distance_up} import {py_type} as {string_alias}") + return string_alias + + +def reference_cousin( + current_package: List[str], imports: Set[str], py_package: List[str], py_type: str +) -> str: + """ + Returns a reference to a python type in a package that is not descendent, ancestor or sibling, + and adds the required import that is aliased to avoid name conflicts. + """ + shared_ancestry = os.path.commonprefix([current_package, py_package]) + distance_up = len(current_package) - len(shared_ancestry) + string_from = f".{'.' * distance_up}" + ".".join( + py_package[len(shared_ancestry) : -1] + ) + string_import = py_package[-1] + # Add trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34) + string_alias = ( + f"{'_' * distance_up}" + + safe_snake_case(".".join(py_package[len(shared_ancestry) :])) + + "__" + ) + imports.add(f"from {string_from} import {string_import} as {string_alias}") + return f"{string_alias}.{py_type}" diff --git a/betterproto/compile/naming.py b/betterproto/compile/naming.py new file mode 100644 index 0000000..3d56852 --- /dev/null +++ b/betterproto/compile/naming.py @@ -0,0 +1,13 @@ +from betterproto import casing + + +def pythonize_class_name(name): + return casing.pascal_case(name) + + +def pythonize_field_name(name: str): + return casing.safe_snake_case(name) + + +def pythonize_method_name(name: str): + return casing.safe_snake_case(name) diff --git a/betterproto/lib/google/protobuf.py b/betterproto/lib/google/protobuf/__init__.py similarity index 99% rename from betterproto/lib/google/protobuf.py rename to betterproto/lib/google/protobuf/__init__.py index fd379d5..936d175 100644 --- a/betterproto/lib/google/protobuf.py +++ b/betterproto/lib/google/protobuf/__init__.py @@ -84,7 +84,7 @@ class FieldOptionsCType(betterproto.Enum): STRING_PIECE = 2 -class FieldOptionsJSType(betterproto.Enum): +class FieldOptionsJsType(betterproto.Enum): JS_NORMAL = 0 JS_STRING = 1 JS_NUMBER = 2 @@ -717,7 +717,7 @@ class FieldOptions(betterproto.Message): # use the JavaScript "number" type. The behavior of the default option # JS_NORMAL is implementation dependent. This option is an enum to permit # additional types to be added, e.g. goog.math.Integer. - jstype: "FieldOptionsJSType" = betterproto.enum_field(6) + jstype: "FieldOptionsJsType" = betterproto.enum_field(6) # Should this field be parsed lazily? Lazy applies only to message-type # fields. It means that when the outer message is initially parsed, the # inner message's contents will not be parsed but instead stored in encoded diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 928f026..e835fab 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -2,14 +2,19 @@ import itertools import os.path +import pathlib import re -import stringcase import sys import textwrap from typing import List, Union + import betterproto -from betterproto.casing import safe_snake_case -from betterproto.compile.importing import get_ref_type +from betterproto.compile.importing import get_type_reference +from betterproto.compile.naming import ( + pythonize_class_name, + pythonize_field_name, + pythonize_method_name, +) try: # betterproto[compiler] specific dependencies @@ -35,27 +40,22 @@ except ImportError as err: raise SystemExit(1) -def py_type( - package: str, - imports: set, - message: DescriptorProto, - descriptor: FieldDescriptorProto, -) -> str: - if descriptor.type in [1, 2]: +def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str: + if field.type in [1, 2]: return "float" - elif descriptor.type in [3, 4, 5, 6, 7, 13, 15, 16, 17, 18]: + elif field.type in [3, 4, 5, 6, 7, 13, 15, 16, 17, 18]: return "int" - elif descriptor.type == 8: + elif field.type == 8: return "bool" - elif descriptor.type == 9: + elif field.type == 9: return "str" - elif descriptor.type in [11, 14]: + elif field.type in [11, 14]: # Type referencing another defined Message or a named enum - return get_ref_type(package, imports, descriptor.type_name) - elif descriptor.type == 12: + return get_type_reference(package, imports, field.type_name) + elif field.type == 12: return "bytes" else: - raise NotImplementedError(f"Unknown type {descriptor.type}") + raise NotImplementedError(f"Unknown type {field.type}") def get_py_zero(type_num: int) -> Union[str, float]: @@ -131,17 +131,17 @@ def generate_code(request, response): output_map = {} for proto_file in request.proto_file: - out = proto_file.package - - if out == "google.protobuf" and "INCLUDE_GOOGLE" not in plugin_options: + if ( + proto_file.package == "google.protobuf" + and "INCLUDE_GOOGLE" not in plugin_options + ): continue - if not out: - out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".") + output_file = str(pathlib.Path(*proto_file.package.split("."), "__init__.py")) - if out not in output_map: - output_map[out] = {"package": proto_file.package, "files": []} - output_map[out]["files"].append(proto_file) + if output_file not in output_map: + output_map[output_file] = {"package": proto_file.package, "files": []} + output_map[output_file]["files"].append(proto_file) # TODO: Figure out how to handle gRPC request/response messages and add # processing below for Service. @@ -160,17 +160,10 @@ def generate_code(request, response): "services": [], } - type_mapping = {} - for proto_file in options["files"]: - # print(proto_file.message_type, file=sys.stderr) - # print(proto_file.service, file=sys.stderr) - # print(proto_file.source_code_info, file=sys.stderr) - + item: DescriptorProto for item, path in traverse(proto_file): - # print(item, file=sys.stderr) - # print(path, file=sys.stderr) - data = {"name": item.name, "py_name": stringcase.pascalcase(item.name)} + data = {"name": item.name, "py_name": pythonize_class_name(item.name)} if isinstance(item, DescriptorProto): # print(item, file=sys.stderr) @@ -187,7 +180,7 @@ def generate_code(request, response): ) for i, f in enumerate(item.field): - t = py_type(package, output["imports"], item, f) + t = py_type(package, output["imports"], f) zero = get_py_zero(f.type) repeated = False @@ -222,13 +215,11 @@ def generate_code(request, response): k = py_type( package, output["imports"], - item, nested.field[0], ) v = py_type( package, output["imports"], - item, nested.field[1], ) t = f"Dict[{k}, {v}]" @@ -264,7 +255,7 @@ def generate_code(request, response): data["properties"].append( { "name": f.name, - "py_name": safe_snake_case(f.name), + "py_name": pythonize_field_name(f.name), "number": f.number, "comment": get_comment(proto_file, path + [2, i]), "proto_type": int(f.type), @@ -305,14 +296,14 @@ def generate_code(request, response): data = { "name": service.name, - "py_name": stringcase.pascalcase(service.name), + "py_name": pythonize_class_name(service.name), "comment": get_comment(proto_file, [6, i]), "methods": [], } for j, method in enumerate(service.method): input_message = None - input_type = get_ref_type( + input_type = get_type_reference( package, output["imports"], method.input_type ).strip('"') for msg in output["messages"]: @@ -326,14 +317,14 @@ def generate_code(request, response): data["methods"].append( { "name": method.name, - "py_name": stringcase.snakecase(method.name), + "py_name": pythonize_method_name(method.name), "comment": get_comment(proto_file, [6, i, 2, j], indent=8), "route": f"/{package}.{service.name}/{method.name}", - "input": get_ref_type( + "input": get_type_reference( package, output["imports"], method.input_type ).strip('"'), "input_message": input_message, - "output": get_ref_type( + "output": get_type_reference( package, output["imports"], method.output_type, @@ -359,8 +350,7 @@ def generate_code(request, response): # Fill response f = response.file.add() - # print(filename, file=sys.stderr) - f.name = filename.replace(".", os.path.sep) + ".py" + f.name = filename # Render and then format the output file. f.content = black.format_str( @@ -368,32 +358,23 @@ def generate_code(request, response): mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])), ) - inits = set([""]) - for f in response.file: - # Ensure output paths exist - # print(f.name, file=sys.stderr) - dirnames = os.path.dirname(f.name) - if dirnames: - os.makedirs(dirnames, exist_ok=True) - base = "" - for part in dirnames.split(os.path.sep): - base = os.path.join(base, part) - inits.add(base) - - for base in inits: - name = os.path.join(base, "__init__.py") - - if os.path.exists(name): - # Never overwrite inits as they may have custom stuff in them. - continue + # Make each output directory a package with __init__ file + output_paths = set(pathlib.Path(path) for path in output_map.keys()) + init_files = ( + set( + directory.joinpath("__init__.py") + for path in output_paths + for directory in path.parents + ) + - output_paths + ) + for init_file in init_files: init = response.file.add() - init.name = name - init.content = b"" + init.name = str(init_file) - filenames = sorted([f.name for f in response.file]) - for fname in filenames: - print(f"Writing {fname}", file=sys.stderr) + for filename in sorted(output_paths.union(init_files)): + print(f"Writing {filename}", file=sys.stderr) def main(): diff --git a/betterproto/tests/README.md b/betterproto/tests/README.md index 1892cea..51cd8ec 100644 --- a/betterproto/tests/README.md +++ b/betterproto/tests/README.md @@ -12,12 +12,12 @@ inputs/ ## Test case directory structure -Each testcase has a `.proto` file with a message called `Test`, a matching `.json` file and optionally a custom test file called `test_*.py`. +Each testcase has a `.proto` file with a message called `Test`, and optionally a matching `.json` file and a custom test called `test_*.py`. ```bash bool/ bool.proto - bool.json + bool.json # optional test_bool.py # optional ``` @@ -61,21 +61,22 @@ def test_value(): The following tests are automatically executed for all cases: -- [x] Can the generated python code imported? +- [x] Can the generated python code be imported? - [x] Can the generated message class be instantiated? - [x] Is the generated code compatible with the Google's `grpc_tools.protoc` implementation? + - _when `.json` is present_ ## Running the tests -- `pipenv run generate` - This generates +- `pipenv run generate` + This generates: - `betterproto/tests/output_betterproto` — *the plugin generated python classes* - `betterproto/tests/output_reference` — *reference implementation classes* - `pipenv run test` ## Intentionally Failing tests -The standard test suite includes tests that fail by intention. These tests document known bugs and missing features that are intended to be corrented in the future. +The standard test suite includes tests that fail by intention. These tests document known bugs and missing features that are intended to be corrected in the future. When running `pytest`, they show up as `x` or `X` in the test results. diff --git a/betterproto/tests/inputs/bool/test_bool.py b/betterproto/tests/inputs/bool/test_bool.py index 0d4daa6..3131236 100644 --- a/betterproto/tests/inputs/bool/test_bool.py +++ b/betterproto/tests/inputs/bool/test_bool.py @@ -1,4 +1,4 @@ -from betterproto.tests.output_betterproto.bool.bool import Test +from betterproto.tests.output_betterproto.bool import Test def test_value(): diff --git a/betterproto/tests/inputs/casing/casing.proto b/betterproto/tests/inputs/casing/casing.proto index ad0c427..ca458b5 100644 --- a/betterproto/tests/inputs/casing/casing.proto +++ b/betterproto/tests/inputs/casing/casing.proto @@ -10,6 +10,7 @@ message Test { int32 camelCase = 1; my_enum snake_case = 2; snake_case_message snake_case_message = 3; + int32 UPPERCASE = 4; } message snake_case_message { diff --git a/betterproto/tests/inputs/casing/test_casing.py b/betterproto/tests/inputs/casing/test_casing.py index 3255c4e..1c0dc80 100644 --- a/betterproto/tests/inputs/casing/test_casing.py +++ b/betterproto/tests/inputs/casing/test_casing.py @@ -1,5 +1,5 @@ -import betterproto.tests.output_betterproto.casing.casing as casing -from betterproto.tests.output_betterproto.casing.casing import Test +import betterproto.tests.output_betterproto.casing as casing +from betterproto.tests.output_betterproto.casing import Test def test_message_attributes(): @@ -8,6 +8,7 @@ def test_message_attributes(): message, "snake_case_message" ), "snake_case field name is same in python" assert hasattr(message, "camel_case"), "CamelCase field is snake_case in python" + assert hasattr(message, "uppercase"), "UPPERCASE field is lowercase in python" def test_message_casing(): diff --git a/betterproto/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.json b/betterproto/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.json deleted file mode 100644 index 83bd111..0000000 --- a/betterproto/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "UPPERCASE": 10, - "UPPERCASE_V2": 10, - "UPPER_CAMEL_CASE": 10 -} diff --git a/betterproto/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py b/betterproto/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py index d77119e..e0dee0c 100644 --- a/betterproto/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py +++ b/betterproto/tests/inputs/casing_message_field_uppercase/casing_message_field_uppercase.py @@ -1,6 +1,4 @@ -from betterproto.tests.output_betterproto.casing_message_field_uppercase.casing_message_field_uppercase import ( - Test, -) +from betterproto.tests.output_betterproto.casing_message_field_uppercase import Test def test_message_casing(): diff --git a/betterproto/tests/inputs/config.py b/betterproto/tests/inputs/config.py index 4bacfac..245e508 100644 --- a/betterproto/tests/inputs/config.py +++ b/betterproto/tests/inputs/config.py @@ -1,19 +1,15 @@ # Test cases that are expected to fail, e.g. unimplemented features or bug-fixes. # Remove from list when fixed. -tests = { - "import_root_sibling", # 61 - "import_child_package_from_package", # 58 - "import_root_package_from_child", # 60 - "import_parent_package_from_child", # 59 - "import_circular_dependency", # failing because of other bugs now - "import_packages_same_name", # 25 +xfail = { + "import_circular_dependency", "oneof_enum", # 63 - "casing_message_field_uppercase", # 11 "namespace_keywords", # 70 "namespace_builtin_types", # 53 "googletypes_struct", # 9 "googletypes_value", # 9 "enum_skipped_value", # 93 + "import_capitalized_package", + "example", # This is the example in the readme. Not a test. } services = { diff --git a/betterproto/tests/inputs/example/example.proto b/betterproto/tests/inputs/example/example.proto new file mode 100644 index 0000000..edc4d87 --- /dev/null +++ b/betterproto/tests/inputs/example/example.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; + +package hello; + +// Greeting represents a message you can tell a user. +message Greeting { + string message = 1; +} diff --git a/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py b/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py index bd5f602..2e37f88 100644 --- a/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py +++ b/betterproto/tests/inputs/googletypes_response/test_googletypes_response.py @@ -4,9 +4,7 @@ import betterproto.lib.google.protobuf as protobuf import pytest from betterproto.tests.mocks import MockChannel -from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import ( - TestStub, -) +from betterproto.tests.output_betterproto.googletypes_response import TestStub test_cases = [ (TestStub.get_double, protobuf.DoubleValue, 2.5), diff --git a/betterproto/tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py b/betterproto/tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py index 00b980a..4ef8c22 100644 --- a/betterproto/tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py +++ b/betterproto/tests/inputs/googletypes_response_embedded/test_googletypes_response_embedded.py @@ -1,7 +1,7 @@ import pytest from betterproto.tests.mocks import MockChannel -from betterproto.tests.output_betterproto.googletypes_response_embedded.googletypes_response_embedded import ( +from betterproto.tests.output_betterproto.googletypes_response_embedded import ( Output, TestStub, ) diff --git a/betterproto/tests/inputs/import_capitalized_package/capitalized.proto b/betterproto/tests/inputs/import_capitalized_package/capitalized.proto new file mode 100644 index 0000000..0b73bab --- /dev/null +++ b/betterproto/tests/inputs/import_capitalized_package/capitalized.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; + + +package Capitalized; + +message Message { + +} diff --git a/betterproto/tests/inputs/import_capitalized_package/test.proto b/betterproto/tests/inputs/import_capitalized_package/test.proto new file mode 100644 index 0000000..f94bbc9 --- /dev/null +++ b/betterproto/tests/inputs/import_capitalized_package/test.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +import "capitalized.proto"; + +// Tests that we can import from a package with a capital name, that looks like a nested type, but isn't. + +message Test { + Capitalized.Message message = 1; +} diff --git a/betterproto/tests/inputs/import_circular_dependency/import_circular_dependency.proto b/betterproto/tests/inputs/import_circular_dependency/import_circular_dependency.proto index 589d14f..7d02aad 100644 --- a/betterproto/tests/inputs/import_circular_dependency/import_circular_dependency.proto +++ b/betterproto/tests/inputs/import_circular_dependency/import_circular_dependency.proto @@ -3,9 +3,9 @@ syntax = "proto3"; import "root.proto"; import "other.proto"; -// This test-case verifies that future implementations will support circular dependencies in the generated python files. +// This test-case verifies support for circular dependencies in the generated python files. // -// This becomes important when generating 1 python file/module per package, rather than 1 file per proto file. +// This is important because we generate 1 python file/module per package, rather than 1 file per proto file. // // Scenario: // @@ -24,5 +24,5 @@ import "other.proto"; // (root: Test & RootPackageMessage) <-------> (other: OtherPackageMessage) message Test { RootPackageMessage message = 1; - other.OtherPackageMessage other =2; + other.OtherPackageMessage other = 2; } diff --git a/betterproto/tests/inputs/import_cousin_package/cousin.proto b/betterproto/tests/inputs/import_cousin_package/cousin.proto new file mode 100644 index 0000000..4361545 --- /dev/null +++ b/betterproto/tests/inputs/import_cousin_package/cousin.proto @@ -0,0 +1,6 @@ +syntax = "proto3"; + +package cousin.cousin_subpackage; + +message CousinMessage { +} diff --git a/betterproto/tests/inputs/import_cousin_package/test.proto b/betterproto/tests/inputs/import_cousin_package/test.proto new file mode 100644 index 0000000..53f3b7f --- /dev/null +++ b/betterproto/tests/inputs/import_cousin_package/test.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package test.subpackage; + +import "cousin.proto"; + +// Verify that we can import message unrelated to us + +message Test { + cousin.cousin_subpackage.CousinMessage message = 1; +} diff --git a/betterproto/tests/inputs/import_cousin_package_same_name/cousin.proto b/betterproto/tests/inputs/import_cousin_package_same_name/cousin.proto new file mode 100644 index 0000000..9253b95 --- /dev/null +++ b/betterproto/tests/inputs/import_cousin_package_same_name/cousin.proto @@ -0,0 +1,6 @@ +syntax = "proto3"; + +package cousin.subpackage; + +message CousinMessage { +} diff --git a/betterproto/tests/inputs/import_cousin_package_same_name/test.proto b/betterproto/tests/inputs/import_cousin_package_same_name/test.proto new file mode 100644 index 0000000..fe31b5f --- /dev/null +++ b/betterproto/tests/inputs/import_cousin_package_same_name/test.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package test.subpackage; + +import "cousin.proto"; + +// Verify that we can import a message unrelated to us, in a subpackage with the same name as us. + +message Test { + cousin.subpackage.CousinMessage message = 1; +} diff --git a/betterproto/tests/inputs/import_root_package_from_child/child.proto b/betterproto/tests/inputs/import_root_package_from_child/child.proto new file mode 100644 index 0000000..d2b29cc --- /dev/null +++ b/betterproto/tests/inputs/import_root_package_from_child/child.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package child; + +import "root.proto"; + +// Verify that we can import root message from child package + +message Test { + RootMessage message = 1; +} diff --git a/betterproto/tests/inputs/import_root_package_from_child/import_root_package_from_child.proto b/betterproto/tests/inputs/import_root_package_from_child/import_root_package_from_child.proto deleted file mode 100644 index 9e7dbcd..0000000 --- a/betterproto/tests/inputs/import_root_package_from_child/import_root_package_from_child.proto +++ /dev/null @@ -1,11 +0,0 @@ -syntax = "proto3"; - -import "root.proto"; - -package child; - -// Tests generated imports when a message inside a child-package refers to a message defined in the root. - -message Test { - RootMessage message = 1; -} diff --git a/betterproto/tests/inputs/import_service_input_message/test_import_service.py b/betterproto/tests/inputs/import_service_input_message/test_import_service.py index a8395fc..891b77a 100644 --- a/betterproto/tests/inputs/import_service_input_message/test_import_service.py +++ b/betterproto/tests/inputs/import_service_input_message/test_import_service.py @@ -1,7 +1,7 @@ import pytest from betterproto.tests.mocks import MockChannel -from betterproto.tests.output_betterproto.import_service_input_message.import_service_input_message import ( +from betterproto.tests.output_betterproto.import_service_input_message import ( RequestResponse, TestStub, ) diff --git a/betterproto/tests/inputs/nested/nested.proto b/betterproto/tests/inputs/nested/nested.proto index 974bf86..98bafd9 100644 --- a/betterproto/tests/inputs/nested/nested.proto +++ b/betterproto/tests/inputs/nested/nested.proto @@ -15,4 +15,4 @@ message Test { message Sibling { int32 foo = 1; -} +} \ No newline at end of file diff --git a/betterproto/tests/inputs/nested2/nested2.proto b/betterproto/tests/inputs/nested2/nested2.proto new file mode 100644 index 0000000..3e39918 --- /dev/null +++ b/betterproto/tests/inputs/nested2/nested2.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +import "package.proto"; + +message Game { + message Player { + enum Race { + human = 0; + orc = 1; + } + } +} + +message Test { + Game game = 1; + Game.Player GamePlayer = 2; + Game.Player.Race GamePlayerRace = 3; + equipment.Weapon Weapon = 4; +} \ No newline at end of file diff --git a/betterproto/tests/inputs/nested2/package.proto b/betterproto/tests/inputs/nested2/package.proto new file mode 100644 index 0000000..4466256 --- /dev/null +++ b/betterproto/tests/inputs/nested2/package.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package equipment; + +message Weapon { + +} \ No newline at end of file diff --git a/betterproto/tests/inputs/nestedtwice/nestedtwice.json b/betterproto/tests/inputs/nestedtwice/nestedtwice.json index 203a660..c953132 100644 --- a/betterproto/tests/inputs/nestedtwice/nestedtwice.json +++ b/betterproto/tests/inputs/nestedtwice/nestedtwice.json @@ -1,10 +1,10 @@ { - "root": { + "top": { "name": "double-nested", - "parent": { - "child": [{"foo": "hello"}], - "enumChild": ["A"], - "rootParentChild": [{"a": "hello"}], + "middle": { + "bottom": [{"foo": "hello"}], + "enumBottom": ["A"], + "topMiddleBottom": [{"a": "hello"}], "bar": true } } diff --git a/betterproto/tests/inputs/nestedtwice/nestedtwice.proto b/betterproto/tests/inputs/nestedtwice/nestedtwice.proto index 91c8050..7e9c206 100644 --- a/betterproto/tests/inputs/nestedtwice/nestedtwice.proto +++ b/betterproto/tests/inputs/nestedtwice/nestedtwice.proto @@ -1,26 +1,26 @@ syntax = "proto3"; message Test { - message Root { - message Parent { - message RootParentChild { + message Top { + message Middle { + message TopMiddleBottom { string a = 1; } - enum EnumChild{ + enum EnumBottom{ A = 0; B = 1; } - message Child { + message Bottom { string foo = 1; } reserved 1; - repeated Child child = 2; - repeated EnumChild enumChild=3; - repeated RootParentChild rootParentChild=4; + repeated Bottom bottom = 2; + repeated EnumBottom enumBottom=3; + repeated TopMiddleBottom topMiddleBottom=4; bool bar = 5; } string name = 1; - Parent parent = 2; + Middle middle = 2; } - Root root = 1; + Top top = 1; } diff --git a/betterproto/tests/inputs/oneof/test_oneof.py b/betterproto/tests/inputs/oneof/test_oneof.py index 400e4fd..058563e 100644 --- a/betterproto/tests/inputs/oneof/test_oneof.py +++ b/betterproto/tests/inputs/oneof/test_oneof.py @@ -1,5 +1,5 @@ import betterproto -from betterproto.tests.output_betterproto.oneof.oneof import Test +from betterproto.tests.output_betterproto.oneof import Test from betterproto.tests.util import get_test_case_json_data diff --git a/betterproto/tests/inputs/oneof_enum/test_oneof_enum.py b/betterproto/tests/inputs/oneof_enum/test_oneof_enum.py index 1d6ea98..ae9d40d 100644 --- a/betterproto/tests/inputs/oneof_enum/test_oneof_enum.py +++ b/betterproto/tests/inputs/oneof_enum/test_oneof_enum.py @@ -1,7 +1,7 @@ import pytest import betterproto -from betterproto.tests.output_betterproto.oneof_enum.oneof_enum import ( +from betterproto.tests.output_betterproto.oneof_enum import ( Move, Signal, Test, diff --git a/betterproto/tests/inputs/ref/ref.proto b/betterproto/tests/inputs/ref/ref.proto index 6945590..e09fb15 100644 --- a/betterproto/tests/inputs/ref/ref.proto +++ b/betterproto/tests/inputs/ref/ref.proto @@ -1,7 +1,5 @@ syntax = "proto3"; -package ref; - import "repeatedmessage.proto"; message Test { diff --git a/betterproto/tests/test_casing.py b/betterproto/tests/test_casing.py new file mode 100644 index 0000000..ec60483 --- /dev/null +++ b/betterproto/tests/test_casing.py @@ -0,0 +1,125 @@ +import pytest + +from betterproto.casing import camel_case, pascal_case, snake_case + + +@pytest.mark.parametrize( + ["value", "expected"], + [ + ("", ""), + ("a", "A"), + ("foobar", "Foobar"), + ("fooBar", "FooBar"), + ("FooBar", "FooBar"), + ("foo.bar", "FooBar"), + ("foo_bar", "FooBar"), + ("FOOBAR", "Foobar"), + ("FOOBar", "FooBar"), + ("UInt32", "UInt32"), + ("FOO_BAR", "FooBar"), + ("FOOBAR1", "Foobar1"), + ("FOOBAR_1", "Foobar1"), + ("FOO1BAR2", "Foo1Bar2"), + ("foo__bar", "FooBar"), + ("_foobar", "Foobar"), + ("foobaR", "FoobaR"), + ("foo~bar", "FooBar"), + ("foo:bar", "FooBar"), + ("1foobar", "1Foobar"), + ], +) +def test_pascal_case(value, expected): + actual = pascal_case(value, strict=True) + assert actual == expected, f"{value} => {expected} (actual: {actual})" + + +@pytest.mark.parametrize( + ["value", "expected"], + [ + ("", ""), + ("a", "a"), + ("foobar", "foobar"), + ("fooBar", "fooBar"), + ("FooBar", "fooBar"), + ("foo.bar", "fooBar"), + ("foo_bar", "fooBar"), + ("FOOBAR", "foobar"), + ("FOO_BAR", "fooBar"), + ("FOOBAR1", "foobar1"), + ("FOOBAR_1", "foobar1"), + ("FOO1BAR2", "foo1Bar2"), + ("foo__bar", "fooBar"), + ("_foobar", "foobar"), + ("foobaR", "foobaR"), + ("foo~bar", "fooBar"), + ("foo:bar", "fooBar"), + ("1foobar", "1Foobar"), + ], +) +def test_camel_case_strict(value, expected): + actual = camel_case(value, strict=True) + assert actual == expected, f"{value} => {expected} (actual: {actual})" + + +@pytest.mark.parametrize( + ["value", "expected"], + [ + ("foo_bar", "fooBar"), + ("FooBar", "fooBar"), + ("foo__bar", "foo_Bar"), + ("foo__Bar", "foo__Bar"), + ], +) +def test_camel_case_not_strict(value, expected): + actual = camel_case(value, strict=False) + assert actual == expected, f"{value} => {expected} (actual: {actual})" + + +@pytest.mark.parametrize( + ["value", "expected"], + [ + ("", ""), + ("a", "a"), + ("foobar", "foobar"), + ("fooBar", "foo_bar"), + ("FooBar", "foo_bar"), + ("foo.bar", "foo_bar"), + ("foo_bar", "foo_bar"), + ("foo_Bar", "foo_bar"), + ("FOOBAR", "foobar"), + ("FOOBar", "foo_bar"), + ("UInt32", "u_int32"), + ("FOO_BAR", "foo_bar"), + ("FOOBAR1", "foobar1"), + ("FOOBAR_1", "foobar_1"), + ("FOOBAR_123", "foobar_123"), + ("FOO1BAR2", "foo1_bar2"), + ("foo__bar", "foo_bar"), + ("_foobar", "foobar"), + ("foobaR", "fooba_r"), + ("foo~bar", "foo_bar"), + ("foo:bar", "foo_bar"), + ("1foobar", "1_foobar"), + ("GetUInt64", "get_u_int64"), + ], +) +def test_snake_case_strict(value, expected): + actual = snake_case(value) + assert actual == expected, f"{value} => {expected} (actual: {actual})" + + +@pytest.mark.parametrize( + ["value", "expected"], + [ + ("fooBar", "foo_bar"), + ("FooBar", "foo_bar"), + ("foo_Bar", "foo__bar"), + ("foo__bar", "foo__bar"), + ("FOOBar", "foo_bar"), + ("__foo", "__foo"), + ("GetUInt64", "get_u_int64"), + ], +) +def test_snake_case_not_strict(value, expected): + actual = snake_case(value, strict=False) + assert actual == expected, f"{value} => {expected} (actual: {actual})" diff --git a/betterproto/tests/test_get_ref_type.py b/betterproto/tests/test_get_ref_type.py index 9635356..2bedf76 100644 --- a/betterproto/tests/test_get_ref_type.py +++ b/betterproto/tests/test_get_ref_type.py @@ -1,6 +1,6 @@ import pytest -from ..compile.importing import get_ref_type +from ..compile.importing import get_type_reference, parse_source_type_name @pytest.mark.parametrize( @@ -28,14 +28,16 @@ from ..compile.importing import get_ref_type ), ], ) -def test_import_google_wellknown_types_non_wrappers( +def test_reference_google_wellknown_types_non_wrappers( google_type: str, expected_name: str, expected_import: str ): imports = set() - name = get_ref_type(package="", imports=imports, type_name=google_type) + name = get_type_reference(package="", imports=imports, source_type=google_type) assert name == expected_name - assert imports.__contains__(expected_import) + assert imports.__contains__( + expected_import + ), f"{expected_import} not found in {imports}" @pytest.mark.parametrize( @@ -52,9 +54,11 @@ def test_import_google_wellknown_types_non_wrappers( (".google.protobuf.BytesValue", "Optional[bytes]"), ], ) -def test_importing_google_wrappers_unwraps_them(google_type: str, expected_name: str): +def test_referenceing_google_wrappers_unwraps_them( + google_type: str, expected_name: str +): imports = set() - name = get_ref_type(package="", imports=imports, type_name=google_type) + name = get_type_reference(package="", imports=imports, source_type=google_type) assert name == expected_name assert imports == set() @@ -74,9 +78,238 @@ def test_importing_google_wrappers_unwraps_them(google_type: str, expected_name: (".google.protobuf.BytesValue", "betterproto_lib_google_protobuf.BytesValue"), ], ) -def test_importing_google_wrappers_without_unwrapping( +def test_referenceing_google_wrappers_without_unwrapping( google_type: str, expected_name: str ): - name = get_ref_type(package="", imports=set(), type_name=google_type, unwrap=False) + name = get_type_reference( + package="", imports=set(), source_type=google_type, unwrap=False + ) assert name == expected_name + + +def test_reference_child_package_from_package(): + imports = set() + name = get_type_reference( + package="package", imports=imports, source_type="package.child.Message" + ) + + assert imports == {"from . import child"} + assert name == "child.Message" + + +def test_reference_child_package_from_root(): + imports = set() + name = get_type_reference(package="", imports=imports, source_type="child.Message") + + assert imports == {"from . import child"} + assert name == "child.Message" + + +def test_reference_camel_cased(): + imports = set() + name = get_type_reference( + package="", imports=imports, source_type="child_package.example_message" + ) + + assert imports == {"from . import child_package"} + assert name == "child_package.ExampleMessage" + + +def test_reference_nested_child_from_root(): + imports = set() + name = get_type_reference( + package="", imports=imports, source_type="nested.child.Message" + ) + + assert imports == {"from .nested import child as nested_child"} + assert name == "nested_child.Message" + + +def test_reference_deeply_nested_child_from_root(): + imports = set() + name = get_type_reference( + package="", imports=imports, source_type="deeply.nested.child.Message" + ) + + assert imports == {"from .deeply.nested import child as deeply_nested_child"} + assert name == "deeply_nested_child.Message" + + +def test_reference_deeply_nested_child_from_package(): + imports = set() + name = get_type_reference( + package="package", + imports=imports, + source_type="package.deeply.nested.child.Message", + ) + + assert imports == {"from .deeply.nested import child as deeply_nested_child"} + assert name == "deeply_nested_child.Message" + + +def test_reference_root_sibling(): + imports = set() + name = get_type_reference(package="", imports=imports, source_type="Message") + + assert imports == set() + assert name == '"Message"' + + +def test_reference_nested_siblings(): + imports = set() + name = get_type_reference(package="foo", imports=imports, source_type="foo.Message") + + assert imports == set() + assert name == '"Message"' + + +def test_reference_deeply_nested_siblings(): + imports = set() + name = get_type_reference( + package="foo.bar", imports=imports, source_type="foo.bar.Message" + ) + + assert imports == set() + assert name == '"Message"' + + +def test_reference_parent_package_from_child(): + imports = set() + name = get_type_reference( + package="package.child", imports=imports, source_type="package.Message" + ) + + assert imports == {"from ... import package as __package__"} + assert name == "__package__.Message" + + +def test_reference_parent_package_from_deeply_nested_child(): + imports = set() + name = get_type_reference( + package="package.deeply.nested.child", + imports=imports, + source_type="package.deeply.nested.Message", + ) + + assert imports == {"from ... import nested as __nested__"} + assert name == "__nested__.Message" + + +def test_reference_ancestor_package_from_nested_child(): + imports = set() + name = get_type_reference( + package="package.ancestor.nested.child", + imports=imports, + source_type="package.ancestor.Message", + ) + + assert imports == {"from .... import ancestor as ___ancestor__"} + assert name == "___ancestor__.Message" + + +def test_reference_root_package_from_child(): + imports = set() + name = get_type_reference( + package="package.child", imports=imports, source_type="Message" + ) + + assert imports == {"from ... import Message as __Message__"} + assert name == "__Message__" + + +def test_reference_root_package_from_deeply_nested_child(): + imports = set() + name = get_type_reference( + package="package.deeply.nested.child", imports=imports, source_type="Message" + ) + + assert imports == {"from ..... import Message as ____Message__"} + assert name == "____Message__" + + +def test_reference_unrelated_package(): + imports = set() + name = get_type_reference(package="a", imports=imports, source_type="p.Message") + + assert imports == {"from .. import p as _p__"} + assert name == "_p__.Message" + + +def test_reference_unrelated_nested_package(): + imports = set() + name = get_type_reference(package="a.b", imports=imports, source_type="p.q.Message") + + assert imports == {"from ...p import q as __p_q__"} + assert name == "__p_q__.Message" + + +def test_reference_unrelated_deeply_nested_package(): + imports = set() + name = get_type_reference( + package="a.b.c.d", imports=imports, source_type="p.q.r.s.Message" + ) + + assert imports == {"from .....p.q.r import s as ____p_q_r_s__"} + assert name == "____p_q_r_s__.Message" + + +def test_reference_cousin_package(): + imports = set() + name = get_type_reference(package="a.x", imports=imports, source_type="a.y.Message") + + assert imports == {"from .. import y as _y__"} + assert name == "_y__.Message" + + +def test_reference_cousin_package_different_name(): + imports = set() + name = get_type_reference( + package="test.package1", imports=imports, source_type="cousin.package2.Message" + ) + + assert imports == {"from ...cousin import package2 as __cousin_package2__"} + assert name == "__cousin_package2__.Message" + + +def test_reference_cousin_package_same_name(): + imports = set() + name = get_type_reference( + package="test.package", imports=imports, source_type="cousin.package.Message" + ) + + assert imports == {"from ...cousin import package as __cousin_package__"} + assert name == "__cousin_package__.Message" + + +def test_reference_far_cousin_package(): + imports = set() + name = get_type_reference( + package="a.x.y", imports=imports, source_type="a.b.c.Message" + ) + + assert imports == {"from ...b import c as __b_c__"} + assert name == "__b_c__.Message" + + +def test_reference_far_far_cousin_package(): + imports = set() + name = get_type_reference( + package="a.x.y.z", imports=imports, source_type="a.b.c.d.Message" + ) + + assert imports == {"from ....b.c import d as ___b_c_d__"} + assert name == "___b_c_d__.Message" + + +@pytest.mark.parametrize( + ["full_name", "expected_output"], + [ + ("package.SomeMessage.NestedType", ("package", "SomeMessage.NestedType")), + (".package.SomeMessage.NestedType", ("package", "SomeMessage.NestedType")), + (".service.ExampleRequest", ("service", "ExampleRequest")), + (".package.lower_case_message", ("package", "lower_case_message")), + ], +) +def test_parse_field_type_name(full_name, expected_output): + assert parse_source_type_name(full_name) == expected_output diff --git a/betterproto/tests/test_inputs.py b/betterproto/tests/test_inputs.py index cb5974d..3183957 100644 --- a/betterproto/tests/test_inputs.py +++ b/betterproto/tests/test_inputs.py @@ -3,6 +3,7 @@ import json import os import sys from collections import namedtuple +from types import ModuleType from typing import Set import pytest @@ -10,7 +11,12 @@ import pytest import betterproto from betterproto.tests.inputs import config as test_input_config from betterproto.tests.mocks import MockChannel -from betterproto.tests.util import get_directories, get_test_case_json_data, inputs_path +from betterproto.tests.util import ( + find_module, + get_directories, + get_test_case_json_data, + inputs_path, +) # Force pure-python implementation instead of C++, otherwise imports # break things because we can't properly reset the symbol database. @@ -50,14 +56,17 @@ class TestCases: test_cases = TestCases( path=inputs_path, services=test_input_config.services, - xfail=test_input_config.tests, + xfail=test_input_config.xfail, ) plugin_output_package = "betterproto.tests.output_betterproto" reference_output_package = "betterproto.tests.output_reference" +TestData = namedtuple("TestData", ["plugin_module", "reference_module", "json_data"]) -TestData = namedtuple("TestData", "plugin_module, reference_module, json_data") + +def module_has_entry_point(module: ModuleType): + return any(hasattr(module, attr) for attr in ["Test", "TestStub"]) @pytest.fixture @@ -75,11 +84,19 @@ def test_data(request): sys.path.append(reference_module_root) + plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}") + + plugin_module_entry_point = find_module(plugin_module, module_has_entry_point) + + if not plugin_module_entry_point: + raise Exception( + f"Test case {repr(test_case_name)} has no entry point. " + "Please add a proto message or service called Test and recompile." + ) + yield ( TestData( - plugin_module=importlib.import_module( - f"{plugin_output_package}.{test_case_name}.{test_case_name}" - ), + plugin_module=plugin_module_entry_point, reference_module=lambda: importlib.import_module( f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2" ), diff --git a/betterproto/tests/util.py b/betterproto/tests/util.py index 61ba53e..3689cb8 100644 --- a/betterproto/tests/util.py +++ b/betterproto/tests/util.py @@ -1,7 +1,10 @@ import asyncio +import importlib import os +import pathlib from pathlib import Path -from typing import Generator, IO, Optional +from types import ModuleType +from typing import Callable, Generator, Optional os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" @@ -55,3 +58,35 @@ def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] = with test_data_file_path.open("r") as fh: return fh.read() + + +def find_module( + module: ModuleType, predicate: Callable[[ModuleType], bool] +) -> Optional[ModuleType]: + """ + Recursively search module tree for a module that matches the search predicate. + Assumes that the submodules are directories containing __init__.py. + + Example: + + # find module inside foo that contains Test + import foo + test_module = find_module(foo, lambda m: hasattr(m, 'Test')) + """ + if predicate(module): + return module + + module_path = pathlib.Path(*module.__path__) + + for sub in list(sub.parent for sub in module_path.glob("**/__init__.py")): + if sub == module_path: + continue + sub_module_path = sub.relative_to(module_path) + sub_module_name = ".".join(sub_module_path.parts) + + sub_module = importlib.import_module(f".{sub_module_name}", module.__name__) + + if predicate(sub_module): + return sub_module + + return None diff --git a/poetry.lock b/poetry.lock index 6afb23f..3434f55 100644 --- a/poetry.lock +++ b/poetry.lock @@ -271,7 +271,7 @@ docs = ["sphinx", "rst.linker", "jaraco.packaging"] category = "main" description = "A very fast and expressive template engine." name = "jinja2" -optional = true +optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" version = "2.11.2" @@ -285,7 +285,7 @@ i18n = ["Babel (>=0.8)"] category = "main" description = "Safely add untrusted strings to HTML/XML markup." name = "markupsafe" -optional = true +optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*" version = "1.1.1" @@ -369,7 +369,7 @@ dev = ["pre-commit", "tox"] category = "main" description = "Protocol Buffers" name = "protobuf" -optional = true +optional = false python-versions = "*" version = "3.12.2" @@ -490,14 +490,6 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" version = "1.15.0" -[[package]] -category = "main" -description = "String case converter." -name = "stringcase" -optional = false -python-versions = "*" -version = "1.2.0" - [[package]] category = "main" description = "Python Library for Tom's Obvious, Minimal Language" @@ -612,7 +604,7 @@ testing = ["jaraco.itertools", "func-timeout"] compiler = ["black", "jinja2", "protobuf"] [metadata] -content-hash = "ecafcaed2d4a25c2829e6dc3ef3c56cd72a8bc28c25c7aeae3484c978c816722" +content-hash = "8a4fa01ede86e1b5ba35b9dab8b6eacee766a9b5666f48ab41445c01882ab003" python-versions = "^3.6" [metadata.files] @@ -865,6 +857,8 @@ protobuf = [ {file = "protobuf-3.12.2-cp37-cp37m-win_amd64.whl", hash = "sha256:e72736dd822748b0721f41f9aaaf6a5b6d5cfc78f6c8690263aef8bba4457f0e"}, {file = "protobuf-3.12.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:87535dc2d2ef007b9d44e309d2b8ea27a03d2fa09556a72364d706fcb7090828"}, {file = "protobuf-3.12.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:50b5fee674878b14baea73b4568dc478c46a31dd50157a5b5d2f71138243b1a9"}, + {file = "protobuf-3.12.2-py2.py3-none-any.whl", hash = "sha256:a96f8fc625e9ff568838e556f6f6ae8eca8b4837cdfb3f90efcb7c00e342a2eb"}, + {file = "protobuf-3.12.2.tar.gz", hash = "sha256:49ef8ab4c27812a89a76fa894fe7a08f42f2147078392c0dee51d4a444ef6df5"}, ] py = [ {file = "py-1.8.2-py2.py3-none-any.whl", hash = "sha256:a673fa23d7000440cc885c17dbd34fafcb7d7a6e230b29f6766400de36a33c44"}, @@ -920,9 +914,6 @@ six = [ {file = "six-1.15.0-py2.py3-none-any.whl", hash = "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced"}, {file = "six-1.15.0.tar.gz", hash = "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259"}, ] -stringcase = [ - {file = "stringcase-1.2.0.tar.gz", hash = "sha256:48a06980661908efe8d9d34eab2b6c13aefa2163b3ced26972902e3bdfd87008"}, -] toml = [ {file = "toml-0.10.1-py2.py3-none-any.whl", hash = "sha256:bda89d5935c2eac546d648028b9901107a595863cb36bae0c73ac804a9b4ce88"}, {file = "toml-0.10.1.tar.gz", hash = "sha256:926b612be1e5ce0634a2ca03470f95169cf16f939018233a670519cb4ac58b0f"}, diff --git a/pyproject.toml b/pyproject.toml index f4466d9..f1e966b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,6 @@ dataclasses = { version = "^0.7", python = ">=3.6, <3.7" } grpclib = "^0.3.1" jinja2 = { version = "^2.11.2", optional = true } protobuf = { version = "^3.12.2", optional = true } -stringcase = "^1.2.0" [tool.poetry.dev-dependencies] black = "^19.10b0"