REF: Refactor plugin.py to use modular dataclasses in tree-like structure to represent parsed data (#121)

Refactor plugin to parse input into data-class based hierarchical structure
This commit is contained in:
Adrian Garcia Badaracco 2020-07-25 10:44:02 -07:00 committed by GitHub
parent cbd3437080
commit b5dcac1250
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 986 additions and 507 deletions

View File

@ -1,2 +0,0 @@
@SET plugin_dir=%~dp0
@python %plugin_dir%/plugin.py %*

View File

@ -1,480 +0,0 @@
#!/usr/bin/env python
import collections
import itertools
import os.path
import pathlib
import re
import sys
import textwrap
from typing import List, Union
from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest
import betterproto
from betterproto.compile.importing import get_type_reference, parse_source_type_name
from betterproto.compile.naming import (
pythonize_class_name,
pythonize_field_name,
pythonize_method_name,
)
from betterproto.lib.google.protobuf import ServiceDescriptorProto
try:
# betterproto[compiler] specific dependencies
import black
from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.descriptor_pb2 import (
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
)
import google.protobuf.wrappers_pb2 as google_wrappers
import jinja2
except ImportError as err:
missing_import = err.args[0][17:-1]
print(
"\033[31m"
f"Unable to import `{missing_import}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
"\033[0m"
)
raise SystemExit(1)
def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str:
if field.type in [1, 2]:
return "float"
elif field.type in [3, 4, 5, 6, 7, 13, 15, 16, 17, 18]:
return "int"
elif field.type == 8:
return "bool"
elif field.type == 9:
return "str"
elif field.type in [11, 14]:
# Type referencing another defined Message or a named enum
return get_type_reference(package, imports, field.type_name)
elif field.type == 12:
return "bytes"
else:
raise NotImplementedError(f"Unknown type {field.type}")
def get_py_zero(type_num: int) -> Union[str, float]:
zero: Union[str, float] = 0
if type_num in []:
zero = 0.0
elif type_num == 8:
zero = "False"
elif type_num == 9:
zero = '""'
elif type_num == 11:
zero = "None"
elif type_num == 12:
zero = 'b""'
return zero
def traverse(proto_file):
# Todo: Keep information about nested hierarchy
def _traverse(path, items, prefix=""):
for i, item in enumerate(items):
# Adjust the name since we flatten the hierarchy.
# Todo: don't change the name, but include full name in returned tuple
item.name = next_prefix = prefix + item.name
yield item, path + [i]
if isinstance(item, DescriptorProto):
for enum in item.enum_type:
enum.name = next_prefix + enum.name
yield enum, path + [i, 4]
if item.nested_type:
for n, p in _traverse(path + [i, 3], item.nested_type, next_prefix):
yield n, p
return itertools.chain(
_traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type)
)
def get_comment(proto_file, path: List[int], indent: int = 4) -> str:
pad = " " * indent
for sci in proto_file.source_code_info.location:
# print(list(sci.path), path, file=sys.stderr)
if list(sci.path) == path and sci.leading_comments:
lines = textwrap.wrap(
sci.leading_comments.strip().replace("\n", ""), width=79 - indent
)
if path[-2] == 2 and path[-4] != 6:
# This is a field
return f"{pad}# " + f"\n{pad}# ".join(lines)
else:
# This is a message, enum, service, or method
if len(lines) == 1 and len(lines[0]) < 79 - indent - 6:
lines[0] = lines[0].strip('"')
return f'{pad}"""{lines[0]}"""'
else:
joined = f"\n{pad}".join(lines)
return f'{pad}"""\n{pad}{joined}\n{pad}"""'
return ""
def generate_code(request, response):
plugin_options = request.parameter.split(",") if request.parameter else []
env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
loader=jinja2.FileSystemLoader("%s/templates/" % os.path.dirname(__file__)),
)
template = env.get_template("template.py.j2")
# Gather output packages
output_package_files = collections.defaultdict()
for proto_file in request.proto_file:
if (
proto_file.package == "google.protobuf"
and "INCLUDE_GOOGLE" not in plugin_options
):
continue
output_package = proto_file.package
output_package_files.setdefault(
output_package, {"input_package": proto_file.package, "files": []}
)
output_package_files[output_package]["files"].append(proto_file)
# Initialize Template data for each package
for output_package_name, output_package_content in output_package_files.items():
template_data = {
"input_package": output_package_content["input_package"],
"files": [f.name for f in output_package_content["files"]],
"imports": set(),
"datetime_imports": set(),
"typing_imports": set(),
"messages": [],
"enums": [],
"services": [],
}
output_package_content["template_data"] = template_data
# Read Messages and Enums
output_types = []
for output_package_name, output_package_content in output_package_files.items():
for proto_file in output_package_content["files"]:
for item, path in traverse(proto_file):
type_data = read_protobuf_type(
item, path, proto_file, output_package_content
)
output_types.append(type_data)
# Read Services
for output_package_name, output_package_content in output_package_files.items():
for proto_file in output_package_content["files"]:
for index, service in enumerate(proto_file.service):
read_protobuf_service(
service, index, proto_file, output_package_content, output_types
)
# Render files
output_paths = set()
for output_package_name, output_package_content in output_package_files.items():
template_data = output_package_content["template_data"]
template_data["imports"] = sorted(template_data["imports"])
template_data["datetime_imports"] = sorted(template_data["datetime_imports"])
template_data["typing_imports"] = sorted(template_data["typing_imports"])
# Fill response
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
output_paths.add(output_path)
f = response.file.add()
f.name = str(output_path)
# Render and then format the output file.
f.content = black.format_str(
template.render(description=template_data),
mode=black.FileMode(target_versions={black.TargetVersion.PY37}),
)
# Make each output directory a package with __init__ file
init_files = (
set(
directory.joinpath("__init__.py")
for path in output_paths
for directory in path.parents
)
- output_paths
)
for init_file in init_files:
init = response.file.add()
init.name = str(init_file)
for output_package_name in sorted(output_paths.union(init_files)):
print(f"Writing {output_package_name}", file=sys.stderr)
def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file, content):
input_package_name = content["input_package"]
template_data = content["template_data"]
data = {
"name": item.name,
"py_name": pythonize_class_name(item.name),
"descriptor": item,
"package": input_package_name,
}
if isinstance(item, DescriptorProto):
# print(item, file=sys.stderr)
if item.options.map_entry:
# Skip generated map entry messages since we just use dicts
return
data.update(
{
"type": "Message",
"comment": get_comment(proto_file, path),
"properties": [],
}
)
for i, f in enumerate(item.field):
t = py_type(input_package_name, template_data["imports"], f)
zero = get_py_zero(f.type)
repeated = False
packed = False
field_type = f.Type.Name(f.type).lower()[5:]
field_wraps = ""
match_wrapper = re.match(r"\.google\.protobuf\.(.+)Value", f.type_name)
if match_wrapper:
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
if hasattr(betterproto, wrapped_type):
field_wraps = f"betterproto.{wrapped_type}"
map_types = None
if f.type == 11:
# This might be a map...
message_type = f.type_name.split(".").pop().lower()
# message_type = py_type(package)
map_entry = f"{f.name.replace('_', '').lower()}entry"
if message_type == map_entry:
for nested in item.nested_type:
if nested.name.replace("_", "").lower() == map_entry:
if nested.options.map_entry:
# print("Found a map!", file=sys.stderr)
k = py_type(
input_package_name,
template_data["imports"],
nested.field[0],
)
v = py_type(
input_package_name,
template_data["imports"],
nested.field[1],
)
t = f"Dict[{k}, {v}]"
field_type = "map"
map_types = (
f.Type.Name(nested.field[0].type),
f.Type.Name(nested.field[1].type),
)
template_data["typing_imports"].add("Dict")
if f.label == 3 and field_type != "map":
# Repeated field
repeated = True
t = f"List[{t}]"
zero = "[]"
template_data["typing_imports"].add("List")
if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]:
packed = True
one_of = ""
if f.HasField("oneof_index"):
one_of = item.oneof_decl[f.oneof_index].name
if "Optional[" in t:
template_data["typing_imports"].add("Optional")
if "timedelta" in t:
template_data["datetime_imports"].add("timedelta")
elif "datetime" in t:
template_data["datetime_imports"].add("datetime")
data["properties"].append(
{
"name": f.name,
"py_name": pythonize_field_name(f.name),
"number": f.number,
"comment": get_comment(proto_file, path + [2, i]),
"proto_type": int(f.type),
"field_type": field_type,
"field_wraps": field_wraps,
"map_types": map_types,
"type": t,
"zero": zero,
"repeated": repeated,
"packed": packed,
"one_of": one_of,
}
)
# print(f, file=sys.stderr)
template_data["messages"].append(data)
return data
elif isinstance(item, EnumDescriptorProto):
# print(item.name, path, file=sys.stderr)
data.update(
{
"type": "Enum",
"comment": get_comment(proto_file, path),
"entries": [
{
"name": v.name,
"value": v.number,
"comment": get_comment(proto_file, path + [2, i]),
}
for i, v in enumerate(item.value)
],
}
)
template_data["enums"].append(data)
return data
def lookup_method_input_type(method, types):
package, name = parse_source_type_name(method.input_type)
for known_type in types:
if known_type["type"] != "Message":
continue
# Nested types are currently flattened without dots.
# Todo: keep a fully quantified name in types, that is comparable with method.input_type
if (
package == known_type["package"]
and name.replace(".", "") == known_type["name"]
):
return known_type
def is_mutable_field_type(field_type: str) -> bool:
return field_type.startswith("List[") or field_type.startswith("Dict[")
def read_protobuf_service(
service: ServiceDescriptorProto, index, proto_file, content, output_types
):
input_package_name = content["input_package"]
template_data = content["template_data"]
# print(service, file=sys.stderr)
data = {
"name": service.name,
"py_name": pythonize_class_name(service.name),
"comment": get_comment(proto_file, [6, index]),
"methods": [],
}
for j, method in enumerate(service.method):
method_input_message = lookup_method_input_type(method, output_types)
# This section ensures that method arguments having a default
# value that is initialised as a List/Dict (mutable) is replaced
# with None and initialisation is deferred to the beginning of the
# method definition. This is done so to avoid any side-effects.
# Reference: https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
mutable_default_args = []
if method_input_message:
for field in method_input_message["properties"]:
if (
not method.client_streaming
and field["zero"] != "None"
and is_mutable_field_type(field["type"])
):
mutable_default_args.append((field["py_name"], field["zero"]))
field["zero"] = "None"
if field["zero"] == "None":
template_data["typing_imports"].add("Optional")
data["methods"].append(
{
"name": method.name,
"py_name": pythonize_method_name(method.name),
"comment": get_comment(proto_file, [6, index, 2, j], indent=8),
"route": f"/{input_package_name}.{service.name}/{method.name}",
"input": get_type_reference(
input_package_name, template_data["imports"], method.input_type
).strip('"'),
"input_message": method_input_message,
"output": get_type_reference(
input_package_name,
template_data["imports"],
method.output_type,
unwrap=False,
),
"client_streaming": method.client_streaming,
"server_streaming": method.server_streaming,
"mutable_default_args": mutable_default_args,
}
)
if method.client_streaming:
template_data["typing_imports"].add("AsyncIterable")
template_data["typing_imports"].add("Iterable")
template_data["typing_imports"].add("Union")
if method.server_streaming:
template_data["typing_imports"].add("AsyncIterator")
template_data["services"].append(data)
def main():
"""The plugin's main entry point."""
# Read request message from stdin
data = sys.stdin.buffer.read()
# Parse request
request = plugin.CodeGeneratorRequest()
request.ParseFromString(data)
dump_file = os.getenv("BETTERPROTO_DUMP")
if dump_file:
dump_request(dump_file, request)
# Create response
response = plugin.CodeGeneratorResponse()
# Generate code
generate_code(request, response)
# Serialise response message
output = response.SerializeToString()
# Write to stdout
sys.stdout.buffer.write(output)
def dump_request(dump_file: str, request: CodeGeneratorRequest):
"""
For developers: Supports running plugin.py standalone so its possible to debug it.
Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file.
Then run plugin.py from your IDE in debugging mode, and redirect stdin to the file.
"""
with open(str(dump_file), "wb") as fh:
sys.stderr.write(f"\033[31mWriting input from protoc to: {dump_file}\033[0m\n")
fh.write(request.SerializeToString())
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
from .main import main

View File

@ -0,0 +1,4 @@
from .main import main
main()

View File

@ -0,0 +1,48 @@
#!/usr/bin/env python
import sys
import os
from google.protobuf.compiler import plugin_pb2 as plugin
from betterproto.plugin.parser import generate_code
def main():
"""The plugin's main entry point."""
# Read request message from stdin
data = sys.stdin.buffer.read()
# Parse request
request = plugin.CodeGeneratorRequest()
request.ParseFromString(data)
dump_file = os.getenv("BETTERPROTO_DUMP")
if dump_file:
dump_request(dump_file, request)
# Create response
response = plugin.CodeGeneratorResponse()
# Generate code
generate_code(request, response)
# Serialise response message
output = response.SerializeToString()
# Write to stdout
sys.stdout.buffer.write(output)
def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest):
"""
For developers: Supports running plugin.py standalone so its possible to debug it.
Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file.
Then run plugin.py from your IDE in debugging mode, and redirect stdin to the file.
"""
with open(str(dump_file), "wb") as fh:
sys.stderr.write(f"\033[31mWriting input from protoc to: {dump_file}\033[0m\n")
fh.write(request.SerializeToString())
if __name__ == "__main__":
main()

View File

@ -0,0 +1,713 @@
"""Plugin model dataclasses.
These classes are meant to be an intermediate representation
of protbuf objects. They are used to organize the data collected during parsing.
The general intention is to create a doubly-linked tree-like structure
with the following types of references:
- Downwards references: from message -> fields, from output package -> messages
or from service -> service methods
- Upwards references: from field -> message, message -> package.
- Input/ouput message references: from a service method to it's corresponding
input/output messages, which may even be in another package.
There are convenience methods to allow climbing up and down this tree, for
example to retrieve the list of all messages that are in the same package as
the current message.
Most of these classes take as inputs:
- proto_obj: A reference to it's corresponding protobuf object as
presented by the protoc plugin.
- parent: a reference to the parent object in the tree.
With this information, the class is able to expose attributes,
such as a pythonized name, that will be calculated from proto_obj.
The instantiation should also attach a reference to the new object
into the corresponding place within it's parent object. For example,
instantiating field `A` with parent message `B` should add a
reference to `A` to `B`'s `fields` attirbute.
"""
import re
from dataclasses import dataclass
from dataclasses import field
from typing import (
Union,
Type,
List,
Dict,
Set,
Text,
)
import textwrap
import betterproto
from betterproto.compile.importing import (
get_type_reference,
parse_source_type_name,
)
from betterproto.compile.naming import (
pythonize_class_name,
pythonize_field_name,
pythonize_method_name,
)
try:
# betterproto[compiler] specific dependencies
from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.descriptor_pb2 import (
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
FileDescriptorProto,
MethodDescriptorProto,
)
except ImportError as err:
missing_import = re.match(r".*(cannot import name .*$)", err.args[0]).group(1)
print(
"\033[31m"
f"Unable to import `{missing_import}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
"\033[0m"
)
raise SystemExit(1)
# Create a unique placeholder to deal with
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
PLACEHOLDER = object()
# Organize proto types into categories
PROTO_FLOAT_TYPES = (
FieldDescriptorProto.TYPE_DOUBLE, # 1
FieldDescriptorProto.TYPE_FLOAT, # 2
)
PROTO_INT_TYPES = (
FieldDescriptorProto.TYPE_INT64, # 3
FieldDescriptorProto.TYPE_UINT64, # 4
FieldDescriptorProto.TYPE_INT32, # 5
FieldDescriptorProto.TYPE_FIXED64, # 6
FieldDescriptorProto.TYPE_FIXED32, # 7
FieldDescriptorProto.TYPE_UINT32, # 13
FieldDescriptorProto.TYPE_SFIXED32, # 15
FieldDescriptorProto.TYPE_SFIXED64, # 16
FieldDescriptorProto.TYPE_SINT32, # 17
FieldDescriptorProto.TYPE_SINT64, # 18
)
PROTO_BOOL_TYPES = (FieldDescriptorProto.TYPE_BOOL,) # 8
PROTO_STR_TYPES = (FieldDescriptorProto.TYPE_STRING,) # 9
PROTO_BYTES_TYPES = (FieldDescriptorProto.TYPE_BYTES,) # 12
PROTO_MESSAGE_TYPES = (
FieldDescriptorProto.TYPE_MESSAGE, # 11
FieldDescriptorProto.TYPE_ENUM, # 14
)
PROTO_MAP_TYPES = (FieldDescriptorProto.TYPE_MESSAGE,) # 11
PROTO_PACKED_TYPES = (
FieldDescriptorProto.TYPE_DOUBLE, # 1
FieldDescriptorProto.TYPE_FLOAT, # 2
FieldDescriptorProto.TYPE_INT64, # 3
FieldDescriptorProto.TYPE_UINT64, # 4
FieldDescriptorProto.TYPE_INT32, # 5
FieldDescriptorProto.TYPE_FIXED64, # 6
FieldDescriptorProto.TYPE_FIXED32, # 7
FieldDescriptorProto.TYPE_BOOL, # 8
FieldDescriptorProto.TYPE_UINT32, # 13
FieldDescriptorProto.TYPE_SFIXED32, # 15
FieldDescriptorProto.TYPE_SFIXED64, # 16
FieldDescriptorProto.TYPE_SINT32, # 17
FieldDescriptorProto.TYPE_SINT64, # 18
)
def get_comment(proto_file, path: List[int], indent: int = 4) -> str:
pad = " " * indent
for sci in proto_file.source_code_info.location:
# print(list(sci.path), path, file=sys.stderr)
if list(sci.path) == path and sci.leading_comments:
lines = textwrap.wrap(
sci.leading_comments.strip().replace("\n", ""), width=79 - indent,
)
if path[-2] == 2 and path[-4] != 6:
# This is a field
return f"{pad}# " + f"\n{pad}# ".join(lines)
else:
# This is a message, enum, service, or method
if len(lines) == 1 and len(lines[0]) < 79 - indent - 6:
lines[0] = lines[0].strip('"')
return f'{pad}"""{lines[0]}"""'
else:
joined = f"\n{pad}".join(lines)
return f'{pad}"""\n{pad}{joined}\n{pad}"""'
return ""
class ProtoContentBase:
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
path: List[int]
comment_indent: int = 4
def __post_init__(self):
"""Checks that no fake default fields were left as placeholders."""
for field_name, field_val in self.__dataclass_fields__.items():
if field_val is PLACEHOLDER:
raise ValueError(f"`{field_name}` is a required field.")
@property
def output_file(self) -> "OutputTemplate":
current = self
while not isinstance(current, OutputTemplate):
current = current.parent
return current
@property
def proto_file(self) -> FieldDescriptorProto:
current = self
while not isinstance(current, OutputTemplate):
current = current.parent
return current.package_proto_obj
@property
def request(self) -> "PluginRequestCompiler":
current = self
while not isinstance(current, OutputTemplate):
current = current.parent
return current.parent_request
@property
def comment(self) -> str:
"""Crawl the proto source code and retrieve comments
for this object.
"""
return get_comment(
proto_file=self.proto_file, path=self.path, indent=self.comment_indent,
)
@dataclass
class PluginRequestCompiler:
plugin_request_obj: plugin.CodeGeneratorRequest
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
@property
def all_messages(self) -> List["MessageCompiler"]:
"""All of the messages in this request.
Returns
-------
List[MessageCompiler]
List of all of the messages in this request.
"""
return [
msg for output in self.output_packages.values() for msg in output.messages
]
@dataclass
class OutputTemplate:
"""Representation of an output .py file.
Each output file corresponds to a .proto input file,
but may need references to other .proto files to be
built.
"""
parent_request: PluginRequestCompiler
package_proto_obj: FileDescriptorProto
input_files: List[str] = field(default_factory=list)
imports: Set[str] = field(default_factory=set)
datetime_imports: Set[str] = field(default_factory=set)
typing_imports: Set[str] = field(default_factory=set)
messages: List["MessageCompiler"] = field(default_factory=list)
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
services: List["ServiceCompiler"] = field(default_factory=list)
@property
def package(self) -> str:
"""Name of input package.
Returns
-------
str
Name of input package.
"""
return self.package_proto_obj.package
@property
def input_filenames(self) -> List[str]:
"""Names of the input files used to build this output.
Returns
-------
List[str]
Names of the input files used to build this output.
"""
return [f.name for f in self.input_files]
@dataclass
class MessageCompiler(ProtoContentBase):
"""Representation of a protobuf message.
"""
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
proto_obj: DescriptorProto = PLACEHOLDER
path: List[int] = PLACEHOLDER
fields: List[Union["FieldCompiler", "MessageCompiler"]] = field(
default_factory=list
)
def __post_init__(self):
# Add message to output file
if isinstance(self.parent, OutputTemplate):
if isinstance(self, EnumDefinitionCompiler):
self.output_file.enums.append(self)
else:
self.output_file.messages.append(self)
super().__post_init__()
@property
def proto_name(self) -> str:
return self.proto_obj.name
@property
def py_name(self) -> str:
return pythonize_class_name(self.proto_name)
@property
def annotation(self) -> str:
if self.repeated:
return f"List[{self.py_name}]"
return self.py_name
def is_map(
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
) -> bool:
"""True if proto_field_obj is a map, otherwise False.
"""
if proto_field_obj.type == FieldDescriptorProto.TYPE_MESSAGE:
# This might be a map...
message_type = proto_field_obj.type_name.split(".").pop().lower()
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
if message_type == map_entry:
for nested in parent_message.nested_type: # parent message
if nested.name.replace("_", "").lower() == map_entry:
if nested.options.map_entry:
return True
return False
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
"""True if proto_field_obj is a OneOf, otherwise False.
"""
if proto_field_obj.HasField("oneof_index"):
return True
return False
@dataclass
class FieldCompiler(MessageCompiler):
parent: MessageCompiler = PLACEHOLDER
proto_obj: FieldDescriptorProto = PLACEHOLDER
def __post_init__(self):
# Add field to message
self.parent.fields.append(self)
# Check for new imports
annotation = self.annotation
if "Optional[" in annotation:
self.output_file.typing_imports.add("Optional")
if "List[" in annotation:
self.output_file.typing_imports.add("List")
if "Dict[" in annotation:
self.output_file.typing_imports.add("Dict")
if "timedelta" in annotation:
self.output_file.datetime_imports.add("timedelta")
if "datetime" in annotation:
self.output_file.datetime_imports.add("datetime")
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
def get_field_string(self, indent: int = 4) -> str:
"""Construct string representation of this field as a field."""
name = f"{self.py_name}"
annotations = f": {self.annotation}"
betterproto_field_type = (
f"betterproto.{self.field_type}_field({self.proto_obj.number}"
+ f"{self.betterproto_field_args}"
+ ")"
)
return name + annotations + " = " + betterproto_field_type
@property
def betterproto_field_args(self):
args = ""
if self.field_wraps:
args = args + f", wraps={self.field_wraps}"
return args
@property
def field_wraps(self) -> Union[str, None]:
"""Returns betterproto wrapped field type or None.
"""
match_wrapper = re.match(
r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name
)
if match_wrapper:
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
if hasattr(betterproto, wrapped_type):
return f"betterproto.{wrapped_type}"
return None
@property
def repeated(self) -> bool:
if self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED and not is_map(
self.proto_obj, self.parent
):
return True
return False
@property
def mutable(self) -> bool:
"""True if the field is a mutable type, otherwise False."""
annotation = self.annotation
return annotation.startswith("List[") or annotation.startswith("Dict[")
@property
def field_type(self) -> str:
"""String representation of proto field type."""
return (
self.proto_obj.Type.Name(self.proto_obj.type).lower().replace("type_", "")
)
@property
def default_value_string(self) -> Union[Text, None, float, int]:
"""Python representation of the default proto value.
"""
if self.repeated:
return "[]"
if self.py_type == "int":
return "0"
if self.py_type == "float":
return "0.0"
elif self.py_type == "bool":
return "False"
elif self.py_type == "str":
return '""'
elif self.py_type == "bytes":
return 'b""'
else:
# Message type
return "None"
@property
def packed(self) -> bool:
"""True if the wire representation is a packed format."""
if self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES:
return True
return False
@property
def py_name(self) -> str:
"""Pythonized name."""
return pythonize_field_name(self.proto_name)
@property
def proto_name(self) -> str:
"""Original protobuf name."""
return self.proto_obj.name
@property
def py_type(self) -> str:
"""String representation of Python type."""
if self.proto_obj.type in PROTO_FLOAT_TYPES:
return "float"
elif self.proto_obj.type in PROTO_INT_TYPES:
return "int"
elif self.proto_obj.type in PROTO_BOOL_TYPES:
return "bool"
elif self.proto_obj.type in PROTO_STR_TYPES:
return "str"
elif self.proto_obj.type in PROTO_BYTES_TYPES:
return "bytes"
elif self.proto_obj.type in PROTO_MESSAGE_TYPES:
# Type referencing another defined Message or a named enum
return get_type_reference(
package=self.output_file.package,
imports=self.output_file.imports,
source_type=self.proto_obj.type_name,
)
else:
raise NotImplementedError(f"Unknown type {field.type}")
@property
def annotation(self) -> str:
if self.repeated:
return f"List[{self.py_type}]"
return self.py_type
@dataclass
class OneOfFieldCompiler(FieldCompiler):
@property
def betterproto_field_args(self) -> "str":
args = super().betterproto_field_args
group = self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name
args = args + f', group="{group}"'
return args
@dataclass
class MapEntryCompiler(FieldCompiler):
py_k_type: Type = PLACEHOLDER
py_v_type: Type = PLACEHOLDER
proto_k_type: str = PLACEHOLDER
proto_v_type: str = PLACEHOLDER
def __post_init__(self):
"""Explore nested types and set k_type and v_type if unset."""
map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry"
for nested in self.parent.proto_obj.nested_type:
if nested.name.replace("_", "").lower() == map_entry:
if nested.options.map_entry:
# Get Python types
self.py_k_type = FieldCompiler(
parent=self, proto_obj=nested.field[0], # key
).py_type
self.py_v_type = FieldCompiler(
parent=self, proto_obj=nested.field[1], # value
).py_type
# Get proto types
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
def get_field_string(self, indent: int = 4) -> str:
"""Construct string representation of this field."""
name = f"{self.py_name}"
annotations = f": {self.annotation}"
betterproto_field_type = (
f"betterproto.map_field("
f"{self.proto_obj.number}, betterproto.{self.proto_k_type}, "
f"betterproto.{self.proto_v_type})"
)
return name + annotations + " = " + betterproto_field_type
@property
def annotation(self):
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
@property
def repeated(self):
return False # maps cannot be repeated
@dataclass
class EnumDefinitionCompiler(MessageCompiler):
"""Representation of a proto Enum definition."""
proto_obj: EnumDescriptorProto = PLACEHOLDER
entries: List["EnumDefinitionCompiler.EnumEntry"] = PLACEHOLDER
@dataclass(unsafe_hash=True)
class EnumEntry:
"""Representation of an Enum entry."""
name: str
value: int
comment: str
def __post_init__(self):
# Get entries/allowed values for this Enum
self.entries = [
self.EnumEntry(
name=entry_proto_value.name,
value=entry_proto_value.number,
comment=get_comment(
proto_file=self.proto_file, path=self.path + [2, entry_number]
),
)
for entry_number, entry_proto_value in enumerate(self.proto_obj.value)
]
super().__post_init__() # call MessageCompiler __post_init__
@property
def default_value_string(self) -> int:
"""Python representation of the default value for Enums.
As per the spec, this is the first value of the Enum.
"""
return str(self.entries[0].value) # ideally, should ALWAYS be int(0)!
@dataclass
class ServiceCompiler(ProtoContentBase):
parent: OutputTemplate = PLACEHOLDER
proto_obj: DescriptorProto = PLACEHOLDER
path: List[int] = PLACEHOLDER
methods: List["ServiceMethodCompiler"] = field(default_factory=list)
def __post_init__(self) -> None:
# Add service to output file
self.output_file.services.append(self)
super().__post_init__() # check for unset fields
@property
def proto_name(self):
return self.proto_obj.name
@property
def py_name(self):
return pythonize_class_name(self.proto_name)
@dataclass
class ServiceMethodCompiler(ProtoContentBase):
parent: ServiceCompiler
proto_obj: MethodDescriptorProto
path: List[int] = PLACEHOLDER
comment_indent: int = 8
def __post_init__(self) -> None:
# Add method to service
self.parent.methods.append(self)
# Check for Optional import
if self.py_input_message:
for f in self.py_input_message.fields:
if f.default_value_string == "None":
self.output_file.typing_imports.add("Optional")
if "Optional" in self.py_output_message_type:
self.output_file.typing_imports.add("Optional")
self.mutable_default_args # ensure this is called before rendering
# Check for Async imports
if self.client_streaming:
self.output_file.typing_imports.add("AsyncIterable")
self.output_file.typing_imports.add("Iterable")
self.output_file.typing_imports.add("Union")
if self.server_streaming:
self.output_file.typing_imports.add("AsyncIterator")
super().__post_init__() # check for unset fields
@property
def mutable_default_args(self) -> Dict[str, str]:
"""Handle mutable default arguments.
Returns a list of tuples containing the name and default value
for arguments to this message who's default value is mutable.
The defaults are swapped out for None and replaced back inside
the method's body.
Reference:
https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
Returns
-------
Dict[str, str]
Name and actual default value (as a string)
for each argument with mutable default values.
"""
mutable_default_args = dict()
if self.py_input_message:
for f in self.py_input_message.fields:
if (
not self.client_streaming
and f.default_value_string != "None"
and f.mutable
):
mutable_default_args[f.py_name] = f.default_value_string
self.output_file.typing_imports.add("Optional")
return mutable_default_args
@property
def py_name(self) -> str:
"""Pythonized method name."""
return pythonize_method_name(self.proto_obj.name)
@property
def proto_name(self) -> str:
"""Original protobuf name."""
return self.proto_obj.name
@property
def route(self) -> str:
return (
f"/{self.output_file.package}."
f"{self.parent.proto_name}/{self.proto_name}"
)
@property
def py_input_message(self) -> Union[None, MessageCompiler]:
"""Find the input message object.
Returns
-------
Union[None, MessageCompiler]
Method instance representing the input message.
If not input message could be found or there are no
input messages, None is returned.
"""
package, name = parse_source_type_name(self.proto_obj.input_type)
# Nested types are currently flattened without dots.
# Todo: keep a fully quantified name in types, that is
# comparable with method.input_type
for msg in self.request.all_messages:
if (
msg.py_name == name.replace(".", "")
and msg.output_file.package == package
):
return msg
return None
@property
def py_input_message_type(self) -> str:
"""String representation of the Python type correspoding to the
input message.
Returns
-------
str
String representation of the Python type correspoding to the
input message.
"""
return get_type_reference(
package=self.output_file.package,
imports=self.output_file.imports,
source_type=self.proto_obj.input_type,
).strip('"')
@property
def py_output_message_type(self) -> str:
"""String representation of the Python type correspoding to the
output message.
Returns
-------
str
String representation of the Python type correspoding to the
output message.
"""
return get_type_reference(
package=self.output_file.package,
imports=self.output_file.imports,
source_type=self.proto_obj.output_type,
unwrap=False,
).strip('"')
@property
def client_streaming(self) -> bool:
return self.proto_obj.client_streaming
@property
def server_streaming(self) -> bool:
return self.proto_obj.server_streaming

View File

@ -0,0 +1,188 @@
import itertools
import os.path
import pathlib
import sys
from typing import List, Iterator
try:
# betterproto[compiler] specific dependencies
import black
from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.descriptor_pb2 import (
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
ServiceDescriptorProto,
)
import jinja2
except ImportError as err:
missing_import = err.args[0][17:-1]
print(
"\033[31m"
f"Unable to import `{missing_import}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
"\033[0m"
)
raise SystemExit(1)
from betterproto.plugin.models import (
PluginRequestCompiler,
OutputTemplate,
MessageCompiler,
FieldCompiler,
OneOfFieldCompiler,
MapEntryCompiler,
EnumDefinitionCompiler,
ServiceCompiler,
ServiceMethodCompiler,
is_map,
is_oneof,
)
def traverse(proto_file: FieldDescriptorProto) -> Iterator:
# Todo: Keep information about nested hierarchy
def _traverse(path, items, prefix=""):
for i, item in enumerate(items):
# Adjust the name since we flatten the hierarchy.
# Todo: don't change the name, but include full name in returned tuple
item.name = next_prefix = prefix + item.name
yield item, path + [i]
if isinstance(item, DescriptorProto):
for enum in item.enum_type:
enum.name = next_prefix + enum.name
yield enum, path + [i, 4]
if item.nested_type:
for n, p in _traverse(path + [i, 3], item.nested_type, next_prefix):
yield n, p
return itertools.chain(
_traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type)
)
def generate_code(
request: plugin.CodeGeneratorRequest, response: plugin.CodeGeneratorResponse
) -> None:
plugin_options = request.parameter.split(",") if request.parameter else []
templates_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "templates")
)
env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
loader=jinja2.FileSystemLoader(templates_folder),
)
template = env.get_template("template.py.j2")
request_data = PluginRequestCompiler(plugin_request_obj=request)
# Gather output packages
for proto_file in request.proto_file:
if (
proto_file.package == "google.protobuf"
and "INCLUDE_GOOGLE" not in plugin_options
):
# If not INCLUDE_GOOGLE,
# skip re-compiling Google's well-known types
continue
output_package_name = proto_file.package
if output_package_name not in request_data.output_packages:
# Create a new output if there is no output for this package
request_data.output_packages[output_package_name] = OutputTemplate(
parent_request=request_data, package_proto_obj=proto_file
)
# Add this input file to the output corresponding to this package
request_data.output_packages[output_package_name].input_files.append(proto_file)
# Read Messages and Enums
# We need to read Messages before Services in so that we can
# get the references to input/output messages for each service
for output_package_name, output_package in request_data.output_packages.items():
for proto_input_file in output_package.input_files:
for item, path in traverse(proto_input_file):
read_protobuf_type(item=item, path=path, output_package=output_package)
# Read Services
for output_package_name, output_package in request_data.output_packages.items():
for proto_input_file in output_package.input_files:
for index, service in enumerate(proto_input_file.service):
read_protobuf_service(service, index, output_package)
# Generate output files
output_paths: pathlib.Path = set()
for output_package_name, template_data in request_data.output_packages.items():
# Add files to the response object
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
output_paths.add(output_path)
f: response.File = response.file.add()
f.name: str = str(output_path)
# Render and then format the output file
f.content: str = black.format_str(
template.render(description=template_data),
mode=black.FileMode(target_versions={black.TargetVersion.PY37}),
)
# Make each output directory a package with __init__ file
init_files = (
set(
directory.joinpath("__init__.py")
for path in output_paths
for directory in path.parents
)
- output_paths
)
for init_file in init_files:
init = response.file.add()
init.name = str(init_file)
for output_package_name in sorted(output_paths.union(init_files)):
print(f"Writing {output_package_name}", file=sys.stderr)
def read_protobuf_type(
item: DescriptorProto, path: List[int], output_package: OutputTemplate
) -> None:
if isinstance(item, DescriptorProto):
if item.options.map_entry:
# Skip generated map entry messages since we just use dicts
return
# Process Message
message_data = MessageCompiler(parent=output_package, proto_obj=item, path=path)
for index, field in enumerate(item.field):
if is_map(field, item):
MapEntryCompiler(
parent=message_data, proto_obj=field, path=path + [2, index]
)
elif is_oneof(field):
OneOfFieldCompiler(
parent=message_data, proto_obj=field, path=path + [2, index]
)
else:
FieldCompiler(
parent=message_data, proto_obj=field, path=path + [2, index]
)
elif isinstance(item, EnumDescriptorProto):
# Enum
EnumDefinitionCompiler(parent=output_package, proto_obj=item, path=path)
def read_protobuf_service(
service: ServiceDescriptorProto, index: int, output_package: OutputTemplate
) -> None:
service_data = ServiceCompiler(
parent=output_package, proto_obj=service, path=[6, index],
)
for j, method in enumerate(service.method):
ServiceMethodCompiler(
parent=service_data, proto_obj=method, path=[6, index, 2, j],
)

View File

@ -0,0 +1,2 @@
@SET plugin_dir=%~dp0
@python -m %plugin_dir% %*

View File

@ -1,13 +1,13 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: {{ ', '.join(description.files) }}
# sources: {{ ', '.join(description.input_filenames) }}
# plugin: python-betterproto
from dataclasses import dataclass
{% if description.datetime_imports %}
from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
from datetime import {% for i in description.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif%}
{% if description.typing_imports %}
from typing import {% for i in description.typing_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
from typing import {% for i in description.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %}
@ -40,13 +40,13 @@ class {{ message.py_name }}(betterproto.Message):
{{ message.comment }}
{% endif %}
{% for field in message.properties %}
{% for field in message.fields %}
{% if field.comment %}
{{ field.comment }}
{% endif %}
{{ field.py_name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}{% if field.field_wraps %}, wraps={{ field.field_wraps }}{% endif %})
{{ field.get_field_string() }}
{% endfor %}
{% if not message.properties %}
{% if not message.fields %}
pass
{% endif %}
@ -61,32 +61,37 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% for method in service.methods %}
async def {{ method.py_name }}(self
{%- if not method.client_streaming -%}
{%- if method.input_message and method.input_message.properties -%}, *,
{%- for field in method.input_message.properties -%}
{{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") -%}
Optional[{{ field.type }}]
{%- if method.py_input_message and method.py_input_message.fields -%}, *,
{%- for field in method.py_input_message.fields -%}
{{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%}
Optional[{{ field.annotation }}]
{%- else -%}
{{ field.type }}
{%- endif -%} = {{ field.zero }}
{{ field.annotation }}
{%- endif -%} =
{%- if field.py_name not in method.mutable_default_args -%}
{{ field.default_value_string }}
{%- else -%}
None
{% endif -%}
{%- if not loop.last %}, {% endif -%}
{%- endfor -%}
{%- endif -%}
{%- else -%}
{# Client streaming: need a request iterator instead #}
, request_iterator: Union[AsyncIterable["{{ method.input }}"], Iterable["{{ method.input }}"]]
, request_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]]
{%- endif -%}
) -> {% if method.server_streaming %}AsyncIterator[{{ method.output }}]{% else %}{{ method.output }}{% endif %}:
) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}:
{% if method.comment %}
{{ method.comment }}
{% endif %}
{%- for py_name, zero in method.mutable_default_args %}
{%- for py_name, zero in method.mutable_default_args.items() %}
{{ py_name }} = {{ py_name }} or {{ zero }}
{% endfor %}
{% if not method.client_streaming %}
request = {{ method.input }}()
{% for field in method.input_message.properties %}
request = {{ method.py_input_message_type }}()
{% for field in method.py_input_message.fields %}
{% if field.field_type == 'message' %}
if {{ field.py_name }} is not None:
request.{{ field.py_name }} = {{ field.py_name }}
@ -101,15 +106,15 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
async for response in self._stream_stream(
"{{ method.route }}",
request_iterator,
{{ method.input }},
{{ method.output.strip('"') }},
{{ method.py_input_message_type }},
{{ method.py_output_message_type.strip('"') }},
):
yield response
{% else %}{# i.e. not client streaming #}
async for response in self._unary_stream(
"{{ method.route }}",
request,
{{ method.output.strip('"') }},
{{ method.py_output_message_type.strip('"') }},
):
yield response
@ -119,14 +124,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
return await self._stream_unary(
"{{ method.route }}",
request_iterator,
{{ method.input }},
{{ method.output.strip('"') }}
{{ method.py_input_message_type }},
{{ method.py_output_message_type.strip('"') }}
)
{% else %}{# i.e. not client streaming #}
return await self._unary_unary(
"{{ method.route }}",
request,
{{ method.output.strip('"') }}
{{ method.py_output_message_type.strip('"') }}
)
{% endif %}{# client streaming #}
{% endif %}
@ -134,6 +139,6 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% endfor %}
{% endfor %}
{% for i in description.imports %}
{% for i in description.imports|sort %}
{{ i }}
{% endfor %}
{% endfor %}