Merge branch 'refs/heads/master_gh'
This commit is contained in:
		| @@ -1852,7 +1852,9 @@ class Message(ABC): | ||||
|                     continue | ||||
|  | ||||
|             set_fields = [ | ||||
|                 field.name for field in field_set if values[field.name] is not None | ||||
|                 field.name | ||||
|                 for field in field_set | ||||
|                 if getattr(values, field.name, None) is not None | ||||
|             ] | ||||
|  | ||||
|             if not set_fields: | ||||
|   | ||||
| @@ -47,6 +47,7 @@ def get_type_reference( | ||||
|     package: str, | ||||
|     imports: set, | ||||
|     source_type: str, | ||||
|     typing_compiler: "TypingCompiler", | ||||
|     unwrap: bool = True, | ||||
|     pydantic: bool = False, | ||||
| ) -> str: | ||||
| @@ -57,7 +58,7 @@ def get_type_reference( | ||||
|     if unwrap: | ||||
|         if source_type in WRAPPER_TYPES: | ||||
|             wrapped_type = type(WRAPPER_TYPES[source_type]().value) | ||||
|             return f"Optional[{wrapped_type.__name__}]" | ||||
|             return typing_compiler.optional(wrapped_type.__name__) | ||||
|  | ||||
|         if source_type == ".google.protobuf.Duration": | ||||
|             return "timedelta" | ||||
|   | ||||
| @@ -1,9 +1,16 @@ | ||||
| # Generated by the protocol buffer compiler.  DO NOT EDIT! | ||||
| # sources: google/protobuf/any.proto, google/protobuf/api.proto, google/protobuf/descriptor.proto, google/protobuf/duration.proto, google/protobuf/empty.proto, google/protobuf/field_mask.proto, google/protobuf/source_context.proto, google/protobuf/struct.proto, google/protobuf/timestamp.proto, google/protobuf/type.proto, google/protobuf/wrappers.proto | ||||
| # plugin: python-betterproto | ||||
|  | ||||
| # This file has been @generated | ||||
| import warnings | ||||
| from typing import TYPE_CHECKING | ||||
| from typing import ( | ||||
|     TYPE_CHECKING, | ||||
|     Mapping, | ||||
| ) | ||||
|  | ||||
| from typing_extensions import Self | ||||
|  | ||||
| from betterproto import hybridmethod | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
| @@ -14,15 +21,13 @@ else: | ||||
| from typing import ( | ||||
|     Dict, | ||||
|     List, | ||||
|     Mapping, | ||||
|     Optional, | ||||
| ) | ||||
|  | ||||
| from pydantic import root_validator | ||||
| from typing_extensions import Self | ||||
| from pydantic import model_validator | ||||
| from pydantic.dataclasses import rebuild_dataclass | ||||
|  | ||||
| import betterproto | ||||
| from betterproto.utils import hybridmethod | ||||
|  | ||||
|  | ||||
| class Syntax(betterproto.Enum): | ||||
| @@ -37,6 +42,12 @@ class Syntax(betterproto.Enum): | ||||
|     EDITIONS = 2 | ||||
|     """Syntax `editions`.""" | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class FieldKind(betterproto.Enum): | ||||
|     """Basic field types.""" | ||||
| @@ -98,6 +109,12 @@ class FieldKind(betterproto.Enum): | ||||
|     TYPE_SINT64 = 18 | ||||
|     """Field type sint64.""" | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class FieldCardinality(betterproto.Enum): | ||||
|     """Whether a field is optional, required, or repeated.""" | ||||
| @@ -114,6 +131,12 @@ class FieldCardinality(betterproto.Enum): | ||||
|     CARDINALITY_REPEATED = 3 | ||||
|     """For repeated fields.""" | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class Edition(betterproto.Enum): | ||||
|     """The full set of known editions.""" | ||||
| @@ -155,6 +178,12 @@ class Edition(betterproto.Enum): | ||||
|      support a new edition. | ||||
|     """ | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class ExtensionRangeOptionsVerificationState(betterproto.Enum): | ||||
|     """The verification state of the extension range.""" | ||||
| @@ -164,6 +193,12 @@ class ExtensionRangeOptionsVerificationState(betterproto.Enum): | ||||
|  | ||||
|     UNVERIFIED = 1 | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class FieldDescriptorProtoType(betterproto.Enum): | ||||
|     TYPE_DOUBLE = 1 | ||||
| @@ -210,6 +245,12 @@ class FieldDescriptorProtoType(betterproto.Enum): | ||||
|     TYPE_SINT32 = 17 | ||||
|     TYPE_SINT64 = 18 | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class FieldDescriptorProtoLabel(betterproto.Enum): | ||||
|     LABEL_OPTIONAL = 1 | ||||
| @@ -223,6 +264,12 @@ class FieldDescriptorProtoLabel(betterproto.Enum): | ||||
|      can be used to get this behavior. | ||||
|     """ | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class FileOptionsOptimizeMode(betterproto.Enum): | ||||
|     """Generated classes can be optimized for speed or code size.""" | ||||
| @@ -233,6 +280,12 @@ class FileOptionsOptimizeMode(betterproto.Enum): | ||||
|  | ||||
|     LITE_RUNTIME = 3 | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class FieldOptionsCType(betterproto.Enum): | ||||
|     STRING = 0 | ||||
| @@ -250,6 +303,12 @@ class FieldOptionsCType(betterproto.Enum): | ||||
|  | ||||
|     STRING_PIECE = 2 | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class FieldOptionsJsType(betterproto.Enum): | ||||
|     JS_NORMAL = 0 | ||||
| @@ -261,6 +320,12 @@ class FieldOptionsJsType(betterproto.Enum): | ||||
|     JS_NUMBER = 2 | ||||
|     """Use JavaScript numbers.""" | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class FieldOptionsOptionRetention(betterproto.Enum): | ||||
|     """ | ||||
| @@ -273,6 +338,12 @@ class FieldOptionsOptionRetention(betterproto.Enum): | ||||
|     RETENTION_RUNTIME = 1 | ||||
|     RETENTION_SOURCE = 2 | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class FieldOptionsOptionTargetType(betterproto.Enum): | ||||
|     """ | ||||
| @@ -293,6 +364,12 @@ class FieldOptionsOptionTargetType(betterproto.Enum): | ||||
|     TARGET_TYPE_SERVICE = 8 | ||||
|     TARGET_TYPE_METHOD = 9 | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class MethodOptionsIdempotencyLevel(betterproto.Enum): | ||||
|     """ | ||||
| @@ -305,6 +382,12 @@ class MethodOptionsIdempotencyLevel(betterproto.Enum): | ||||
|     NO_SIDE_EFFECTS = 1 | ||||
|     IDEMPOTENT = 2 | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class FeatureSetFieldPresence(betterproto.Enum): | ||||
|     FIELD_PRESENCE_UNKNOWN = 0 | ||||
| @@ -312,36 +395,72 @@ class FeatureSetFieldPresence(betterproto.Enum): | ||||
|     IMPLICIT = 2 | ||||
|     LEGACY_REQUIRED = 3 | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class FeatureSetEnumType(betterproto.Enum): | ||||
|     ENUM_TYPE_UNKNOWN = 0 | ||||
|     OPEN = 1 | ||||
|     CLOSED = 2 | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class FeatureSetRepeatedFieldEncoding(betterproto.Enum): | ||||
|     REPEATED_FIELD_ENCODING_UNKNOWN = 0 | ||||
|     PACKED = 1 | ||||
|     EXPANDED = 2 | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class FeatureSetUtf8Validation(betterproto.Enum): | ||||
|     UTF8_VALIDATION_UNKNOWN = 0 | ||||
|     VERIFY = 2 | ||||
|     NONE = 3 | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class FeatureSetMessageEncoding(betterproto.Enum): | ||||
|     MESSAGE_ENCODING_UNKNOWN = 0 | ||||
|     LENGTH_PREFIXED = 1 | ||||
|     DELIMITED = 2 | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class FeatureSetJsonFormat(betterproto.Enum): | ||||
|     JSON_FORMAT_UNKNOWN = 0 | ||||
|     ALLOW = 1 | ||||
|     LEGACY_BEST_EFFORT = 2 | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class GeneratedCodeInfoAnnotationSemantic(betterproto.Enum): | ||||
|     """ | ||||
| @@ -358,6 +477,12 @@ class GeneratedCodeInfoAnnotationSemantic(betterproto.Enum): | ||||
|     ALIAS = 2 | ||||
|     """An alias to the element is returned.""" | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| class NullValue(betterproto.Enum): | ||||
|     """ | ||||
| @@ -370,6 +495,12 @@ class NullValue(betterproto.Enum): | ||||
|     _ = 0 | ||||
|     """Null value.""" | ||||
|  | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|  | ||||
|  | ||||
| @dataclass(eq=False, repr=False) | ||||
| class Any(betterproto.Message): | ||||
| @@ -1176,16 +1307,12 @@ class FileOptions(betterproto.Message): | ||||
|  | ||||
|     java_string_check_utf8: bool = betterproto.bool_field(27) | ||||
|     """ | ||||
|     A proto2 file can set this to true to opt in to UTF-8 checking for Java, | ||||
|      which will throw an exception if invalid UTF-8 is parsed from the wire or | ||||
|      assigned to a string field. | ||||
|      | ||||
|      TODO: clarify exactly what kinds of field types this option | ||||
|      applies to, and update these docs accordingly. | ||||
|      | ||||
|      Proto3 files already perform these checks. Setting the option explicitly to | ||||
|      false has no effect: it cannot be used to opt proto3 files out of UTF-8 | ||||
|      checks. | ||||
|     If set true, then the Java2 code generator will generate code that | ||||
|      throws an exception whenever an attempt is made to assign a non-UTF-8 | ||||
|      byte sequence to a string field. | ||||
|      Message reflection will do the same. | ||||
|      However, an extension field still accepts non-UTF-8 byte sequences. | ||||
|      This option has no effect on when used with the lite runtime. | ||||
|     """ | ||||
|  | ||||
|     optimize_for: "FileOptionsOptimizeMode" = betterproto.enum_field(9) | ||||
| @@ -1477,7 +1604,6 @@ class FieldOptions(betterproto.Message): | ||||
|     features: "FeatureSet" = betterproto.message_field(21) | ||||
|     """Any features defined in the specific edition.""" | ||||
|  | ||||
|     feature_support: "FieldOptionsFeatureSupport" = betterproto.message_field(22) | ||||
|     uninterpreted_option: List["UninterpretedOption"] = betterproto.message_field(999) | ||||
|     """The parser stores options it doesn't recognize here. See above.""" | ||||
|  | ||||
| @@ -1488,37 +1614,6 @@ class FieldOptionsEditionDefault(betterproto.Message): | ||||
|     value: str = betterproto.string_field(2) | ||||
|  | ||||
|  | ||||
| @dataclass(eq=False, repr=False) | ||||
| class FieldOptionsFeatureSupport(betterproto.Message): | ||||
|     """Information about the support window of a feature.""" | ||||
|  | ||||
|     edition_introduced: "Edition" = betterproto.enum_field(1) | ||||
|     """ | ||||
|     The edition that this feature was first available in.  In editions | ||||
|      earlier than this one, the default assigned to EDITION_LEGACY will be | ||||
|      used, and proto files will not be able to override it. | ||||
|     """ | ||||
|  | ||||
|     edition_deprecated: "Edition" = betterproto.enum_field(2) | ||||
|     """ | ||||
|     The edition this feature becomes deprecated in.  Using this after this | ||||
|      edition may trigger warnings. | ||||
|     """ | ||||
|  | ||||
|     deprecation_warning: str = betterproto.string_field(3) | ||||
|     """ | ||||
|     The deprecation warning text if this feature is used after the edition it | ||||
|      was marked deprecated in. | ||||
|     """ | ||||
|  | ||||
|     edition_removed: "Edition" = betterproto.enum_field(4) | ||||
|     """ | ||||
|     The edition this feature is no longer available in.  In editions after | ||||
|      this one, the last default assigned will be used, and proto files will | ||||
|      not be able to override it. | ||||
|     """ | ||||
|  | ||||
|  | ||||
| @dataclass(eq=False, repr=False) | ||||
| class OneofOptions(betterproto.Message): | ||||
|     features: "FeatureSet" = betterproto.message_field(1) | ||||
| @@ -1723,17 +1818,7 @@ class FeatureSetDefaultsFeatureSetEditionDefault(betterproto.Message): | ||||
|     """ | ||||
|  | ||||
|     edition: "Edition" = betterproto.enum_field(3) | ||||
|     overridable_features: "FeatureSet" = betterproto.message_field(4) | ||||
|     """Defaults of features that can be overridden in this edition.""" | ||||
|  | ||||
|     fixed_features: "FeatureSet" = betterproto.message_field(5) | ||||
|     """Defaults of features that can't be overridden in this edition.""" | ||||
|  | ||||
|     features: "FeatureSet" = betterproto.message_field(2) | ||||
|     """ | ||||
|     TODO Deprecate and remove this field, which is just the | ||||
|      above two merged. | ||||
|     """ | ||||
|  | ||||
|  | ||||
| @dataclass(eq=False, repr=False) | ||||
| @@ -2314,7 +2399,7 @@ class Value(betterproto.Message): | ||||
|     ) | ||||
|     """Represents a repeated `Value`.""" | ||||
|  | ||||
|     @root_validator() | ||||
|     @model_validator(mode="after") | ||||
|     def check_oneof(cls, values): | ||||
|         return cls._validate_field_groups(values) | ||||
|  | ||||
| @@ -2549,41 +2634,40 @@ class BytesValue(betterproto.Message): | ||||
|     """The bytes value.""" | ||||
|  | ||||
|  | ||||
| Type.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| Field.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| Enum.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| EnumValue.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| Option.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| Api.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| Method.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| FileDescriptorSet.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| FileDescriptorProto.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| DescriptorProto.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| DescriptorProtoExtensionRange.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| ExtensionRangeOptions.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| FieldDescriptorProto.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| OneofDescriptorProto.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| EnumDescriptorProto.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| EnumValueDescriptorProto.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| ServiceDescriptorProto.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| MethodDescriptorProto.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| FileOptions.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| MessageOptions.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| FieldOptions.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| FieldOptionsEditionDefault.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| FieldOptionsFeatureSupport.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| OneofOptions.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| EnumOptions.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| EnumValueOptions.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| ServiceOptions.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| MethodOptions.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| UninterpretedOption.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| FeatureSet.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| FeatureSetDefaults.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| FeatureSetDefaultsFeatureSetEditionDefault.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| SourceCodeInfo.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| GeneratedCodeInfo.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| GeneratedCodeInfoAnnotation.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| Struct.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| Value.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| ListValue.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| rebuild_dataclass(Type)  # type: ignore | ||||
| rebuild_dataclass(Field)  # type: ignore | ||||
| rebuild_dataclass(Enum)  # type: ignore | ||||
| rebuild_dataclass(EnumValue)  # type: ignore | ||||
| rebuild_dataclass(Option)  # type: ignore | ||||
| rebuild_dataclass(Api)  # type: ignore | ||||
| rebuild_dataclass(Method)  # type: ignore | ||||
| rebuild_dataclass(FileDescriptorSet)  # type: ignore | ||||
| rebuild_dataclass(FileDescriptorProto)  # type: ignore | ||||
| rebuild_dataclass(DescriptorProto)  # type: ignore | ||||
| rebuild_dataclass(DescriptorProtoExtensionRange)  # type: ignore | ||||
| rebuild_dataclass(ExtensionRangeOptions)  # type: ignore | ||||
| rebuild_dataclass(FieldDescriptorProto)  # type: ignore | ||||
| rebuild_dataclass(OneofDescriptorProto)  # type: ignore | ||||
| rebuild_dataclass(EnumDescriptorProto)  # type: ignore | ||||
| rebuild_dataclass(EnumValueDescriptorProto)  # type: ignore | ||||
| rebuild_dataclass(ServiceDescriptorProto)  # type: ignore | ||||
| rebuild_dataclass(MethodDescriptorProto)  # type: ignore | ||||
| rebuild_dataclass(FileOptions)  # type: ignore | ||||
| rebuild_dataclass(MessageOptions)  # type: ignore | ||||
| rebuild_dataclass(FieldOptions)  # type: ignore | ||||
| rebuild_dataclass(FieldOptionsEditionDefault)  # type: ignore | ||||
| rebuild_dataclass(OneofOptions)  # type: ignore | ||||
| rebuild_dataclass(EnumOptions)  # type: ignore | ||||
| rebuild_dataclass(EnumValueOptions)  # type: ignore | ||||
| rebuild_dataclass(ServiceOptions)  # type: ignore | ||||
| rebuild_dataclass(MethodOptions)  # type: ignore | ||||
| rebuild_dataclass(UninterpretedOption)  # type: ignore | ||||
| rebuild_dataclass(FeatureSet)  # type: ignore | ||||
| rebuild_dataclass(FeatureSetDefaults)  # type: ignore | ||||
| rebuild_dataclass(FeatureSetDefaultsFeatureSetEditionDefault)  # type: ignore | ||||
| rebuild_dataclass(SourceCodeInfo)  # type: ignore | ||||
| rebuild_dataclass(GeneratedCodeInfo)  # type: ignore | ||||
| rebuild_dataclass(GeneratedCodeInfoAnnotation)  # type: ignore | ||||
| rebuild_dataclass(Struct)  # type: ignore | ||||
| rebuild_dataclass(Value)  # type: ignore | ||||
| rebuild_dataclass(ListValue)  # type: ignore | ||||
|   | ||||
| @@ -1,4 +1,7 @@ | ||||
| import os.path | ||||
| import sys | ||||
|  | ||||
| from .module_validation import ModuleValidator | ||||
|  | ||||
|  | ||||
| try: | ||||
| @@ -30,9 +33,12 @@ def outputfile_compiler(output_file: OutputTemplate) -> str: | ||||
|         lstrip_blocks=True, | ||||
|         loader=jinja2.FileSystemLoader(templates_folder), | ||||
|     ) | ||||
|     template = env.get_template("template.py.j2") | ||||
|     # Load the body first so we have a compleate list of imports needed. | ||||
|     body_template = env.get_template("template.py.j2") | ||||
|     header_template = env.get_template("header.py.j2") | ||||
|  | ||||
|     code = template.render(output_file=output_file) | ||||
|     code = body_template.render(output_file=output_file) | ||||
|     code = header_template.render(output_file=output_file) + code | ||||
|     code = isort.api.sort_code_string( | ||||
|         code=code, | ||||
|         show_diff=False, | ||||
| @@ -44,7 +50,18 @@ def outputfile_compiler(output_file: OutputTemplate) -> str: | ||||
|         force_grid_wrap=2, | ||||
|         known_third_party=["grpclib", "betterproto"], | ||||
|     ) | ||||
|     return black.format_str( | ||||
|     code = black.format_str( | ||||
|         src_contents=code, | ||||
|         mode=black.Mode(), | ||||
|     ) | ||||
|  | ||||
|     # Validate the generated code. | ||||
|     validator = ModuleValidator(iter(code.splitlines())) | ||||
|     if not validator.validate(): | ||||
|         message_builder = ["[WARNING]: Generated code has collisions in the module:"] | ||||
|         for collision, lines in validator.collisions.items(): | ||||
|             message_builder.append(f'  "{collision}" on lines:') | ||||
|             for num, line in lines: | ||||
|                 message_builder.append(f"    {num}:{line}") | ||||
|         print("\n".join(message_builder), file=sys.stderr) | ||||
|     return code | ||||
|   | ||||
| @@ -29,10 +29,8 @@ instantiating field `A` with parent message `B` should add a | ||||
| reference to `A` to `B`'s `fields` attribute. | ||||
| """ | ||||
|  | ||||
|  | ||||
| import builtins | ||||
| import re | ||||
| import textwrap | ||||
| from dataclasses import ( | ||||
|     dataclass, | ||||
|     field, | ||||
| @@ -49,12 +47,6 @@ from typing import ( | ||||
| ) | ||||
|  | ||||
| import betterproto | ||||
| from betterproto import which_one_of | ||||
| from betterproto.casing import sanitize_name | ||||
| from betterproto.compile.importing import ( | ||||
|     get_type_reference, | ||||
|     parse_source_type_name, | ||||
| ) | ||||
| from betterproto.compile.naming import ( | ||||
|     pythonize_class_name, | ||||
|     pythonize_field_name, | ||||
| @@ -72,6 +64,7 @@ from betterproto.lib.google.protobuf import ( | ||||
| ) | ||||
| from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest | ||||
|  | ||||
| from .. import which_one_of | ||||
| from ..compile.importing import ( | ||||
|     get_type_reference, | ||||
|     parse_source_type_name, | ||||
| @@ -82,6 +75,10 @@ from ..compile.naming import ( | ||||
|     pythonize_field_name, | ||||
|     pythonize_method_name, | ||||
| ) | ||||
| from .typing_compiler import ( | ||||
|     DirectImportTypingCompiler, | ||||
|     TypingCompiler, | ||||
| ) | ||||
|  | ||||
|  | ||||
| # Create a unique placeholder to deal with | ||||
| @@ -173,6 +170,7 @@ class ProtoContentBase: | ||||
|     """Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler.""" | ||||
|  | ||||
|     source_file: FileDescriptorProto | ||||
|     typing_compiler: TypingCompiler | ||||
|     path: List[int] | ||||
|     comment_indent: int = 4 | ||||
|     parent: Union["betterproto.Message", "OutputTemplate"] | ||||
| @@ -242,7 +240,6 @@ class OutputTemplate: | ||||
|     input_files: List[str] = field(default_factory=list) | ||||
|     imports: Set[str] = field(default_factory=set) | ||||
|     datetime_imports: Set[str] = field(default_factory=set) | ||||
|     typing_imports: Set[str] = field(default_factory=set) | ||||
|     pydantic_imports: Set[str] = field(default_factory=set) | ||||
|     builtins_import: bool = False | ||||
|     messages: List["MessageCompiler"] = field(default_factory=list) | ||||
| @@ -251,6 +248,7 @@ class OutputTemplate: | ||||
|     imports_type_checking_only: Set[str] = field(default_factory=set) | ||||
|     pydantic_dataclasses: bool = False | ||||
|     output: bool = True | ||||
|     typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler) | ||||
|  | ||||
|     @property | ||||
|     def package(self) -> str: | ||||
| @@ -289,6 +287,7 @@ class MessageCompiler(ProtoContentBase): | ||||
|     """Representation of a protobuf message.""" | ||||
|  | ||||
|     source_file: FileDescriptorProto | ||||
|     typing_compiler: TypingCompiler | ||||
|     parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER | ||||
|     proto_obj: DescriptorProto = PLACEHOLDER | ||||
|     path: List[int] = PLACEHOLDER | ||||
| @@ -319,7 +318,7 @@ class MessageCompiler(ProtoContentBase): | ||||
|     @property | ||||
|     def annotation(self) -> str: | ||||
|         if self.repeated: | ||||
|             return f"List[{self.py_name}]" | ||||
|             return self.typing_compiler.list(self.py_name) | ||||
|         return self.py_name | ||||
|  | ||||
|     @property | ||||
| @@ -434,18 +433,6 @@ class FieldCompiler(MessageCompiler): | ||||
|             imports.add("datetime") | ||||
|         return imports | ||||
|  | ||||
|     @property | ||||
|     def typing_imports(self) -> Set[str]: | ||||
|         imports = set() | ||||
|         annotation = self.annotation | ||||
|         if "Optional[" in annotation: | ||||
|             imports.add("Optional") | ||||
|         if "List[" in annotation: | ||||
|             imports.add("List") | ||||
|         if "Dict[" in annotation: | ||||
|             imports.add("Dict") | ||||
|         return imports | ||||
|  | ||||
|     @property | ||||
|     def pydantic_imports(self) -> Set[str]: | ||||
|         return set() | ||||
| @@ -458,7 +445,6 @@ class FieldCompiler(MessageCompiler): | ||||
|  | ||||
|     def add_imports_to(self, output_file: OutputTemplate) -> None: | ||||
|         output_file.datetime_imports.update(self.datetime_imports) | ||||
|         output_file.typing_imports.update(self.typing_imports) | ||||
|         output_file.pydantic_imports.update(self.pydantic_imports) | ||||
|         output_file.builtins_import = output_file.builtins_import or self.use_builtins | ||||
|  | ||||
| @@ -488,7 +474,9 @@ class FieldCompiler(MessageCompiler): | ||||
|     @property | ||||
|     def mutable(self) -> bool: | ||||
|         """True if the field is a mutable type, otherwise False.""" | ||||
|         return self.annotation.startswith(("List[", "Dict[")) | ||||
|         return self.annotation.startswith( | ||||
|             ("typing.List[", "typing.Dict[", "dict[", "list[", "Dict[", "List[") | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def field_type(self) -> str: | ||||
| @@ -562,6 +550,7 @@ class FieldCompiler(MessageCompiler): | ||||
|                 package=self.output_file.package, | ||||
|                 imports=self.output_file.imports, | ||||
|                 source_type=self.proto_obj.type_name, | ||||
|                 typing_compiler=self.typing_compiler, | ||||
|                 pydantic=self.output_file.pydantic_dataclasses, | ||||
|             ) | ||||
|         else: | ||||
| @@ -573,9 +562,9 @@ class FieldCompiler(MessageCompiler): | ||||
|         if self.use_builtins: | ||||
|             py_type = f"builtins.{py_type}" | ||||
|         if self.repeated: | ||||
|             return f"List[{py_type}]" | ||||
|             return self.typing_compiler.list(py_type) | ||||
|         if self.optional: | ||||
|             return f"Optional[{py_type}]" | ||||
|             return self.typing_compiler.optional(py_type) | ||||
|         return py_type | ||||
|  | ||||
|  | ||||
| @@ -600,7 +589,7 @@ class PydanticOneOfFieldCompiler(OneOfFieldCompiler): | ||||
|  | ||||
|     @property | ||||
|     def pydantic_imports(self) -> Set[str]: | ||||
|         return {"root_validator"} | ||||
|         return {"model_validator"} | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| @@ -623,11 +612,13 @@ class MapEntryCompiler(FieldCompiler): | ||||
|                     source_file=self.source_file, | ||||
|                     parent=self, | ||||
|                     proto_obj=nested.field[0],  # key | ||||
|                     typing_compiler=self.typing_compiler, | ||||
|                 ).py_type | ||||
|                 self.py_v_type = FieldCompiler( | ||||
|                     source_file=self.source_file, | ||||
|                     parent=self, | ||||
|                     proto_obj=nested.field[1],  # value | ||||
|                     typing_compiler=self.typing_compiler, | ||||
|                 ).py_type | ||||
|  | ||||
|                 # Get proto types | ||||
| @@ -645,7 +636,7 @@ class MapEntryCompiler(FieldCompiler): | ||||
|  | ||||
|     @property | ||||
|     def annotation(self) -> str: | ||||
|         return f"Dict[{self.py_k_type}, {self.py_v_type}]" | ||||
|         return self.typing_compiler.dict(self.py_k_type, self.py_v_type) | ||||
|  | ||||
|     @property | ||||
|     def repeated(self) -> bool: | ||||
| @@ -702,7 +693,6 @@ class ServiceCompiler(ProtoContentBase): | ||||
|     def __post_init__(self) -> None: | ||||
|         # Add service to output file | ||||
|         self.output_file.services.append(self) | ||||
|         self.output_file.typing_imports.add("Dict") | ||||
|         super().__post_init__()  # check for unset fields | ||||
|  | ||||
|     @property | ||||
| @@ -725,22 +715,6 @@ class ServiceMethodCompiler(ProtoContentBase): | ||||
|         # Add method to service | ||||
|         self.parent.methods.append(self) | ||||
|  | ||||
|         # Check for imports | ||||
|         if "Optional" in self.py_output_message_type: | ||||
|             self.output_file.typing_imports.add("Optional") | ||||
|  | ||||
|         # Check for Async imports | ||||
|         if self.client_streaming: | ||||
|             self.output_file.typing_imports.add("AsyncIterable") | ||||
|             self.output_file.typing_imports.add("Iterable") | ||||
|             self.output_file.typing_imports.add("Union") | ||||
|  | ||||
|         # Required by both client and server | ||||
|         if self.client_streaming or self.server_streaming: | ||||
|             self.output_file.typing_imports.add("AsyncIterator") | ||||
|  | ||||
|         # add imports required for request arguments timeout, deadline and metadata | ||||
|         self.output_file.typing_imports.add("Optional") | ||||
|         self.output_file.imports_type_checking_only.add("import grpclib.server") | ||||
|         self.output_file.imports_type_checking_only.add( | ||||
|             "from betterproto.grpc.grpclib_client import MetadataLike" | ||||
| @@ -806,6 +780,7 @@ class ServiceMethodCompiler(ProtoContentBase): | ||||
|             package=self.output_file.package, | ||||
|             imports=self.output_file.imports, | ||||
|             source_type=self.proto_obj.input_type, | ||||
|             typing_compiler=self.output_file.typing_compiler, | ||||
|             unwrap=False, | ||||
|             pydantic=self.output_file.pydantic_dataclasses, | ||||
|         ).strip('"') | ||||
| @@ -835,6 +810,7 @@ class ServiceMethodCompiler(ProtoContentBase): | ||||
|             package=self.output_file.package, | ||||
|             imports=self.output_file.imports, | ||||
|             source_type=self.proto_obj.output_type, | ||||
|             typing_compiler=self.output_file.typing_compiler, | ||||
|             unwrap=False, | ||||
|             pydantic=self.output_file.pydantic_dataclasses, | ||||
|         ).strip('"') | ||||
|   | ||||
							
								
								
									
										163
									
								
								src/betterproto/plugin/module_validation.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										163
									
								
								src/betterproto/plugin/module_validation.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,163 @@ | ||||
| import re | ||||
| from collections import defaultdict | ||||
| from dataclasses import ( | ||||
|     dataclass, | ||||
|     field, | ||||
| ) | ||||
| from typing import ( | ||||
|     Dict, | ||||
|     Iterator, | ||||
|     List, | ||||
|     Tuple, | ||||
| ) | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class ModuleValidator: | ||||
|     line_iterator: Iterator[str] | ||||
|     line_number: int = field(init=False, default=0) | ||||
|  | ||||
|     collisions: Dict[str, List[Tuple[int, str]]] = field( | ||||
|         init=False, default_factory=lambda: defaultdict(list) | ||||
|     ) | ||||
|  | ||||
|     def add_import(self, imp: str, number: int, full_line: str): | ||||
|         """ | ||||
|         Adds an import to be tracked. | ||||
|         """ | ||||
|         self.collisions[imp].append((number, full_line)) | ||||
|  | ||||
|     def process_import(self, imp: str): | ||||
|         """ | ||||
|         Filters out the import to its actual value. | ||||
|         """ | ||||
|         if " as " in imp: | ||||
|             imp = imp[imp.index(" as ") + 4 :] | ||||
|  | ||||
|         imp = imp.strip() | ||||
|         assert " " not in imp, imp | ||||
|         return imp | ||||
|  | ||||
|     def evaluate_multiline_import(self, line: str): | ||||
|         """ | ||||
|         Evaluates a multiline import from a starting line | ||||
|         """ | ||||
|         # Filter the first line and remove anything before the import statement. | ||||
|         full_line = line | ||||
|         line = line.split("import", 1)[1] | ||||
|         if "(" in line: | ||||
|             conditional = lambda line: ")" not in line | ||||
|         else: | ||||
|             conditional = lambda line: "\\" in line | ||||
|  | ||||
|         # Remove open parenthesis if it exists. | ||||
|         if "(" in line: | ||||
|             line = line[line.index("(") + 1 :] | ||||
|  | ||||
|         # Choose the conditional based on how multiline imports are formatted. | ||||
|         while conditional(line): | ||||
|             # Split the line by commas | ||||
|             imports = line.split(",") | ||||
|  | ||||
|             for imp in imports: | ||||
|                 # Add the import to the namespace | ||||
|                 imp = self.process_import(imp) | ||||
|                 if imp: | ||||
|                     self.add_import(imp, self.line_number, full_line) | ||||
|             # Get the next line | ||||
|             full_line = line = next(self.line_iterator) | ||||
|             # Increment the line number | ||||
|             self.line_number += 1 | ||||
|  | ||||
|         # validate the last line | ||||
|         if ")" in line: | ||||
|             line = line[: line.index(")")] | ||||
|         imports = line.split(",") | ||||
|         for imp in imports: | ||||
|             imp = self.process_import(imp) | ||||
|             if imp: | ||||
|                 self.add_import(imp, self.line_number, full_line) | ||||
|  | ||||
|     def evaluate_import(self, line: str): | ||||
|         """ | ||||
|         Extracts an import from a line. | ||||
|         """ | ||||
|         whole_line = line | ||||
|         line = line[line.index("import") + 6 :] | ||||
|         values = line.split(",") | ||||
|         for v in values: | ||||
|             self.add_import(self.process_import(v), self.line_number, whole_line) | ||||
|  | ||||
|     def next(self): | ||||
|         """ | ||||
|         Evaluate each line for names in the module. | ||||
|         """ | ||||
|         line = next(self.line_iterator) | ||||
|  | ||||
|         # Skip lines with indentation or comments | ||||
|         if ( | ||||
|             # Skip indents and whitespace. | ||||
|             line.startswith(" ") | ||||
|             or line == "\n" | ||||
|             or line.startswith("\t") | ||||
|             or | ||||
|             # Skip comments | ||||
|             line.startswith("#") | ||||
|             or | ||||
|             # Skip  decorators | ||||
|             line.startswith("@") | ||||
|         ): | ||||
|             self.line_number += 1 | ||||
|             return | ||||
|  | ||||
|         # Skip docstrings. | ||||
|         if line.startswith('"""') or line.startswith("'''"): | ||||
|             quote = line[0] * 3 | ||||
|             line = line[3:] | ||||
|             while quote not in line: | ||||
|                 line = next(self.line_iterator) | ||||
|             self.line_number += 1 | ||||
|             return | ||||
|  | ||||
|         # Evaluate Imports. | ||||
|         if line.startswith("from ") or line.startswith("import "): | ||||
|             if "(" in line or "\\" in line: | ||||
|                 self.evaluate_multiline_import(line) | ||||
|             else: | ||||
|                 self.evaluate_import(line) | ||||
|  | ||||
|         # Evaluate Classes. | ||||
|         elif line.startswith("class "): | ||||
|             class_name = re.search(r"class (\w+)", line).group(1) | ||||
|             if class_name: | ||||
|                 self.add_import(class_name, self.line_number, line) | ||||
|  | ||||
|         # Evaluate Functions. | ||||
|         elif line.startswith("def "): | ||||
|             function_name = re.search(r"def (\w+)", line).group(1) | ||||
|             if function_name: | ||||
|                 self.add_import(function_name, self.line_number, line) | ||||
|  | ||||
|         # Evaluate direct assignments. | ||||
|         elif "=" in line: | ||||
|             assignment = re.search(r"(\w+)\s*=", line).group(1) | ||||
|             if assignment: | ||||
|                 self.add_import(assignment, self.line_number, line) | ||||
|  | ||||
|         self.line_number += 1 | ||||
|  | ||||
|     def validate(self) -> bool: | ||||
|         """ | ||||
|         Run Validation. | ||||
|         """ | ||||
|         try: | ||||
|             while True: | ||||
|                 self.next() | ||||
|         except StopIteration: | ||||
|             pass | ||||
|  | ||||
|         # Filter collisions for those with more than one value. | ||||
|         self.collisions = {k: v for k, v in self.collisions.items() if len(v) > 1} | ||||
|  | ||||
|         # Return True if no collisions are found. | ||||
|         return not bool(self.collisions) | ||||
| @@ -37,6 +37,12 @@ from .models import ( | ||||
|     is_map, | ||||
|     is_oneof, | ||||
| ) | ||||
| from .typing_compiler import ( | ||||
|     DirectImportTypingCompiler, | ||||
|     NoTyping310TypingCompiler, | ||||
|     TypingCompiler, | ||||
|     TypingImportTypingCompiler, | ||||
| ) | ||||
|  | ||||
|  | ||||
| def traverse( | ||||
| @@ -98,6 +104,28 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse: | ||||
|                 output_package_name | ||||
|             ].pydantic_dataclasses = True | ||||
|  | ||||
|         # Gather any typing generation options. | ||||
|         typing_opts = [ | ||||
|             opt[len("typing.") :] for opt in plugin_options if opt.startswith("typing.") | ||||
|         ] | ||||
|  | ||||
|         if len(typing_opts) > 1: | ||||
|             raise ValueError("Multiple typing options provided") | ||||
|         # Set the compiler type. | ||||
|         typing_opt = typing_opts[0] if typing_opts else "direct" | ||||
|         if typing_opt == "direct": | ||||
|             request_data.output_packages[ | ||||
|                 output_package_name | ||||
|             ].typing_compiler = DirectImportTypingCompiler() | ||||
|         elif typing_opt == "root": | ||||
|             request_data.output_packages[ | ||||
|                 output_package_name | ||||
|             ].typing_compiler = TypingImportTypingCompiler() | ||||
|         elif typing_opt == "310": | ||||
|             request_data.output_packages[ | ||||
|                 output_package_name | ||||
|             ].typing_compiler = NoTyping310TypingCompiler() | ||||
|  | ||||
|     # Read Messages and Enums | ||||
|     # We need to read Messages before Services in so that we can | ||||
|     # get the references to input/output messages for each service | ||||
| @@ -166,6 +194,7 @@ def _make_one_of_field_compiler( | ||||
|         parent=parent, | ||||
|         proto_obj=proto_obj, | ||||
|         path=path, | ||||
|         typing_compiler=output_package.typing_compiler, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| @@ -181,7 +210,11 @@ def read_protobuf_type( | ||||
|             return | ||||
|         # Process Message | ||||
|         message_data = MessageCompiler( | ||||
|             source_file=source_file, parent=output_package, proto_obj=item, path=path | ||||
|             source_file=source_file, | ||||
|             parent=output_package, | ||||
|             proto_obj=item, | ||||
|             path=path, | ||||
|             typing_compiler=output_package.typing_compiler, | ||||
|         ) | ||||
|         for index, field in enumerate(item.field): | ||||
|             if is_map(field, item): | ||||
| @@ -190,6 +223,7 @@ def read_protobuf_type( | ||||
|                     parent=message_data, | ||||
|                     proto_obj=field, | ||||
|                     path=path + [2, index], | ||||
|                     typing_compiler=output_package.typing_compiler, | ||||
|                 ) | ||||
|             elif is_oneof(field): | ||||
|                 _make_one_of_field_compiler( | ||||
| @@ -201,11 +235,16 @@ def read_protobuf_type( | ||||
|                     parent=message_data, | ||||
|                     proto_obj=field, | ||||
|                     path=path + [2, index], | ||||
|                     typing_compiler=output_package.typing_compiler, | ||||
|                 ) | ||||
|     elif isinstance(item, EnumDescriptorProto): | ||||
|         # Enum | ||||
|         EnumDefinitionCompiler( | ||||
|             source_file=source_file, parent=output_package, proto_obj=item, path=path | ||||
|             source_file=source_file, | ||||
|             parent=output_package, | ||||
|             proto_obj=item, | ||||
|             path=path, | ||||
|             typing_compiler=output_package.typing_compiler, | ||||
|         ) | ||||
|  | ||||
|  | ||||
|   | ||||
							
								
								
									
										167
									
								
								src/betterproto/plugin/typing_compiler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										167
									
								
								src/betterproto/plugin/typing_compiler.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,167 @@ | ||||
| import abc | ||||
| from collections import defaultdict | ||||
| from dataclasses import ( | ||||
|     dataclass, | ||||
|     field, | ||||
| ) | ||||
| from typing import ( | ||||
|     Dict, | ||||
|     Iterator, | ||||
|     Optional, | ||||
|     Set, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class TypingCompiler(metaclass=abc.ABCMeta): | ||||
|     @abc.abstractmethod | ||||
|     def optional(self, type: str) -> str: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def list(self, type: str) -> str: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def dict(self, key: str, value: str) -> str: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def union(self, *types: str) -> str: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def iterable(self, type: str) -> str: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def async_iterable(self, type: str) -> str: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def async_iterator(self, type: str) -> str: | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     @abc.abstractmethod | ||||
|     def imports(self) -> Dict[str, Optional[Set[str]]]: | ||||
|         """ | ||||
|         Returns either the direct import as a key with none as value, or a set of | ||||
|         values to import from the key. | ||||
|         """ | ||||
|         raise NotImplementedError() | ||||
|  | ||||
|     def import_lines(self) -> Iterator: | ||||
|         imports = self.imports() | ||||
|         for key, value in imports.items(): | ||||
|             if value is None: | ||||
|                 yield f"import {key}" | ||||
|             else: | ||||
|                 yield f"from {key} import (" | ||||
|                 for v in sorted(value): | ||||
|                     yield f"    {v}," | ||||
|                 yield ")" | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class DirectImportTypingCompiler(TypingCompiler): | ||||
|     _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) | ||||
|  | ||||
|     def optional(self, type: str) -> str: | ||||
|         self._imports["typing"].add("Optional") | ||||
|         return f"Optional[{type}]" | ||||
|  | ||||
|     def list(self, type: str) -> str: | ||||
|         self._imports["typing"].add("List") | ||||
|         return f"List[{type}]" | ||||
|  | ||||
|     def dict(self, key: str, value: str) -> str: | ||||
|         self._imports["typing"].add("Dict") | ||||
|         return f"Dict[{key}, {value}]" | ||||
|  | ||||
|     def union(self, *types: str) -> str: | ||||
|         self._imports["typing"].add("Union") | ||||
|         return f"Union[{', '.join(types)}]" | ||||
|  | ||||
|     def iterable(self, type: str) -> str: | ||||
|         self._imports["typing"].add("Iterable") | ||||
|         return f"Iterable[{type}]" | ||||
|  | ||||
|     def async_iterable(self, type: str) -> str: | ||||
|         self._imports["typing"].add("AsyncIterable") | ||||
|         return f"AsyncIterable[{type}]" | ||||
|  | ||||
|     def async_iterator(self, type: str) -> str: | ||||
|         self._imports["typing"].add("AsyncIterator") | ||||
|         return f"AsyncIterator[{type}]" | ||||
|  | ||||
|     def imports(self) -> Dict[str, Optional[Set[str]]]: | ||||
|         return {k: v if v else None for k, v in self._imports.items()} | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class TypingImportTypingCompiler(TypingCompiler): | ||||
|     _imported: bool = False | ||||
|  | ||||
|     def optional(self, type: str) -> str: | ||||
|         self._imported = True | ||||
|         return f"typing.Optional[{type}]" | ||||
|  | ||||
|     def list(self, type: str) -> str: | ||||
|         self._imported = True | ||||
|         return f"typing.List[{type}]" | ||||
|  | ||||
|     def dict(self, key: str, value: str) -> str: | ||||
|         self._imported = True | ||||
|         return f"typing.Dict[{key}, {value}]" | ||||
|  | ||||
|     def union(self, *types: str) -> str: | ||||
|         self._imported = True | ||||
|         return f"typing.Union[{', '.join(types)}]" | ||||
|  | ||||
|     def iterable(self, type: str) -> str: | ||||
|         self._imported = True | ||||
|         return f"typing.Iterable[{type}]" | ||||
|  | ||||
|     def async_iterable(self, type: str) -> str: | ||||
|         self._imported = True | ||||
|         return f"typing.AsyncIterable[{type}]" | ||||
|  | ||||
|     def async_iterator(self, type: str) -> str: | ||||
|         self._imported = True | ||||
|         return f"typing.AsyncIterator[{type}]" | ||||
|  | ||||
|     def imports(self) -> Dict[str, Optional[Set[str]]]: | ||||
|         if self._imported: | ||||
|             return {"typing": None} | ||||
|         return {} | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class NoTyping310TypingCompiler(TypingCompiler): | ||||
|     _imports: Dict[str, Set[str]] = field(default_factory=lambda: defaultdict(set)) | ||||
|  | ||||
|     def optional(self, type: str) -> str: | ||||
|         return f"{type} | None" | ||||
|  | ||||
|     def list(self, type: str) -> str: | ||||
|         return f"list[{type}]" | ||||
|  | ||||
|     def dict(self, key: str, value: str) -> str: | ||||
|         return f"dict[{key}, {value}]" | ||||
|  | ||||
|     def union(self, *types: str) -> str: | ||||
|         return " | ".join(types) | ||||
|  | ||||
|     def iterable(self, type: str) -> str: | ||||
|         self._imports["typing"].add("Iterable") | ||||
|         return f"Iterable[{type}]" | ||||
|  | ||||
|     def async_iterable(self, type: str) -> str: | ||||
|         self._imports["typing"].add("AsyncIterable") | ||||
|         return f"AsyncIterable[{type}]" | ||||
|  | ||||
|     def async_iterator(self, type: str) -> str: | ||||
|         self._imports["typing"].add("AsyncIterator") | ||||
|         return f"AsyncIterator[{type}]" | ||||
|  | ||||
|     def imports(self) -> Dict[str, Optional[Set[str]]]: | ||||
|         return {k: v if v else None for k, v in self._imports.items()} | ||||
							
								
								
									
										55
									
								
								src/betterproto/templates/header.py.j2
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								src/betterproto/templates/header.py.j2
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | ||||
| # Generated by the protocol buffer compiler.  DO NOT EDIT! | ||||
| # sources: {{ ', '.join(output_file.input_filenames) }} | ||||
| # plugin: python-betterproto | ||||
| # This file has been @generated | ||||
| {% for i in output_file.python_module_imports|sort %} | ||||
| import {{ i }} | ||||
| {% endfor %} | ||||
| {% set type_checking_imported = False %} | ||||
|  | ||||
| {% if output_file.pydantic_dataclasses %} | ||||
| from typing import TYPE_CHECKING | ||||
| {% set type_checking_imported = True %} | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from dataclasses import dataclass | ||||
| else: | ||||
|     from pydantic.dataclasses import dataclass | ||||
| from pydantic.dataclasses import rebuild_dataclass | ||||
| {%- else -%} | ||||
| from dataclasses import dataclass | ||||
| {% endif %} | ||||
|  | ||||
| {% if output_file.datetime_imports %} | ||||
| from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} | ||||
|  | ||||
| {% endif%} | ||||
| {% set typing_imports = output_file.typing_compiler.imports() %} | ||||
| {% if typing_imports %} | ||||
| {% for line in output_file.typing_compiler.import_lines() %} | ||||
| {{ line }} | ||||
| {% endfor %} | ||||
| {% endif %} | ||||
|  | ||||
| {% if output_file.pydantic_imports %} | ||||
| from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} | ||||
|  | ||||
| {% endif %} | ||||
|  | ||||
| import betterproto | ||||
| {% if output_file.services %} | ||||
| from betterproto.grpc.grpclib_server import ServiceBase | ||||
| import grpclib | ||||
| {% endif %} | ||||
|  | ||||
| {% for i in output_file.imports|sort %} | ||||
| {{ i }} | ||||
| {% endfor %} | ||||
|  | ||||
| {% if output_file.imports_type_checking_only and not type_checking_imported %} | ||||
| from typing import TYPE_CHECKING | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
| {% for i in output_file.imports_type_checking_only|sort %}    {{ i }} | ||||
| {% endfor %} | ||||
| {% endif %} | ||||
| @@ -1,53 +1,3 @@ | ||||
| # Generated by the protocol buffer compiler.  DO NOT EDIT! | ||||
| # sources: {{ ', '.join(output_file.input_filenames) }} | ||||
| # plugin: python-betterproto | ||||
| # This file has been @generated | ||||
| {% for i in output_file.python_module_imports|sort %} | ||||
| import {{ i }} | ||||
| {% endfor %} | ||||
|  | ||||
| {% if output_file.pydantic_dataclasses %} | ||||
| from typing import TYPE_CHECKING | ||||
| if TYPE_CHECKING: | ||||
|     from dataclasses import dataclass | ||||
| else: | ||||
|     from pydantic.dataclasses import dataclass | ||||
| {%- else -%} | ||||
| from dataclasses import dataclass | ||||
| {% endif %} | ||||
|  | ||||
| {% if output_file.datetime_imports %} | ||||
| from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} | ||||
|  | ||||
| {% endif%} | ||||
| {% if output_file.typing_imports %} | ||||
| from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} | ||||
|  | ||||
| {% endif %} | ||||
|  | ||||
| {% if output_file.pydantic_imports %} | ||||
| from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} | ||||
|  | ||||
| {% endif %} | ||||
|  | ||||
| import betterproto | ||||
| {% if output_file.services %} | ||||
| from betterproto.grpc.grpclib_server import ServiceBase | ||||
| import grpclib | ||||
| {% endif %} | ||||
|  | ||||
| {% for i in output_file.imports|sort %} | ||||
| {{ i }} | ||||
| {% endfor %} | ||||
|  | ||||
| {% if output_file.imports_type_checking_only %} | ||||
| from typing import TYPE_CHECKING | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
| {% for i in output_file.imports_type_checking_only|sort %}    {{ i }} | ||||
| {% endfor %} | ||||
| {% endif %} | ||||
|  | ||||
| {% if output_file.enums %}{% for enum in output_file.enums %} | ||||
| class {{ enum.py_name }}(betterproto.Enum): | ||||
|     {% if enum.comment %} | ||||
| @@ -62,6 +12,13 @@ class {{ enum.py_name }}(betterproto.Enum): | ||||
|         {% endif %} | ||||
|     {% endfor %} | ||||
|  | ||||
|     {% if output_file.pydantic_dataclasses %} | ||||
|     @classmethod | ||||
|     def __get_pydantic_core_schema__(cls, _source_type, _handler): | ||||
|         from pydantic_core import core_schema | ||||
|  | ||||
|         return core_schema.int_schema(ge=0) | ||||
|     {% endif %} | ||||
|  | ||||
| {% endfor %} | ||||
| {% endif %} | ||||
| @@ -96,7 +53,7 @@ class {{ message.py_name }}(betterproto.Message): | ||||
|     {%  endif %} | ||||
|  | ||||
|     {% if output_file.pydantic_dataclasses and message.has_oneof_fields %} | ||||
|     @root_validator() | ||||
|     @model_validator(mode='after') | ||||
|     def check_oneof(cls, values): | ||||
|         return cls._validate_field_groups(values) | ||||
|     {%  endif %} | ||||
| @@ -116,14 +73,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): | ||||
|             {%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%} | ||||
|         {%- else -%} | ||||
|             {# Client streaming: need a request iterator instead #} | ||||
|             , {{ method.py_input_message_param }}_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]] | ||||
|             , {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.union(output_file.typing_compiler.async_iterable(method.py_input_message_type), output_file.typing_compiler.iterable(method.py_input_message_type)) }} | ||||
|         {%- endif -%} | ||||
|             , | ||||
|             * | ||||
|             , timeout: Optional[float] = None | ||||
|             , deadline: Optional["Deadline"] = None | ||||
|             , metadata: Optional["MetadataLike"] = None | ||||
|             ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: | ||||
|             , timeout: {{ output_file.typing_compiler.optional("float") }} = None | ||||
|             , deadline: {{ output_file.typing_compiler.optional('"Deadline"') }} = None | ||||
|             , metadata: {{ output_file.typing_compiler.optional('"MetadataLike"') }} = None | ||||
|             ) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type ) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}: | ||||
|         {% if method.comment %} | ||||
| {{ method.comment }} | ||||
|  | ||||
| @@ -191,9 +148,9 @@ class {{ service.py_name }}Base(ServiceBase): | ||||
|             {%- if method.py_input_message -%}, {{ method.py_input_message_param }}: "{{ method.py_input_message_type }}"{%- endif -%} | ||||
|         {%- else -%} | ||||
|             {# Client streaming: need a request iterator instead #} | ||||
|             , {{ method.py_input_message_param }}_iterator: AsyncIterator["{{ method.py_input_message_type }}"] | ||||
|             , {{ method.py_input_message_param }}_iterator: {{ output_file.typing_compiler.async_iterator(method.py_input_message_type) }} | ||||
|         {%- endif -%} | ||||
|             ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: | ||||
|             ) -> {% if method.server_streaming %}{{ output_file.typing_compiler.async_iterator(method.py_output_message_type) }}{% else %}"{{ method.py_output_message_type }}"{% endif %}: | ||||
|         {% if method.comment %} | ||||
| {{ method.comment }} | ||||
|  | ||||
| @@ -225,7 +182,7 @@ class {{ service.py_name }}Base(ServiceBase): | ||||
|  | ||||
|     {% endfor %} | ||||
|  | ||||
|     def __mapping__(self) -> Dict[str, grpclib.const.Handler]: | ||||
|     def __mapping__(self) -> {{ output_file.typing_compiler.dict("str", "grpclib.const.Handler") }}: | ||||
|         return { | ||||
|         {% for method in service.methods %} | ||||
|         "{{ method.route }}": grpclib.const.Handler( | ||||
| @@ -250,7 +207,7 @@ class {{ service.py_name }}Base(ServiceBase): | ||||
| {% if output_file.pydantic_dataclasses %} | ||||
| {% for message in output_file.messages %} | ||||
| {% if message.has_message_field %} | ||||
| {{ message.py_name }}.__pydantic_model__.update_forward_refs()  # type: ignore | ||||
| rebuild_dataclass({{ message.py_name }})  # type: ignore | ||||
| {% endif %} | ||||
| {% endfor %} | ||||
| {% endif %} | ||||
|   | ||||
		Reference in New Issue
	
	Block a user