Merge branch 'refs/heads/master_gh'
This commit is contained in:
@@ -1852,7 +1852,9 @@ class Message(ABC):
|
||||
continue
|
||||
|
||||
set_fields = [
|
||||
field.name for field in field_set if values[field.name] is not None
|
||||
field.name
|
||||
for field in field_set
|
||||
if getattr(values, field.name, None) is not None
|
||||
]
|
||||
|
||||
if not set_fields:
|
||||
|
||||
@@ -47,6 +47,7 @@ def get_type_reference(
|
||||
package: str,
|
||||
imports: set,
|
||||
source_type: str,
|
||||
typing_compiler: "TypingCompiler",
|
||||
unwrap: bool = True,
|
||||
pydantic: bool = False,
|
||||
) -> str:
|
||||
@@ -57,7 +58,7 @@ def get_type_reference(
|
||||
if unwrap:
|
||||
if source_type in WRAPPER_TYPES:
|
||||
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
|
||||
return f"Optional[{wrapped_type.__name__}]"
|
||||
return typing_compiler.optional(wrapped_type.__name__)
|
||||
|
||||
if source_type == ".google.protobuf.Duration":
|
||||
return "timedelta"
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# sources: google/protobuf/any.proto, google/protobuf/api.proto, google/protobuf/descriptor.proto, google/protobuf/duration.proto, google/protobuf/empty.proto, google/protobuf/field_mask.proto, google/protobuf/source_context.proto, google/protobuf/struct.proto, google/protobuf/timestamp.proto, google/protobuf/type.proto, google/protobuf/wrappers.proto
|
||||
# plugin: python-betterproto
|
||||
|
||||
# This file has been @generated
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Mapping,
|
||||
)
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from betterproto import hybridmethod
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -14,15 +21,13 @@ else:
|
||||
from typing import (
|
||||
Dict,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from pydantic import root_validator
|
||||
from typing_extensions import Self
|
||||
from pydantic import model_validator
|
||||
from pydantic.dataclasses import rebuild_dataclass
|
||||
|
||||
import betterproto
|
||||
from betterproto.utils import hybridmethod
|
||||
|
||||
|
||||
class Syntax(betterproto.Enum):
|
||||
@@ -37,6 +42,12 @@ class Syntax(betterproto.Enum):
|
||||
EDITIONS = 2
|
||||
"""Syntax `editions`."""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class FieldKind(betterproto.Enum):
|
||||
"""Basic field types."""
|
||||
@@ -98,6 +109,12 @@ class FieldKind(betterproto.Enum):
|
||||
TYPE_SINT64 = 18
|
||||
"""Field type sint64."""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class FieldCardinality(betterproto.Enum):
|
||||
"""Whether a field is optional, required, or repeated."""
|
||||
@@ -114,6 +131,12 @@ class FieldCardinality(betterproto.Enum):
|
||||
CARDINALITY_REPEATED = 3
|
||||
"""For repeated fields."""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class Edition(betterproto.Enum):
|
||||
"""The full set of known editions."""
|
||||
@@ -155,6 +178,12 @@ class Edition(betterproto.Enum):
|
||||
support a new edition.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class ExtensionRangeOptionsVerificationState(betterproto.Enum):
|
||||
"""The verification state of the extension range."""
|
||||
@@ -164,6 +193,12 @@ class ExtensionRangeOptionsVerificationState(betterproto.Enum):
|
||||
|
||||
UNVERIFIED = 1
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class FieldDescriptorProtoType(betterproto.Enum):
|
||||
TYPE_DOUBLE = 1
|
||||
@@ -210,6 +245,12 @@ class FieldDescriptorProtoType(betterproto.Enum):
|
||||
TYPE_SINT32 = 17
|
||||
TYPE_SINT64 = 18
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class FieldDescriptorProtoLabel(betterproto.Enum):
|
||||
LABEL_OPTIONAL = 1
|
||||
@@ -223,6 +264,12 @@ class FieldDescriptorProtoLabel(betterproto.Enum):
|
||||
can be used to get this behavior.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class FileOptionsOptimizeMode(betterproto.Enum):
|
||||
"""Generated classes can be optimized for speed or code size."""
|
||||
@@ -233,6 +280,12 @@ class FileOptionsOptimizeMode(betterproto.Enum):
|
||||
|
||||
LITE_RUNTIME = 3
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class FieldOptionsCType(betterproto.Enum):
|
||||
STRING = 0
|
||||
@@ -250,6 +303,12 @@ class FieldOptionsCType(betterproto.Enum):
|
||||
|
||||
STRING_PIECE = 2
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class FieldOptionsJsType(betterproto.Enum):
|
||||
JS_NORMAL = 0
|
||||
@@ -261,6 +320,12 @@ class FieldOptionsJsType(betterproto.Enum):
|
||||
JS_NUMBER = 2
|
||||
"""Use JavaScript numbers."""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class FieldOptionsOptionRetention(betterproto.Enum):
|
||||
"""
|
||||
@@ -273,6 +338,12 @@ class FieldOptionsOptionRetention(betterproto.Enum):
|
||||
RETENTION_RUNTIME = 1
|
||||
RETENTION_SOURCE = 2
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class FieldOptionsOptionTargetType(betterproto.Enum):
|
||||
"""
|
||||
@@ -293,6 +364,12 @@ class FieldOptionsOptionTargetType(betterproto.Enum):
|
||||
TARGET_TYPE_SERVICE = 8
|
||||
TARGET_TYPE_METHOD = 9
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class MethodOptionsIdempotencyLevel(betterproto.Enum):
|
||||
"""
|
||||
@@ -305,6 +382,12 @@ class MethodOptionsIdempotencyLevel(betterproto.Enum):
|
||||
NO_SIDE_EFFECTS = 1
|
||||
IDEMPOTENT = 2
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class FeatureSetFieldPresence(betterproto.Enum):
|
||||
FIELD_PRESENCE_UNKNOWN = 0
|
||||
@@ -312,36 +395,72 @@ class FeatureSetFieldPresence(betterproto.Enum):
|
||||
IMPLICIT = 2
|
||||
LEGACY_REQUIRED = 3
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class FeatureSetEnumType(betterproto.Enum):
|
||||
ENUM_TYPE_UNKNOWN = 0
|
||||
OPEN = 1
|
||||
CLOSED = 2
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class FeatureSetRepeatedFieldEncoding(betterproto.Enum):
|
||||
REPEATED_FIELD_ENCODING_UNKNOWN = 0
|
||||
PACKED = 1
|
||||
EXPANDED = 2
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class FeatureSetUtf8Validation(betterproto.Enum):
|
||||
UTF8_VALIDATION_UNKNOWN = 0
|
||||
VERIFY = 2
|
||||
NONE = 3
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class FeatureSetMessageEncoding(betterproto.Enum):
|
||||
MESSAGE_ENCODING_UNKNOWN = 0
|
||||
LENGTH_PREFIXED = 1
|
||||
DELIMITED = 2
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class FeatureSetJsonFormat(betterproto.Enum):
|
||||
JSON_FORMAT_UNKNOWN = 0
|
||||
ALLOW = 1
|
||||
LEGACY_BEST_EFFORT = 2
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class GeneratedCodeInfoAnnotationSemantic(betterproto.Enum):
|
||||
"""
|
||||
@@ -358,6 +477,12 @@ class GeneratedCodeInfoAnnotationSemantic(betterproto.Enum):
|
||||
ALIAS = 2
|
||||
"""An alias to the element is returned."""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
class NullValue(betterproto.Enum):
|
||||
"""
|
||||
@@ -370,6 +495,12 @@ class NullValue(betterproto.Enum):
|
||||
_ = 0
|
||||
"""Null value."""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class Any(betterproto.Message):
|
||||
@@ -1176,16 +1307,12 @@ class FileOptions(betterproto.Message):
|
||||
|
||||
java_string_check_utf8: bool = betterproto.bool_field(27)
|
||||
"""
|
||||
A proto2 file can set this to true to opt in to UTF-8 checking for Java,
|
||||
which will throw an exception if invalid UTF-8 is parsed from the wire or
|
||||
assigned to a string field.
|
||||
|
||||
TODO: clarify exactly what kinds of field types this option
|
||||
applies to, and update these docs accordingly.
|
||||
|
||||
Proto3 files already perform these checks. Setting the option explicitly to
|
||||
false has no effect: it cannot be used to opt proto3 files out of UTF-8
|
||||
checks.
|
||||
If set true, then the Java2 code generator will generate code that
|
||||
throws an exception whenever an attempt is made to assign a non-UTF-8
|
||||
byte sequence to a string field.
|
||||
Message reflection will do the same.
|
||||
However, an extension field still accepts non-UTF-8 byte sequences.
|
||||
This option has no effect on when used with the lite runtime.
|
||||
"""
|
||||
|
||||
optimize_for: "FileOptionsOptimizeMode" = betterproto.enum_field(9)
|
||||
@@ -1477,7 +1604,6 @@ class FieldOptions(betterproto.Message):
|
||||
features: "FeatureSet" = betterproto.message_field(21)
|
||||
"""Any features defined in the specific edition."""
|
||||
|
||||
feature_support: "FieldOptionsFeatureSupport" = betterproto.message_field(22)
|
||||
uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999)
|
||||
"""The parser stores options it doesn't recognize here. See above."""
|
||||
|
||||
@@ -1488,37 +1614,6 @@ class FieldOptionsEditionDefault(betterproto.Message):
|
||||
value: str = betterproto.string_field(2)
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class FieldOptionsFeatureSupport(betterproto.Message):
|
||||
"""Information about the support window of a feature."""
|
||||
|
||||
edition_introduced: "Edition" = betterproto.enum_field(1)
|
||||
"""
|
||||
The edition that this feature was first available in. In editions
|
||||
earlier than this one, the default assigned to EDITION_LEGACY will be
|
||||
used, and proto files will not be able to override it.
|
||||
"""
|
||||
|
||||
edition_deprecated: "Edition" = betterproto.enum_field(2)
|
||||
"""
|
||||
The edition this feature becomes deprecated in. Using this after this
|
||||
edition may trigger warnings.
|
||||
"""
|
||||
|
||||
deprecation_warning: str = betterproto.string_field(3)
|
||||
"""
|
||||
The deprecation warning text if this feature is used after the edition it
|
||||
was marked deprecated in.
|
||||
"""
|
||||
|
||||
edition_removed: "Edition" = betterproto.enum_field(4)
|
||||
"""
|
||||
The edition this feature is no longer available in. In editions after
|
||||
this one, the last default assigned will be used, and proto files will
|
||||
not be able to override it.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class OneofOptions(betterproto.Message):
|
||||
features: "FeatureSet" = betterproto.message_field(1)
|
||||
@@ -1723,17 +1818,7 @@ class FeatureSetDefaultsFeatureSetEditionDefault(betterproto.Message):
|
||||
"""
|
||||
|
||||
edition: "Edition" = betterproto.enum_field(3)
|
||||
overridable_features: "FeatureSet" = betterproto.message_field(4)
|
||||
"""Defaults of features that can be overridden in this edition."""
|
||||
|
||||
fixed_features: "FeatureSet" = betterproto.message_field(5)
|
||||
"""Defaults of features that can't be overridden in this edition."""
|
||||
|
||||
features: "FeatureSet" = betterproto.message_field(2)
|
||||
"""
|
||||
TODO Deprecate and remove this field, which is just the
|
||||
above two merged.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
@@ -2314,7 +2399,7 @@ class Value(betterproto.Message):
|
||||
)
|
||||
"""Represents a repeated `Value`."""
|
||||
|
||||
@root_validator()
|
||||
@model_validator(mode="after")
|
||||
def check_oneof(cls, values):
|
||||
return cls._validate_field_groups(values)
|
||||
|
||||
@@ -2549,41 +2634,40 @@ class BytesValue(betterproto.Message):
|
||||
"""The bytes value."""
|
||||
|
||||
|
||||
Type.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
Field.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
Enum.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
EnumValue.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
Option.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
Api.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
Method.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
FileDescriptorSet.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
FileDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
DescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
DescriptorProtoExtensionRange.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
ExtensionRangeOptions.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
FieldDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
OneofDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
EnumDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
EnumValueDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
ServiceDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
MethodDescriptorProto.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
FileOptions.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
MessageOptions.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
FieldOptions.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
FieldOptionsEditionDefault.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
FieldOptionsFeatureSupport.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
OneofOptions.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
EnumOptions.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
EnumValueOptions.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
ServiceOptions.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
MethodOptions.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
UninterpretedOption.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
FeatureSet.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
FeatureSetDefaults.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
FeatureSetDefaultsFeatureSetEditionDefault.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
SourceCodeInfo.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
GeneratedCodeInfo.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
GeneratedCodeInfoAnnotation.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
Struct.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
Value.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
ListValue.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
rebuild_dataclass(Type) # type: ignore
|
||||
rebuild_dataclass(Field) # type: ignore
|
||||
rebuild_dataclass(Enum) # type: ignore
|
||||
rebuild_dataclass(EnumValue) # type: ignore
|
||||
rebuild_dataclass(Option) # type: ignore
|
||||
rebuild_dataclass(Api) # type: ignore
|
||||
rebuild_dataclass(Method) # type: ignore
|
||||
rebuild_dataclass(FileDescriptorSet) # type: ignore
|
||||
rebuild_dataclass(FileDescriptorProto) # type: ignore
|
||||
rebuild_dataclass(DescriptorProto) # type: ignore
|
||||
rebuild_dataclass(DescriptorProtoExtensionRange) # type: ignore
|
||||
rebuild_dataclass(ExtensionRangeOptions) # type: ignore
|
||||
rebuild_dataclass(FieldDescriptorProto) # type: ignore
|
||||
rebuild_dataclass(OneofDescriptorProto) # type: ignore
|
||||
rebuild_dataclass(EnumDescriptorProto) # type: ignore
|
||||
rebuild_dataclass(EnumValueDescriptorProto) # type: ignore
|
||||
rebuild_dataclass(ServiceDescriptorProto) # type: ignore
|
||||
rebuild_dataclass(MethodDescriptorProto) # type: ignore
|
||||
rebuild_dataclass(FileOptions) # type: ignore
|
||||
rebuild_dataclass(MessageOptions) # type: ignore
|
||||
rebuild_dataclass(FieldOptions) # type: ignore
|
||||
rebuild_dataclass(FieldOptionsEditionDefault) # type: ignore
|
||||
rebuild_dataclass(OneofOptions) # type: ignore
|
||||
rebuild_dataclass(EnumOptions) # type: ignore
|
||||
rebuild_dataclass(EnumValueOptions) # type: ignore
|
||||
rebuild_dataclass(ServiceOptions) # type: ignore
|
||||
rebuild_dataclass(MethodOptions) # type: ignore
|
||||
rebuild_dataclass(UninterpretedOption) # type: ignore
|
||||
rebuild_dataclass(FeatureSet) # type: ignore
|
||||
rebuild_dataclass(FeatureSetDefaults) # type: ignore
|
||||
rebuild_dataclass(FeatureSetDefaultsFeatureSetEditionDefault) # type: ignore
|
||||
rebuild_dataclass(SourceCodeInfo) # type: ignore
|
||||
rebuild_dataclass(GeneratedCodeInfo) # type: ignore
|
||||
rebuild_dataclass(GeneratedCodeInfoAnnotation) # type: ignore
|
||||
rebuild_dataclass(Struct) # type: ignore
|
||||
rebuild_dataclass(Value) # type: ignore
|
||||
rebuild_dataclass(ListValue) # type: ignore
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
from .module_validation import ModuleValidator
|
||||
|
||||
|
||||
try:
|
||||
@@ -30,9 +33,12 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
|
||||
lstrip_blocks=True,
|
||||
loader=jinja2.FileSystemLoader(templates_folder),
|
||||
)
|
||||
template = env.get_template("template.py.j2")
|
||||
# Load the body first so we have a compleate list of imports needed.
|
||||
body_template = env.get_template("template.py.j2")
|
||||
header_template = env.get_template("header.py.j2")
|
||||
|
||||
code = template.render(output_file=output_file)
|
||||
code = body_template.render(output_file=output_file)
|
||||
code = header_template.render(output_file=output_file) + code
|
||||
code = isort.api.sort_code_string(
|
||||
code=code,
|
||||
show_diff=False,
|
||||
@@ -44,7 +50,18 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
|
||||
force_grid_wrap=2,
|
||||
known_third_party=["grpclib", "betterproto"],
|
||||
)
|
||||
return black.format_str(
|
||||
code = black.format_str(
|
||||
src_contents=code,
|
||||
mode=black.Mode(),
|
||||
)
|
||||
|
||||
# Validate the generated code.
|
||||
validator = ModuleValidator(iter(code.splitlines()))
|
||||
if not validator.validate():
|
||||
message_builder = ["[WARNING]: Generated code has collisions in the module:"]
|
||||
for collision, lines in validator.collisions.items():
|
||||
message_builder.append(f' "{collision}" on lines:')
|
||||
for num, line in lines:
|
||||
message_builder.append(f" {num}:{line}")
|
||||
print("\n".join(message_builder), file=sys.stderr)
|
||||
return code
|
||||
|
||||
@@ -29,10 +29,8 @@ instantiating field `A` with parent message `B` should add a
|
||||
reference to `A` to `B`'s `fields` attribute.
|
||||
"""
|
||||
|
||||
|
||||
import builtins
|
||||
import re
|
||||
import textwrap
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
@@ -49,12 +47,6 @@ from typing import (
|
||||
)
|
||||
|
||||
import betterproto
|
||||
from betterproto import which_one_of
|
||||
from betterproto.casing import sanitize_name
|
||||
from betterproto.compile.importing import (
|
||||
get_type_reference,
|
||||
parse_source_type_name,
|
||||
)
|
||||
from betterproto.compile.naming import (
|
||||
pythonize_class_name,
|
||||
pythonize_field_name,
|
||||
@@ -72,6 +64,7 @@ from betterproto.lib.google.protobuf import (
|
||||
)
|
||||
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
|
||||
|
||||
from .. import which_one_of
|
||||
from ..compile.importing import (
|
||||
get_type_reference,
|
||||
parse_source_type_name,
|
||||
@@ -82,6 +75,10 @@ from ..compile.naming import (
|
||||
pythonize_field_name,
|
||||
pythonize_method_name,
|
||||
)
|
||||
from .typing_compiler import (
|
||||
DirectImportTypingCompiler,
|
||||
TypingCompiler,
|
||||
)
|
||||
|
||||
|
||||
# Create a unique placeholder to deal with
|
||||
@@ -173,6 +170,7 @@ class ProtoContentBase:
|
||||
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
|
||||
|
||||
source_file: FileDescriptorProto
|
||||
typing_compiler: TypingCompiler
|
||||
path: List[int]
|
||||
comment_indent: int = 4
|
||||
parent: Union["betterproto.Message", "OutputTemplate"]
|
||||
@@ -242,7 +240,6 @@ class OutputTemplate:
|
||||
input_files: List[str] = field(default_factory=list)
|
||||
imports: Set[str] = field(default_factory=set)
|
||||
datetime_imports: Set[str] = field(default_factory=set)
|
||||
typing_imports: Set[str] = field(default_factory=set)
|
||||
pydantic_imports: Set[str] = field(default_factory=set)
|
||||
builtins_import: bool = False
|
||||
messages: List["MessageCompiler"] = field(default_factory=list)
|
||||
@@ -251,6 +248,7 @@ class OutputTemplate:
|
||||
imports_type_checking_only: Set[str] = field(default_factory=set)
|
||||
pydantic_dataclasses: bool = False
|
||||
output: bool = True
|
||||
typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)
|
||||
|
||||
@property
|
||||
def package(self) -> str:
|
||||
@@ -289,6 +287,7 @@ class MessageCompiler(ProtoContentBase):
|
||||
"""Representation of a protobuf message."""
|
||||
|
||||
source_file: FileDescriptorProto
|
||||
typing_compiler: TypingCompiler
|
||||
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
|
||||
proto_obj: DescriptorProto = PLACEHOLDER
|
||||
path: List[int] = PLACEHOLDER
|
||||
@@ -319,7 +318,7 @@ class MessageCompiler(ProtoContentBase):
|
||||
@property
|
||||
def annotation(self) -> str:
|
||||
if self.repeated:
|
||||
return f"List[{self.py_name}]"
|
||||
return self.typing_compiler.list(self.py_name)
|
||||
return self.py_name
|
||||
|
||||
@property
|
||||
@@ -434,18 +433,6 @@ class FieldCompiler(MessageCompiler):
|
||||
imports.add("datetime")
|
||||
return imports
|
||||
|
||||
@property
|
||||
def typing_imports(self) -> Set[str]:
|
||||
imports = set()
|
||||
annotation = self.annotation
|
||||
if "Optional[" in annotation:
|
||||
imports.add("Optional")
|
||||
if "List[" in annotation:
|
||||
imports.add("List")
|
||||
if "Dict[" in annotation:
|
||||
imports.add("Dict")
|
||||
return imports
|
||||
|
||||
@property
|
||||
def pydantic_imports(self) -> Set[str]:
|
||||
return set()
|
||||
@@ -458,7 +445,6 @@ class FieldCompiler(MessageCompiler):
|
||||
|
||||
def add_imports_to(self, output_file: OutputTemplate) -> None:
|
||||
output_file.datetime_imports.update(self.datetime_imports)
|
||||
output_file.typing_imports.update(self.typing_imports)
|
||||
output_file.pydantic_imports.update(self.pydantic_imports)
|
||||
output_file.builtins_import = output_file.builtins_import or self.use_builtins
|
||||
|
||||
@@ -488,7 +474,9 @@ class FieldCompiler(MessageCompiler):
|
||||
@property
|
||||
def mutable(self) -> bool:
|
||||
"""True if the field is a mutable type, otherwise False."""
|
||||
return self.annotation.startswith(("List[", "Dict["))
|
||||
return self.annotation.startswith(
|
||||
("typing.List[", "typing.Dict[", "dict[", "list[", "Dict[", "List[")
|
||||
)
|
||||
|
||||
@property
|
||||
def field_type(self) -> str:
|
||||
@@ -562,6 +550,7 @@ class FieldCompiler(MessageCompiler):
|
||||
package=self.output_file.package,
|
||||
imports=self.output_file.imports,
|
||||
source_type=self.proto_obj.type_name,
|
||||
typing_compiler=self.typing_compiler,
|
||||
pydantic=self.output_file.pydantic_dataclasses,
|
||||
)
|
||||
else:
|
||||
@@ -573,9 +562,9 @@ class FieldCompiler(MessageCompiler):
|
||||
if self.use_builtins:
|
||||
py_type = f"builtins.{py_type}"
|
||||
if self.repeated:
|
||||
return f"List[{py_type}]"
|
||||
return self.typing_compiler.list(py_type)
|
||||
if self.optional:
|
||||
return f"Optional[{py_type}]"
|
||||
return self.typing_compiler.optional(py_type)
|
||||
return py_type
|
||||
|
||||
|
||||
@@ -600,7 +589,7 @@ class PydanticOneOfFieldCompiler(OneOfFieldCompiler):
|
||||
|
||||
@property
|
||||
def pydantic_imports(self) -> Set[str]:
|
||||
return {"root_validator"}
|
||||
return {"model_validator"}
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -623,11 +612,13 @@ class MapEntryCompiler(FieldCompiler):
|
||||
source_file=self.source_file,
|
||||
parent=self,
|
||||
proto_obj=nested.field[0], # key
|
||||
typing_compiler=self.typing_compiler,
|
||||
).py_type
|
||||
self.py_v_type = FieldCompiler(
|
||||
source_file=self.source_file,
|
||||
parent=self,
|
||||
proto_obj=nested.field[1], # value
|
||||
typing_compiler=self.typing_compiler,
|
||||
).py_type
|
||||
|
||||
# Get proto types
|
||||
@@ -645,7 +636,7 @@ class MapEntryCompiler(FieldCompiler):
|
||||
|
||||
@property
|
||||
def annotation(self) -> str:
|
||||
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
|
||||
return self.typing_compiler.dict(self.py_k_type, self.py_v_type)
|
||||
|
||||
@property
|
||||
def repeated(self) -> bool:
|
||||
@@ -702,7 +693,6 @@ class ServiceCompiler(ProtoContentBase):
|
||||
def __post_init__(self) -> None:
|
||||
# Add service to output file
|
||||
self.output_file.services.append(self)
|
||||
self.output_file.typing_imports.add("Dict")
|
||||
super().__post_init__() # check for unset fields
|
||||
|
||||
@property
|
||||
@@ -725,22 +715,6 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
# Add method to service
|
||||
self.parent.methods.append(self)
|
||||
|
||||
# Check for imports
|
||||
if "Optional" in self.py_output_message_type:
|
||||
self.output_file.typing_imports.add("Optional")
|
||||
|
||||
# Check for Async imports
|
||||
if self.client_streaming:
|
||||
self.output_file.typing_imports.add("AsyncIterable")
|
||||
self.output_file.typing_imports.add("Iterable")
|
||||
self.output_file.typing_imports.add("Union")
|
||||
|
||||
# Required by both client and server
|
||||
if self.client_streaming or self.server_streaming:
|
||||
self.output_file.typing_imports.add("AsyncIterator")
|
||||
|
||||
# add imports required for request arguments timeout, deadline and metadata
|
||||
self.output_file.typing_imports.add("Optional")
|
||||
self.output_file.imports_type_checking_only.add("import grpclib.server")
|
||||
self.output_file.imports_type_checking_only.add(
|
||||
"from betterproto.grpc.grpclib_client import MetadataLike"
|
||||
@@ -806,6 +780,7 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
package=self.output_file.package,
|
||||
imports=self.output_file.imports,
|
||||
source_type=self.proto_obj.input_type,
|
||||
typing_compiler=self.output_file.typing_compiler,
|
||||
unwrap=False,
|
||||
pydantic=self.output_file.pydantic_dataclasses,
|
||||
).strip('"')
|
||||
@@ -835,6 +810,7 @@ class ServiceMethodCompiler(ProtoContentBase):
|
||||
package=self.output_file.package,
|
||||
imports=self.output_file.imports,
|
||||
source_type=self.proto_obj.output_type,
|
||||
typing_compiler=self.output_file.typing_compiler,
|
||||
unwrap=False,
|
||||
pydantic=self.output_file.pydantic_dataclasses,
|
||||
).strip('"')
|
||||
|
||||
163
src/betterproto/plugin/module_validation.py
Normal file
163
src/betterproto/plugin/module_validation.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
)
|
||||
from typing import (
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleValidator:
|
||||
line_iterator: Iterator[str]
|
||||
line_number: int = field(init=False, default=0)
|
||||
|
||||
collisions: Dict[str, List[Tuple[int, str]]] = field(
|
||||
init=False, default_factory=lambda: defaultdict(list)
|
||||
)
|
||||
|
||||
def add_import(self, imp: str, number: int, full_line: str):
|
||||
"""
|
||||
Adds an import to be tracked.
|
||||
"""
|
||||
self.collisions[imp].append((number, full_line))
|
||||
|
||||
def process_import(self, imp: str):
|
||||
"""
|
||||
Filters out the import to its actual value.
|
||||
"""
|
||||
if " as " in imp:
|
||||
imp = imp[imp.index(" as ") + 4 :]
|
||||
|
||||
imp = imp.strip()
|
||||
assert " " not in imp, imp
|
||||
return imp
|
||||
|
||||
def evaluate_multiline_import(self, line: str):
|
||||
"""
|
||||
Evaluates a multiline import from a starting line
|
||||
"""
|
||||
# Filter the first line and remove anything before the import statement.
|
||||
full_line = line
|
||||
line = line.split("import", 1)[1]
|
||||
if "(" in line:
|
||||
conditional = lambda line: ")" not in line
|
||||
else:
|
||||
conditional = lambda line: "\\" in line
|
||||
|
||||
# Remove open parenthesis if it exists.
|
||||
if "(" in line:
|
||||
line = line[line.index("(") + 1 :]
|
||||
|
||||
# Choose the conditional based on how multiline imports are formatted.
|
||||
while conditional(line):
|
||||
# Split the line by commas
|
||||
imports = line.split(",")
|
||||
|
||||
for imp in imports:
|
||||
# Add the import to the namespace
|
||||
imp = self.process_import(imp)
|
||||
if imp:
|
||||
self.add_import(imp, self.line_number, full_line)
|
||||
# Get the next line
|
||||
full_line = line = next(self.line_iterator)
|
||||
# Increment the line number
|
||||
self.line_number += 1
|
||||
|
||||
# validate the last line
|
||||
if ")" in line:
|
||||
line = line[: line.index(")")]
|
||||
imports = line.split(",")
|
||||
for imp in imports:
|
||||
imp = self.process_import(imp)
|
||||
if imp:
|
||||
self.add_import(imp, self.line_number, full_line)
|
||||
|
||||
def evaluate_import(self, line: str):
|
||||
"""
|
||||
Extracts an import from a line.
|
||||
"""
|
||||
whole_line = line
|
||||
line = line[line.index("import") + 6 :]
|
||||
values = line.split(",")
|
||||
for v in values:
|
||||
self.add_import(self.process_import(v), self.line_number, whole_line)
|
||||
|
||||
def next(self):
|
||||
"""
|
||||
Evaluate each line for names in the module.
|
||||
"""
|
||||
line = next(self.line_iterator)
|
||||
|
||||
# Skip lines with indentation or comments
|
||||
if (
|
||||
# Skip indents and whitespace.
|
||||
line.startswith(" ")
|
||||
or line == "\n"
|
||||
or line.startswith("\t")
|
||||
or
|
||||
# Skip comments
|
||||
line.startswith("#")
|
||||
or
|
||||
# Skip decorators
|
||||
line.startswith("@")
|
||||
):
|
||||
self.line_number += 1
|
||||
return
|
||||
|
||||
# Skip docstrings.
|
||||
if line.startswith('"""') or line.startswith("'''"):
|
||||
quote = line[0] * 3
|
||||
line = line[3:]
|
||||
while quote not in line:
|
||||
line = next(self.line_iterator)
|
||||
self.line_number += 1
|
||||
return
|
||||
|
||||
# Evaluate Imports.
|
||||
if line.startswith("from ") or line.startswith("import "):
|
||||
if "(" in line or "\\" in line:
|
||||
self.evaluate_multiline_import(line)
|
||||
else:
|
||||
self.evaluate_import(line)
|
||||
|
||||
# Evaluate Classes.
|
||||
elif line.startswith("class "):
|
||||
class_name = re.search(r"class (\w+)", line).group(1)
|
||||
if class_name:
|
||||
self.add_import(class_name, self.line_number, line)
|
||||
|
||||
# Evaluate Functions.
|
||||
elif line.startswith("def "):
|
||||
function_name = re.search(r"def (\w+)", line).group(1)
|
||||
if function_name:
|
||||
self.add_import(function_name, self.line_number, line)
|
||||
|
||||
# Evaluate direct assignments.
|
||||
elif "=" in line:
|
||||
assignment = re.search(r"(\w+)\s*=", line).group(1)
|
||||
if assignment:
|
||||
self.add_import(assignment, self.line_number, line)
|
||||
|
||||
self.line_number += 1
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""
|
||||
Run Validation.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
self.next()
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
# Filter collisions for those with more than one value.
|
||||
self.collisions = {k: v for k, v in self.collisions.items() if len(v) > 1}
|
||||
|
||||
# Return True if no collisions are found.
|
||||
return not bool(self.collisions)
|
||||
@@ -37,6 +37,12 @@ from .models import (
|
||||
is_map,
|
||||
is_oneof,
|
||||
)
|
||||
from .typing_compiler import (
|
||||
DirectImportTypingCompiler,
|
||||
NoTyping310TypingCompiler,
|
||||
TypingCompiler,
|
||||
TypingImportTypingCompiler,
|
||||
)
|
||||
|
||||
|
||||
def traverse(
|
||||
@@ -98,6 +104,28 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
|
||||
output_package_name
|
||||
].pydantic_dataclasses = True
|
||||
|
||||
# Gather any typing generation options.
|
||||
typing_opts = [
|
||||
opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.")
|
||||
]
|
||||
|
||||
if len(typing_opts) > 1:
|
||||
raise ValueError("Multiple typing options provided")
|
||||
# Set the compiler type.
|
||||
typing_opt = typing_opts[0] if typing_opts else "direct"
|
||||
if typing_opt == "direct":
|
||||
request_data.output_packages[
|
||||
output_package_name
|
||||
].typing_compiler = DirectImportTypingCompiler()
|
||||
elif typing_opt == "root":
|
||||
request_data.output_packages[
|
||||
output_package_name
|
||||
].typing_compiler = TypingImportTypingCompiler()
|
||||
elif typing_opt == "310":
|
||||
request_data.output_packages[
|
||||
output_package_name
|
||||
].typing_compiler = NoTyping310TypingCompiler()
|
||||
|
||||
# Read Messages and Enums
|
||||
# We need to read Messages before Services in so that we can
|
||||
# get the references to input/output messages for each service
|
||||
@@ -166,6 +194,7 @@ def _make_one_of_field_compiler(
|
||||
parent=parent,
|
||||
proto_obj=proto_obj,
|
||||
path=path,
|
||||
typing_compiler=output_package.typing_compiler,
|
||||
)
|
||||
|
||||
|
||||
@@ -181,7 +210,11 @@ def read_protobuf_type(
|
||||
return
|
||||
# Process Message
|
||||
message_data = MessageCompiler(
|
||||
source_file=source_file, parent=output_package, proto_obj=item, path=path
|
||||
source_file=source_file,
|
||||
parent=output_package,
|
||||
proto_obj=item,
|
||||
path=path,
|
||||
typing_compiler=output_package.typing_compiler,
|
||||
)
|
||||
for index, field in enumerate(item.field):
|
||||
if is_map(field, item):
|
||||
@@ -190,6 +223,7 @@ def read_protobuf_type(
|
||||
parent=message_data,
|
||||
proto_obj=field,
|
||||
path=path + [2, index],
|
||||
typing_compiler=output_package.typing_compiler,
|
||||
)
|
||||
elif is_oneof(field):
|
||||
_make_one_of_field_compiler(
|
||||
@@ -201,11 +235,16 @@ def read_protobuf_type(
|
||||
parent=message_data,
|
||||
proto_obj=field,
|
||||
path=path + [2, index],
|
||||
typing_compiler=output_package.typing_compiler,
|
||||
)
|
||||
elif isinstance(item, EnumDescriptorProto):
|
||||
# Enum
|
||||
EnumDefinitionCompiler(
|
||||
source_file=source_file, parent=output_package, proto_obj=item, path=path
|
||||
source_file=source_file,
|
||||
parent=output_package,
|
||||
proto_obj=item,
|
||||
path=path,
|
||||
typing_compiler=output_package.typing_compiler,
|
||||
)
|
||||
|
||||
|
||||
|
||||
167
src/betterproto/plugin/typing_compiler.py
Normal file
167
src/betterproto/plugin/typing_compiler.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import abc
|
||||
from collections import defaultdict
|
||||
from dataclasses import (
|
||||
dataclass,
|
||||
field,
|
||||
)
|
||||
from typing import (
|
||||
Dict,
|
||||
Iterator,
|
||||
Optional,
|
||||
Set,
|
||||
)
|
||||
|
||||
|
||||
class TypingCompiler(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def optional(self, type: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def list(self, type: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def dict(self, key: str, value: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def union(self, *types: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def iterable(self, type: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def async_iterable(self, type: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def async_iterator(self, type: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def imports(self) -> Dict[str, Optional[Set[str]]]:
|
||||
"""
|
||||
Returns either the direct import as a key with none as value, or a set of
|
||||
values to import from the key.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def import_lines(self) -> Iterator:
|
||||
imports = self.imports()
|
||||
for key, value in imports.items():
|
||||
if value is None:
|
||||
yield f"import {key}"
|
||||
else:
|
||||
yield f"from {key} import ("
|
||||
for v in sorted(value):
|
||||
yield f" {v},"
|
||||
yield ")"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DirectImportTypingCompiler(TypingCompiler):
|
||||
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
|
||||
|
||||
def optional(self, type: str) -> str:
|
||||
self._imports["typing"].add("Optional")
|
||||
return f"Optional[{type}]"
|
||||
|
||||
def list(self, type: str) -> str:
|
||||
self._imports["typing"].add("List")
|
||||
return f"List[{type}]"
|
||||
|
||||
def dict(self, key: str, value: str) -> str:
|
||||
self._imports["typing"].add("Dict")
|
||||
return f"Dict[{key}, {value}]"
|
||||
|
||||
def union(self, *types: str) -> str:
|
||||
self._imports["typing"].add("Union")
|
||||
return f"Union[{', '.join(types)}]"
|
||||
|
||||
def iterable(self, type: str) -> str:
|
||||
self._imports["typing"].add("Iterable")
|
||||
return f"Iterable[{type}]"
|
||||
|
||||
def async_iterable(self, type: str) -> str:
|
||||
self._imports["typing"].add("AsyncIterable")
|
||||
return f"AsyncIterable[{type}]"
|
||||
|
||||
def async_iterator(self, type: str) -> str:
|
||||
self._imports["typing"].add("AsyncIterator")
|
||||
return f"AsyncIterator[{type}]"
|
||||
|
||||
def imports(self) -> Dict[str, Optional[Set[str]]]:
|
||||
return {k: v if v else None for k, v in self._imports.items()}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TypingImportTypingCompiler(TypingCompiler):
|
||||
_imported: bool = False
|
||||
|
||||
def optional(self, type: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.Optional[{type}]"
|
||||
|
||||
def list(self, type: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.List[{type}]"
|
||||
|
||||
def dict(self, key: str, value: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.Dict[{key}, {value}]"
|
||||
|
||||
def union(self, *types: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.Union[{', '.join(types)}]"
|
||||
|
||||
def iterable(self, type: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.Iterable[{type}]"
|
||||
|
||||
def async_iterable(self, type: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.AsyncIterable[{type}]"
|
||||
|
||||
def async_iterator(self, type: str) -> str:
|
||||
self._imported = True
|
||||
return f"typing.AsyncIterator[{type}]"
|
||||
|
||||
def imports(self) -> Dict[str, Optional[Set[str]]]:
|
||||
if self._imported:
|
||||
return {"typing": None}
|
||||
return {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class NoTyping310TypingCompiler(TypingCompiler):
|
||||
_imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set))
|
||||
|
||||
def optional(self, type: str) -> str:
|
||||
return f"{type} | None"
|
||||
|
||||
def list(self, type: str) -> str:
|
||||
return f"list[{type}]"
|
||||
|
||||
def dict(self, key: str, value: str) -> str:
|
||||
return f"dict[{key}, {value}]"
|
||||
|
||||
def union(self, *types: str) -> str:
|
||||
return " | ".join(types)
|
||||
|
||||
def iterable(self, type: str) -> str:
|
||||
self._imports["typing"].add("Iterable")
|
||||
return f"Iterable[{type}]"
|
||||
|
||||
def async_iterable(self, type: str) -> str:
|
||||
self._imports["typing"].add("AsyncIterable")
|
||||
return f"AsyncIterable[{type}]"
|
||||
|
||||
def async_iterator(self, type: str) -> str:
|
||||
self._imports["typing"].add("AsyncIterator")
|
||||
return f"AsyncIterator[{type}]"
|
||||
|
||||
def imports(self) -> Dict[str, Optional[Set[str]]]:
|
||||
return {k: v if v else None for k, v in self._imports.items()}
|
||||
55
src/betterproto/templates/header.py.j2
Normal file
55
src/betterproto/templates/header.py.j2
Normal file
@@ -0,0 +1,55 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# sources: {{ ', '.join(output_file.input_filenames) }}
|
||||
# plugin: python-betterproto
|
||||
# This file has been @generated
|
||||
{% for i in output_file.python_module_imports|sort %}
|
||||
import {{ i }}
|
||||
{% endfor %}
|
||||
{% set type_checking_imported = False %}
|
||||
|
||||
{% if output_file.pydantic_dataclasses %}
|
||||
from typing import TYPE_CHECKING
|
||||
{% set type_checking_imported = True %}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dataclasses import dataclass
|
||||
else:
|
||||
from pydantic.dataclasses import dataclass
|
||||
from pydantic.dataclasses import rebuild_dataclass
|
||||
{%- else -%}
|
||||
from dataclasses import dataclass
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.datetime_imports %}
|
||||
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif%}
|
||||
{% set typing_imports = output_file.typing_compiler.imports() %}
|
||||
{% if typing_imports %}
|
||||
{% for line in output_file.typing_compiler.import_lines() %}
|
||||
{{ line }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.pydantic_imports %}
|
||||
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif %}
|
||||
|
||||
import betterproto
|
||||
{% if output_file.services %}
|
||||
from betterproto.grpc.grpclib_server import ServiceBase
|
||||
import grpclib
|
||||
{% endif %}
|
||||
|
||||
{% for i in output_file.imports|sort %}
|
||||
{{ i }}
|
||||
{% endfor %}
|
||||
|
||||
{% if output_file.imports_type_checking_only and not type_checking_imported %}
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
@@ -1,53 +1,3 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# sources: {{ ', '.join(output_file.input_filenames) }}
|
||||
# plugin: python-betterproto
|
||||
# This file has been @generated
|
||||
{% for i in output_file.python_module_imports|sort %}
|
||||
import {{ i }}
|
||||
{% endfor %}
|
||||
|
||||
{% if output_file.pydantic_dataclasses %}
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from dataclasses import dataclass
|
||||
else:
|
||||
from pydantic.dataclasses import dataclass
|
||||
{%- else -%}
|
||||
from dataclasses import dataclass
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.datetime_imports %}
|
||||
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif%}
|
||||
{% if output_file.typing_imports %}
|
||||
from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.pydantic_imports %}
|
||||
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
|
||||
|
||||
{% endif %}
|
||||
|
||||
import betterproto
|
||||
{% if output_file.services %}
|
||||
from betterproto.grpc.grpclib_server import ServiceBase
|
||||
import grpclib
|
||||
{% endif %}
|
||||
|
||||
{% for i in output_file.imports|sort %}
|
||||
{{ i }}
|
||||
{% endfor %}
|
||||
|
||||
{% if output_file.imports_type_checking_only %}
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
{% for i in output_file.imports_type_checking_only|sort %} {{ i }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.enums %}{% for enum in output_file.enums %}
|
||||
class {{ enum.py_name }}(betterproto.Enum):
|
||||
{% if enum.comment %}
|
||||
@@ -62,6 +12,13 @@ class {{ enum.py_name }}(betterproto.Enum):
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
|
||||
{% if output_file.pydantic_dataclasses %}
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, _source_type, _handler):
|
||||
from pydantic_core import core_schema
|
||||
|
||||
return core_schema.int_schema(ge=0)
|
||||
{% endif %}
|
||||
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
@@ -96,7 +53,7 @@ class {{ message.py_name }}(betterproto.Message):
|
||||
{% endif %}
|
||||
|
||||
{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
|
||||
@root_validator()
|
||||
@model_validator(mode='after')
|
||||
def check_oneof(cls, values):
|
||||
return cls._validate_field_groups(values)
|
||||
{% endif %}
|
||||
@@ -116,14 +73,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
|
||||
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
|
||||
{%- else -%}
|
||||
{# Client streaming: need a request iterator instead #}
|
||||
, {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
|
||||
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }}
|
||||
{%- endif -%}
|
||||
,
|
||||
*
|
||||
, timeout: Optional[float] = None
|
||||
, deadline: Optional["Deadline"] = None
|
||||
, metadata: Optional["MetadataLike"] = None
|
||||
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||
, timeout: {{ output_file.typing_compiler.optional("float") }} = None
|
||||
, deadline: {{ output_file.typing_compiler.optional('"Deadline"') }} = None
|
||||
, metadata: {{ output_file.typing_compiler.optional('"MetadataLike"') }} = None
|
||||
) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||
{% if method.comment %}
|
||||
{{ method.comment }}
|
||||
|
||||
@@ -191,9 +148,9 @@ class {{ service.py_name }}Base(ServiceBase):
|
||||
{%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%}
|
||||
{%- else -%}
|
||||
{# Client streaming: need a request iterator instead #}
|
||||
, {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"]
|
||||
, {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }}
|
||||
{%- endif -%}
|
||||
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||
) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}:
|
||||
{% if method.comment %}
|
||||
{{ method.comment }}
|
||||
|
||||
@@ -225,7 +182,7 @@ class {{ service.py_name }}Base(ServiceBase):
|
||||
|
||||
{% endfor %}
|
||||
|
||||
def __mapping__(self) -> Dict[str, grpclib.const.Handler]:
|
||||
def __mapping__(self) -> {{ output_file.typing_compiler.dict("str", "grpclib.const.Handler") }}:
|
||||
return {
|
||||
{% for method in service.methods %}
|
||||
"{{ method.route }}": grpclib.const.Handler(
|
||||
@@ -250,7 +207,7 @@ class {{ service.py_name }}Base(ServiceBase):
|
||||
{% if output_file.pydantic_dataclasses %}
|
||||
{% for message in output_file.messages %}
|
||||
{% if message.has_message_field %}
|
||||
{{ message.py_name }}.__pydantic_model__.update_forward_refs() # type: ignore
|
||||
rebuild_dataclass({{ message.py_name }}) # type: ignore
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
Reference in New Issue
Block a user