Merge branch 'master' into michael-sayapin/master

This commit is contained in:
Bouke Versteegh 2020-07-04 11:23:42 +02:00 committed by GitHub
commit ac32bcd25a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
41 changed files with 870 additions and 238 deletions

View File

@ -68,14 +68,15 @@ message Greeting {
You can run the following:
```sh
protoc -I . --python_betterproto_out=. example.proto
mkdir lib
protoc -I . --python_betterproto_out=lib example.proto
```
This will generate `hello.py` which looks like:
This will generate `lib/hello/__init__.py` which looks like:
```py
```python
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: hello.proto
# sources: example.proto
# plugin: python-betterproto
from dataclasses import dataclass
@ -83,7 +84,7 @@ import betterproto
@dataclass
class Hello(betterproto.Message):
class Greeting(betterproto.Message):
"""Greeting represents a message you can tell a user."""
message: str = betterproto.string_field(1)
@ -91,23 +92,23 @@ class Hello(betterproto.Message):
Now you can use it!
```py
>>> from hello import Hello
>>> test = Hello()
```python
>>> from lib.hello import Greeting
>>> test = Greeting()
>>> test
Hello(message='')
Greeting(message='')
>>> test.message = "Hey!"
>>> test
Hello(message="Hey!")
Greeting(message="Hey!")
>>> serialized = bytes(test)
>>> serialized
b'\n\x04Hey!'
>>> another = Hello().parse(serialized)
>>> another = Greeting().parse(serialized)
>>> another
Hello(message="Hey!")
Greeting(message="Hey!")
>>> another.to_dict()
{"message": "Hey!"}
@ -315,7 +316,7 @@ To benefit from the collection of standard development tasks ensure you have mak
This project enforces [black](https://github.com/psf/black) python code formatting.
Before commiting changes run:
Before committing changes run:
```sh
make format
@ -336,7 +337,7 @@ Adding a standard test case is easy.
- Create a new directory `betterproto/tests/inputs/<name>`
- add `<name>.proto` with a message called `Test`
- add `<name>.json` with some test data
- add `<name>.json` with some test data (optional)
It will be picked up automatically when you run the tests.

View File

@ -7,27 +7,22 @@ import sys
from abc import ABC
from base64 import b64decode, b64encode
from datetime import datetime, timedelta, timezone
import stringcase
from typing import (
Any,
AsyncGenerator,
Callable,
Collection,
Dict,
Generator,
Iterator,
List,
Mapping,
Optional,
Set,
SupportsBytes,
Tuple,
Type,
Union,
get_type_hints,
)
from ._types import ST, T
from .casing import safe_snake_case
from ._types import T
from .casing import camel_case, safe_snake_case, safe_snake_case, snake_case
from .grpc.grpclib_client import ServiceStub
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
@ -124,8 +119,8 @@ DATETIME_ZERO = datetime_default_gen()
class Casing(enum.Enum):
"""Casing constants for serialization."""
CAMEL = stringcase.camelcase
SNAKE = stringcase.snakecase
CAMEL = camel_case
SNAKE = snake_case
class _PLACEHOLDER:

View File

@ -1,9 +1,21 @@
import stringcase
import re
# Word delimiters and symbols that will not be preserved when re-casing.
# language=PythonRegExp
SYMBOLS = "[^a-zA-Z0-9]*"
# Optionally capitalized word.
# language=PythonRegExp
WORD = "[A-Z]*[a-z]*[0-9]*"
# Uppercase word, not followed by lowercase letters.
# language=PythonRegExp
WORD_UPPER = "[A-Z]+(?![a-z])[0-9]*"
def safe_snake_case(value: str) -> str:
"""Snake case a value taking into account Python keywords."""
value = stringcase.snakecase(value)
value = snake_case(value)
if value in [
"and",
"as",
@ -39,3 +51,70 @@ def safe_snake_case(value: str) -> str:
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
value += "_"
return value
def snake_case(value: str, strict: bool = True):
"""
Join words with an underscore into lowercase and remove symbols.
@param value: value to convert
@param strict: force single underscores
"""
def substitute_word(symbols, word, is_start):
if not word:
return ""
if strict:
delimiter_count = 0 if is_start else 1 # Single underscore if strict.
elif is_start:
delimiter_count = len(symbols)
elif word.isupper() or word.islower():
delimiter_count = max(
1, len(symbols)
) # Preserve all delimiters if not strict.
else:
delimiter_count = len(symbols) + 1 # Extra underscore for leading capital.
return ("_" * delimiter_count) + word.lower()
snake = re.sub(
f"(^)?({SYMBOLS})({WORD_UPPER}|{WORD})",
lambda groups: substitute_word(groups[2], groups[3], groups[1] is not None),
value,
)
return snake
def pascal_case(value: str, strict: bool = True):
"""
Capitalize each word and remove symbols.
@param value: value to convert
@param strict: output only alphanumeric characters
"""
def substitute_word(symbols, word):
if strict:
return word.capitalize() # Remove all delimiters
if word.islower():
delimiter_length = len(symbols[:-1]) # Lose one delimiter
else:
delimiter_length = len(symbols) # Preserve all delimiters
return ("_" * delimiter_length) + word.capitalize()
return re.sub(
f"({SYMBOLS})({WORD_UPPER}|{WORD})",
lambda groups: substitute_word(groups[1], groups[2]),
value,
)
def camel_case(value: str, strict: bool = True):
"""
Capitalize all words except first and remove symbols.
"""
return lowercase_first(pascal_case(value, strict=strict))
def lowercase_first(value: str):
return value[0:1].lower() + value[1:]

View File

@ -1,70 +1,160 @@
from typing import Dict, Type
import stringcase
import os
import re
from typing import Dict, List, Set, Type
from betterproto import safe_snake_case
from betterproto.compile.naming import pythonize_class_name
from betterproto.lib.google import protobuf as google_protobuf
WRAPPER_TYPES: Dict[str, Type] = {
"google.protobuf.DoubleValue": google_protobuf.DoubleValue,
"google.protobuf.FloatValue": google_protobuf.FloatValue,
"google.protobuf.Int32Value": google_protobuf.Int32Value,
"google.protobuf.Int64Value": google_protobuf.Int64Value,
"google.protobuf.UInt32Value": google_protobuf.UInt32Value,
"google.protobuf.UInt64Value": google_protobuf.UInt64Value,
"google.protobuf.BoolValue": google_protobuf.BoolValue,
"google.protobuf.StringValue": google_protobuf.StringValue,
"google.protobuf.BytesValue": google_protobuf.BytesValue,
".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
".google.protobuf.FloatValue": google_protobuf.FloatValue,
".google.protobuf.Int32Value": google_protobuf.Int32Value,
".google.protobuf.Int64Value": google_protobuf.Int64Value,
".google.protobuf.UInt32Value": google_protobuf.UInt32Value,
".google.protobuf.UInt64Value": google_protobuf.UInt64Value,
".google.protobuf.BoolValue": google_protobuf.BoolValue,
".google.protobuf.StringValue": google_protobuf.StringValue,
".google.protobuf.BytesValue": google_protobuf.BytesValue,
}
def get_ref_type(
package: str, imports: set, type_name: str, unwrap: bool = True
def parse_source_type_name(field_type_name):
"""
Split full source type name into package and type name.
E.g. 'root.package.Message' -> ('root.package', 'Message')
'root.Message.SomeEnum' -> ('root', 'Message.SomeEnum')
"""
package_match = re.match(r"^\.?([^A-Z]+)\.(.+)", field_type_name)
if package_match:
package = package_match.group(1)
name = package_match.group(2)
else:
package = ""
name = field_type_name.lstrip(".")
return package, name
def get_type_reference(
package: str, imports: set, source_type: str, unwrap: bool = True,
) -> str:
"""
Return a Python type name for a proto type reference. Adds the import if
necessary. Unwraps well known type if required.
"""
# If the package name is a blank string, then this should still work
# because by convention packages are lowercase and message/enum types are
# pascal-cased. May require refactoring in the future.
type_name = type_name.lstrip(".")
is_wrapper = type_name in WRAPPER_TYPES
if unwrap:
if is_wrapper:
wrapped_type = type(WRAPPER_TYPES[type_name]().value)
if source_type in WRAPPER_TYPES:
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
return f"Optional[{wrapped_type.__name__}]"
if type_name == "google.protobuf.Duration":
if source_type == ".google.protobuf.Duration":
return "timedelta"
if type_name == "google.protobuf.Timestamp":
if source_type == ".google.protobuf.Timestamp":
return "datetime"
if type_name.startswith(package):
parts = type_name.lstrip(package).lstrip(".").split(".")
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
# This is the current package, which has nested types flattened.
# foo.bar_thing => FooBarThing
cased = [stringcase.pascalcase(part) for part in parts]
type_name = f'"{"".join(cased)}"'
source_package, source_type = parse_source_type_name(source_type)
# Use precompiled classes for google.protobuf.* objects
if type_name.startswith("google.protobuf.") and type_name.count(".") == 2:
type_name = type_name.rsplit(".", maxsplit=1)[1]
import_package = "betterproto.lib.google.protobuf"
import_alias = safe_snake_case(import_package)
imports.add(f"import {import_package} as {import_alias}")
return f"{import_alias}.{type_name}"
current_package: List[str] = package.split(".") if package else []
py_package: List[str] = source_package.split(".") if source_package else []
py_type: str = pythonize_class_name(source_type)
if "." in type_name:
# This is imported from another package. No need
# to use a forward ref and we need to add the import.
parts = type_name.split(".")
parts[-1] = stringcase.pascalcase(parts[-1])
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
type_name = f"{parts[-2]}.{parts[-1]}"
compiling_google_protobuf = current_package == ["google", "protobuf"]
importing_google_protobuf = py_package == ["google", "protobuf"]
if importing_google_protobuf and not compiling_google_protobuf:
py_package = ["betterproto", "lib"] + py_package
return type_name
if py_package[:1] == ["betterproto"]:
return reference_absolute(imports, py_package, py_type)
if py_package == current_package:
return reference_sibling(py_type)
if py_package[: len(current_package)] == current_package:
return reference_descendent(current_package, imports, py_package, py_type)
if current_package[: len(py_package)] == py_package:
return reference_ancestor(current_package, imports, py_package, py_type)
return reference_cousin(current_package, imports, py_package, py_type)
def reference_absolute(imports, py_package, py_type):
"""
Returns a reference to a python type located in the root, i.e. sys.path.
"""
string_import = ".".join(py_package)
string_alias = safe_snake_case(string_import)
imports.add(f"import {string_import} as {string_alias}")
return f"{string_alias}.{py_type}"
def reference_sibling(py_type: str) -> str:
"""
Returns a reference to a python type within the same package as the current package.
"""
return f'"{py_type}"'
def reference_descendent(
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
) -> str:
"""
Returns a reference to a python type in a package that is a descendent of the current package,
and adds the required import that is aliased to avoid name conflicts.
"""
importing_descendent = py_package[len(current_package) :]
string_from = ".".join(importing_descendent[:-1])
string_import = importing_descendent[-1]
if string_from:
string_alias = "_".join(importing_descendent)
imports.add(f"from .{string_from} import {string_import} as {string_alias}")
return f"{string_alias}.{py_type}"
else:
imports.add(f"from . import {string_import}")
return f"{string_import}.{py_type}"
def reference_ancestor(
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
) -> str:
"""
Returns a reference to a python type in a package which is an ancestor to the current package,
and adds the required import that is aliased (if possible) to avoid name conflicts.
Adds trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34).
"""
distance_up = len(current_package) - len(py_package)
if py_package:
string_import = py_package[-1]
string_alias = f"_{'_' * distance_up}{string_import}__"
string_from = f"..{'.' * distance_up}"
imports.add(f"from {string_from} import {string_import} as {string_alias}")
return f"{string_alias}.{py_type}"
else:
string_alias = f"{'_' * distance_up}{py_type}__"
imports.add(f"from .{'.' * distance_up} import {py_type} as {string_alias}")
return string_alias
def reference_cousin(
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
) -> str:
"""
Returns a reference to a python type in a package that is not descendent, ancestor or sibling,
and adds the required import that is aliased to avoid name conflicts.
"""
shared_ancestry = os.path.commonprefix([current_package, py_package])
distance_up = len(current_package) - len(shared_ancestry)
string_from = f".{'.' * distance_up}" + ".".join(
py_package[len(shared_ancestry) : -1]
)
string_import = py_package[-1]
# Add trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34)
string_alias = (
f"{'_' * distance_up}"
+ safe_snake_case(".".join(py_package[len(shared_ancestry) :]))
+ "__"
)
imports.add(f"from {string_from} import {string_import} as {string_alias}")
return f"{string_alias}.{py_type}"

View File

@ -0,0 +1,13 @@
from betterproto import casing
def pythonize_class_name(name):
return casing.pascal_case(name)
def pythonize_field_name(name: str):
return casing.safe_snake_case(name)
def pythonize_method_name(name: str):
return casing.safe_snake_case(name)

View File

@ -84,7 +84,7 @@ class FieldOptionsCType(betterproto.Enum):
STRING_PIECE = 2
class FieldOptionsJSType(betterproto.Enum):
class FieldOptionsJsType(betterproto.Enum):
JS_NORMAL = 0
JS_STRING = 1
JS_NUMBER = 2
@ -717,7 +717,7 @@ class FieldOptions(betterproto.Message):
# use the JavaScript "number" type. The behavior of the default option
# JS_NORMAL is implementation dependent. This option is an enum to permit
# additional types to be added, e.g. goog.math.Integer.
jstype: "FieldOptionsJSType" = betterproto.enum_field(6)
jstype: "FieldOptionsJsType" = betterproto.enum_field(6)
# Should this field be parsed lazily? Lazy applies only to message-type
# fields. It means that when the outer message is initially parsed, the
# inner message's contents will not be parsed but instead stored in encoded

View File

@ -2,14 +2,19 @@
import itertools
import os.path
import pathlib
import re
import stringcase
import sys
import textwrap
from typing import List, Union
import betterproto
from betterproto.casing import safe_snake_case
from betterproto.compile.importing import get_ref_type
from betterproto.compile.importing import get_type_reference
from betterproto.compile.naming import (
pythonize_class_name,
pythonize_field_name,
pythonize_method_name,
)
try:
# betterproto[compiler] specific dependencies
@ -35,27 +40,22 @@ except ImportError as err:
raise SystemExit(1)
def py_type(
package: str,
imports: set,
message: DescriptorProto,
descriptor: FieldDescriptorProto,
) -> str:
if descriptor.type in [1, 2]:
def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str:
if field.type in [1, 2]:
return "float"
elif descriptor.type in [3, 4, 5, 6, 7, 13, 15, 16, 17, 18]:
elif field.type in [3, 4, 5, 6, 7, 13, 15, 16, 17, 18]:
return "int"
elif descriptor.type == 8:
elif field.type == 8:
return "bool"
elif descriptor.type == 9:
elif field.type == 9:
return "str"
elif descriptor.type in [11, 14]:
elif field.type in [11, 14]:
# Type referencing another defined Message or a named enum
return get_ref_type(package, imports, descriptor.type_name)
elif descriptor.type == 12:
return get_type_reference(package, imports, field.type_name)
elif field.type == 12:
return "bytes"
else:
raise NotImplementedError(f"Unknown type {descriptor.type}")
raise NotImplementedError(f"Unknown type {field.type}")
def get_py_zero(type_num: int) -> Union[str, float]:
@ -131,17 +131,17 @@ def generate_code(request, response):
output_map = {}
for proto_file in request.proto_file:
out = proto_file.package
if out == "google.protobuf" and "INCLUDE_GOOGLE" not in plugin_options:
if (
proto_file.package == "google.protobuf"
and "INCLUDE_GOOGLE" not in plugin_options
):
continue
if not out:
out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".")
output_file = str(pathlib.Path(*proto_file.package.split("."), "__init__.py"))
if out not in output_map:
output_map[out] = {"package": proto_file.package, "files": []}
output_map[out]["files"].append(proto_file)
if output_file not in output_map:
output_map[output_file] = {"package": proto_file.package, "files": []}
output_map[output_file]["files"].append(proto_file)
# TODO: Figure out how to handle gRPC request/response messages and add
# processing below for Service.
@ -160,17 +160,10 @@ def generate_code(request, response):
"services": [],
}
type_mapping = {}
for proto_file in options["files"]:
# print(proto_file.message_type, file=sys.stderr)
# print(proto_file.service, file=sys.stderr)
# print(proto_file.source_code_info, file=sys.stderr)
item: DescriptorProto
for item, path in traverse(proto_file):
# print(item, file=sys.stderr)
# print(path, file=sys.stderr)
data = {"name": item.name, "py_name": stringcase.pascalcase(item.name)}
data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
if isinstance(item, DescriptorProto):
# print(item, file=sys.stderr)
@ -187,7 +180,7 @@ def generate_code(request, response):
)
for i, f in enumerate(item.field):
t = py_type(package, output["imports"], item, f)
t = py_type(package, output["imports"], f)
zero = get_py_zero(f.type)
repeated = False
@ -222,13 +215,11 @@ def generate_code(request, response):
k = py_type(
package,
output["imports"],
item,
nested.field[0],
)
v = py_type(
package,
output["imports"],
item,
nested.field[1],
)
t = f"Dict[{k}, {v}]"
@ -264,7 +255,7 @@ def generate_code(request, response):
data["properties"].append(
{
"name": f.name,
"py_name": safe_snake_case(f.name),
"py_name": pythonize_field_name(f.name),
"number": f.number,
"comment": get_comment(proto_file, path + [2, i]),
"proto_type": int(f.type),
@ -305,14 +296,14 @@ def generate_code(request, response):
data = {
"name": service.name,
"py_name": stringcase.pascalcase(service.name),
"py_name": pythonize_class_name(service.name),
"comment": get_comment(proto_file, [6, i]),
"methods": [],
}
for j, method in enumerate(service.method):
input_message = None
input_type = get_ref_type(
input_type = get_type_reference(
package, output["imports"], method.input_type
).strip('"')
for msg in output["messages"]:
@ -326,14 +317,14 @@ def generate_code(request, response):
data["methods"].append(
{
"name": method.name,
"py_name": stringcase.snakecase(method.name),
"py_name": pythonize_method_name(method.name),
"comment": get_comment(proto_file, [6, i, 2, j], indent=8),
"route": f"/{package}.{service.name}/{method.name}",
"input": get_ref_type(
"input": get_type_reference(
package, output["imports"], method.input_type
).strip('"'),
"input_message": input_message,
"output": get_ref_type(
"output": get_type_reference(
package,
output["imports"],
method.output_type,
@ -359,8 +350,7 @@ def generate_code(request, response):
# Fill response
f = response.file.add()
# print(filename, file=sys.stderr)
f.name = filename.replace(".", os.path.sep) + ".py"
f.name = filename
# Render and then format the output file.
f.content = black.format_str(
@ -368,32 +358,23 @@ def generate_code(request, response):
mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])),
)
inits = set([""])
for f in response.file:
# Ensure output paths exist
# print(f.name, file=sys.stderr)
dirnames = os.path.dirname(f.name)
if dirnames:
os.makedirs(dirnames, exist_ok=True)
base = ""
for part in dirnames.split(os.path.sep):
base = os.path.join(base, part)
inits.add(base)
for base in inits:
name = os.path.join(base, "__init__.py")
if os.path.exists(name):
# Never overwrite inits as they may have custom stuff in them.
continue
# Make each output directory a package with __init__ file
output_paths = set(pathlib.Path(path) for path in output_map.keys())
init_files = (
set(
directory.joinpath("__init__.py")
for path in output_paths
for directory in path.parents
)
- output_paths
)
for init_file in init_files:
init = response.file.add()
init.name = name
init.content = b""
init.name = str(init_file)
filenames = sorted([f.name for f in response.file])
for fname in filenames:
print(f"Writing {fname}", file=sys.stderr)
for filename in sorted(output_paths.union(init_files)):
print(f"Writing {filename}", file=sys.stderr)
def main():

View File

@ -12,12 +12,12 @@ inputs/
## Test case directory structure
Each testcase has a `<name>.proto` file with a message called `Test`, a matching `.json` file and optionally a custom test file called `test_*.py`.
Each testcase has a `<name>.proto` file with a message called `Test`, and optionally a matching `.json` file and a custom test called `test_*.py`.
```bash
bool/
bool.proto
bool.json
bool.json # optional
test_bool.py # optional
```
@ -61,21 +61,22 @@ def test_value():
The following tests are automatically executed for all cases:
- [x] Can the generated python code imported?
- [x] Can the generated python code be imported?
- [x] Can the generated message class be instantiated?
- [x] Is the generated code compatible with the Google's `grpc_tools.protoc` implementation?
- _when `.json` is present_
## Running the tests
- `pipenv run generate`
This generates
- `pipenv run generate`
This generates:
- `betterproto/tests/output_betterproto` &mdash; *the plugin generated python classes*
- `betterproto/tests/output_reference` &mdash; *reference implementation classes*
- `pipenv run test`
## Intentionally Failing tests
The standard test suite includes tests that fail by intention. These tests document known bugs and missing features that are intended to be corrented in the future.
The standard test suite includes tests that fail by intention. These tests document known bugs and missing features that are intended to be corrected in the future.
When running `pytest`, they show up as `x` or `X` in the test results.

View File

@ -1,4 +1,4 @@
from betterproto.tests.output_betterproto.bool.bool import Test
from betterproto.tests.output_betterproto.bool import Test
def test_value():

View File

@ -10,6 +10,7 @@ message Test {
int32 camelCase = 1;
my_enum snake_case = 2;
snake_case_message snake_case_message = 3;
int32 UPPERCASE = 4;
}
message snake_case_message {

View File

@ -1,5 +1,5 @@
import betterproto.tests.output_betterproto.casing.casing as casing
from betterproto.tests.output_betterproto.casing.casing import Test
import betterproto.tests.output_betterproto.casing as casing
from betterproto.tests.output_betterproto.casing import Test
def test_message_attributes():
@ -8,6 +8,7 @@ def test_message_attributes():
message, "snake_case_message"
), "snake_case field name is same in python"
assert hasattr(message, "camel_case"), "CamelCase field is snake_case in python"
assert hasattr(message, "uppercase"), "UPPERCASE field is lowercase in python"
def test_message_casing():

View File

@ -1,5 +0,0 @@
{
"UPPERCASE": 10,
"UPPERCASE_V2": 10,
"UPPER_CAMEL_CASE": 10
}

View File

@ -1,6 +1,4 @@
from betterproto.tests.output_betterproto.casing_message_field_uppercase.casing_message_field_uppercase import (
Test,
)
from betterproto.tests.output_betterproto.casing_message_field_uppercase import Test
def test_message_casing():

View File

@ -1,19 +1,15 @@
# Test cases that are expected to fail, e.g. unimplemented features or bug-fixes.
# Remove from list when fixed.
tests = {
"import_root_sibling", # 61
"import_child_package_from_package", # 58
"import_root_package_from_child", # 60
"import_parent_package_from_child", # 59
"import_circular_dependency", # failing because of other bugs now
"import_packages_same_name", # 25
xfail = {
"import_circular_dependency",
"oneof_enum", # 63
"casing_message_field_uppercase", # 11
"namespace_keywords", # 70
"namespace_builtin_types", # 53
"googletypes_struct", # 9
"googletypes_value", # 9
"enum_skipped_value", # 93
"import_capitalized_package",
"example", # This is the example in the readme. Not a test.
}
services = {

View File

@ -0,0 +1,8 @@
syntax = "proto3";
package hello;
// Greeting represents a message you can tell a user.
message Greeting {
string message = 1;
}

View File

@ -4,9 +4,7 @@ import betterproto.lib.google.protobuf as protobuf
import pytest
from betterproto.tests.mocks import MockChannel
from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import (
TestStub,
)
from betterproto.tests.output_betterproto.googletypes_response import TestStub
test_cases = [
(TestStub.get_double, protobuf.DoubleValue, 2.5),

View File

@ -1,7 +1,7 @@
import pytest
from betterproto.tests.mocks import MockChannel
from betterproto.tests.output_betterproto.googletypes_response_embedded.googletypes_response_embedded import (
from betterproto.tests.output_betterproto.googletypes_response_embedded import (
Output,
TestStub,
)

View File

@ -0,0 +1,8 @@
syntax = "proto3";
package Capitalized;
message Message {
}

View File

@ -0,0 +1,9 @@
syntax = "proto3";
import "capitalized.proto";
// Tests that we can import from a package with a capital name, that looks like a nested type, but isn't.
message Test {
Capitalized.Message message = 1;
}

View File

@ -3,9 +3,9 @@ syntax = "proto3";
import "root.proto";
import "other.proto";
// This test-case verifies that future implementations will support circular dependencies in the generated python files.
// This test-case verifies support for circular dependencies in the generated python files.
//
// This becomes important when generating 1 python file/module per package, rather than 1 file per proto file.
// This is important because we generate 1 python file/module per package, rather than 1 file per proto file.
//
// Scenario:
//
@ -24,5 +24,5 @@ import "other.proto";
// (root: Test & RootPackageMessage) <-------> (other: OtherPackageMessage)
message Test {
RootPackageMessage message = 1;
other.OtherPackageMessage other =2;
other.OtherPackageMessage other = 2;
}

View File

@ -0,0 +1,6 @@
syntax = "proto3";
package cousin.cousin_subpackage;
message CousinMessage {
}

View File

@ -0,0 +1,11 @@
syntax = "proto3";
package test.subpackage;
import "cousin.proto";
// Verify that we can import message unrelated to us
message Test {
cousin.cousin_subpackage.CousinMessage message = 1;
}

View File

@ -0,0 +1,6 @@
syntax = "proto3";
package cousin.subpackage;
message CousinMessage {
}

View File

@ -0,0 +1,11 @@
syntax = "proto3";
package test.subpackage;
import "cousin.proto";
// Verify that we can import a message unrelated to us, in a subpackage with the same name as us.
message Test {
cousin.subpackage.CousinMessage message = 1;
}

View File

@ -0,0 +1,11 @@
syntax = "proto3";
package child;
import "root.proto";
// Verify that we can import root message from child package
message Test {
RootMessage message = 1;
}

View File

@ -1,11 +0,0 @@
syntax = "proto3";
import "root.proto";
package child;
// Tests generated imports when a message inside a child-package refers to a message defined in the root.
message Test {
RootMessage message = 1;
}

View File

@ -1,7 +1,7 @@
import pytest
from betterproto.tests.mocks import MockChannel
from betterproto.tests.output_betterproto.import_service_input_message.import_service_input_message import (
from betterproto.tests.output_betterproto.import_service_input_message import (
RequestResponse,
TestStub,
)

View File

@ -15,4 +15,4 @@ message Test {
message Sibling {
int32 foo = 1;
}
}

View File

@ -0,0 +1,19 @@
syntax = "proto3";
import "package.proto";
message Game {
message Player {
enum Race {
human = 0;
orc = 1;
}
}
}
message Test {
Game game = 1;
Game.Player GamePlayer = 2;
Game.Player.Race GamePlayerRace = 3;
equipment.Weapon Weapon = 4;
}

View File

@ -0,0 +1,7 @@
syntax = "proto3";
package equipment;
message Weapon {
}

View File

@ -1,10 +1,10 @@
{
"root": {
"top": {
"name": "double-nested",
"parent": {
"child": [{"foo": "hello"}],
"enumChild": ["A"],
"rootParentChild": [{"a": "hello"}],
"middle": {
"bottom": [{"foo": "hello"}],
"enumBottom": ["A"],
"topMiddleBottom": [{"a": "hello"}],
"bar": true
}
}

View File

@ -1,26 +1,26 @@
syntax = "proto3";
message Test {
message Root {
message Parent {
message RootParentChild {
message Top {
message Middle {
message TopMiddleBottom {
string a = 1;
}
enum EnumChild{
enum EnumBottom{
A = 0;
B = 1;
}
message Child {
message Bottom {
string foo = 1;
}
reserved 1;
repeated Child child = 2;
repeated EnumChild enumChild=3;
repeated RootParentChild rootParentChild=4;
repeated Bottom bottom = 2;
repeated EnumBottom enumBottom=3;
repeated TopMiddleBottom topMiddleBottom=4;
bool bar = 5;
}
string name = 1;
Parent parent = 2;
Middle middle = 2;
}
Root root = 1;
Top top = 1;
}

View File

@ -1,5 +1,5 @@
import betterproto
from betterproto.tests.output_betterproto.oneof.oneof import Test
from betterproto.tests.output_betterproto.oneof import Test
from betterproto.tests.util import get_test_case_json_data

View File

@ -1,7 +1,7 @@
import pytest
import betterproto
from betterproto.tests.output_betterproto.oneof_enum.oneof_enum import (
from betterproto.tests.output_betterproto.oneof_enum import (
Move,
Signal,
Test,

View File

@ -1,7 +1,5 @@
syntax = "proto3";
package ref;
import "repeatedmessage.proto";
message Test {

View File

@ -0,0 +1,125 @@
import pytest
from betterproto.casing import camel_case, pascal_case, snake_case
@pytest.mark.parametrize(
["value", "expected"],
[
("", ""),
("a", "A"),
("foobar", "Foobar"),
("fooBar", "FooBar"),
("FooBar", "FooBar"),
("foo.bar", "FooBar"),
("foo_bar", "FooBar"),
("FOOBAR", "Foobar"),
("FOOBar", "FooBar"),
("UInt32", "UInt32"),
("FOO_BAR", "FooBar"),
("FOOBAR1", "Foobar1"),
("FOOBAR_1", "Foobar1"),
("FOO1BAR2", "Foo1Bar2"),
("foo__bar", "FooBar"),
("_foobar", "Foobar"),
("foobaR", "FoobaR"),
("foo~bar", "FooBar"),
("foo:bar", "FooBar"),
("1foobar", "1Foobar"),
],
)
def test_pascal_case(value, expected):
actual = pascal_case(value, strict=True)
assert actual == expected, f"{value} => {expected} (actual: {actual})"
@pytest.mark.parametrize(
["value", "expected"],
[
("", ""),
("a", "a"),
("foobar", "foobar"),
("fooBar", "fooBar"),
("FooBar", "fooBar"),
("foo.bar", "fooBar"),
("foo_bar", "fooBar"),
("FOOBAR", "foobar"),
("FOO_BAR", "fooBar"),
("FOOBAR1", "foobar1"),
("FOOBAR_1", "foobar1"),
("FOO1BAR2", "foo1Bar2"),
("foo__bar", "fooBar"),
("_foobar", "foobar"),
("foobaR", "foobaR"),
("foo~bar", "fooBar"),
("foo:bar", "fooBar"),
("1foobar", "1Foobar"),
],
)
def test_camel_case_strict(value, expected):
actual = camel_case(value, strict=True)
assert actual == expected, f"{value} => {expected} (actual: {actual})"
@pytest.mark.parametrize(
["value", "expected"],
[
("foo_bar", "fooBar"),
("FooBar", "fooBar"),
("foo__bar", "foo_Bar"),
("foo__Bar", "foo__Bar"),
],
)
def test_camel_case_not_strict(value, expected):
actual = camel_case(value, strict=False)
assert actual == expected, f"{value} => {expected} (actual: {actual})"
@pytest.mark.parametrize(
["value", "expected"],
[
("", ""),
("a", "a"),
("foobar", "foobar"),
("fooBar", "foo_bar"),
("FooBar", "foo_bar"),
("foo.bar", "foo_bar"),
("foo_bar", "foo_bar"),
("foo_Bar", "foo_bar"),
("FOOBAR", "foobar"),
("FOOBar", "foo_bar"),
("UInt32", "u_int32"),
("FOO_BAR", "foo_bar"),
("FOOBAR1", "foobar1"),
("FOOBAR_1", "foobar_1"),
("FOOBAR_123", "foobar_123"),
("FOO1BAR2", "foo1_bar2"),
("foo__bar", "foo_bar"),
("_foobar", "foobar"),
("foobaR", "fooba_r"),
("foo~bar", "foo_bar"),
("foo:bar", "foo_bar"),
("1foobar", "1_foobar"),
("GetUInt64", "get_u_int64"),
],
)
def test_snake_case_strict(value, expected):
actual = snake_case(value)
assert actual == expected, f"{value} => {expected} (actual: {actual})"
@pytest.mark.parametrize(
["value", "expected"],
[
("fooBar", "foo_bar"),
("FooBar", "foo_bar"),
("foo_Bar", "foo__bar"),
("foo__bar", "foo__bar"),
("FOOBar", "foo_bar"),
("__foo", "__foo"),
("GetUInt64", "get_u_int64"),
],
)
def test_snake_case_not_strict(value, expected):
actual = snake_case(value, strict=False)
assert actual == expected, f"{value} => {expected} (actual: {actual})"

View File

@ -1,6 +1,6 @@
import pytest
from ..compile.importing import get_ref_type
from ..compile.importing import get_type_reference, parse_source_type_name
@pytest.mark.parametrize(
@ -28,14 +28,16 @@ from ..compile.importing import get_ref_type
),
],
)
def test_import_google_wellknown_types_non_wrappers(
def test_reference_google_wellknown_types_non_wrappers(
google_type: str, expected_name: str, expected_import: str
):
imports = set()
name = get_ref_type(package="", imports=imports, type_name=google_type)
name = get_type_reference(package="", imports=imports, source_type=google_type)
assert name == expected_name
assert imports.__contains__(expected_import)
assert imports.__contains__(
expected_import
), f"{expected_import} not found in {imports}"
@pytest.mark.parametrize(
@ -52,9 +54,11 @@ def test_import_google_wellknown_types_non_wrappers(
(".google.protobuf.BytesValue", "Optional[bytes]"),
],
)
def test_importing_google_wrappers_unwraps_them(google_type: str, expected_name: str):
def test_referenceing_google_wrappers_unwraps_them(
google_type: str, expected_name: str
):
imports = set()
name = get_ref_type(package="", imports=imports, type_name=google_type)
name = get_type_reference(package="", imports=imports, source_type=google_type)
assert name == expected_name
assert imports == set()
@ -74,9 +78,238 @@ def test_importing_google_wrappers_unwraps_them(google_type: str, expected_name:
(".google.protobuf.BytesValue", "betterproto_lib_google_protobuf.BytesValue"),
],
)
def test_importing_google_wrappers_without_unwrapping(
def test_referenceing_google_wrappers_without_unwrapping(
google_type: str, expected_name: str
):
name = get_ref_type(package="", imports=set(), type_name=google_type, unwrap=False)
name = get_type_reference(
package="", imports=set(), source_type=google_type, unwrap=False
)
assert name == expected_name
def test_reference_child_package_from_package():
imports = set()
name = get_type_reference(
package="package", imports=imports, source_type="package.child.Message"
)
assert imports == {"from . import child"}
assert name == "child.Message"
def test_reference_child_package_from_root():
imports = set()
name = get_type_reference(package="", imports=imports, source_type="child.Message")
assert imports == {"from . import child"}
assert name == "child.Message"
def test_reference_camel_cased():
imports = set()
name = get_type_reference(
package="", imports=imports, source_type="child_package.example_message"
)
assert imports == {"from . import child_package"}
assert name == "child_package.ExampleMessage"
def test_reference_nested_child_from_root():
imports = set()
name = get_type_reference(
package="", imports=imports, source_type="nested.child.Message"
)
assert imports == {"from .nested import child as nested_child"}
assert name == "nested_child.Message"
def test_reference_deeply_nested_child_from_root():
imports = set()
name = get_type_reference(
package="", imports=imports, source_type="deeply.nested.child.Message"
)
assert imports == {"from .deeply.nested import child as deeply_nested_child"}
assert name == "deeply_nested_child.Message"
def test_reference_deeply_nested_child_from_package():
imports = set()
name = get_type_reference(
package="package",
imports=imports,
source_type="package.deeply.nested.child.Message",
)
assert imports == {"from .deeply.nested import child as deeply_nested_child"}
assert name == "deeply_nested_child.Message"
def test_reference_root_sibling():
imports = set()
name = get_type_reference(package="", imports=imports, source_type="Message")
assert imports == set()
assert name == '"Message"'
def test_reference_nested_siblings():
imports = set()
name = get_type_reference(package="foo", imports=imports, source_type="foo.Message")
assert imports == set()
assert name == '"Message"'
def test_reference_deeply_nested_siblings():
imports = set()
name = get_type_reference(
package="foo.bar", imports=imports, source_type="foo.bar.Message"
)
assert imports == set()
assert name == '"Message"'
def test_reference_parent_package_from_child():
imports = set()
name = get_type_reference(
package="package.child", imports=imports, source_type="package.Message"
)
assert imports == {"from ... import package as __package__"}
assert name == "__package__.Message"
def test_reference_parent_package_from_deeply_nested_child():
imports = set()
name = get_type_reference(
package="package.deeply.nested.child",
imports=imports,
source_type="package.deeply.nested.Message",
)
assert imports == {"from ... import nested as __nested__"}
assert name == "__nested__.Message"
def test_reference_ancestor_package_from_nested_child():
imports = set()
name = get_type_reference(
package="package.ancestor.nested.child",
imports=imports,
source_type="package.ancestor.Message",
)
assert imports == {"from .... import ancestor as ___ancestor__"}
assert name == "___ancestor__.Message"
def test_reference_root_package_from_child():
imports = set()
name = get_type_reference(
package="package.child", imports=imports, source_type="Message"
)
assert imports == {"from ... import Message as __Message__"}
assert name == "__Message__"
def test_reference_root_package_from_deeply_nested_child():
imports = set()
name = get_type_reference(
package="package.deeply.nested.child", imports=imports, source_type="Message"
)
assert imports == {"from ..... import Message as ____Message__"}
assert name == "____Message__"
def test_reference_unrelated_package():
imports = set()
name = get_type_reference(package="a", imports=imports, source_type="p.Message")
assert imports == {"from .. import p as _p__"}
assert name == "_p__.Message"
def test_reference_unrelated_nested_package():
imports = set()
name = get_type_reference(package="a.b", imports=imports, source_type="p.q.Message")
assert imports == {"from ...p import q as __p_q__"}
assert name == "__p_q__.Message"
def test_reference_unrelated_deeply_nested_package():
imports = set()
name = get_type_reference(
package="a.b.c.d", imports=imports, source_type="p.q.r.s.Message"
)
assert imports == {"from .....p.q.r import s as ____p_q_r_s__"}
assert name == "____p_q_r_s__.Message"
def test_reference_cousin_package():
imports = set()
name = get_type_reference(package="a.x", imports=imports, source_type="a.y.Message")
assert imports == {"from .. import y as _y__"}
assert name == "_y__.Message"
def test_reference_cousin_package_different_name():
imports = set()
name = get_type_reference(
package="test.package1", imports=imports, source_type="cousin.package2.Message"
)
assert imports == {"from ...cousin import package2 as __cousin_package2__"}
assert name == "__cousin_package2__.Message"
def test_reference_cousin_package_same_name():
imports = set()
name = get_type_reference(
package="test.package", imports=imports, source_type="cousin.package.Message"
)
assert imports == {"from ...cousin import package as __cousin_package__"}
assert name == "__cousin_package__.Message"
def test_reference_far_cousin_package():
imports = set()
name = get_type_reference(
package="a.x.y", imports=imports, source_type="a.b.c.Message"
)
assert imports == {"from ...b import c as __b_c__"}
assert name == "__b_c__.Message"
def test_reference_far_far_cousin_package():
imports = set()
name = get_type_reference(
package="a.x.y.z", imports=imports, source_type="a.b.c.d.Message"
)
assert imports == {"from ....b.c import d as ___b_c_d__"}
assert name == "___b_c_d__.Message"
@pytest.mark.parametrize(
["full_name", "expected_output"],
[
("package.SomeMessage.NestedType", ("package", "SomeMessage.NestedType")),
(".package.SomeMessage.NestedType", ("package", "SomeMessage.NestedType")),
(".service.ExampleRequest", ("service", "ExampleRequest")),
(".package.lower_case_message", ("package", "lower_case_message")),
],
)
def test_parse_field_type_name(full_name, expected_output):
assert parse_source_type_name(full_name) == expected_output

View File

@ -3,6 +3,7 @@ import json
import os
import sys
from collections import namedtuple
from types import ModuleType
from typing import Set
import pytest
@ -10,7 +11,12 @@ import pytest
import betterproto
from betterproto.tests.inputs import config as test_input_config
from betterproto.tests.mocks import MockChannel
from betterproto.tests.util import get_directories, get_test_case_json_data, inputs_path
from betterproto.tests.util import (
find_module,
get_directories,
get_test_case_json_data,
inputs_path,
)
# Force pure-python implementation instead of C++, otherwise imports
# break things because we can't properly reset the symbol database.
@ -50,14 +56,17 @@ class TestCases:
test_cases = TestCases(
path=inputs_path,
services=test_input_config.services,
xfail=test_input_config.tests,
xfail=test_input_config.xfail,
)
plugin_output_package = "betterproto.tests.output_betterproto"
reference_output_package = "betterproto.tests.output_reference"
TestData = namedtuple("TestData", ["plugin_module", "reference_module", "json_data"])
TestData = namedtuple("TestData", "plugin_module, reference_module, json_data")
def module_has_entry_point(module: ModuleType):
return any(hasattr(module, attr) for attr in ["Test", "TestStub"])
@pytest.fixture
@ -75,11 +84,19 @@ def test_data(request):
sys.path.append(reference_module_root)
plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}")
plugin_module_entry_point = find_module(plugin_module, module_has_entry_point)
if not plugin_module_entry_point:
raise Exception(
f"Test case {repr(test_case_name)} has no entry point. "
"Please add a proto message or service called Test and recompile."
)
yield (
TestData(
plugin_module=importlib.import_module(
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
),
plugin_module=plugin_module_entry_point,
reference_module=lambda: importlib.import_module(
f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2"
),

View File

@ -1,7 +1,10 @@
import asyncio
import importlib
import os
import pathlib
from pathlib import Path
from typing import Generator, IO, Optional
from types import ModuleType
from typing import Callable, Generator, Optional
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
@ -55,3 +58,35 @@ def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] =
with test_data_file_path.open("r") as fh:
return fh.read()
def find_module(
module: ModuleType, predicate: Callable[[ModuleType], bool]
) -> Optional[ModuleType]:
"""
Recursively search module tree for a module that matches the search predicate.
Assumes that the submodules are directories containing __init__.py.
Example:
# find module inside foo that contains Test
import foo
test_module = find_module(foo, lambda m: hasattr(m, 'Test'))
"""
if predicate(module):
return module
module_path = pathlib.Path(*module.__path__)
for sub in list(sub.parent for sub in module_path.glob("**/__init__.py")):
if sub == module_path:
continue
sub_module_path = sub.relative_to(module_path)
sub_module_name = ".".join(sub_module_path.parts)
sub_module = importlib.import_module(f".{sub_module_name}", module.__name__)
if predicate(sub_module):
return sub_module
return None

21
poetry.lock generated
View File

@ -271,7 +271,7 @@ docs = ["sphinx", "rst.linker", "jaraco.packaging"]
category = "main"
description = "A very fast and expressive template engine."
name = "jinja2"
optional = true
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
version = "2.11.2"
@ -285,7 +285,7 @@ i18n = ["Babel (>=0.8)"]
category = "main"
description = "Safely add untrusted strings to HTML/XML markup."
name = "markupsafe"
optional = true
optional = false
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*"
version = "1.1.1"
@ -369,7 +369,7 @@ dev = ["pre-commit", "tox"]
category = "main"
description = "Protocol Buffers"
name = "protobuf"
optional = true
optional = false
python-versions = "*"
version = "3.12.2"
@ -490,14 +490,6 @@ optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
version = "1.15.0"
[[package]]
category = "main"
description = "String case converter."
name = "stringcase"
optional = false
python-versions = "*"
version = "1.2.0"
[[package]]
category = "main"
description = "Python Library for Tom's Obvious, Minimal Language"
@ -612,7 +604,7 @@ testing = ["jaraco.itertools", "func-timeout"]
compiler = ["black", "jinja2", "protobuf"]
[metadata]
content-hash = "ecafcaed2d4a25c2829e6dc3ef3c56cd72a8bc28c25c7aeae3484c978c816722"
content-hash = "8a4fa01ede86e1b5ba35b9dab8b6eacee766a9b5666f48ab41445c01882ab003"
python-versions = "^3.6"
[metadata.files]
@ -865,6 +857,8 @@ protobuf = [
{file = "protobuf-3.12.2-cp37-cp37m-win_amd64.whl", hash = "sha256:e72736dd822748b0721f41f9aaaf6a5b6d5cfc78f6c8690263aef8bba4457f0e"},
{file = "protobuf-3.12.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:87535dc2d2ef007b9d44e309d2b8ea27a03d2fa09556a72364d706fcb7090828"},
{file = "protobuf-3.12.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:50b5fee674878b14baea73b4568dc478c46a31dd50157a5b5d2f71138243b1a9"},
{file = "protobuf-3.12.2-py2.py3-none-any.whl", hash = "sha256:a96f8fc625e9ff568838e556f6f6ae8eca8b4837cdfb3f90efcb7c00e342a2eb"},
{file = "protobuf-3.12.2.tar.gz", hash = "sha256:49ef8ab4c27812a89a76fa894fe7a08f42f2147078392c0dee51d4a444ef6df5"},
]
py = [
{file = "py-1.8.2-py2.py3-none-any.whl", hash = "sha256:a673fa23d7000440cc885c17dbd34fafcb7d7a6e230b29f6766400de36a33c44"},
@ -920,9 +914,6 @@ six = [
{file = "six-1.15.0-py2.py3-none-any.whl", hash = "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced"},
{file = "six-1.15.0.tar.gz", hash = "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259"},
]
stringcase = [
{file = "stringcase-1.2.0.tar.gz", hash = "sha256:48a06980661908efe8d9d34eab2b6c13aefa2163b3ced26972902e3bdfd87008"},
]
toml = [
{file = "toml-0.10.1-py2.py3-none-any.whl", hash = "sha256:bda89d5935c2eac546d648028b9901107a595863cb36bae0c73ac804a9b4ce88"},
{file = "toml-0.10.1.tar.gz", hash = "sha256:926b612be1e5ce0634a2ca03470f95169cf16f939018233a670519cb4ac58b0f"},

View File

@ -18,7 +18,6 @@ dataclasses = { version = "^0.7", python = ">=3.6, <3.7" }
grpclib = "^0.3.1"
jinja2 = { version = "^2.11.2", optional = true }
protobuf = { version = "^3.12.2", optional = true }
stringcase = "^1.2.0"
[tool.poetry.dev-dependencies]
black = "^19.10b0"