Merge branch 'refs/heads/master_gh'
This commit is contained in:
commit
32eaa51e8d
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [Ubuntu, MacOS, Windows]
|
os: [Ubuntu, MacOS, Windows]
|
||||||
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12']
|
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
43
README.md
43
README.md
@ -391,7 +391,50 @@ swap the dataclass implementation from the builtin python dataclass to the
|
|||||||
pydantic dataclass. You must have pydantic as a dependency in your project for
|
pydantic dataclass. You must have pydantic as a dependency in your project for
|
||||||
this to work.
|
this to work.
|
||||||
|
|
||||||
|
## Configuration typing imports
|
||||||
|
|
||||||
|
By default typing types will be imported directly from typing. This sometimes can lead to issues in generation if types that are being generated conflict with the name. In this case you can configure the way types are imported from 3 different options:
|
||||||
|
|
||||||
|
### Direct
|
||||||
|
```
|
||||||
|
protoc -I . --python_betterproto_opt=typing.direct --python_betterproto_out=lib example.proto
|
||||||
|
```
|
||||||
|
this configuration is the default, and will import types as follows:
|
||||||
|
```
|
||||||
|
from typing import (
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Union
|
||||||
|
)
|
||||||
|
...
|
||||||
|
value: List[str] = []
|
||||||
|
value2: Optional[str] = None
|
||||||
|
value3: Union[str, int] = 1
|
||||||
|
```
|
||||||
|
### Root
|
||||||
|
```
|
||||||
|
protoc -I . --python_betterproto_opt=typing.root --python_betterproto_out=lib example.proto
|
||||||
|
```
|
||||||
|
this configuration loads the root typing module, and then access the types off of it directly:
|
||||||
|
```
|
||||||
|
import typing
|
||||||
|
...
|
||||||
|
value: typing.List[str] = []
|
||||||
|
value2: typing.Optional[str] = None
|
||||||
|
value3: typing.Union[str, int] = 1
|
||||||
|
```
|
||||||
|
|
||||||
|
### 310
|
||||||
|
```
|
||||||
|
protoc -I . --python_betterproto_opt=typing.310 --python_betterproto_out=lib example.proto
|
||||||
|
```
|
||||||
|
this configuration avoid loading typing all together if possible and uses the python 3.10 pattern:
|
||||||
|
```
|
||||||
|
...
|
||||||
|
value: list[str] = []
|
||||||
|
value2: str | None = None
|
||||||
|
value3: str | int = 1
|
||||||
|
```
|
||||||
|
|
||||||
## Development
|
## Development
|
||||||
|
|
||||||
|
1138
poetry.lock
generated
1138
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -39,7 +39,7 @@ pytest = "^6.2.5"
|
|||||||
pytest-asyncio = "^0.12.0"
|
pytest-asyncio = "^0.12.0"
|
||||||
pytest-cov = "^2.9.0"
|
pytest-cov = "^2.9.0"
|
||||||
pytest-mock = "^3.1.1"
|
pytest-mock = "^3.1.1"
|
||||||
pydantic = ">=1.8.0,<2"
|
pydantic = ">=2.0,<3"
|
||||||
protobuf = "^4"
|
protobuf = "^4"
|
||||||
cachelib = "^0.10.2"
|
cachelib = "^0.10.2"
|
||||||
tomlkit = ">=0.7.0"
|
tomlkit = ">=0.7.0"
|
||||||
|
@ -1852,7 +1852,9 @@ class Message(ABC):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
set_fields = [
|
set_fields = [
|
||||||
field.name for field in field_set if values[field.name] is not None
|
field.name
|
||||||
|
for field in field_set
|
||||||
|
if getattr(values, field.name, None) is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
if not set_fields:
|
if not set_fields:
|
||||||
|
@ -47,6 +47,7 @@ def get_type_reference(
|
|||||||
package: str,
|
package: str,
|
||||||
imports: set,
|
imports: set,
|
||||||
source_type: str,
|
source_type: str,
|
||||||
|
typing_compiler: "TypingCompiler",
|
||||||
unwrap: bool = True,
|
unwrap: bool = True,
|
||||||
pydantic: bool = False,
|
pydantic: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -57,7 +58,7 @@ def get_type_reference(
|
|||||||
if unwrap:
|
if unwrap:
|
||||||
if source_type in WRAPPER_TYPES:
|
if source_type in WRAPPER_TYPES:
|
||||||
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
|
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
|
||||||
return f"Optional[{wrapped_type.__name__}]"
|
return typing_compiler.optional(wrapped_type.__name__)
|
||||||
|
|
||||||
if source_type == ".google.protobuf.Duration":
|
if source_type == ".google.protobuf.Duration":
|
||||||
return "timedelta"
|
return "timedelta"
|
||||||
|
@ -1,9 +1,16 @@
|
|||||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
# sources: google/protobuf/any.proto, google/protobuf/api.proto, google/protobuf/descriptor.proto, google/protobuf/duration.proto, google/protobuf/empty.proto, google/protobuf/field_mask.proto, google/protobuf/source_context.proto, google/protobuf/struct.proto, google/protobuf/timestamp.proto, google/protobuf/type.proto, google/protobuf/wrappers.proto
|
# sources: google/protobuf/any.proto, google/protobuf/api.proto, google/protobuf/descriptor.proto, google/protobuf/duration.proto, google/protobuf/empty.proto, google/protobuf/field_mask.proto, google/protobuf/source_context.proto, google/protobuf/struct.proto, google/protobuf/timestamp.proto, google/protobuf/type.proto, google/protobuf/wrappers.proto
|
||||||
# plugin: python-betterproto
|
# plugin: python-betterproto
|
||||||
|
# This file has been @generated
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Mapping,
|
||||||
|
)
|
||||||
|
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from betterproto import hybridmethod
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -14,15 +21,13 @@ else:
|
|||||||
from typing import (
|
from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Mapping,
|
|
||||||
Optional,
|
Optional,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import root_validator
|
from pydantic import model_validator
|
||||||
from typing_extensions import Self
|
from pydantic.dataclasses import rebuild_dataclass
|
||||||
|
|
||||||
import betterproto
|
import betterproto
|
||||||
from betterproto.utils import hybridmethod
|
|
||||||
|
|
||||||
|
|
||||||
class Syntax(betterproto.Enum):
|
class Syntax(betterproto.Enum):
|
||||||
@ -37,6 +42,12 @@ class Syntax(betterproto.Enum):
|
|||||||
EDITIONS = 2
|
EDITIONS = 2
|
||||||
"""Syntax `editions`."""
|
"""Syntax `editions`."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class FieldKind(betterproto.Enum):
|
class FieldKind(betterproto.Enum):
|
||||||
"""Basic field types."""
|
"""Basic field types."""
|
||||||
@ -98,6 +109,12 @@ class FieldKind(betterproto.Enum):
|
|||||||
TYPE_SINT64 = 18
|
TYPE_SINT64 = 18
|
||||||
"""Field type sint64."""
|
"""Field type sint64."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class FieldCardinality(betterproto.Enum):
|
class FieldCardinality(betterproto.Enum):
|
||||||
"""Whether a field is optional, required, or repeated."""
|
"""Whether a field is optional, required, or repeated."""
|
||||||
@ -114,6 +131,12 @@ class FieldCardinality(betterproto.Enum):
|
|||||||
CARDINALITY_REPEATED = 3
|
CARDINALITY_REPEATED = 3
|
||||||
"""For repeated fields."""
|
"""For repeated fields."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class Edition(betterproto.Enum):
|
class Edition(betterproto.Enum):
|
||||||
"""The full set of known editions."""
|
"""The full set of known editions."""
|
||||||
@ -155,6 +178,12 @@ class Edition(betterproto.Enum):
|
|||||||
support a new edition.
|
support a new edition.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class ExtensionRangeOptionsVerificationState(betterproto.Enum):
|
class ExtensionRangeOptionsVerificationState(betterproto.Enum):
|
||||||
"""The verification state of the extension range."""
|
"""The verification state of the extension range."""
|
||||||
@ -164,6 +193,12 @@ class ExtensionRangeOptionsVerificationState(betterproto.Enum):
|
|||||||
|
|
||||||
UNVERIFIED = 1
|
UNVERIFIED = 1
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class FieldDescriptorProtoType(betterproto.Enum):
|
class FieldDescriptorProtoType(betterproto.Enum):
|
||||||
TYPE_DOUBLE = 1
|
TYPE_DOUBLE = 1
|
||||||
@ -210,6 +245,12 @@ class FieldDescriptorProtoType(betterproto.Enum):
|
|||||||
TYPE_SINT32 = 17
|
TYPE_SINT32 = 17
|
||||||
TYPE_SINT64 = 18
|
TYPE_SINT64 = 18
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class FieldDescriptorProtoLabel(betterproto.Enum):
|
class FieldDescriptorProtoLabel(betterproto.Enum):
|
||||||
LABEL_OPTIONAL = 1
|
LABEL_OPTIONAL = 1
|
||||||
@ -223,6 +264,12 @@ class FieldDescriptorProtoLabel(betterproto.Enum):
|
|||||||
can be used to get this behavior.
|
can be used to get this behavior.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class FileOptionsOptimizeMode(betterproto.Enum):
|
class FileOptionsOptimizeMode(betterproto.Enum):
|
||||||
"""Generated classes can be optimized for speed or code size."""
|
"""Generated classes can be optimized for speed or code size."""
|
||||||
@ -233,6 +280,12 @@ class FileOptionsOptimizeMode(betterproto.Enum):
|
|||||||
|
|
||||||
LITE_RUNTIME = 3
|
LITE_RUNTIME = 3
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class FieldOptionsCType(betterproto.Enum):
|
class FieldOptionsCType(betterproto.Enum):
|
||||||
STRING = 0
|
STRING = 0
|
||||||
@ -250,6 +303,12 @@ class FieldOptionsCType(betterproto.Enum):
|
|||||||
|
|
||||||
STRING_PIECE = 2
|
STRING_PIECE = 2
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class FieldOptionsJsType(betterproto.Enum):
|
class FieldOptionsJsType(betterproto.Enum):
|
||||||
JS_NORMAL = 0
|
JS_NORMAL = 0
|
||||||
@ -261,6 +320,12 @@ class FieldOptionsJsType(betterproto.Enum):
|
|||||||
JS_NUMBER = 2
|
JS_NUMBER = 2
|
||||||
"""Use JavaScript numbers."""
|
"""Use JavaScript numbers."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class FieldOptionsOptionRetention(betterproto.Enum):
|
class FieldOptionsOptionRetention(betterproto.Enum):
|
||||||
"""
|
"""
|
||||||
@ -273,6 +338,12 @@ class FieldOptionsOptionRetention(betterproto.Enum):
|
|||||||
RETENTION_RUNTIME = 1
|
RETENTION_RUNTIME = 1
|
||||||
RETENTION_SOURCE = 2
|
RETENTION_SOURCE = 2
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class FieldOptionsOptionTargetType(betterproto.Enum):
|
class FieldOptionsOptionTargetType(betterproto.Enum):
|
||||||
"""
|
"""
|
||||||
@ -293,6 +364,12 @@ class FieldOptionsOptionTargetType(betterproto.Enum):
|
|||||||
TARGET_TYPE_SERVICE = 8
|
TARGET_TYPE_SERVICE = 8
|
||||||
TARGET_TYPE_METHOD = 9
|
TARGET_TYPE_METHOD = 9
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class MethodOptionsIdempotencyLevel(betterproto.Enum):
|
class MethodOptionsIdempotencyLevel(betterproto.Enum):
|
||||||
"""
|
"""
|
||||||
@ -305,6 +382,12 @@ class MethodOptionsIdempotencyLevel(betterproto.Enum):
|
|||||||
NO_SIDE_EFFECTS = 1
|
NO_SIDE_EFFECTS = 1
|
||||||
IDEMPOTENT = 2
|
IDEMPOTENT = 2
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class FeatureSetFieldPresence(betterproto.Enum):
|
class FeatureSetFieldPresence(betterproto.Enum):
|
||||||
FIELD_PRESENCE_UNKNOWN = 0
|
FIELD_PRESENCE_UNKNOWN = 0
|
||||||
@ -312,36 +395,72 @@ class FeatureSetFieldPresence(betterproto.Enum):
|
|||||||
IMPLICIT = 2
|
IMPLICIT = 2
|
||||||
LEGACY_REQUIRED = 3
|
LEGACY_REQUIRED = 3
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class FeatureSetEnumType(betterproto.Enum):
|
class FeatureSetEnumType(betterproto.Enum):
|
||||||
ENUM_TYPE_UNKNOWN = 0
|
ENUM_TYPE_UNKNOWN = 0
|
||||||
OPEN = 1
|
OPEN = 1
|
||||||
CLOSED = 2
|
CLOSED = 2
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class FeatureSetRepeatedFieldEncoding(betterproto.Enum):
|
class FeatureSetRepeatedFieldEncoding(betterproto.Enum):
|
||||||
REPEATED_FIELD_ENCODING_UNKNOWN = 0
|
REPEATED_FIELD_ENCODING_UNKNOWN = 0
|
||||||
PACKED = 1
|
PACKED = 1
|
||||||
EXPANDED = 2
|
EXPANDED = 2
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class FeatureSetUtf8Validation(betterproto.Enum):
|
class FeatureSetUtf8Validation(betterproto.Enum):
|
||||||
UTF8_VALIDATION_UNKNOWN = 0
|
UTF8_VALIDATION_UNKNOWN = 0
|
||||||
VERIFY = 2
|
VERIFY = 2
|
||||||
NONE = 3
|
NONE = 3
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class FeatureSetMessageEncoding(betterproto.Enum):
|
class FeatureSetMessageEncoding(betterproto.Enum):
|
||||||
MESSAGE_ENCODING_UNKNOWN = 0
|
MESSAGE_ENCODING_UNKNOWN = 0
|
||||||
LENGTH_PREFIXED = 1
|
LENGTH_PREFIXED = 1
|
||||||
DELIMITED = 2
|
DELIMITED = 2
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class FeatureSetJsonFormat(betterproto.Enum):
|
class FeatureSetJsonFormat(betterproto.Enum):
|
||||||
JSON_FORMAT_UNKNOWN = 0
|
JSON_FORMAT_UNKNOWN = 0
|
||||||
ALLOW = 1
|
ALLOW = 1
|
||||||
LEGACY_BEST_EFFORT = 2
|
LEGACY_BEST_EFFORT = 2
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class GeneratedCodeInfoAnnotationSemantic(betterproto.Enum):
|
class GeneratedCodeInfoAnnotationSemantic(betterproto.Enum):
|
||||||
"""
|
"""
|
||||||
@ -358,6 +477,12 @@ class GeneratedCodeInfoAnnotationSemantic(betterproto.Enum):
|
|||||||
ALIAS = 2
|
ALIAS = 2
|
||||||
"""An alias to the element is returned."""
|
"""An alias to the element is returned."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
class NullValue(betterproto.Enum):
|
class NullValue(betterproto.Enum):
|
||||||
"""
|
"""
|
||||||
@ -370,6 +495,12 @@ class NullValue(betterproto.Enum):
|
|||||||
_ = 0
|
_ = 0
|
||||||
"""Null value."""
|
"""Null value."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False, repr=False)
|
@dataclass(eq=False, repr=False)
|
||||||
class Any(betterproto.Message):
|
class Any(betterproto.Message):
|
||||||
@ -1176,16 +1307,12 @@ class FileOptions(betterproto.Message):
|
|||||||
|
|
||||||
java_string_check_utf8: bool = betterproto.bool_field(27)
|
java_string_check_utf8: bool = betterproto.bool_field(27)
|
||||||
"""
|
"""
|
||||||
A proto2 file can set this to true to opt in to UTF-8 checking for Java,
|
If set true, then the Java2 code generator will generate code that
|
||||||
which will throw an exception if invalid UTF-8 is parsed from the wire or
|
throws an exception whenever an attempt is made to assign a non-UTF-8
|
||||||
assigned to a string field.
|
byte sequence to a string field.
|
||||||
|
Message reflection will do the same.
|
||||||
TODO: clarify exactly what kinds of field types this option
|
However, an extension field still accepts non-UTF-8 byte sequences.
|
||||||
applies to, and update these docs accordingly.
|
This option has no effect on when used with the lite runtime.
|
||||||
|
|
||||||
Proto3 files already perform these checks. Setting the option explicitly to
|
|
||||||
false has no effect: it cannot be used to opt proto3 files out of UTF-8
|
|
||||||
checks.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
optimize_for: "FileOptionsOptimizeMode" = betterproto.enum_field(9)
|
optimize_for: "FileOptionsOptimizeMode" = betterproto.enum_field(9)
|
||||||
@ -1477,7 +1604,6 @@ class FieldOptions(betterproto.Message):
|
|||||||
features: "FeatureSet" = betterproto.message_field(21)
|
features: "FeatureSet" = betterproto.message_field(21)
|
||||||
"""Any features defined in the specific edition."""
|
"""Any features defined in the specific edition."""
|
||||||
|
|
||||||
feature_support: "FieldOptionsFeatureSupport" = betterproto.message_field(22)
|
|
||||||
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
|
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
|
||||||
"""The parser stores options it doesn't recognize here. See above."""
|
"""The parser stores options it doesn't recognize here. See above."""
|
||||||
|
|
||||||
@ -1488,37 +1614,6 @@ class FieldOptionsEditionDefault(betterproto.Message):
|
|||||||
value: str = betterproto.string_field(2)
|
value: str = betterproto.string_field(2)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False, repr=False)
|
|
||||||
class FieldOptionsFeatureSupport(betterproto.Message):
|
|
||||||
"""Information about the support window of a feature."""
|
|
||||||
|
|
||||||
edition_introduced: "Edition" = betterproto.enum_field(1)
|
|
||||||
"""
|
|
||||||
The edition that this feature was first available in. In editions
|
|
||||||
earlier than this one, the default assigned to EDITION_LEGACY will be
|
|
||||||
used, and proto files will not be able to override it.
|
|
||||||
"""
|
|
||||||
|
|
||||||
edition_deprecated: "Edition" = betterproto.enum_field(2)
|
|
||||||
"""
|
|
||||||
The edition this feature becomes deprecated in. Using this after this
|
|
||||||
edition may trigger warnings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
deprecation_warning: str = betterproto.string_field(3)
|
|
||||||
"""
|
|
||||||
The deprecation warning text if this feature is used after the edition it
|
|
||||||
was marked deprecated in.
|
|
||||||
"""
|
|
||||||
|
|
||||||
edition_removed: "Edition" = betterproto.enum_field(4)
|
|
||||||
"""
|
|
||||||
The edition this feature is no longer available in. In editions after
|
|
||||||
this one, the last default assigned will be used, and proto files will
|
|
||||||
not be able to override it.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False, repr=False)
|
@dataclass(eq=False, repr=False)
|
||||||
class OneofOptions(betterproto.Message):
|
class OneofOptions(betterproto.Message):
|
||||||
features: "FeatureSet" = betterproto.message_field(1)
|
features: "FeatureSet" = betterproto.message_field(1)
|
||||||
@ -1723,17 +1818,7 @@ class FeatureSetDefaultsFeatureSetEditionDefault(betterproto.Message):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
edition: "Edition" = betterproto.enum_field(3)
|
edition: "Edition" = betterproto.enum_field(3)
|
||||||
overridable_features: "FeatureSet" = betterproto.message_field(4)
|
|
||||||
"""Defaults of features that can be overridden in this edition."""
|
|
||||||
|
|
||||||
fixed_features: "FeatureSet" = betterproto.message_field(5)
|
|
||||||
"""Defaults of features that can't be overridden in this edition."""
|
|
||||||
|
|
||||||
features: "FeatureSet" = betterproto.message_field(2)
|
features: "FeatureSet" = betterproto.message_field(2)
|
||||||
"""
|
|
||||||
TODO Deprecate and remove this field, which is just the
|
|
||||||
above two merged.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False, repr=False)
|
@dataclass(eq=False, repr=False)
|
||||||
@ -2314,7 +2399,7 @@ class Value(betterproto.Message):
|
|||||||
)
|
)
|
||||||
"""Represents a repeated `Value`."""
|
"""Represents a repeated `Value`."""
|
||||||
|
|
||||||
@root_validator()
|
@model_validator(mode="after")
|
||||||
def check_oneof(cls, values):
|
def check_oneof(cls, values):
|
||||||
return cls._validate_field_groups(values)
|
return cls._validate_field_groups(values)
|
||||||
|
|
||||||
@ -2549,41 +2634,40 @@ class BytesValue(betterproto.Message):
|
|||||||
"""The bytes value."""
|
"""The bytes value."""
|
||||||
|
|
||||||
|
|
||||||
Type.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(Type) # type: ignore
|
||||||
Field.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(Field) # type: ignore
|
||||||
Enum.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(Enum) # type: ignore
|
||||||
EnumValue.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(EnumValue) # type: ignore
|
||||||
Option.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(Option) # type: ignore
|
||||||
Api.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(Api) # type: ignore
|
||||||
Method.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(Method) # type: ignore
|
||||||
FileDescriptorSet.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(FileDescriptorSet) # type: ignore
|
||||||
FileDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(FileDescriptorProto) # type: ignore
|
||||||
DescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(DescriptorProto) # type: ignore
|
||||||
DescriptorProtoExtensionRange.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(DescriptorProtoExtensionRange) # type: ignore
|
||||||
ExtensionRangeOptions.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(ExtensionRangeOptions) # type: ignore
|
||||||
FieldDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(FieldDescriptorProto) # type: ignore
|
||||||
OneofDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(OneofDescriptorProto) # type: ignore
|
||||||
EnumDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(EnumDescriptorProto) # type: ignore
|
||||||
EnumValueDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(EnumValueDescriptorProto) # type: ignore
|
||||||
ServiceDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(ServiceDescriptorProto) # type: ignore
|
||||||
MethodDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(MethodDescriptorProto) # type: ignore
|
||||||
FileOptions.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(FileOptions) # type: ignore
|
||||||
MessageOptions.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(MessageOptions) # type: ignore
|
||||||
FieldOptions.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(FieldOptions) # type: ignore
|
||||||
FieldOptionsEditionDefault.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(FieldOptionsEditionDefault) # type: ignore
|
||||||
FieldOptionsFeatureSupport.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(OneofOptions) # type: ignore
|
||||||
OneofOptions.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(EnumOptions) # type: ignore
|
||||||
EnumOptions.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(EnumValueOptions) # type: ignore
|
||||||
EnumValueOptions.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(ServiceOptions) # type: ignore
|
||||||
ServiceOptions.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(MethodOptions) # type: ignore
|
||||||
MethodOptions.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(UninterpretedOption) # type: ignore
|
||||||
UninterpretedOption.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(FeatureSet) # type: ignore
|
||||||
FeatureSet.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(FeatureSetDefaults) # type: ignore
|
||||||
FeatureSetDefaults.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(FeatureSetDefaultsFeatureSetEditionDefault) # type: ignore
|
||||||
FeatureSetDefaultsFeatureSetEditionDefault.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(SourceCodeInfo) # type: ignore
|
||||||
SourceCodeInfo.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(GeneratedCodeInfo) # type: ignore
|
||||||
GeneratedCodeInfo.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(GeneratedCodeInfoAnnotation) # type: ignore
|
||||||
GeneratedCodeInfoAnnotation.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(Struct) # type: ignore
|
||||||
Struct.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(Value) # type: ignore
|
||||||
Value.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass(ListValue) # type: ignore
|
||||||
ListValue.__pydantic_model__.update_forward_refs() # type: ignore
|
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
import os.path
|
import os.path
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from .module_validation import ModuleValidator
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -30,9 +33,12 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
|
|||||||
lstrip_blocks=True,
|
lstrip_blocks=True,
|
||||||
loader=jinja2.FileSystemLoader(templates_folder),
|
loader=jinja2.FileSystemLoader(templates_folder),
|
||||||
)
|
)
|
||||||
template = env.get_template("template.py.j2")
|
# Load the body first so we have a compleate list of imports needed.
|
||||||
|
body_template = env.get_template("template.py.j2")
|
||||||
|
header_template = env.get_template("header.py.j2")
|
||||||
|
|
||||||
code = template.render(output_file=output_file)
|
code = body_template.render(output_file=output_file)
|
||||||
|
code = header_template.render(output_file=output_file) + code
|
||||||
code = isort.api.sort_code_string(
|
code = isort.api.sort_code_string(
|
||||||
code=code,
|
code=code,
|
||||||
show_diff=False,
|
show_diff=False,
|
||||||
@ -44,7 +50,18 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
|
|||||||
force_grid_wrap=2,
|
force_grid_wrap=2,
|
||||||
known_third_party=["grpclib", "betterproto"],
|
known_third_party=["grpclib", "betterproto"],
|
||||||
)
|
)
|
||||||
return black.format_str(
|
code = black.format_str(
|
||||||
src_contents=code,
|
src_contents=code,
|
||||||
mode=black.Mode(),
|
mode=black.Mode(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate the generated code.
|
||||||
|
validator = ModuleValidator(iter(code.splitlines()))
|
||||||
|
if not validator.validate():
|
||||||
|
message_builder = ["[WARNING]: Generated code has collisions in the module:"]
|
||||||
|
for collision, lines in validator.collisions.items():
|
||||||
|
message_builder.append(f' "{collision}" on lines:')
|
||||||
|
for num, line in lines:
|
||||||
|
message_builder.append(f" {num}:{line}")
|
||||||
|
print("\n".join(message_builder), file=sys.stderr)
|
||||||
|
return code
|
||||||
|
@ -29,10 +29,8 @@ instantiating field `A` with parent message `B` should add a
|
|||||||
reference to `A` to `B`'s `fields` attribute.
|
reference to `A` to `B`'s `fields` attribute.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
import re
|
import re
|
||||||
import textwrap
|
|
||||||
from dataclasses import (
|
from dataclasses import (
|
||||||
dataclass,
|
dataclass,
|
||||||
field,
|
field,
|
||||||
@ -49,12 +47,6 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import betterproto
|
import betterproto
|
||||||
from betterproto import which_one_of
|
|
||||||
from betterproto.casing import sanitize_name
|
|
||||||
from betterproto.compile.importing import (
|
|
||||||
get_type_reference,
|
|
||||||
parse_source_type_name,
|
|
||||||
)
|
|
||||||
from betterproto.compile.naming import (
|
from betterproto.compile.naming import (
|
||||||
pythonize_class_name,
|
pythonize_class_name,
|
||||||
pythonize_field_name,
|
pythonize_field_name,
|
||||||
@ -72,6 +64,7 @@ from betterproto.lib.google.protobuf import (
|
|||||||
)
|
)
|
||||||
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
|
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
|
||||||
|
|
||||||
|
from .. import which_one_of
|
||||||
from ..compile.importing import (
|
from ..compile.importing import (
|
||||||
get_type_reference,
|
get_type_reference,
|
||||||
parse_source_type_name,
|
parse_source_type_name,
|
||||||
@ -82,6 +75,10 @@ from ..compile.naming import (
|
|||||||
pythonize_field_name,
|
pythonize_field_name,
|
||||||
pythonize_method_name,
|
pythonize_method_name,
|
||||||
)
|
)
|
||||||
|
from .typing_compiler import (
|
||||||
|
DirectImportTypingCompiler,
|
||||||
|
TypingCompiler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Create a unique placeholder to deal with
|
# Create a unique placeholder to deal with
|
||||||
@ -173,6 +170,7 @@ class ProtoContentBase:
|
|||||||
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
|
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
|
||||||
|
|
||||||
source_file: FileDescriptorProto
|
source_file: FileDescriptorProto
|
||||||
|
typing_compiler: TypingCompiler
|
||||||
path: List[int]
|
path: List[int]
|
||||||
comment_indent: int = 4
|
comment_indent: int = 4
|
||||||
parent: Union["betterproto.Message", "OutputTemplate"]
|
parent: Union["betterproto.Message", "OutputTemplate"]
|
||||||
@ -242,7 +240,6 @@ class OutputTemplate:
|
|||||||
input_files: List[str] = field(default_factory=list)
|
input_files: List[str] = field(default_factory=list)
|
||||||
imports: Set[str] = field(default_factory=set)
|
imports: Set[str] = field(default_factory=set)
|
||||||
datetime_imports: Set[str] = field(default_factory=set)
|
datetime_imports: Set[str] = field(default_factory=set)
|
||||||
typing_imports: Set[str] = field(default_factory=set)
|
|
||||||
pydantic_imports: Set[str] = field(default_factory=set)
|
pydantic_imports: Set[str] = field(default_factory=set)
|
||||||
builtins_import: bool = False
|
builtins_import: bool = False
|
||||||
messages: List["MessageCompiler"] = field(default_factory=list)
|
messages: List["MessageCompiler"] = field(default_factory=list)
|
||||||
@ -251,6 +248,7 @@ class OutputTemplate:
|
|||||||
imports_type_checking_only: Set[str] = field(default_factory=set)
|
imports_type_checking_only: Set[str] = field(default_factory=set)
|
||||||
pydantic_dataclasses: bool = False
|
pydantic_dataclasses: bool = False
|
||||||
output: bool = True
|
output: bool = True
|
||||||
|
typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def package(self) -> str:
|
def package(self) -> str:
|
||||||
@ -289,6 +287,7 @@ class MessageCompiler(ProtoContentBase):
|
|||||||
"""Representation of a protobuf message."""
|
"""Representation of a protobuf message."""
|
||||||
|
|
||||||
source_file: FileDescriptorProto
|
source_file: FileDescriptorProto
|
||||||
|
typing_compiler: TypingCompiler
|
||||||
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
|
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
|
||||||
proto_obj: DescriptorProto = PLACEHOLDER
|
proto_obj: DescriptorProto = PLACEHOLDER
|
||||||
path: List[int] = PLACEHOLDER
|
path: List[int] = PLACEHOLDER
|
||||||
@ -319,7 +318,7 @@ class MessageCompiler(ProtoContentBase):
|
|||||||
@property
|
@property
|
||||||
def annotation(self) -> str:
|
def annotation(self) -> str:
|
||||||
if self.repeated:
|
if self.repeated:
|
||||||
return f"List[{self.py_name}]"
|
return self.typing_compiler.list(self.py_name)
|
||||||
return self.py_name
|
return self.py_name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -434,18 +433,6 @@ class FieldCompiler(MessageCompiler):
|
|||||||
imports.add("datetime")
|
imports.add("datetime")
|
||||||
return imports
|
return imports
|
||||||
|
|
||||||
@property
|
|
||||||
def typing_imports(self) -> Set[str]:
|
|
||||||
imports = set()
|
|
||||||
annotation = self.annotation
|
|
||||||
if "Optional[" in annotation:
|
|
||||||
imports.add("Optional")
|
|
||||||
if "List[" in annotation:
|
|
||||||
imports.add("List")
|
|
||||||
if "Dict[" in annotation:
|
|
||||||
imports.add("Dict")
|
|
||||||
return imports
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pydantic_imports(self) -> Set[str]:
|
def pydantic_imports(self) -> Set[str]:
|
||||||
return set()
|
return set()
|
||||||
@ -458,7 +445,6 @@ class FieldCompiler(MessageCompiler):
|
|||||||
|
|
||||||
def add_imports_to(self, output_file: OutputTemplate) -> None:
|
def add_imports_to(self, output_file: OutputTemplate) -> None:
|
||||||
output_file.datetime_imports.update(self.datetime_imports)
|
output_file.datetime_imports.update(self.datetime_imports)
|
||||||
output_file.typing_imports.update(self.typing_imports)
|
|
||||||
output_file.pydantic_imports.update(self.pydantic_imports)
|
output_file.pydantic_imports.update(self.pydantic_imports)
|
||||||
output_file.builtins_import = output_file.builtins_import or self.use_builtins
|
output_file.builtins_import = output_file.builtins_import or self.use_builtins
|
||||||
|
|
||||||
@ -488,7 +474,9 @@ class FieldCompiler(MessageCompiler):
|
|||||||
@property
|
@property
|
||||||
def mutable(self) -> bool:
|
def mutable(self) -> bool:
|
||||||
"""True if the field is a mutable type, otherwise False."""
|
"""True if the field is a mutable type, otherwise False."""
|
||||||
return self.annotation.startswith(("List[", "Dict["))
|
return self.annotation.startswith(
|
||||||
|
("typing.List[", "typing.Dict[", "dict[", "list[", "Dict[", "List[")
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def field_type(self) -> str:
|
def field_type(self) -> str:
|
||||||
@ -562,6 +550,7 @@ class FieldCompiler(MessageCompiler):
|
|||||||
package=self.output_file.package,
|
package=self.output_file.package,
|
||||||
imports=self.output_file.imports,
|
imports=self.output_file.imports,
|
||||||
source_type=self.proto_obj.type_name,
|
source_type=self.proto_obj.type_name,
|
||||||
|
typing_compiler=self.typing_compiler,
|
||||||
pydantic=self.output_file.pydantic_dataclasses,
|
pydantic=self.output_file.pydantic_dataclasses,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -573,9 +562,9 @@ class FieldCompiler(MessageCompiler):
|
|||||||
if self.use_builtins:
|
if self.use_builtins:
|
||||||
py_type = f"builtins.{py_type}"
|
py_type = f"builtins.{py_type}"
|
||||||
if self.repeated:
|
if self.repeated:
|
||||||
return f"List[{py_type}]"
|
return self.typing_compiler.list(py_type)
|
||||||
if self.optional:
|
if self.optional:
|
||||||
return f"Optional[{py_type}]"
|
return self.typing_compiler.optional(py_type)
|
||||||
return py_type
|
return py_type
|
||||||
|
|
||||||
|
|
||||||
@ -600,7 +589,7 @@ class PydanticOneOfFieldCompiler(OneOfFieldCompiler):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def pydantic_imports(self) -> Set[str]:
|
def pydantic_imports(self) -> Set[str]:
|
||||||
return {"root_validator"}
|
return {"model_validator"}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -623,11 +612,13 @@ class MapEntryCompiler(FieldCompiler):
|
|||||||
source_file=self.source_file,
|
source_file=self.source_file,
|
||||||
parent=self,
|
parent=self,
|
||||||
proto_obj=nested.field[0], # key
|
proto_obj=nested.field[0], # key
|
||||||
|
typing_compiler=self.typing_compiler,
|
||||||
).py_type
|
).py_type
|
||||||
self.py_v_type = FieldCompiler(
|
self.py_v_type = FieldCompiler(
|
||||||
source_file=self.source_file,
|
source_file=self.source_file,
|
||||||
parent=self,
|
parent=self,
|
||||||
proto_obj=nested.field[1], # value
|
proto_obj=nested.field[1], # value
|
||||||
|
typing_compiler=self.typing_compiler,
|
||||||
).py_type
|
).py_type
|
||||||
|
|
||||||
# Get proto types
|
# Get proto types
|
||||||
@ -645,7 +636,7 @@ class MapEntryCompiler(FieldCompiler):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def annotation(self) -> str:
|
def annotation(self) -> str:
|
||||||
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
|
return self.typing_compiler.dict(self.py_k_type, self.py_v_type)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def repeated(self) -> bool:
|
def repeated(self) -> bool:
|
||||||
@ -702,7 +693,6 @@ class ServiceCompiler(ProtoContentBase):
|
|||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
# Add service to output file
|
# Add service to output file
|
||||||
self.output_file.services.append(self)
|
self.output_file.services.append(self)
|
||||||
self.output_file.typing_imports.add("Dict")
|
|
||||||
super().__post_init__() # check for unset fields
|
super().__post_init__() # check for unset fields
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -725,22 +715,6 @@ class ServiceMethodCompiler(ProtoContentBase):
|
|||||||
# Add method to service
|
# Add method to service
|
||||||
self.parent.methods.append(self)
|
self.parent.methods.append(self)
|
||||||
|
|
||||||
# Check for imports
|
|
||||||
if "Optional" in self.py_output_message_type:
|
|
||||||
self.output_file.typing_imports.add("Optional")
|
|
||||||
|
|
||||||
# Check for Async imports
|
|
||||||
if self.client_streaming:
|
|
||||||
self.output_file.typing_imports.add("AsyncIterable")
|
|
||||||
self.output_file.typing_imports.add("Iterable")
|
|
||||||
self.output_file.typing_imports.add("Union")
|
|
||||||
|
|
||||||
# Required by both client and server
|
|
||||||
if self.client_streaming or self.server_streaming:
|
|
||||||
self.output_file.typing_imports.add("AsyncIterator")
|
|
||||||
|
|
||||||
# add imports required for request arguments timeout, deadline and metadata
|
|
||||||
self.output_file.typing_imports.add("Optional")
|
|
||||||
self.output_file.imports_type_checking_only.add("import grpclib.server")
|
self.output_file.imports_type_checking_only.add("import grpclib.server")
|
||||||
self.output_file.imports_type_checking_only.add(
|
self.output_file.imports_type_checking_only.add(
|
||||||
"from betterproto.grpc.grpclib_client import MetadataLike"
|
"from betterproto.grpc.grpclib_client import MetadataLike"
|
||||||
@ -806,6 +780,7 @@ class ServiceMethodCompiler(ProtoContentBase):
|
|||||||
package=self.output_file.package,
|
package=self.output_file.package,
|
||||||
imports=self.output_file.imports,
|
imports=self.output_file.imports,
|
||||||
source_type=self.proto_obj.input_type,
|
source_type=self.proto_obj.input_type,
|
||||||
|
typing_compiler=self.output_file.typing_compiler,
|
||||||
unwrap=False,
|
unwrap=False,
|
||||||
pydantic=self.output_file.pydantic_dataclasses,
|
pydantic=self.output_file.pydantic_dataclasses,
|
||||||
).strip('"')
|
).strip('"')
|
||||||
@ -835,6 +810,7 @@ class ServiceMethodCompiler(ProtoContentBase):
|
|||||||
package=self.output_file.package,
|
package=self.output_file.package,
|
||||||
imports=self.output_file.imports,
|
imports=self.output_file.imports,
|
||||||
source_type=self.proto_obj.output_type,
|
source_type=self.proto_obj.output_type,
|
||||||
|
typing_compiler=self.output_file.typing_compiler,
|
||||||
unwrap=False,
|
unwrap=False,
|
||||||
pydantic=self.output_file.pydantic_dataclasses,
|
pydantic=self.output_file.pydantic_dataclasses,
|
||||||
).strip('"')
|
).strip('"')
|
||||||
|
163
src/betterproto/plugin/module_validation.py
Normal file
163
src/betterproto/plugin/module_validation.py
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
import re
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import (
|
||||||
|
dataclass,
|
||||||
|
field,
|
||||||
|
)
|
||||||
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Tuple,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModuleValidator:
|
||||||
|
line_iterator: Iterator[str]
|
||||||
|
line_number: int = field(init=False, default=0)
|
||||||
|
|
||||||
|
collisions: Dict[str, List[Tuple[int, str]]] = field(
|
||||||
|
init=False, default_factory=lambda: defaultdict(list)
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_import(self, imp: str, number: int, full_line: str):
|
||||||
|
"""
|
||||||
|
Adds an import to be tracked.
|
||||||
|
"""
|
||||||
|
self.collisions[imp].append((number, full_line))
|
||||||
|
|
||||||
|
def process_import(self, imp: str):
|
||||||
|
"""
|
||||||
|
Filters out the import to its actual value.
|
||||||
|
"""
|
||||||
|
if " as " in imp:
|
||||||
|
imp = imp[imp.index(" as ") + 4 :]
|
||||||
|
|
||||||
|
imp = imp.strip()
|
||||||
|
assert " " not in imp, imp
|
||||||
|
return imp
|
||||||
|
|
||||||
|
def evaluate_multiline_import(self, line: str):
|
||||||
|
"""
|
||||||
|
Evaluates a multiline import from a starting line
|
||||||
|
"""
|
||||||
|
# Filter the first line and remove anything before the import statement.
|
||||||
|
full_line = line
|
||||||
|
line = line.split("import", 1)[1]
|
||||||
|
if "(" in line:
|
||||||
|
conditional = lambda line: ")" not in line
|
||||||
|
else:
|
||||||
|
conditional = lambda line: "\\" in line
|
||||||
|
|
||||||
|
# Remove open parenthesis if it exists.
|
||||||
|
if "(" in line:
|
||||||
|
line = line[line.index("(") + 1 :]
|
||||||
|
|
||||||
|
# Choose the conditional based on how multiline imports are formatted.
|
||||||
|
while conditional(line):
|
||||||
|
# Split the line by commas
|
||||||
|
imports = line.split(",")
|
||||||
|
|
||||||
|
for imp in imports:
|
||||||
|
# Add the import to the namespace
|
||||||
|
imp = self.process_import(imp)
|
||||||
|
if imp:
|
||||||
|
self.add_import(imp, self.line_number, full_line)
|
||||||
|
# Get the next line
|
||||||
|
full_line = line = next(self.line_iterator)
|
||||||
|
# Increment the line number
|
||||||
|
self.line_number += 1
|
||||||
|
|
||||||
|
# validate the last line
|
||||||
|
if ")" in line:
|
||||||
|
line = line[: line.index(")")]
|
||||||
|
imports = line.split(",")
|
||||||
|
for imp in imports:
|
||||||
|
imp = self.process_import(imp)
|
||||||
|
if imp:
|
||||||
|
self.add_import(imp, self.line_number, full_line)
|
||||||
|
|
||||||
|
def evaluate_import(self, line: str):
|
||||||
|
"""
|
||||||
|
Extracts an import from a line.
|
||||||
|
"""
|
||||||
|
whole_line = line
|
||||||
|
line = line[line.index("import") + 6 :]
|
||||||
|
values = line.split(",")
|
||||||
|
for v in values:
|
||||||
|
self.add_import(self.process_import(v), self.line_number, whole_line)
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
"""
|
||||||
|
Evaluate each line for names in the module.
|
||||||
|
"""
|
||||||
|
line = next(self.line_iterator)
|
||||||
|
|
||||||
|
# Skip lines with indentation or comments
|
||||||
|
if (
|
||||||
|
# Skip indents and whitespace.
|
||||||
|
line.startswith(" ")
|
||||||
|
or line == "\n"
|
||||||
|
or line.startswith("\t")
|
||||||
|
or
|
||||||
|
# Skip comments
|
||||||
|
line.startswith("#")
|
||||||
|
or
|
||||||
|
# Skip decorators
|
||||||
|
line.startswith("@")
|
||||||
|
):
|
||||||
|
self.line_number += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
# Skip docstrings.
|
||||||
|
if line.startswith('"""') or line.startswith("'''"):
|
||||||
|
quote = line[0] * 3
|
||||||
|
line = line[3:]
|
||||||
|
while quote not in line:
|
||||||
|
line = next(self.line_iterator)
|
||||||
|
self.line_number += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
# Evaluate Imports.
|
||||||
|
if line.startswith("from ") or line.startswith("import "):
|
||||||
|
if "(" in line or "\\" in line:
|
||||||
|
self.evaluate_multiline_import(line)
|
||||||
|
else:
|
||||||
|
self.evaluate_import(line)
|
||||||
|
|
||||||
|
# Evaluate Classes.
|
||||||
|
elif line.startswith("class "):
|
||||||
|
class_name = re.search(r"class (\w+)", line).group(1)
|
||||||
|
if class_name:
|
||||||
|
self.add_import(class_name, self.line_number, line)
|
||||||
|
|
||||||
|
# Evaluate Functions.
|
||||||
|
elif line.startswith("def "):
|
||||||
|
function_name = re.search(r"def (\w+)", line).group(1)
|
||||||
|
if function_name:
|
||||||
|
self.add_import(function_name, self.line_number, line)
|
||||||
|
|
||||||
|
# Evaluate direct assignments.
|
||||||
|
elif "=" in line:
|
||||||
|
assignment = re.search(r"(\w+)\s*=", line).group(1)
|
||||||
|
if assignment:
|
||||||
|
self.add_import(assignment, self.line_number, line)
|
||||||
|
|
||||||
|
self.line_number += 1
|
||||||
|
|
||||||
|
def validate(self) -> bool:
|
||||||
|
"""
|
||||||
|
Run Validation.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
self.next()
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Filter collisions for those with more than one value.
|
||||||
|
self.collisions = {k: v for k, v in self.collisions.items() if len(v) > 1}
|
||||||
|
|
||||||
|
# Return True if no collisions are found.
|
||||||
|
return not bool(self.collisions)
|
@ -37,6 +37,12 @@ from .models import (
|
|||||||
is_map,
|
is_map,
|
||||||
is_oneof,
|
is_oneof,
|
||||||
)
|
)
|
||||||
|
from .typing_compiler import (
|
||||||
|
DirectImportTypingCompiler,
|
||||||
|
NoTyping310TypingCompiler,
|
||||||
|
TypingCompiler,
|
||||||
|
TypingImportTypingCompiler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def traverse(
|
def traverse(
|
||||||
@ -98,6 +104,28 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
|||||||
output_package_name
|
output_package_name
|
||||||
].pydantic_dataclasses = True
|
].pydantic_dataclasses = True
|
||||||
|
|
||||||
|
# Gather any typing generation options.
|
||||||
|
typing_opts = [
|
||||||
|
opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(typing_opts) > 1:
|
||||||
|
raise ValueError("Multiple typing options provided")
|
||||||
|
# Set the compiler type.
|
||||||
|
typing_opt = typing_opts[0] if typing_opts else "direct"
|
||||||
|
if typing_opt == "direct":
|
||||||
|
request_data.output_packages[
|
||||||
|
output_package_name
|
||||||
|
].typing_compiler = DirectImportTypingCompiler()
|
||||||
|
elif typing_opt == "root":
|
||||||
|
request_data.output_packages[
|
||||||
|
output_package_name
|
||||||
|
].typing_compiler = TypingImportTypingCompiler()
|
||||||
|
elif typing_opt == "310":
|
||||||
|
request_data.output_packages[
|
||||||
|
output_package_name
|
||||||
|
].typing_compiler = NoTyping310TypingCompiler()
|
||||||
|
|
||||||
# Read Messages and Enums
|
# Read Messages and Enums
|
||||||
# We need to read Messages before Services in so that we can
|
# We need to read Messages before Services in so that we can
|
||||||
# get the references to input/output messages for each service
|
# get the references to input/output messages for each service
|
||||||
@ -166,6 +194,7 @@ def _make_one_of_field_compiler(
|
|||||||
parent=parent,
|
parent=parent,
|
||||||
proto_obj=proto_obj,
|
proto_obj=proto_obj,
|
||||||
path=path,
|
path=path,
|
||||||
|
typing_compiler=output_package.typing_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -181,7 +210,11 @@ def read_protobuf_type(
|
|||||||
return
|
return
|
||||||
# Process Message
|
# Process Message
|
||||||
message_data = MessageCompiler(
|
message_data = MessageCompiler(
|
||||||
source_file=source_file, parent=output_package, proto_obj=item, path=path
|
source_file=source_file,
|
||||||
|
parent=output_package,
|
||||||
|
proto_obj=item,
|
||||||
|
path=path,
|
||||||
|
typing_compiler=output_package.typing_compiler,
|
||||||
)
|
)
|
||||||
for index, field in enumerate(item.field):
|
for index, field in enumerate(item.field):
|
||||||
if is_map(field, item):
|
if is_map(field, item):
|
||||||
@ -190,6 +223,7 @@ def read_protobuf_type(
|
|||||||
parent=message_data,
|
parent=message_data,
|
||||||
proto_obj=field,
|
proto_obj=field,
|
||||||
path=path + [2, index],
|
path=path + [2, index],
|
||||||
|
typing_compiler=output_package.typing_compiler,
|
||||||
)
|
)
|
||||||
elif is_oneof(field):
|
elif is_oneof(field):
|
||||||
_make_one_of_field_compiler(
|
_make_one_of_field_compiler(
|
||||||
@ -201,11 +235,16 @@ def read_protobuf_type(
|
|||||||
parent=message_data,
|
parent=message_data,
|
||||||
proto_obj=field,
|
proto_obj=field,
|
||||||
path=path + [2, index],
|
path=path + [2, index],
|
||||||
|
typing_compiler=output_package.typing_compiler,
|
||||||
)
|
)
|
||||||
elif isinstance(item, EnumDescriptorProto):
|
elif isinstance(item, EnumDescriptorProto):
|
||||||
# Enum
|
# Enum
|
||||||
EnumDefinitionCompiler(
|
EnumDefinitionCompiler(
|
||||||
source_file=source_file, parent=output_package, proto_obj=item, path=path
|
source_file=source_file,
|
||||||
|
parent=output_package,
|
||||||
|
proto_obj=item,
|
||||||
|
path=path,
|
||||||
|
typing_compiler=output_package.typing_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
167
src/betterproto/plugin/typing_compiler.py
Normal file
167
src/betterproto/plugin/typing_compiler.py
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
import abc
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import (
|
||||||
|
dataclass,
|
||||||
|
field,
|
||||||
|
)
|
||||||
|
from typing import (
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TypingCompiler(metaclass=abc.ABCMeta):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def optional(self, type: str) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def list(self, type: str) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def dict(self, key: str, value: str) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def union(self, *types: str) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def iterable(self, type: str) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def async_iterable(self, type: str) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def async_iterator(self, type: str) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def imports(self) -> Dict[str, Optional[Set[str]]]:
|
||||||
|
"""
|
||||||
|
Returns either the direct import as a key with none as value, or a set of
|
||||||
|
values to import from the key.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def import_lines(self) -> Iterator:
|
||||||
|
imports = self.imports()
|
||||||
|
for key, value in imports.items():
|
||||||
|
if value is None:
|
||||||
|
yield f"import {key}"
|
||||||
|
else:
|
||||||
|
yield f"from {key} import ("
|
||||||
|
for v in sorted(value):
|
||||||
|
yield f" {v},"
|
||||||
|
yield ")"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DirectImportTypingCompiler(TypingCompiler):
|
||||||
|
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
|
||||||
|
|
||||||
|
def optional(self, type: str) -> str:
|
||||||
|
self._imports["typing"].add("Optional")
|
||||||
|
return f"Optional[{type}]"
|
||||||
|
|
||||||
|
def list(self, type: str) -> str:
|
||||||
|
self._imports["typing"].add("List")
|
||||||
|
return f"List[{type}]"
|
||||||
|
|
||||||
|
def dict(self, key: str, value: str) -> str:
|
||||||
|
self._imports["typing"].add("Dict")
|
||||||
|
return f"Dict[{key}, {value}]"
|
||||||
|
|
||||||
|
def union(self, *types: str) -> str:
|
||||||
|
self._imports["typing"].add("Union")
|
||||||
|
return f"Union[{', '.join(types)}]"
|
||||||
|
|
||||||
|
def iterable(self, type: str) -> str:
|
||||||
|
self._imports["typing"].add("Iterable")
|
||||||
|
return f"Iterable[{type}]"
|
||||||
|
|
||||||
|
def async_iterable(self, type: str) -> str:
|
||||||
|
self._imports["typing"].add("AsyncIterable")
|
||||||
|
return f"AsyncIterable[{type}]"
|
||||||
|
|
||||||
|
def async_iterator(self, type: str) -> str:
|
||||||
|
self._imports["typing"].add("AsyncIterator")
|
||||||
|
return f"AsyncIterator[{type}]"
|
||||||
|
|
||||||
|
def imports(self) -> Dict[str, Optional[Set[str]]]:
|
||||||
|
return {k: v if v else None for k, v in self._imports.items()}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TypingImportTypingCompiler(TypingCompiler):
|
||||||
|
_imported: bool = False
|
||||||
|
|
||||||
|
def optional(self, type: str) -> str:
|
||||||
|
self._imported = True
|
||||||
|
return f"typing.Optional[{type}]"
|
||||||
|
|
||||||
|
def list(self, type: str) -> str:
|
||||||
|
self._imported = True
|
||||||
|
return f"typing.List[{type}]"
|
||||||
|
|
||||||
|
def dict(self, key: str, value: str) -> str:
|
||||||
|
self._imported = True
|
||||||
|
return f"typing.Dict[{key}, {value}]"
|
||||||
|
|
||||||
|
def union(self, *types: str) -> str:
|
||||||
|
self._imported = True
|
||||||
|
return f"typing.Union[{', '.join(types)}]"
|
||||||
|
|
||||||
|
def iterable(self, type: str) -> str:
|
||||||
|
self._imported = True
|
||||||
|
return f"typing.Iterable[{type}]"
|
||||||
|
|
||||||
|
def async_iterable(self, type: str) -> str:
|
||||||
|
self._imported = True
|
||||||
|
return f"typing.AsyncIterable[{type}]"
|
||||||
|
|
||||||
|
def async_iterator(self, type: str) -> str:
|
||||||
|
self._imported = True
|
||||||
|
return f"typing.AsyncIterator[{type}]"
|
||||||
|
|
||||||
|
def imports(self) -> Dict[str, Optional[Set[str]]]:
|
||||||
|
if self._imported:
|
||||||
|
return {"typing": None}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NoTyping310TypingCompiler(TypingCompiler):
|
||||||
|
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
|
||||||
|
|
||||||
|
def optional(self, type: str) -> str:
|
||||||
|
return f"{type} | None"
|
||||||
|
|
||||||
|
def list(self, type: str) -> str:
|
||||||
|
return f"list[{type}]"
|
||||||
|
|
||||||
|
def dict(self, key: str, value: str) -> str:
|
||||||
|
return f"dict[{key}, {value}]"
|
||||||
|
|
||||||
|
def union(self, *types: str) -> str:
|
||||||
|
return " | ".join(types)
|
||||||
|
|
||||||
|
def iterable(self, type: str) -> str:
|
||||||
|
self._imports["typing"].add("Iterable")
|
||||||
|
return f"Iterable[{type}]"
|
||||||
|
|
||||||
|
def async_iterable(self, type: str) -> str:
|
||||||
|
self._imports["typing"].add("AsyncIterable")
|
||||||
|
return f"AsyncIterable[{type}]"
|
||||||
|
|
||||||
|
def async_iterator(self, type: str) -> str:
|
||||||
|
self._imports["typing"].add("AsyncIterator")
|
||||||
|
return f"AsyncIterator[{type}]"
|
||||||
|
|
||||||
|
def imports(self) -> Dict[str, Optional[Set[str]]]:
|
||||||
|
return {k: v if v else None for k, v in self._imports.items()}
|
55
src/betterproto/templates/header.py.j2
Normal file
55
src/betterproto/templates/header.py.j2
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||||
|
# sources: {{ ', '.join(output_file.input_filenames) }}
|
||||||
|
# plugin: python-betterproto
|
||||||
|
# This file has been @generated
|
||||||
|
{% for i in output_file.python_module_imports|sort %}
|
||||||
|
import {{ i }}
|
||||||
|
{% endfor %}
|
||||||
|
{% set type_checking_imported = False %}
|
||||||
|
|
||||||
|
{% if output_file.pydantic_dataclasses %}
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
{% set type_checking_imported = True %}
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from dataclasses import dataclass
|
||||||
|
else:
|
||||||
|
from pydantic.dataclasses import dataclass
|
||||||
|
from pydantic.dataclasses import rebuild_dataclass
|
||||||
|
{%- else -%}
|
||||||
|
from dataclasses import dataclass
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% if output_file.datetime_imports %}
|
||||||
|
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||||
|
|
||||||
|
{% endif%}
|
||||||
|
{% set typing_imports = output_file.typing_compiler.imports() %}
|
||||||
|
{% if typing_imports %}
|
||||||
|
{% for line in output_file.typing_compiler.import_lines() %}
|
||||||
|
{{ line }}
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% if output_file.pydantic_imports %}
|
||||||
|
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||||
|
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
import betterproto
|
||||||
|
{% if output_file.services %}
|
||||||
|
from betterproto.grpc.grpclib_server import ServiceBase
|
||||||
|
import grpclib
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% for i in output_file.imports|sort %}
|
||||||
|
{{ i }}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
{% if output_file.imports_type_checking_only and not type_checking_imported %}
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
@ -1,53 +1,3 @@
|
|||||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
||||||
# sources: {{ ', '.join(output_file.input_filenames) }}
|
|
||||||
# plugin: python-betterproto
|
|
||||||
# This file has been @generated
|
|
||||||
{% for i in output_file.python_module_imports|sort %}
|
|
||||||
import {{ i }}
|
|
||||||
{% endfor %}
|
|
||||||
|
|
||||||
{% if output_file.pydantic_dataclasses %}
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from dataclasses import dataclass
|
|
||||||
else:
|
|
||||||
from pydantic.dataclasses import dataclass
|
|
||||||
{%- else -%}
|
|
||||||
from dataclasses import dataclass
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
{% if output_file.datetime_imports %}
|
|
||||||
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
|
||||||
|
|
||||||
{% endif%}
|
|
||||||
{% if output_file.typing_imports %}
|
|
||||||
from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
|
||||||
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
{% if output_file.pydantic_imports %}
|
|
||||||
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
|
||||||
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
import betterproto
|
|
||||||
{% if output_file.services %}
|
|
||||||
from betterproto.grpc.grpclib_server import ServiceBase
|
|
||||||
import grpclib
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
{% for i in output_file.imports|sort %}
|
|
||||||
{{ i }}
|
|
||||||
{% endfor %}
|
|
||||||
|
|
||||||
{% if output_file.imports_type_checking_only %}
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
|
|
||||||
{% endfor %}
|
|
||||||
{% endif %}
|
|
||||||
|
|
||||||
{% if output_file.enums %}{% for enum in output_file.enums %}
|
{% if output_file.enums %}{% for enum in output_file.enums %}
|
||||||
class {{ enum.py_name }}(betterproto.Enum):
|
class {{ enum.py_name }}(betterproto.Enum):
|
||||||
{% if enum.comment %}
|
{% if enum.comment %}
|
||||||
@ -62,6 +12,13 @@ class {{ enum.py_name }}(betterproto.Enum):
|
|||||||
{% endif %}
|
{% endif %}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
|
{% if output_file.pydantic_dataclasses %}
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||||
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
return core_schema.int_schema(ge=0)
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
@ -96,7 +53,7 @@ class {{ message.py_name }}(betterproto.Message):
|
|||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
|
{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
|
||||||
@root_validator()
|
@model_validator(mode='after')
|
||||||
def check_oneof(cls, values):
|
def check_oneof(cls, values):
|
||||||
return cls._validate_field_groups(values)
|
return cls._validate_field_groups(values)
|
||||||
{% endif %}
|
{% endif %}
|
||||||
@ -116,14 +73,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
|||||||
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
|
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
|
||||||
{%- else -%}
|
{%- else -%}
|
||||||
{# Client streaming: need a request iterator instead #}
|
{# Client streaming: need a request iterator instead #}
|
||||||
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
|
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }}
|
||||||
{%- endif -%}
|
{%- endif -%}
|
||||||
,
|
,
|
||||||
*
|
*
|
||||||
, timeout: Optional[float] = None
|
, timeout: {{ output_file.typing_compiler.optional("float") }} = None
|
||||||
, deadline: Optional["Deadline"] = None
|
, deadline: {{ output_file.typing_compiler.optional('"Deadline"') }} = None
|
||||||
, metadata: Optional["MetadataLike"] = None
|
, metadata: {{ output_file.typing_compiler.optional('"MetadataLike"') }} = None
|
||||||
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||||
{% if method.comment %}
|
{% if method.comment %}
|
||||||
{{ method.comment }}
|
{{ method.comment }}
|
||||||
|
|
||||||
@ -191,9 +148,9 @@ class {{ service.py_name }}Base(ServiceBase):
|
|||||||
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
|
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
|
||||||
{%- else -%}
|
{%- else -%}
|
||||||
{# Client streaming: need a request iterator instead #}
|
{# Client streaming: need a request iterator instead #}
|
||||||
, {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
|
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }}
|
||||||
{%- endif -%}
|
{%- endif -%}
|
||||||
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||||
{% if method.comment %}
|
{% if method.comment %}
|
||||||
{{ method.comment }}
|
{{ method.comment }}
|
||||||
|
|
||||||
@ -225,7 +182,7 @@ class {{ service.py_name }}Base(ServiceBase):
|
|||||||
|
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
|
def __mapping__(self) -> {{ output_file.typing_compiler.dict("str", "grpclib.const.Handler") }}:
|
||||||
return {
|
return {
|
||||||
{% for method in service.methods %}
|
{% for method in service.methods %}
|
||||||
"{{ method.route }}": grpclib.const.Handler(
|
"{{ method.route }}": grpclib.const.Handler(
|
||||||
@ -250,7 +207,7 @@ class {{ service.py_name }}Base(ServiceBase):
|
|||||||
{% if output_file.pydantic_dataclasses %}
|
{% if output_file.pydantic_dataclasses %}
|
||||||
{% for message in output_file.messages %}
|
{% for message in output_file.messages %}
|
||||||
{% if message.has_message_field %}
|
{% if message.has_message_field %}
|
||||||
{{ message.py_name }}.__pydantic_model__.update_forward_refs() # type: ignore
|
rebuild_dataclass({{ message.py_name }}) # type: ignore
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
@ -108,6 +108,7 @@ async def generate_test_case_output(
|
|||||||
print(
|
print(
|
||||||
f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m"
|
f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m"
|
||||||
)
|
)
|
||||||
|
print(ref_err.decode())
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
if ref_out:
|
if ref_out:
|
||||||
@ -126,6 +127,7 @@ async def generate_test_case_output(
|
|||||||
print(
|
print(
|
||||||
f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m"
|
f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m"
|
||||||
)
|
)
|
||||||
|
print(plg_err.decode())
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
if plg_out:
|
if plg_out:
|
||||||
@ -146,6 +148,7 @@ async def generate_test_case_output(
|
|||||||
print(
|
print(
|
||||||
f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
|
f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
|
||||||
)
|
)
|
||||||
|
print(plg_err_pyd.decode())
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
if plg_out_pyd:
|
if plg_out_pyd:
|
||||||
|
@ -10,10 +10,15 @@ def test_value():
|
|||||||
|
|
||||||
|
|
||||||
def test_pydantic_no_value():
|
def test_pydantic_no_value():
|
||||||
with pytest.raises(ValueError):
|
message = TestPyd()
|
||||||
TestPyd()
|
assert not message.value, "Boolean is False by default"
|
||||||
|
|
||||||
|
|
||||||
def test_pydantic_value():
|
def test_pydantic_value():
|
||||||
message = Test(value=False)
|
message = TestPyd(value=False)
|
||||||
assert not message.value
|
assert not message.value
|
||||||
|
|
||||||
|
|
||||||
|
def test_pydantic_bad_value():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
TestPyd(value=123)
|
||||||
|
@ -4,6 +4,15 @@ from betterproto.compile.importing import (
|
|||||||
get_type_reference,
|
get_type_reference,
|
||||||
parse_source_type_name,
|
parse_source_type_name,
|
||||||
)
|
)
|
||||||
|
from betterproto.plugin.typing_compiler import DirectImportTypingCompiler
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def typing_compiler() -> DirectImportTypingCompiler:
|
||||||
|
"""
|
||||||
|
Generates a simple Direct Import Typing Compiler for testing.
|
||||||
|
"""
|
||||||
|
return DirectImportTypingCompiler()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -32,11 +41,18 @@ from betterproto.compile.importing import (
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_reference_google_wellknown_types_non_wrappers(
|
def test_reference_google_wellknown_types_non_wrappers(
|
||||||
google_type: str, expected_name: str, expected_import: str
|
google_type: str,
|
||||||
|
expected_name: str,
|
||||||
|
expected_import: str,
|
||||||
|
typing_compiler: DirectImportTypingCompiler,
|
||||||
):
|
):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="", imports=imports, source_type=google_type, pydantic=False
|
package="",
|
||||||
|
imports=imports,
|
||||||
|
source_type=google_type,
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
|
pydantic=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert name == expected_name
|
assert name == expected_name
|
||||||
@ -71,11 +87,18 @@ def test_reference_google_wellknown_types_non_wrappers(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_reference_google_wellknown_types_non_wrappers_pydantic(
|
def test_reference_google_wellknown_types_non_wrappers_pydantic(
|
||||||
google_type: str, expected_name: str, expected_import: str
|
google_type: str,
|
||||||
|
expected_name: str,
|
||||||
|
expected_import: str,
|
||||||
|
typing_compiler: DirectImportTypingCompiler,
|
||||||
):
|
):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="", imports=imports, source_type=google_type, pydantic=True
|
package="",
|
||||||
|
imports=imports,
|
||||||
|
source_type=google_type,
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
|
pydantic=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert name == expected_name
|
assert name == expected_name
|
||||||
@ -99,10 +122,15 @@ def test_reference_google_wellknown_types_non_wrappers_pydantic(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_referenceing_google_wrappers_unwraps_them(
|
def test_referenceing_google_wrappers_unwraps_them(
|
||||||
google_type: str, expected_name: str
|
google_type: str, expected_name: str, typing_compiler: DirectImportTypingCompiler
|
||||||
):
|
):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(package="", imports=imports, source_type=google_type)
|
name = get_type_reference(
|
||||||
|
package="",
|
||||||
|
imports=imports,
|
||||||
|
source_type=google_type,
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
|
)
|
||||||
|
|
||||||
assert name == expected_name
|
assert name == expected_name
|
||||||
assert imports == set()
|
assert imports == set()
|
||||||
@ -135,223 +163,321 @@ def test_referenceing_google_wrappers_unwraps_them(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_referenceing_google_wrappers_without_unwrapping(
|
def test_referenceing_google_wrappers_without_unwrapping(
|
||||||
google_type: str, expected_name: str
|
google_type: str, expected_name: str, typing_compiler: DirectImportTypingCompiler
|
||||||
):
|
):
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="", imports=set(), source_type=google_type, unwrap=False
|
package="",
|
||||||
|
imports=set(),
|
||||||
|
source_type=google_type,
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
|
unwrap=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert name == expected_name
|
assert name == expected_name
|
||||||
|
|
||||||
|
|
||||||
def test_reference_child_package_from_package():
|
def test_reference_child_package_from_package(
|
||||||
|
typing_compiler: DirectImportTypingCompiler,
|
||||||
|
):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="package", imports=imports, source_type="package.child.Message"
|
package="package",
|
||||||
|
imports=imports,
|
||||||
|
source_type="package.child.Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert imports == {"from . import child"}
|
assert imports == {"from . import child"}
|
||||||
assert name == '"child.Message"'
|
assert name == '"child.Message"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_child_package_from_root():
|
def test_reference_child_package_from_root(typing_compiler: DirectImportTypingCompiler):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(package="", imports=imports, source_type="child.Message")
|
name = get_type_reference(
|
||||||
|
package="",
|
||||||
|
imports=imports,
|
||||||
|
source_type="child.Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
|
)
|
||||||
|
|
||||||
assert imports == {"from . import child"}
|
assert imports == {"from . import child"}
|
||||||
assert name == '"child.Message"'
|
assert name == '"child.Message"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_camel_cased():
|
def test_reference_camel_cased(typing_compiler: DirectImportTypingCompiler):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="", imports=imports, source_type="child_package.example_message"
|
package="",
|
||||||
|
imports=imports,
|
||||||
|
source_type="child_package.example_message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert imports == {"from . import child_package"}
|
assert imports == {"from . import child_package"}
|
||||||
assert name == '"child_package.ExampleMessage"'
|
assert name == '"child_package.ExampleMessage"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_nested_child_from_root():
|
def test_reference_nested_child_from_root(typing_compiler: DirectImportTypingCompiler):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="", imports=imports, source_type="nested.child.Message"
|
package="",
|
||||||
|
imports=imports,
|
||||||
|
source_type="nested.child.Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert imports == {"from .nested import child as nested_child"}
|
assert imports == {"from .nested import child as nested_child"}
|
||||||
assert name == '"nested_child.Message"'
|
assert name == '"nested_child.Message"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_deeply_nested_child_from_root():
|
def test_reference_deeply_nested_child_from_root(
|
||||||
|
typing_compiler: DirectImportTypingCompiler,
|
||||||
|
):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="", imports=imports, source_type="deeply.nested.child.Message"
|
package="",
|
||||||
|
imports=imports,
|
||||||
|
source_type="deeply.nested.child.Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert imports == {"from .deeply.nested import child as deeply_nested_child"}
|
assert imports == {"from .deeply.nested import child as deeply_nested_child"}
|
||||||
assert name == '"deeply_nested_child.Message"'
|
assert name == '"deeply_nested_child.Message"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_deeply_nested_child_from_package():
|
def test_reference_deeply_nested_child_from_package(
|
||||||
|
typing_compiler: DirectImportTypingCompiler,
|
||||||
|
):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="package",
|
package="package",
|
||||||
imports=imports,
|
imports=imports,
|
||||||
source_type="package.deeply.nested.child.Message",
|
source_type="package.deeply.nested.child.Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert imports == {"from .deeply.nested import child as deeply_nested_child"}
|
assert imports == {"from .deeply.nested import child as deeply_nested_child"}
|
||||||
assert name == '"deeply_nested_child.Message"'
|
assert name == '"deeply_nested_child.Message"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_root_sibling():
|
def test_reference_root_sibling(typing_compiler: DirectImportTypingCompiler):
|
||||||
imports = set()
|
|
||||||
name = get_type_reference(package="", imports=imports, source_type="Message")
|
|
||||||
|
|
||||||
assert imports == set()
|
|
||||||
assert name == '"Message"'
|
|
||||||
|
|
||||||
|
|
||||||
def test_reference_nested_siblings():
|
|
||||||
imports = set()
|
|
||||||
name = get_type_reference(package="foo", imports=imports, source_type="foo.Message")
|
|
||||||
|
|
||||||
assert imports == set()
|
|
||||||
assert name == '"Message"'
|
|
||||||
|
|
||||||
|
|
||||||
def test_reference_deeply_nested_siblings():
|
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="foo.bar", imports=imports, source_type="foo.bar.Message"
|
package="",
|
||||||
|
imports=imports,
|
||||||
|
source_type="Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert imports == set()
|
assert imports == set()
|
||||||
assert name == '"Message"'
|
assert name == '"Message"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_parent_package_from_child():
|
def test_reference_nested_siblings(typing_compiler: DirectImportTypingCompiler):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="package.child", imports=imports, source_type="package.Message"
|
package="foo",
|
||||||
|
imports=imports,
|
||||||
|
source_type="foo.Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert imports == set()
|
||||||
|
assert name == '"Message"'
|
||||||
|
|
||||||
|
|
||||||
|
def test_reference_deeply_nested_siblings(typing_compiler: DirectImportTypingCompiler):
|
||||||
|
imports = set()
|
||||||
|
name = get_type_reference(
|
||||||
|
package="foo.bar",
|
||||||
|
imports=imports,
|
||||||
|
source_type="foo.bar.Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert imports == set()
|
||||||
|
assert name == '"Message"'
|
||||||
|
|
||||||
|
|
||||||
|
def test_reference_parent_package_from_child(
|
||||||
|
typing_compiler: DirectImportTypingCompiler,
|
||||||
|
):
|
||||||
|
imports = set()
|
||||||
|
name = get_type_reference(
|
||||||
|
package="package.child",
|
||||||
|
imports=imports,
|
||||||
|
source_type="package.Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert imports == {"from ... import package as __package__"}
|
assert imports == {"from ... import package as __package__"}
|
||||||
assert name == '"__package__.Message"'
|
assert name == '"__package__.Message"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_parent_package_from_deeply_nested_child():
|
def test_reference_parent_package_from_deeply_nested_child(
|
||||||
|
typing_compiler: DirectImportTypingCompiler,
|
||||||
|
):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="package.deeply.nested.child",
|
package="package.deeply.nested.child",
|
||||||
imports=imports,
|
imports=imports,
|
||||||
source_type="package.deeply.nested.Message",
|
source_type="package.deeply.nested.Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert imports == {"from ... import nested as __nested__"}
|
assert imports == {"from ... import nested as __nested__"}
|
||||||
assert name == '"__nested__.Message"'
|
assert name == '"__nested__.Message"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_ancestor_package_from_nested_child():
|
def test_reference_ancestor_package_from_nested_child(
|
||||||
|
typing_compiler: DirectImportTypingCompiler,
|
||||||
|
):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="package.ancestor.nested.child",
|
package="package.ancestor.nested.child",
|
||||||
imports=imports,
|
imports=imports,
|
||||||
source_type="package.ancestor.Message",
|
source_type="package.ancestor.Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert imports == {"from .... import ancestor as ___ancestor__"}
|
assert imports == {"from .... import ancestor as ___ancestor__"}
|
||||||
assert name == '"___ancestor__.Message"'
|
assert name == '"___ancestor__.Message"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_root_package_from_child():
|
def test_reference_root_package_from_child(typing_compiler: DirectImportTypingCompiler):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="package.child", imports=imports, source_type="Message"
|
package="package.child",
|
||||||
|
imports=imports,
|
||||||
|
source_type="Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert imports == {"from ... import Message as __Message__"}
|
assert imports == {"from ... import Message as __Message__"}
|
||||||
assert name == '"__Message__"'
|
assert name == '"__Message__"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_root_package_from_deeply_nested_child():
|
def test_reference_root_package_from_deeply_nested_child(
|
||||||
|
typing_compiler: DirectImportTypingCompiler,
|
||||||
|
):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="package.deeply.nested.child", imports=imports, source_type="Message"
|
package="package.deeply.nested.child",
|
||||||
|
imports=imports,
|
||||||
|
source_type="Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert imports == {"from ..... import Message as ____Message__"}
|
assert imports == {"from ..... import Message as ____Message__"}
|
||||||
assert name == '"____Message__"'
|
assert name == '"____Message__"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_unrelated_package():
|
def test_reference_unrelated_package(typing_compiler: DirectImportTypingCompiler):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(package="a", imports=imports, source_type="p.Message")
|
name = get_type_reference(
|
||||||
|
package="a",
|
||||||
|
imports=imports,
|
||||||
|
source_type="p.Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
|
)
|
||||||
|
|
||||||
assert imports == {"from .. import p as _p__"}
|
assert imports == {"from .. import p as _p__"}
|
||||||
assert name == '"_p__.Message"'
|
assert name == '"_p__.Message"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_unrelated_nested_package():
|
def test_reference_unrelated_nested_package(
|
||||||
|
typing_compiler: DirectImportTypingCompiler,
|
||||||
|
):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(package="a.b", imports=imports, source_type="p.q.Message")
|
name = get_type_reference(
|
||||||
|
package="a.b",
|
||||||
|
imports=imports,
|
||||||
|
source_type="p.q.Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
|
)
|
||||||
|
|
||||||
assert imports == {"from ...p import q as __p_q__"}
|
assert imports == {"from ...p import q as __p_q__"}
|
||||||
assert name == '"__p_q__.Message"'
|
assert name == '"__p_q__.Message"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_unrelated_deeply_nested_package():
|
def test_reference_unrelated_deeply_nested_package(
|
||||||
|
typing_compiler: DirectImportTypingCompiler,
|
||||||
|
):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="a.b.c.d", imports=imports, source_type="p.q.r.s.Message"
|
package="a.b.c.d",
|
||||||
|
imports=imports,
|
||||||
|
source_type="p.q.r.s.Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert imports == {"from .....p.q.r import s as ____p_q_r_s__"}
|
assert imports == {"from .....p.q.r import s as ____p_q_r_s__"}
|
||||||
assert name == '"____p_q_r_s__.Message"'
|
assert name == '"____p_q_r_s__.Message"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_cousin_package():
|
def test_reference_cousin_package(typing_compiler: DirectImportTypingCompiler):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(package="a.x", imports=imports, source_type="a.y.Message")
|
name = get_type_reference(
|
||||||
|
package="a.x",
|
||||||
|
imports=imports,
|
||||||
|
source_type="a.y.Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
|
)
|
||||||
|
|
||||||
assert imports == {"from .. import y as _y__"}
|
assert imports == {"from .. import y as _y__"}
|
||||||
assert name == '"_y__.Message"'
|
assert name == '"_y__.Message"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_cousin_package_different_name():
|
def test_reference_cousin_package_different_name(
|
||||||
|
typing_compiler: DirectImportTypingCompiler,
|
||||||
|
):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="test.package1", imports=imports, source_type="cousin.package2.Message"
|
package="test.package1",
|
||||||
|
imports=imports,
|
||||||
|
source_type="cousin.package2.Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert imports == {"from ...cousin import package2 as __cousin_package2__"}
|
assert imports == {"from ...cousin import package2 as __cousin_package2__"}
|
||||||
assert name == '"__cousin_package2__.Message"'
|
assert name == '"__cousin_package2__.Message"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_cousin_package_same_name():
|
def test_reference_cousin_package_same_name(
|
||||||
|
typing_compiler: DirectImportTypingCompiler,
|
||||||
|
):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="test.package", imports=imports, source_type="cousin.package.Message"
|
package="test.package",
|
||||||
|
imports=imports,
|
||||||
|
source_type="cousin.package.Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert imports == {"from ...cousin import package as __cousin_package__"}
|
assert imports == {"from ...cousin import package as __cousin_package__"}
|
||||||
assert name == '"__cousin_package__.Message"'
|
assert name == '"__cousin_package__.Message"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_far_cousin_package():
|
def test_reference_far_cousin_package(typing_compiler: DirectImportTypingCompiler):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="a.x.y", imports=imports, source_type="a.b.c.Message"
|
package="a.x.y",
|
||||||
|
imports=imports,
|
||||||
|
source_type="a.b.c.Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert imports == {"from ...b import c as __b_c__"}
|
assert imports == {"from ...b import c as __b_c__"}
|
||||||
assert name == '"__b_c__.Message"'
|
assert name == '"__b_c__.Message"'
|
||||||
|
|
||||||
|
|
||||||
def test_reference_far_far_cousin_package():
|
def test_reference_far_far_cousin_package(typing_compiler: DirectImportTypingCompiler):
|
||||||
imports = set()
|
imports = set()
|
||||||
name = get_type_reference(
|
name = get_type_reference(
|
||||||
package="a.x.y.z", imports=imports, source_type="a.b.c.d.Message"
|
package="a.x.y.z",
|
||||||
|
imports=imports,
|
||||||
|
source_type="a.b.c.d.Message",
|
||||||
|
typing_compiler=typing_compiler,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert imports == {"from ....b.c import d as ___b_c_d__"}
|
assert imports == {"from ....b.c import d as ___b_c_d__"}
|
||||||
|
111
tests/test_module_validation.py
Normal file
111
tests/test_module_validation.py
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
from typing import (
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Set,
|
||||||
|
)
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from betterproto.plugin.module_validation import ModuleValidator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["text", "expected_collisions"],
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
["import os"],
|
||||||
|
None,
|
||||||
|
id="single import",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
["import os", "import sys"],
|
||||||
|
None,
|
||||||
|
id="multiple imports",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
["import os", "import os"],
|
||||||
|
{"os"},
|
||||||
|
id="duplicate imports",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
["from os import path", "import os"],
|
||||||
|
None,
|
||||||
|
id="duplicate imports with alias",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
["from os import path", "import os as os_alias"],
|
||||||
|
None,
|
||||||
|
id="duplicate imports with alias",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
["from os import path", "import os as path"],
|
||||||
|
{"path"},
|
||||||
|
id="duplicate imports with alias",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
["import os", "class os:"],
|
||||||
|
{"os"},
|
||||||
|
id="duplicate import with class",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
["import os", "class os:", " pass", "import sys"],
|
||||||
|
{"os"},
|
||||||
|
id="duplicate import with class and another",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
["def test(): pass", "class test:"],
|
||||||
|
{"test"},
|
||||||
|
id="duplicate class and function",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
["def test(): pass", "def test(): pass"],
|
||||||
|
{"test"},
|
||||||
|
id="duplicate functions",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
["def test(): pass", "test = 100"],
|
||||||
|
{"test"},
|
||||||
|
id="function and variable",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
["def test():", " test = 3"],
|
||||||
|
None,
|
||||||
|
id="function and variable in function",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
[
|
||||||
|
"def test(): pass",
|
||||||
|
"'''",
|
||||||
|
"def test(): pass",
|
||||||
|
"'''",
|
||||||
|
"def test_2(): pass",
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
id="duplicate functions with multiline string",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
["def test(): pass", "# def test(): pass"],
|
||||||
|
None,
|
||||||
|
id="duplicate functions with comments",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
["from test import (", " A", " B", " C", ")"],
|
||||||
|
None,
|
||||||
|
id="multiline import",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
["from test import (", " A", " B", " C", ")", "from test import A"],
|
||||||
|
{"A"},
|
||||||
|
id="multiline import with duplicate",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_module_validator(text: List[str], expected_collisions: Optional[Set[str]]):
|
||||||
|
line_iterator = iter(text)
|
||||||
|
validator = ModuleValidator(line_iterator)
|
||||||
|
valid = validator.validate()
|
||||||
|
if expected_collisions is None:
|
||||||
|
assert valid
|
||||||
|
else:
|
||||||
|
assert set(validator.collisions.keys()) == expected_collisions
|
||||||
|
assert not valid
|
80
tests/test_typing_compiler.py
Normal file
80
tests/test_typing_compiler.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from betterproto.plugin.typing_compiler import (
|
||||||
|
DirectImportTypingCompiler,
|
||||||
|
NoTyping310TypingCompiler,
|
||||||
|
TypingImportTypingCompiler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_direct_import_typing_compiler():
|
||||||
|
compiler = DirectImportTypingCompiler()
|
||||||
|
assert compiler.imports() == {}
|
||||||
|
assert compiler.optional("str") == "Optional[str]"
|
||||||
|
assert compiler.imports() == {"typing": {"Optional"}}
|
||||||
|
assert compiler.list("str") == "List[str]"
|
||||||
|
assert compiler.imports() == {"typing": {"Optional", "List"}}
|
||||||
|
assert compiler.dict("str", "int") == "Dict[str, int]"
|
||||||
|
assert compiler.imports() == {"typing": {"Optional", "List", "Dict"}}
|
||||||
|
assert compiler.union("str", "int") == "Union[str, int]"
|
||||||
|
assert compiler.imports() == {"typing": {"Optional", "List", "Dict", "Union"}}
|
||||||
|
assert compiler.iterable("str") == "Iterable[str]"
|
||||||
|
assert compiler.imports() == {
|
||||||
|
"typing": {"Optional", "List", "Dict", "Union", "Iterable"}
|
||||||
|
}
|
||||||
|
assert compiler.async_iterable("str") == "AsyncIterable[str]"
|
||||||
|
assert compiler.imports() == {
|
||||||
|
"typing": {"Optional", "List", "Dict", "Union", "Iterable", "AsyncIterable"}
|
||||||
|
}
|
||||||
|
assert compiler.async_iterator("str") == "AsyncIterator[str]"
|
||||||
|
assert compiler.imports() == {
|
||||||
|
"typing": {
|
||||||
|
"Optional",
|
||||||
|
"List",
|
||||||
|
"Dict",
|
||||||
|
"Union",
|
||||||
|
"Iterable",
|
||||||
|
"AsyncIterable",
|
||||||
|
"AsyncIterator",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_typing_import_typing_compiler():
|
||||||
|
compiler = TypingImportTypingCompiler()
|
||||||
|
assert compiler.imports() == {}
|
||||||
|
assert compiler.optional("str") == "typing.Optional[str]"
|
||||||
|
assert compiler.imports() == {"typing": None}
|
||||||
|
assert compiler.list("str") == "typing.List[str]"
|
||||||
|
assert compiler.imports() == {"typing": None}
|
||||||
|
assert compiler.dict("str", "int") == "typing.Dict[str, int]"
|
||||||
|
assert compiler.imports() == {"typing": None}
|
||||||
|
assert compiler.union("str", "int") == "typing.Union[str, int]"
|
||||||
|
assert compiler.imports() == {"typing": None}
|
||||||
|
assert compiler.iterable("str") == "typing.Iterable[str]"
|
||||||
|
assert compiler.imports() == {"typing": None}
|
||||||
|
assert compiler.async_iterable("str") == "typing.AsyncIterable[str]"
|
||||||
|
assert compiler.imports() == {"typing": None}
|
||||||
|
assert compiler.async_iterator("str") == "typing.AsyncIterator[str]"
|
||||||
|
assert compiler.imports() == {"typing": None}
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_typing_311_typing_compiler():
|
||||||
|
compiler = NoTyping310TypingCompiler()
|
||||||
|
assert compiler.imports() == {}
|
||||||
|
assert compiler.optional("str") == "str | None"
|
||||||
|
assert compiler.imports() == {}
|
||||||
|
assert compiler.list("str") == "list[str]"
|
||||||
|
assert compiler.imports() == {}
|
||||||
|
assert compiler.dict("str", "int") == "dict[str, int]"
|
||||||
|
assert compiler.imports() == {}
|
||||||
|
assert compiler.union("str", "int") == "str | int"
|
||||||
|
assert compiler.imports() == {}
|
||||||
|
assert compiler.iterable("str") == "Iterable[str]"
|
||||||
|
assert compiler.imports() == {"typing": {"Iterable"}}
|
||||||
|
assert compiler.async_iterable("str") == "AsyncIterable[str]"
|
||||||
|
assert compiler.imports() == {"typing": {"Iterable", "AsyncIterable"}}
|
||||||
|
assert compiler.async_iterator("str") == "AsyncIterator[str]"
|
||||||
|
assert compiler.imports() == {
|
||||||
|
"typing": {"Iterable", "AsyncIterable", "AsyncIterator"}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user