Merge branch 'refs/heads/master_gh'

This commit is contained in:
Georg K 2024-07-30 22:10:55 +03:00
commit 32eaa51e8d
19 changed files with 1726 additions and 791 deletions

View File

@ -16,7 +16,7 @@ jobs:
fail-fast: false
matrix:
os: [Ubuntu, MacOS, Windows]
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12']
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
steps:
- uses: actions/checkout@v4

View File

@ -391,7 +391,50 @@ swap the dataclass implementation from the builtin python dataclass to the
pydantic dataclass. You must have pydantic as a dependency in your project for
this to work.
## Configuration typing imports
By default typing types will be imported directly from typing. This sometimes can lead to issues in generation if types that are being generated conflict with the name. In this case you can configure the way types are imported from 3 different options:
### Direct
```
protoc -I . --python_betterproto_opt=typing.direct --python_betterproto_out=lib example.proto
```
this configuration is the default, and will import types as follows:
```
from typing import (
List,
Optional,
Union
)
...
value: List[str] = []
value2: Optional[str] = None
value3: Union[str, int] = 1
```
### Root
```
protoc -I . --python_betterproto_opt=typing.root --python_betterproto_out=lib example.proto
```
this configuration loads the root typing module, and then access the types off of it directly:
```
import typing
...
value: typing.List[str] = []
value2: typing.Optional[str] = None
value3: typing.Union[str, int] = 1
```
### 310
```
protoc -I . --python_betterproto_opt=typing.310 --python_betterproto_out=lib example.proto
```
this configuration avoid loading typing all together if possible and uses the python 3.10 pattern:
```
...
value: list[str] = []
value2: str | None = None
value3: str | int = 1
```
## Development

1138
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -39,7 +39,7 @@ pytest = "^6.2.5"
pytest-asyncio = "^0.12.0"
pytest-cov = "^2.9.0"
pytest-mock = "^3.1.1"
pydantic = ">=1.8.0,<2"
pydantic = ">=2.0,<3"
protobuf = "^4"
cachelib = "^0.10.2"
tomlkit = ">=0.7.0"

View File

@ -1852,7 +1852,9 @@ class Message(ABC):
continue
set_fields = [
field.name for field in field_set if values[field.name] is not None
field.name
for field in field_set
if getattr(values, field.name, None) is not None
]
if not set_fields:

View File

@ -47,6 +47,7 @@ def get_type_reference(
package: str,
imports: set,
source_type: str,
typing_compiler: "TypingCompiler",
unwrap: bool = True,
pydantic: bool = False,
) -> str:
@ -57,7 +58,7 @@ def get_type_reference(
if unwrap:
if source_type in WRAPPER_TYPES:
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
return f"Optional[{wrapped_type.__name__}]"
return typing_compiler.optional(wrapped_type.__name__)
if source_type == ".google.protobuf.Duration":
return "timedelta"

View File

@ -1,9 +1,16 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: google/protobuf/any.proto, google/protobuf/api.proto, google/protobuf/descriptor.proto, google/protobuf/duration.proto, google/protobuf/empty.proto, google/protobuf/field_mask.proto, google/protobuf/source_context.proto, google/protobuf/struct.proto, google/protobuf/timestamp.proto, google/protobuf/type.proto, google/protobuf/wrappers.proto
# plugin: python-betterproto
# This file has been @generated
import warnings
from typing import TYPE_CHECKING
from typing import (
TYPE_CHECKING,
Mapping,
)
from typing_extensions import Self
from betterproto import hybridmethod
if TYPE_CHECKING:
@ -14,15 +21,13 @@ else:
from typing import (
Dict,
List,
Mapping,
Optional,
)
from pydantic import root_validator
from typing_extensions import Self
from pydantic import model_validator
from pydantic.dataclasses import rebuild_dataclass
import betterproto
from betterproto.utils import hybridmethod
class Syntax(betterproto.Enum):
@ -37,6 +42,12 @@ class Syntax(betterproto.Enum):
EDITIONS = 2
"""Syntax `editions`."""
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FieldKind(betterproto.Enum):
"""Basic field types."""
@ -98,6 +109,12 @@ class FieldKind(betterproto.Enum):
TYPE_SINT64 = 18
"""Field type sint64."""
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FieldCardinality(betterproto.Enum):
"""Whether a field is optional, required, or repeated."""
@ -114,6 +131,12 @@ class FieldCardinality(betterproto.Enum):
CARDINALITY_REPEATED = 3
"""For repeated fields."""
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class Edition(betterproto.Enum):
"""The full set of known editions."""
@ -155,6 +178,12 @@ class Edition(betterproto.Enum):
support a new edition.
"""
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class ExtensionRangeOptionsVerificationState(betterproto.Enum):
"""The verification state of the extension range."""
@ -164,6 +193,12 @@ class ExtensionRangeOptionsVerificationState(betterproto.Enum):
UNVERIFIED = 1
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FieldDescriptorProtoType(betterproto.Enum):
TYPE_DOUBLE = 1
@ -210,6 +245,12 @@ class FieldDescriptorProtoType(betterproto.Enum):
TYPE_SINT32 = 17
TYPE_SINT64 = 18
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FieldDescriptorProtoLabel(betterproto.Enum):
LABEL_OPTIONAL = 1
@ -223,6 +264,12 @@ class FieldDescriptorProtoLabel(betterproto.Enum):
can be used to get this behavior.
"""
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FileOptionsOptimizeMode(betterproto.Enum):
"""Generated classes can be optimized for speed or code size."""
@ -233,6 +280,12 @@ class FileOptionsOptimizeMode(betterproto.Enum):
LITE_RUNTIME = 3
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FieldOptionsCType(betterproto.Enum):
STRING = 0
@ -250,6 +303,12 @@ class FieldOptionsCType(betterproto.Enum):
STRING_PIECE = 2
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FieldOptionsJsType(betterproto.Enum):
JS_NORMAL = 0
@ -261,6 +320,12 @@ class FieldOptionsJsType(betterproto.Enum):
JS_NUMBER = 2
"""Use JavaScript numbers."""
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FieldOptionsOptionRetention(betterproto.Enum):
"""
@ -273,6 +338,12 @@ class FieldOptionsOptionRetention(betterproto.Enum):
RETENTION_RUNTIME = 1
RETENTION_SOURCE = 2
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FieldOptionsOptionTargetType(betterproto.Enum):
"""
@ -293,6 +364,12 @@ class FieldOptionsOptionTargetType(betterproto.Enum):
TARGET_TYPE_SERVICE = 8
TARGET_TYPE_METHOD = 9
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class MethodOptionsIdempotencyLevel(betterproto.Enum):
"""
@ -305,6 +382,12 @@ class MethodOptionsIdempotencyLevel(betterproto.Enum):
NO_SIDE_EFFECTS = 1
IDEMPOTENT = 2
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FeatureSetFieldPresence(betterproto.Enum):
FIELD_PRESENCE_UNKNOWN = 0
@ -312,36 +395,72 @@ class FeatureSetFieldPresence(betterproto.Enum):
IMPLICIT = 2
LEGACY_REQUIRED = 3
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FeatureSetEnumType(betterproto.Enum):
ENUM_TYPE_UNKNOWN = 0
OPEN = 1
CLOSED = 2
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FeatureSetRepeatedFieldEncoding(betterproto.Enum):
REPEATED_FIELD_ENCODING_UNKNOWN = 0
PACKED = 1
EXPANDED = 2
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FeatureSetUtf8Validation(betterproto.Enum):
UTF8_VALIDATION_UNKNOWN = 0
VERIFY = 2
NONE = 3
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FeatureSetMessageEncoding(betterproto.Enum):
MESSAGE_ENCODING_UNKNOWN = 0
LENGTH_PREFIXED = 1
DELIMITED = 2
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class FeatureSetJsonFormat(betterproto.Enum):
JSON_FORMAT_UNKNOWN = 0
ALLOW = 1
LEGACY_BEST_EFFORT = 2
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class GeneratedCodeInfoAnnotationSemantic(betterproto.Enum):
"""
@ -358,6 +477,12 @@ class GeneratedCodeInfoAnnotationSemantic(betterproto.Enum):
ALIAS = 2
"""An alias to the element is returned."""
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
class NullValue(betterproto.Enum):
"""
@ -370,6 +495,12 @@ class NullValue(betterproto.Enum):
_ = 0
"""Null value."""
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
@dataclass(eq=False, repr=False)
class Any(betterproto.Message):
@ -1176,16 +1307,12 @@ class FileOptions(betterproto.Message):
java_string_check_utf8: bool = betterproto.bool_field(27)
"""
A proto2 file can set this to true to opt in to UTF-8 checking for Java,
which will throw an exception if invalid UTF-8 is parsed from the wire or
assigned to a string field.
TODO: clarify exactly what kinds of field types this option
applies to, and update these docs accordingly.
Proto3 files already perform these checks. Setting the option explicitly to
false has no effect: it cannot be used to opt proto3 files out of UTF-8
checks.
If set true, then the Java2 code generator will generate code that
throws an exception whenever an attempt is made to assign a non-UTF-8
byte sequence to a string field.
Message reflection will do the same.
However, an extension field still accepts non-UTF-8 byte sequences.
This option has no effect on when used with the lite runtime.
"""
optimize_for: "FileOptionsOptimizeMode" = betterproto.enum_field(9)
@ -1477,7 +1604,6 @@ class FieldOptions(betterproto.Message):
features: "FeatureSet" = betterproto.message_field(21)
"""Any features defined in the specific edition."""
feature_support: "FieldOptionsFeatureSupport" = betterproto.message_field(22)
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
"""The parser stores options it doesn't recognize here. See above."""
@ -1488,37 +1614,6 @@ class FieldOptionsEditionDefault(betterproto.Message):
value: str = betterproto.string_field(2)
@dataclass(eq=False, repr=False)
class FieldOptionsFeatureSupport(betterproto.Message):
"""Information about the support window of a feature."""
edition_introduced: "Edition" = betterproto.enum_field(1)
"""
The edition that this feature was first available in. In editions
earlier than this one, the default assigned to EDITION_LEGACY will be
used, and proto files will not be able to override it.
"""
edition_deprecated: "Edition" = betterproto.enum_field(2)
"""
The edition this feature becomes deprecated in. Using this after this
edition may trigger warnings.
"""
deprecation_warning: str = betterproto.string_field(3)
"""
The deprecation warning text if this feature is used after the edition it
was marked deprecated in.
"""
edition_removed: "Edition" = betterproto.enum_field(4)
"""
The edition this feature is no longer available in. In editions after
this one, the last default assigned will be used, and proto files will
not be able to override it.
"""
@dataclass(eq=False, repr=False)
class OneofOptions(betterproto.Message):
features: "FeatureSet" = betterproto.message_field(1)
@ -1723,17 +1818,7 @@ class FeatureSetDefaultsFeatureSetEditionDefault(betterproto.Message):
"""
edition: "Edition" = betterproto.enum_field(3)
overridable_features: "FeatureSet" = betterproto.message_field(4)
"""Defaults of features that can be overridden in this edition."""
fixed_features: "FeatureSet" = betterproto.message_field(5)
"""Defaults of features that can't be overridden in this edition."""
features: "FeatureSet" = betterproto.message_field(2)
"""
TODO Deprecate and remove this field, which is just the
above two merged.
"""
@dataclass(eq=False, repr=False)
@ -2314,7 +2399,7 @@ class Value(betterproto.Message):
)
"""Represents a repeated `Value`."""
@root_validator()
@model_validator(mode="after")
def check_oneof(cls, values):
return cls._validate_field_groups(values)
@ -2549,41 +2634,40 @@ class BytesValue(betterproto.Message):
"""The bytes value."""
Type.__pydantic_model__.update_forward_refs() # type: ignore
Field.__pydantic_model__.update_forward_refs() # type: ignore
Enum.__pydantic_model__.update_forward_refs() # type: ignore
EnumValue.__pydantic_model__.update_forward_refs() # type: ignore
Option.__pydantic_model__.update_forward_refs() # type: ignore
Api.__pydantic_model__.update_forward_refs() # type: ignore
Method.__pydantic_model__.update_forward_refs() # type: ignore
FileDescriptorSet.__pydantic_model__.update_forward_refs() # type: ignore
FileDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
DescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
DescriptorProtoExtensionRange.__pydantic_model__.update_forward_refs() # type: ignore
ExtensionRangeOptions.__pydantic_model__.update_forward_refs() # type: ignore
FieldDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
OneofDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
EnumDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
EnumValueDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
ServiceDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
MethodDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
FileOptions.__pydantic_model__.update_forward_refs() # type: ignore
MessageOptions.__pydantic_model__.update_forward_refs() # type: ignore
FieldOptions.__pydantic_model__.update_forward_refs() # type: ignore
FieldOptionsEditionDefault.__pydantic_model__.update_forward_refs() # type: ignore
FieldOptionsFeatureSupport.__pydantic_model__.update_forward_refs() # type: ignore
OneofOptions.__pydantic_model__.update_forward_refs() # type: ignore
EnumOptions.__pydantic_model__.update_forward_refs() # type: ignore
EnumValueOptions.__pydantic_model__.update_forward_refs() # type: ignore
ServiceOptions.__pydantic_model__.update_forward_refs() # type: ignore
MethodOptions.__pydantic_model__.update_forward_refs() # type: ignore
UninterpretedOption.__pydantic_model__.update_forward_refs() # type: ignore
FeatureSet.__pydantic_model__.update_forward_refs() # type: ignore
FeatureSetDefaults.__pydantic_model__.update_forward_refs() # type: ignore
FeatureSetDefaultsFeatureSetEditionDefault.__pydantic_model__.update_forward_refs() # type: ignore
SourceCodeInfo.__pydantic_model__.update_forward_refs() # type: ignore
GeneratedCodeInfo.__pydantic_model__.update_forward_refs() # type: ignore
GeneratedCodeInfoAnnotation.__pydantic_model__.update_forward_refs() # type: ignore
Struct.__pydantic_model__.update_forward_refs() # type: ignore
Value.__pydantic_model__.update_forward_refs() # type: ignore
ListValue.__pydantic_model__.update_forward_refs() # type: ignore
rebuild_dataclass(Type) # type: ignore
rebuild_dataclass(Field) # type: ignore
rebuild_dataclass(Enum) # type: ignore
rebuild_dataclass(EnumValue) # type: ignore
rebuild_dataclass(Option) # type: ignore
rebuild_dataclass(Api) # type: ignore
rebuild_dataclass(Method) # type: ignore
rebuild_dataclass(FileDescriptorSet) # type: ignore
rebuild_dataclass(FileDescriptorProto) # type: ignore
rebuild_dataclass(DescriptorProto) # type: ignore
rebuild_dataclass(DescriptorProtoExtensionRange) # type: ignore
rebuild_dataclass(ExtensionRangeOptions) # type: ignore
rebuild_dataclass(FieldDescriptorProto) # type: ignore
rebuild_dataclass(OneofDescriptorProto) # type: ignore
rebuild_dataclass(EnumDescriptorProto) # type: ignore
rebuild_dataclass(EnumValueDescriptorProto) # type: ignore
rebuild_dataclass(ServiceDescriptorProto) # type: ignore
rebuild_dataclass(MethodDescriptorProto) # type: ignore
rebuild_dataclass(FileOptions) # type: ignore
rebuild_dataclass(MessageOptions) # type: ignore
rebuild_dataclass(FieldOptions) # type: ignore
rebuild_dataclass(FieldOptionsEditionDefault) # type: ignore
rebuild_dataclass(OneofOptions) # type: ignore
rebuild_dataclass(EnumOptions) # type: ignore
rebuild_dataclass(EnumValueOptions) # type: ignore
rebuild_dataclass(ServiceOptions) # type: ignore
rebuild_dataclass(MethodOptions) # type: ignore
rebuild_dataclass(UninterpretedOption) # type: ignore
rebuild_dataclass(FeatureSet) # type: ignore
rebuild_dataclass(FeatureSetDefaults) # type: ignore
rebuild_dataclass(FeatureSetDefaultsFeatureSetEditionDefault) # type: ignore
rebuild_dataclass(SourceCodeInfo) # type: ignore
rebuild_dataclass(GeneratedCodeInfo) # type: ignore
rebuild_dataclass(GeneratedCodeInfoAnnotation) # type: ignore
rebuild_dataclass(Struct) # type: ignore
rebuild_dataclass(Value) # type: ignore
rebuild_dataclass(ListValue) # type: ignore

View File

@ -1,4 +1,7 @@
import os.path
import sys
from .module_validation import ModuleValidator
try:
@ -30,9 +33,12 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
lstrip_blocks=True,
loader=jinja2.FileSystemLoader(templates_folder),
)
template = env.get_template("template.py.j2")
# Load the body first so we have a compleate list of imports needed.
body_template = env.get_template("template.py.j2")
header_template = env.get_template("header.py.j2")
code = template.render(output_file=output_file)
code = body_template.render(output_file=output_file)
code = header_template.render(output_file=output_file) + code
code = isort.api.sort_code_string(
code=code,
show_diff=False,
@ -44,7 +50,18 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
force_grid_wrap=2,
known_third_party=["grpclib", "betterproto"],
)
return black.format_str(
code = black.format_str(
src_contents=code,
mode=black.Mode(),
)
# Validate the generated code.
validator = ModuleValidator(iter(code.splitlines()))
if not validator.validate():
message_builder = ["[WARNING]: Generated code has collisions in the module:"]
for collision, lines in validator.collisions.items():
message_builder.append(f' "{collision}" on lines:')
for num, line in lines:
message_builder.append(f" {num}:{line}")
print("\n".join(message_builder), file=sys.stderr)
return code

View File

@ -29,10 +29,8 @@ instantiating field `A` with parent message `B` should add a
reference to `A` to `B`'s `fields` attribute.
"""
import builtins
import re
import textwrap
from dataclasses import (
dataclass,
field,
@ -49,12 +47,6 @@ from typing import (
)
import betterproto
from betterproto import which_one_of
from betterproto.casing import sanitize_name
from betterproto.compile.importing import (
get_type_reference,
parse_source_type_name,
)
from betterproto.compile.naming import (
pythonize_class_name,
pythonize_field_name,
@ -72,6 +64,7 @@ from betterproto.lib.google.protobuf import (
)
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
from .. import which_one_of
from ..compile.importing import (
get_type_reference,
parse_source_type_name,
@ -82,6 +75,10 @@ from ..compile.naming import (
pythonize_field_name,
pythonize_method_name,
)
from .typing_compiler import (
DirectImportTypingCompiler,
TypingCompiler,
)
# Create a unique placeholder to deal with
@ -173,6 +170,7 @@ class ProtoContentBase:
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
source_file: FileDescriptorProto
typing_compiler: TypingCompiler
path: List[int]
comment_indent: int = 4
parent: Union["betterproto.Message", "OutputTemplate"]
@ -242,7 +240,6 @@ class OutputTemplate:
input_files: List[str] = field(default_factory=list)
imports: Set[str] = field(default_factory=set)
datetime_imports: Set[str] = field(default_factory=set)
typing_imports: Set[str] = field(default_factory=set)
pydantic_imports: Set[str] = field(default_factory=set)
builtins_import: bool = False
messages: List["MessageCompiler"] = field(default_factory=list)
@ -251,6 +248,7 @@ class OutputTemplate:
imports_type_checking_only: Set[str] = field(default_factory=set)
pydantic_dataclasses: bool = False
output: bool = True
typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)
@property
def package(self) -> str:
@ -289,6 +287,7 @@ class MessageCompiler(ProtoContentBase):
"""Representation of a protobuf message."""
source_file: FileDescriptorProto
typing_compiler: TypingCompiler
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
proto_obj: DescriptorProto = PLACEHOLDER
path: List[int] = PLACEHOLDER
@ -319,7 +318,7 @@ class MessageCompiler(ProtoContentBase):
@property
def annotation(self) -> str:
if self.repeated:
return f"List[{self.py_name}]"
return self.typing_compiler.list(self.py_name)
return self.py_name
@property
@ -434,18 +433,6 @@ class FieldCompiler(MessageCompiler):
imports.add("datetime")
return imports
@property
def typing_imports(self) -> Set[str]:
imports = set()
annotation = self.annotation
if "Optional[" in annotation:
imports.add("Optional")
if "List[" in annotation:
imports.add("List")
if "Dict[" in annotation:
imports.add("Dict")
return imports
@property
def pydantic_imports(self) -> Set[str]:
return set()
@ -458,7 +445,6 @@ class FieldCompiler(MessageCompiler):
def add_imports_to(self, output_file: OutputTemplate) -> None:
output_file.datetime_imports.update(self.datetime_imports)
output_file.typing_imports.update(self.typing_imports)
output_file.pydantic_imports.update(self.pydantic_imports)
output_file.builtins_import = output_file.builtins_import or self.use_builtins
@ -488,7 +474,9 @@ class FieldCompiler(MessageCompiler):
@property
def mutable(self) -> bool:
"""True if the field is a mutable type, otherwise False."""
return self.annotation.startswith(("List[", "Dict["))
return self.annotation.startswith(
("typing.List[", "typing.Dict[", "dict[", "list[", "Dict[", "List[")
)
@property
def field_type(self) -> str:
@ -562,6 +550,7 @@ class FieldCompiler(MessageCompiler):
package=self.output_file.package,
imports=self.output_file.imports,
source_type=self.proto_obj.type_name,
typing_compiler=self.typing_compiler,
pydantic=self.output_file.pydantic_dataclasses,
)
else:
@ -573,9 +562,9 @@ class FieldCompiler(MessageCompiler):
if self.use_builtins:
py_type = f"builtins.{py_type}"
if self.repeated:
return f"List[{py_type}]"
return self.typing_compiler.list(py_type)
if self.optional:
return f"Optional[{py_type}]"
return self.typing_compiler.optional(py_type)
return py_type
@ -600,7 +589,7 @@ class PydanticOneOfFieldCompiler(OneOfFieldCompiler):
@property
def pydantic_imports(self) -> Set[str]:
return {"root_validator"}
return {"model_validator"}
@dataclass
@ -623,11 +612,13 @@ class MapEntryCompiler(FieldCompiler):
source_file=self.source_file,
parent=self,
proto_obj=nested.field[0], # key
typing_compiler=self.typing_compiler,
).py_type
self.py_v_type = FieldCompiler(
source_file=self.source_file,
parent=self,
proto_obj=nested.field[1], # value
typing_compiler=self.typing_compiler,
).py_type
# Get proto types
@ -645,7 +636,7 @@ class MapEntryCompiler(FieldCompiler):
@property
def annotation(self) -> str:
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
return self.typing_compiler.dict(self.py_k_type, self.py_v_type)
@property
def repeated(self) -> bool:
@ -702,7 +693,6 @@ class ServiceCompiler(ProtoContentBase):
def __post_init__(self) -> None:
# Add service to output file
self.output_file.services.append(self)
self.output_file.typing_imports.add("Dict")
super().__post_init__() # check for unset fields
@property
@ -725,22 +715,6 @@ class ServiceMethodCompiler(ProtoContentBase):
# Add method to service
self.parent.methods.append(self)
# Check for imports
if "Optional" in self.py_output_message_type:
self.output_file.typing_imports.add("Optional")
# Check for Async imports
if self.client_streaming:
self.output_file.typing_imports.add("AsyncIterable")
self.output_file.typing_imports.add("Iterable")
self.output_file.typing_imports.add("Union")
# Required by both client and server
if self.client_streaming or self.server_streaming:
self.output_file.typing_imports.add("AsyncIterator")
# add imports required for request arguments timeout, deadline and metadata
self.output_file.typing_imports.add("Optional")
self.output_file.imports_type_checking_only.add("import grpclib.server")
self.output_file.imports_type_checking_only.add(
"from betterproto.grpc.grpclib_client import MetadataLike"
@ -806,6 +780,7 @@ class ServiceMethodCompiler(ProtoContentBase):
package=self.output_file.package,
imports=self.output_file.imports,
source_type=self.proto_obj.input_type,
typing_compiler=self.output_file.typing_compiler,
unwrap=False,
pydantic=self.output_file.pydantic_dataclasses,
).strip('"')
@ -835,6 +810,7 @@ class ServiceMethodCompiler(ProtoContentBase):
package=self.output_file.package,
imports=self.output_file.imports,
source_type=self.proto_obj.output_type,
typing_compiler=self.output_file.typing_compiler,
unwrap=False,
pydantic=self.output_file.pydantic_dataclasses,
).strip('"')

View File

@ -0,0 +1,163 @@
import re
from collections import defaultdict
from dataclasses import (
dataclass,
field,
)
from typing import (
Dict,
Iterator,
List,
Tuple,
)
@dataclass
class ModuleValidator:
line_iterator: Iterator[str]
line_number: int = field(init=False, default=0)
collisions: Dict[str, List[Tuple[int, str]]] = field(
init=False, default_factory=lambda: defaultdict(list)
)
def add_import(self, imp: str, number: int, full_line: str):
"""
Adds an import to be tracked.
"""
self.collisions[imp].append((number, full_line))
def process_import(self, imp: str):
"""
Filters out the import to its actual value.
"""
if " as " in imp:
imp = imp[imp.index(" as ") + 4 :]
imp = imp.strip()
assert " " not in imp, imp
return imp
def evaluate_multiline_import(self, line: str):
"""
Evaluates a multiline import from a starting line
"""
# Filter the first line and remove anything before the import statement.
full_line = line
line = line.split("import", 1)[1]
if "(" in line:
conditional = lambda line: ")" not in line
else:
conditional = lambda line: "\\" in line
# Remove open parenthesis if it exists.
if "(" in line:
line = line[line.index("(") + 1 :]
# Choose the conditional based on how multiline imports are formatted.
while conditional(line):
# Split the line by commas
imports = line.split(",")
for imp in imports:
# Add the import to the namespace
imp = self.process_import(imp)
if imp:
self.add_import(imp, self.line_number, full_line)
# Get the next line
full_line = line = next(self.line_iterator)
# Increment the line number
self.line_number += 1
# validate the last line
if ")" in line:
line = line[: line.index(")")]
imports = line.split(",")
for imp in imports:
imp = self.process_import(imp)
if imp:
self.add_import(imp, self.line_number, full_line)
def evaluate_import(self, line: str):
"""
Extracts an import from a line.
"""
whole_line = line
line = line[line.index("import") + 6 :]
values = line.split(",")
for v in values:
self.add_import(self.process_import(v), self.line_number, whole_line)
def next(self):
"""
Evaluate each line for names in the module.
"""
line = next(self.line_iterator)
# Skip lines with indentation or comments
if (
# Skip indents and whitespace.
line.startswith(" ")
or line == "\n"
or line.startswith("\t")
or
# Skip comments
line.startswith("#")
or
# Skip decorators
line.startswith("@")
):
self.line_number += 1
return
# Skip docstrings.
if line.startswith('"""') or line.startswith("'''"):
quote = line[0] * 3
line = line[3:]
while quote not in line:
line = next(self.line_iterator)
self.line_number += 1
return
# Evaluate Imports.
if line.startswith("from ") or line.startswith("import "):
if "(" in line or "\\" in line:
self.evaluate_multiline_import(line)
else:
self.evaluate_import(line)
# Evaluate Classes.
elif line.startswith("class "):
class_name = re.search(r"class (\w+)", line).group(1)
if class_name:
self.add_import(class_name, self.line_number, line)
# Evaluate Functions.
elif line.startswith("def "):
function_name = re.search(r"def (\w+)", line).group(1)
if function_name:
self.add_import(function_name, self.line_number, line)
# Evaluate direct assignments.
elif "=" in line:
assignment = re.search(r"(\w+)\s*=", line).group(1)
if assignment:
self.add_import(assignment, self.line_number, line)
self.line_number += 1
def validate(self) -> bool:
"""
Run Validation.
"""
try:
while True:
self.next()
except StopIteration:
pass
# Filter collisions for those with more than one value.
self.collisions = {k: v for k, v in self.collisions.items() if len(v) > 1}
# Return True if no collisions are found.
return not bool(self.collisions)

View File

@ -37,6 +37,12 @@ from .models import (
is_map,
is_oneof,
)
from .typing_compiler import (
DirectImportTypingCompiler,
NoTyping310TypingCompiler,
TypingCompiler,
TypingImportTypingCompiler,
)
def traverse(
@ -98,6 +104,28 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
output_package_name
].pydantic_dataclasses = True
# Gather any typing generation options.
typing_opts = [
opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")
]
if len(typing_opts) > 1:
raise ValueError("Multiple typing options provided")
# Set the compiler type.
typing_opt = typing_opts[0] if typing_opts else "direct"
if typing_opt == "direct":
request_data.output_packages[
output_package_name
].typing_compiler = DirectImportTypingCompiler()
elif typing_opt == "root":
request_data.output_packages[
output_package_name
].typing_compiler = TypingImportTypingCompiler()
elif typing_opt == "310":
request_data.output_packages[
output_package_name
].typing_compiler = NoTyping310TypingCompiler()
# Read Messages and Enums
# We need to read Messages before Services in so that we can
# get the references to input/output messages for each service
@ -166,6 +194,7 @@ def _make_one_of_field_compiler(
parent=parent,
proto_obj=proto_obj,
path=path,
typing_compiler=output_package.typing_compiler,
)
@ -181,7 +210,11 @@ def read_protobuf_type(
return
# Process Message
message_data = MessageCompiler(
source_file=source_file, parent=output_package, proto_obj=item, path=path
source_file=source_file,
parent=output_package,
proto_obj=item,
path=path,
typing_compiler=output_package.typing_compiler,
)
for index, field in enumerate(item.field):
if is_map(field, item):
@ -190,6 +223,7 @@ def read_protobuf_type(
parent=message_data,
proto_obj=field,
path=path + [2, index],
typing_compiler=output_package.typing_compiler,
)
elif is_oneof(field):
_make_one_of_field_compiler(
@ -201,11 +235,16 @@ def read_protobuf_type(
parent=message_data,
proto_obj=field,
path=path + [2, index],
typing_compiler=output_package.typing_compiler,
)
elif isinstance(item, EnumDescriptorProto):
# Enum
EnumDefinitionCompiler(
source_file=source_file, parent=output_package, proto_obj=item, path=path
source_file=source_file,
parent=output_package,
proto_obj=item,
path=path,
typing_compiler=output_package.typing_compiler,
)

View File

@ -0,0 +1,167 @@
import abc
from collections import defaultdict
from dataclasses import (
dataclass,
field,
)
from typing import (
Dict,
Iterator,
Optional,
Set,
)
class TypingCompiler(metaclass=abc.ABCMeta):
@abc.abstractmethod
def optional(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def list(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def dict(self, key: str, value: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def union(self, *types: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def iterable(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def async_iterable(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def async_iterator(self, type: str) -> str:
raise NotImplementedError()
@abc.abstractmethod
def imports(self) -> Dict[str, Optional[Set[str]]]:
"""
Returns either the direct import as a key with none as value, or a set of
values to import from the key.
"""
raise NotImplementedError()
def import_lines(self) -> Iterator:
imports = self.imports()
for key, value in imports.items():
if value is None:
yield f"import {key}"
else:
yield f"from {key} import ("
for v in sorted(value):
yield f" {v},"
yield ")"
@dataclass
class DirectImportTypingCompiler(TypingCompiler):
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
def optional(self, type: str) -> str:
self._imports["typing"].add("Optional")
return f"Optional[{type}]"
def list(self, type: str) -> str:
self._imports["typing"].add("List")
return f"List[{type}]"
def dict(self, key: str, value: str) -> str:
self._imports["typing"].add("Dict")
return f"Dict[{key}, {value}]"
def union(self, *types: str) -> str:
self._imports["typing"].add("Union")
return f"Union[{', '.join(types)}]"
def iterable(self, type: str) -> str:
self._imports["typing"].add("Iterable")
return f"Iterable[{type}]"
def async_iterable(self, type: str) -> str:
self._imports["typing"].add("AsyncIterable")
return f"AsyncIterable[{type}]"
def async_iterator(self, type: str) -> str:
self._imports["typing"].add("AsyncIterator")
return f"AsyncIterator[{type}]"
def imports(self) -> Dict[str, Optional[Set[str]]]:
return {k: v if v else None for k, v in self._imports.items()}
@dataclass
class TypingImportTypingCompiler(TypingCompiler):
_imported: bool = False
def optional(self, type: str) -> str:
self._imported = True
return f"typing.Optional[{type}]"
def list(self, type: str) -> str:
self._imported = True
return f"typing.List[{type}]"
def dict(self, key: str, value: str) -> str:
self._imported = True
return f"typing.Dict[{key}, {value}]"
def union(self, *types: str) -> str:
self._imported = True
return f"typing.Union[{', '.join(types)}]"
def iterable(self, type: str) -> str:
self._imported = True
return f"typing.Iterable[{type}]"
def async_iterable(self, type: str) -> str:
self._imported = True
return f"typing.AsyncIterable[{type}]"
def async_iterator(self, type: str) -> str:
self._imported = True
return f"typing.AsyncIterator[{type}]"
def imports(self) -> Dict[str, Optional[Set[str]]]:
if self._imported:
return {"typing": None}
return {}
@dataclass
class NoTyping310TypingCompiler(TypingCompiler):
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
def optional(self, type: str) -> str:
return f"{type} | None"
def list(self, type: str) -> str:
return f"list[{type}]"
def dict(self, key: str, value: str) -> str:
return f"dict[{key}, {value}]"
def union(self, *types: str) -> str:
return " | ".join(types)
def iterable(self, type: str) -> str:
self._imports["typing"].add("Iterable")
return f"Iterable[{type}]"
def async_iterable(self, type: str) -> str:
self._imports["typing"].add("AsyncIterable")
return f"AsyncIterable[{type}]"
def async_iterator(self, type: str) -> str:
self._imports["typing"].add("AsyncIterator")
return f"AsyncIterator[{type}]"
def imports(self) -> Dict[str, Optional[Set[str]]]:
return {k: v if v else None for k, v in self._imports.items()}

View File

@ -0,0 +1,55 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: {{ ', '.join(output_file.input_filenames) }}
# plugin: python-betterproto
# This file has been @generated
{% for i in output_file.python_module_imports|sort %}
import {{ i }}
{% endfor %}
{% set type_checking_imported = False %}
{% if output_file.pydantic_dataclasses %}
from typing import TYPE_CHECKING
{% set type_checking_imported = True %}
if TYPE_CHECKING:
from dataclasses import dataclass
else:
from pydantic.dataclasses import dataclass
from pydantic.dataclasses import rebuild_dataclass
{%- else -%}
from dataclasses import dataclass
{% endif %}
{% if output_file.datetime_imports %}
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif%}
{% set typing_imports = output_file.typing_compiler.imports() %}
{% if typing_imports %}
{% for line in output_file.typing_compiler.import_lines() %}
{{ line }}
{% endfor %}
{% endif %}
{% if output_file.pydantic_imports %}
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
import betterproto
{% if output_file.services %}
from betterproto.grpc.grpclib_server import ServiceBase
import grpclib
{% endif %}
{% for i in output_file.imports|sort %}
{{ i }}
{% endfor %}
{% if output_file.imports_type_checking_only and not type_checking_imported %}
from typing import TYPE_CHECKING
if TYPE_CHECKING:
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
{% endfor %}
{% endif %}

View File

@ -1,53 +1,3 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: {{ ', '.join(output_file.input_filenames) }}
# plugin: python-betterproto
# This file has been @generated
{% for i in output_file.python_module_imports|sort %}
import {{ i }}
{% endfor %}
{% if output_file.pydantic_dataclasses %}
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from dataclasses import dataclass
else:
from pydantic.dataclasses import dataclass
{%- else -%}
from dataclasses import dataclass
{% endif %}
{% if output_file.datetime_imports %}
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif%}
{% if output_file.typing_imports %}
from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
{% if output_file.pydantic_imports %}
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
import betterproto
{% if output_file.services %}
from betterproto.grpc.grpclib_server import ServiceBase
import grpclib
{% endif %}
{% for i in output_file.imports|sort %}
{{ i }}
{% endfor %}
{% if output_file.imports_type_checking_only %}
from typing import TYPE_CHECKING
if TYPE_CHECKING:
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
{% endfor %}
{% endif %}
{% if output_file.enums %}{% for enum in output_file.enums %}
class {{ enum.py_name }}(betterproto.Enum):
{% if enum.comment %}
@ -62,6 +12,13 @@ class {{ enum.py_name }}(betterproto.Enum):
{% endif %}
{% endfor %}
{% if output_file.pydantic_dataclasses %}
@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler):
from pydantic_core import core_schema
return core_schema.int_schema(ge=0)
{% endif %}
{% endfor %}
{% endif %}
@ -96,7 +53,7 @@ class {{ message.py_name }}(betterproto.Message):
{% endif %}
{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
@root_validator()
@model_validator(mode='after')
def check_oneof(cls, values):
return cls._validate_field_groups(values)
{% endif %}
@ -116,14 +73,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }}
{%- endif -%}
,
*
, timeout: Optional[float] = None
, deadline: Optional["Deadline"] = None
, metadata: Optional["MetadataLike"] = None
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
, timeout: {{ output_file.typing_compiler.optional("float") }} = None
, deadline: {{ output_file.typing_compiler.optional('"Deadline"') }} = None
, metadata: {{ output_file.typing_compiler.optional('"MetadataLike"') }} = None
) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
{{ method.comment }}
@ -191,9 +148,9 @@ class {{ service.py_name }}Base(ServiceBase):
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }}
{%- endif -%}
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
{{ method.comment }}
@ -225,7 +182,7 @@ class {{ service.py_name }}Base(ServiceBase):
{% endfor %}
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
def __mapping__(self) -> {{ output_file.typing_compiler.dict("str", "grpclib.const.Handler") }}:
return {
{% for method in service.methods %}
"{{ method.route }}": grpclib.const.Handler(
@ -250,7 +207,7 @@ class {{ service.py_name }}Base(ServiceBase):
{% if output_file.pydantic_dataclasses %}
{% for message in output_file.messages %}
{% if message.has_message_field %}
{{ message.py_name }}.__pydantic_model__.update_forward_refs() # type: ignore
rebuild_dataclass({{ message.py_name }}) # type: ignore
{% endif %}
{% endfor %}
{% endif %}

View File

@ -108,6 +108,7 @@ async def generate_test_case_output(
print(
f"\033[31;1;4mFailed to generate reference output for {test_case_name!r}\033[0m"
)
print(ref_err.decode())
if verbose:
if ref_out:
@ -126,6 +127,7 @@ async def generate_test_case_output(
print(
f"\033[31;1;4mFailed to generate plugin output for {test_case_name!r}\033[0m"
)
print(plg_err.decode())
if verbose:
if plg_out:
@ -146,6 +148,7 @@ async def generate_test_case_output(
print(
f"\033[31;1;4mFailed to generate plugin (pydantic compatible) output for {test_case_name!r}\033[0m"
)
print(plg_err_pyd.decode())
if verbose:
if plg_out_pyd:

View File

@ -10,10 +10,15 @@ def test_value():
def test_pydantic_no_value():
with pytest.raises(ValueError):
TestPyd()
message = TestPyd()
assert not message.value, "Boolean is False by default"
def test_pydantic_value():
message = Test(value=False)
message = TestPyd(value=False)
assert not message.value
def test_pydantic_bad_value():
with pytest.raises(ValueError):
TestPyd(value=123)

View File

@ -4,6 +4,15 @@ from betterproto.compile.importing import (
get_type_reference,
parse_source_type_name,
)
from betterproto.plugin.typing_compiler import DirectImportTypingCompiler
@pytest.fixture
def typing_compiler() -> DirectImportTypingCompiler:
"""
Generates a simple Direct Import Typing Compiler for testing.
"""
return DirectImportTypingCompiler()
@pytest.mark.parametrize(
@ -32,11 +41,18 @@ from betterproto.compile.importing import (
],
)
def test_reference_google_wellknown_types_non_wrappers(
google_type: str, expected_name: str, expected_import: str
google_type: str,
expected_name: str,
expected_import: str,
typing_compiler: DirectImportTypingCompiler,
):
imports = set()
name = get_type_reference(
package="", imports=imports, source_type=google_type, pydantic=False
package="",
imports=imports,
source_type=google_type,
typing_compiler=typing_compiler,
pydantic=False,
)
assert name == expected_name
@ -71,11 +87,18 @@ def test_reference_google_wellknown_types_non_wrappers(
],
)
def test_reference_google_wellknown_types_non_wrappers_pydantic(
google_type: str, expected_name: str, expected_import: str
google_type: str,
expected_name: str,
expected_import: str,
typing_compiler: DirectImportTypingCompiler,
):
imports = set()
name = get_type_reference(
package="", imports=imports, source_type=google_type, pydantic=True
package="",
imports=imports,
source_type=google_type,
typing_compiler=typing_compiler,
pydantic=True,
)
assert name == expected_name
@ -99,10 +122,15 @@ def test_reference_google_wellknown_types_non_wrappers_pydantic(
],
)
def test_referenceing_google_wrappers_unwraps_them(
google_type: str, expected_name: str
google_type: str, expected_name: str, typing_compiler: DirectImportTypingCompiler
):
imports = set()
name = get_type_reference(package="", imports=imports, source_type=google_type)
name = get_type_reference(
package="",
imports=imports,
source_type=google_type,
typing_compiler=typing_compiler,
)
assert name == expected_name
assert imports == set()
@ -135,223 +163,321 @@ def test_referenceing_google_wrappers_unwraps_them(
],
)
def test_referenceing_google_wrappers_without_unwrapping(
google_type: str, expected_name: str
google_type: str, expected_name: str, typing_compiler: DirectImportTypingCompiler
):
name = get_type_reference(
package="", imports=set(), source_type=google_type, unwrap=False
package="",
imports=set(),
source_type=google_type,
typing_compiler=typing_compiler,
unwrap=False,
)
assert name == expected_name
def test_reference_child_package_from_package():
def test_reference_child_package_from_package(
typing_compiler: DirectImportTypingCompiler,
):
imports = set()
name = get_type_reference(
package="package", imports=imports, source_type="package.child.Message"
package="package",
imports=imports,
source_type="package.child.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from . import child"}
assert name == '"child.Message"'
def test_reference_child_package_from_root():
def test_reference_child_package_from_root(typing_compiler: DirectImportTypingCompiler):
imports = set()
name = get_type_reference(package="", imports=imports, source_type="child.Message")
name = get_type_reference(
package="",
imports=imports,
source_type="child.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from . import child"}
assert name == '"child.Message"'
def test_reference_camel_cased():
def test_reference_camel_cased(typing_compiler: DirectImportTypingCompiler):
imports = set()
name = get_type_reference(
package="", imports=imports, source_type="child_package.example_message"
package="",
imports=imports,
source_type="child_package.example_message",
typing_compiler=typing_compiler,
)
assert imports == {"from . import child_package"}
assert name == '"child_package.ExampleMessage"'
def test_reference_nested_child_from_root():
def test_reference_nested_child_from_root(typing_compiler: DirectImportTypingCompiler):
imports = set()
name = get_type_reference(
package="", imports=imports, source_type="nested.child.Message"
package="",
imports=imports,
source_type="nested.child.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from .nested import child as nested_child"}
assert name == '"nested_child.Message"'
def test_reference_deeply_nested_child_from_root():
def test_reference_deeply_nested_child_from_root(
typing_compiler: DirectImportTypingCompiler,
):
imports = set()
name = get_type_reference(
package="", imports=imports, source_type="deeply.nested.child.Message"
package="",
imports=imports,
source_type="deeply.nested.child.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from .deeply.nested import child as deeply_nested_child"}
assert name == '"deeply_nested_child.Message"'
def test_reference_deeply_nested_child_from_package():
def test_reference_deeply_nested_child_from_package(
typing_compiler: DirectImportTypingCompiler,
):
imports = set()
name = get_type_reference(
package="package",
imports=imports,
source_type="package.deeply.nested.child.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from .deeply.nested import child as deeply_nested_child"}
assert name == '"deeply_nested_child.Message"'
def test_reference_root_sibling():
imports = set()
name = get_type_reference(package="", imports=imports, source_type="Message")
assert imports == set()
assert name == '"Message"'
def test_reference_nested_siblings():
imports = set()
name = get_type_reference(package="foo", imports=imports, source_type="foo.Message")
assert imports == set()
assert name == '"Message"'
def test_reference_deeply_nested_siblings():
def test_reference_root_sibling(typing_compiler: DirectImportTypingCompiler):
imports = set()
name = get_type_reference(
package="foo.bar", imports=imports, source_type="foo.bar.Message"
package="",
imports=imports,
source_type="Message",
typing_compiler=typing_compiler,
)
assert imports == set()
assert name == '"Message"'
def test_reference_parent_package_from_child():
def test_reference_nested_siblings(typing_compiler: DirectImportTypingCompiler):
imports = set()
name = get_type_reference(
package="package.child", imports=imports, source_type="package.Message"
package="foo",
imports=imports,
source_type="foo.Message",
typing_compiler=typing_compiler,
)
assert imports == set()
assert name == '"Message"'
def test_reference_deeply_nested_siblings(typing_compiler: DirectImportTypingCompiler):
imports = set()
name = get_type_reference(
package="foo.bar",
imports=imports,
source_type="foo.bar.Message",
typing_compiler=typing_compiler,
)
assert imports == set()
assert name == '"Message"'
def test_reference_parent_package_from_child(
typing_compiler: DirectImportTypingCompiler,
):
imports = set()
name = get_type_reference(
package="package.child",
imports=imports,
source_type="package.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from ... import package as __package__"}
assert name == '"__package__.Message"'
def test_reference_parent_package_from_deeply_nested_child():
def test_reference_parent_package_from_deeply_nested_child(
typing_compiler: DirectImportTypingCompiler,
):
imports = set()
name = get_type_reference(
package="package.deeply.nested.child",
imports=imports,
source_type="package.deeply.nested.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from ... import nested as __nested__"}
assert name == '"__nested__.Message"'
def test_reference_ancestor_package_from_nested_child():
def test_reference_ancestor_package_from_nested_child(
typing_compiler: DirectImportTypingCompiler,
):
imports = set()
name = get_type_reference(
package="package.ancestor.nested.child",
imports=imports,
source_type="package.ancestor.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from .... import ancestor as ___ancestor__"}
assert name == '"___ancestor__.Message"'
def test_reference_root_package_from_child():
def test_reference_root_package_from_child(typing_compiler: DirectImportTypingCompiler):
imports = set()
name = get_type_reference(
package="package.child", imports=imports, source_type="Message"
package="package.child",
imports=imports,
source_type="Message",
typing_compiler=typing_compiler,
)
assert imports == {"from ... import Message as __Message__"}
assert name == '"__Message__"'
def test_reference_root_package_from_deeply_nested_child():
def test_reference_root_package_from_deeply_nested_child(
typing_compiler: DirectImportTypingCompiler,
):
imports = set()
name = get_type_reference(
package="package.deeply.nested.child", imports=imports, source_type="Message"
package="package.deeply.nested.child",
imports=imports,
source_type="Message",
typing_compiler=typing_compiler,
)
assert imports == {"from ..... import Message as ____Message__"}
assert name == '"____Message__"'
def test_reference_unrelated_package():
def test_reference_unrelated_package(typing_compiler: DirectImportTypingCompiler):
imports = set()
name = get_type_reference(package="a", imports=imports, source_type="p.Message")
name = get_type_reference(
package="a",
imports=imports,
source_type="p.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from .. import p as _p__"}
assert name == '"_p__.Message"'
def test_reference_unrelated_nested_package():
def test_reference_unrelated_nested_package(
typing_compiler: DirectImportTypingCompiler,
):
imports = set()
name = get_type_reference(package="a.b", imports=imports, source_type="p.q.Message")
name = get_type_reference(
package="a.b",
imports=imports,
source_type="p.q.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from ...p import q as __p_q__"}
assert name == '"__p_q__.Message"'
def test_reference_unrelated_deeply_nested_package():
def test_reference_unrelated_deeply_nested_package(
typing_compiler: DirectImportTypingCompiler,
):
imports = set()
name = get_type_reference(
package="a.b.c.d", imports=imports, source_type="p.q.r.s.Message"
package="a.b.c.d",
imports=imports,
source_type="p.q.r.s.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from .....p.q.r import s as ____p_q_r_s__"}
assert name == '"____p_q_r_s__.Message"'
def test_reference_cousin_package():
def test_reference_cousin_package(typing_compiler: DirectImportTypingCompiler):
imports = set()
name = get_type_reference(package="a.x", imports=imports, source_type="a.y.Message")
name = get_type_reference(
package="a.x",
imports=imports,
source_type="a.y.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from .. import y as _y__"}
assert name == '"_y__.Message"'
def test_reference_cousin_package_different_name():
def test_reference_cousin_package_different_name(
typing_compiler: DirectImportTypingCompiler,
):
imports = set()
name = get_type_reference(
package="test.package1", imports=imports, source_type="cousin.package2.Message"
package="test.package1",
imports=imports,
source_type="cousin.package2.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from ...cousin import package2 as __cousin_package2__"}
assert name == '"__cousin_package2__.Message"'
def test_reference_cousin_package_same_name():
def test_reference_cousin_package_same_name(
typing_compiler: DirectImportTypingCompiler,
):
imports = set()
name = get_type_reference(
package="test.package", imports=imports, source_type="cousin.package.Message"
package="test.package",
imports=imports,
source_type="cousin.package.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from ...cousin import package as __cousin_package__"}
assert name == '"__cousin_package__.Message"'
def test_reference_far_cousin_package():
def test_reference_far_cousin_package(typing_compiler: DirectImportTypingCompiler):
imports = set()
name = get_type_reference(
package="a.x.y", imports=imports, source_type="a.b.c.Message"
package="a.x.y",
imports=imports,
source_type="a.b.c.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from ...b import c as __b_c__"}
assert name == '"__b_c__.Message"'
def test_reference_far_far_cousin_package():
def test_reference_far_far_cousin_package(typing_compiler: DirectImportTypingCompiler):
imports = set()
name = get_type_reference(
package="a.x.y.z", imports=imports, source_type="a.b.c.d.Message"
package="a.x.y.z",
imports=imports,
source_type="a.b.c.d.Message",
typing_compiler=typing_compiler,
)
assert imports == {"from ....b.c import d as ___b_c_d__"}

View File

@ -0,0 +1,111 @@
from typing import (
List,
Optional,
Set,
)
import pytest
from betterproto.plugin.module_validation import ModuleValidator
@pytest.mark.parametrize(
["text", "expected_collisions"],
[
pytest.param(
["import os"],
None,
id="single import",
),
pytest.param(
["import os", "import sys"],
None,
id="multiple imports",
),
pytest.param(
["import os", "import os"],
{"os"},
id="duplicate imports",
),
pytest.param(
["from os import path", "import os"],
None,
id="duplicate imports with alias",
),
pytest.param(
["from os import path", "import os as os_alias"],
None,
id="duplicate imports with alias",
),
pytest.param(
["from os import path", "import os as path"],
{"path"},
id="duplicate imports with alias",
),
pytest.param(
["import os", "class os:"],
{"os"},
id="duplicate import with class",
),
pytest.param(
["import os", "class os:", " pass", "import sys"],
{"os"},
id="duplicate import with class and another",
),
pytest.param(
["def test(): pass", "class test:"],
{"test"},
id="duplicate class and function",
),
pytest.param(
["def test(): pass", "def test(): pass"],
{"test"},
id="duplicate functions",
),
pytest.param(
["def test(): pass", "test = 100"],
{"test"},
id="function and variable",
),
pytest.param(
["def test():", " test = 3"],
None,
id="function and variable in function",
),
pytest.param(
[
"def test(): pass",
"'''",
"def test(): pass",
"'''",
"def test_2(): pass",
],
None,
id="duplicate functions with multiline string",
),
pytest.param(
["def test(): pass", "# def test(): pass"],
None,
id="duplicate functions with comments",
),
pytest.param(
["from test import (", " A", " B", " C", ")"],
None,
id="multiline import",
),
pytest.param(
["from test import (", " A", " B", " C", ")", "from test import A"],
{"A"},
id="multiline import with duplicate",
),
],
)
def test_module_validator(text: List[str], expected_collisions: Optional[Set[str]]):
line_iterator = iter(text)
validator = ModuleValidator(line_iterator)
valid = validator.validate()
if expected_collisions is None:
assert valid
else:
assert set(validator.collisions.keys()) == expected_collisions
assert not valid

View File

@ -0,0 +1,80 @@
import pytest
from betterproto.plugin.typing_compiler import (
DirectImportTypingCompiler,
NoTyping310TypingCompiler,
TypingImportTypingCompiler,
)
def test_direct_import_typing_compiler():
compiler = DirectImportTypingCompiler()
assert compiler.imports() == {}
assert compiler.optional("str") == "Optional[str]"
assert compiler.imports() == {"typing": {"Optional"}}
assert compiler.list("str") == "List[str]"
assert compiler.imports() == {"typing": {"Optional", "List"}}
assert compiler.dict("str", "int") == "Dict[str, int]"
assert compiler.imports() == {"typing": {"Optional", "List", "Dict"}}
assert compiler.union("str", "int") == "Union[str, int]"
assert compiler.imports() == {"typing": {"Optional", "List", "Dict", "Union"}}
assert compiler.iterable("str") == "Iterable[str]"
assert compiler.imports() == {
"typing": {"Optional", "List", "Dict", "Union", "Iterable"}
}
assert compiler.async_iterable("str") == "AsyncIterable[str]"
assert compiler.imports() == {
"typing": {"Optional", "List", "Dict", "Union", "Iterable", "AsyncIterable"}
}
assert compiler.async_iterator("str") == "AsyncIterator[str]"
assert compiler.imports() == {
"typing": {
"Optional",
"List",
"Dict",
"Union",
"Iterable",
"AsyncIterable",
"AsyncIterator",
}
}
def test_typing_import_typing_compiler():
compiler = TypingImportTypingCompiler()
assert compiler.imports() == {}
assert compiler.optional("str") == "typing.Optional[str]"
assert compiler.imports() == {"typing": None}
assert compiler.list("str") == "typing.List[str]"
assert compiler.imports() == {"typing": None}
assert compiler.dict("str", "int") == "typing.Dict[str, int]"
assert compiler.imports() == {"typing": None}
assert compiler.union("str", "int") == "typing.Union[str, int]"
assert compiler.imports() == {"typing": None}
assert compiler.iterable("str") == "typing.Iterable[str]"
assert compiler.imports() == {"typing": None}
assert compiler.async_iterable("str") == "typing.AsyncIterable[str]"
assert compiler.imports() == {"typing": None}
assert compiler.async_iterator("str") == "typing.AsyncIterator[str]"
assert compiler.imports() == {"typing": None}
def test_no_typing_311_typing_compiler():
compiler = NoTyping310TypingCompiler()
assert compiler.imports() == {}
assert compiler.optional("str") == "str | None"
assert compiler.imports() == {}
assert compiler.list("str") == "list[str]"
assert compiler.imports() == {}
assert compiler.dict("str", "int") == "dict[str, int]"
assert compiler.imports() == {}
assert compiler.union("str", "int") == "str | int"
assert compiler.imports() == {}
assert compiler.iterable("str") == "Iterable[str]"
assert compiler.imports() == {"typing": {"Iterable"}}
assert compiler.async_iterable("str") == "AsyncIterable[str]"
assert compiler.imports() == {"typing": {"Iterable", "AsyncIterable"}}
assert compiler.async_iterator("str") == "AsyncIterator[str]"
assert compiler.imports() == {
"typing": {"Iterable", "AsyncIterable", "AsyncIterator"}
}