Merge branch 'master' into michael-sayapin/master
This commit is contained in:
		
							
								
								
									
										29
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										29
									
								
								README.md
									
									
									
									
									
								
							@@ -68,14 +68,15 @@ message Greeting {
 | 
				
			|||||||
You can run the following:
 | 
					You can run the following:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```sh
 | 
					```sh
 | 
				
			||||||
protoc -I . --python_betterproto_out=. example.proto
 | 
					mkdir lib
 | 
				
			||||||
 | 
					protoc -I . --python_betterproto_out=lib example.proto
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
This will generate `hello.py` which looks like:
 | 
					This will generate `lib/hello/__init__.py` which looks like:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```py
 | 
					```python
 | 
				
			||||||
# Generated by the protocol buffer compiler.  DO NOT EDIT!
 | 
					# Generated by the protocol buffer compiler.  DO NOT EDIT!
 | 
				
			||||||
# sources: hello.proto
 | 
					# sources: example.proto
 | 
				
			||||||
# plugin: python-betterproto
 | 
					# plugin: python-betterproto
 | 
				
			||||||
from dataclasses import dataclass
 | 
					from dataclasses import dataclass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -83,7 +84,7 @@ import betterproto
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@dataclass
 | 
					@dataclass
 | 
				
			||||||
class Hello(betterproto.Message):
 | 
					class Greeting(betterproto.Message):
 | 
				
			||||||
    """Greeting represents a message you can tell a user."""
 | 
					    """Greeting represents a message you can tell a user."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    message: str = betterproto.string_field(1)
 | 
					    message: str = betterproto.string_field(1)
 | 
				
			||||||
@@ -91,23 +92,23 @@ class Hello(betterproto.Message):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
Now you can use it!
 | 
					Now you can use it!
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```py
 | 
					```python
 | 
				
			||||||
>>> from hello import Hello
 | 
					>>> from lib.hello import Greeting
 | 
				
			||||||
>>> test = Hello()
 | 
					>>> test = Greeting()
 | 
				
			||||||
>>> test
 | 
					>>> test
 | 
				
			||||||
Hello(message='')
 | 
					Greeting(message='')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
>>> test.message = "Hey!"
 | 
					>>> test.message = "Hey!"
 | 
				
			||||||
>>> test
 | 
					>>> test
 | 
				
			||||||
Hello(message="Hey!")
 | 
					Greeting(message="Hey!")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
>>> serialized = bytes(test)
 | 
					>>> serialized = bytes(test)
 | 
				
			||||||
>>> serialized
 | 
					>>> serialized
 | 
				
			||||||
b'\n\x04Hey!'
 | 
					b'\n\x04Hey!'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
>>> another = Hello().parse(serialized)
 | 
					>>> another = Greeting().parse(serialized)
 | 
				
			||||||
>>> another
 | 
					>>> another
 | 
				
			||||||
Hello(message="Hey!")
 | 
					Greeting(message="Hey!")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
>>> another.to_dict()
 | 
					>>> another.to_dict()
 | 
				
			||||||
{"message": "Hey!"}
 | 
					{"message": "Hey!"}
 | 
				
			||||||
@@ -315,7 +316,7 @@ To benefit from the collection of standard development tasks ensure you have mak
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
This project enforces [black](https://github.com/psf/black) python code formatting.
 | 
					This project enforces [black](https://github.com/psf/black) python code formatting.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Before commiting changes run:
 | 
					Before committing changes run:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```sh
 | 
					```sh
 | 
				
			||||||
make format
 | 
					make format
 | 
				
			||||||
@@ -336,7 +337,7 @@ Adding a standard test case is easy.
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
- Create a new directory `betterproto/tests/inputs/<name>`
 | 
					- Create a new directory `betterproto/tests/inputs/<name>`
 | 
				
			||||||
  - add `<name>.proto`  with a message called `Test`
 | 
					  - add `<name>.proto`  with a message called `Test`
 | 
				
			||||||
  - add `<name>.json` with some test data
 | 
					  - add `<name>.json` with some test data (optional)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
It will be picked up automatically when you run the tests.
 | 
					It will be picked up automatically when you run the tests.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -7,27 +7,22 @@ import sys
 | 
				
			|||||||
from abc import ABC
 | 
					from abc import ABC
 | 
				
			||||||
from base64 import b64decode, b64encode
 | 
					from base64 import b64decode, b64encode
 | 
				
			||||||
from datetime import datetime, timedelta, timezone
 | 
					from datetime import datetime, timedelta, timezone
 | 
				
			||||||
import stringcase
 | 
					 | 
				
			||||||
from typing import (
 | 
					from typing import (
 | 
				
			||||||
    Any,
 | 
					    Any,
 | 
				
			||||||
    AsyncGenerator,
 | 
					 | 
				
			||||||
    Callable,
 | 
					    Callable,
 | 
				
			||||||
    Collection,
 | 
					 | 
				
			||||||
    Dict,
 | 
					    Dict,
 | 
				
			||||||
    Generator,
 | 
					    Generator,
 | 
				
			||||||
    Iterator,
 | 
					 | 
				
			||||||
    List,
 | 
					    List,
 | 
				
			||||||
    Mapping,
 | 
					 | 
				
			||||||
    Optional,
 | 
					    Optional,
 | 
				
			||||||
    Set,
 | 
					    Set,
 | 
				
			||||||
    SupportsBytes,
 | 
					 | 
				
			||||||
    Tuple,
 | 
					    Tuple,
 | 
				
			||||||
    Type,
 | 
					    Type,
 | 
				
			||||||
    Union,
 | 
					    Union,
 | 
				
			||||||
    get_type_hints,
 | 
					    get_type_hints,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
from ._types import ST, T
 | 
					
 | 
				
			||||||
from .casing import safe_snake_case
 | 
					from ._types import T
 | 
				
			||||||
 | 
					from .casing import camel_case, safe_snake_case, safe_snake_case, snake_case
 | 
				
			||||||
from .grpc.grpclib_client import ServiceStub
 | 
					from .grpc.grpclib_client import ServiceStub
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
 | 
					if not (sys.version_info.major == 3 and sys.version_info.minor >= 7):
 | 
				
			||||||
@@ -124,8 +119,8 @@ DATETIME_ZERO = datetime_default_gen()
 | 
				
			|||||||
class Casing(enum.Enum):
 | 
					class Casing(enum.Enum):
 | 
				
			||||||
    """Casing constants for serialization."""
 | 
					    """Casing constants for serialization."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    CAMEL = stringcase.camelcase
 | 
					    CAMEL = camel_case
 | 
				
			||||||
    SNAKE = stringcase.snakecase
 | 
					    SNAKE = snake_case
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class _PLACEHOLDER:
 | 
					class _PLACEHOLDER:
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,9 +1,21 @@
 | 
				
			|||||||
import stringcase
 | 
					import re
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Word delimiters and symbols that will not be preserved when re-casing.
 | 
				
			||||||
 | 
					# language=PythonRegExp
 | 
				
			||||||
 | 
					SYMBOLS = "[^a-zA-Z0-9]*"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Optionally capitalized word.
 | 
				
			||||||
 | 
					# language=PythonRegExp
 | 
				
			||||||
 | 
					WORD = "[A-Z]*[a-z]*[0-9]*"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Uppercase word, not followed by lowercase letters.
 | 
				
			||||||
 | 
					# language=PythonRegExp
 | 
				
			||||||
 | 
					WORD_UPPER = "[A-Z]+(?![a-z])[0-9]*"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def safe_snake_case(value: str) -> str:
 | 
					def safe_snake_case(value: str) -> str:
 | 
				
			||||||
    """Snake case a value taking into account Python keywords."""
 | 
					    """Snake case a value taking into account Python keywords."""
 | 
				
			||||||
    value = stringcase.snakecase(value)
 | 
					    value = snake_case(value)
 | 
				
			||||||
    if value in [
 | 
					    if value in [
 | 
				
			||||||
        "and",
 | 
					        "and",
 | 
				
			||||||
        "as",
 | 
					        "as",
 | 
				
			||||||
@@ -39,3 +51,70 @@ def safe_snake_case(value: str) -> str:
 | 
				
			|||||||
        # https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
 | 
					        # https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
 | 
				
			||||||
        value += "_"
 | 
					        value += "_"
 | 
				
			||||||
    return value
 | 
					    return value
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def snake_case(value: str, strict: bool = True):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Join words with an underscore into lowercase and remove symbols.
 | 
				
			||||||
 | 
					    @param value: value to convert
 | 
				
			||||||
 | 
					    @param strict: force single underscores
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def substitute_word(symbols, word, is_start):
 | 
				
			||||||
 | 
					        if not word:
 | 
				
			||||||
 | 
					            return ""
 | 
				
			||||||
 | 
					        if strict:
 | 
				
			||||||
 | 
					            delimiter_count = 0 if is_start else 1  # Single underscore if strict.
 | 
				
			||||||
 | 
					        elif is_start:
 | 
				
			||||||
 | 
					            delimiter_count = len(symbols)
 | 
				
			||||||
 | 
					        elif word.isupper() or word.islower():
 | 
				
			||||||
 | 
					            delimiter_count = max(
 | 
				
			||||||
 | 
					                1, len(symbols)
 | 
				
			||||||
 | 
					            )  # Preserve all delimiters if not strict.
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            delimiter_count = len(symbols) + 1  # Extra underscore for leading capital.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return ("_" * delimiter_count) + word.lower()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    snake = re.sub(
 | 
				
			||||||
 | 
					        f"(^)?({SYMBOLS})({WORD_UPPER}|{WORD})",
 | 
				
			||||||
 | 
					        lambda groups: substitute_word(groups[2], groups[3], groups[1] is not None),
 | 
				
			||||||
 | 
					        value,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    return snake
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def pascal_case(value: str, strict: bool = True):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Capitalize each word and remove symbols.
 | 
				
			||||||
 | 
					    @param value: value to convert
 | 
				
			||||||
 | 
					    @param strict: output only alphanumeric characters
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def substitute_word(symbols, word):
 | 
				
			||||||
 | 
					        if strict:
 | 
				
			||||||
 | 
					            return word.capitalize()  # Remove all delimiters
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if word.islower():
 | 
				
			||||||
 | 
					            delimiter_length = len(symbols[:-1])  # Lose one delimiter
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            delimiter_length = len(symbols)  # Preserve all delimiters
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return ("_" * delimiter_length) + word.capitalize()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return re.sub(
 | 
				
			||||||
 | 
					        f"({SYMBOLS})({WORD_UPPER}|{WORD})",
 | 
				
			||||||
 | 
					        lambda groups: substitute_word(groups[1], groups[2]),
 | 
				
			||||||
 | 
					        value,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def camel_case(value: str, strict: bool = True):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Capitalize all words except first and remove symbols.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    return lowercase_first(pascal_case(value, strict=strict))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def lowercase_first(value: str):
 | 
				
			||||||
 | 
					    return value[0:1].lower() + value[1:]
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,70 +1,160 @@
 | 
				
			|||||||
from typing import Dict, Type
 | 
					import os
 | 
				
			||||||
 | 
					import re
 | 
				
			||||||
import stringcase
 | 
					from typing import Dict, List, Set, Type
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from betterproto import safe_snake_case
 | 
					from betterproto import safe_snake_case
 | 
				
			||||||
 | 
					from betterproto.compile.naming import pythonize_class_name
 | 
				
			||||||
from betterproto.lib.google import protobuf as google_protobuf
 | 
					from betterproto.lib.google import protobuf as google_protobuf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
WRAPPER_TYPES: Dict[str, Type] = {
 | 
					WRAPPER_TYPES: Dict[str, Type] = {
 | 
				
			||||||
    "google.protobuf.DoubleValue": google_protobuf.DoubleValue,
 | 
					    ".google.protobuf.DoubleValue": google_protobuf.DoubleValue,
 | 
				
			||||||
    "google.protobuf.FloatValue": google_protobuf.FloatValue,
 | 
					    ".google.protobuf.FloatValue": google_protobuf.FloatValue,
 | 
				
			||||||
    "google.protobuf.Int32Value": google_protobuf.Int32Value,
 | 
					    ".google.protobuf.Int32Value": google_protobuf.Int32Value,
 | 
				
			||||||
    "google.protobuf.Int64Value": google_protobuf.Int64Value,
 | 
					    ".google.protobuf.Int64Value": google_protobuf.Int64Value,
 | 
				
			||||||
    "google.protobuf.UInt32Value": google_protobuf.UInt32Value,
 | 
					    ".google.protobuf.UInt32Value": google_protobuf.UInt32Value,
 | 
				
			||||||
    "google.protobuf.UInt64Value": google_protobuf.UInt64Value,
 | 
					    ".google.protobuf.UInt64Value": google_protobuf.UInt64Value,
 | 
				
			||||||
    "google.protobuf.BoolValue": google_protobuf.BoolValue,
 | 
					    ".google.protobuf.BoolValue": google_protobuf.BoolValue,
 | 
				
			||||||
    "google.protobuf.StringValue": google_protobuf.StringValue,
 | 
					    ".google.protobuf.StringValue": google_protobuf.StringValue,
 | 
				
			||||||
    "google.protobuf.BytesValue": google_protobuf.BytesValue,
 | 
					    ".google.protobuf.BytesValue": google_protobuf.BytesValue,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_ref_type(
 | 
					def parse_source_type_name(field_type_name):
 | 
				
			||||||
    package: str, imports: set, type_name: str, unwrap: bool = True
 | 
					    """
 | 
				
			||||||
 | 
					    Split full source type name into package and type name.
 | 
				
			||||||
 | 
					    E.g. 'root.package.Message' -> ('root.package', 'Message')
 | 
				
			||||||
 | 
					         'root.Message.SomeEnum' -> ('root', 'Message.SomeEnum')
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    package_match = re.match(r"^\.?([^A-Z]+)\.(.+)", field_type_name)
 | 
				
			||||||
 | 
					    if package_match:
 | 
				
			||||||
 | 
					        package = package_match.group(1)
 | 
				
			||||||
 | 
					        name = package_match.group(2)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        package = ""
 | 
				
			||||||
 | 
					        name = field_type_name.lstrip(".")
 | 
				
			||||||
 | 
					    return package, name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_type_reference(
 | 
				
			||||||
 | 
					    package: str, imports: set, source_type: str, unwrap: bool = True,
 | 
				
			||||||
) -> str:
 | 
					) -> str:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Return a Python type name for a proto type reference. Adds the import if
 | 
					    Return a Python type name for a proto type reference. Adds the import if
 | 
				
			||||||
    necessary. Unwraps well known type if required.
 | 
					    necessary. Unwraps well known type if required.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    # If the package name is a blank string, then this should still work
 | 
					 | 
				
			||||||
    # because by convention packages are lowercase and message/enum types are
 | 
					 | 
				
			||||||
    # pascal-cased. May require refactoring in the future.
 | 
					 | 
				
			||||||
    type_name = type_name.lstrip(".")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    is_wrapper = type_name in WRAPPER_TYPES
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if unwrap:
 | 
					    if unwrap:
 | 
				
			||||||
        if is_wrapper:
 | 
					        if source_type in WRAPPER_TYPES:
 | 
				
			||||||
            wrapped_type = type(WRAPPER_TYPES[type_name]().value)
 | 
					            wrapped_type = type(WRAPPER_TYPES[source_type]().value)
 | 
				
			||||||
            return f"Optional[{wrapped_type.__name__}]"
 | 
					            return f"Optional[{wrapped_type.__name__}]"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if type_name == "google.protobuf.Duration":
 | 
					        if source_type == ".google.protobuf.Duration":
 | 
				
			||||||
            return "timedelta"
 | 
					            return "timedelta"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if type_name == "google.protobuf.Timestamp":
 | 
					        if source_type == ".google.protobuf.Timestamp":
 | 
				
			||||||
            return "datetime"
 | 
					            return "datetime"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if type_name.startswith(package):
 | 
					    source_package, source_type = parse_source_type_name(source_type)
 | 
				
			||||||
        parts = type_name.lstrip(package).lstrip(".").split(".")
 | 
					 | 
				
			||||||
        if len(parts) == 1 or (len(parts) > 1 and parts[0][0] == parts[0][0].upper()):
 | 
					 | 
				
			||||||
            # This is the current package, which has nested types flattened.
 | 
					 | 
				
			||||||
            # foo.bar_thing => FooBarThing
 | 
					 | 
				
			||||||
            cased = [stringcase.pascalcase(part) for part in parts]
 | 
					 | 
				
			||||||
            type_name = f'"{"".join(cased)}"'
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Use precompiled classes for google.protobuf.* objects
 | 
					    current_package: List[str] = package.split(".") if package else []
 | 
				
			||||||
    if type_name.startswith("google.protobuf.") and type_name.count(".") == 2:
 | 
					    py_package: List[str] = source_package.split(".") if source_package else []
 | 
				
			||||||
        type_name = type_name.rsplit(".", maxsplit=1)[1]
 | 
					    py_type: str = pythonize_class_name(source_type)
 | 
				
			||||||
        import_package = "betterproto.lib.google.protobuf"
 | 
					 | 
				
			||||||
        import_alias = safe_snake_case(import_package)
 | 
					 | 
				
			||||||
        imports.add(f"import {import_package} as {import_alias}")
 | 
					 | 
				
			||||||
        return f"{import_alias}.{type_name}"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if "." in type_name:
 | 
					    compiling_google_protobuf = current_package == ["google", "protobuf"]
 | 
				
			||||||
        # This is imported from another package. No need
 | 
					    importing_google_protobuf = py_package == ["google", "protobuf"]
 | 
				
			||||||
        # to use a forward ref and we need to add the import.
 | 
					    if importing_google_protobuf and not compiling_google_protobuf:
 | 
				
			||||||
        parts = type_name.split(".")
 | 
					        py_package = ["betterproto", "lib"] + py_package
 | 
				
			||||||
        parts[-1] = stringcase.pascalcase(parts[-1])
 | 
					 | 
				
			||||||
        imports.add(f"from .{'.'.join(parts[:-2])} import {parts[-2]}")
 | 
					 | 
				
			||||||
        type_name = f"{parts[-2]}.{parts[-1]}"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return type_name
 | 
					    if py_package[:1] == ["betterproto"]:
 | 
				
			||||||
 | 
					        return reference_absolute(imports, py_package, py_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if py_package == current_package:
 | 
				
			||||||
 | 
					        return reference_sibling(py_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if py_package[: len(current_package)] == current_package:
 | 
				
			||||||
 | 
					        return reference_descendent(current_package, imports, py_package, py_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if current_package[: len(py_package)] == py_package:
 | 
				
			||||||
 | 
					        return reference_ancestor(current_package, imports, py_package, py_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return reference_cousin(current_package, imports, py_package, py_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def reference_absolute(imports, py_package, py_type):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Returns a reference to a python type located in the root, i.e. sys.path.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    string_import = ".".join(py_package)
 | 
				
			||||||
 | 
					    string_alias = safe_snake_case(string_import)
 | 
				
			||||||
 | 
					    imports.add(f"import {string_import} as {string_alias}")
 | 
				
			||||||
 | 
					    return f"{string_alias}.{py_type}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def reference_sibling(py_type: str) -> str:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Returns a reference to a python type within the same package as the current package.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    return f'"{py_type}"'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def reference_descendent(
 | 
				
			||||||
 | 
					    current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
 | 
				
			||||||
 | 
					) -> str:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Returns a reference to a python type in a package that is a descendent of the current package,
 | 
				
			||||||
 | 
					    and adds the required import that is aliased to avoid name conflicts.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    importing_descendent = py_package[len(current_package) :]
 | 
				
			||||||
 | 
					    string_from = ".".join(importing_descendent[:-1])
 | 
				
			||||||
 | 
					    string_import = importing_descendent[-1]
 | 
				
			||||||
 | 
					    if string_from:
 | 
				
			||||||
 | 
					        string_alias = "_".join(importing_descendent)
 | 
				
			||||||
 | 
					        imports.add(f"from .{string_from} import {string_import} as {string_alias}")
 | 
				
			||||||
 | 
					        return f"{string_alias}.{py_type}"
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        imports.add(f"from . import {string_import}")
 | 
				
			||||||
 | 
					        return f"{string_import}.{py_type}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def reference_ancestor(
 | 
				
			||||||
 | 
					    current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
 | 
				
			||||||
 | 
					) -> str:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Returns a reference to a python type in a package which is an ancestor to the current package,
 | 
				
			||||||
 | 
					    and adds the required import that is aliased (if possible) to avoid name conflicts.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Adds trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34).
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    distance_up = len(current_package) - len(py_package)
 | 
				
			||||||
 | 
					    if py_package:
 | 
				
			||||||
 | 
					        string_import = py_package[-1]
 | 
				
			||||||
 | 
					        string_alias = f"_{'_' * distance_up}{string_import}__"
 | 
				
			||||||
 | 
					        string_from = f"..{'.' * distance_up}"
 | 
				
			||||||
 | 
					        imports.add(f"from {string_from} import {string_import} as {string_alias}")
 | 
				
			||||||
 | 
					        return f"{string_alias}.{py_type}"
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        string_alias = f"{'_' * distance_up}{py_type}__"
 | 
				
			||||||
 | 
					        imports.add(f"from .{'.' * distance_up} import {py_type} as {string_alias}")
 | 
				
			||||||
 | 
					        return string_alias
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def reference_cousin(
 | 
				
			||||||
 | 
					    current_package: List[str], imports: Set[str], py_package: List[str], py_type: str
 | 
				
			||||||
 | 
					) -> str:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Returns a reference to a python type in a package that is not descendent, ancestor or sibling,
 | 
				
			||||||
 | 
					    and adds the required import that is aliased to avoid name conflicts.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    shared_ancestry = os.path.commonprefix([current_package, py_package])
 | 
				
			||||||
 | 
					    distance_up = len(current_package) - len(shared_ancestry)
 | 
				
			||||||
 | 
					    string_from = f".{'.' * distance_up}" + ".".join(
 | 
				
			||||||
 | 
					        py_package[len(shared_ancestry) : -1]
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    string_import = py_package[-1]
 | 
				
			||||||
 | 
					    # Add trailing __ to avoid name mangling (python.org/dev/peps/pep-0008/#id34)
 | 
				
			||||||
 | 
					    string_alias = (
 | 
				
			||||||
 | 
					        f"{'_' * distance_up}"
 | 
				
			||||||
 | 
					        + safe_snake_case(".".join(py_package[len(shared_ancestry) :]))
 | 
				
			||||||
 | 
					        + "__"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    imports.add(f"from {string_from} import {string_import} as {string_alias}")
 | 
				
			||||||
 | 
					    return f"{string_alias}.{py_type}"
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										13
									
								
								betterproto/compile/naming.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								betterproto/compile/naming.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,13 @@
 | 
				
			|||||||
 | 
					from betterproto import casing
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def pythonize_class_name(name):
 | 
				
			||||||
 | 
					    return casing.pascal_case(name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def pythonize_field_name(name: str):
 | 
				
			||||||
 | 
					    return casing.safe_snake_case(name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def pythonize_method_name(name: str):
 | 
				
			||||||
 | 
					    return casing.safe_snake_case(name)
 | 
				
			||||||
@@ -84,7 +84,7 @@ class FieldOptionsCType(betterproto.Enum):
 | 
				
			|||||||
    STRING_PIECE = 2
 | 
					    STRING_PIECE = 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class FieldOptionsJSType(betterproto.Enum):
 | 
					class FieldOptionsJsType(betterproto.Enum):
 | 
				
			||||||
    JS_NORMAL = 0
 | 
					    JS_NORMAL = 0
 | 
				
			||||||
    JS_STRING = 1
 | 
					    JS_STRING = 1
 | 
				
			||||||
    JS_NUMBER = 2
 | 
					    JS_NUMBER = 2
 | 
				
			||||||
@@ -717,7 +717,7 @@ class FieldOptions(betterproto.Message):
 | 
				
			|||||||
    # use the JavaScript "number" type.  The behavior of the default option
 | 
					    # use the JavaScript "number" type.  The behavior of the default option
 | 
				
			||||||
    # JS_NORMAL is implementation dependent. This option is an enum to permit
 | 
					    # JS_NORMAL is implementation dependent. This option is an enum to permit
 | 
				
			||||||
    # additional types to be added, e.g. goog.math.Integer.
 | 
					    # additional types to be added, e.g. goog.math.Integer.
 | 
				
			||||||
    jstype: "FieldOptionsJSType" = betterproto.enum_field(6)
 | 
					    jstype: "FieldOptionsJsType" = betterproto.enum_field(6)
 | 
				
			||||||
    # Should this field be parsed lazily?  Lazy applies only to message-type
 | 
					    # Should this field be parsed lazily?  Lazy applies only to message-type
 | 
				
			||||||
    # fields.  It means that when the outer message is initially parsed, the
 | 
					    # fields.  It means that when the outer message is initially parsed, the
 | 
				
			||||||
    # inner message's contents will not be parsed but instead stored in encoded
 | 
					    # inner message's contents will not be parsed but instead stored in encoded
 | 
				
			||||||
@@ -2,14 +2,19 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import itertools
 | 
					import itertools
 | 
				
			||||||
import os.path
 | 
					import os.path
 | 
				
			||||||
 | 
					import pathlib
 | 
				
			||||||
import re
 | 
					import re
 | 
				
			||||||
import stringcase
 | 
					 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
import textwrap
 | 
					import textwrap
 | 
				
			||||||
from typing import List, Union
 | 
					from typing import List, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import betterproto
 | 
					import betterproto
 | 
				
			||||||
from betterproto.casing import safe_snake_case
 | 
					from betterproto.compile.importing import get_type_reference
 | 
				
			||||||
from betterproto.compile.importing import get_ref_type
 | 
					from betterproto.compile.naming import (
 | 
				
			||||||
 | 
					    pythonize_class_name,
 | 
				
			||||||
 | 
					    pythonize_field_name,
 | 
				
			||||||
 | 
					    pythonize_method_name,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
try:
 | 
					try:
 | 
				
			||||||
    # betterproto[compiler] specific dependencies
 | 
					    # betterproto[compiler] specific dependencies
 | 
				
			||||||
@@ -35,27 +40,22 @@ except ImportError as err:
 | 
				
			|||||||
    raise SystemExit(1)
 | 
					    raise SystemExit(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def py_type(
 | 
					def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str:
 | 
				
			||||||
    package: str,
 | 
					    if field.type in [1, 2]:
 | 
				
			||||||
    imports: set,
 | 
					 | 
				
			||||||
    message: DescriptorProto,
 | 
					 | 
				
			||||||
    descriptor: FieldDescriptorProto,
 | 
					 | 
				
			||||||
) -> str:
 | 
					 | 
				
			||||||
    if descriptor.type in [1, 2]:
 | 
					 | 
				
			||||||
        return "float"
 | 
					        return "float"
 | 
				
			||||||
    elif descriptor.type in [3, 4, 5, 6, 7, 13, 15, 16, 17, 18]:
 | 
					    elif field.type in [3, 4, 5, 6, 7, 13, 15, 16, 17, 18]:
 | 
				
			||||||
        return "int"
 | 
					        return "int"
 | 
				
			||||||
    elif descriptor.type == 8:
 | 
					    elif field.type == 8:
 | 
				
			||||||
        return "bool"
 | 
					        return "bool"
 | 
				
			||||||
    elif descriptor.type == 9:
 | 
					    elif field.type == 9:
 | 
				
			||||||
        return "str"
 | 
					        return "str"
 | 
				
			||||||
    elif descriptor.type in [11, 14]:
 | 
					    elif field.type in [11, 14]:
 | 
				
			||||||
        # Type referencing another defined Message or a named enum
 | 
					        # Type referencing another defined Message or a named enum
 | 
				
			||||||
        return get_ref_type(package, imports, descriptor.type_name)
 | 
					        return get_type_reference(package, imports, field.type_name)
 | 
				
			||||||
    elif descriptor.type == 12:
 | 
					    elif field.type == 12:
 | 
				
			||||||
        return "bytes"
 | 
					        return "bytes"
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        raise NotImplementedError(f"Unknown type {descriptor.type}")
 | 
					        raise NotImplementedError(f"Unknown type {field.type}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_py_zero(type_num: int) -> Union[str, float]:
 | 
					def get_py_zero(type_num: int) -> Union[str, float]:
 | 
				
			||||||
@@ -131,17 +131,17 @@ def generate_code(request, response):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    output_map = {}
 | 
					    output_map = {}
 | 
				
			||||||
    for proto_file in request.proto_file:
 | 
					    for proto_file in request.proto_file:
 | 
				
			||||||
        out = proto_file.package
 | 
					        if (
 | 
				
			||||||
 | 
					            proto_file.package == "google.protobuf"
 | 
				
			||||||
        if out == "google.protobuf" and "INCLUDE_GOOGLE" not in plugin_options:
 | 
					            and "INCLUDE_GOOGLE" not in plugin_options
 | 
				
			||||||
 | 
					        ):
 | 
				
			||||||
            continue
 | 
					            continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not out:
 | 
					        output_file = str(pathlib.Path(*proto_file.package.split("."), "__init__.py"))
 | 
				
			||||||
            out = os.path.splitext(proto_file.name)[0].replace(os.path.sep, ".")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if out not in output_map:
 | 
					        if output_file not in output_map:
 | 
				
			||||||
            output_map[out] = {"package": proto_file.package, "files": []}
 | 
					            output_map[output_file] = {"package": proto_file.package, "files": []}
 | 
				
			||||||
        output_map[out]["files"].append(proto_file)
 | 
					        output_map[output_file]["files"].append(proto_file)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # TODO: Figure out how to handle gRPC request/response messages and add
 | 
					    # TODO: Figure out how to handle gRPC request/response messages and add
 | 
				
			||||||
    # processing below for Service.
 | 
					    # processing below for Service.
 | 
				
			||||||
@@ -160,17 +160,10 @@ def generate_code(request, response):
 | 
				
			|||||||
            "services": [],
 | 
					            "services": [],
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        type_mapping = {}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for proto_file in options["files"]:
 | 
					        for proto_file in options["files"]:
 | 
				
			||||||
            # print(proto_file.message_type, file=sys.stderr)
 | 
					            item: DescriptorProto
 | 
				
			||||||
            # print(proto_file.service, file=sys.stderr)
 | 
					 | 
				
			||||||
            # print(proto_file.source_code_info, file=sys.stderr)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            for item, path in traverse(proto_file):
 | 
					            for item, path in traverse(proto_file):
 | 
				
			||||||
                # print(item, file=sys.stderr)
 | 
					                data = {"name": item.name, "py_name": pythonize_class_name(item.name)}
 | 
				
			||||||
                # print(path, file=sys.stderr)
 | 
					 | 
				
			||||||
                data = {"name": item.name, "py_name": stringcase.pascalcase(item.name)}
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if isinstance(item, DescriptorProto):
 | 
					                if isinstance(item, DescriptorProto):
 | 
				
			||||||
                    # print(item, file=sys.stderr)
 | 
					                    # print(item, file=sys.stderr)
 | 
				
			||||||
@@ -187,7 +180,7 @@ def generate_code(request, response):
 | 
				
			|||||||
                    )
 | 
					                    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    for i, f in enumerate(item.field):
 | 
					                    for i, f in enumerate(item.field):
 | 
				
			||||||
                        t = py_type(package, output["imports"], item, f)
 | 
					                        t = py_type(package, output["imports"], f)
 | 
				
			||||||
                        zero = get_py_zero(f.type)
 | 
					                        zero = get_py_zero(f.type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        repeated = False
 | 
					                        repeated = False
 | 
				
			||||||
@@ -222,13 +215,11 @@ def generate_code(request, response):
 | 
				
			|||||||
                                            k = py_type(
 | 
					                                            k = py_type(
 | 
				
			||||||
                                                package,
 | 
					                                                package,
 | 
				
			||||||
                                                output["imports"],
 | 
					                                                output["imports"],
 | 
				
			||||||
                                                item,
 | 
					 | 
				
			||||||
                                                nested.field[0],
 | 
					                                                nested.field[0],
 | 
				
			||||||
                                            )
 | 
					                                            )
 | 
				
			||||||
                                            v = py_type(
 | 
					                                            v = py_type(
 | 
				
			||||||
                                                package,
 | 
					                                                package,
 | 
				
			||||||
                                                output["imports"],
 | 
					                                                output["imports"],
 | 
				
			||||||
                                                item,
 | 
					 | 
				
			||||||
                                                nested.field[1],
 | 
					                                                nested.field[1],
 | 
				
			||||||
                                            )
 | 
					                                            )
 | 
				
			||||||
                                            t = f"Dict[{k}, {v}]"
 | 
					                                            t = f"Dict[{k}, {v}]"
 | 
				
			||||||
@@ -264,7 +255,7 @@ def generate_code(request, response):
 | 
				
			|||||||
                        data["properties"].append(
 | 
					                        data["properties"].append(
 | 
				
			||||||
                            {
 | 
					                            {
 | 
				
			||||||
                                "name": f.name,
 | 
					                                "name": f.name,
 | 
				
			||||||
                                "py_name": safe_snake_case(f.name),
 | 
					                                "py_name": pythonize_field_name(f.name),
 | 
				
			||||||
                                "number": f.number,
 | 
					                                "number": f.number,
 | 
				
			||||||
                                "comment": get_comment(proto_file, path + [2, i]),
 | 
					                                "comment": get_comment(proto_file, path + [2, i]),
 | 
				
			||||||
                                "proto_type": int(f.type),
 | 
					                                "proto_type": int(f.type),
 | 
				
			||||||
@@ -305,14 +296,14 @@ def generate_code(request, response):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                data = {
 | 
					                data = {
 | 
				
			||||||
                    "name": service.name,
 | 
					                    "name": service.name,
 | 
				
			||||||
                    "py_name": stringcase.pascalcase(service.name),
 | 
					                    "py_name": pythonize_class_name(service.name),
 | 
				
			||||||
                    "comment": get_comment(proto_file, [6, i]),
 | 
					                    "comment": get_comment(proto_file, [6, i]),
 | 
				
			||||||
                    "methods": [],
 | 
					                    "methods": [],
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                for j, method in enumerate(service.method):
 | 
					                for j, method in enumerate(service.method):
 | 
				
			||||||
                    input_message = None
 | 
					                    input_message = None
 | 
				
			||||||
                    input_type = get_ref_type(
 | 
					                    input_type = get_type_reference(
 | 
				
			||||||
                        package, output["imports"], method.input_type
 | 
					                        package, output["imports"], method.input_type
 | 
				
			||||||
                    ).strip('"')
 | 
					                    ).strip('"')
 | 
				
			||||||
                    for msg in output["messages"]:
 | 
					                    for msg in output["messages"]:
 | 
				
			||||||
@@ -326,14 +317,14 @@ def generate_code(request, response):
 | 
				
			|||||||
                    data["methods"].append(
 | 
					                    data["methods"].append(
 | 
				
			||||||
                        {
 | 
					                        {
 | 
				
			||||||
                            "name": method.name,
 | 
					                            "name": method.name,
 | 
				
			||||||
                            "py_name": stringcase.snakecase(method.name),
 | 
					                            "py_name": pythonize_method_name(method.name),
 | 
				
			||||||
                            "comment": get_comment(proto_file, [6, i, 2, j], indent=8),
 | 
					                            "comment": get_comment(proto_file, [6, i, 2, j], indent=8),
 | 
				
			||||||
                            "route": f"/{package}.{service.name}/{method.name}",
 | 
					                            "route": f"/{package}.{service.name}/{method.name}",
 | 
				
			||||||
                            "input": get_ref_type(
 | 
					                            "input": get_type_reference(
 | 
				
			||||||
                                package, output["imports"], method.input_type
 | 
					                                package, output["imports"], method.input_type
 | 
				
			||||||
                            ).strip('"'),
 | 
					                            ).strip('"'),
 | 
				
			||||||
                            "input_message": input_message,
 | 
					                            "input_message": input_message,
 | 
				
			||||||
                            "output": get_ref_type(
 | 
					                            "output": get_type_reference(
 | 
				
			||||||
                                package,
 | 
					                                package,
 | 
				
			||||||
                                output["imports"],
 | 
					                                output["imports"],
 | 
				
			||||||
                                method.output_type,
 | 
					                                method.output_type,
 | 
				
			||||||
@@ -359,8 +350,7 @@ def generate_code(request, response):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        # Fill response
 | 
					        # Fill response
 | 
				
			||||||
        f = response.file.add()
 | 
					        f = response.file.add()
 | 
				
			||||||
        # print(filename, file=sys.stderr)
 | 
					        f.name = filename
 | 
				
			||||||
        f.name = filename.replace(".", os.path.sep) + ".py"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Render and then format the output file.
 | 
					        # Render and then format the output file.
 | 
				
			||||||
        f.content = black.format_str(
 | 
					        f.content = black.format_str(
 | 
				
			||||||
@@ -368,32 +358,23 @@ def generate_code(request, response):
 | 
				
			|||||||
            mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])),
 | 
					            mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    inits = set([""])
 | 
					    # Make each output directory a package with __init__ file
 | 
				
			||||||
    for f in response.file:
 | 
					    output_paths = set(pathlib.Path(path) for path in output_map.keys())
 | 
				
			||||||
        # Ensure output paths exist
 | 
					    init_files = (
 | 
				
			||||||
        # print(f.name, file=sys.stderr)
 | 
					        set(
 | 
				
			||||||
        dirnames = os.path.dirname(f.name)
 | 
					            directory.joinpath("__init__.py")
 | 
				
			||||||
        if dirnames:
 | 
					            for path in output_paths
 | 
				
			||||||
            os.makedirs(dirnames, exist_ok=True)
 | 
					            for directory in path.parents
 | 
				
			||||||
            base = ""
 | 
					        )
 | 
				
			||||||
            for part in dirnames.split(os.path.sep):
 | 
					        - output_paths
 | 
				
			||||||
                base = os.path.join(base, part)
 | 
					    )
 | 
				
			||||||
                inits.add(base)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for base in inits:
 | 
					 | 
				
			||||||
        name = os.path.join(base, "__init__.py")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if os.path.exists(name):
 | 
					 | 
				
			||||||
            # Never overwrite inits as they may have custom stuff in them.
 | 
					 | 
				
			||||||
            continue
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for init_file in init_files:
 | 
				
			||||||
        init = response.file.add()
 | 
					        init = response.file.add()
 | 
				
			||||||
        init.name = name
 | 
					        init.name = str(init_file)
 | 
				
			||||||
        init.content = b""
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    filenames = sorted([f.name for f in response.file])
 | 
					    for filename in sorted(output_paths.union(init_files)):
 | 
				
			||||||
    for fname in filenames:
 | 
					        print(f"Writing {filename}", file=sys.stderr)
 | 
				
			||||||
        print(f"Writing {fname}", file=sys.stderr)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main():
 | 
					def main():
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,12 +12,12 @@ inputs/
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
## Test case directory structure
 | 
					## Test case directory structure
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Each testcase has a `<name>.proto` file with a message called `Test`, a matching `.json` file and optionally a custom test file called `test_*.py`.
 | 
					Each testcase has a `<name>.proto` file with a message called `Test`, and optionally a matching `.json` file and a custom test called `test_*.py`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
```bash
 | 
					```bash
 | 
				
			||||||
bool/
 | 
					bool/
 | 
				
			||||||
  bool.proto
 | 
					  bool.proto
 | 
				
			||||||
  bool.json
 | 
					  bool.json     # optional
 | 
				
			||||||
  test_bool.py  # optional
 | 
					  test_bool.py  # optional
 | 
				
			||||||
```
 | 
					```
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -61,21 +61,22 @@ def test_value():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
The following tests are automatically executed for all cases:
 | 
					The following tests are automatically executed for all cases:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- [x] Can the generated python code imported?
 | 
					- [x] Can the generated python code be imported?
 | 
				
			||||||
- [x] Can the generated message class be instantiated?
 | 
					- [x] Can the generated message class be instantiated?
 | 
				
			||||||
- [x] Is the generated code compatible with the Google's `grpc_tools.protoc` implementation?
 | 
					- [x] Is the generated code compatible with the Google's `grpc_tools.protoc` implementation?
 | 
				
			||||||
 | 
					  - _when `.json` is present_ 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Running the tests
 | 
					## Running the tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
- `pipenv run generate`  
 | 
					- `pipenv run generate`  
 | 
				
			||||||
  This generates
 | 
					  This generates:
 | 
				
			||||||
  - `betterproto/tests/output_betterproto` — *the plugin generated python classes*
 | 
					  - `betterproto/tests/output_betterproto` — *the plugin generated python classes*
 | 
				
			||||||
  - `betterproto/tests/output_reference` — *reference implementation classes*
 | 
					  - `betterproto/tests/output_reference` — *reference implementation classes*
 | 
				
			||||||
- `pipenv run test`
 | 
					- `pipenv run test`
 | 
				
			||||||
 | 
					
 | 
				
			||||||
## Intentionally Failing tests
 | 
					## Intentionally Failing tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
The standard test suite includes tests that fail by intention. These tests document known bugs and missing features that are intended to be corrented in the future.
 | 
					The standard test suite includes tests that fail by intention. These tests document known bugs and missing features that are intended to be corrected in the future.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
When running `pytest`, they show up as `x` or  `X` in the test results.
 | 
					When running `pytest`, they show up as `x` or  `X` in the test results.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,4 +1,4 @@
 | 
				
			|||||||
from betterproto.tests.output_betterproto.bool.bool import Test
 | 
					from betterproto.tests.output_betterproto.bool import Test
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_value():
 | 
					def test_value():
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -10,6 +10,7 @@ message Test {
 | 
				
			|||||||
  int32 camelCase = 1;
 | 
					  int32 camelCase = 1;
 | 
				
			||||||
  my_enum snake_case = 2;
 | 
					  my_enum snake_case = 2;
 | 
				
			||||||
  snake_case_message snake_case_message = 3;
 | 
					  snake_case_message snake_case_message = 3;
 | 
				
			||||||
 | 
					  int32 UPPERCASE = 4;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
message snake_case_message {
 | 
					message snake_case_message {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,5 +1,5 @@
 | 
				
			|||||||
import betterproto.tests.output_betterproto.casing.casing as casing
 | 
					import betterproto.tests.output_betterproto.casing as casing
 | 
				
			||||||
from betterproto.tests.output_betterproto.casing.casing import Test
 | 
					from betterproto.tests.output_betterproto.casing import Test
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_message_attributes():
 | 
					def test_message_attributes():
 | 
				
			||||||
@@ -8,6 +8,7 @@ def test_message_attributes():
 | 
				
			|||||||
        message, "snake_case_message"
 | 
					        message, "snake_case_message"
 | 
				
			||||||
    ), "snake_case field name is same in python"
 | 
					    ), "snake_case field name is same in python"
 | 
				
			||||||
    assert hasattr(message, "camel_case"), "CamelCase field is snake_case in python"
 | 
					    assert hasattr(message, "camel_case"), "CamelCase field is snake_case in python"
 | 
				
			||||||
 | 
					    assert hasattr(message, "uppercase"), "UPPERCASE field is lowercase in python"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_message_casing():
 | 
					def test_message_casing():
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,5 +0,0 @@
 | 
				
			|||||||
{
 | 
					 | 
				
			||||||
  "UPPERCASE": 10,
 | 
					 | 
				
			||||||
  "UPPERCASE_V2": 10,
 | 
					 | 
				
			||||||
  "UPPER_CAMEL_CASE": 10
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@@ -1,6 +1,4 @@
 | 
				
			|||||||
from betterproto.tests.output_betterproto.casing_message_field_uppercase.casing_message_field_uppercase import (
 | 
					from betterproto.tests.output_betterproto.casing_message_field_uppercase import Test
 | 
				
			||||||
    Test,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_message_casing():
 | 
					def test_message_casing():
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,19 +1,15 @@
 | 
				
			|||||||
# Test cases that are expected to fail, e.g. unimplemented features or bug-fixes.
 | 
					# Test cases that are expected to fail, e.g. unimplemented features or bug-fixes.
 | 
				
			||||||
# Remove from list when fixed.
 | 
					# Remove from list when fixed.
 | 
				
			||||||
tests = {
 | 
					xfail = {
 | 
				
			||||||
    "import_root_sibling",  # 61
 | 
					    "import_circular_dependency",
 | 
				
			||||||
    "import_child_package_from_package",  # 58
 | 
					 | 
				
			||||||
    "import_root_package_from_child",  # 60
 | 
					 | 
				
			||||||
    "import_parent_package_from_child",  # 59
 | 
					 | 
				
			||||||
    "import_circular_dependency",  # failing because of other bugs now
 | 
					 | 
				
			||||||
    "import_packages_same_name",  # 25
 | 
					 | 
				
			||||||
    "oneof_enum",  # 63
 | 
					    "oneof_enum",  # 63
 | 
				
			||||||
    "casing_message_field_uppercase",  # 11
 | 
					 | 
				
			||||||
    "namespace_keywords",  # 70
 | 
					    "namespace_keywords",  # 70
 | 
				
			||||||
    "namespace_builtin_types",  # 53
 | 
					    "namespace_builtin_types",  # 53
 | 
				
			||||||
    "googletypes_struct",  # 9
 | 
					    "googletypes_struct",  # 9
 | 
				
			||||||
    "googletypes_value",  # 9
 | 
					    "googletypes_value",  # 9
 | 
				
			||||||
    "enum_skipped_value",  # 93
 | 
					    "enum_skipped_value",  # 93
 | 
				
			||||||
 | 
					    "import_capitalized_package",
 | 
				
			||||||
 | 
					    "example",  # This is the example in the readme. Not a test.
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
services = {
 | 
					services = {
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										8
									
								
								betterproto/tests/inputs/example/example.proto
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8
									
								
								betterproto/tests/inputs/example/example.proto
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,8 @@
 | 
				
			|||||||
 | 
					syntax = "proto3";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package hello;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Greeting represents a message you can tell a user.
 | 
				
			||||||
 | 
					message Greeting {
 | 
				
			||||||
 | 
					  string message = 1;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -4,9 +4,7 @@ import betterproto.lib.google.protobuf as protobuf
 | 
				
			|||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from betterproto.tests.mocks import MockChannel
 | 
					from betterproto.tests.mocks import MockChannel
 | 
				
			||||||
from betterproto.tests.output_betterproto.googletypes_response.googletypes_response import (
 | 
					from betterproto.tests.output_betterproto.googletypes_response import TestStub
 | 
				
			||||||
    TestStub,
 | 
					 | 
				
			||||||
)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
test_cases = [
 | 
					test_cases = [
 | 
				
			||||||
    (TestStub.get_double, protobuf.DoubleValue, 2.5),
 | 
					    (TestStub.get_double, protobuf.DoubleValue, 2.5),
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,7 +1,7 @@
 | 
				
			|||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from betterproto.tests.mocks import MockChannel
 | 
					from betterproto.tests.mocks import MockChannel
 | 
				
			||||||
from betterproto.tests.output_betterproto.googletypes_response_embedded.googletypes_response_embedded import (
 | 
					from betterproto.tests.output_betterproto.googletypes_response_embedded import (
 | 
				
			||||||
    Output,
 | 
					    Output,
 | 
				
			||||||
    TestStub,
 | 
					    TestStub,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -0,0 +1,8 @@
 | 
				
			|||||||
 | 
					syntax = "proto3";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package Capitalized;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message Message {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -0,0 +1,9 @@
 | 
				
			|||||||
 | 
					syntax = "proto3";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "capitalized.proto";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Tests that we can import from a package with a capital name, that looks like a nested type, but isn't.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message Test {
 | 
				
			||||||
 | 
					  Capitalized.Message message = 1;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -3,9 +3,9 @@ syntax = "proto3";
 | 
				
			|||||||
import "root.proto";
 | 
					import "root.proto";
 | 
				
			||||||
import "other.proto";
 | 
					import "other.proto";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// This test-case verifies that future implementations will support circular dependencies in the generated python files.
 | 
					// This test-case verifies support for circular dependencies in the generated python files.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// This becomes important when generating 1 python file/module per package, rather than 1 file per proto file.
 | 
					// This is important because we generate 1 python file/module per package, rather than 1 file per proto file.
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// Scenario:
 | 
					// Scenario:
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
@@ -24,5 +24,5 @@ import "other.proto";
 | 
				
			|||||||
//           (root: Test & RootPackageMessage) <-------> (other: OtherPackageMessage)
 | 
					//           (root: Test & RootPackageMessage) <-------> (other: OtherPackageMessage)
 | 
				
			||||||
message Test {
 | 
					message Test {
 | 
				
			||||||
  RootPackageMessage message = 1;
 | 
					  RootPackageMessage message = 1;
 | 
				
			||||||
  other.OtherPackageMessage other =2;
 | 
					  other.OtherPackageMessage other = 2;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -0,0 +1,6 @@
 | 
				
			|||||||
 | 
					syntax = "proto3";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package cousin.cousin_subpackage;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message CousinMessage {
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										11
									
								
								betterproto/tests/inputs/import_cousin_package/test.proto
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								betterproto/tests/inputs/import_cousin_package/test.proto
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,11 @@
 | 
				
			|||||||
 | 
					syntax = "proto3";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package test.subpackage;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "cousin.proto";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Verify that we can import message unrelated to us
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message Test {
 | 
				
			||||||
 | 
					    cousin.cousin_subpackage.CousinMessage message = 1;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -0,0 +1,6 @@
 | 
				
			|||||||
 | 
					syntax = "proto3";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package cousin.subpackage;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message CousinMessage {
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -0,0 +1,11 @@
 | 
				
			|||||||
 | 
					syntax = "proto3";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package test.subpackage;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "cousin.proto";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Verify that we can import a message unrelated to us, in a subpackage with the same name as us.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message Test {
 | 
				
			||||||
 | 
					    cousin.subpackage.CousinMessage message = 1;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -0,0 +1,11 @@
 | 
				
			|||||||
 | 
					syntax = "proto3";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package child;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "root.proto";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// Verify that we can import root message from child package
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message Test {
 | 
				
			||||||
 | 
					    RootMessage message = 1;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -1,11 +0,0 @@
 | 
				
			|||||||
syntax = "proto3";
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import "root.proto";
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
package child;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
// Tests generated imports when a message inside a child-package refers to a message defined in the root.
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
message Test {
 | 
					 | 
				
			||||||
  RootMessage message = 1;
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@@ -1,7 +1,7 @@
 | 
				
			|||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from betterproto.tests.mocks import MockChannel
 | 
					from betterproto.tests.mocks import MockChannel
 | 
				
			||||||
from betterproto.tests.output_betterproto.import_service_input_message.import_service_input_message import (
 | 
					from betterproto.tests.output_betterproto.import_service_input_message import (
 | 
				
			||||||
    RequestResponse,
 | 
					    RequestResponse,
 | 
				
			||||||
    TestStub,
 | 
					    TestStub,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										19
									
								
								betterproto/tests/inputs/nested2/nested2.proto
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								betterproto/tests/inputs/nested2/nested2.proto
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
				
			|||||||
 | 
					syntax = "proto3";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import "package.proto";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message Game {
 | 
				
			||||||
 | 
					    message Player {
 | 
				
			||||||
 | 
					        enum Race {
 | 
				
			||||||
 | 
					            human = 0;
 | 
				
			||||||
 | 
					            orc = 1;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message Test {
 | 
				
			||||||
 | 
					    Game game = 1;
 | 
				
			||||||
 | 
					    Game.Player GamePlayer = 2;
 | 
				
			||||||
 | 
					    Game.Player.Race GamePlayerRace = 3;
 | 
				
			||||||
 | 
					    equipment.Weapon Weapon = 4;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										7
									
								
								betterproto/tests/inputs/nested2/package.proto
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								betterproto/tests/inputs/nested2/package.proto
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,7 @@
 | 
				
			|||||||
 | 
					syntax = "proto3";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					package equipment;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					message Weapon {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -1,10 +1,10 @@
 | 
				
			|||||||
{
 | 
					{
 | 
				
			||||||
  "root": {
 | 
					  "top": {
 | 
				
			||||||
    "name": "double-nested",
 | 
					    "name": "double-nested",
 | 
				
			||||||
    "parent": {
 | 
					    "middle": {
 | 
				
			||||||
      "child": [{"foo": "hello"}],
 | 
					      "bottom": [{"foo": "hello"}],
 | 
				
			||||||
      "enumChild": ["A"],
 | 
					      "enumBottom": ["A"],
 | 
				
			||||||
      "rootParentChild": [{"a": "hello"}],
 | 
					      "topMiddleBottom": [{"a": "hello"}],
 | 
				
			||||||
      "bar": true
 | 
					      "bar": true
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,26 +1,26 @@
 | 
				
			|||||||
syntax = "proto3";
 | 
					syntax = "proto3";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
message Test {
 | 
					message Test {
 | 
				
			||||||
  message Root {
 | 
					  message Top {
 | 
				
			||||||
    message Parent {
 | 
					    message Middle {
 | 
				
			||||||
      message RootParentChild {
 | 
					      message TopMiddleBottom {
 | 
				
			||||||
        string a = 1;
 | 
					        string a = 1;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      enum EnumChild{
 | 
					      enum EnumBottom{
 | 
				
			||||||
        A = 0;
 | 
					        A = 0;
 | 
				
			||||||
        B = 1;
 | 
					        B = 1;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      message Child {
 | 
					      message Bottom {
 | 
				
			||||||
        string foo = 1;
 | 
					        string foo = 1;
 | 
				
			||||||
      }
 | 
					      }
 | 
				
			||||||
      reserved 1;
 | 
					      reserved 1;
 | 
				
			||||||
      repeated Child child = 2;
 | 
					      repeated Bottom bottom = 2;
 | 
				
			||||||
      repeated EnumChild enumChild=3;
 | 
					      repeated EnumBottom enumBottom=3;
 | 
				
			||||||
      repeated RootParentChild rootParentChild=4;
 | 
					      repeated TopMiddleBottom topMiddleBottom=4;
 | 
				
			||||||
      bool bar = 5;
 | 
					      bool bar = 5;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    string name = 1;
 | 
					    string name = 1;
 | 
				
			||||||
    Parent parent = 2;
 | 
					    Middle middle = 2;
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
  Root root = 1;
 | 
					  Top top = 1;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,5 +1,5 @@
 | 
				
			|||||||
import betterproto
 | 
					import betterproto
 | 
				
			||||||
from betterproto.tests.output_betterproto.oneof.oneof import Test
 | 
					from betterproto.tests.output_betterproto.oneof import Test
 | 
				
			||||||
from betterproto.tests.util import get_test_case_json_data
 | 
					from betterproto.tests.util import get_test_case_json_data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,7 +1,7 @@
 | 
				
			|||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import betterproto
 | 
					import betterproto
 | 
				
			||||||
from betterproto.tests.output_betterproto.oneof_enum.oneof_enum import (
 | 
					from betterproto.tests.output_betterproto.oneof_enum import (
 | 
				
			||||||
    Move,
 | 
					    Move,
 | 
				
			||||||
    Signal,
 | 
					    Signal,
 | 
				
			||||||
    Test,
 | 
					    Test,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,7 +1,5 @@
 | 
				
			|||||||
syntax = "proto3";
 | 
					syntax = "proto3";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
package ref;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import "repeatedmessage.proto";
 | 
					import "repeatedmessage.proto";
 | 
				
			||||||
 | 
					
 | 
				
			||||||
message Test {
 | 
					message Test {
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										125
									
								
								betterproto/tests/test_casing.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										125
									
								
								betterproto/tests/test_casing.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,125 @@
 | 
				
			|||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from betterproto.casing import camel_case, pascal_case, snake_case
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    ["value", "expected"],
 | 
				
			||||||
 | 
					    [
 | 
				
			||||||
 | 
					        ("", ""),
 | 
				
			||||||
 | 
					        ("a", "A"),
 | 
				
			||||||
 | 
					        ("foobar", "Foobar"),
 | 
				
			||||||
 | 
					        ("fooBar", "FooBar"),
 | 
				
			||||||
 | 
					        ("FooBar", "FooBar"),
 | 
				
			||||||
 | 
					        ("foo.bar", "FooBar"),
 | 
				
			||||||
 | 
					        ("foo_bar", "FooBar"),
 | 
				
			||||||
 | 
					        ("FOOBAR", "Foobar"),
 | 
				
			||||||
 | 
					        ("FOOBar", "FooBar"),
 | 
				
			||||||
 | 
					        ("UInt32", "UInt32"),
 | 
				
			||||||
 | 
					        ("FOO_BAR", "FooBar"),
 | 
				
			||||||
 | 
					        ("FOOBAR1", "Foobar1"),
 | 
				
			||||||
 | 
					        ("FOOBAR_1", "Foobar1"),
 | 
				
			||||||
 | 
					        ("FOO1BAR2", "Foo1Bar2"),
 | 
				
			||||||
 | 
					        ("foo__bar", "FooBar"),
 | 
				
			||||||
 | 
					        ("_foobar", "Foobar"),
 | 
				
			||||||
 | 
					        ("foobaR", "FoobaR"),
 | 
				
			||||||
 | 
					        ("foo~bar", "FooBar"),
 | 
				
			||||||
 | 
					        ("foo:bar", "FooBar"),
 | 
				
			||||||
 | 
					        ("1foobar", "1Foobar"),
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def test_pascal_case(value, expected):
 | 
				
			||||||
 | 
					    actual = pascal_case(value, strict=True)
 | 
				
			||||||
 | 
					    assert actual == expected, f"{value} => {expected} (actual: {actual})"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    ["value", "expected"],
 | 
				
			||||||
 | 
					    [
 | 
				
			||||||
 | 
					        ("", ""),
 | 
				
			||||||
 | 
					        ("a", "a"),
 | 
				
			||||||
 | 
					        ("foobar", "foobar"),
 | 
				
			||||||
 | 
					        ("fooBar", "fooBar"),
 | 
				
			||||||
 | 
					        ("FooBar", "fooBar"),
 | 
				
			||||||
 | 
					        ("foo.bar", "fooBar"),
 | 
				
			||||||
 | 
					        ("foo_bar", "fooBar"),
 | 
				
			||||||
 | 
					        ("FOOBAR", "foobar"),
 | 
				
			||||||
 | 
					        ("FOO_BAR", "fooBar"),
 | 
				
			||||||
 | 
					        ("FOOBAR1", "foobar1"),
 | 
				
			||||||
 | 
					        ("FOOBAR_1", "foobar1"),
 | 
				
			||||||
 | 
					        ("FOO1BAR2", "foo1Bar2"),
 | 
				
			||||||
 | 
					        ("foo__bar", "fooBar"),
 | 
				
			||||||
 | 
					        ("_foobar", "foobar"),
 | 
				
			||||||
 | 
					        ("foobaR", "foobaR"),
 | 
				
			||||||
 | 
					        ("foo~bar", "fooBar"),
 | 
				
			||||||
 | 
					        ("foo:bar", "fooBar"),
 | 
				
			||||||
 | 
					        ("1foobar", "1Foobar"),
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def test_camel_case_strict(value, expected):
 | 
				
			||||||
 | 
					    actual = camel_case(value, strict=True)
 | 
				
			||||||
 | 
					    assert actual == expected, f"{value} => {expected} (actual: {actual})"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    ["value", "expected"],
 | 
				
			||||||
 | 
					    [
 | 
				
			||||||
 | 
					        ("foo_bar", "fooBar"),
 | 
				
			||||||
 | 
					        ("FooBar", "fooBar"),
 | 
				
			||||||
 | 
					        ("foo__bar", "foo_Bar"),
 | 
				
			||||||
 | 
					        ("foo__Bar", "foo__Bar"),
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def test_camel_case_not_strict(value, expected):
 | 
				
			||||||
 | 
					    actual = camel_case(value, strict=False)
 | 
				
			||||||
 | 
					    assert actual == expected, f"{value} => {expected} (actual: {actual})"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    ["value", "expected"],
 | 
				
			||||||
 | 
					    [
 | 
				
			||||||
 | 
					        ("", ""),
 | 
				
			||||||
 | 
					        ("a", "a"),
 | 
				
			||||||
 | 
					        ("foobar", "foobar"),
 | 
				
			||||||
 | 
					        ("fooBar", "foo_bar"),
 | 
				
			||||||
 | 
					        ("FooBar", "foo_bar"),
 | 
				
			||||||
 | 
					        ("foo.bar", "foo_bar"),
 | 
				
			||||||
 | 
					        ("foo_bar", "foo_bar"),
 | 
				
			||||||
 | 
					        ("foo_Bar", "foo_bar"),
 | 
				
			||||||
 | 
					        ("FOOBAR", "foobar"),
 | 
				
			||||||
 | 
					        ("FOOBar", "foo_bar"),
 | 
				
			||||||
 | 
					        ("UInt32", "u_int32"),
 | 
				
			||||||
 | 
					        ("FOO_BAR", "foo_bar"),
 | 
				
			||||||
 | 
					        ("FOOBAR1", "foobar1"),
 | 
				
			||||||
 | 
					        ("FOOBAR_1", "foobar_1"),
 | 
				
			||||||
 | 
					        ("FOOBAR_123", "foobar_123"),
 | 
				
			||||||
 | 
					        ("FOO1BAR2", "foo1_bar2"),
 | 
				
			||||||
 | 
					        ("foo__bar", "foo_bar"),
 | 
				
			||||||
 | 
					        ("_foobar", "foobar"),
 | 
				
			||||||
 | 
					        ("foobaR", "fooba_r"),
 | 
				
			||||||
 | 
					        ("foo~bar", "foo_bar"),
 | 
				
			||||||
 | 
					        ("foo:bar", "foo_bar"),
 | 
				
			||||||
 | 
					        ("1foobar", "1_foobar"),
 | 
				
			||||||
 | 
					        ("GetUInt64", "get_u_int64"),
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def test_snake_case_strict(value, expected):
 | 
				
			||||||
 | 
					    actual = snake_case(value)
 | 
				
			||||||
 | 
					    assert actual == expected, f"{value} => {expected} (actual: {actual})"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    ["value", "expected"],
 | 
				
			||||||
 | 
					    [
 | 
				
			||||||
 | 
					        ("fooBar", "foo_bar"),
 | 
				
			||||||
 | 
					        ("FooBar", "foo_bar"),
 | 
				
			||||||
 | 
					        ("foo_Bar", "foo__bar"),
 | 
				
			||||||
 | 
					        ("foo__bar", "foo__bar"),
 | 
				
			||||||
 | 
					        ("FOOBar", "foo_bar"),
 | 
				
			||||||
 | 
					        ("__foo", "__foo"),
 | 
				
			||||||
 | 
					        ("GetUInt64", "get_u_int64"),
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def test_snake_case_not_strict(value, expected):
 | 
				
			||||||
 | 
					    actual = snake_case(value, strict=False)
 | 
				
			||||||
 | 
					    assert actual == expected, f"{value} => {expected} (actual: {actual})"
 | 
				
			||||||
@@ -1,6 +1,6 @@
 | 
				
			|||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from ..compile.importing import get_ref_type
 | 
					from ..compile.importing import get_type_reference, parse_source_type_name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.parametrize(
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
@@ -28,14 +28,16 @@ from ..compile.importing import get_ref_type
 | 
				
			|||||||
        ),
 | 
					        ),
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
def test_import_google_wellknown_types_non_wrappers(
 | 
					def test_reference_google_wellknown_types_non_wrappers(
 | 
				
			||||||
    google_type: str, expected_name: str, expected_import: str
 | 
					    google_type: str, expected_name: str, expected_import: str
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    imports = set()
 | 
					    imports = set()
 | 
				
			||||||
    name = get_ref_type(package="", imports=imports, type_name=google_type)
 | 
					    name = get_type_reference(package="", imports=imports, source_type=google_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert name == expected_name
 | 
					    assert name == expected_name
 | 
				
			||||||
    assert imports.__contains__(expected_import)
 | 
					    assert imports.__contains__(
 | 
				
			||||||
 | 
					        expected_import
 | 
				
			||||||
 | 
					    ), f"{expected_import} not found in {imports}"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.mark.parametrize(
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
@@ -52,9 +54,11 @@ def test_import_google_wellknown_types_non_wrappers(
 | 
				
			|||||||
        (".google.protobuf.BytesValue", "Optional[bytes]"),
 | 
					        (".google.protobuf.BytesValue", "Optional[bytes]"),
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
def test_importing_google_wrappers_unwraps_them(google_type: str, expected_name: str):
 | 
					def test_referenceing_google_wrappers_unwraps_them(
 | 
				
			||||||
 | 
					    google_type: str, expected_name: str
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
    imports = set()
 | 
					    imports = set()
 | 
				
			||||||
    name = get_ref_type(package="", imports=imports, type_name=google_type)
 | 
					    name = get_type_reference(package="", imports=imports, source_type=google_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert name == expected_name
 | 
					    assert name == expected_name
 | 
				
			||||||
    assert imports == set()
 | 
					    assert imports == set()
 | 
				
			||||||
@@ -74,9 +78,238 @@ def test_importing_google_wrappers_unwraps_them(google_type: str, expected_name:
 | 
				
			|||||||
        (".google.protobuf.BytesValue", "betterproto_lib_google_protobuf.BytesValue"),
 | 
					        (".google.protobuf.BytesValue", "betterproto_lib_google_protobuf.BytesValue"),
 | 
				
			||||||
    ],
 | 
					    ],
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
def test_importing_google_wrappers_without_unwrapping(
 | 
					def test_referenceing_google_wrappers_without_unwrapping(
 | 
				
			||||||
    google_type: str, expected_name: str
 | 
					    google_type: str, expected_name: str
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    name = get_ref_type(package="", imports=set(), type_name=google_type, unwrap=False)
 | 
					    name = get_type_reference(
 | 
				
			||||||
 | 
					        package="", imports=set(), source_type=google_type, unwrap=False
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert name == expected_name
 | 
					    assert name == expected_name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_child_package_from_package():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(
 | 
				
			||||||
 | 
					        package="package", imports=imports, source_type="package.child.Message"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from . import child"}
 | 
				
			||||||
 | 
					    assert name == "child.Message"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_child_package_from_root():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(package="", imports=imports, source_type="child.Message")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from . import child"}
 | 
				
			||||||
 | 
					    assert name == "child.Message"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_camel_cased():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(
 | 
				
			||||||
 | 
					        package="", imports=imports, source_type="child_package.example_message"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from . import child_package"}
 | 
				
			||||||
 | 
					    assert name == "child_package.ExampleMessage"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_nested_child_from_root():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(
 | 
				
			||||||
 | 
					        package="", imports=imports, source_type="nested.child.Message"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from .nested import child as nested_child"}
 | 
				
			||||||
 | 
					    assert name == "nested_child.Message"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_deeply_nested_child_from_root():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(
 | 
				
			||||||
 | 
					        package="", imports=imports, source_type="deeply.nested.child.Message"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from .deeply.nested import child as deeply_nested_child"}
 | 
				
			||||||
 | 
					    assert name == "deeply_nested_child.Message"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_deeply_nested_child_from_package():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(
 | 
				
			||||||
 | 
					        package="package",
 | 
				
			||||||
 | 
					        imports=imports,
 | 
				
			||||||
 | 
					        source_type="package.deeply.nested.child.Message",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from .deeply.nested import child as deeply_nested_child"}
 | 
				
			||||||
 | 
					    assert name == "deeply_nested_child.Message"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_root_sibling():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(package="", imports=imports, source_type="Message")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == set()
 | 
				
			||||||
 | 
					    assert name == '"Message"'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_nested_siblings():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(package="foo", imports=imports, source_type="foo.Message")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == set()
 | 
				
			||||||
 | 
					    assert name == '"Message"'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_deeply_nested_siblings():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(
 | 
				
			||||||
 | 
					        package="foo.bar", imports=imports, source_type="foo.bar.Message"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == set()
 | 
				
			||||||
 | 
					    assert name == '"Message"'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_parent_package_from_child():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(
 | 
				
			||||||
 | 
					        package="package.child", imports=imports, source_type="package.Message"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from ... import package as __package__"}
 | 
				
			||||||
 | 
					    assert name == "__package__.Message"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_parent_package_from_deeply_nested_child():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(
 | 
				
			||||||
 | 
					        package="package.deeply.nested.child",
 | 
				
			||||||
 | 
					        imports=imports,
 | 
				
			||||||
 | 
					        source_type="package.deeply.nested.Message",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from ... import nested as __nested__"}
 | 
				
			||||||
 | 
					    assert name == "__nested__.Message"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_ancestor_package_from_nested_child():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(
 | 
				
			||||||
 | 
					        package="package.ancestor.nested.child",
 | 
				
			||||||
 | 
					        imports=imports,
 | 
				
			||||||
 | 
					        source_type="package.ancestor.Message",
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from .... import ancestor as ___ancestor__"}
 | 
				
			||||||
 | 
					    assert name == "___ancestor__.Message"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_root_package_from_child():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(
 | 
				
			||||||
 | 
					        package="package.child", imports=imports, source_type="Message"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from ... import Message as __Message__"}
 | 
				
			||||||
 | 
					    assert name == "__Message__"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_root_package_from_deeply_nested_child():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(
 | 
				
			||||||
 | 
					        package="package.deeply.nested.child", imports=imports, source_type="Message"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from ..... import Message as ____Message__"}
 | 
				
			||||||
 | 
					    assert name == "____Message__"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_unrelated_package():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(package="a", imports=imports, source_type="p.Message")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from .. import p as _p__"}
 | 
				
			||||||
 | 
					    assert name == "_p__.Message"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_unrelated_nested_package():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(package="a.b", imports=imports, source_type="p.q.Message")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from ...p import q as __p_q__"}
 | 
				
			||||||
 | 
					    assert name == "__p_q__.Message"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_unrelated_deeply_nested_package():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(
 | 
				
			||||||
 | 
					        package="a.b.c.d", imports=imports, source_type="p.q.r.s.Message"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from .....p.q.r import s as ____p_q_r_s__"}
 | 
				
			||||||
 | 
					    assert name == "____p_q_r_s__.Message"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_cousin_package():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(package="a.x", imports=imports, source_type="a.y.Message")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from .. import y as _y__"}
 | 
				
			||||||
 | 
					    assert name == "_y__.Message"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_cousin_package_different_name():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(
 | 
				
			||||||
 | 
					        package="test.package1", imports=imports, source_type="cousin.package2.Message"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from ...cousin import package2 as __cousin_package2__"}
 | 
				
			||||||
 | 
					    assert name == "__cousin_package2__.Message"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_cousin_package_same_name():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(
 | 
				
			||||||
 | 
					        package="test.package", imports=imports, source_type="cousin.package.Message"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from ...cousin import package as __cousin_package__"}
 | 
				
			||||||
 | 
					    assert name == "__cousin_package__.Message"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_far_cousin_package():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(
 | 
				
			||||||
 | 
					        package="a.x.y", imports=imports, source_type="a.b.c.Message"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from ...b import c as __b_c__"}
 | 
				
			||||||
 | 
					    assert name == "__b_c__.Message"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_reference_far_far_cousin_package():
 | 
				
			||||||
 | 
					    imports = set()
 | 
				
			||||||
 | 
					    name = get_type_reference(
 | 
				
			||||||
 | 
					        package="a.x.y.z", imports=imports, source_type="a.b.c.d.Message"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert imports == {"from ....b.c import d as ___b_c_d__"}
 | 
				
			||||||
 | 
					    assert name == "___b_c_d__.Message"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.mark.parametrize(
 | 
				
			||||||
 | 
					    ["full_name", "expected_output"],
 | 
				
			||||||
 | 
					    [
 | 
				
			||||||
 | 
					        ("package.SomeMessage.NestedType", ("package", "SomeMessage.NestedType")),
 | 
				
			||||||
 | 
					        (".package.SomeMessage.NestedType", ("package", "SomeMessage.NestedType")),
 | 
				
			||||||
 | 
					        (".service.ExampleRequest", ("service", "ExampleRequest")),
 | 
				
			||||||
 | 
					        (".package.lower_case_message", ("package", "lower_case_message")),
 | 
				
			||||||
 | 
					    ],
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					def test_parse_field_type_name(full_name, expected_output):
 | 
				
			||||||
 | 
					    assert parse_source_type_name(full_name) == expected_output
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,6 +3,7 @@ import json
 | 
				
			|||||||
import os
 | 
					import os
 | 
				
			||||||
import sys
 | 
					import sys
 | 
				
			||||||
from collections import namedtuple
 | 
					from collections import namedtuple
 | 
				
			||||||
 | 
					from types import ModuleType
 | 
				
			||||||
from typing import Set
 | 
					from typing import Set
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
@@ -10,7 +11,12 @@ import pytest
 | 
				
			|||||||
import betterproto
 | 
					import betterproto
 | 
				
			||||||
from betterproto.tests.inputs import config as test_input_config
 | 
					from betterproto.tests.inputs import config as test_input_config
 | 
				
			||||||
from betterproto.tests.mocks import MockChannel
 | 
					from betterproto.tests.mocks import MockChannel
 | 
				
			||||||
from betterproto.tests.util import get_directories, get_test_case_json_data, inputs_path
 | 
					from betterproto.tests.util import (
 | 
				
			||||||
 | 
					    find_module,
 | 
				
			||||||
 | 
					    get_directories,
 | 
				
			||||||
 | 
					    get_test_case_json_data,
 | 
				
			||||||
 | 
					    inputs_path,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Force pure-python implementation instead of C++, otherwise imports
 | 
					# Force pure-python implementation instead of C++, otherwise imports
 | 
				
			||||||
# break things because we can't properly reset the symbol database.
 | 
					# break things because we can't properly reset the symbol database.
 | 
				
			||||||
@@ -50,14 +56,17 @@ class TestCases:
 | 
				
			|||||||
test_cases = TestCases(
 | 
					test_cases = TestCases(
 | 
				
			||||||
    path=inputs_path,
 | 
					    path=inputs_path,
 | 
				
			||||||
    services=test_input_config.services,
 | 
					    services=test_input_config.services,
 | 
				
			||||||
    xfail=test_input_config.tests,
 | 
					    xfail=test_input_config.xfail,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
plugin_output_package = "betterproto.tests.output_betterproto"
 | 
					plugin_output_package = "betterproto.tests.output_betterproto"
 | 
				
			||||||
reference_output_package = "betterproto.tests.output_reference"
 | 
					reference_output_package = "betterproto.tests.output_reference"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					TestData = namedtuple("TestData", ["plugin_module", "reference_module", "json_data"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
TestData = namedtuple("TestData", "plugin_module, reference_module, json_data")
 | 
					
 | 
				
			||||||
 | 
					def module_has_entry_point(module: ModuleType):
 | 
				
			||||||
 | 
					    return any(hasattr(module, attr) for attr in ["Test", "TestStub"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@pytest.fixture
 | 
					@pytest.fixture
 | 
				
			||||||
@@ -75,11 +84,19 @@ def test_data(request):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    sys.path.append(reference_module_root)
 | 
					    sys.path.append(reference_module_root)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    plugin_module = importlib.import_module(f"{plugin_output_package}.{test_case_name}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    plugin_module_entry_point = find_module(plugin_module, module_has_entry_point)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not plugin_module_entry_point:
 | 
				
			||||||
 | 
					        raise Exception(
 | 
				
			||||||
 | 
					            f"Test case {repr(test_case_name)} has no entry point. "
 | 
				
			||||||
 | 
					            "Please add a proto message or service called Test and recompile."
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    yield (
 | 
					    yield (
 | 
				
			||||||
        TestData(
 | 
					        TestData(
 | 
				
			||||||
            plugin_module=importlib.import_module(
 | 
					            plugin_module=plugin_module_entry_point,
 | 
				
			||||||
                f"{plugin_output_package}.{test_case_name}.{test_case_name}"
 | 
					 | 
				
			||||||
            ),
 | 
					 | 
				
			||||||
            reference_module=lambda: importlib.import_module(
 | 
					            reference_module=lambda: importlib.import_module(
 | 
				
			||||||
                f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2"
 | 
					                f"{reference_output_package}.{test_case_name}.{test_case_name}_pb2"
 | 
				
			||||||
            ),
 | 
					            ),
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,7 +1,10 @@
 | 
				
			|||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
 | 
					import importlib
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
 | 
					import pathlib
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
from typing import Generator, IO, Optional
 | 
					from types import ModuleType
 | 
				
			||||||
 | 
					from typing import Callable, Generator, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
 | 
					os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -55,3 +58,35 @@ def get_test_case_json_data(test_case_name: str, json_file_name: Optional[str] =
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    with test_data_file_path.open("r") as fh:
 | 
					    with test_data_file_path.open("r") as fh:
 | 
				
			||||||
        return fh.read()
 | 
					        return fh.read()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def find_module(
 | 
				
			||||||
 | 
					    module: ModuleType, predicate: Callable[[ModuleType], bool]
 | 
				
			||||||
 | 
					) -> Optional[ModuleType]:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Recursively search module tree for a module that matches the search predicate.
 | 
				
			||||||
 | 
					    Assumes that the submodules are directories containing __init__.py.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Example:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # find module inside foo that contains Test
 | 
				
			||||||
 | 
					        import foo
 | 
				
			||||||
 | 
					        test_module = find_module(foo, lambda m: hasattr(m, 'Test'))
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    if predicate(module):
 | 
				
			||||||
 | 
					        return module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    module_path = pathlib.Path(*module.__path__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for sub in list(sub.parent for sub in module_path.glob("**/__init__.py")):
 | 
				
			||||||
 | 
					        if sub == module_path:
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
 | 
					        sub_module_path = sub.relative_to(module_path)
 | 
				
			||||||
 | 
					        sub_module_name = ".".join(sub_module_path.parts)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        sub_module = importlib.import_module(f".{sub_module_name}", module.__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if predicate(sub_module):
 | 
				
			||||||
 | 
					            return sub_module
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return None
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										21
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							
							
						
						
									
										21
									
								
								poetry.lock
									
									
									
										generated
									
									
									
								
							@@ -271,7 +271,7 @@ docs = ["sphinx", "rst.linker", "jaraco.packaging"]
 | 
				
			|||||||
category = "main"
 | 
					category = "main"
 | 
				
			||||||
description = "A very fast and expressive template engine."
 | 
					description = "A very fast and expressive template engine."
 | 
				
			||||||
name = "jinja2"
 | 
					name = "jinja2"
 | 
				
			||||||
optional = true
 | 
					optional = false
 | 
				
			||||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
 | 
					python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
 | 
				
			||||||
version = "2.11.2"
 | 
					version = "2.11.2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -285,7 +285,7 @@ i18n = ["Babel (>=0.8)"]
 | 
				
			|||||||
category = "main"
 | 
					category = "main"
 | 
				
			||||||
description = "Safely add untrusted strings to HTML/XML markup."
 | 
					description = "Safely add untrusted strings to HTML/XML markup."
 | 
				
			||||||
name = "markupsafe"
 | 
					name = "markupsafe"
 | 
				
			||||||
optional = true
 | 
					optional = false
 | 
				
			||||||
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*"
 | 
					python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*"
 | 
				
			||||||
version = "1.1.1"
 | 
					version = "1.1.1"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -369,7 +369,7 @@ dev = ["pre-commit", "tox"]
 | 
				
			|||||||
category = "main"
 | 
					category = "main"
 | 
				
			||||||
description = "Protocol Buffers"
 | 
					description = "Protocol Buffers"
 | 
				
			||||||
name = "protobuf"
 | 
					name = "protobuf"
 | 
				
			||||||
optional = true
 | 
					optional = false
 | 
				
			||||||
python-versions = "*"
 | 
					python-versions = "*"
 | 
				
			||||||
version = "3.12.2"
 | 
					version = "3.12.2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -490,14 +490,6 @@ optional = false
 | 
				
			|||||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
 | 
					python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
 | 
				
			||||||
version = "1.15.0"
 | 
					version = "1.15.0"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[[package]]
 | 
					 | 
				
			||||||
category = "main"
 | 
					 | 
				
			||||||
description = "String case converter."
 | 
					 | 
				
			||||||
name = "stringcase"
 | 
					 | 
				
			||||||
optional = false
 | 
					 | 
				
			||||||
python-versions = "*"
 | 
					 | 
				
			||||||
version = "1.2.0"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
[[package]]
 | 
					[[package]]
 | 
				
			||||||
category = "main"
 | 
					category = "main"
 | 
				
			||||||
description = "Python Library for Tom's Obvious, Minimal Language"
 | 
					description = "Python Library for Tom's Obvious, Minimal Language"
 | 
				
			||||||
@@ -612,7 +604,7 @@ testing = ["jaraco.itertools", "func-timeout"]
 | 
				
			|||||||
compiler = ["black", "jinja2", "protobuf"]
 | 
					compiler = ["black", "jinja2", "protobuf"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[metadata]
 | 
					[metadata]
 | 
				
			||||||
content-hash = "ecafcaed2d4a25c2829e6dc3ef3c56cd72a8bc28c25c7aeae3484c978c816722"
 | 
					content-hash = "8a4fa01ede86e1b5ba35b9dab8b6eacee766a9b5666f48ab41445c01882ab003"
 | 
				
			||||||
python-versions = "^3.6"
 | 
					python-versions = "^3.6"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[metadata.files]
 | 
					[metadata.files]
 | 
				
			||||||
@@ -865,6 +857,8 @@ protobuf = [
 | 
				
			|||||||
    {file = "protobuf-3.12.2-cp37-cp37m-win_amd64.whl", hash = "sha256:e72736dd822748b0721f41f9aaaf6a5b6d5cfc78f6c8690263aef8bba4457f0e"},
 | 
					    {file = "protobuf-3.12.2-cp37-cp37m-win_amd64.whl", hash = "sha256:e72736dd822748b0721f41f9aaaf6a5b6d5cfc78f6c8690263aef8bba4457f0e"},
 | 
				
			||||||
    {file = "protobuf-3.12.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:87535dc2d2ef007b9d44e309d2b8ea27a03d2fa09556a72364d706fcb7090828"},
 | 
					    {file = "protobuf-3.12.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:87535dc2d2ef007b9d44e309d2b8ea27a03d2fa09556a72364d706fcb7090828"},
 | 
				
			||||||
    {file = "protobuf-3.12.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:50b5fee674878b14baea73b4568dc478c46a31dd50157a5b5d2f71138243b1a9"},
 | 
					    {file = "protobuf-3.12.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:50b5fee674878b14baea73b4568dc478c46a31dd50157a5b5d2f71138243b1a9"},
 | 
				
			||||||
 | 
					    {file = "protobuf-3.12.2-py2.py3-none-any.whl", hash = "sha256:a96f8fc625e9ff568838e556f6f6ae8eca8b4837cdfb3f90efcb7c00e342a2eb"},
 | 
				
			||||||
 | 
					    {file = "protobuf-3.12.2.tar.gz", hash = "sha256:49ef8ab4c27812a89a76fa894fe7a08f42f2147078392c0dee51d4a444ef6df5"},
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
py = [
 | 
					py = [
 | 
				
			||||||
    {file = "py-1.8.2-py2.py3-none-any.whl", hash = "sha256:a673fa23d7000440cc885c17dbd34fafcb7d7a6e230b29f6766400de36a33c44"},
 | 
					    {file = "py-1.8.2-py2.py3-none-any.whl", hash = "sha256:a673fa23d7000440cc885c17dbd34fafcb7d7a6e230b29f6766400de36a33c44"},
 | 
				
			||||||
@@ -920,9 +914,6 @@ six = [
 | 
				
			|||||||
    {file = "six-1.15.0-py2.py3-none-any.whl", hash = "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced"},
 | 
					    {file = "six-1.15.0-py2.py3-none-any.whl", hash = "sha256:8b74bedcbbbaca38ff6d7491d76f2b06b3592611af620f8426e82dddb04a5ced"},
 | 
				
			||||||
    {file = "six-1.15.0.tar.gz", hash = "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259"},
 | 
					    {file = "six-1.15.0.tar.gz", hash = "sha256:30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259"},
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
stringcase = [
 | 
					 | 
				
			||||||
    {file = "stringcase-1.2.0.tar.gz", hash = "sha256:48a06980661908efe8d9d34eab2b6c13aefa2163b3ced26972902e3bdfd87008"},
 | 
					 | 
				
			||||||
]
 | 
					 | 
				
			||||||
toml = [
 | 
					toml = [
 | 
				
			||||||
    {file = "toml-0.10.1-py2.py3-none-any.whl", hash = "sha256:bda89d5935c2eac546d648028b9901107a595863cb36bae0c73ac804a9b4ce88"},
 | 
					    {file = "toml-0.10.1-py2.py3-none-any.whl", hash = "sha256:bda89d5935c2eac546d648028b9901107a595863cb36bae0c73ac804a9b4ce88"},
 | 
				
			||||||
    {file = "toml-0.10.1.tar.gz", hash = "sha256:926b612be1e5ce0634a2ca03470f95169cf16f939018233a670519cb4ac58b0f"},
 | 
					    {file = "toml-0.10.1.tar.gz", hash = "sha256:926b612be1e5ce0634a2ca03470f95169cf16f939018233a670519cb4ac58b0f"},
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -18,7 +18,6 @@ dataclasses = { version = "^0.7", python = ">=3.6, <3.7" }
 | 
				
			|||||||
grpclib = "^0.3.1"
 | 
					grpclib = "^0.3.1"
 | 
				
			||||||
jinja2 = { version = "^2.11.2", optional = true }
 | 
					jinja2 = { version = "^2.11.2", optional = true }
 | 
				
			||||||
protobuf = { version = "^3.12.2", optional = true }
 | 
					protobuf = { version = "^3.12.2", optional = true }
 | 
				
			||||||
stringcase = "^1.2.0"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
[tool.poetry.dev-dependencies]
 | 
					[tool.poetry.dev-dependencies]
 | 
				
			||||||
black = "^19.10b0"
 | 
					black = "^19.10b0"
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user