Make plugin use betterproto generated classes internally

This means the betterproto plugin no longer needs to depend durectly on
protobuf.

This requires a small runtime hack to monkey patch some google types to
get around the fact that the compiler uses proto2, but betterproto
expects proto3.

Also:
- regenerate google.protobuf package
- fix a regex bug in the logic for determining whether to use a google
  wrapper type.
- fix a bug causing comments to get mixed up when multiple proto files
  generate code into a single python module
This commit is contained in:
Nat Noordanus
2020-10-18 22:47:58 +02:00
committed by Basileus
parent 7a358a63cf
commit fe1e712fdb
11 changed files with 2381 additions and 1122 deletions

17
src/betterproto/plugin/main.py Normal file → Executable file
View File

@@ -3,9 +3,13 @@
import os
import sys
from google.protobuf.compiler import plugin_pb2 as plugin
from betterproto.lib.google.protobuf.compiler import (
CodeGeneratorRequest,
CodeGeneratorResponse,
)
from betterproto.plugin.parser import generate_code
from betterproto.plugin.models import monkey_patch_oneof_index
def main() -> None:
@@ -13,16 +17,19 @@ def main() -> None:
# Read request message from stdin
data = sys.stdin.buffer.read()
# Apply Work around for proto2/3 difference in protoc messages
monkey_patch_oneof_index()
# Parse request
request = plugin.CodeGeneratorRequest()
request.ParseFromString(data)
request = CodeGeneratorRequest()
request.parse(data)
dump_file = os.getenv("BETTERPROTO_DUMP")
if dump_file:
dump_request(dump_file, request)
# Create response
response = plugin.CodeGeneratorResponse()
response = CodeGeneratorResponse()
# Generate code
generate_code(request, response)
@@ -34,7 +41,7 @@ def main() -> None:
sys.stdout.buffer.write(output)
def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest) -> None:
def dump_request(dump_file: str, request: CodeGeneratorRequest) -> None:
"""
For developers: Supports running plugin.py standalone so its possible to debug it.
Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file.

View File

@@ -29,12 +29,37 @@ instantiating field `A` with parent message `B` should add a
reference to `A` to `B`'s `fields` attribute.
"""
import betterproto
from betterproto import which_one_of
from betterproto.casing import sanitize_name
from betterproto.compile.importing import (
get_type_reference,
parse_source_type_name,
)
from betterproto.compile.naming import (
pythonize_class_name,
pythonize_field_name,
pythonize_method_name,
)
from betterproto.lib.google.protobuf import (
DescriptorProto,
EnumDescriptorProto,
FileDescriptorProto,
MethodDescriptorProto,
Field,
FieldDescriptorProto,
FieldDescriptorProtoType,
FieldDescriptorProtoLabel,
)
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest
import re
import textwrap
from dataclasses import dataclass, field
from typing import Dict, Iterator, List, Optional, Set, Text, Type, Union
import betterproto
import sys
from ..casing import sanitize_name
from ..compile.importing import get_type_reference, parse_source_type_name
@@ -44,26 +69,6 @@ from ..compile.naming import (
pythonize_method_name,
)
try:
# betterproto[compiler] specific dependencies
from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.descriptor_pb2 import (
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
FileDescriptorProto,
MethodDescriptorProto,
)
except ImportError as err:
print(
"\033[31m"
f"Unable to import `{err.name}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
"\033[0m"
)
raise SystemExit(1)
# Create a unique placeholder to deal with
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
@@ -71,54 +76,75 @@ PLACEHOLDER = object()
# Organize proto types into categories
PROTO_FLOAT_TYPES = (
FieldDescriptorProto.TYPE_DOUBLE, # 1
FieldDescriptorProto.TYPE_FLOAT, # 2
FieldDescriptorProtoType.TYPE_DOUBLE, # 1
FieldDescriptorProtoType.TYPE_FLOAT, # 2
)
PROTO_INT_TYPES = (
FieldDescriptorProto.TYPE_INT64, # 3
FieldDescriptorProto.TYPE_UINT64, # 4
FieldDescriptorProto.TYPE_INT32, # 5
FieldDescriptorProto.TYPE_FIXED64, # 6
FieldDescriptorProto.TYPE_FIXED32, # 7
FieldDescriptorProto.TYPE_UINT32, # 13
FieldDescriptorProto.TYPE_SFIXED32, # 15
FieldDescriptorProto.TYPE_SFIXED64, # 16
FieldDescriptorProto.TYPE_SINT32, # 17
FieldDescriptorProto.TYPE_SINT64, # 18
FieldDescriptorProtoType.TYPE_INT64, # 3
FieldDescriptorProtoType.TYPE_UINT64, # 4
FieldDescriptorProtoType.TYPE_INT32, # 5
FieldDescriptorProtoType.TYPE_FIXED64, # 6
FieldDescriptorProtoType.TYPE_FIXED32, # 7
FieldDescriptorProtoType.TYPE_UINT32, # 13
FieldDescriptorProtoType.TYPE_SFIXED32, # 15
FieldDescriptorProtoType.TYPE_SFIXED64, # 16
FieldDescriptorProtoType.TYPE_SINT32, # 17
FieldDescriptorProtoType.TYPE_SINT64, # 18
)
PROTO_BOOL_TYPES = (FieldDescriptorProto.TYPE_BOOL,) # 8
PROTO_STR_TYPES = (FieldDescriptorProto.TYPE_STRING,) # 9
PROTO_BYTES_TYPES = (FieldDescriptorProto.TYPE_BYTES,) # 12
PROTO_BOOL_TYPES = (FieldDescriptorProtoType.TYPE_BOOL,) # 8
PROTO_STR_TYPES = (FieldDescriptorProtoType.TYPE_STRING,) # 9
PROTO_BYTES_TYPES = (FieldDescriptorProtoType.TYPE_BYTES,) # 12
PROTO_MESSAGE_TYPES = (
FieldDescriptorProto.TYPE_MESSAGE, # 11
FieldDescriptorProto.TYPE_ENUM, # 14
FieldDescriptorProtoType.TYPE_MESSAGE, # 11
FieldDescriptorProtoType.TYPE_ENUM, # 14
)
PROTO_MAP_TYPES = (FieldDescriptorProto.TYPE_MESSAGE,) # 11
PROTO_MAP_TYPES = (FieldDescriptorProtoType.TYPE_MESSAGE,) # 11
PROTO_PACKED_TYPES = (
FieldDescriptorProto.TYPE_DOUBLE, # 1
FieldDescriptorProto.TYPE_FLOAT, # 2
FieldDescriptorProto.TYPE_INT64, # 3
FieldDescriptorProto.TYPE_UINT64, # 4
FieldDescriptorProto.TYPE_INT32, # 5
FieldDescriptorProto.TYPE_FIXED64, # 6
FieldDescriptorProto.TYPE_FIXED32, # 7
FieldDescriptorProto.TYPE_BOOL, # 8
FieldDescriptorProto.TYPE_UINT32, # 13
FieldDescriptorProto.TYPE_SFIXED32, # 15
FieldDescriptorProto.TYPE_SFIXED64, # 16
FieldDescriptorProto.TYPE_SINT32, # 17
FieldDescriptorProto.TYPE_SINT64, # 18
FieldDescriptorProtoType.TYPE_DOUBLE, # 1
FieldDescriptorProtoType.TYPE_FLOAT, # 2
FieldDescriptorProtoType.TYPE_INT64, # 3
FieldDescriptorProtoType.TYPE_UINT64, # 4
FieldDescriptorProtoType.TYPE_INT32, # 5
FieldDescriptorProtoType.TYPE_FIXED64, # 6
FieldDescriptorProtoType.TYPE_FIXED32, # 7
FieldDescriptorProtoType.TYPE_BOOL, # 8
FieldDescriptorProtoType.TYPE_UINT32, # 13
FieldDescriptorProtoType.TYPE_SFIXED32, # 15
FieldDescriptorProtoType.TYPE_SFIXED64, # 16
FieldDescriptorProtoType.TYPE_SINT32, # 17
FieldDescriptorProtoType.TYPE_SINT64, # 18
)
def monkey_patch_oneof_index():
"""
The compiler message types are written for proto2, but we read them as proto3.
For this to work in the case of the oneof_index fields, which depend on being able
to tell whether they were set, we have to treat them as oneof fields. This method
monkey patches the generated classes after the fact to force this behaviour.
"""
object.__setattr__(
FieldDescriptorProto.__dataclass_fields__["oneof_index"].metadata[
"betterproto"
],
"group",
"oneof_index",
)
object.__setattr__(
Field.__dataclass_fields__["oneof_index"].metadata["betterproto"],
"group",
"oneof_index",
)
def get_comment(
proto_file: "FileDescriptorProto", path: List[int], indent: int = 4
) -> str:
pad = " " * indent
for sci in proto_file.source_code_info.location:
if list(sci.path) == path and sci.leading_comments:
for sci_loc in proto_file.source_code_info.location:
if list(sci_loc.path) == path and sci_loc.leading_comments:
lines = textwrap.wrap(
sci.leading_comments.strip().replace("\n", ""), width=79 - indent
sci_loc.leading_comments.strip().replace("\n", ""), width=79 - indent
)
if path[-2] == 2 and path[-4] != 6:
@@ -139,6 +165,7 @@ def get_comment(
class ProtoContentBase:
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
source_file: FileDescriptorProto
path: List[int]
comment_indent: int = 4
parent: Union["betterproto.Message", "OutputTemplate"]
@@ -156,13 +183,6 @@ class ProtoContentBase:
current = current.parent
return current
@property
def proto_file(self) -> FieldDescriptorProto:
current = self
while not isinstance(current, OutputTemplate):
current = current.parent
return current.package_proto_obj
@property
def request(self) -> "PluginRequestCompiler":
current = self
@@ -176,14 +196,14 @@ class ProtoContentBase:
for this object.
"""
return get_comment(
proto_file=self.proto_file, path=self.path, indent=self.comment_indent
proto_file=self.source_file, path=self.path, indent=self.comment_indent
)
@dataclass
class PluginRequestCompiler:
plugin_request_obj: plugin.CodeGeneratorRequest
plugin_request_obj: CodeGeneratorRequest
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
@property
@@ -253,6 +273,7 @@ class OutputTemplate:
class MessageCompiler(ProtoContentBase):
"""Representation of a protobuf message."""
source_file: FileDescriptorProto
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
proto_obj: DescriptorProto = PLACEHOLDER
path: List[int] = PLACEHOLDER
@@ -296,7 +317,7 @@ def is_map(
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
) -> bool:
"""True if proto_field_obj is a map, otherwise False."""
if proto_field_obj.type == FieldDescriptorProto.TYPE_MESSAGE:
if proto_field_obj.type == FieldDescriptorProtoType.TYPE_MESSAGE:
# This might be a map...
message_type = proto_field_obj.type_name.split(".").pop().lower()
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
@@ -311,8 +332,20 @@ def is_map(
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
"""True if proto_field_obj is a OneOf, otherwise False."""
return proto_field_obj.HasField("oneof_index")
"""
True if proto_field_obj is a OneOf, otherwise False.
.. warning::
Becuase the message from protoc is defined in proto2, and betterproto works with
proto3, and interpreting the FieldDescriptorProto.oneof_index field requires
distinguishing between default and unset values (which proto3 doesn't support),
we have to hack the generated FieldDescriptorProto class for this to work.
The hack consists of setting group="oneof_index" in the field metadata,
essentially making oneof_index the sole member of a one_of group, which allows
us to tell whether it was set, via the which_one_of interface.
"""
return which_one_of(proto_field_obj, "oneof_index")[0] == "oneof_index"
@dataclass
@@ -377,7 +410,7 @@ class FieldCompiler(MessageCompiler):
def field_wraps(self) -> Optional[str]:
"""Returns betterproto wrapped field type or None."""
match_wrapper = re.match(
r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name
r"\.google\.protobuf\.(.+)Value$", self.proto_obj.type_name
)
if match_wrapper:
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
@@ -388,7 +421,7 @@ class FieldCompiler(MessageCompiler):
@property
def repeated(self) -> bool:
return (
self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED
self.proto_obj.label == FieldDescriptorProtoLabel.LABEL_REPEATED
and not is_map(self.proto_obj, self.parent)
)
@@ -401,7 +434,9 @@ class FieldCompiler(MessageCompiler):
def field_type(self) -> str:
"""String representation of proto field type."""
return (
self.proto_obj.Type.Name(self.proto_obj.type).lower().replace("type_", "")
FieldDescriptorProtoType(self.proto_obj.type)
.name.lower()
.replace("type_", "")
)
@property
@@ -495,14 +530,19 @@ class MapEntryCompiler(FieldCompiler):
):
# Get Python types
self.py_k_type = FieldCompiler(
parent=self, proto_obj=nested.field[0] # key
source_file=self.source_file,
parent=self,
proto_obj=nested.field[0], # key
).py_type
self.py_v_type = FieldCompiler(
parent=self, proto_obj=nested.field[1] # value
source_file=self.source_file,
parent=self,
proto_obj=nested.field[1], # value
).py_type
# Get proto types
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
self.proto_k_type = FieldDescriptorProtoType(nested.field[0].type).name
self.proto_v_type = FieldDescriptorProtoType(nested.field[1].type).name
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
@property
@@ -544,7 +584,7 @@ class EnumDefinitionCompiler(MessageCompiler):
name=sanitize_name(entry_proto_value.name),
value=entry_proto_value.number,
comment=get_comment(
proto_file=self.proto_file, path=self.path + [2, entry_number]
proto_file=self.source_file, path=self.path + [2, entry_number]
),
)
for entry_number, entry_proto_value in enumerate(self.proto_obj.value)

View File

@@ -1,28 +1,19 @@
from betterproto.lib.google.protobuf import (
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
FileDescriptorProto,
ServiceDescriptorProto,
)
from betterproto.lib.google.protobuf.compiler import (
CodeGeneratorRequest,
CodeGeneratorResponse,
CodeGeneratorResponseFile,
)
import itertools
import pathlib
import sys
from typing import TYPE_CHECKING, Iterator, List, Tuple, Union, Set
try:
# betterproto[compiler] specific dependencies
from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.descriptor_pb2 import (
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
ServiceDescriptorProto,
)
except ImportError as err:
print(
"\033[31m"
f"Unable to import `{err.name}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
"\033[0m"
)
raise SystemExit(1)
from typing import Iterator, List, Tuple, TYPE_CHECKING, Union
from .compiler import outputfile_compiler
from .models import (
EnumDefinitionCompiler,
@@ -70,7 +61,7 @@ def traverse(
def generate_code(
request: plugin.CodeGeneratorRequest, response: plugin.CodeGeneratorResponse
request: CodeGeneratorRequest, response: CodeGeneratorResponse
) -> None:
plugin_options = request.parameter.split(",") if request.parameter else []
@@ -100,7 +91,12 @@ def generate_code(
for output_package_name, output_package in request_data.output_packages.items():
for proto_input_file in output_package.input_files:
for item, path in traverse(proto_input_file):
read_protobuf_type(item=item, path=path, output_package=output_package)
read_protobuf_type(
source_file=proto_input_file,
item=item,
path=path,
output_package=output_package,
)
# Read Services
for output_package_name, output_package in request_data.output_packages.items():
@@ -116,11 +112,13 @@ def generate_code(
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
output_paths.add(output_path)
f: response.File = response.file.add()
f.name = str(output_path)
# Render and then format the output file
f.content = outputfile_compiler(output_file=output_package)
response.file.append(
CodeGeneratorResponseFile(
name=str(output_path),
# Render and then format the output file
content=outputfile_compiler(output_file=output_package),
)
)
# Make each output directory a package with __init__ file
init_files = {
@@ -130,38 +128,53 @@ def generate_code(
} - output_paths
for init_file in init_files:
init = response.file.add()
init.name = str(init_file)
response.file.append(CodeGeneratorResponseFile(name=str(init_file)))
for output_package_name in sorted(output_paths.union(init_files)):
print(f"Writing {output_package_name}", file=sys.stderr)
def read_protobuf_type(
item: DescriptorProto, path: List[int], output_package: OutputTemplate
item: DescriptorProto,
path: List[int],
source_file: "FileDescriptorProto",
output_package: OutputTemplate,
) -> None:
if isinstance(item, DescriptorProto):
if item.options.map_entry:
# Skip generated map entry messages since we just use dicts
return
# Process Message
message_data = MessageCompiler(parent=output_package, proto_obj=item, path=path)
message_data = MessageCompiler(
source_file=source_file, parent=output_package, proto_obj=item, path=path
)
for index, field in enumerate(item.field):
if is_map(field, item):
MapEntryCompiler(
parent=message_data, proto_obj=field, path=path + [2, index]
source_file=source_file,
parent=message_data,
proto_obj=field,
path=path + [2, index],
)
elif is_oneof(field):
OneOfFieldCompiler(
parent=message_data, proto_obj=field, path=path + [2, index]
source_file=source_file,
parent=message_data,
proto_obj=field,
path=path + [2, index],
)
else:
FieldCompiler(
parent=message_data, proto_obj=field, path=path + [2, index]
source_file=source_file,
parent=message_data,
proto_obj=field,
path=path + [2, index],
)
elif isinstance(item, EnumDescriptorProto):
# Enum
EnumDefinitionCompiler(parent=output_package, proto_obj=item, path=path)
EnumDefinitionCompiler(
source_file=source_file, parent=output_package, proto_obj=item, path=path
)
def read_protobuf_service(