Merge branch 'refs/heads/master_gh'

This commit is contained in:
Georg K 2024-07-30 22:10:55 +03:00
commit 32eaa51e8d
19 changed files with 1726 additions and 791 deletions

View File

@ -16,7 +16,7 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
os: [Ubuntu, MacOS, Windows] os: [Ubuntu, MacOS, Windows]
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4

View File

@ -391,7 +391,50 @@ swap the dataclass implementation from the builtin python dataclass to the
pydantic dataclass. You must have pydantic as a dependency in your project for pydantic dataclass. You must have pydantic as a dependency in your project for
this to work. this to work.
## Configuration typing imports
By default typing types will be imported directly from typing. This sometimes can lead to issues in generation if types that are being generated conflict with the name. In this case you can configure the way types are imported from 3 different options:
### Direct
```
protoc -I . --python_betterproto_opt=typing.direct --python_betterproto_out=lib example.proto
```
this configuration is the default, and will import types as follows:
```
from typing import (
List,
Optional,
Union
)
...
value: List[str] = []
value2: Optional[str] = None
value3: Union[str, int] = 1
```
### Root
```
protoc -I . --python_betterproto_opt=typing.root --python_betterproto_out=lib example.proto
```
this configuration loads the root typing module, and then access the types off of it directly:
```
import typing
...
value: typing.List[str] = []
value2: typing.Optional[str] = None
value3: typing.Union[str, int] = 1
```
### 310
```
protoc -I . --python_betterproto_opt=typing.310 --python_betterproto_out=lib example.proto
```
this configuration avoid loading typing all together if possible and uses the python 3.10 pattern:
```
...
value: list[str] = []
value2: str | None = None
value3: str | int = 1
```
## Development ## Development

1138
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -39,7 +39,7 @@ pytest = "^6.2.5"
pytest-asyncio = "^0.12.0" pytest-asyncio = "^0.12.0"
pytest-cov = "^2.9.0" pytest-cov = "^2.9.0"
pytest-mock = "^3.1.1" pytest-mock = "^3.1.1"
pydantic = ">=1.8.0,<2" pydantic = ">=2.0,<3"
protobuf = "^4" protobuf = "^4"
cachelib = "^0.10.2" cachelib = "^0.10.2"
tomlkit = ">=0.7.0" tomlkit = ">=0.7.0"

View File

@ -1852,7 +1852,9 @@ class Message(ABC):
continue continue
set_fields = [ set_fields = [
field.name for field in field_set if values[field.name] is not None field.name
for field in field_set
if getattr(values, field.name, None) is not None
] ]
if not set_fields: if not set_fields:

View File

@ -47,6 +47,7 @@ def get_type_reference(
package: str, package: str,
imports: set, imports: set,
source_type: str, source_type: str,
typing_compiler: "TypingCompiler",
unwrap: bool = True, unwrap: bool = True,
pydantic: bool = False, pydantic: bool = False,
) -> str: ) -> str:
@ -57,7 +58,7 @@ def get_type_reference(
if unwrap: if unwrap:
if source_type in WRAPPER_TYPES: if source_type in WRAPPER_TYPES:
wrapped_type = type(WRAPPER_TYPES[source_type]().value) wrapped_type = type(WRAPPER_TYPES[source_type]().value)
return f"Optional[{wrapped_type.__name__}]" return typing_compiler.optional(wrapped_type.__name__)
if source_type == ".google.protobuf.Duration": if source_type == ".google.protobuf.Duration":
return "timedelta" return "timedelta"

View File

@ -1,9 +1,16 @@
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: google/protobuf/any.proto, google/protobuf/api.proto, google/protobuf/descriptor.proto, google/protobuf/duration.proto, google/protobuf/empty.proto, google/protobuf/field_mask.proto, google/protobuf/source_context.proto, google/protobuf/struct.proto, google/protobuf/timestamp.proto, google/protobuf/type.proto, google/protobuf/wrappers.proto # sources: google/protobuf/any.proto, google/protobuf/api.proto, google/protobuf/descriptor.proto, google/protobuf/duration.proto, google/protobuf/empty.proto, google/protobuf/field_mask.proto, google/protobuf/source_context.proto, google/protobuf/struct.proto, google/protobuf/timestamp.proto, google/protobuf/type.proto, google/protobuf/wrappers.proto
# plugin: python-betterproto # plugin: python-betterproto
# This file has been @generated
import warnings import warnings
from typing import TYPE_CHECKING from typing import (
TYPE_CHECKING,
Mapping,
)
from typing_extensions import Self
from betterproto import hybridmethod
if TYPE_CHECKING: if TYPE_CHECKING:
@ -14,15 +21,13 @@ else:
from typing import ( from typing import (
Dict, Dict,
List, List,
Mapping,
Optional, Optional,
) )
from pydantic import root_validator from pydantic import model_validator
from typing_extensions import Self from pydantic.dataclasses import rebuild_dataclass
import betterproto import betterproto
from betterproto.utils import hybridmethod
class Syntax(betterproto.Enum): class Syntax(betterproto.Enum):
@ -37,6 +42,12 @@ class Syntax(betterproto.Enum):
EDITIONS = 2 EDITIONS = 2
"""Syntax `editions`.""" """Syntax `editions`."""
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FieldKind(betterproto.Enum): class FieldKind(betterproto.Enum):
"""Basic field types.""" """Basic field types."""
@ -98,6 +109,12 @@ class FieldKind(betterproto.Enum):
TYPE_SINT64 = 18 TYPE_SINT64 = 18
"""Field type sint64.""" """Field type sint64."""
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FieldCardinality(betterproto.Enum): class FieldCardinality(betterproto.Enum):
"""Whether a field is optional, required, or repeated.""" """Whether a field is optional, required, or repeated."""
@ -114,6 +131,12 @@ class FieldCardinality(betterproto.Enum):
CARDINALITY_REPEATED = 3 CARDINALITY_REPEATED = 3
"""For repeated fields.""" """For repeated fields."""
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class Edition(betterproto.Enum): class Edition(betterproto.Enum):
"""The full set of known editions.""" """The full set of known editions."""
@ -155,6 +178,12 @@ class Edition(betterproto.Enum):
support a new edition. support a new edition.
""" """
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class ExtensionRangeOptionsVerificationState(betterproto.Enum): class ExtensionRangeOptionsVerificationState(betterproto.Enum):
"""The verification state of the extension range.""" """The verification state of the extension range."""
@ -164,6 +193,12 @@ class ExtensionRangeOptionsVerificationState(betterproto.Enum):
UNVERIFIED = 1 UNVERIFIED = 1
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FieldDescriptorProtoType(betterproto.Enum): class FieldDescriptorProtoType(betterproto.Enum):
TYPE_DOUBLE = 1 TYPE_DOUBLE = 1
@ -210,6 +245,12 @@ class FieldDescriptorProtoType(betterproto.Enum):
TYPE_SINT32 = 17 TYPE_SINT32 = 17
TYPE_SINT64 = 18 TYPE_SINT64 = 18
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FieldDescriptorProtoLabel(betterproto.Enum): class FieldDescriptorProtoLabel(betterproto.Enum):
LABEL_OPTIONAL = 1 LABEL_OPTIONAL = 1
@ -223,6 +264,12 @@ class FieldDescriptorProtoLabel(betterproto.Enum):
can be used to get this behavior. can be used to get this behavior.
""" """
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FileOptionsOptimizeMode(betterproto.Enum): class FileOptionsOptimizeMode(betterproto.Enum):
"""Generated classes can be optimized for speed or code size.""" """Generated classes can be optimized for speed or code size."""
@ -233,6 +280,12 @@ class FileOptionsOptimizeMode(betterproto.Enum):
LITE_RUNTIME = 3 LITE_RUNTIME = 3
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FieldOptionsCType(betterproto.Enum): class FieldOptionsCType(betterproto.Enum):
STRING = 0 STRING = 0
@ -250,6 +303,12 @@ class FieldOptionsCType(betterproto.Enum):
STRING_PIECE = 2 STRING_PIECE = 2
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FieldOptionsJsType(betterproto.Enum): class FieldOptionsJsType(betterproto.Enum):
JS_NORMAL = 0 JS_NORMAL = 0
@ -261,6 +320,12 @@ class FieldOptionsJsType(betterproto.Enum):
JS_NUMBER = 2 JS_NUMBER = 2
"""Use JavaScript numbers.""" """Use JavaScript numbers."""
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FieldOptionsOptionRetention(betterproto.Enum): class FieldOptionsOptionRetention(betterproto.Enum):
""" """
@ -273,6 +338,12 @@ class FieldOptionsOptionRetention(betterproto.Enum):
RETENTION_RUNTIME = 1 RETENTION_RUNTIME = 1
RETENTION_SOURCE = 2 RETENTION_SOURCE = 2
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FieldOptionsOptionTargetType(betterproto.Enum): class FieldOptionsOptionTargetType(betterproto.Enum):
""" """
@ -293,6 +364,12 @@ class FieldOptionsOptionTargetType(betterproto.Enum):
TARGET_TYPE_SERVICE = 8 TARGET_TYPE_SERVICE = 8
TARGET_TYPE_METHOD = 9 TARGET_TYPE_METHOD = 9
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class MethodOptionsIdempotencyLevel(betterproto.Enum): class MethodOptionsIdempotencyLevel(betterproto.Enum):
""" """
@ -305,6 +382,12 @@ class MethodOptionsIdempotencyLevel(betterproto.Enum):
NO_SIDE_EFFECTS = 1 NO_SIDE_EFFECTS = 1
IDEMPOTENT = 2 IDEMPOTENT = 2
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FeatureSetFieldPresence(betterproto.Enum): class FeatureSetFieldPresence(betterproto.Enum):
FIELD_PRESENCE_UNKNOWN = 0 FIELD_PRESENCE_UNKNOWN = 0
@ -312,36 +395,72 @@ class FeatureSetFieldPresence(betterproto.Enum):
IMPLICIT = 2 IMPLICIT = 2
LEGACY_REQUIRED = 3 LEGACY_REQUIRED = 3
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FeatureSetEnumType(betterproto.Enum): class FeatureSetEnumType(betterproto.Enum):
ENUM_TYPE_UNKNOWN = 0 ENUM_TYPE_UNKNOWN = 0
OPEN = 1 OPEN = 1
CLOSED = 2 CLOSED = 2
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FeatureSetRepeatedFieldEncoding(betterproto.Enum): class FeatureSetRepeatedFieldEncoding(betterproto.Enum):
REPEATED_FIELD_ENCODING_UNKNOWN = 0 REPEATED_FIELD_ENCODING_UNKNOWN = 0
PACKED = 1 PACKED = 1
EXPANDED = 2 EXPANDED = 2
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FeatureSetUtf8Validation(betterproto.Enum): class FeatureSetUtf8Validation(betterproto.Enum):
UTF8_VALIDATION_UNKNOWN = 0 UTF8_VALIDATION_UNKNOWN = 0
VERIFY = 2 VERIFY = 2
NONE = 3 NONE = 3
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FeatureSetMessageEncoding(betterproto.Enum): class FeatureSetMessageEncoding(betterproto.Enum):
MESSAGE_ENCODING_UNKNOWN = 0 MESSAGE_ENCODING_UNKNOWN = 0
LENGTH_PREFIXED = 1 LENGTH_PREFIXED = 1
DELIMITED = 2 DELIMITED = 2
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FeatureSetJsonFormat(betterproto.Enum): class FeatureSetJsonFormat(betterproto.Enum):
JSON_FORMAT_UNKNOWN = 0 JSON_FORMAT_UNKNOWN = 0
ALLOW = 1 ALLOW = 1
LEGACY_BEST_EFFORT = 2 LEGACY_BEST_EFFORT = 2
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class GeneratedCodeInfoAnnotationSemantic(betterproto.Enum): class GeneratedCodeInfoAnnotationSemantic(betterproto.Enum):
""" """
@ -358,6 +477,12 @@ class GeneratedCodeInfoAnnotationSemantic(betterproto.Enum):
ALIAS = 2 ALIAS = 2
"""An alias to the element is returned.""" """An alias to the element is returned."""
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class NullValue(betterproto.Enum): class NullValue(betterproto.Enum):
""" """
@ -370,6 +495,12 @@ class NullValue(betterproto.Enum):
_ = 0 _ = 0
"""Null value.""" """Null value."""
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
@dataclass(eq=False, repr=False) @dataclass(eq=False, repr=False)
class Any(betterproto.Message): class Any(betterproto.Message):
@ -1176,16 +1307,12 @@ class FileOptions(betterproto.Message):
java_string_check_utf8: bool = betterproto.bool_field(27) java_string_check_utf8: bool = betterproto.bool_field(27)
""" """
A proto2 file can set this to true to opt in to UTF-8 checking for Java, If set true, then the Java2 code generator will generate code that
which will throw an exception if invalid UTF-8 is parsed from the wire or throws an exception whenever an attempt is made to assign a non-UTF-8
assigned to a string field. byte sequence to a string field.
Message reflection will do the same.
TODO: clarify exactly what kinds of field types this option However, an extension field still accepts non-UTF-8 byte sequences.
applies to, and update these docs accordingly. This option has no effect on when used with the lite runtime.
Proto3 files already perform these checks. Setting the option explicitly to
false has no effect: it cannot be used to opt proto3 files out of UTF-8
checks.
""" """
optimize_for: "FileOptionsOptimizeMode" = betterproto.enum_field(9) optimize_for: "FileOptionsOptimizeMode" = betterproto.enum_field(9)
@ -1477,7 +1604,6 @@ class FieldOptions(betterproto.Message):
features: "FeatureSet" = betterproto.message_field(21) features: "FeatureSet" = betterproto.message_field(21)
"""Any features defined in the specific edition.""" """Any features defined in the specific edition."""
feature_support: "FieldOptionsFeatureSupport" = betterproto.message_field(22)
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
"""The parser stores options it doesn't recognize here. See above.""" """The parser stores options it doesn't recognize here. See above."""
@ -1488,37 +1614,6 @@ class FieldOptionsEditionDefault(betterproto.Message):
value: str = betterproto.string_field(2) value: str = betterproto.string_field(2)
@dataclass(eq=False, repr=False)
class FieldOptionsFeatureSupport(betterproto.Message):
"""Information about the support window of a feature."""
edition_introduced: "Edition" = betterproto.enum_field(1)
"""
The edition that this feature was first available in. In editions
earlier than this one, the default assigned to EDITION_LEGACY will be
used, and proto files will not be able to override it.
"""
edition_deprecated: "Edition" = betterproto.enum_field(2)
"""
The edition this feature becomes deprecated in. Using this after this
edition may trigger warnings.
"""
deprecation_warning: str = betterproto.string_field(3)
"""
The deprecation warning text if this feature is used after the edition it
was marked deprecated in.
"""
edition_removed: "Edition" = betterproto.enum_field(4)
"""
The edition this feature is no longer available in. In editions after
this one, the last default assigned will be used, and proto files will
not be able to override it.
"""
@dataclass(eq=False, repr=False) @dataclass(eq=False, repr=False)
class OneofOptions(betterproto.Message): class OneofOptions(betterproto.Message):
features: "FeatureSet" = betterproto.message_field(1) features: "FeatureSet" = betterproto.message_field(1)
@ -1723,17 +1818,7 @@ class FeatureSetDefaultsFeatureSetEditionDefault(betterproto.Message):
""" """
edition: "Edition" = betterproto.enum_field(3) edition: "Edition" = betterproto.enum_field(3)
overridable_features: "FeatureSet" = betterproto.message_field(4)
"""Defaults of features that can be overridden in this edition."""
fixed_features: "FeatureSet" = betterproto.message_field(5)
"""Defaults of features that can't be overridden in this edition."""
features: "FeatureSet" = betterproto.message_field(2) features: "FeatureSet" = betterproto.message_field(2)
"""
TODO Deprecate and remove this field, which is just the
above two merged.
"""
@dataclass(eq=False, repr=False) @dataclass(eq=False, repr=False)
@ -2314,7 +2399,7 @@ class Value(betterproto.Message):
) )
"""Represents a repeated `Value`.""" """Represents a repeated `Value`."""
@root_validator() @model_validator(mode="after")
def check_oneof(cls, values): def check_oneof(cls, values):
return cls._validate_field_groups(values) return cls._validate_field_groups(values)
@ -2549,41 +2634,40 @@ class BytesValue(betterproto.Message):
"""The bytes value.""" """The bytes value."""
Type.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(Type) # type: ignore
Field.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(Field) # type: ignore
Enum.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(Enum) # type: ignore
EnumValue.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(EnumValue) # type: ignore
Option.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(Option) # type: ignore
Api.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(Api) # type: ignore
Method.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(Method) # type: ignore
FileDescriptorSet.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(FileDescriptorSet) # type: ignore
FileDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(FileDescriptorProto) # type: ignore
DescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(DescriptorProto) # type: ignore
DescriptorProtoExtensionRange.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(DescriptorProtoExtensionRange) # type: ignore
ExtensionRangeOptions.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(ExtensionRangeOptions) # type: ignore
FieldDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(FieldDescriptorProto) # type: ignore
OneofDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(OneofDescriptorProto) # type: ignore
EnumDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(EnumDescriptorProto) # type: ignore
EnumValueDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(EnumValueDescriptorProto) # type: ignore
ServiceDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(ServiceDescriptorProto) # type: ignore
MethodDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(MethodDescriptorProto) # type: ignore
FileOptions.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(FileOptions) # type: ignore
MessageOptions.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(MessageOptions) # type: ignore
FieldOptions.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(FieldOptions) # type: ignore
FieldOptionsEditionDefault.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(FieldOptionsEditionDefault) # type: ignore
FieldOptionsFeatureSupport.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(OneofOptions) # type: ignore
OneofOptions.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(EnumOptions) # type: ignore
EnumOptions.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(EnumValueOptions) # type: ignore
EnumValueOptions.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(ServiceOptions) # type: ignore
ServiceOptions.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(MethodOptions) # type: ignore
MethodOptions.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(UninterpretedOption) # type: ignore
UninterpretedOption.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(FeatureSet) # type: ignore
FeatureSet.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(FeatureSetDefaults) # type: ignore
FeatureSetDefaults.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(FeatureSetDefaultsFeatureSetEditionDefault) # type: ignore
FeatureSetDefaultsFeatureSetEditionDefault.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(SourceCodeInfo) # type: ignore
SourceCodeInfo.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(GeneratedCodeInfo) # type: ignore
GeneratedCodeInfo.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(GeneratedCodeInfoAnnotation) # type: ignore
GeneratedCodeInfoAnnotation.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(Struct) # type: ignore
Struct.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(Value) # type: ignore
Value.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass(ListValue) # type: ignore
ListValue.__pydantic_model__.update_forward_refs() # type: ignore

View File

@ -1,4 +1,7 @@
import os.path import os.path
import sys
from .module_validation import ModuleValidator
try: try:
@ -30,9 +33,12 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
lstrip_blocks=True, lstrip_blocks=True,
loader=jinja2.FileSystemLoader(templates_folder), loader=jinja2.FileSystemLoader(templates_folder),
) )
template = env.get_template("template.py.j2") # Load the body first so we have a compleate list of imports needed.
body_template = env.get_template("template.py.j2")
header_template = env.get_template("header.py.j2")
code = template.render(output_file=output_file) code = body_template.render(output_file=output_file)
code = header_template.render(output_file=output_file) + code
code = isort.api.sort_code_string( code = isort.api.sort_code_string(
code=code, code=code,
show_diff=False, show_diff=False,
@ -44,7 +50,18 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
force_grid_wrap=2, force_grid_wrap=2,
known_third_party=["grpclib", "betterproto"], known_third_party=["grpclib", "betterproto"],
) )
return black.format_str( code = black.format_str(
src_contents=code, src_contents=code,
mode=black.Mode(), mode=black.Mode(),
) )
# Validate the generated code.
validator = ModuleValidator(iter(code.splitlines()))
if not validator.validate():
message_builder = ["[WARNING]: Generated code has collisions in the module:"]
for collision, lines in validator.collisions.items():
message_builder.append(f' "{collision}" on lines:')
for num, line in lines:
message_builder.append(f" {num}:{line}")
print("\n".join(message_builder), file=sys.stderr)
return code

View File

@ -29,10 +29,8 @@ instantiating field `A` with parent message `B` should add a
reference to `A` to `B`'s `fields` attribute. reference to `A` to `B`'s `fields` attribute.
""" """
import builtins import builtins
import re import re
import textwrap
from dataclasses import ( from dataclasses import (
dataclass, dataclass,
field, field,
@ -49,12 +47,6 @@ from typing import (
) )
import betterproto import betterproto
from betterproto import which_one_of
from betterproto.casing import sanitize_name
from betterproto.compile.importing import (
get_type_reference,
parse_source_type_name,
)
from betterproto.compile.naming import ( from betterproto.compile.naming import (
pythonize_class_name, pythonize_class_name,
pythonize_field_name, pythonize_field_name,
@ -72,6 +64,7 @@ from betterproto.lib.google.protobuf import (
) )
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
from .. import which_one_of
from ..compile.importing import ( from ..compile.importing import (
get_type_reference, get_type_reference,
parse_source_type_name, parse_source_type_name,
@ -82,6 +75,10 @@ from ..compile.naming import (
pythonize_field_name, pythonize_field_name,
pythonize_method_name, pythonize_method_name,
) )
from .typing_compiler import (
DirectImportTypingCompiler,
TypingCompiler,
)
# Create a unique placeholder to deal with # Create a unique placeholder to deal with
@ -173,6 +170,7 @@ class ProtoContentBase:
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler.""" """Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
source_file: FileDescriptorProto source_file: FileDescriptorProto
typing_compiler: TypingCompiler
path: List[int] path: List[int]
comment_indent: int = 4 comment_indent: int = 4
parent: Union["betterproto.Message", "OutputTemplate"] parent: Union["betterproto.Message", "OutputTemplate"]
@ -242,7 +240,6 @@ class OutputTemplate:
input_files: List[str] = field(default_factory=list) input_files: List[str] = field(default_factory=list)
imports: Set[str] = field(default_factory=set) imports: Set[str] = field(default_factory=set)
datetime_imports: Set[str] = field(default_factory=set) datetime_imports: Set[str] = field(default_factory=set)
typing_imports: Set[str] = field(default_factory=set)
pydantic_imports: Set[str] = field(default_factory=set) pydantic_imports: Set[str] = field(default_factory=set)
builtins_import: bool = False builtins_import: bool = False
messages: List["MessageCompiler"] = field(default_factory=list) messages: List["MessageCompiler"] = field(default_factory=list)
@ -251,6 +248,7 @@ class OutputTemplate:
imports_type_checking_only: Set[str] = field(default_factory=set) imports_type_checking_only: Set[str] = field(default_factory=set)
pydantic_dataclasses: bool = False pydantic_dataclasses: bool = False
output: bool = True output: bool = True
typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)
@property @property
def package(self) -> str: def package(self) -> str:
@ -289,6 +287,7 @@ class MessageCompiler(ProtoContentBase):
"""Representation of a protobuf message.""" """Representation of a protobuf message."""
source_file: FileDescriptorProto source_file: FileDescriptorProto
typing_compiler: TypingCompiler
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
proto_obj: DescriptorProto = PLACEHOLDER proto_obj: DescriptorProto = PLACEHOLDER
path: List[int] = PLACEHOLDER path: List[int] = PLACEHOLDER
@ -319,7 +318,7 @@ class MessageCompiler(ProtoContentBase):
@property @property
def annotation(self) -> str: def annotation(self) -> str:
if self.repeated: if self.repeated:
return f"List[{self.py_name}]" return self.typing_compiler.list(self.py_name)
return self.py_name return self.py_name
@property @property
@ -434,18 +433,6 @@ class FieldCompiler(MessageCompiler):
imports.add("datetime") imports.add("datetime")
return imports return imports
@property
def typing_imports(self) -> Set[str]:
imports = set()
annotation = self.annotation
if "Optional[" in annotation:
imports.add("Optional")
if "List[" in annotation:
imports.add("List")
if "Dict[" in annotation:
imports.add("Dict")
return imports
@property @property
def pydantic_imports(self) -> Set[str]: def pydantic_imports(self) -> Set[str]:
return set() return set()
@ -458,7 +445,6 @@ class FieldCompiler(MessageCompiler):
def add_imports_to(self, output_file: OutputTemplate) -> None: def add_imports_to(self, output_file: OutputTemplate) -> None:
output_file.datetime_imports.update(self.datetime_imports) output_file.datetime_imports.update(self.datetime_imports)
output_file.typing_imports.update(self.typing_imports)
output_file.pydantic_imports.update(self.pydantic_imports) output_file.pydantic_imports.update(self.pydantic_imports)
output_file.builtins_import = output_file.builtins_import or self.use_builtins output_file.builtins_import = output_file.builtins_import or self.use_builtins
@ -488,7 +474,9 @@ class FieldCompiler(MessageCompiler):
@property @property
def mutable(self) -> bool: def mutable(self) -> bool:
"""True if the field is a mutable type, otherwise False.""" """True if the field is a mutable type, otherwise False."""
return self.annotation.startswith(("List[", "Dict[")) return self.annotation.startswith(
("typing.List[", "typing.Dict[", "dict[", "list[", "Dict[", "List[")
)
@property @property
def field_type(self) -> str: def field_type(self) -> str:
@ -562,6 +550,7 @@ class FieldCompiler(MessageCompiler):
package=self.output_file.package, package=self.output_file.package,
imports=self.output_file.imports, imports=self.output_file.imports,
source_type=self.proto_obj.type_name, source_type=self.proto_obj.type_name,
typing_compiler=self.typing_compiler,
pydantic=self.output_file.pydantic_dataclasses, pydantic=self.output_file.pydantic_dataclasses,
) )
else: else:
@ -573,9 +562,9 @@ class FieldCompiler(MessageCompiler):
if self.use_builtins: if self.use_builtins:
py_type = f"builtins.{py_type}" py_type = f"builtins.{py_type}"
if self.repeated: if self.repeated:
return f"List[{py_type}]" return self.typing_compiler.list(py_type)
if self.optional: if self.optional:
return f"Optional[{py_type}]" return self.typing_compiler.optional(py_type)
return py_type return py_type
@ -600,7 +589,7 @@ class PydanticOneOfFieldCompiler(OneOfFieldCompiler):
@property @property
def pydantic_imports(self) -> Set[str]: def pydantic_imports(self) -> Set[str]:
return {"root_validator"} return {"model_validator"}
@dataclass @dataclass
@ -623,11 +612,13 @@ class MapEntryCompiler(FieldCompiler):
source_file=self.source_file, source_file=self.source_file,
parent=self, parent=self,
proto_obj=nested.field[0], # key proto_obj=nested.field[0], # key
typing_compiler=self.typing_compiler,
).py_type ).py_type
self.py_v_type = FieldCompiler( self.py_v_type = FieldCompiler(
source_file=self.source_file, source_file=self.source_file,
parent=self, parent=self,
proto_obj=nested.field[1], # value proto_obj=nested.field[1], # value
typing_compiler=self.typing_compiler,
).py_type ).py_type
# Get proto types # Get proto types
@ -645,7 +636,7 @@ class MapEntryCompiler(FieldCompiler):
@property @property
def annotation(self) -> str: def annotation(self) -> str:
return f"Dict[{self.py_k_type}, {self.py_v_type}]" return self.typing_compiler.dict(self.py_k_type, self.py_v_type)
@property @property
def repeated(self) -> bool: def repeated(self) -> bool:
@ -702,7 +693,6 @@ class ServiceCompiler(ProtoContentBase):
def __post_init__(self) -> None: def __post_init__(self) -> None:
# Add service to output file # Add service to output file
self.output_file.services.append(self) self.output_file.services.append(self)
self.output_file.typing_imports.add("Dict")
super().__post_init__() # check for unset fields super().__post_init__() # check for unset fields
@property @property
@ -725,22 +715,6 @@ class ServiceMethodCompiler(ProtoContentBase):
# Add method to service # Add method to service
self.parent.methods.append(self) self.parent.methods.append(self)
# Check for imports
if "Optional" in self.py_output_message_type:
self.output_file.typing_imports.add("Optional")
# Check for Async imports
if self.client_streaming:
self.output_file.typing_imports.add("AsyncIterable")
self.output_file.typing_imports.add("Iterable")
self.output_file.typing_imports.add("Union")
# Required by both client and server
if self.client_streaming or self.server_streaming:
self.output_file.typing_imports.add("AsyncIterator")
# add imports required for request arguments timeout, deadline and metadata
self.output_file.typing_imports.add("Optional")
self.output_file.imports_type_checking_only.add("import grpclib.server") self.output_file.imports_type_checking_only.add("import grpclib.server")
self.output_file.imports_type_checking_only.add( self.output_file.imports_type_checking_only.add(
"from betterproto.grpc.grpclib_client import MetadataLike" "from betterproto.grpc.grpclib_client import MetadataLike"
@ -806,6 +780,7 @@ class ServiceMethodCompiler(ProtoContentBase):
package=self.output_file.package, package=self.output_file.package,
imports=self.output_file.imports, imports=self.output_file.imports,
source_type=self.proto_obj.input_type, source_type=self.proto_obj.input_type,
typing_compiler=self.output_file.typing_compiler,
unwrap=False, unwrap=False,
pydantic=self.output_file.pydantic_dataclasses, pydantic=self.output_file.pydantic_dataclasses,
).strip('"') ).strip('"')
@ -835,6 +810,7 @@ class ServiceMethodCompiler(ProtoContentBase):
package=self.output_file.package, package=self.output_file.package,
imports=self.output_file.imports, imports=self.output_file.imports,
source_type=self.proto_obj.output_type, source_type=self.proto_obj.output_type,
typing_compiler=self.output_file.typing_compiler,
unwrap=False, unwrap=False,
pydantic=self.output_file.pydantic_dataclasses, pydantic=self.output_file.pydantic_dataclasses,
).strip('"') ).strip('"')

View File

@ -0,0 +1,163 @@
import re
from collections import defaultdict
from dataclasses import (
dataclass,
field,
)
from typing import (
Dict,
Iterator,
List,
Tuple,
)
@dataclass
class ModuleValidator:
line_iterator: Iterator[str]
line_number: int = field(init=False, default=0)
collisions: Dict[str, List[Tuple[int, str]]] = field(
init=False, default_factory=lambda: defaultdict(list)
)
def add_import(self, imp: str, number: int, full_line: str):
"""
Adds an import to be tracked.
"""
self.collisions[imp].append((number, full_line))
def process_import(self, imp: str):
"""
Filters out the import to its actual value.
"""
if " as " in imp:
imp = imp[imp.index(" as ") + 4 :]
imp = imp.strip()
assert " " not in imp, imp
return imp
def evaluate_multiline_import(self, line: str):
"""
Evaluates a multiline import from a starting line
"""
# Filter the first line and remove anything before the import statement.
full_line = line
line = line.split("import", 1)[1]
if "(" in line:
conditional = lambda line: ")" not in line
else:
conditional = lambda line: "\\" in line
# Remove open parenthesis if it exists.
if "(" in line:
line = line[line.index("(") + 1 :]
# Choose the conditional based on how multiline imports are formatted.
while conditional(line):
# Split the line by commas
imports = line.split(",")
for imp in imports:
# Add the import to the namespace
imp = self.process_import(imp)
if imp:
self.add_import(imp, self.line_number, full_line)
# Get the next line
full_line = line = next(self.line_iterator)
# Increment the line number
self.line_number += 1
# validate the last line
if ")" in line:
line = line[: line.index(")")]
imports = line.split(",")
for imp in imports:
imp = self.process_import(imp)
if imp:
self.add_import(imp, self.line_number, full_line)
def evaluate_import(self, line: str):
"""
Extracts an import from a line.
"""
whole_line = line
line = line[line.index("import") + 6 :]
values = line.split(",")
for v in values:
self.add_import(self.process_import(v), self.line_number, whole_line)
def next(self):
"""
Evaluate each line for names in the module.
"""
line = next(self.line_iterator)
# Skip lines with indentation or comments
if (
# Skip indents and whitespace.
line.startswith(" ")
or line == "\n"
or line.startswith("\t")
or
# Skip comments
line.startswith("#")
or
# Skip decorators
line.startswith("@")
):
self.line_number += 1
return
# Skip docstrings.
if line.startswith('"""') or line.startswith("'''"):
quote = line[0] * 3
line = line[3:]
while quote not in line:
line = next(self.line_iterator)
self.line_number += 1
return
# Evaluate Imports.
if line.startswith("from ") or line.startswith("import "):
if "(" in line or "\\" in line:
self.evaluate_multiline_import(line)
else:
self.evaluate_import(line)
# Evaluate Classes.
elif line.startswith("class "):
class_name = re.search(r"class (\w+)", line).group(1)
if class_name:
self.add_import(class_name, self.line_number, line)
# Evaluate Functions.
elif line.startswith("def "):
function_name = re.search(r"def (\w+)", line).group(1)
if function_name:
self.add_import(function_name, self.line_number, line)
# Evaluate direct assignments.
elif "=" in line:
assignment = re.search(r"(\w+)\s*=", line).group(1)
if assignment:
self.add_import(assignment, self.line_number, line)
self.line_number += 1
def validate(self) -> bool:
"""
Run Validation.
"""
try:
while True:
self.next()
except StopIteration:
pass
# Filter collisions for those with more than one value.
self.collisions = {k: v for k, v in self.collisions.items() if len(v) > 1}
# Return True if no collisions are found.
return not bool(self.collisions)

View File

@ -37,6 +37,12 @@ from .models import (
is_map, is_map,
is_oneof, is_oneof,
) )
from .typing_compiler import (
DirectImportTypingCompiler,
NoTyping310TypingCompiler,
TypingCompiler,
TypingImportTypingCompiler,
)
def traverse( def traverse(
@ -98,6 +104,28 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
output_package_name output_package_name
].pydantic_dataclasses = True ].pydantic_dataclasses = True
# Gather any typing generation options.
typing_opts = [
opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")
]
if len(typing_opts) > 1:
raise ValueError("Multiple typing options provided")
# Set the compiler type.
typing_opt = typing_opts[0] if typing_opts else "direct"
if typing_opt == "direct":
request_data.output_packages[
output_package_name
].typing_compiler = DirectImportTypingCompiler()
elif typing_opt == "root":
request_data.output_packages[
output_package_name
].typing_compiler = TypingImportTypingCompiler()
elif typing_opt == "310":
request_data.output_packages[
output_package_name
].typing_compiler = NoTyping310TypingCompiler()
# Read Messages and Enums # Read Messages and Enums
# We need to read Messages before Services in so that we can # We need to read Messages before Services in so that we can
# get the references to input/output messages for each service # get the references to input/output messages for each service
@ -166,6 +194,7 @@ def _make_one_of_field_compiler(
parent=parent, parent=parent,
proto_obj=proto_obj, proto_obj=proto_obj,
path=path, path=path,
typing_compiler=output_package.typing_compiler,
) )
@ -181,7 +210,11 @@ def read_protobuf_type(
return return
# Process Message # Process Message
message_data = MessageCompiler( message_data = MessageCompiler(
source_file=source_file, parent=output_package, proto_obj=item, path=path source_file=source_file,
parent=output_package,
proto_obj=item,
path=path,
typing_compiler=output_package.typing_compiler,
) )
for index, field in enumerate(item.field): for index, field in enumerate(item.field):
if is_map(field, item): if is_map(field, item):
@ -190,6 +223,7 @@ def read_protobuf_type(
parent=message_data, parent=message_data,
proto_obj=field, proto_obj=field,
path=path + [2, index], path=path + [2, index],
typing_compiler=output_package.typing_compiler,
) )
elif is_oneof(field): elif is_oneof(field):
_make_one_of_field_compiler( _make_one_of_field_compiler(
@ -201,11 +235,16 @@ def read_protobuf_type(
parent=message_data, parent=message_data,
proto_obj=field, proto_obj=field,
path=path + [2, index], path=path + [2, index],
typing_compiler=output_package.typing_compiler,
) )
elif isinstance(item, EnumDescriptorProto): elif isinstance(item, EnumDescriptorProto):
# Enum # Enum
EnumDefinitionCompiler( EnumDefinitionCompiler(
source_file=source_file, parent=output_package, proto_obj=item, path=path source_file=source_file,
parent=output_package,
proto_obj=item,
path=path,
typing_compiler=output_package.typing_compiler,
) )

View File

@ -0,0 +1,167 @@
import abc
from collections import defaultdict
from dataclasses import (
dataclass,
field,
)
from typing import (
Dict,
Iterator,
Optional,
Set,
)
class TypingCompiler(metaclass=abc.ABCMeta):
@abc.abstractmethod
def optional(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def list(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def dict(self, key: str, value: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def union(self, *types: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def iterable(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def async_iterable(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def async_iterator(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def imports(self) -> Dict[str, Optional[Set[str]]]:
"""
Returns either the direct import as a key with none as value, or a set of
values to import from the key.
"""
raise NotImplementedError()
def import_lines(self) -> Iterator:
imports = self.imports()
for key, value in imports.items():
if value is None:
yield f"import {key}"
else:
yield f"from {key} import ("
for v in sorted(value):
yield f" {v},"
yield ")"
@dataclass
class DirectImportTypingCompiler(TypingCompiler):
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
def optional(self, type: str) -> str:
self._imports["typing"].add("Optional")
return f"Optional[{type}]"
def list(self, type: str) -> str:
self._imports["typing"].add("List")
return f"List[{type}]"
def dict(self, key: str, value: str) -> str:
self._imports["typing"].add("Dict")
return f"Dict[{key}, {value}]"
def union(self, *types: str) -> str:
self._imports["typing"].add("Union")
return f"Union[{', '.join(types)}]"
def iterable(self, type: str) -> str:
self._imports["typing"].add("Iterable")
return f"Iterable[{type}]"
def async_iterable(self, type: str) -> str:
self._imports["typing"].add("AsyncIterable")
return f"AsyncIterable[{type}]"
def async_iterator(self, type: str) -> str:
self._imports["typing"].add("AsyncIterator")
return f"AsyncIterator[{type}]"
def imports(self) -> Dict[str, Optional[Set[str]]]:
return {k: v if v else None for k, v in self._imports.items()}
@dataclass
class TypingImportTypingCompiler(TypingCompiler):
_imported: bool = False
def optional(self, type: str) -> str:
self._imported = True
return f"typing.Optional[{type}]"
def list(self, type: str) -> str:
self._imported = True
return f"typing.List[{type}]"
def dict(self, key: str, value: str) -> str:
self._imported = True
return f"typing.Dict[{key}, {value}]"
def union(self, *types: str) -> str:
self._imported = True
return f"typing.Union[{', '.join(types)}]"
def iterable(self, type: str) -> str:
self._imported = True
return f"typing.Iterable[{type}]"
def async_iterable(self, type: str) -> str:
self._imported = True
return f"typing.AsyncIterable[{type}]"
def async_iterator(self, type: str) -> str:
self._imported = True
return f"typing.AsyncIterator[{type}]"
def imports(self) -> Dict[str, Optional[Set[str]]]:
if self._imported:
return {"typing": None}
return {}
@dataclass
class NoTyping310TypingCompiler(TypingCompiler):
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
def optional(self, type: str) -> str:
return f"{type} | None"
def list(self, type: str) -> str:
return f"list[{type}]"
def dict(self, key: str, value: str) -> str:
return f"dict[{key}, {value}]"
def union(self, *types: str) -> str:
return " | ".join(types)
def iterable(self, type: str) -> str:
self._imports["typing"].add("Iterable")
return f"Iterable[{type}]"
def async_iterable(self, type: str) -> str:
self._imports["typing"].add("AsyncIterable")
return f"AsyncIterable[{type}]"
def async_iterator(self, type: str) -> str:
self._imports["typing"].add("AsyncIterator")
return f"AsyncIterator[{type}]"
def imports(self) -> Dict[str, Optional[Set[str]]]:
return {k: v if v else None for k, v in self._imports.items()}

View File

@ -0,0 +1,55 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: {{ ', '.join(output_file.input_filenames) }}
# plugin: python-betterproto
# This file has been @generated
{% for i in output_file.python_module_imports|sort %}
import {{ i }}
{% endfor %}
{% set type_checking_imported = False %}
{% if output_file.pydantic_dataclasses %}
from typing import TYPE_CHECKING
{% set type_checking_imported = True %}
if TYPE_CHECKING:
from dataclasses import dataclass
else:
from pydantic.dataclasses import dataclass
from pydantic.dataclasses import rebuild_dataclass
{%- else -%}
from dataclasses import dataclass
{% endif %}
{% if output_file.datetime_imports %}
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif%}
{% set typing_imports = output_file.typing_compiler.imports() %}
{% if typing_imports %}
{% for line in output_file.typing_compiler.import_lines() %}
{{ line }}
{% endfor %}
{% endif %}
{% if output_file.pydantic_imports %}
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
import betterproto
{% if output_file.services %}
from betterproto.grpc.grpclib_server import ServiceBase
import grpclib
{% endif %}
{% for i in output_file.imports|sort %}
{{ i }}
{% endfor %}
{% if output_file.imports_type_checking_only and not type_checking_imported %}
from typing import TYPE_CHECKING
if TYPE_CHECKING:
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
{% endfor %}
{% endif %}

View File

@ -1,53 +1,3 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: {{ ', '.join(output_file.input_filenames) }}
# plugin: python-betterproto
# This file has been @generated
{% for i in output_file.python_module_imports|sort %}
import {{ i }}
{% endfor %}
{% if output_file.pydantic_dataclasses %}
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dataclasses import dataclass
else:
from pydantic.dataclasses import dataclass
{%- else -%}
from dataclasses import dataclass
{% endif %}
{% if output_file.datetime_imports %}
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif%}
{% if output_file.typing_imports %}
from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
{% if output_file.pydantic_imports %}
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
import betterproto
{% if output_file.services %}
from betterproto.grpc.grpclib_server import ServiceBase
import grpclib
{% endif %}
{% for i in output_file.imports|sort %}
{{ i }}
{% endfor %}
{% if output_file.imports_type_checking_only %}
from typing import TYPE_CHECKING
if TYPE_CHECKING:
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
{% endfor %}
{% endif %}
{% if output_file.enums %}{% for enum in output_file.enums %} {% if output_file.enums %}{% for enum in output_file.enums %}
class {{ enum.py_name }}(betterproto.Enum): class {{ enum.py_name }}(betterproto.Enum):
{% if enum.comment %} {% if enum.comment %}
@ -62,6 +12,13 @@ class {{ enum.py_name }}(betterproto.Enum):
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{% if output_file.pydantic_dataclasses %}
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
{% endif %}
{% endfor %} {% endfor %}
{% endif %} {% endif %}
@ -96,7 +53,7 @@ class {{ message.py_name }}(betterproto.Message):
{% endif %} {% endif %}
{% if output_file.pydantic_dataclasses and message.has_oneof_fields %} {% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
@root_validator() @model_validator(mode='after')
def check_oneof(cls, values): def check_oneof(cls, values):
return cls._validate_field_groups(values) return cls._validate_field_groups(values)
{% endif %} {% endif %}
@ -116,14 +73,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%} {%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
{%- else -%} {%- else -%}
{# Client streaming: need a request iterator instead #} {# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]] , {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }}
{%- endif -%} {%- endif -%}
, ,
* *
, timeout: Optional[float] = None , timeout: {{ output_file.typing_compiler.optional("float") }} = None
, deadline: Optional["Deadline"] = None , deadline: {{ output_file.typing_compiler.optional('"Deadline"') }} = None
, metadata: Optional["MetadataLike"] = None , metadata: {{ output_file.typing_compiler.optional('"MetadataLike"') }} = None
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: ) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %} {% if method.comment %}
{{ method.comment }} {{ method.comment }}
@ -191,9 +148,9 @@ class {{ service.py_name }}Base(ServiceBase):
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%} {%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
{%- else -%} {%- else -%}
{# Client streaming: need a request iterator instead #} {# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"] , {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }}
{%- endif -%} {%- endif -%}
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: ) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %} {% if method.comment %}
{{ method.comment }} {{ method.comment }}
@ -225,7 +182,7 @@ class {{ service.py_name }}Base(ServiceBase):
{% endfor %} {% endfor %}
def __mapping__(self) -> Dict[str, grpclib.const.Handler]: def __mapping__(self) -> {{ output_file.typing_compiler.dict("str", "grpclib.const.Handler") }}:
return { return {
{% for method in service.methods %} {% for method in service.methods %}
"{{ method.route }}": grpclib.const.Handler( "{{ method.route }}": grpclib.const.Handler(
@ -250,7 +207,7 @@ class {{ service.py_name }}Base(ServiceBase):
{% if output_file.pydantic_dataclasses %} {% if output_file.pydantic_dataclasses %}
{% for message in output_file.messages %} {% for message in output_file.messages %}
{% if message.has_message_field %} {% if message.has_message_field %}
{{ message.py_name }}.__pydantic_model__.update_forward_refs() # type: ignore rebuild_dataclass({{ message.py_name }}) # type: ignore
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{% endif %} {% endif %}

View File

@ -108,6 +108,7 @@ async def generate_test_case_output(
print( print(
f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m" f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m"
) )
print(ref_err.decode())
if verbose: if verbose:
if ref_out: if ref_out:
@ -126,6 +127,7 @@ async def generate_test_case_output(
print( print(
f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m" f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m"
) )
print(plg_err.decode())
if verbose: if verbose:
if plg_out: if plg_out:
@ -146,6 +148,7 @@ async def generate_test_case_output(
print( print(
f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m" f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
) )
print(plg_err_pyd.decode())
if verbose: if verbose:
if plg_out_pyd: if plg_out_pyd:

View File

@ -10,10 +10,15 @@ def test_value():
def test_pydantic_no_value(): def test_pydantic_no_value():
with pytest.raises(ValueError): message = TestPyd()
TestPyd() assert not message.value, "Boolean is False by default"
def test_pydantic_value(): def test_pydantic_value():
message = Test(value=False) message = TestPyd(value=False)
assert not message.value assert not message.value
def test_pydantic_bad_value():
with pytest.raises(ValueError):
TestPyd(value=123)

View File

@ -4,6 +4,15 @@ from betterproto.compile.importing import (
get_type_reference, get_type_reference,
parse_source_type_name, parse_source_type_name,
) )
from betterproto.plugin.typing_compiler import DirectImportTypingCompiler
@pytest.fixture
def typing_compiler() -> DirectImportTypingCompiler:
"""
Generates a simple Direct Import Typing Compiler for testing.
"""
return DirectImportTypingCompiler()
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -32,11 +41,18 @@ from betterproto.compile.importing import (
], ],
) )
def test_reference_google_wellknown_types_non_wrappers( def test_reference_google_wellknown_types_non_wrappers(
google_type: str, expected_name: str, expected_import: str google_type: str,
expected_name: str,
expected_import: str,
typing_compiler: DirectImportTypingCompiler,
): ):
imports = set() imports = set()
name = get_type_reference( name = get_type_reference(
package="", imports=imports, source_type=google_type, pydantic=False package="",
imports=imports,
source_type=google_type,
typing_compiler=typing_compiler,
pydantic=False,
) )
assert name == expected_name assert name == expected_name
@ -71,11 +87,18 @@ def test_reference_google_wellknown_types_non_wrappers(
], ],
) )
def test_reference_google_wellknown_types_non_wrappers_pydantic( def test_reference_google_wellknown_types_non_wrappers_pydantic(
google_type: str, expected_name: str, expected_import: str google_type: str,
expected_name: str,
expected_import: str,
typing_compiler: DirectImportTypingCompiler,
): ):
imports = set() imports = set()
name = get_type_reference( name = get_type_reference(
package="", imports=imports, source_type=google_type, pydantic=True package="",
imports=imports,
source_type=google_type,
typing_compiler=typing_compiler,
pydantic=True,
) )
assert name == expected_name assert name == expected_name
@ -99,10 +122,15 @@ def test_reference_google_wellknown_types_non_wrappers_pydantic(
], ],
) )
def test_referenceing_google_wrappers_unwraps_them( def test_referenceing_google_wrappers_unwraps_them(
google_type: str, expected_name: str google_type: str, expected_name: str, typing_compiler: DirectImportTypingCompiler
): ):
imports = set() imports = set()
name = get_type_reference(package="", imports=imports, source_type=google_type) name = get_type_reference(
package="",
imports=imports,
source_type=google_type,
typing_compiler=typing_compiler,
)
assert name == expected_name assert name == expected_name
assert imports == set() assert imports == set()
@ -135,223 +163,321 @@ def test_referenceing_google_wrappers_unwraps_them(
], ],
) )
def test_referenceing_google_wrappers_without_unwrapping( def test_referenceing_google_wrappers_without_unwrapping(
google_type: str, expected_name: str google_type: str, expected_name: str, typing_compiler: DirectImportTypingCompiler
): ):
name = get_type_reference( name = get_type_reference(
package="", imports=set(), source_type=google_type, unwrap=False package="",
imports=set(),
source_type=google_type,
typing_compiler=typing_compiler,
unwrap=False,
) )
assert name == expected_name assert name == expected_name
def test_reference_child_package_from_package(): def test_reference_child_package_from_package(
typing_compiler: DirectImportTypingCompiler,
):
imports = set() imports = set()
name = get_type_reference( name = get_type_reference(
package="package", imports=imports, source_type="package.child.Message" package="package",
imports=imports,
source_type="package.child.Message",
typing_compiler=typing_compiler,
) )
assert imports == {"from . import child"} assert imports == {"from . import child"}
assert name == '"child.Message"' assert name == '"child.Message"'
def test_reference_child_package_from_root(): def test_reference_child_package_from_root(typing_compiler: DirectImportTypingCompiler):
imports = set() imports = set()
name = get_type_reference(package="", imports=imports, source_type="child.Message") name = get_type_reference(
package="",
imports=imports,
source_type="child.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from . import child"} assert imports == {"from . import child"}
assert name == '"child.Message"' assert name == '"child.Message"'
def test_reference_camel_cased(): def test_reference_camel_cased(typing_compiler: DirectImportTypingCompiler):
imports = set() imports = set()
name = get_type_reference( name = get_type_reference(
package="", imports=imports, source_type="child_package.example_message" package="",
imports=imports,
source_type="child_package.example_message",
typing_compiler=typing_compiler,
) )
assert imports == {"from . import child_package"} assert imports == {"from . import child_package"}
assert name == '"child_package.ExampleMessage"' assert name == '"child_package.ExampleMessage"'
def test_reference_nested_child_from_root(): def test_reference_nested_child_from_root(typing_compiler: DirectImportTypingCompiler):
imports = set() imports = set()
name = get_type_reference( name = get_type_reference(
package="", imports=imports, source_type="nested.child.Message" package="",
imports=imports,
source_type="nested.child.Message",
typing_compiler=typing_compiler,
) )
assert imports == {"from .nested import child as nested_child"} assert imports == {"from .nested import child as nested_child"}
assert name == '"nested_child.Message"' assert name == '"nested_child.Message"'
def test_reference_deeply_nested_child_from_root(): def test_reference_deeply_nested_child_from_root(
typing_compiler: DirectImportTypingCompiler,
):
imports = set() imports = set()
name = get_type_reference( name = get_type_reference(
package="", imports=imports, source_type="deeply.nested.child.Message" package="",
imports=imports,
source_type="deeply.nested.child.Message",
typing_compiler=typing_compiler,
) )
assert imports == {"from .deeply.nested import child as deeply_nested_child"} assert imports == {"from .deeply.nested import child as deeply_nested_child"}
assert name == '"deeply_nested_child.Message"' assert name == '"deeply_nested_child.Message"'
def test_reference_deeply_nested_child_from_package(): def test_reference_deeply_nested_child_from_package(
typing_compiler: DirectImportTypingCompiler,
):
imports = set() imports = set()
name = get_type_reference( name = get_type_reference(
package="package", package="package",
imports=imports, imports=imports,
source_type="package.deeply.nested.child.Message", source_type="package.deeply.nested.child.Message",
typing_compiler=typing_compiler,
) )
assert imports == {"from .deeply.nested import child as deeply_nested_child"} assert imports == {"from .deeply.nested import child as deeply_nested_child"}
assert name == '"deeply_nested_child.Message"' assert name == '"deeply_nested_child.Message"'
def test_reference_root_sibling(): def test_reference_root_sibling(typing_compiler: DirectImportTypingCompiler):
imports = set()
name = get_type_reference(package="", imports=imports, source_type="Message")
assert imports == set()
assert name == '"Message"'
def test_reference_nested_siblings():
imports = set()
name = get_type_reference(package="foo", imports=imports, source_type="foo.Message")
assert imports == set()
assert name == '"Message"'
def test_reference_deeply_nested_siblings():
imports = set() imports = set()
name = get_type_reference( name = get_type_reference(
package="foo.bar", imports=imports, source_type="foo.bar.Message" package="",
imports=imports,
source_type="Message",
typing_compiler=typing_compiler,
) )
assert imports == set() assert imports == set()
assert name == '"Message"' assert name == '"Message"'
def test_reference_parent_package_from_child(): def test_reference_nested_siblings(typing_compiler: DirectImportTypingCompiler):
imports = set() imports = set()
name = get_type_reference( name = get_type_reference(
package="package.child", imports=imports, source_type="package.Message" package="foo",
imports=imports,
source_type="foo.Message",
typing_compiler=typing_compiler,
)
assert imports == set()
assert name == '"Message"'
def test_reference_deeply_nested_siblings(typing_compiler: DirectImportTypingCompiler):
imports = set()
name = get_type_reference(
package="foo.bar",
imports=imports,
source_type="foo.bar.Message",
typing_compiler=typing_compiler,
)
assert imports == set()
assert name == '"Message"'
def test_reference_parent_package_from_child(
typing_compiler: DirectImportTypingCompiler,
):
imports = set()
name = get_type_reference(
package="package.child",
imports=imports,
source_type="package.Message",
typing_compiler=typing_compiler,
) )
assert imports == {"from ... import package as __package__"} assert imports == {"from ... import package as __package__"}
assert name == '"__package__.Message"' assert name == '"__package__.Message"'
def test_reference_parent_package_from_deeply_nested_child(): def test_reference_parent_package_from_deeply_nested_child(
typing_compiler: DirectImportTypingCompiler,
):
imports = set() imports = set()
name = get_type_reference( name = get_type_reference(
package="package.deeply.nested.child", package="package.deeply.nested.child",
imports=imports, imports=imports,
source_type="package.deeply.nested.Message", source_type="package.deeply.nested.Message",
typing_compiler=typing_compiler,
) )
assert imports == {"from ... import nested as __nested__"} assert imports == {"from ... import nested as __nested__"}
assert name == '"__nested__.Message"' assert name == '"__nested__.Message"'
def test_reference_ancestor_package_from_nested_child(): def test_reference_ancestor_package_from_nested_child(
typing_compiler: DirectImportTypingCompiler,
):
imports = set() imports = set()
name = get_type_reference( name = get_type_reference(
package="package.ancestor.nested.child", package="package.ancestor.nested.child",
imports=imports, imports=imports,
source_type="package.ancestor.Message", source_type="package.ancestor.Message",
typing_compiler=typing_compiler,
) )
assert imports == {"from .... import ancestor as ___ancestor__"} assert imports == {"from .... import ancestor as ___ancestor__"}
assert name == '"___ancestor__.Message"' assert name == '"___ancestor__.Message"'
def test_reference_root_package_from_child(): def test_reference_root_package_from_child(typing_compiler: DirectImportTypingCompiler):
imports = set() imports = set()
name = get_type_reference( name = get_type_reference(
package="package.child", imports=imports, source_type="Message" package="package.child",
imports=imports,
source_type="Message",
typing_compiler=typing_compiler,
) )
assert imports == {"from ... import Message as __Message__"} assert imports == {"from ... import Message as __Message__"}
assert name == '"__Message__"' assert name == '"__Message__"'
def test_reference_root_package_from_deeply_nested_child(): def test_reference_root_package_from_deeply_nested_child(
typing_compiler: DirectImportTypingCompiler,
):
imports = set() imports = set()
name = get_type_reference( name = get_type_reference(
package="package.deeply.nested.child", imports=imports, source_type="Message" package="package.deeply.nested.child",
imports=imports,
source_type="Message",
typing_compiler=typing_compiler,
) )
assert imports == {"from ..... import Message as ____Message__"} assert imports == {"from ..... import Message as ____Message__"}
assert name == '"____Message__"' assert name == '"____Message__"'
def test_reference_unrelated_package(): def test_reference_unrelated_package(typing_compiler: DirectImportTypingCompiler):
imports = set() imports = set()
name = get_type_reference(package="a", imports=imports, source_type="p.Message") name = get_type_reference(
package="a",
imports=imports,
source_type="p.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from .. import p as _p__"} assert imports == {"from .. import p as _p__"}
assert name == '"_p__.Message"' assert name == '"_p__.Message"'
def test_reference_unrelated_nested_package(): def test_reference_unrelated_nested_package(
typing_compiler: DirectImportTypingCompiler,
):
imports = set() imports = set()
name = get_type_reference(package="a.b", imports=imports, source_type="p.q.Message") name = get_type_reference(
package="a.b",
imports=imports,
source_type="p.q.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from ...p import q as __p_q__"} assert imports == {"from ...p import q as __p_q__"}
assert name == '"__p_q__.Message"' assert name == '"__p_q__.Message"'
def test_reference_unrelated_deeply_nested_package(): def test_reference_unrelated_deeply_nested_package(
typing_compiler: DirectImportTypingCompiler,
):
imports = set() imports = set()
name = get_type_reference( name = get_type_reference(
package="a.b.c.d", imports=imports, source_type="p.q.r.s.Message" package="a.b.c.d",
imports=imports,
source_type="p.q.r.s.Message",
typing_compiler=typing_compiler,
) )
assert imports == {"from .....p.q.r import s as ____p_q_r_s__"} assert imports == {"from .....p.q.r import s as ____p_q_r_s__"}
assert name == '"____p_q_r_s__.Message"' assert name == '"____p_q_r_s__.Message"'
def test_reference_cousin_package(): def test_reference_cousin_package(typing_compiler: DirectImportTypingCompiler):
imports = set() imports = set()
name = get_type_reference(package="a.x", imports=imports, source_type="a.y.Message") name = get_type_reference(
package="a.x",
imports=imports,
source_type="a.y.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from .. import y as _y__"} assert imports == {"from .. import y as _y__"}
assert name == '"_y__.Message"' assert name == '"_y__.Message"'
def test_reference_cousin_package_different_name(): def test_reference_cousin_package_different_name(
typing_compiler: DirectImportTypingCompiler,
):
imports = set() imports = set()
name = get_type_reference( name = get_type_reference(
package="test.package1", imports=imports, source_type="cousin.package2.Message" package="test.package1",
imports=imports,
source_type="cousin.package2.Message",
typing_compiler=typing_compiler,
) )
assert imports == {"from ...cousin import package2 as __cousin_package2__"} assert imports == {"from ...cousin import package2 as __cousin_package2__"}
assert name == '"__cousin_package2__.Message"' assert name == '"__cousin_package2__.Message"'
def test_reference_cousin_package_same_name(): def test_reference_cousin_package_same_name(
typing_compiler: DirectImportTypingCompiler,
):
imports = set() imports = set()
name = get_type_reference( name = get_type_reference(
package="test.package", imports=imports, source_type="cousin.package.Message" package="test.package",
imports=imports,
source_type="cousin.package.Message",
typing_compiler=typing_compiler,
) )
assert imports == {"from ...cousin import package as __cousin_package__"} assert imports == {"from ...cousin import package as __cousin_package__"}
assert name == '"__cousin_package__.Message"' assert name == '"__cousin_package__.Message"'
def test_reference_far_cousin_package(): def test_reference_far_cousin_package(typing_compiler: DirectImportTypingCompiler):
imports = set() imports = set()
name = get_type_reference( name = get_type_reference(
package="a.x.y", imports=imports, source_type="a.b.c.Message" package="a.x.y",
imports=imports,
source_type="a.b.c.Message",
typing_compiler=typing_compiler,
) )
assert imports == {"from ...b import c as __b_c__"} assert imports == {"from ...b import c as __b_c__"}
assert name == '"__b_c__.Message"' assert name == '"__b_c__.Message"'
def test_reference_far_far_cousin_package(): def test_reference_far_far_cousin_package(typing_compiler: DirectImportTypingCompiler):
imports = set() imports = set()
name = get_type_reference( name = get_type_reference(
package="a.x.y.z", imports=imports, source_type="a.b.c.d.Message" package="a.x.y.z",
imports=imports,
source_type="a.b.c.d.Message",
typing_compiler=typing_compiler,
) )
assert imports == {"from ....b.c import d as ___b_c_d__"} assert imports == {"from ....b.c import d as ___b_c_d__"}

View File

@ -0,0 +1,111 @@
from typing import (
List,
Optional,
Set,
)
import pytest
from betterproto.plugin.module_validation import ModuleValidator
@pytest.mark.parametrize(
["text", "expected_collisions"],
[
pytest.param(
["import os"],
None,
id="single import",
),
pytest.param(
["import os", "import sys"],
None,
id="multiple imports",
),
pytest.param(
["import os", "import os"],
{"os"},
id="duplicate imports",
),
pytest.param(
["from os import path", "import os"],
None,
id="duplicate imports with alias",
),
pytest.param(
["from os import path", "import os as os_alias"],
None,
id="duplicate imports with alias",
),
pytest.param(
["from os import path", "import os as path"],
{"path"},
id="duplicate imports with alias",
),
pytest.param(
["import os", "class os:"],
{"os"},
id="duplicate import with class",
),
pytest.param(
["import os", "class os:", " pass", "import sys"],
{"os"},
id="duplicate import with class and another",
),
pytest.param(
["def test(): pass", "class test:"],
{"test"},
id="duplicate class and function",
),
pytest.param(
["def test(): pass", "def test(): pass"],
{"test"},
id="duplicate functions",
),
pytest.param(
["def test(): pass", "test = 100"],
{"test"},
id="function and variable",
),
pytest.param(
["def test():", " test = 3"],
None,
id="function and variable in function",
),
pytest.param(
[
"def test(): pass",
"'''",
"def test(): pass",
"'''",
"def test_2(): pass",
],
None,
id="duplicate functions with multiline string",
),
pytest.param(
["def test(): pass", "# def test(): pass"],
None,
id="duplicate functions with comments",
),
pytest.param(
["from test import (", " A", " B", " C", ")"],
None,
id="multiline import",
),
pytest.param(
["from test import (", " A", " B", " C", ")", "from test import A"],
{"A"},
id="multiline import with duplicate",
),
],
)
def test_module_validator(text: List[str], expected_collisions: Optional[Set[str]]):
line_iterator = iter(text)
validator = ModuleValidator(line_iterator)
valid = validator.validate()
if expected_collisions is None:
assert valid
else:
assert set(validator.collisions.keys()) == expected_collisions
assert not valid

View File

@ -0,0 +1,80 @@
import pytest
from betterproto.plugin.typing_compiler import (
DirectImportTypingCompiler,
NoTyping310TypingCompiler,
TypingImportTypingCompiler,
)
def test_direct_import_typing_compiler():
compiler = DirectImportTypingCompiler()
assert compiler.imports() == {}
assert compiler.optional("str") == "Optional[str]"
assert compiler.imports() == {"typing": {"Optional"}}
assert compiler.list("str") == "List[str]"
assert compiler.imports() == {"typing": {"Optional", "List"}}
assert compiler.dict("str", "int") == "Dict[str, int]"
assert compiler.imports() == {"typing": {"Optional", "List", "Dict"}}
assert compiler.union("str", "int") == "Union[str, int]"
assert compiler.imports() == {"typing": {"Optional", "List", "Dict", "Union"}}
assert compiler.iterable("str") == "Iterable[str]"
assert compiler.imports() == {
"typing": {"Optional", "List", "Dict", "Union", "Iterable"}
}
assert compiler.async_iterable("str") == "AsyncIterable[str]"
assert compiler.imports() == {
"typing": {"Optional", "List", "Dict", "Union", "Iterable", "AsyncIterable"}
}
assert compiler.async_iterator("str") == "AsyncIterator[str]"
assert compiler.imports() == {
"typing": {
"Optional",
"List",
"Dict",
"Union",
"Iterable",
"AsyncIterable",
"AsyncIterator",
}
}
def test_typing_import_typing_compiler():
compiler = TypingImportTypingCompiler()
assert compiler.imports() == {}
assert compiler.optional("str") == "typing.Optional[str]"
assert compiler.imports() == {"typing": None}
assert compiler.list("str") == "typing.List[str]"
assert compiler.imports() == {"typing": None}
assert compiler.dict("str", "int") == "typing.Dict[str, int]"
assert compiler.imports() == {"typing": None}
assert compiler.union("str", "int") == "typing.Union[str, int]"
assert compiler.imports() == {"typing": None}
assert compiler.iterable("str") == "typing.Iterable[str]"
assert compiler.imports() == {"typing": None}
assert compiler.async_iterable("str") == "typing.AsyncIterable[str]"
assert compiler.imports() == {"typing": None}
assert compiler.async_iterator("str") == "typing.AsyncIterator[str]"
assert compiler.imports() == {"typing": None}
def test_no_typing_311_typing_compiler():
compiler = NoTyping310TypingCompiler()
assert compiler.imports() == {}
assert compiler.optional("str") == "str | None"
assert compiler.imports() == {}
assert compiler.list("str") == "list[str]"
assert compiler.imports() == {}
assert compiler.dict("str", "int") == "dict[str, int]"
assert compiler.imports() == {}
assert compiler.union("str", "int") == "str | int"
assert compiler.imports() == {}
assert compiler.iterable("str") == "Iterable[str]"
assert compiler.imports() == {"typing": {"Iterable"}}
assert compiler.async_iterable("str") == "AsyncIterable[str]"
assert compiler.imports() == {"typing": {"Iterable", "AsyncIterable"}}
assert compiler.async_iterator("str") == "AsyncIterator[str]"
assert compiler.imports() == {
"typing": {"Iterable", "AsyncIterable", "AsyncIterator"}
}