REF: Refactor plugin.py to use modular dataclasses in tree-like structure to represent parsed data (#121)
Refactor plugin to parse input into data-class based hierarchical structure
This commit is contained in:
committed by
GitHub
parent
cbd3437080
commit
b5dcac1250
1
src/betterproto/plugin/__init__.py
Normal file
1
src/betterproto/plugin/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .main import main
|
||||
4
src/betterproto/plugin/__main__.py
Normal file
4
src/betterproto/plugin/__main__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .main import main
|
||||
|
||||
|
||||
main()
|
||||
48
src/betterproto/plugin/main.py
Normal file
48
src/betterproto/plugin/main.py
Normal file
@@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python
|
||||
import sys
|
||||
import os
|
||||
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
|
||||
from betterproto.plugin.parser import generate_code
|
||||
|
||||
|
||||
def main():
|
||||
"""The plugin's main entry point."""
|
||||
# Read request message from stdin
|
||||
data = sys.stdin.buffer.read()
|
||||
|
||||
# Parse request
|
||||
request = plugin.CodeGeneratorRequest()
|
||||
request.ParseFromString(data)
|
||||
|
||||
dump_file = os.getenv("BETTERPROTO_DUMP")
|
||||
if dump_file:
|
||||
dump_request(dump_file, request)
|
||||
|
||||
# Create response
|
||||
response = plugin.CodeGeneratorResponse()
|
||||
|
||||
# Generate code
|
||||
generate_code(request, response)
|
||||
|
||||
# Serialise response message
|
||||
output = response.SerializeToString()
|
||||
|
||||
# Write to stdout
|
||||
sys.stdout.buffer.write(output)
|
||||
|
||||
|
||||
def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest):
|
||||
"""
|
||||
For developers: Supports running plugin.py standalone so its possible to debug it.
|
||||
Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file.
|
||||
Then run plugin.py from your IDE in debugging mode, and redirect stdin to the file.
|
||||
"""
|
||||
with open(str(dump_file), "wb") as fh:
|
||||
sys.stderr.write(f"\033[31mWriting input from protoc to: {dump_file}\033[0m\n")
|
||||
fh.write(request.SerializeToString())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
713
src/betterproto/plugin/models.py
Normal file
713
src/betterproto/plugin/models.py
Normal file
@@ -0,0 +1,713 @@
|
||||
"""Plugin model dataclasses.
|
||||
|
||||
These classes are meant to be an intermediate representation
|
||||
of protbuf objects. They are used to organize the data collected during parsing.
|
||||
|
||||
The general intention is to create a doubly-linked tree-like structure
|
||||
with the following types of references:
|
||||
- Downwards references: from message -> fields, from output package -> messages
|
||||
or from service -> service methods
|
||||
- Upwards references: from field -> message, message -> package.
|
||||
- Input/ouput message references: from a service method to it's corresponding
|
||||
input/output messages, which may even be in another package.
|
||||
|
||||
There are convenience methods to allow climbing up and down this tree, for
|
||||
example to retrieve the list of all messages that are in the same package as
|
||||
the current message.
|
||||
|
||||
Most of these classes take as inputs:
|
||||
- proto_obj: A reference to it's corresponding protobuf object as
|
||||
presented by the protoc plugin.
|
||||
- parent: a reference to the parent object in the tree.
|
||||
|
||||
With this information, the class is able to expose attributes,
|
||||
such as a pythonized name, that will be calculated from proto_obj.
|
||||
|
||||
The instantiation should also attach a reference to the new object
|
||||
into the corresponding place within it's parent object. For example,
|
||||
instantiating field `A` with parent message `B` should add a
|
||||
reference to `A` to `B`'s `fields` attirbute.
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from typing import (
|
||||
Union,
|
||||
Type,
|
||||
List,
|
||||
Dict,
|
||||
Set,
|
||||
Text,
|
||||
)
|
||||
import textwrap
|
||||
|
||||
import betterproto
|
||||
from betterproto.compile.importing import (
|
||||
get_type_reference,
|
||||
parse_source_type_name,
|
||||
)
|
||||
from betterproto.compile.naming import (
|
||||
pythonize_class_name,
|
||||
pythonize_field_name,
|
||||
pythonize_method_name,
|
||||
)
|
||||
|
||||
try:
|
||||
# betterproto[compiler] specific dependencies
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
from google.protobuf.descriptor_pb2 import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FieldDescriptorProto,
|
||||
FileDescriptorProto,
|
||||
MethodDescriptorProto,
|
||||
)
|
||||
except ImportError as err:
|
||||
missing_import = re.match(r".*(cannot import name .*$)", err.args[0]).group(1)
|
||||
print(
|
||||
"\033[31m"
|
||||
f"Unable to import `{missing_import}` from betterproto plugin! "
|
||||
"Please ensure that you've installed betterproto as "
|
||||
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
||||
"are included."
|
||||
"\033[0m"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
# Create a unique placeholder to deal with
|
||||
# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
|
||||
PLACEHOLDER = object()
|
||||
|
||||
# Organize proto types into categories
|
||||
PROTO_FLOAT_TYPES = (
|
||||
FieldDescriptorProto.TYPE_DOUBLE, # 1
|
||||
FieldDescriptorProto.TYPE_FLOAT, # 2
|
||||
)
|
||||
PROTO_INT_TYPES = (
|
||||
FieldDescriptorProto.TYPE_INT64, # 3
|
||||
FieldDescriptorProto.TYPE_UINT64, # 4
|
||||
FieldDescriptorProto.TYPE_INT32, # 5
|
||||
FieldDescriptorProto.TYPE_FIXED64, # 6
|
||||
FieldDescriptorProto.TYPE_FIXED32, # 7
|
||||
FieldDescriptorProto.TYPE_UINT32, # 13
|
||||
FieldDescriptorProto.TYPE_SFIXED32, # 15
|
||||
FieldDescriptorProto.TYPE_SFIXED64, # 16
|
||||
FieldDescriptorProto.TYPE_SINT32, # 17
|
||||
FieldDescriptorProto.TYPE_SINT64, # 18
|
||||
)
|
||||
PROTO_BOOL_TYPES = (FieldDescriptorProto.TYPE_BOOL,) # 8
|
||||
PROTO_STR_TYPES = (FieldDescriptorProto.TYPE_STRING,) # 9
|
||||
PROTO_BYTES_TYPES = (FieldDescriptorProto.TYPE_BYTES,) # 12
|
||||
PROTO_MESSAGE_TYPES = (
|
||||
FieldDescriptorProto.TYPE_MESSAGE, # 11
|
||||
FieldDescriptorProto.TYPE_ENUM, # 14
|
||||
)
|
||||
PROTO_MAP_TYPES = (FieldDescriptorProto.TYPE_MESSAGE,) # 11
|
||||
PROTO_PACKED_TYPES = (
|
||||
FieldDescriptorProto.TYPE_DOUBLE, # 1
|
||||
FieldDescriptorProto.TYPE_FLOAT, # 2
|
||||
FieldDescriptorProto.TYPE_INT64, # 3
|
||||
FieldDescriptorProto.TYPE_UINT64, # 4
|
||||
FieldDescriptorProto.TYPE_INT32, # 5
|
||||
FieldDescriptorProto.TYPE_FIXED64, # 6
|
||||
FieldDescriptorProto.TYPE_FIXED32, # 7
|
||||
FieldDescriptorProto.TYPE_BOOL, # 8
|
||||
FieldDescriptorProto.TYPE_UINT32, # 13
|
||||
FieldDescriptorProto.TYPE_SFIXED32, # 15
|
||||
FieldDescriptorProto.TYPE_SFIXED64, # 16
|
||||
FieldDescriptorProto.TYPE_SINT32, # 17
|
||||
FieldDescriptorProto.TYPE_SINT64, # 18
|
||||
)
|
||||
|
||||
|
||||
def get_comment(proto_file, path: List[int], indent: int = 4) -> str:
|
||||
pad = " " * indent
|
||||
for sci in proto_file.source_code_info.location:
|
||||
# print(list(sci.path), path, file=sys.stderr)
|
||||
if list(sci.path) == path and sci.leading_comments:
|
||||
lines = textwrap.wrap(
|
||||
sci.leading_comments.strip().replace("\n", ""), width=79 - indent,
|
||||
)
|
||||
|
||||
if path[-2] == 2 and path[-4] != 6:
|
||||
# This is a field
|
||||
return f"{pad}# " + f"\n{pad}# ".join(lines)
|
||||
else:
|
||||
# This is a message, enum, service, or method
|
||||
if len(lines) == 1 and len(lines[0]) < 79 - indent - 6:
|
||||
lines[0] = lines[0].strip('"')
|
||||
return f'{pad}"""{lines[0]}"""'
|
||||
else:
|
||||
joined = f"\n{pad}".join(lines)
|
||||
return f'{pad}"""\n{pad}{joined}\n{pad}"""'
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
class ProtoContentBase:
|
||||
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""
|
||||
|
||||
path: List[int]
|
||||
comment_indent: int = 4
|
||||
|
||||
def __post_init__(self):
|
||||
"""Checks that no fake default fields were left as placeholders."""
|
||||
for field_name, field_val in self.__dataclass_fields__.items():
|
||||
if field_val is PLACEHOLDER:
|
||||
raise ValueError(f"`{field_name}` is a required field.")
|
||||
|
||||
@property
|
||||
def output_file(self) -> "OutputTemplate":
|
||||
current = self
|
||||
while not isinstance(current, OutputTemplate):
|
||||
current = current.parent
|
||||
return current
|
||||
|
||||
@property
|
||||
def proto_file(self) -> FieldDescriptorProto:
|
||||
current = self
|
||||
while not isinstance(current, OutputTemplate):
|
||||
current = current.parent
|
||||
return current.package_proto_obj
|
||||
|
||||
@property
|
||||
def request(self) -> "PluginRequestCompiler":
|
||||
current = self
|
||||
while not isinstance(current, OutputTemplate):
|
||||
current = current.parent
|
||||
return current.parent_request
|
||||
|
||||
@property
|
||||
def comment(self) -> str:
|
||||
"""Crawl the proto source code and retrieve comments
|
||||
for this object.
|
||||
"""
|
||||
return get_comment(
|
||||
proto_file=self.proto_file, path=self.path, indent=self.comment_indent,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginRequestCompiler:
|
||||
|
||||
plugin_request_obj: plugin.CodeGeneratorRequest
|
||||
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def all_messages(self) -> List["MessageCompiler"]:
|
||||
"""All of the messages in this request.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[MessageCompiler]
|
||||
List of all of the messages in this request.
|
||||
"""
|
||||
return [
|
||||
msg for output in self.output_packages.values() for msg in output.messages
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputTemplate:
|
||||
"""Representation of an output .py file.
|
||||
|
||||
Each output file corresponds to a .proto input file,
|
||||
but may need references to other .proto files to be
|
||||
built.
|
||||
"""
|
||||
|
||||
parent_request: PluginRequestCompiler
|
||||
package_proto_obj: FileDescriptorProto
|
||||
input_files: List[str] = field(default_factory=list)
|
||||
imports: Set[str] = field(default_factory=set)
|
||||
datetime_imports: Set[str] = field(default_factory=set)
|
||||
typing_imports: Set[str] = field(default_factory=set)
|
||||
messages: List["MessageCompiler"] = field(default_factory=list)
|
||||
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
|
||||
services: List["ServiceCompiler"] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def package(self) -> str:
|
||||
"""Name of input package.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Name of input package.
|
||||
"""
|
||||
return self.package_proto_obj.package
|
||||
|
||||
@property
|
||||
def input_filenames(self) -> List[str]:
|
||||
"""Names of the input files used to build this output.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[str]
|
||||
Names of the input files used to build this output.
|
||||
"""
|
||||
return [f.name for f in self.input_files]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageCompiler(ProtoContentBase):
|
||||
"""Representation of a protobuf message.
|
||||
"""
|
||||
|
||||
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
|
||||
proto_obj: DescriptorProto = PLACEHOLDER
|
||||
path: List[int] = PLACEHOLDER
|
||||
fields: List[Union["FieldCompiler", "MessageCompiler"]] = field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Add message to output file
|
||||
if isinstance(self.parent, OutputTemplate):
|
||||
if isinstance(self, EnumDefinitionCompiler):
|
||||
self.output_file.enums.append(self)
|
||||
else:
|
||||
self.output_file.messages.append(self)
|
||||
super().__post_init__()
|
||||
|
||||
@property
|
||||
def proto_name(self) -> str:
|
||||
return self.proto_obj.name
|
||||
|
||||
@property
|
||||
def py_name(self) -> str:
|
||||
return pythonize_class_name(self.proto_name)
|
||||
|
||||
@property
|
||||
def annotation(self) -> str:
|
||||
if self.repeated:
|
||||
return f"List[{self.py_name}]"
|
||||
return self.py_name
|
||||
|
||||
|
||||
def is_map(
|
||||
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
|
||||
) -> bool:
|
||||
"""True if proto_field_obj is a map, otherwise False.
|
||||
"""
|
||||
if proto_field_obj.type == FieldDescriptorProto.TYPE_MESSAGE:
|
||||
# This might be a map...
|
||||
message_type = proto_field_obj.type_name.split(".").pop().lower()
|
||||
map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry"
|
||||
if message_type == map_entry:
|
||||
for nested in parent_message.nested_type: # parent message
|
||||
if nested.name.replace("_", "").lower() == map_entry:
|
||||
if nested.options.map_entry:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool:
|
||||
"""True if proto_field_obj is a OneOf, otherwise False.
|
||||
"""
|
||||
if proto_field_obj.HasField("oneof_index"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class FieldCompiler(MessageCompiler):
|
||||
parent: MessageCompiler = PLACEHOLDER
|
||||
proto_obj: FieldDescriptorProto = PLACEHOLDER
|
||||
|
||||
def __post_init__(self):
|
||||
# Add field to message
|
||||
self.parent.fields.append(self)
|
||||
# Check for new imports
|
||||
annotation = self.annotation
|
||||
if "Optional[" in annotation:
|
||||
self.output_file.typing_imports.add("Optional")
|
||||
if "List[" in annotation:
|
||||
self.output_file.typing_imports.add("List")
|
||||
if "Dict[" in annotation:
|
||||
self.output_file.typing_imports.add("Dict")
|
||||
if "timedelta" in annotation:
|
||||
self.output_file.datetime_imports.add("timedelta")
|
||||
if "datetime" in annotation:
|
||||
self.output_file.datetime_imports.add("datetime")
|
||||
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
|
||||
|
||||
def get_field_string(self, indent: int = 4) -> str:
|
||||
"""Construct string representation of this field as a field."""
|
||||
name = f"{self.py_name}"
|
||||
annotations = f": {self.annotation}"
|
||||
betterproto_field_type = (
|
||||
f"betterproto.{self.field_type}_field({self.proto_obj.number}"
|
||||
+ f"{self.betterproto_field_args}"
|
||||
+ ")"
|
||||
)
|
||||
return name + annotations + " = " + betterproto_field_type
|
||||
|
||||
@property
|
||||
def betterproto_field_args(self):
|
||||
args = ""
|
||||
if self.field_wraps:
|
||||
args = args + f", wraps={self.field_wraps}"
|
||||
return args
|
||||
|
||||
@property
|
||||
def field_wraps(self) -> Union[str, None]:
|
||||
"""Returns betterproto wrapped field type or None.
|
||||
"""
|
||||
match_wrapper = re.match(
|
||||
r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name
|
||||
)
|
||||
if match_wrapper:
|
||||
wrapped_type = "TYPE_" + match_wrapper.group(1).upper()
|
||||
if hasattr(betterproto, wrapped_type):
|
||||
return f"betterproto.{wrapped_type}"
|
||||
return None
|
||||
|
||||
@property
|
||||
def repeated(self) -> bool:
|
||||
if self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED and not is_map(
|
||||
self.proto_obj, self.parent
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def mutable(self) -> bool:
|
||||
"""True if the field is a mutable type, otherwise False."""
|
||||
annotation = self.annotation
|
||||
return annotation.startswith("List[") or annotation.startswith("Dict[")
|
||||
|
||||
@property
|
||||
def field_type(self) -> str:
|
||||
"""String representation of proto field type."""
|
||||
return (
|
||||
self.proto_obj.Type.Name(self.proto_obj.type).lower().replace("type_", "")
|
||||
)
|
||||
|
||||
@property
|
||||
def default_value_string(self) -> Union[Text, None, float, int]:
|
||||
"""Python representation of the default proto value.
|
||||
"""
|
||||
if self.repeated:
|
||||
return "[]"
|
||||
if self.py_type == "int":
|
||||
return "0"
|
||||
if self.py_type == "float":
|
||||
return "0.0"
|
||||
elif self.py_type == "bool":
|
||||
return "False"
|
||||
elif self.py_type == "str":
|
||||
return '""'
|
||||
elif self.py_type == "bytes":
|
||||
return 'b""'
|
||||
else:
|
||||
# Message type
|
||||
return "None"
|
||||
|
||||
@property
|
||||
def packed(self) -> bool:
|
||||
"""True if the wire representation is a packed format."""
|
||||
if self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES:
|
||||
return True
|
||||
return False
|
||||
|
||||
@property
|
||||
def py_name(self) -> str:
|
||||
"""Pythonized name."""
|
||||
return pythonize_field_name(self.proto_name)
|
||||
|
||||
@property
|
||||
def proto_name(self) -> str:
|
||||
"""Original protobuf name."""
|
||||
return self.proto_obj.name
|
||||
|
||||
@property
|
||||
def py_type(self) -> str:
|
||||
"""String representation of Python type."""
|
||||
if self.proto_obj.type in PROTO_FLOAT_TYPES:
|
||||
return "float"
|
||||
elif self.proto_obj.type in PROTO_INT_TYPES:
|
||||
return "int"
|
||||
elif self.proto_obj.type in PROTO_BOOL_TYPES:
|
||||
return "bool"
|
||||
elif self.proto_obj.type in PROTO_STR_TYPES:
|
||||
return "str"
|
||||
elif self.proto_obj.type in PROTO_BYTES_TYPES:
|
||||
return "bytes"
|
||||
elif self.proto_obj.type in PROTO_MESSAGE_TYPES:
|
||||
# Type referencing another defined Message or a named enum
|
||||
return get_type_reference(
|
||||
package=self.output_file.package,
|
||||
imports=self.output_file.imports,
|
||||
source_type=self.proto_obj.type_name,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown type {field.type}")
|
||||
|
||||
@property
|
||||
def annotation(self) -> str:
|
||||
if self.repeated:
|
||||
return f"List[{self.py_type}]"
|
||||
return self.py_type
|
||||
|
||||
|
||||
@dataclass
|
||||
class OneOfFieldCompiler(FieldCompiler):
|
||||
@property
|
||||
def betterproto_field_args(self) -> "str":
|
||||
args = super().betterproto_field_args
|
||||
group = self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name
|
||||
args = args + f', group="{group}"'
|
||||
return args
|
||||
|
||||
|
||||
@dataclass
|
||||
class MapEntryCompiler(FieldCompiler):
|
||||
py_k_type: Type = PLACEHOLDER
|
||||
py_v_type: Type = PLACEHOLDER
|
||||
proto_k_type: str = PLACEHOLDER
|
||||
proto_v_type: str = PLACEHOLDER
|
||||
|
||||
def __post_init__(self):
|
||||
"""Explore nested types and set k_type and v_type if unset."""
|
||||
map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry"
|
||||
for nested in self.parent.proto_obj.nested_type:
|
||||
if nested.name.replace("_", "").lower() == map_entry:
|
||||
if nested.options.map_entry:
|
||||
# Get Python types
|
||||
self.py_k_type = FieldCompiler(
|
||||
parent=self, proto_obj=nested.field[0], # key
|
||||
).py_type
|
||||
self.py_v_type = FieldCompiler(
|
||||
parent=self, proto_obj=nested.field[1], # value
|
||||
).py_type
|
||||
# Get proto types
|
||||
self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type)
|
||||
self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type)
|
||||
super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__
|
||||
|
||||
def get_field_string(self, indent: int = 4) -> str:
|
||||
"""Construct string representation of this field."""
|
||||
name = f"{self.py_name}"
|
||||
annotations = f": {self.annotation}"
|
||||
betterproto_field_type = (
|
||||
f"betterproto.map_field("
|
||||
f"{self.proto_obj.number}, betterproto.{self.proto_k_type}, "
|
||||
f"betterproto.{self.proto_v_type})"
|
||||
)
|
||||
return name + annotations + " = " + betterproto_field_type
|
||||
|
||||
@property
|
||||
def annotation(self):
|
||||
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
|
||||
|
||||
@property
|
||||
def repeated(self):
|
||||
return False # maps cannot be repeated
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnumDefinitionCompiler(MessageCompiler):
|
||||
"""Representation of a proto Enum definition."""
|
||||
|
||||
proto_obj: EnumDescriptorProto = PLACEHOLDER
|
||||
entries: List["EnumDefinitionCompiler.EnumEntry"] = PLACEHOLDER
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class EnumEntry:
|
||||
"""Representation of an Enum entry."""
|
||||
|
||||
name: str
|
||||
value: int
|
||||
comment: str
|
||||
|
||||
def __post_init__(self):
|
||||
# Get entries/allowed values for this Enum
|
||||
self.entries = [
|
||||
self.EnumEntry(
|
||||
name=entry_proto_value.name,
|
||||
value=entry_proto_value.number,
|
||||
comment=get_comment(
|
||||
proto_file=self.proto_file, path=self.path + [2, entry_number]
|
||||
),
|
||||
)
|
||||
for entry_number, entry_proto_value in enumerate(self.proto_obj.value)
|
||||
]
|
||||
super().__post_init__() # call MessageCompiler __post_init__
|
||||
|
||||
@property
|
||||
def default_value_string(self) -> int:
|
||||
"""Python representation of the default value for Enums.
|
||||
|
||||
As per the spec, this is the first value of the Enum.
|
||||
"""
|
||||
return str(self.entries[0].value) # ideally, should ALWAYS be int(0)!
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceCompiler(ProtoContentBase):
|
||||
parent: OutputTemplate = PLACEHOLDER
|
||||
proto_obj: DescriptorProto = PLACEHOLDER
|
||||
path: List[int] = PLACEHOLDER
|
||||
methods: List["ServiceMethodCompiler"] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Add service to output file
|
||||
self.output_file.services.append(self)
|
||||
super().__post_init__() # check for unset fields
|
||||
|
||||
@property
|
||||
def proto_name(self):
|
||||
return self.proto_obj.name
|
||||
|
||||
@property
|
||||
def py_name(self):
|
||||
return pythonize_class_name(self.proto_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceMethodCompiler(ProtoContentBase):
|
||||
|
||||
parent: ServiceCompiler
|
||||
proto_obj: MethodDescriptorProto
|
||||
path: List[int] = PLACEHOLDER
|
||||
comment_indent: int = 8
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# Add method to service
|
||||
self.parent.methods.append(self)
|
||||
|
||||
# Check for Optional import
|
||||
if self.py_input_message:
|
||||
for f in self.py_input_message.fields:
|
||||
if f.default_value_string == "None":
|
||||
self.output_file.typing_imports.add("Optional")
|
||||
if "Optional" in self.py_output_message_type:
|
||||
self.output_file.typing_imports.add("Optional")
|
||||
self.mutable_default_args # ensure this is called before rendering
|
||||
|
||||
# Check for Async imports
|
||||
if self.client_streaming:
|
||||
self.output_file.typing_imports.add("AsyncIterable")
|
||||
self.output_file.typing_imports.add("Iterable")
|
||||
self.output_file.typing_imports.add("Union")
|
||||
if self.server_streaming:
|
||||
self.output_file.typing_imports.add("AsyncIterator")
|
||||
|
||||
super().__post_init__() # check for unset fields
|
||||
|
||||
@property
|
||||
def mutable_default_args(self) -> Dict[str, str]:
|
||||
"""Handle mutable default arguments.
|
||||
|
||||
Returns a list of tuples containing the name and default value
|
||||
for arguments to this message who's default value is mutable.
|
||||
The defaults are swapped out for None and replaced back inside
|
||||
the method's body.
|
||||
Reference:
|
||||
https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, str]
|
||||
Name and actual default value (as a string)
|
||||
for each argument with mutable default values.
|
||||
"""
|
||||
mutable_default_args = dict()
|
||||
|
||||
if self.py_input_message:
|
||||
for f in self.py_input_message.fields:
|
||||
if (
|
||||
not self.client_streaming
|
||||
and f.default_value_string != "None"
|
||||
and f.mutable
|
||||
):
|
||||
mutable_default_args[f.py_name] = f.default_value_string
|
||||
self.output_file.typing_imports.add("Optional")
|
||||
|
||||
return mutable_default_args
|
||||
|
||||
@property
|
||||
def py_name(self) -> str:
|
||||
"""Pythonized method name."""
|
||||
return pythonize_method_name(self.proto_obj.name)
|
||||
|
||||
@property
|
||||
def proto_name(self) -> str:
|
||||
"""Original protobuf name."""
|
||||
return self.proto_obj.name
|
||||
|
||||
@property
|
||||
def route(self) -> str:
|
||||
return (
|
||||
f"/{self.output_file.package}."
|
||||
f"{self.parent.proto_name}/{self.proto_name}"
|
||||
)
|
||||
|
||||
@property
|
||||
def py_input_message(self) -> Union[None, MessageCompiler]:
|
||||
"""Find the input message object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[None, MessageCompiler]
|
||||
Method instance representing the input message.
|
||||
If not input message could be found or there are no
|
||||
input messages, None is returned.
|
||||
"""
|
||||
package, name = parse_source_type_name(self.proto_obj.input_type)
|
||||
|
||||
# Nested types are currently flattened without dots.
|
||||
# Todo: keep a fully quantified name in types, that is
|
||||
# comparable with method.input_type
|
||||
for msg in self.request.all_messages:
|
||||
if (
|
||||
msg.py_name == name.replace(".", "")
|
||||
and msg.output_file.package == package
|
||||
):
|
||||
return msg
|
||||
return None
|
||||
|
||||
@property
|
||||
def py_input_message_type(self) -> str:
|
||||
"""String representation of the Python type correspoding to the
|
||||
input message.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
String representation of the Python type correspoding to the
|
||||
input message.
|
||||
"""
|
||||
return get_type_reference(
|
||||
package=self.output_file.package,
|
||||
imports=self.output_file.imports,
|
||||
source_type=self.proto_obj.input_type,
|
||||
).strip('"')
|
||||
|
||||
@property
|
||||
def py_output_message_type(self) -> str:
|
||||
"""String representation of the Python type correspoding to the
|
||||
output message.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
String representation of the Python type correspoding to the
|
||||
output message.
|
||||
"""
|
||||
return get_type_reference(
|
||||
package=self.output_file.package,
|
||||
imports=self.output_file.imports,
|
||||
source_type=self.proto_obj.output_type,
|
||||
unwrap=False,
|
||||
).strip('"')
|
||||
|
||||
@property
|
||||
def client_streaming(self) -> bool:
|
||||
return self.proto_obj.client_streaming
|
||||
|
||||
@property
|
||||
def server_streaming(self) -> bool:
|
||||
return self.proto_obj.server_streaming
|
||||
188
src/betterproto/plugin/parser.py
Normal file
188
src/betterproto/plugin/parser.py
Normal file
@@ -0,0 +1,188 @@
|
||||
import itertools
|
||||
import os.path
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import List, Iterator
|
||||
|
||||
try:
|
||||
# betterproto[compiler] specific dependencies
|
||||
import black
|
||||
from google.protobuf.compiler import plugin_pb2 as plugin
|
||||
from google.protobuf.descriptor_pb2 import (
|
||||
DescriptorProto,
|
||||
EnumDescriptorProto,
|
||||
FieldDescriptorProto,
|
||||
ServiceDescriptorProto,
|
||||
)
|
||||
import jinja2
|
||||
except ImportError as err:
|
||||
missing_import = err.args[0][17:-1]
|
||||
print(
|
||||
"\033[31m"
|
||||
f"Unable to import `{missing_import}` from betterproto plugin! "
|
||||
"Please ensure that you've installed betterproto as "
|
||||
'`pip install "betterproto[compiler]"` so that compiler dependencies '
|
||||
"are included."
|
||||
"\033[0m"
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
from betterproto.plugin.models import (
|
||||
PluginRequestCompiler,
|
||||
OutputTemplate,
|
||||
MessageCompiler,
|
||||
FieldCompiler,
|
||||
OneOfFieldCompiler,
|
||||
MapEntryCompiler,
|
||||
EnumDefinitionCompiler,
|
||||
ServiceCompiler,
|
||||
ServiceMethodCompiler,
|
||||
is_map,
|
||||
is_oneof,
|
||||
)
|
||||
|
||||
|
||||
def traverse(proto_file: FieldDescriptorProto) -> Iterator:
|
||||
# Todo: Keep information about nested hierarchy
|
||||
def _traverse(path, items, prefix=""):
|
||||
for i, item in enumerate(items):
|
||||
# Adjust the name since we flatten the hierarchy.
|
||||
# Todo: don't change the name, but include full name in returned tuple
|
||||
item.name = next_prefix = prefix + item.name
|
||||
yield item, path + [i]
|
||||
|
||||
if isinstance(item, DescriptorProto):
|
||||
for enum in item.enum_type:
|
||||
enum.name = next_prefix + enum.name
|
||||
yield enum, path + [i, 4]
|
||||
|
||||
if item.nested_type:
|
||||
for n, p in _traverse(path + [i, 3], item.nested_type, next_prefix):
|
||||
yield n, p
|
||||
|
||||
return itertools.chain(
|
||||
_traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type)
|
||||
)
|
||||
|
||||
|
||||
def generate_code(
|
||||
request: plugin.CodeGeneratorRequest, response: plugin.CodeGeneratorResponse
|
||||
) -> None:
|
||||
plugin_options = request.parameter.split(",") if request.parameter else []
|
||||
|
||||
templates_folder = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "templates")
|
||||
)
|
||||
|
||||
env = jinja2.Environment(
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
loader=jinja2.FileSystemLoader(templates_folder),
|
||||
)
|
||||
template = env.get_template("template.py.j2")
|
||||
request_data = PluginRequestCompiler(plugin_request_obj=request)
|
||||
# Gather output packages
|
||||
for proto_file in request.proto_file:
|
||||
if (
|
||||
proto_file.package == "google.protobuf"
|
||||
and "INCLUDE_GOOGLE" not in plugin_options
|
||||
):
|
||||
# If not INCLUDE_GOOGLE,
|
||||
# skip re-compiling Google's well-known types
|
||||
continue
|
||||
|
||||
output_package_name = proto_file.package
|
||||
if output_package_name not in request_data.output_packages:
|
||||
# Create a new output if there is no output for this package
|
||||
request_data.output_packages[output_package_name] = OutputTemplate(
|
||||
parent_request=request_data, package_proto_obj=proto_file
|
||||
)
|
||||
# Add this input file to the output corresponding to this package
|
||||
request_data.output_packages[output_package_name].input_files.append(proto_file)
|
||||
|
||||
# Read Messages and Enums
|
||||
# We need to read Messages before Services in so that we can
|
||||
# get the references to input/output messages for each service
|
||||
for output_package_name, output_package in request_data.output_packages.items():
|
||||
for proto_input_file in output_package.input_files:
|
||||
for item, path in traverse(proto_input_file):
|
||||
read_protobuf_type(item=item, path=path, output_package=output_package)
|
||||
|
||||
# Read Services
|
||||
for output_package_name, output_package in request_data.output_packages.items():
|
||||
for proto_input_file in output_package.input_files:
|
||||
for index, service in enumerate(proto_input_file.service):
|
||||
read_protobuf_service(service, index, output_package)
|
||||
|
||||
# Generate output files
|
||||
output_paths: pathlib.Path = set()
|
||||
for output_package_name, template_data in request_data.output_packages.items():
|
||||
|
||||
# Add files to the response object
|
||||
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
|
||||
output_paths.add(output_path)
|
||||
|
||||
f: response.File = response.file.add()
|
||||
f.name: str = str(output_path)
|
||||
|
||||
# Render and then format the output file
|
||||
f.content: str = black.format_str(
|
||||
template.render(description=template_data),
|
||||
mode=black.FileMode(target_versions={black.TargetVersion.PY37}),
|
||||
)
|
||||
|
||||
# Make each output directory a package with __init__ file
|
||||
init_files = (
|
||||
set(
|
||||
directory.joinpath("__init__.py")
|
||||
for path in output_paths
|
||||
for directory in path.parents
|
||||
)
|
||||
- output_paths
|
||||
)
|
||||
|
||||
for init_file in init_files:
|
||||
init = response.file.add()
|
||||
init.name = str(init_file)
|
||||
|
||||
for output_package_name in sorted(output_paths.union(init_files)):
|
||||
print(f"Writing {output_package_name}", file=sys.stderr)
|
||||
|
||||
|
||||
def read_protobuf_type(
|
||||
item: DescriptorProto, path: List[int], output_package: OutputTemplate
|
||||
) -> None:
|
||||
if isinstance(item, DescriptorProto):
|
||||
if item.options.map_entry:
|
||||
# Skip generated map entry messages since we just use dicts
|
||||
return
|
||||
# Process Message
|
||||
message_data = MessageCompiler(parent=output_package, proto_obj=item, path=path)
|
||||
for index, field in enumerate(item.field):
|
||||
if is_map(field, item):
|
||||
MapEntryCompiler(
|
||||
parent=message_data, proto_obj=field, path=path + [2, index]
|
||||
)
|
||||
elif is_oneof(field):
|
||||
OneOfFieldCompiler(
|
||||
parent=message_data, proto_obj=field, path=path + [2, index]
|
||||
)
|
||||
else:
|
||||
FieldCompiler(
|
||||
parent=message_data, proto_obj=field, path=path + [2, index]
|
||||
)
|
||||
elif isinstance(item, EnumDescriptorProto):
|
||||
# Enum
|
||||
EnumDefinitionCompiler(parent=output_package, proto_obj=item, path=path)
|
||||
|
||||
|
||||
def read_protobuf_service(
|
||||
service: ServiceDescriptorProto, index: int, output_package: OutputTemplate
|
||||
) -> None:
|
||||
service_data = ServiceCompiler(
|
||||
parent=output_package, proto_obj=service, path=[6, index],
|
||||
)
|
||||
for j, method in enumerate(service.method):
|
||||
ServiceMethodCompiler(
|
||||
parent=service_data, proto_obj=method, path=[6, index, 2, j],
|
||||
)
|
||||
2
src/betterproto/plugin/plugin.bat
Normal file
2
src/betterproto/plugin/plugin.bat
Normal file
@@ -0,0 +1,2 @@
|
||||
@SET plugin_dir=%~dp0
|
||||
@python -m %plugin_dir% %*
|
||||
Reference in New Issue
Block a user