Make plugin use betterproto generated classes internally
This means the betterproto plugin no longer needs to depend durectly on protobuf. This requires a small runtime hack to monkey patch some google types to get around the fact that the compiler uses proto2, but betterproto expects proto3. Also: - regenerate google.protobuf package - fix a regex bug in the logic for determining whether to use a google wrapper type. - fix a bug causing comments to get mixed up when multiple proto files generate code into a single python module
This commit is contained in:
@@ -1172,6 +1172,7 @@ from .lib.google.protobuf import ( # noqa
|
||||
BytesValue,
|
||||
DoubleValue,
|
||||
Duration,
|
||||
EnumValue,
|
||||
FloatValue,
|
||||
Int32Value,
|
||||
Int64Value,
|
||||
@@ -1238,14 +1239,17 @@ class _WrappedMessage(Message):
|
||||
|
||||
def _get_wrapper(proto_type: str) -> Type:
|
||||
"""Get the wrapper message class for a wrapped type."""
|
||||
|
||||
# TODO: include ListValue and NullValue?
|
||||
return {
|
||||
TYPE_BOOL: BoolValue,
|
||||
TYPE_INT32: Int32Value,
|
||||
TYPE_UINT32: UInt32Value,
|
||||
TYPE_INT64: Int64Value,
|
||||
TYPE_UINT64: UInt64Value,
|
||||
TYPE_FLOAT: FloatValue,
|
||||
TYPE_DOUBLE: DoubleValue,
|
||||
TYPE_STRING: StringValue,
|
||||
TYPE_BYTES: BytesValue,
|
||||
TYPE_DOUBLE: DoubleValue,
|
||||
TYPE_FLOAT: FloatValue,
|
||||
TYPE_ENUM: EnumValue,
|
||||
TYPE_INT32: Int32Value,
|
||||
TYPE_INT64: Int64Value,
|
||||
TYPE_STRING: StringValue,
|
||||
TYPE_UINT32: UInt32Value,
|
||||
TYPE_UINT64: UInt64Value,
|
||||
}[proto_type]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
125
src/betterproto/lib/google/protobuf/compiler/__init__.py
Normal file
125
src/betterproto/lib/google/protobuf/compiler/__init__.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# sources: google/protobuf/compiler/plugin.proto
|
||||
# plugin: python-betterproto
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import betterproto
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class Version(betterproto.Message):
|
||||
"""The version number of protocol compiler."""
|
||||
|
||||
major: int = betterproto.int32_field(1)
|
||||
minor: int = betterproto.int32_field(2)
|
||||
patch: int = betterproto.int32_field(3)
|
||||
# A suffix for alpha, beta or rc release, e.g., "alpha-1", "rc2". It should
|
||||
# be empty for mainline stable releases.
|
||||
suffix: str = betterproto.string_field(4)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class CodeGeneratorRequest(betterproto.Message):
|
||||
"""An encoded CodeGeneratorRequest is written to the plugin's stdin."""
|
||||
|
||||
# The .proto files that were explicitly listed on the command-line. The code
|
||||
# generator should generate code only for these files. Each file's
|
||||
# descriptor will be included in proto_file, below.
|
||||
file_to_generate: List[str] = betterproto.string_field(1)
|
||||
# The generator parameter passed on the command-line.
|
||||
parameter: str = betterproto.string_field(2)
|
||||
# FileDescriptorProtos for all files in files_to_generate and everything they
|
||||
# import. The files will appear in topological order, so each file appears
|
||||
# before any file that imports it. protoc guarantees that all proto_files
|
||||
# will be written after the fields above, even though this is not technically
|
||||
# guaranteed by the protobuf wire format. This theoretically could allow a
|
||||
# plugin to stream in the FileDescriptorProtos and handle them one by one
|
||||
# rather than read the entire set into memory at once. However, as of this
|
||||
# writing, this is not similarly optimized on protoc's end -- it will store
|
||||
# all fields in memory at once before sending them to the plugin. Type names
|
||||
# of fields and extensions in the FileDescriptorProto are always fully
|
||||
# qualified.
|
||||
proto_file: List[
|
||||
"betterproto_lib_google_protobuf.FileDescriptorProto"
|
||||
] = betterproto.message_field(15)
|
||||
# The version number of protocol compiler.
|
||||
compiler_version: "Version" = betterproto.message_field(3)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class CodeGeneratorResponse(betterproto.Message):
|
||||
"""The plugin writes an encoded CodeGeneratorResponse to stdout."""
|
||||
|
||||
# Error message. If non-empty, code generation failed. The plugin process
|
||||
# should exit with status code zero even if it reports an error in this way.
|
||||
# This should be used to indicate errors in .proto files which prevent the
|
||||
# code generator from generating correct code. Errors which indicate a
|
||||
# problem in protoc itself -- such as the input CodeGeneratorRequest being
|
||||
# unparseable -- should be reported by writing a message to stderr and
|
||||
# exiting with a non-zero status code.
|
||||
error: str = betterproto.string_field(1)
|
||||
file: List["CodeGeneratorResponseFile"] = betterproto.message_field(15)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@dataclass(eq=False, repr=False)
|
||||
class CodeGeneratorResponseFile(betterproto.Message):
|
||||
"""Represents a single generated file."""
|
||||
|
||||
# The file name, relative to the output directory. The name must not contain
|
||||
# "." or ".." components and must be relative, not be absolute (so, the file
|
||||
# cannot lie outside the output directory). "/" must be used as the path
|
||||
# separator, not "\". If the name is omitted, the content will be appended to
|
||||
# the previous file. This allows the generator to break large files into
|
||||
# small chunks, and allows the generated text to be streamed back to protoc
|
||||
# so that large files need not reside completely in memory at one time. Note
|
||||
# that as of this writing protoc does not optimize for this -- it will read
|
||||
# the entire CodeGeneratorResponse before writing files to disk.
|
||||
name: str = betterproto.string_field(1)
|
||||
# If non-empty, indicates that the named file should already exist, and the
|
||||
# content here is to be inserted into that file at a defined insertion point.
|
||||
# This feature allows a code generator to extend the output produced by
|
||||
# another code generator. The original generator may provide insertion
|
||||
# points by placing special annotations in the file that look like:
|
||||
# @@protoc_insertion_point(NAME) The annotation can have arbitrary text
|
||||
# before and after it on the line, which allows it to be placed in a comment.
|
||||
# NAME should be replaced with an identifier naming the point -- this is what
|
||||
# other generators will use as the insertion_point. Code inserted at this
|
||||
# point will be placed immediately above the line containing the insertion
|
||||
# point (thus multiple insertions to the same point will come out in the
|
||||
# order they were added). The double-@ is intended to make it unlikely that
|
||||
# the generated code could contain things that look like insertion points by
|
||||
# accident. For example, the C++ code generator places the following line in
|
||||
# the .pb.h files that it generates: //
|
||||
# @@protoc_insertion_point(namespace_scope) This line appears within the
|
||||
# scope of the file's package namespace, but outside of any particular class.
|
||||
# Another plugin can then specify the insertion_point "namespace_scope" to
|
||||
# generate additional classes or other declarations that should be placed in
|
||||
# this scope. Note that if the line containing the insertion point begins
|
||||
# with whitespace, the same whitespace will be added to every line of the
|
||||
# inserted text. This is useful for languages like Python, where indentation
|
||||
# matters. In these languages, the insertion point comment should be
|
||||
# indented the same amount as any inserted code will need to be in order to
|
||||
# work correctly in that context. The code generator that generates the
|
||||
# initial file and the one which inserts into it must both run as part of a
|
||||
# single invocation of protoc. Code generators are executed in the order in
|
||||
# which they appear on the command line. If |insertion_point| is present,
|
||||
# |name| must also be present.
|
||||
insertion_point: str = betterproto.string_field(2)
|
||||
# The file contents.
|
||||
content: str = betterproto.string_field(15)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf
|
||||
17
src/betterproto/plugin/main.py
Normal file → Executable file
17
src/betterproto/plugin/main.py
Normal file → Executable file
@@ -3,9 +3,13 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
from betterproto.lib.google.protobuf.compiler import (
|
||||
CodeGeneratorRequest,
|
||||
CodeGeneratorResponse,
|
||||
)
|
||||
|
||||
from betterproto.plugin.parser import generate_code
|
||||
from betterproto.plugin.models import monkey_patch_oneof_index
|
||||
|
||||
|
||||
def main() -> None:
|
||||
@@ -13,16 +17,19 @@ def main() -> None:
|
||||
# Read request message from stdin
|
||||
data = sys.stdin.buffer.read()
|
||||
|
||||
# Apply Work around for proto2/3 difference in protoc messages
|
||||
monkey_patch_oneof_index()
|
||||
|
||||
# Parse request
|
||||
request = plugin.CodeGeneratorRequest()
|
||||
request.ParseFromString(data)
|
||||
request = CodeGeneratorRequest()
|
||||
request.parse(data)
|
||||
|
||||
dump_file = os.getenv("BETTERPROTO_DUMP")
|
||||
if dump_file:
|
||||
dump_request(dump_file, request)
|
||||
|
||||
# Create response
|
||||
response = plugin.CodeGeneratorResponse()
|
||||
response = CodeGeneratorResponse()
|
||||
|
||||
# Generate code
|
||||
generate_code(request, response)
|
||||
@@ -34,7 +41,7 @@ def main() -> None:
|
||||
sys.stdout.buffer.write(output)
|
||||
|
||||
|
||||
def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest) -> None:
|
||||
def dump_request(dump_file: str, request: CodeGeneratorRequest) -> None:
|
||||
"""
|
||||
For developers: Supports running plugin.py standalone so its possible to debug it.
|
||||
Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file.
|
||||
|
||||
@@ -29,12 +29,37 @@ instantiating field `A` with parent message `B` should add a
|
||||
reference to `A` to `B`'s `fields` attribute.
|
||||
"""
|
||||
|
||||
|
||||
import betterproto
|
||||
from betterproto import which_one_of
|
||||
from betterproto.casing import sanitize_name
|
||||
from betterproto.compile.importing import (
|
||||
get_type_reference,
|
||||
parse_source_type_name,
|
||||
)
|
||||
from betterproto.compile.naming import (
|
||||
pythonize_class_name,
|
||||
pythonize_field_name,
|
||||
pythonize_method_name,
|
||||
)
|
||||
from betterproto.lib.google.protobuf import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FileDescriptorProto,
|
||||
MethodDescriptorProto,
|
||||
Field,
|
||||
FieldDescriptorProto,
|
||||
FieldDescriptorProtoType,
|
||||
FieldDescriptorProtoLabel,
|
||||
)
|
||||
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
|
||||
|
||||
|
||||
import re
|
||||
import textwrap
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Iterator, List, Optional, Set, Text, Type, Union
|
||||
|
||||
import betterproto
|
||||
import sys
|
||||
|
||||
from ..casing import sanitize_name
|
||||
from ..compile.importing import get_type_reference, parse_source_type_name
|
||||
@@ -44,26 +69,6 @@ from ..compile.naming import (
|
||||
pythonize_method_name,
|
||||
)
|
||||
|
||||
try:
|
||||
# betterproto[compiler] specific dependencies
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
from google.protobuf.descriptor_pb2 import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FieldDescriptorProto,
|
||||
FileDescriptorProto,
|
||||
MethodDescriptorProto,
|
||||
)
|
||||
except ImportError as err:
|
||||
print(
|
||||
"\033[31m"
|
||||
f"Unable to import `{err.name}` from betterproto plugin! "
|
||||
"Please ensure that you've installed betterproto as "
|
||||
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
||||
"are included."
|
||||
"\033[0m"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
# Create a unique placeholder to deal with
|
||||
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
|
||||
@@ -71,54 +76,75 @@ PLACEHOLDER = object()
|
||||
|
||||
# Organize proto types into categories
|
||||
PROTO_FLOAT_TYPES = (
|
||||
FieldDescriptorProto.TYPE_DOUBLE, # 1
|
||||
FieldDescriptorProto.TYPE_FLOAT, # 2
|
||||
FieldDescriptorProtoType.TYPE_DOUBLE, # 1
|
||||
FieldDescriptorProtoType.TYPE_FLOAT, # 2
|
||||
)
|
||||
PROTO_INT_TYPES = (
|
||||
FieldDescriptorProto.TYPE_INT64, # 3
|
||||
FieldDescriptorProto.TYPE_UINT64, # 4
|
||||
FieldDescriptorProto.TYPE_INT32, # 5
|
||||
FieldDescriptorProto.TYPE_FIXED64, # 6
|
||||
FieldDescriptorProto.TYPE_FIXED32, # 7
|
||||
FieldDescriptorProto.TYPE_UINT32, # 13
|
||||
FieldDescriptorProto.TYPE_SFIXED32, # 15
|
||||
FieldDescriptorProto.TYPE_SFIXED64, # 16
|
||||
FieldDescriptorProto.TYPE_SINT32, # 17
|
||||
FieldDescriptorProto.TYPE_SINT64, # 18
|
||||
FieldDescriptorProtoType.TYPE_INT64, # 3
|
||||
FieldDescriptorProtoType.TYPE_UINT64, # 4
|
||||
FieldDescriptorProtoType.TYPE_INT32, # 5
|
||||
FieldDescriptorProtoType.TYPE_FIXED64, # 6
|
||||
FieldDescriptorProtoType.TYPE_FIXED32, # 7
|
||||
FieldDescriptorProtoType.TYPE_UINT32, # 13
|
||||
FieldDescriptorProtoType.TYPE_SFIXED32, # 15
|
||||
FieldDescriptorProtoType.TYPE_SFIXED64, # 16
|
||||
FieldDescriptorProtoType.TYPE_SINT32, # 17
|
||||
FieldDescriptorProtoType.TYPE_SINT64, # 18
|
||||
)
|
||||
PROTO_BOOL_TYPES = (FieldDescriptorProto.TYPE_BOOL,) # 8
|
||||
PROTO_STR_TYPES = (FieldDescriptorProto.TYPE_STRING,) # 9
|
||||
PROTO_BYTES_TYPES = (FieldDescriptorProto.TYPE_BYTES,) # 12
|
||||
PROTO_BOOL_TYPES = (FieldDescriptorProtoType.TYPE_BOOL,) # 8
|
||||
PROTO_STR_TYPES = (FieldDescriptorProtoType.TYPE_STRING,) # 9
|
||||
PROTO_BYTES_TYPES = (FieldDescriptorProtoType.TYPE_BYTES,) # 12
|
||||
PROTO_MESSAGE_TYPES = (
|
||||
FieldDescriptorProto.TYPE_MESSAGE, # 11
|
||||
FieldDescriptorProto.TYPE_ENUM, # 14
|
||||
FieldDescriptorProtoType.TYPE_MESSAGE, # 11
|
||||
FieldDescriptorProtoType.TYPE_ENUM, # 14
|
||||
)
|
||||
PROTO_MAP_TYPES = (FieldDescriptorProto.TYPE_MESSAGE,) # 11
|
||||
PROTO_MAP_TYPES = (FieldDescriptorProtoType.TYPE_MESSAGE,) # 11
|
||||
PROTO_PACKED_TYPES = (
|
||||
FieldDescriptorProto.TYPE_DOUBLE, # 1
|
||||
FieldDescriptorProto.TYPE_FLOAT, # 2
|
||||
FieldDescriptorProto.TYPE_INT64, # 3
|
||||
FieldDescriptorProto.TYPE_UINT64, # 4
|
||||
FieldDescriptorProto.TYPE_INT32, # 5
|
||||
FieldDescriptorProto.TYPE_FIXED64, # 6
|
||||
FieldDescriptorProto.TYPE_FIXED32, # 7
|
||||
FieldDescriptorProto.TYPE_BOOL, # 8
|
||||
FieldDescriptorProto.TYPE_UINT32, # 13
|
||||
FieldDescriptorProto.TYPE_SFIXED32, # 15
|
||||
FieldDescriptorProto.TYPE_SFIXED64, # 16
|
||||
FieldDescriptorProto.TYPE_SINT32, # 17
|
||||
FieldDescriptorProto.TYPE_SINT64, # 18
|
||||
FieldDescriptorProtoType.TYPE_DOUBLE, # 1
|
||||
FieldDescriptorProtoType.TYPE_FLOAT, # 2
|
||||
FieldDescriptorProtoType.TYPE_INT64, # 3
|
||||
FieldDescriptorProtoType.TYPE_UINT64, # 4
|
||||
FieldDescriptorProtoType.TYPE_INT32, # 5
|
||||
FieldDescriptorProtoType.TYPE_FIXED64, # 6
|
||||
FieldDescriptorProtoType.TYPE_FIXED32, # 7
|
||||
FieldDescriptorProtoType.TYPE_BOOL, # 8
|
||||
FieldDescriptorProtoType.TYPE_UINT32, # 13
|
||||
FieldDescriptorProtoType.TYPE_SFIXED32, # 15
|
||||
FieldDescriptorProtoType.TYPE_SFIXED64, # 16
|
||||
FieldDescriptorProtoType.TYPE_SINT32, # 17
|
||||
FieldDescriptorProtoType.TYPE_SINT64, # 18
|
||||
)
|
||||
|
||||
|
||||
def monkey_patch_oneof_index():
|
||||
"""
|
||||
The compiler message types are written for proto2, but we read them as proto3.
|
||||
For this to work in the case of the oneof_index fields, which depend on being able
|
||||
to tell whether they were set, we have to treat them as oneof fields. This method
|
||||
monkey patches the generated classes after the fact to force this behaviour.
|
||||
"""
|
||||
object.__setattr__(
|
||||
FieldDescriptorProto.__dataclass_fields__["oneof_index"].metadata[
|
||||
"betterproto"
|
||||
],
|
||||
"group",
|
||||
"oneof_index",
|
||||
)
|
||||
object.__setattr__(
|
||||
Field.__dataclass_fields__["oneof_index"].metadata["betterproto"],
|
||||
"group",
|
||||
"oneof_index",
|
||||
)
|
||||
|
||||
|
||||
def get_comment(
|
||||
proto_file: "FileDescriptorProto", path: List[int], indent: int = 4
|
||||
) -> str:
|
||||
pad = " " * indent
|
||||
for sci in proto_file.source_code_info.location:
|
||||
if list(sci.path) == path and sci.leading_comments:
|
||||
for sci_loc in proto_file.source_code_info.location:
|
||||
if list(sci_loc.path) == path and sci_loc.leading_comments:
|
||||
lines = textwrap.wrap(
|
||||
sci.leading_comments.strip().replace("\n", ""), width=79 - indent
|
||||
sci_loc.leading_comments.strip().replace("\n", ""), width=79 - indent
|
||||
)
|
||||
|
||||
if path[-2] == 2 and path[-4] != 6:
|
||||
@@ -139,6 +165,7 @@ def get_comment(
|
||||
class ProtoContentBase:
|
||||
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
|
||||
|
||||
source_file: FileDescriptorProto
|
||||
path: List[int]
|
||||
comment_indent: int = 4
|
||||
parent: Union["betterproto.Message", "OutputTemplate"]
|
||||
@@ -156,13 +183,6 @@ class ProtoContentBase:
|
||||
current = current.parent
|
||||
return current
|
||||
|
||||
@property
|
||||
def proto_file(self) -> FieldDescriptorProto:
|
||||
current = self
|
||||
while not isinstance(current, OutputTemplate):
|
||||
current = current.parent
|
||||
return current.package_proto_obj
|
||||
|
||||
@property
|
||||
def request(self) -> "PluginRequestCompiler":
|
||||
current = self
|
||||
@@ -176,14 +196,14 @@ class ProtoContentBase:
|
||||
for this object.
|
||||
"""
|
||||
return get_comment(
|
||||
proto_file=self.proto_file, path=self.path, indent=self.comment_indent
|
||||
proto_file=self.source_file, path=self.path, indent=self.comment_indent
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginRequestCompiler:
|
||||
|
||||
plugin_request_obj: plugin.CodeGeneratorRequest
|
||||
plugin_request_obj: CodeGeneratorRequest
|
||||
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
@@ -253,6 +273,7 @@ class OutputTemplate:
|
||||
class MessageCompiler(ProtoContentBase):
|
||||
"""Representation of a protobuf message."""
|
||||
|
||||
source_file: FileDescriptorProto
|
||||
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
|
||||
proto_obj: DescriptorProto = PLACEHOLDER
|
||||
path: List[int] = PLACEHOLDER
|
||||
@@ -296,7 +317,7 @@ def is_map(
|
||||
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
|
||||
) -> bool:
|
||||
"""True if proto_field_obj is a map, otherwise False."""
|
||||
if proto_field_obj.type == FieldDescriptorProto.TYPE_MESSAGE:
|
||||
if proto_field_obj.type == FieldDescriptorProtoType.TYPE_MESSAGE:
|
||||
# This might be a map...
|
||||
message_type = proto_field_obj.type_name.split(".").pop().lower()
|
||||
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
|
||||
@@ -311,8 +332,20 @@ def is_map(
|
||||
|
||||
|
||||
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
|
||||
"""True if proto_field_obj is a OneOf, otherwise False."""
|
||||
return proto_field_obj.HasField("oneof_index")
|
||||
"""
|
||||
True if proto_field_obj is a OneOf, otherwise False.
|
||||
|
||||
.. warning::
|
||||
Becuase the message from protoc is defined in proto2, and betterproto works with
|
||||
proto3, and interpreting the FieldDescriptorProto.oneof_index field requires
|
||||
distinguishing between default and unset values (which proto3 doesn't support),
|
||||
we have to hack the generated FieldDescriptorProto class for this to work.
|
||||
The hack consists of setting group="oneof_index" in the field metadata,
|
||||
essentially making oneof_index the sole member of a one_of group, which allows
|
||||
us to tell whether it was set, via the which_one_of interface.
|
||||
"""
|
||||
|
||||
return which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -377,7 +410,7 @@ class FieldCompiler(MessageCompiler):
|
||||
def field_wraps(self) -> Optional[str]:
|
||||
"""Returns betterproto wrapped field type or None."""
|
||||
match_wrapper = re.match(
|
||||
r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name
|
||||
r"\.google\.protobuf\.(.+)Value$", self.proto_obj.type_name
|
||||
)
|
||||
if match_wrapper:
|
||||
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
|
||||
@@ -388,7 +421,7 @@ class FieldCompiler(MessageCompiler):
|
||||
@property
|
||||
def repeated(self) -> bool:
|
||||
return (
|
||||
self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED
|
||||
self.proto_obj.label == FieldDescriptorProtoLabel.LABEL_REPEATED
|
||||
and not is_map(self.proto_obj, self.parent)
|
||||
)
|
||||
|
||||
@@ -401,7 +434,9 @@ class FieldCompiler(MessageCompiler):
|
||||
def field_type(self) -> str:
|
||||
"""String representation of proto field type."""
|
||||
return (
|
||||
self.proto_obj.Type.Name(self.proto_obj.type).lower().replace("type_", "")
|
||||
FieldDescriptorProtoType(self.proto_obj.type)
|
||||
.name.lower()
|
||||
.replace("type_", "")
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -495,14 +530,19 @@ class MapEntryCompiler(FieldCompiler):
|
||||
):
|
||||
# Get Python types
|
||||
self.py_k_type = FieldCompiler(
|
||||
parent=self, proto_obj=nested.field[0] # key
|
||||
source_file=self.source_file,
|
||||
parent=self,
|
||||
proto_obj=nested.field[0], # key
|
||||
).py_type
|
||||
self.py_v_type = FieldCompiler(
|
||||
parent=self, proto_obj=nested.field[1] # value
|
||||
source_file=self.source_file,
|
||||
parent=self,
|
||||
proto_obj=nested.field[1], # value
|
||||
).py_type
|
||||
|
||||
# Get proto types
|
||||
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
|
||||
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
|
||||
self.proto_k_type = FieldDescriptorProtoType(nested.field[0].type).name
|
||||
self.proto_v_type = FieldDescriptorProtoType(nested.field[1].type).name
|
||||
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
|
||||
|
||||
@property
|
||||
@@ -544,7 +584,7 @@ class EnumDefinitionCompiler(MessageCompiler):
|
||||
name=sanitize_name(entry_proto_value.name),
|
||||
value=entry_proto_value.number,
|
||||
comment=get_comment(
|
||||
proto_file=self.proto_file, path=self.path + [2, entry_number]
|
||||
proto_file=self.source_file, path=self.path + [2, entry_number]
|
||||
),
|
||||
)
|
||||
for entry_number, entry_proto_value in enumerate(self.proto_obj.value)
|
||||
|
||||
@@ -1,28 +1,19 @@
|
||||
from betterproto.lib.google.protobuf import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FieldDescriptorProto,
|
||||
FileDescriptorProto,
|
||||
ServiceDescriptorProto,
|
||||
)
|
||||
from betterproto.lib.google.protobuf.compiler import (
|
||||
CodeGeneratorRequest,
|
||||
CodeGeneratorResponse,
|
||||
CodeGeneratorResponseFile,
|
||||
)
|
||||
import itertools
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Iterator, List, Tuple, Union, Set
|
||||
|
||||
try:
|
||||
# betterproto[compiler] specific dependencies
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
from google.protobuf.descriptor_pb2 import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FieldDescriptorProto,
|
||||
ServiceDescriptorProto,
|
||||
)
|
||||
except ImportError as err:
|
||||
print(
|
||||
"\033[31m"
|
||||
f"Unable to import `{err.name}` from betterproto plugin! "
|
||||
"Please ensure that you've installed betterproto as "
|
||||
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
||||
"are included."
|
||||
"\033[0m"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
from typing import Iterator, List, Tuple, TYPE_CHECKING, Union
|
||||
from .compiler import outputfile_compiler
|
||||
from .models import (
|
||||
EnumDefinitionCompiler,
|
||||
@@ -70,7 +61,7 @@ def traverse(
|
||||
|
||||
|
||||
def generate_code(
|
||||
request: plugin.CodeGeneratorRequest, response: plugin.CodeGeneratorResponse
|
||||
request: CodeGeneratorRequest, response: CodeGeneratorResponse
|
||||
) -> None:
|
||||
plugin_options = request.parameter.split(",") if request.parameter else []
|
||||
|
||||
@@ -100,7 +91,12 @@ def generate_code(
|
||||
for output_package_name, output_package in request_data.output_packages.items():
|
||||
for proto_input_file in output_package.input_files:
|
||||
for item, path in traverse(proto_input_file):
|
||||
read_protobuf_type(item=item, path=path, output_package=output_package)
|
||||
read_protobuf_type(
|
||||
source_file=proto_input_file,
|
||||
item=item,
|
||||
path=path,
|
||||
output_package=output_package,
|
||||
)
|
||||
|
||||
# Read Services
|
||||
for output_package_name, output_package in request_data.output_packages.items():
|
||||
@@ -116,11 +112,13 @@ def generate_code(
|
||||
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
|
||||
output_paths.add(output_path)
|
||||
|
||||
f: response.File = response.file.add()
|
||||
f.name = str(output_path)
|
||||
|
||||
# Render and then format the output file
|
||||
f.content = outputfile_compiler(output_file=output_package)
|
||||
response.file.append(
|
||||
CodeGeneratorResponseFile(
|
||||
name=str(output_path),
|
||||
# Render and then format the output file
|
||||
content=outputfile_compiler(output_file=output_package),
|
||||
)
|
||||
)
|
||||
|
||||
# Make each output directory a package with __init__ file
|
||||
init_files = {
|
||||
@@ -130,38 +128,53 @@ def generate_code(
|
||||
} - output_paths
|
||||
|
||||
for init_file in init_files:
|
||||
init = response.file.add()
|
||||
init.name = str(init_file)
|
||||
response.file.append(CodeGeneratorResponseFile(name=str(init_file)))
|
||||
|
||||
for output_package_name in sorted(output_paths.union(init_files)):
|
||||
print(f"Writing {output_package_name}", file=sys.stderr)
|
||||
|
||||
|
||||
def read_protobuf_type(
|
||||
item: DescriptorProto, path: List[int], output_package: OutputTemplate
|
||||
item: DescriptorProto,
|
||||
path: List[int],
|
||||
source_file: "FileDescriptorProto",
|
||||
output_package: OutputTemplate,
|
||||
) -> None:
|
||||
if isinstance(item, DescriptorProto):
|
||||
if item.options.map_entry:
|
||||
# Skip generated map entry messages since we just use dicts
|
||||
return
|
||||
# Process Message
|
||||
message_data = MessageCompiler(parent=output_package, proto_obj=item, path=path)
|
||||
message_data = MessageCompiler(
|
||||
source_file=source_file, parent=output_package, proto_obj=item, path=path
|
||||
)
|
||||
for index, field in enumerate(item.field):
|
||||
if is_map(field, item):
|
||||
MapEntryCompiler(
|
||||
parent=message_data, proto_obj=field, path=path + [2, index]
|
||||
source_file=source_file,
|
||||
parent=message_data,
|
||||
proto_obj=field,
|
||||
path=path + [2, index],
|
||||
)
|
||||
elif is_oneof(field):
|
||||
OneOfFieldCompiler(
|
||||
parent=message_data, proto_obj=field, path=path + [2, index]
|
||||
source_file=source_file,
|
||||
parent=message_data,
|
||||
proto_obj=field,
|
||||
path=path + [2, index],
|
||||
)
|
||||
else:
|
||||
FieldCompiler(
|
||||
parent=message_data, proto_obj=field, path=path + [2, index]
|
||||
source_file=source_file,
|
||||
parent=message_data,
|
||||
proto_obj=field,
|
||||
path=path + [2, index],
|
||||
)
|
||||
elif isinstance(item, EnumDescriptorProto):
|
||||
# Enum
|
||||
EnumDefinitionCompiler(parent=output_package, proto_obj=item, path=path)
|
||||
EnumDefinitionCompiler(
|
||||
source_file=source_file, parent=output_package, proto_obj=item, path=path
|
||||
)
|
||||
|
||||
|
||||
def read_protobuf_service(
|
||||
|
||||
Reference in New Issue
Block a user