Make plugin use betterproto generated classes internally
This means the betterproto plugin no longer needs to depend durectly on protobuf. This requires a small runtime hack to monkey patch some google types to get around the fact that the compiler uses proto2, but betterproto expects proto3. Also: - regenerate google.protobuf package - fix a regex bug in the logic for determining whether to use a google wrapper type. - fix a bug causing comments to get mixed up when multiple proto files generate code into a single python module
This commit is contained in:
@@ -29,12 +29,37 @@ instantiating field `A` with parent message `B` should add a
|
||||
reference to `A` to `B`'s `fields` attribute.
|
||||
"""
|
||||
|
||||
|
||||
import betterproto
|
||||
from betterproto import which_one_of
|
||||
from betterproto.casing import sanitize_name
|
||||
from betterproto.compile.importing import (
|
||||
get_type_reference,
|
||||
parse_source_type_name,
|
||||
)
|
||||
from betterproto.compile.naming import (
|
||||
pythonize_class_name,
|
||||
pythonize_field_name,
|
||||
pythonize_method_name,
|
||||
)
|
||||
from betterproto.lib.google.protobuf import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FileDescriptorProto,
|
||||
MethodDescriptorProto,
|
||||
Field,
|
||||
FieldDescriptorProto,
|
||||
FieldDescriptorProtoType,
|
||||
FieldDescriptorProtoLabel,
|
||||
)
|
||||
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
|
||||
|
||||
|
||||
import re
|
||||
import textwrap
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Iterator, List, Optional, Set, Text, Type, Union
|
||||
|
||||
import betterproto
|
||||
import sys
|
||||
|
||||
from ..casing import sanitize_name
|
||||
from ..compile.importing import get_type_reference, parse_source_type_name
|
||||
@@ -44,26 +69,6 @@ from ..compile.naming import (
|
||||
pythonize_method_name,
|
||||
)
|
||||
|
||||
try:
|
||||
# betterproto[compiler] specific dependencies
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
from google.protobuf.descriptor_pb2 import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FieldDescriptorProto,
|
||||
FileDescriptorProto,
|
||||
MethodDescriptorProto,
|
||||
)
|
||||
except ImportError as err:
|
||||
print(
|
||||
"\033[31m"
|
||||
f"Unable to import `{err.name}` from betterproto plugin! "
|
||||
"Please ensure that you've installed betterproto as "
|
||||
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
||||
"are included."
|
||||
"\033[0m"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
# Create a unique placeholder to deal with
|
||||
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
|
||||
@@ -71,54 +76,75 @@ PLACEHOLDER = object()
|
||||
|
||||
# Organize proto types into categories
|
||||
PROTO_FLOAT_TYPES = (
|
||||
FieldDescriptorProto.TYPE_DOUBLE, # 1
|
||||
FieldDescriptorProto.TYPE_FLOAT, # 2
|
||||
FieldDescriptorProtoType.TYPE_DOUBLE, # 1
|
||||
FieldDescriptorProtoType.TYPE_FLOAT, # 2
|
||||
)
|
||||
PROTO_INT_TYPES = (
|
||||
FieldDescriptorProto.TYPE_INT64, # 3
|
||||
FieldDescriptorProto.TYPE_UINT64, # 4
|
||||
FieldDescriptorProto.TYPE_INT32, # 5
|
||||
FieldDescriptorProto.TYPE_FIXED64, # 6
|
||||
FieldDescriptorProto.TYPE_FIXED32, # 7
|
||||
FieldDescriptorProto.TYPE_UINT32, # 13
|
||||
FieldDescriptorProto.TYPE_SFIXED32, # 15
|
||||
FieldDescriptorProto.TYPE_SFIXED64, # 16
|
||||
FieldDescriptorProto.TYPE_SINT32, # 17
|
||||
FieldDescriptorProto.TYPE_SINT64, # 18
|
||||
FieldDescriptorProtoType.TYPE_INT64, # 3
|
||||
FieldDescriptorProtoType.TYPE_UINT64, # 4
|
||||
FieldDescriptorProtoType.TYPE_INT32, # 5
|
||||
FieldDescriptorProtoType.TYPE_FIXED64, # 6
|
||||
FieldDescriptorProtoType.TYPE_FIXED32, # 7
|
||||
FieldDescriptorProtoType.TYPE_UINT32, # 13
|
||||
FieldDescriptorProtoType.TYPE_SFIXED32, # 15
|
||||
FieldDescriptorProtoType.TYPE_SFIXED64, # 16
|
||||
FieldDescriptorProtoType.TYPE_SINT32, # 17
|
||||
FieldDescriptorProtoType.TYPE_SINT64, # 18
|
||||
)
|
||||
PROTO_BOOL_TYPES = (FieldDescriptorProto.TYPE_BOOL,) # 8
|
||||
PROTO_STR_TYPES = (FieldDescriptorProto.TYPE_STRING,) # 9
|
||||
PROTO_BYTES_TYPES = (FieldDescriptorProto.TYPE_BYTES,) # 12
|
||||
PROTO_BOOL_TYPES = (FieldDescriptorProtoType.TYPE_BOOL,) # 8
|
||||
PROTO_STR_TYPES = (FieldDescriptorProtoType.TYPE_STRING,) # 9
|
||||
PROTO_BYTES_TYPES = (FieldDescriptorProtoType.TYPE_BYTES,) # 12
|
||||
PROTO_MESSAGE_TYPES = (
|
||||
FieldDescriptorProto.TYPE_MESSAGE, # 11
|
||||
FieldDescriptorProto.TYPE_ENUM, # 14
|
||||
FieldDescriptorProtoType.TYPE_MESSAGE, # 11
|
||||
FieldDescriptorProtoType.TYPE_ENUM, # 14
|
||||
)
|
||||
PROTO_MAP_TYPES = (FieldDescriptorProto.TYPE_MESSAGE,) # 11
|
||||
PROTO_MAP_TYPES = (FieldDescriptorProtoType.TYPE_MESSAGE,) # 11
|
||||
PROTO_PACKED_TYPES = (
|
||||
FieldDescriptorProto.TYPE_DOUBLE, # 1
|
||||
FieldDescriptorProto.TYPE_FLOAT, # 2
|
||||
FieldDescriptorProto.TYPE_INT64, # 3
|
||||
FieldDescriptorProto.TYPE_UINT64, # 4
|
||||
FieldDescriptorProto.TYPE_INT32, # 5
|
||||
FieldDescriptorProto.TYPE_FIXED64, # 6
|
||||
FieldDescriptorProto.TYPE_FIXED32, # 7
|
||||
FieldDescriptorProto.TYPE_BOOL, # 8
|
||||
FieldDescriptorProto.TYPE_UINT32, # 13
|
||||
FieldDescriptorProto.TYPE_SFIXED32, # 15
|
||||
FieldDescriptorProto.TYPE_SFIXED64, # 16
|
||||
FieldDescriptorProto.TYPE_SINT32, # 17
|
||||
FieldDescriptorProto.TYPE_SINT64, # 18
|
||||
FieldDescriptorProtoType.TYPE_DOUBLE, # 1
|
||||
FieldDescriptorProtoType.TYPE_FLOAT, # 2
|
||||
FieldDescriptorProtoType.TYPE_INT64, # 3
|
||||
FieldDescriptorProtoType.TYPE_UINT64, # 4
|
||||
FieldDescriptorProtoType.TYPE_INT32, # 5
|
||||
FieldDescriptorProtoType.TYPE_FIXED64, # 6
|
||||
FieldDescriptorProtoType.TYPE_FIXED32, # 7
|
||||
FieldDescriptorProtoType.TYPE_BOOL, # 8
|
||||
FieldDescriptorProtoType.TYPE_UINT32, # 13
|
||||
FieldDescriptorProtoType.TYPE_SFIXED32, # 15
|
||||
FieldDescriptorProtoType.TYPE_SFIXED64, # 16
|
||||
FieldDescriptorProtoType.TYPE_SINT32, # 17
|
||||
FieldDescriptorProtoType.TYPE_SINT64, # 18
|
||||
)
|
||||
|
||||
|
||||
def monkey_patch_oneof_index():
|
||||
"""
|
||||
The compiler message types are written for proto2, but we read them as proto3.
|
||||
For this to work in the case of the oneof_index fields, which depend on being able
|
||||
to tell whether they were set, we have to treat them as oneof fields. This method
|
||||
monkey patches the generated classes after the fact to force this behaviour.
|
||||
"""
|
||||
object.__setattr__(
|
||||
FieldDescriptorProto.__dataclass_fields__["oneof_index"].metadata[
|
||||
"betterproto"
|
||||
],
|
||||
"group",
|
||||
"oneof_index",
|
||||
)
|
||||
object.__setattr__(
|
||||
Field.__dataclass_fields__["oneof_index"].metadata["betterproto"],
|
||||
"group",
|
||||
"oneof_index",
|
||||
)
|
||||
|
||||
|
||||
def get_comment(
|
||||
proto_file: "FileDescriptorProto", path: List[int], indent: int = 4
|
||||
) -> str:
|
||||
pad = " " * indent
|
||||
for sci in proto_file.source_code_info.location:
|
||||
if list(sci.path) == path and sci.leading_comments:
|
||||
for sci_loc in proto_file.source_code_info.location:
|
||||
if list(sci_loc.path) == path and sci_loc.leading_comments:
|
||||
lines = textwrap.wrap(
|
||||
sci.leading_comments.strip().replace("\n", ""), width=79 - indent
|
||||
sci_loc.leading_comments.strip().replace("\n", ""), width=79 - indent
|
||||
)
|
||||
|
||||
if path[-2] == 2 and path[-4] != 6:
|
||||
@@ -139,6 +165,7 @@ def get_comment(
|
||||
class ProtoContentBase:
|
||||
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
|
||||
|
||||
source_file: FileDescriptorProto
|
||||
path: List[int]
|
||||
comment_indent: int = 4
|
||||
parent: Union["betterproto.Message", "OutputTemplate"]
|
||||
@@ -156,13 +183,6 @@ class ProtoContentBase:
|
||||
current = current.parent
|
||||
return current
|
||||
|
||||
@property
|
||||
def proto_file(self) -> FieldDescriptorProto:
|
||||
current = self
|
||||
while not isinstance(current, OutputTemplate):
|
||||
current = current.parent
|
||||
return current.package_proto_obj
|
||||
|
||||
@property
|
||||
def request(self) -> "PluginRequestCompiler":
|
||||
current = self
|
||||
@@ -176,14 +196,14 @@ class ProtoContentBase:
|
||||
for this object.
|
||||
"""
|
||||
return get_comment(
|
||||
proto_file=self.proto_file, path=self.path, indent=self.comment_indent
|
||||
proto_file=self.source_file, path=self.path, indent=self.comment_indent
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginRequestCompiler:
|
||||
|
||||
plugin_request_obj: plugin.CodeGeneratorRequest
|
||||
plugin_request_obj: CodeGeneratorRequest
|
||||
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
@@ -253,6 +273,7 @@ class OutputTemplate:
|
||||
class MessageCompiler(ProtoContentBase):
|
||||
"""Representation of a protobuf message."""
|
||||
|
||||
source_file: FileDescriptorProto
|
||||
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
|
||||
proto_obj: DescriptorProto = PLACEHOLDER
|
||||
path: List[int] = PLACEHOLDER
|
||||
@@ -296,7 +317,7 @@ def is_map(
|
||||
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
|
||||
) -> bool:
|
||||
"""True if proto_field_obj is a map, otherwise False."""
|
||||
if proto_field_obj.type == FieldDescriptorProto.TYPE_MESSAGE:
|
||||
if proto_field_obj.type == FieldDescriptorProtoType.TYPE_MESSAGE:
|
||||
# This might be a map...
|
||||
message_type = proto_field_obj.type_name.split(".").pop().lower()
|
||||
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
|
||||
@@ -311,8 +332,20 @@ def is_map(
|
||||
|
||||
|
||||
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
|
||||
"""True if proto_field_obj is a OneOf, otherwise False."""
|
||||
return proto_field_obj.HasField("oneof_index")
|
||||
"""
|
||||
True if proto_field_obj is a OneOf, otherwise False.
|
||||
|
||||
.. warning::
|
||||
Becuase the message from protoc is defined in proto2, and betterproto works with
|
||||
proto3, and interpreting the FieldDescriptorProto.oneof_index field requires
|
||||
distinguishing between default and unset values (which proto3 doesn't support),
|
||||
we have to hack the generated FieldDescriptorProto class for this to work.
|
||||
The hack consists of setting group="oneof_index" in the field metadata,
|
||||
essentially making oneof_index the sole member of a one_of group, which allows
|
||||
us to tell whether it was set, via the which_one_of interface.
|
||||
"""
|
||||
|
||||
return which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -377,7 +410,7 @@ class FieldCompiler(MessageCompiler):
|
||||
def field_wraps(self) -> Optional[str]:
|
||||
"""Returns betterproto wrapped field type or None."""
|
||||
match_wrapper = re.match(
|
||||
r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name
|
||||
r"\.google\.protobuf\.(.+)Value$", self.proto_obj.type_name
|
||||
)
|
||||
if match_wrapper:
|
||||
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
|
||||
@@ -388,7 +421,7 @@ class FieldCompiler(MessageCompiler):
|
||||
@property
|
||||
def repeated(self) -> bool:
|
||||
return (
|
||||
self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED
|
||||
self.proto_obj.label == FieldDescriptorProtoLabel.LABEL_REPEATED
|
||||
and not is_map(self.proto_obj, self.parent)
|
||||
)
|
||||
|
||||
@@ -401,7 +434,9 @@ class FieldCompiler(MessageCompiler):
|
||||
def field_type(self) -> str:
|
||||
"""String representation of proto field type."""
|
||||
return (
|
||||
self.proto_obj.Type.Name(self.proto_obj.type).lower().replace("type_", "")
|
||||
FieldDescriptorProtoType(self.proto_obj.type)
|
||||
.name.lower()
|
||||
.replace("type_", "")
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -495,14 +530,19 @@ class MapEntryCompiler(FieldCompiler):
|
||||
):
|
||||
# Get Python types
|
||||
self.py_k_type = FieldCompiler(
|
||||
parent=self, proto_obj=nested.field[0] # key
|
||||
source_file=self.source_file,
|
||||
parent=self,
|
||||
proto_obj=nested.field[0], # key
|
||||
).py_type
|
||||
self.py_v_type = FieldCompiler(
|
||||
parent=self, proto_obj=nested.field[1] # value
|
||||
source_file=self.source_file,
|
||||
parent=self,
|
||||
proto_obj=nested.field[1], # value
|
||||
).py_type
|
||||
|
||||
# Get proto types
|
||||
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
|
||||
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
|
||||
self.proto_k_type = FieldDescriptorProtoType(nested.field[0].type).name
|
||||
self.proto_v_type = FieldDescriptorProtoType(nested.field[1].type).name
|
||||
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
|
||||
|
||||
@property
|
||||
@@ -544,7 +584,7 @@ class EnumDefinitionCompiler(MessageCompiler):
|
||||
name=sanitize_name(entry_proto_value.name),
|
||||
value=entry_proto_value.number,
|
||||
comment=get_comment(
|
||||
proto_file=self.proto_file, path=self.path + [2, entry_number]
|
||||
proto_file=self.source_file, path=self.path + [2, entry_number]
|
||||
),
|
||||
)
|
||||
for entry_number, entry_proto_value in enumerate(self.proto_obj.value)
|
||||
|
||||
Reference in New Issue
Block a user