Merge branch 'master' into michael-sayapin/master
This commit is contained in:
commit
ac32bcd25a
29
README.md
29
README.md
@ -68,14 +68,15 @@ message Greeting {
|
||||
You can run the following:
|
||||
|
||||
```sh
|
||||
protoc -I . --python_betterproto_out=. example.proto
|
||||
mkdir lib
|
||||
protoc -I . --python_betterproto_out=lib example.proto
|
||||
```
|
||||
|
||||
This will generate `hello.py` which looks like:
|
||||
This will generate `lib/hello/__init__.py` which looks like:
|
||||
|
||||
```py
|
||||
```python
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# sources: hello.proto
|
||||
# sources: example.proto
|
||||
# plugin: python-betterproto
|
||||
from dataclasses import dataclass
|
||||
|
||||
@ -83,7 +84,7 @@ import betterproto
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hello(betterproto.Message):
|
||||
class Greeting(betterproto.Message):
|
||||
"""Greeting represents a message you can tell a user."""
|
||||
|
||||
message: str = betterproto.string_field(1)
|
||||
@ -91,23 +92,23 @@ class Hello(betterproto.Message):
|
||||
|
||||
Now you can use it!
|
||||
|
||||
```py
|
||||
>>> from hello import Hello
|
||||
>>> test = Hello()
|
||||
```python
|
||||
>>> from lib.hello import Greeting
|
||||
>>> test = Greeting()
|
||||
>>> test
|
||||
Hello(message='')
|
||||
Greeting(message='')
|
||||
|
||||
>>> test.message = "Hey!"
|
||||
>>> test
|
||||
Hello(message="Hey!")
|
||||
Greeting(message="Hey!")
|
||||
|
||||
>>> serialized = bytes(test)
|
||||
>>> serialized
|
||||
b'\n\x04Hey!'
|
||||
|
||||
>>> another = Hello().parse(serialized)
|
||||
>>> another = Greeting().parse(serialized)
|
||||
>>> another
|
||||
Hello(message="Hey!")
|
||||
Greeting(message="Hey!")
|
||||
|
||||
>>> another.to_dict()
|
||||
{"message": "Hey!"}
|
||||
@ -315,7 +316,7 @@ To benefit from the collection of standard development tasks ensure you have mak
|
||||
|
||||
This project enforces [black](https://github.com/psf/black) python code formatting.
|
||||
|
||||
Before commiting changes run:
|
||||
Before committing changes run:
|
||||
|
||||
```sh
|
||||
make format
|
||||
@ -336,7 +337,7 @@ Adding a standard test case is easy.
|
||||
|
||||
- Create a new directory `betterproto/tests/inputs/<name>`
|
||||
- add `<name>.proto` with a message called `Test`
|
||||
- add `<name>.json` with some test data
|
||||
- add `<name>.json` with some test data (optional)
|
||||
|
||||
It will be picked up automatically when you run the tests.
|
||||
|
||||
|
@ -7,27 +7,22 @@ import sys
|
||||
from abc import ABC
|
||||
from base64 import b64decode, b64encode
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import stringcase
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
SupportsBytes,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
get_type_hints,
|
||||
)
|
||||
from ._types import ST, T
|
||||
from .casing import safe_snake_case
|
||||
|
||||
from ._types import T
|
||||
from .casing import camel_case, safe_snake_case, safe_snake_case, snake_case
|
||||
from .grpc.grpclib_client import ServiceStub
|
||||
|
||||
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
|
||||
@ -124,8 +119,8 @@ DATETIME_ZERO = datetime_default_gen()
|
||||
class Casing(enum.Enum):
|
||||
"""Casing constants for serialization."""
|
||||
|
||||
CAMEL = stringcase.camelcase
|
||||
SNAKE = stringcase.snakecase
|
||||
CAMEL = camel_case
|
||||
SNAKE = snake_case
|
||||
|
||||
|
||||
class _PLACEHOLDER:
|
||||
|
@ -1,9 +1,21 @@
|
||||
import stringcase
|
||||
import re
|
||||
|
||||
# Word delimiters and symbols that will not be preserved when re-casing.
|
||||
# language=PythonRegExp
|
||||
SYMBOLS = "[^a-zA-Z0-9]*"
|
||||
|
||||
# Optionally capitalized word.
|
||||
# language=PythonRegExp
|
||||
WORD = "[A-Z]*[a-z]*[0-9]*"
|
||||
|
||||
# Uppercase word, not followed by lowercase letters.
|
||||
# language=PythonRegExp
|
||||
WORD_UPPER = "[A-Z]+(?![a-z])[0-9]*"
|
||||
|
||||
|
||||
def safe_snake_case(value: str) -> str:
|
||||
"""Snake case a value taking into account Python keywords."""
|
||||
value = stringcase.snakecase(value)
|
||||
value = snake_case(value)
|
||||
if value in [
|
||||
"and",
|
||||
"as",
|
||||
@ -39,3 +51,70 @@ def safe_snake_case(value: str) -> str:
|
||||
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
|
||||
value += "_"
|
||||
return value
|
||||
|
||||
|
||||
def snake_case(value: str, strict: bool = True):
|
||||
"""
|
||||
Join words with an underscore into lowercase and remove symbols.
|
||||
@param value: value to convert
|
||||
@param strict: force single underscores
|
||||
"""
|
||||
|
||||
def substitute_word(symbols, word, is_start):
|
||||
if not word:
|
||||
return ""
|
||||
if strict:
|
||||
delimiter_count = 0 if is_start else 1 # Single underscore if strict.
|
||||
elif is_start:
|
||||
delimiter_count = len(symbols)
|
||||
elif word.isupper() or word.islower():
|
||||
delimiter_count = max(
|
||||
1, len(symbols)
|
||||
) # Preserve all delimiters if not strict.
|
||||
else:
|
||||
delimiter_count = len(symbols) + 1 # Extra underscore for leading capital.
|
||||
|
||||
return ("_" * delimiter_count) + word.lower()
|
||||
|
||||
snake = re.sub(
|
||||
f"(^)?({SYMBOLS})({WORD_UPPER}|{WORD})",
|
||||
lambda groups: substitute_word(groups[2], groups[3], groups[1] is not None),
|
||||
value,
|
||||
)
|
||||
return snake
|
||||
|
||||
|
||||
def pascal_case(value: str, strict: bool = True):
|
||||
"""
|
||||
Capitalize each word and remove symbols.
|
||||
@param value: value to convert
|
||||
@param strict: output only alphanumeric characters
|
||||
"""
|
||||
|
||||
def substitute_word(symbols, word):
|
||||
if strict:
|
||||
return word.capitalize() # Remove all delimiters
|
||||
|
||||
if word.islower():
|
||||
delimiter_length = len(symbols[:-1]) # Lose one delimiter
|
||||
else:
|
||||
delimiter_length = len(symbols) # Preserve all delimiters
|
||||
|
||||
return ("_" * delimiter_length) + word.capitalize()
|
||||
|
||||
return re.sub(
|
||||
f"({SYMBOLS})({WORD_UPPER}|{WORD})",
|
||||
lambda groups: substitute_word(groups[1], groups[2]),
|
||||
value,
|
||||
)
|
||||
|
||||
|
||||
def camel_case(value: str, strict: bool = True):
|
||||
"""
|
||||
Capitalize all words except first and remove symbols.
|
||||
"""
|
||||
return lowercase_first(pascal_case(value, strict=strict))
|
||||
|
||||
|
||||
def lowercase_first(value: str):
|
||||
return value[0:1].lower() + value[1:]
|
||||
|
@ -1,70 +1,160 @@
|
||||
from typing import Dict, Type
|
||||
|
||||
import stringcase
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Set, Type
|
||||
|
||||
from betterproto import safe_snake_case
|
||||
from betterproto.compile.naming import pythonize_class_name
|
||||
from betterproto.lib.google import protobuf as google_protobuf
|
||||
|
||||
WRAPPER_TYPES: Dict[str, Type] = {
|
||||
"google.protobuf.DoubleValue": google_protobuf.DoubleValue,
|
||||
"google.protobuf.FloatValue": google_protobuf.FloatValue,
|
||||
"google.protobuf.Int32Value": google_protobuf.Int32Value,
|
||||
"google.protobuf.Int64Value": google_protobuf.Int64Value,
|
||||
"google.protobuf.UInt32Value": google_protobuf.UInt32Value,
|
||||
"google.protobuf.UInt64Value": google_protobuf.UInt64Value,
|
||||
"google.protobuf.BoolValue": google_protobuf.BoolValue,
|
||||
"google.protobuf.StringValue": google_protobuf.StringValue,
|
||||
"google.protobuf.BytesValue": google_protobuf.BytesValue,
|
||||
".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
|
||||
".google.protobuf.FloatValue": google_protobuf.FloatValue,
|
||||
".google.protobuf.Int32Value": google_protobuf.Int32Value,
|
||||
".google.protobuf.Int64Value": google_protobuf.Int64Value,
|
||||
".google.protobuf.UInt32Value": google_protobuf.UInt32Value,
|
||||
".google.protobuf.UInt64Value": google_protobuf.UInt64Value,
|
||||
".google.protobuf.BoolValue": google_protobuf.BoolValue,
|
||||
".google.protobuf.StringValue": google_protobuf.StringValue,
|
||||
".google.protobuf.BytesValue": google_protobuf.BytesValue,
|
||||
}
|
||||
|
||||
|
||||
def get_ref_type(
|
||||
package: str, imports: set, type_name: str, unwrap: bool = True
|
||||
def parse_source_type_name(field_type_name):
|
||||
"""
|
||||
Split full source type name into package and type name.
|
||||
E.g. 'root.package.Message' -> ('root.package', 'Message')
|
||||
'root.Message.SomeEnum' -> ('root', 'Message.SomeEnum')
|
||||
"""
|
||||
package_match = re.match(r"^\.?([^A-Z]+)\.(.+)", field_type_name)
|
||||
if package_match:
|
||||
package = package_match.group(1)
|
||||
name = package_match.group(2)
|
||||
else:
|
||||
package = ""
|
||||
name = field_type_name.lstrip(".")
|
||||
return package, name
|
||||
|
||||
|
||||
def get_type_reference(
|
||||
package: str, imports: set, source_type: str, unwrap: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Return a Python type name for a proto type reference. Adds the import if
|
||||
necessary. Unwraps well known type if required.
|
||||
"""
|
||||
# If the package name is a blank string, then this should still work
|
||||
# because by convention packages are lowercase and message/enum types are
|
||||
# pascal-cased. May require refactoring in the future.
|
||||
type_name = type_name.lstrip(".")
|
||||
|
||||
is_wrapper = type_name in WRAPPER_TYPES
|
||||
|
||||
if unwrap:
|
||||
if is_wrapper:
|
||||
wrapped_type = type(WRAPPER_TYPES[type_name]().value)
|
||||
if source_type in WRAPPER_TYPES:
|
||||
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
|
||||
return f"Optional[{wrapped_type.__name__}]"
|
||||
|
||||
if type_name == "google.protobuf.Duration":
|
||||
if source_type == ".google.protobuf.Duration":
|
||||
return "timedelta"
|
||||
|
||||
if type_name == "google.protobuf.Timestamp":
|
||||
if source_type == ".google.protobuf.Timestamp":
|
||||
return "datetime"
|
||||
|
||||
if type_name.startswith(package):
|
||||
parts = type_name.lstrip(package).lstrip(".").split(".")
|
||||
if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
|
||||
# This is the current package, which has nested types flattened.
|
||||
# foo.bar_thing => FooBarThing
|
||||
cased = [stringcase.pascalcase(part) for part in parts]
|
||||
type_name = f'"{"".join(cased)}"'
|
||||
source_package, source_type = parse_source_type_name(source_type)
|
||||
|
||||
# Use precompiled classes for google.protobuf.* objects
|
||||
if type_name.startswith("google.protobuf.") and type_name.count(".") == 2:
|
||||
type_name = type_name.rsplit(".", maxsplit=1)[1]
|
||||
import_package = "betterproto.lib.google.protobuf"
|
||||
import_alias = safe_snake_case(import_package)
|
||||
imports.add(f"import {import_package} as {import_alias}")
|
||||
return f"{import_alias}.{type_name}"
|
||||
current_package: List[str] = package.split(".") if package else []
|
||||
py_package: List[str] = source_package.split(".") if source_package else []
|
||||
py_type: str = pythonize_class_name(source_type)
|
||||
|
||||
if "." in type_name:
|
||||
# This is imported from another package. No need
|
||||
# to use a forward ref and we need to add the import.
|
||||
parts = type_name.split(".")
|
||||
parts[-1] = stringcase.pascalcase(parts[-1])
|
||||
imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
|
||||
type_name = f"{parts[-2]}.{parts[-1]}"
|
||||
compiling_google_protobuf = current_package == ["google", "protobuf"]
|
||||
importing_google_protobuf = py_package == ["google", "protobuf"]
|
||||
if importing_google_protobuf and not compiling_google_protobuf:
|
||||
py_package = ["betterproto", "lib"] + py_package
|
||||
|
||||
return type_name
|
||||
if py_package[:1] == ["betterproto"]:
|
||||
return reference_absolute(imports, py_package, py_type)
|
||||
|
||||
if py_package == current_package:
|
||||
return reference_sibling(py_type)
|
||||
|
||||
if py_package[: len(current_package)] == current_package:
|
||||
return reference_descendent(current_package, imports, py_package, py_type)
|
||||
|
||||
if current_package[: len(py_package)] == py_package:
|
||||
return reference_ancestor(current_package, imports, py_package, py_type)
|
||||
|
||||
return reference_cousin(current_package, imports, py_package, py_type)
|
||||
|
||||
|
||||
def reference_absolute(imports, py_package, py_type):
|
||||
"""
|
||||
Returns a reference to a python type located in the root, i.e. sys.path.
|
||||
"""
|
||||
string_import = ".".join(py_package)
|
||||
string_alias = safe_snake_case(string_import)
|
||||
imports.add(f"import {string_import} as {string_alias}")
|
||||
return f"{string_alias}.{py_type}"
|
||||
|
||||
|
||||
def reference_sibling(py_type: str) -> str:
|
||||
"""
|
||||
Returns a reference to a python type within the same package as the current package.
|
||||
"""
|
||||
return f'"{py_type}"'
|
||||
|
||||
|
||||
def reference_descendent(
|
||||
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
|
||||
) -> str:
|
||||
"""
|
||||
Returns a reference to a python type in a package that is a descendent of the current package,
|
||||
and adds the required import that is aliased to avoid name conflicts.
|
||||
"""
|
||||
importing_descendent = py_package[len(current_package) :]
|
||||
string_from = ".".join(importing_descendent[:-1])
|
||||
string_import = importing_descendent[-1]
|
||||
if string_from:
|
||||
string_alias = "_".join(importing_descendent)
|
||||
imports.add(f"from .{string_from} import {string_import} as {string_alias}")
|
||||
return f"{string_alias}.{py_type}"
|
||||
else:
|
||||
imports.add(f"from . import {string_import}")
|
||||
return f"{string_import}.{py_type}"
|
||||
|
||||
|
||||
def reference_ancestor(
|
||||
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
|
||||
) -> str:
|
||||
"""
|
||||
Returns a reference to a python type in a package which is an ancestor to the current package,
|
||||
and adds the required import that is aliased (if possible) to avoid name conflicts.
|
||||
|
||||
Adds trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34).
|
||||
"""
|
||||
distance_up = len(current_package) - len(py_package)
|
||||
if py_package:
|
||||
string_import = py_package[-1]
|
||||
string_alias = f"_{'_' * distance_up}{string_import}__"
|
||||
string_from = f"..{'.' * distance_up}"
|
||||
imports.add(f"from {string_from} import {string_import} as {string_alias}")
|
||||
return f"{string_alias}.{py_type}"
|
||||
else:
|
||||
string_alias = f"{'_' * distance_up}{py_type}__"
|
||||
imports.add(f"from .{'.' * distance_up} import {py_type} as {string_alias}")
|
||||
return string_alias
|
||||
|
||||
|
||||
def reference_cousin(
|
||||
current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
|
||||
) -> str:
|
||||
"""
|
||||
Returns a reference to a python type in a package that is not descendent, ancestor or sibling,
|
||||
and adds the required import that is aliased to avoid name conflicts.
|
||||
"""
|
||||
shared_ancestry = os.path.commonprefix([current_package, py_package])
|
||||
distance_up = len(current_package) - len(shared_ancestry)
|
||||
string_from = f".{'.' * distance_up}" + ".".join(
|
||||
py_package[len(shared_ancestry) : -1]
|
||||
)
|
||||
string_import = py_package[-1]
|
||||
# Add trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34)
|
||||
string_alias = (
|
||||
f"{'_' * distance_up}"
|
||||
+ safe_snake_case(".".join(py_package[len(shared_ancestry) :]))
|
||||
+ "__"
|
||||
)
|
||||
imports.add(f"from {string_from} import {string_import} as {string_alias}")
|
||||
return f"{string_alias}.{py_type}"
|
||||
|
13
betterproto/compile/naming.py
Normal file
13
betterproto/compile/naming.py
Normal file
@ -0,0 +1,13 @@
|
||||
from betterproto import casing
|
||||
|
||||
|
||||
def pythonize_class_name(name):
|
||||
return casing.pascal_case(name)
|
||||
|
||||
|
||||
def pythonize_field_name(name: str):
|
||||
return casing.safe_snake_case(name)
|
||||
|
||||
|
||||
def pythonize_method_name(name: str):
|
||||
return casing.safe_snake_case(name)
|
@ -84,7 +84,7 @@ class FieldOptionsCType(betterproto.Enum):
|
||||
STRING_PIECE = 2
|
||||
|
||||
|
||||
class FieldOptionsJSType(betterproto.Enum):
|
||||
class FieldOptionsJsType(betterproto.Enum):
|
||||
JS_NORMAL = 0
|
||||
JS_STRING = 1
|
||||
JS_NUMBER = 2
|
||||
@ -717,7 +717,7 @@ class FieldOptions(betterproto.Message):
|
||||
# use the JavaScript "number" type. The behavior of the default option
|
||||
# JS_NORMAL is implementation dependent. This option is an enum to permit
|
||||
# additional types to be added, e.g. goog.math.Integer.
|
||||
jstype: "FieldOptionsJSType" = betterproto.enum_field(6)
|
||||
jstype: "FieldOptionsJsType" = betterproto.enum_field(6)
|
||||
# Should this field be parsed lazily? Lazy applies only to message-type
|
||||
# fields. It means that when the outer message is initially parsed, the
|
||||
# inner message's contents will not be parsed but instead stored in encoded
|
@ -2,14 +2,19 @@
|
||||
|
||||
import itertools
|
||||
import os.path
|
||||
import pathlib
|
||||
import re
|
||||
import stringcase
|
||||
import sys
|
||||
import textwrap
|
||||
from typing import List, Union
|
||||
|
||||
import betterproto
|
||||
from betterproto.casing import safe_snake_case
|
||||
from betterproto.compile.importing import get_ref_type
|
||||
from betterproto.compile.importing import get_type_reference
|
||||
from betterproto.compile.naming import (
|
||||
pythonize_class_name,
|
||||
pythonize_field_name,
|
||||
pythonize_method_name,
|
||||
)
|
||||
|
||||
try:
|
||||
# betterproto[compiler] specific dependencies
|
||||
@ -35,27 +40,22 @@ except ImportError as err:
|
||||
raise SystemExit(1)
|
||||
|
||||
|
||||
def py_type(
|
||||
package: str,
|
||||
imports: set,
|
||||
message: DescriptorProto,
|
||||
descriptor: FieldDescriptorProto,
|
||||
) -> str:
|
||||
if descriptor.type in [1, 2]:
|
||||
def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str:
|
||||
if field.type in [1, 2]:
|
||||
return "float"
|
||||
elif descriptor.type in [3, 4, 5, 6, 7, 13, 15, 16, 17, 18]:
|
||||
elif field.type in [3, 4, 5, 6, 7, 13, 15, 16, 17, 18]:
|
||||
return "int"
|
||||
elif descriptor.type == 8:
|
||||
elif field.type == 8:
|
||||
return "bool"
|
||||
elif descriptor.type == 9:
|
||||
elif field.type == 9:
|
||||
return "str"
|
||||
elif descriptor.type in [11, 14]:
|
||||
elif field.type in [11, 14]:
|
||||
# Type referencing another defined Message or a named enum
|
||||
return get_ref_type(package, imports, descriptor.type_name)
|
||||
elif descriptor.type == 12:
|
||||
return get_type_reference(package, imports, field.type_name)
|
||||
elif field.type == 12:
|
||||
return "bytes"
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown type {descriptor.type}")
|
||||
raise NotImplementedError(f"Unknown type {field.type}")
|
||||
|
||||
|
||||
def get_py_zero(type_num: int) -> Union[str, float]:
|
||||
@ -131,17 +131,17 @@ def generate_code(request, response):
|
||||
|
||||
output_map = {}
|
||||
for proto_file in request.proto_file:
|
||||
out = proto_file.package
|
||||
|
||||
if out == "google.protobuf" and "INCLUDE_GOOGLE" not in plugin_options:
|
||||
if (
|
||||
proto_file.package == "google.protobuf"
|
||||
and "INCLUDE_GOOGLE" not in plugin_options
|
||||
):
|
||||
continue
|
||||
|
||||
if not out:
|
||||
out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".")
|
||||
output_file = str(pathlib.Path(*proto_file.package.split("."), "__init__.py"))
|
||||
|
||||
if out not in output_map:
|
||||
output_map[out] = {"package": proto_file.package, "files": []}
|
||||
output_map[out]["files"].append(proto_file)
|
||||
if output_file not in output_map:
|
||||
output_map[output_file] = {"package": proto_file.package, "files": []}
|
||||
output_map[output_file]["files"].append(proto_file)
|
||||
|
||||
# TODO: Figure out how to handle gRPC request/response messages and add
|
||||
# processing below for Service.
|
||||
@ -160,17 +160,10 @@ def generate_code(request, response):
|
||||
"services": [],
|
||||
}
|
||||
|
||||
type_mapping = {}
|
||||
|
||||
for proto_file in options["files"]:
|
||||
# print(proto_file.message_type, file=sys.stderr)
|
||||
# print(proto_file.service, file=sys.stderr)
|
||||
# print(proto_file.source_code_info, file=sys.stderr)
|
||||
|
||||
item: DescriptorProto
|
||||
for item, path in traverse(proto_file):
|
||||
# print(item, file=sys.stderr)
|
||||
# print(path, file=sys.stderr)
|
||||
data = {"name": item.name, "py_name": stringcase.pascalcase(item.name)}
|
||||
data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
|
||||
|
||||
if isinstance(item, DescriptorProto):
|
||||
# print(item, file=sys.stderr)
|
||||
@ -187,7 +180,7 @@ def generate_code(request, response):
|
||||
)
|
||||
|
||||
for i, f in enumerate(item.field):
|
||||
t = py_type(package, output["imports"], item, f)
|
||||
t = py_type(package, output["imports"], f)
|
||||
zero = get_py_zero(f.type)
|
||||
|
||||
repeated = False
|
||||
@ -222,13 +215,11 @@ def generate_code(request, response):
|
||||
k = py_type(
|
||||
package,
|
||||
output["imports"],
|
||||
item,
|
||||
nested.field[0],
|
||||
)
|
||||
v = py_type(
|
||||
package,
|
||||
output["imports"],
|
||||
item,
|
||||
nested.field[1],
|
||||
)
|
||||
t = f"Dict[{k}, {v}]"
|
||||
@ -264,7 +255,7 @@ def generate_code(request, response):
|
||||
data["properties"].append(
|
||||
{
|
||||
"name": f.name,
|
||||
"py_name": safe_snake_case(f.name),
|
||||
"py_name": pythonize_field_name(f.name),
|
||||
"number": f.number,
|
||||
"comment": get_comment(proto_file, path + [2, i]),
|
||||
"proto_type": int(f.type),
|
||||
@ -305,14 +296,14 @@ def generate_code(request, response):
|
||||
|
||||
data = {
|
||||
"name": service.name,
|
||||
"py_name": stringcase.pascalcase(service.name),
|
||||
"py_name": pythonize_class_name(service.name),
|
||||
"comment": get_comment(proto_file, [6, i]),
|
||||
"methods": [],
|
||||
}
|
||||
|
||||
for j, method in enumerate(service.method):
|
||||
input_message = None
|
||||
input_type = get_ref_type(
|
||||
input_type = get_type_reference(
|
||||
package, output["imports"], method.input_type
|
||||
).strip('"')
|
||||
for msg in output["messages"]:
|
||||
@ -326,14 +317,14 @@ def generate_code(request, response):
|
||||
data["methods"].append(
|
||||
{
|
||||
"name": method.name,
|
||||
"py_name": stringcase.snakecase(method.name),
|
||||
"py_name": pythonize_method_name(method.name),
|
||||
"comment": get_comment(proto_file, [6, i, 2, j], indent=8),
|
||||
"route": f"/{package}.{service.name}/{method.name}",
|
||||
"input": get_ref_type(
|
||||
"input": get_type_reference(
|
||||
package, output["imports"], method.input_type
|
||||
).strip('"'),
|
||||
"input_message": input_message,
|
||||
"output": get_ref_type(
|
||||
"output": get_type_reference(
|
||||
package,
|
||||
output["imports"],
|
||||
method.output_type,
|
||||
@ -359,8 +350,7 @@ def generate_code(request, response):
|
||||
|
||||
# Fill response
|
||||
f = response.file.add()
|
||||
# print(filename, file=sys.stderr)
|
||||
f.name = filename.replace(".", os.path.sep) + ".py"
|
||||
f.name = filename
|
||||
|
||||
# Render and then format the output file.
|
||||
f.content = black.format_str(
|
||||
@ -368,32 +358,23 @@ def generate_code(request, response):
|
||||
mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])),
|
||||
)
|
||||
|
||||
inits = set([""])
|
||||
for f in response.file:
|
||||
# Ensure output paths exist
|
||||
# print(f.name, file=sys.stderr)
|
||||
dirnames = os.path.dirname(f.name)
|
||||
if dirnames:
|
||||
os.makedirs(dirnames, exist_ok=True)
|
||||
base = ""
|
||||
for part in dirnames.split(os.path.sep):
|
||||
base = os.path.join(base, part)
|
||||
inits.add(base)
|
||||
|
||||
for base in inits:
|
||||
name = os.path.join(base, "__init__.py")
|
||||
|
||||
if os.path.exists(name):
|
||||
# Never overwrite inits as they may have custom stuff in them.
|
||||
continue
|
||||
# Make each output directory a package with __init__ file
|
||||
output_paths = set(pathlib.Path(path) for path in output_map.keys())
|
||||
init_files = (
|
||||
set(
|
||||
directory.joinpath("__init__.py")
|
||||
for path in output_paths
|
||||
for directory in path.parents
|
||||
)
|
||||
- output_paths
|
||||
)
|
||||
|
||||
for init_file in init_files:
|
||||
init = response.file.add()
|
||||
init.name = name
|
||||
init.content = b""
|
||||
init.name = str(init_file)
|
||||
|
||||
filenames = sorted([f.name for f in response.file])
|
||||
for fname in filenames:
|
||||
print(f"Writing {fname}", file=sys.stderr)
|
||||
for filename in sorted(output_paths.union(init_files)):
|
||||
print(f"Writing {filename}", file=sys.stderr)
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -12,12 +12,12 @@ inputs/
|
||||
|
||||
## Test case directory structure
|
||||
|
||||
Each testcase has a `<name>.proto` file with a message called `Test`, a matching `.json` file and optionally a custom test file called `test_*.py`.
|
||||
Each testcase has a `<name>.proto` file with a message called `Test`, and optionally a matching `.json` file and a custom test called `test_*.py`.
|
||||
|
||||
```bash
|
||||
bool/
|
||||
bool.proto
|
||||
bool.json
|
||||
bool.json # optional
|
||||
test_bool.py # optional
|
||||
```
|
||||
|
||||
@ -61,21 +61,22 @@ def test_value():
|
||||
|
||||
The following tests are automatically executed for all cases:
|
||||
|
||||
- [x] Can the generated python code imported?
|
||||
- [x] Can the generated python code be imported?
|
||||
- [x] Can the generated message class be instantiated?
|
||||
- [x] Is the generated code compatible with the Google's `grpc_tools.protoc` implementation?
|
||||
- _when `.json` is present_
|
||||
|
||||
## Running the tests
|
||||
|
||||
- `pipenv run generate`
|
||||
This generates
|
||||
- `pipenv run generate`
|
||||
This generates:
|
||||
- `betterproto/tests/output_betterproto` — *the plugin generated python classes*
|
||||
- `betterproto/tests/output_reference` — *reference implementation classes*
|
||||
- `pipenv run test`
|
||||
|
||||
## Intentionally Failing tests
|
||||
|
||||
The standard test suite includes tests that fail by intention. These tests document known bugs and missing features that are intended to be corrented in the future.
|
||||
The standard test suite includes tests that fail by intention. These tests document known bugs and missing features that are intended to be corrected in the future.
|
||||
|
||||
When running `pytest`, they show up as `x` or `X` in the test results.
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
from betterproto.tests.output_betterproto.bool.bool import Test
|
||||
from betterproto.tests.output_betterproto.bool import Test
|
||||
|
||||
|
||||
def test_value():
|
||||
|
@ -10,6 +10,7 @@ message Test {
|
||||
int32 camelCase = 1;
|
||||
my_enum snake_case = 2;
|
||||
snake_case_message snake_case_message = 3;
|
||||
int32 UPPERCASE = 4;
|
||||
}
|
||||
|
||||
message snake_case_message {
|
||||
|
@ -1,5 +1,5 @@
|
||||
import betterproto.tests.output_betterproto.casing.casing as casing
|
||||
from betterproto.tests.output_betterproto.casing.casing import Test
|
||||
import betterproto.tests.output_betterproto.casing as casing
|
||||
from betterproto.tests.output_betterproto.casing import Test
|
||||
|
||||
|
||||
def test_message_attributes():
|
||||
@ -8,6 +8,7 @@ def test_message_attributes():
|
||||
message, "snake_case_message"
|
||||
), "snake_case field name is same in python"
|
||||
assert hasattr(message, "camel_case"), "CamelCase field is snake_case in python"
|
||||
assert hasattr(message, "uppercase"), "UPPERCASE field is lowercase in python"
|
||||
|
||||
|
||||
def test_message_casing():
|
||||
|
@ -1,5 +0,0 @@
|
||||
{
|
||||
"UPPERCASE": 10,
|
||||
"UPPERCASE_V2": 10,
|
||||
"UPPER_CAMEL_CASE": 10
|
||||
}
|
@ -1,6 +1,4 @@
|
||||
from betterproto.tests.output_betterproto.casing_message_field_uppercase.casing_message_field_uppercase import (
|
||||
Test,
|
||||
)
|
||||
from betterproto.tests.output_betterproto.casing_message_field_uppercase import Test
|
||||
|
||||
|
||||
def test_message_casing():
|
||||
|
@ -1,19 +1,15 @@
|
||||
# Test cases that are expected to fail, e.g. unimplemented features or bug-fixes.
|
||||
# Remove from list when fixed.
|
||||
tests = {
|
||||
"import_root_sibling", # 61
|
||||
"import_child_package_from_package", # 58
|
||||
"import_root_package_from_child", # 60
|
||||
"import_parent_package_from_child", # 59
|
||||
"import_circular_dependency", # failing because of other bugs now
|
||||
"import_packages_same_name", # 25
|
||||
xfail = {
|
||||
"import_circular_dependency",
|
||||
"oneof_enum", # 63
|
||||
"casing_message_field_uppercase", # 11
|
||||
"namespace_keywords", # 70
|
||||
"namespace_builtin_types", # 53
|
||||
"googletypes_struct", # 9
|
||||
"googletypes_value", # 9
|
||||
"enum_skipped_value", # 93
|
||||
"import_capitalized_package",
|
||||
"example", # This is the example in the readme. Not a test.
|
||||
}
|
||||
|
||||
services = {
|
||||
|
8
betterproto/tests/inputs/example/example.proto
Normal file
8
betterproto/tests/inputs/example/example.proto
Normal file
@ -0,0 +1,8 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package hello;
|
||||
|
||||
// Greeting represents a message you can tell a user.
|
||||
message Greeting {
|
||||
string message = 1;
|
||||
}
|
@ -4,9 +4,7 @@ import betterproto.lib.google.protobuf as protobuf
|
||||
import pytest
|
||||
|
||||
from betterproto.tests.mocks import MockChannel
|
||||
from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import (
|
||||
TestStub,
|
||||
)
|
||||
from betterproto.tests.output_betterproto.googletypes_response import TestStub
|
||||
|
||||
test_cases = [
|
||||
(TestStub.get_double, protobuf.DoubleValue, 2.5),
|
||||
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from betterproto.tests.mocks import MockChannel
|
||||
from betterproto.tests.output_betterproto.googletypes_response_embedded.googletypes_response_embedded import (
|
||||
from betterproto.tests.output_betterproto.googletypes_response_embedded import (
|
||||
Output,
|
||||
TestStub,
|
||||
)
|
||||
|
@ -0,0 +1,8 @@
|
||||
syntax = "proto3";
|
||||
|
||||
|
||||
package Capitalized;
|
||||
|
||||
message Message {
|
||||
|
||||
}
|
@ -0,0 +1,9 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "capitalized.proto";
|
||||
|
||||
// Tests that we can import from a package with a capital name, that looks like a nested type, but isn't.
|
||||
|
||||
message Test {
|
||||
Capitalized.Message message = 1;
|
||||
}
|
@ -3,9 +3,9 @@ syntax = "proto3";
|
||||
import "root.proto";
|
||||
import "other.proto";
|
||||
|
||||
// This test-case verifies that future implementations will support circular dependencies in the generated python files.
|
||||
// This test-case verifies support for circular dependencies in the generated python files.
|
||||
//
|
||||
// This becomes important when generating 1 python file/module per package, rather than 1 file per proto file.
|
||||
// This is important because we generate 1 python file/module per package, rather than 1 file per proto file.
|
||||
//
|
||||
// Scenario:
|
||||
//
|
||||
@ -24,5 +24,5 @@ import "other.proto";
|
||||
// (root: Test & RootPackageMessage) <-------> (other: OtherPackageMessage)
|
||||
message Test {
|
||||
RootPackageMessage message = 1;
|
||||
other.OtherPackageMessage other =2;
|
||||
other.OtherPackageMessage other = 2;
|
||||
}
|
||||
|
@ -0,0 +1,6 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package cousin.cousin_subpackage;
|
||||
|
||||
message CousinMessage {
|
||||
}
|
11
betterproto/tests/inputs/import_cousin_package/test.proto
Normal file
11
betterproto/tests/inputs/import_cousin_package/test.proto
Normal file
@ -0,0 +1,11 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package test.subpackage;
|
||||
|
||||
import "cousin.proto";
|
||||
|
||||
// Verify that we can import message unrelated to us
|
||||
|
||||
message Test {
|
||||
cousin.cousin_subpackage.CousinMessage message = 1;
|
||||
}
|
@ -0,0 +1,6 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package cousin.subpackage;
|
||||
|
||||
message CousinMessage {
|
||||
}
|
@ -0,0 +1,11 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package test.subpackage;
|
||||
|
||||
import "cousin.proto";
|
||||
|
||||
// Verify that we can import a message unrelated to us, in a subpackage with the same name as us.
|
||||
|
||||
message Test {
|
||||
cousin.subpackage.CousinMessage message = 1;
|
||||
}
|
@ -0,0 +1,11 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package child;
|
||||
|
||||
import "root.proto";
|
||||
|
||||
// Verify that we can import root message from child package
|
||||
|
||||
message Test {
|
||||
RootMessage message = 1;
|
||||
}
|
@ -1,11 +0,0 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "root.proto";
|
||||
|
||||
package child;
|
||||
|
||||
// Tests generated imports when a message inside a child-package refers to a message defined in the root.
|
||||
|
||||
message Test {
|
||||
RootMessage message = 1;
|
||||
}
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from betterproto.tests.mocks import MockChannel
|
||||
from betterproto.tests.output_betterproto.import_service_input_message.import_service_input_message import (
|
||||
from betterproto.tests.output_betterproto.import_service_input_message import (
|
||||
RequestResponse,
|
||||
TestStub,
|
||||
)
|
||||
|
@ -15,4 +15,4 @@ message Test {
|
||||
|
||||
message Sibling {
|
||||
int32 foo = 1;
|
||||
}
|
||||
}
|
19
betterproto/tests/inputs/nested2/nested2.proto
Normal file
19
betterproto/tests/inputs/nested2/nested2.proto
Normal file
@ -0,0 +1,19 @@
|
||||
syntax = "proto3";
|
||||
|
||||
import "package.proto";
|
||||
|
||||
message Game {
|
||||
message Player {
|
||||
enum Race {
|
||||
human = 0;
|
||||
orc = 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
message Test {
|
||||
Game game = 1;
|
||||
Game.Player GamePlayer = 2;
|
||||
Game.Player.Race GamePlayerRace = 3;
|
||||
equipment.Weapon Weapon = 4;
|
||||
}
|
7
betterproto/tests/inputs/nested2/package.proto
Normal file
7
betterproto/tests/inputs/nested2/package.proto
Normal file
@ -0,0 +1,7 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package equipment;
|
||||
|
||||
message Weapon {
|
||||
|
||||
}
|
@ -1,10 +1,10 @@
|
||||
{
|
||||
"root": {
|
||||
"top": {
|
||||
"name": "double-nested",
|
||||
"parent": {
|
||||
"child": [{"foo": "hello"}],
|
||||
"enumChild": ["A"],
|
||||
"rootParentChild": [{"a": "hello"}],
|
||||
"middle": {
|
||||
"bottom": [{"foo": "hello"}],
|
||||
"enumBottom": ["A"],
|
||||
"topMiddleBottom": [{"a": "hello"}],
|
||||
"bar": true
|
||||
}
|
||||
}
|
||||
|
@ -1,26 +1,26 @@
|
||||
syntax = "proto3";
|
||||
|
||||
message Test {
|
||||
message Root {
|
||||
message Parent {
|
||||
message RootParentChild {
|
||||
message Top {
|
||||
message Middle {
|
||||
message TopMiddleBottom {
|
||||
string a = 1;
|
||||
}
|
||||
enum EnumChild{
|
||||
enum EnumBottom{
|
||||
A = 0;
|
||||
B = 1;
|
||||
}
|
||||
message Child {
|
||||
message Bottom {
|
||||
string foo = 1;
|
||||
}
|
||||
reserved 1;
|
||||
repeated Child child = 2;
|
||||
repeated EnumChild enumChild=3;
|
||||
repeated RootParentChild rootParentChild=4;
|
||||
repeated Bottom bottom = 2;
|
||||
repeated EnumBottom enumBottom=3;
|
||||
repeated TopMiddleBottom topMiddleBottom=4;
|
||||
bool bar = 5;
|
||||
}
|
||||
string name = 1;
|
||||
Parent parent = 2;
|
||||
Middle middle = 2;
|
||||
}
|
||||
Root root = 1;
|
||||
Top top = 1;
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
import betterproto
|
||||
from betterproto.tests.output_betterproto.oneof.oneof import Test
|
||||
from betterproto.tests.output_betterproto.oneof import Test
|
||||
from betterproto.tests.util import get_test_case_json_data
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import pytest
|
||||
|
||||
import betterproto
|
||||
from betterproto.tests.output_betterproto.oneof_enum.oneof_enum import (
|
||||
from betterproto.tests.output_betterproto.oneof_enum import (
|
||||
Move,
|
||||
Signal,
|
||||
Test,
|
||||
|
@ -1,7 +1,5 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package ref;
|
||||
|
||||
import "repeatedmessage.proto";
|
||||
|
||||
message Test {
|
||||
|
125
betterproto/tests/test_casing.py
Normal file
125
betterproto/tests/test_casing.py
Normal file
@ -0,0 +1,125 @@
|
||||
import pytest
|
||||
|
||||
from betterproto.casing import camel_case, pascal_case, snake_case
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["value", "expected"],
|
||||
[
|
||||
("", ""),
|
||||
("a", "A"),
|
||||
("foobar", "Foobar"),
|
||||
("fooBar", "FooBar"),
|
||||
("FooBar", "FooBar"),
|
||||
("foo.bar", "FooBar"),
|
||||
("foo_bar", "FooBar"),
|
||||
("FOOBAR", "Foobar"),
|
||||
("FOOBar", "FooBar"),
|
||||
("UInt32", "UInt32"),
|
||||
("FOO_BAR", "FooBar"),
|
||||
("FOOBAR1", "Foobar1"),
|
||||
("FOOBAR_1", "Foobar1"),
|
||||
("FOO1BAR2", "Foo1Bar2"),
|
||||
("foo__bar", "FooBar"),
|
||||
("_foobar", "Foobar"),
|
||||
("foobaR", "FoobaR"),
|
||||
("foo~bar", "FooBar"),
|
||||
("foo:bar", "FooBar"),
|
||||
("1foobar", "1Foobar"),
|
||||
],
|
||||
)
|
||||
def test_pascal_case(value, expected):
|
||||
actual = pascal_case(value, strict=True)
|
||||
assert actual == expected, f"{value} => {expected} (actual: {actual})"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["value", "expected"],
|
||||
[
|
||||
("", ""),
|
||||
("a", "a"),
|
||||
("foobar", "foobar"),
|
||||
("fooBar", "fooBar"),
|
||||
("FooBar", "fooBar"),
|
||||
("foo.bar", "fooBar"),
|
||||
("foo_bar", "fooBar"),
|
||||
("FOOBAR", "foobar"),
|
||||
("FOO_BAR", "fooBar"),
|
||||
("FOOBAR1", "foobar1"),
|
||||
("FOOBAR_1", "foobar1"),
|
||||
("FOO1BAR2", "foo1Bar2"),
|
||||
("foo__bar", "fooBar"),
|
||||
("_foobar", "foobar"),
|
||||
("foobaR", "foobaR"),
|
||||
("foo~bar", "fooBar"),
|
||||
("foo:bar", "fooBar"),
|
||||
("1foobar", "1Foobar"),
|
||||
],
|
||||
)
|
||||
def test_camel_case_strict(value, expected):
|
||||
actual = camel_case(value, strict=True)
|
||||
assert actual == expected, f"{value} => {expected} (actual: {actual})"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["value", "expected"],
|
||||
[
|
||||
("foo_bar", "fooBar"),
|
||||
("FooBar", "fooBar"),
|
||||
("foo__bar", "foo_Bar"),
|
||||
("foo__Bar", "foo__Bar"),
|
||||
],
|
||||
)
|
||||
def test_camel_case_not_strict(value, expected):
|
||||
actual = camel_case(value, strict=False)
|
||||
assert actual == expected, f"{value} => {expected} (actual: {actual})"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["value", "expected"],
|
||||
[
|
||||
("", ""),
|
||||
("a", "a"),
|
||||
("foobar", "foobar"),
|
||||
("fooBar", "foo_bar"),
|
||||
("FooBar", "foo_bar"),
|
||||
("foo.bar", "foo_bar"),
|
||||
("foo_bar", "foo_bar"),
|
||||
("foo_Bar", "foo_bar"),
|
||||
("FOOBAR", "foobar"),
|
||||
("FOOBar", "foo_bar"),
|
||||
("UInt32", "u_int32"),
|
||||
("FOO_BAR", "foo_bar"),
|
||||
("FOOBAR1", "foobar1"),
|
||||
("FOOBAR_1", "foobar_1"),
|
||||
("FOOBAR_123", "foobar_123"),
|
||||
("FOO1BAR2", "foo1_bar2"),
|
||||
("foo__bar", "foo_bar"),
|
||||
("_foobar", "foobar"),
|
||||
("foobaR", "fooba_r"),
|
||||
("foo~bar", "foo_bar"),
|
||||
("foo:bar", "foo_bar"),
|
||||
("1foobar", "1_foobar"),
|
||||
("GetUInt64", "get_u_int64"),
|
||||
],
|
||||
)
|
||||
def test_snake_case_strict(value, expected):
|
||||
actual = snake_case(value)
|
||||
assert actual == expected, f"{value} => {expected} (actual: {actual})"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["value", "expected"],
|
||||
[
|
||||
("fooBar", "foo_bar"),
|
||||
("FooBar", "foo_bar"),
|
||||
("foo_Bar", "foo__bar"),
|
||||
("foo__bar", "foo__bar"),
|
||||
("FOOBar", "foo_bar"),
|
||||
("__foo", "__foo"),
|
||||
("GetUInt64", "get_u_int64"),
|
||||
],
|
||||
)
|
||||
def test_snake_case_not_strict(value, expected):
|
||||
actual = snake_case(value, strict=False)
|
||||
assert actual == expected, f"{value} => {expected} (actual: {actual})"
|
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from ..compile.importing import get_ref_type
|
||||
from ..compile.importing import get_type_reference, parse_source_type_name
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -28,14 +28,16 @@ from ..compile.importing import get_ref_type
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_import_google_wellknown_types_non_wrappers(
|
||||
def test_reference_google_wellknown_types_non_wrappers(
|
||||
google_type: str, expected_name: str, expected_import: str
|
||||
):
|
||||
imports = set()
|
||||
name = get_ref_type(package="", imports=imports, type_name=google_type)
|
||||
name = get_type_reference(package="", imports=imports, source_type=google_type)
|
||||
|
||||
assert name == expected_name
|
||||
assert imports.__contains__(expected_import)
|
||||
assert imports.__contains__(
|
||||
expected_import
|
||||
), f"{expected_import} not found in {imports}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -52,9 +54,11 @@ def test_import_google_wellknown_types_non_wrappers(
|
||||
(".google.protobuf.BytesValue", "Optional[bytes]"),
|
||||
],
|
||||
)
|
||||
def test_importing_google_wrappers_unwraps_them(google_type: str, expected_name: str):
|
||||
def test_referenceing_google_wrappers_unwraps_them(
|
||||
google_type: str, expected_name: str
|
||||
):
|
||||
imports = set()
|
||||
name = get_ref_type(package="", imports=imports, type_name=google_type)
|
||||
name = get_type_reference(package="", imports=imports, source_type=google_type)
|
||||
|
||||
assert name == expected_name
|
||||
assert imports == set()
|
||||
@ -74,9 +78,238 @@ def test_importing_google_wrappers_unwraps_them(google_type: str, expected_name:
|
||||
(".google.protobuf.BytesValue", "betterproto_lib_google_protobuf.BytesValue"),
|
||||
],
|
||||
)
|
||||
def test_importing_google_wrappers_without_unwrapping(
|
||||
def test_referenceing_google_wrappers_without_unwrapping(
|
||||
google_type: str, expected_name: str
|
||||
):
|
||||
name = get_ref_type(package="", imports=set(), type_name=google_type, unwrap=False)
|
||||
name = get_type_reference(
|
||||
package="", imports=set(), source_type=google_type, unwrap=False
|
||||
)
|
||||
|
||||
assert name == expected_name
|
||||
|
||||
|
||||
def test_reference_child_package_from_package():
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="package", imports=imports, source_type="package.child.Message"
|
||||
)
|
||||
|
||||
assert imports == {"from . import child"}
|
||||
assert name == "child.Message"
|
||||
|
||||
|
||||
def test_reference_child_package_from_root():
|
||||
imports = set()
|
||||
name = get_type_reference(package="", imports=imports, source_type="child.Message")
|
||||
|
||||
assert imports == {"from . import child"}
|
||||
assert name == "child.Message"
|
||||
|
||||
|
||||
def test_reference_camel_cased():
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="", imports=imports, source_type="child_package.example_message"
|
||||
)
|
||||
|
||||
assert imports == {"from . import child_package"}
|
||||
assert name == "child_package.ExampleMessage"
|
||||
|
||||
|
||||
def test_reference_nested_child_from_root():
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="", imports=imports, source_type="nested.child.Message"
|
||||
)
|
||||
|
||||
assert imports == {"from .nested import child as nested_child"}
|
||||
assert name == "nested_child.Message"
|
||||
|
||||
|
||||
def test_reference_deeply_nested_child_from_root():
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="", imports=imports, source_type="deeply.nested.child.Message"
|
||||
)
|
||||
|
||||
assert imports == {"from .deeply.nested import child as deeply_nested_child"}
|
||||
assert name == "deeply_nested_child.Message"
|
||||
|
||||
|
||||
def test_reference_deeply_nested_child_from_package():
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="package",
|
||||
imports=imports,
|
||||
source_type="package.deeply.nested.child.Message",
|
||||
)
|
||||
|
||||
assert imports == {"from .deeply.nested import child as deeply_nested_child"}
|
||||
assert name == "deeply_nested_child.Message"
|
||||
|
||||
|
||||
def test_reference_root_sibling():
|
||||
imports = set()
|
||||
name = get_type_reference(package="", imports=imports, source_type="Message")
|
||||
|
||||
assert imports == set()
|
||||
assert name == '"Message"'
|
||||
|
||||
|
||||
def test_reference_nested_siblings():
|
||||
imports = set()
|
||||
name = get_type_reference(package="foo", imports=imports, source_type="foo.Message")
|
||||
|
||||
assert imports == set()
|
||||
assert name == '"Message"'
|
||||
|
||||
|
||||
def test_reference_deeply_nested_siblings():
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="foo.bar", imports=imports, source_type="foo.bar.Message"
|
||||
)
|
||||
|
||||
assert imports == set()
|
||||
assert name == '"Message"'
|
||||
|
||||
|
||||
def test_reference_parent_package_from_child():
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="package.child", imports=imports, source_type="package.Message"
|
||||
)
|
||||
|
||||
assert imports == {"from ... import package as __package__"}
|
||||
assert name == "__package__.Message"
|
||||
|
||||
|
||||
def test_reference_parent_package_from_deeply_nested_child():
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="package.deeply.nested.child",
|
||||
imports=imports,
|
||||
source_type="package.deeply.nested.Message",
|
||||
)
|
||||
|
||||
assert imports == {"from ... import nested as __nested__"}
|
||||
assert name == "__nested__.Message"
|
||||
|
||||
|
||||
def test_reference_ancestor_package_from_nested_child():
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="package.ancestor.nested.child",
|
||||
imports=imports,
|
||||
source_type="package.ancestor.Message",
|
||||
)
|
||||
|
||||
assert imports == {"from .... import ancestor as ___ancestor__"}
|
||||
assert name == "___ancestor__.Message"
|
||||
|
||||
|
||||
def test_reference_root_package_from_child():
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="package.child", imports=imports, source_type="Message"
|
||||
)
|
||||
|
||||
assert imports == {"from ... import Message as __Message__"}
|
||||
assert name == "__Message__"
|
||||
|
||||
|
||||
def test_reference_root_package_from_deeply_nested_child():
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="package.deeply.nested.child", imports=imports, source_type="Message"
|
||||
)
|
||||
|
||||
assert imports == {"from ..... import Message as ____Message__"}
|
||||
assert name == "____Message__"
|
||||
|
||||
|
||||
def test_reference_unrelated_package():
|
||||
imports = set()
|
||||
name = get_type_reference(package="a", imports=imports, source_type="p.Message")
|
||||
|
||||
assert imports == {"from .. import p as _p__"}
|
||||
assert name == "_p__.Message"
|
||||
|
||||
|
||||
def test_reference_unrelated_nested_package():
|
||||
imports = set()
|
||||
name = get_type_reference(package="a.b", imports=imports, source_type="p.q.Message")
|
||||
|
||||
assert imports == {"from ...p import q as __p_q__"}
|
||||
assert name == "__p_q__.Message"
|
||||
|
||||
|
||||
def test_reference_unrelated_deeply_nested_package():
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="a.b.c.d", imports=imports, source_type="p.q.r.s.Message"
|
||||
)
|
||||
|
||||
assert imports == {"from .....p.q.r import s as ____p_q_r_s__"}
|
||||
assert name == "____p_q_r_s__.Message"
|
||||
|
||||
|
||||
def test_reference_cousin_package():
|
||||
imports = set()
|
||||
name = get_type_reference(package="a.x", imports=imports, source_type="a.y.Message")
|
||||
|
||||
assert imports == {"from .. import y as _y__"}
|
||||
assert name == "_y__.Message"
|
||||
|
||||
|
||||
def test_reference_cousin_package_different_name():
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="test.package1", imports=imports, source_type="cousin.package2.Message"
|
||||
)
|
||||
|
||||
assert imports == {"from ...cousin import package2 as __cousin_package2__"}
|
||||
assert name == "__cousin_package2__.Message"
|
||||
|
||||
|
||||
def test_reference_cousin_package_same_name():
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="test.package", imports=imports, source_type="cousin.package.Message"
|
||||
)
|
||||
|
||||
assert imports == {"from ...cousin import package as __cousin_package__"}
|
||||
assert name == "__cousin_package__.Message"
|
||||
|
||||
|
||||
def test_reference_far_cousin_package():
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="a.x.y", imports=imports, source_type="a.b.c.Message"
|
||||
)
|
||||
|
||||
assert imports == {"from ...b import c as __b_c__"}
|
||||
assert name == "__b_c__.Message"
|
||||
|
||||
|
||||
def test_reference_far_far_cousin_package():
|
||||
imports = set()
|
||||
name = get_type_reference(
|
||||
package="a.x.y.z", imports=imports, source_type="a.b.c.d.Message"
|
||||
)
|
||||
|
||||
assert imports == {"from ....b.c import d as ___b_c_d__"}
|
||||
assert name == "___b_c_d__.Message"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["full_name", "expected_output"],
|
||||
[
|
||||
("package.SomeMessage.NestedType", ("package", "SomeMessage.NestedType")),
|
||||
(".package.SomeMessage.NestedType", ("package", "SomeMessage.NestedType")),
|
||||
(".service.ExampleRequest", ("service", "ExampleRequest")),
|
||||
(".package.lower_case_message", ("package", "lower_case_message")),
|
||||
],
|
||||
)
|
||||
def test_parse_field_type_name(full_name, expected_output):
|
||||
assert parse_source_type_name(full_name) == expected_output
|
||||
|
@ -3,6 +3,7 @@ import json
|
||||
import os
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
from types import ModuleType
|
||||
from typing import Set
|
||||
|
||||
import pytest
|
||||
@ -10,7 +11,12 @@ import pytest
|
||||
import betterproto
|
||||
from betterproto.tests.inputs import config as test_input_config
|
||||
from betterproto.tests.mocks import MockChannel
|
||||
from betterproto.tests.util import get_directories, get_test_case_json_data, inputs_path
|
||||
from betterproto.tests.util import (
|
||||
find_module,
|
||||
get_directories,
|
||||
get_test_case_json_data,
|
||||
inputs_path,
|
||||
)
|
||||
|
||||
# Force pure-python implementation instead of C++, otherwise imports
|
||||
# break things because we can't properly reset the symbol database.
|
||||
@ -50,14 +56,17 @@ class TestCases:
|
||||
test_cases = TestCases(
|
||||
path=inputs_path,
|
||||
services=test_input_config.services,
|
||||
xfail=test_input_config.tests,
|
||||
xfail=test_input_config.xfail,
|
||||
)
|
||||
|
||||
plugin_output_package = "betterproto.tests.output_betterproto"
|
||||
reference_output_package = "betterproto.tests.output_reference"
|
||||
|
||||
TestData = namedtuple("TestData", ["plugin_module", "reference_module", "json_data"])
|
||||
|
||||
TestData = namedtuple("TestData", "plugin_module, reference_module, json_data")
|
||||
|
||||
def module_has_entry_point(module: ModuleType):
|
||||
return any(hasattr(module, attr) for attr in ["Test", "TestStub"])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -75,11 +84,19 @@ def test_data(request):
|
||||
|
||||
sys.path.append(reference_module_root)
|
||||
|
||||
plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}")
|
||||
|
||||
plugin_module_entry_point = find_module(plugin_module, module_has_entry_point)
|
||||
|
||||
if not plugin_module_entry_point:
|
||||
raise Exception(
|
||||
f"Test case {repr(test_case_name)} has no entry point. "
|
||||
"Please add a proto message or service called Test and recompile."
|
||||
)
|
||||
|
||||
yield (
|
||||
TestData(
|
||||
plugin_module=importlib.import_module(
|
||||
f"{plugin_output_package}.{test_case_name}.{test_case_name}"
|
||||
),
|
||||
plugin_module=plugin_module_entry_point,
|
||||
reference_module=lambda: importlib.import_module(
|
||||
f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2"
|
||||
),
|
||||
|
@ -1,7 +1,10 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import os
|
||||
import pathlib
|
||||
from pathlib import Path
|
||||
from typing import Generator, IO, Optional
|
||||
from types import ModuleType
|
||||
from typing import Callable, Generator, Optional
|
||||
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
|
||||
|
||||
@ -55,3 +58,35 @@ def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] =
|
||||
|
||||
with test_data_file_path.open("r") as fh:
|
||||
return fh.read()
|
||||
|
||||
|
||||
def find_module(
|
||||
module: ModuleType, predicate: Callable[[ModuleType], bool]
|
||||
) -> Optional[ModuleType]:
|
||||
"""
|
||||
Recursively search module tree for a module that matches the search predicate.
|
||||
Assumes that the submodules are directories containing __init__.py.
|
||||
|
||||
Example:
|
||||
|
||||
# find module inside foo that contains Test
|
||||
import foo
|
||||
test_module = find_module(foo, lambda m: hasattr(m, 'Test'))
|
||||
"""
|
||||
if predicate(module):
|
||||
return module
|
||||
|
||||
module_path = pathlib.Path(*module.__path__)
|
||||
|
||||
for sub in list(sub.parent for sub in module_path.glob("**/__init__.py")):
|
||||
if sub == module_path:
|
||||
continue
|
||||
sub_module_path = sub.relative_to(module_path)
|
||||
sub_module_name = ".".join(sub_module_path.parts)
|
||||
|
||||
sub_module = importlib.import_module(f".{sub_module_name}", module.__name__)
|
||||
|
||||
if predicate(sub_module):
|
||||
return sub_module
|
||||
|
||||
return None
|
||||
|
21
poetry.lock
generated
21
poetry.lock
generated
@ -271,7 +271,7 @@ docs = ["sphinx", "rst.linker", "jaraco.packaging"]
|
||||
category = "main"
|
||||
description = "A very fast and expressive template engine."
|
||||
name = "jinja2"
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||
version = "2.11.2"
|
||||
|
||||
@ -285,7 +285,7 @@ i18n = ["Babel (>=0.8)"]
|
||||
category = "main"
|
||||
description = "Safely add untrusted strings to HTML/XML markup."
|
||||
name = "markupsafe"
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*"
|
||||
version = "1.1.1"
|
||||
|
||||
@ -369,7 +369,7 @@ dev = ["pre-commit", "tox"]
|
||||
category = "main"
|
||||
description = "Protocol Buffers"
|
||||
name = "protobuf"
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
version = "3.12.2"
|
||||
|
||||
@ -490,14 +490,6 @@ optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
|
||||
version = "1.15.0"
|
||||
|
||||
[[package]]
|
||||
category = "main"
|
||||
description = "String case converter."
|
||||
name = "stringcase"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
version = "1.2.0"
|
||||
|
||||
[[package]]
|
||||
category = "main"
|
||||
description = "Python Library for Tom's Obvious, Minimal Language"
|
||||
@ -612,7 +604,7 @@ testing = ["jaraco.itertools", "func-timeout"]
|
||||
compiler = ["black", "jinja2", "protobuf"]
|
||||
|
||||
[metadata]
|
||||
content-hash = "ecafcaed2d4a25c2829e6dc3ef3c56cd72a8bc28c25c7aeae3484c978c816722"
|
||||
content-hash = "8a4fa01ede86e1b5ba35b9dab8b6eacee766a9b5666f48ab41445c01882ab003"
|
||||
python-versions = "^3.6"
|
||||
|
||||
[metadata.files]
|
||||
@ -865,6 +857,8 @@ protobuf = [
|
||||
{file = "protobuf-3.12.2-cp37-cp37m-win_amd64.whl", hash = "sha256:e72736dd822748b0721f41f9aaaf6a5b6d5cfc78f6c8690263aef8bba4457f0e"},
|
||||
{file = "protobuf-3.12.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:87535dc2d2ef007b9d44e309d2b8ea27a03d2fa09556a72364d706fcb7090828"},
|
||||
{file = "protobuf-3.12.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:50b5fee674878b14baea73b4568dc478c46a31dd50157a5b5d2f71138243b1a9"},
|
||||
{file = "protobuf-3.12.2-py2.py3-none-any.whl", hash = "sha256:a96f8fc625e9ff568838e556f6f6ae8eca8b4837cdfb3f90efcb7c00e342a2eb"},
|
||||
{file = "protobuf-3.12.2.tar.gz", hash = "sha256:49ef8ab4c27812a89a76fa894fe7a08f42f2147078392c0dee51d4a444ef6df5"},
|
||||
]
|
||||
py = [
|
||||
{file = "py-1.8.2-py2.py3-none-any.whl", hash = "sha256:a673fa23d7000440cc885c17dbd34fafcb7d7a6e230b29f6766400de36a33c44"},
|
||||
@ -920,9 +914,6 @@ six = [
|
||||
{file = "six-1.15.0-py2.py3-none-any.whl", hash = "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced"},
|
||||
{file = "six-1.15.0.tar.gz", hash = "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259"},
|
||||
]
|
||||
stringcase = [
|
||||
{file = "stringcase-1.2.0.tar.gz", hash = "sha256:48a06980661908efe8d9d34eab2b6c13aefa2163b3ced26972902e3bdfd87008"},
|
||||
]
|
||||
toml = [
|
||||
{file = "toml-0.10.1-py2.py3-none-any.whl", hash = "sha256:bda89d5935c2eac546d648028b9901107a595863cb36bae0c73ac804a9b4ce88"},
|
||||
{file = "toml-0.10.1.tar.gz", hash = "sha256:926b612be1e5ce0634a2ca03470f95169cf16f939018233a670519cb4ac58b0f"},
|
||||
|
@ -18,7 +18,6 @@ dataclasses = { version = "^0.7", python = ">=3.6, <3.7" }
|
||||
grpclib = "^0.3.1"
|
||||
jinja2 = { version = "^2.11.2", optional = true }
|
||||
protobuf = { version = "^3.12.2", optional = true }
|
||||
stringcase = "^1.2.0"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
black = "^19.10b0"
|
||||
|
Loading…
x
Reference in New Issue
Block a user