diff --git a/src/betterproto/plugin.bat b/src/betterproto/plugin.bat deleted file mode 100644 index 9b837d7..0000000 --- a/src/betterproto/plugin.bat +++ /dev/null @@ -1,2 +0,0 @@ -@SET plugin_dir=%~dp0 -@python %plugin_dir%/plugin.py %* \ No newline at end of file diff --git a/src/betterproto/plugin.py b/src/betterproto/plugin.py deleted file mode 100755 index 4f01c29..0000000 --- a/src/betterproto/plugin.py +++ /dev/null @@ -1,480 +0,0 @@ -#!/usr/bin/env python -import collections -import itertools -import os.path -import pathlib -import re -import sys -import textwrap -from typing import List, Union - -from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest - -import betterproto -from betterproto.compile.importing import get_type_reference, parse_source_type_name -from betterproto.compile.naming import ( - pythonize_class_name, - pythonize_field_name, - pythonize_method_name, -) -from betterproto.lib.google.protobuf import ServiceDescriptorProto - -try: - # betterproto[compiler] specific dependencies - import black - from google.protobuf.compiler import plugin_pb2 as plugin - from google.protobuf.descriptor_pb2 import ( - DescriptorProto, - EnumDescriptorProto, - FieldDescriptorProto, - ) - import google.protobuf.wrappers_pb2 as google_wrappers - import jinja2 -except ImportError as err: - missing_import = err.args[0][17:-1] - print( - "\033[31m" - f"Unable to import `{missing_import}` from betterproto plugin! " - "Please ensure that you've installed betterproto as " - '`pip install "betterproto[compiler]"` so that compiler dependencies ' - "are included." - "\033[0m" - ) - raise SystemExit(1) - - -def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str: - if field.type in [1, 2]: - return "float" - elif field.type in [3, 4, 5, 6, 7, 13, 15, 16, 17, 18]: - return "int" - elif field.type == 8: - return "bool" - elif field.type == 9: - return "str" - elif field.type in [11, 14]: - # Type referencing another defined Message or a named enum - return get_type_reference(package, imports, field.type_name) - elif field.type == 12: - return "bytes" - else: - raise NotImplementedError(f"Unknown type {field.type}") - - -def get_py_zero(type_num: int) -> Union[str, float]: - zero: Union[str, float] = 0 - if type_num in []: - zero = 0.0 - elif type_num == 8: - zero = "False" - elif type_num == 9: - zero = '""' - elif type_num == 11: - zero = "None" - elif type_num == 12: - zero = 'b""' - - return zero - - -def traverse(proto_file): - # Todo: Keep information about nested hierarchy - def _traverse(path, items, prefix=""): - for i, item in enumerate(items): - # Adjust the name since we flatten the hierarchy. - # Todo: don't change the name, but include full name in returned tuple - item.name = next_prefix = prefix + item.name - yield item, path + [i] - - if isinstance(item, DescriptorProto): - for enum in item.enum_type: - enum.name = next_prefix + enum.name - yield enum, path + [i, 4] - - if item.nested_type: - for n, p in _traverse(path + [i, 3], item.nested_type, next_prefix): - yield n, p - - return itertools.chain( - _traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type) - ) - - -def get_comment(proto_file, path: List[int], indent: int = 4) -> str: - pad = " " * indent - for sci in proto_file.source_code_info.location: - # print(list(sci.path), path, file=sys.stderr) - if list(sci.path) == path and sci.leading_comments: - lines = textwrap.wrap( - sci.leading_comments.strip().replace("\n", ""), width=79 - indent - ) - - if path[-2] == 2 and path[-4] != 6: - # This is a field - return f"{pad}# " + f"\n{pad}# ".join(lines) - else: - # This is a message, enum, service, or method - if len(lines) == 1 and len(lines[0]) < 79 - indent - 6: - lines[0] = lines[0].strip('"') - return f'{pad}"""{lines[0]}"""' - else: - joined = f"\n{pad}".join(lines) - return f'{pad}"""\n{pad}{joined}\n{pad}"""' - - return "" - - -def generate_code(request, response): - plugin_options = request.parameter.split(",") if request.parameter else [] - - env = jinja2.Environment( - trim_blocks=True, - lstrip_blocks=True, - loader=jinja2.FileSystemLoader("%s/templates/" % os.path.dirname(__file__)), - ) - template = env.get_template("template.py.j2") - - # Gather output packages - output_package_files = collections.defaultdict() - for proto_file in request.proto_file: - if ( - proto_file.package == "google.protobuf" - and "INCLUDE_GOOGLE" not in plugin_options - ): - continue - - output_package = proto_file.package - output_package_files.setdefault( - output_package, {"input_package": proto_file.package, "files": []} - ) - output_package_files[output_package]["files"].append(proto_file) - - # Initialize Template data for each package - for output_package_name, output_package_content in output_package_files.items(): - template_data = { - "input_package": output_package_content["input_package"], - "files": [f.name for f in output_package_content["files"]], - "imports": set(), - "datetime_imports": set(), - "typing_imports": set(), - "messages": [], - "enums": [], - "services": [], - } - output_package_content["template_data"] = template_data - - # Read Messages and Enums - output_types = [] - for output_package_name, output_package_content in output_package_files.items(): - for proto_file in output_package_content["files"]: - for item, path in traverse(proto_file): - type_data = read_protobuf_type( - item, path, proto_file, output_package_content - ) - output_types.append(type_data) - - # Read Services - for output_package_name, output_package_content in output_package_files.items(): - for proto_file in output_package_content["files"]: - for index, service in enumerate(proto_file.service): - read_protobuf_service( - service, index, proto_file, output_package_content, output_types - ) - - # Render files - output_paths = set() - for output_package_name, output_package_content in output_package_files.items(): - template_data = output_package_content["template_data"] - template_data["imports"] = sorted(template_data["imports"]) - template_data["datetime_imports"] = sorted(template_data["datetime_imports"]) - template_data["typing_imports"] = sorted(template_data["typing_imports"]) - - # Fill response - output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") - output_paths.add(output_path) - - f = response.file.add() - f.name = str(output_path) - - # Render and then format the output file. - f.content = black.format_str( - template.render(description=template_data), - mode=black.FileMode(target_versions={black.TargetVersion.PY37}), - ) - - # Make each output directory a package with __init__ file - init_files = ( - set( - directory.joinpath("__init__.py") - for path in output_paths - for directory in path.parents - ) - - output_paths - ) - - for init_file in init_files: - init = response.file.add() - init.name = str(init_file) - - for output_package_name in sorted(output_paths.union(init_files)): - print(f"Writing {output_package_name}", file=sys.stderr) - - -def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file, content): - input_package_name = content["input_package"] - template_data = content["template_data"] - data = { - "name": item.name, - "py_name": pythonize_class_name(item.name), - "descriptor": item, - "package": input_package_name, - } - if isinstance(item, DescriptorProto): - # print(item, file=sys.stderr) - if item.options.map_entry: - # Skip generated map entry messages since we just use dicts - return - - data.update( - { - "type": "Message", - "comment": get_comment(proto_file, path), - "properties": [], - } - ) - - for i, f in enumerate(item.field): - t = py_type(input_package_name, template_data["imports"], f) - zero = get_py_zero(f.type) - - repeated = False - packed = False - - field_type = f.Type.Name(f.type).lower()[5:] - - field_wraps = "" - match_wrapper = re.match(r"\.google\.protobuf\.(.+)Value", f.type_name) - if match_wrapper: - wrapped_type = "TYPE_" + match_wrapper.group(1).upper() - if hasattr(betterproto, wrapped_type): - field_wraps = f"betterproto.{wrapped_type}" - - map_types = None - if f.type == 11: - # This might be a map... - message_type = f.type_name.split(".").pop().lower() - # message_type = py_type(package) - map_entry = f"{f.name.replace('_', '').lower()}entry" - - if message_type == map_entry: - for nested in item.nested_type: - if nested.name.replace("_", "").lower() == map_entry: - if nested.options.map_entry: - # print("Found a map!", file=sys.stderr) - k = py_type( - input_package_name, - template_data["imports"], - nested.field[0], - ) - v = py_type( - input_package_name, - template_data["imports"], - nested.field[1], - ) - t = f"Dict[{k}, {v}]" - field_type = "map" - map_types = ( - f.Type.Name(nested.field[0].type), - f.Type.Name(nested.field[1].type), - ) - template_data["typing_imports"].add("Dict") - - if f.label == 3 and field_type != "map": - # Repeated field - repeated = True - t = f"List[{t}]" - zero = "[]" - template_data["typing_imports"].add("List") - - if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: - packed = True - - one_of = "" - if f.HasField("oneof_index"): - one_of = item.oneof_decl[f.oneof_index].name - - if "Optional[" in t: - template_data["typing_imports"].add("Optional") - - if "timedelta" in t: - template_data["datetime_imports"].add("timedelta") - elif "datetime" in t: - template_data["datetime_imports"].add("datetime") - - data["properties"].append( - { - "name": f.name, - "py_name": pythonize_field_name(f.name), - "number": f.number, - "comment": get_comment(proto_file, path + [2, i]), - "proto_type": int(f.type), - "field_type": field_type, - "field_wraps": field_wraps, - "map_types": map_types, - "type": t, - "zero": zero, - "repeated": repeated, - "packed": packed, - "one_of": one_of, - } - ) - # print(f, file=sys.stderr) - - template_data["messages"].append(data) - return data - elif isinstance(item, EnumDescriptorProto): - # print(item.name, path, file=sys.stderr) - data.update( - { - "type": "Enum", - "comment": get_comment(proto_file, path), - "entries": [ - { - "name": v.name, - "value": v.number, - "comment": get_comment(proto_file, path + [2, i]), - } - for i, v in enumerate(item.value) - ], - } - ) - - template_data["enums"].append(data) - return data - - -def lookup_method_input_type(method, types): - package, name = parse_source_type_name(method.input_type) - - for known_type in types: - if known_type["type"] != "Message": - continue - - # Nested types are currently flattened without dots. - # Todo: keep a fully quantified name in types, that is comparable with method.input_type - if ( - package == known_type["package"] - and name.replace(".", "") == known_type["name"] - ): - return known_type - - -def is_mutable_field_type(field_type: str) -> bool: - return field_type.startswith("List[") or field_type.startswith("Dict[") - - -def read_protobuf_service( - service: ServiceDescriptorProto, index, proto_file, content, output_types -): - input_package_name = content["input_package"] - template_data = content["template_data"] - # print(service, file=sys.stderr) - data = { - "name": service.name, - "py_name": pythonize_class_name(service.name), - "comment": get_comment(proto_file, [6, index]), - "methods": [], - } - for j, method in enumerate(service.method): - method_input_message = lookup_method_input_type(method, output_types) - - # This section ensures that method arguments having a default - # value that is initialised as a List/Dict (mutable) is replaced - # with None and initialisation is deferred to the beginning of the - # method definition. This is done so to avoid any side-effects. - # Reference: https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments - mutable_default_args = [] - - if method_input_message: - for field in method_input_message["properties"]: - if ( - not method.client_streaming - and field["zero"] != "None" - and is_mutable_field_type(field["type"]) - ): - mutable_default_args.append((field["py_name"], field["zero"])) - field["zero"] = "None" - - if field["zero"] == "None": - template_data["typing_imports"].add("Optional") - - data["methods"].append( - { - "name": method.name, - "py_name": pythonize_method_name(method.name), - "comment": get_comment(proto_file, [6, index, 2, j], indent=8), - "route": f"/{input_package_name}.{service.name}/{method.name}", - "input": get_type_reference( - input_package_name, template_data["imports"], method.input_type - ).strip('"'), - "input_message": method_input_message, - "output": get_type_reference( - input_package_name, - template_data["imports"], - method.output_type, - unwrap=False, - ), - "client_streaming": method.client_streaming, - "server_streaming": method.server_streaming, - "mutable_default_args": mutable_default_args, - } - ) - - if method.client_streaming: - template_data["typing_imports"].add("AsyncIterable") - template_data["typing_imports"].add("Iterable") - template_data["typing_imports"].add("Union") - if method.server_streaming: - template_data["typing_imports"].add("AsyncIterator") - template_data["services"].append(data) - - -def main(): - """The plugin's main entry point.""" - # Read request message from stdin - data = sys.stdin.buffer.read() - - # Parse request - request = plugin.CodeGeneratorRequest() - request.ParseFromString(data) - - dump_file = os.getenv("BETTERPROTO_DUMP") - if dump_file: - dump_request(dump_file, request) - - # Create response - response = plugin.CodeGeneratorResponse() - - # Generate code - generate_code(request, response) - - # Serialise response message - output = response.SerializeToString() - - # Write to stdout - sys.stdout.buffer.write(output) - - -def dump_request(dump_file: str, request: CodeGeneratorRequest): - """ - For developers: Supports running plugin.py standalone so its possible to debug it. - Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file. - Then run plugin.py from your IDE in debugging mode, and redirect stdin to the file. - """ - with open(str(dump_file), "wb") as fh: - sys.stderr.write(f"\033[31mWriting input from protoc to: {dump_file}\033[0m\n") - fh.write(request.SerializeToString()) - - -if __name__ == "__main__": - main() diff --git a/src/betterproto/plugin/__init__.py b/src/betterproto/plugin/__init__.py new file mode 100644 index 0000000..c28a133 --- /dev/null +++ b/src/betterproto/plugin/__init__.py @@ -0,0 +1 @@ +from .main import main diff --git a/src/betterproto/plugin/__main__.py b/src/betterproto/plugin/__main__.py new file mode 100644 index 0000000..bd95dae --- /dev/null +++ b/src/betterproto/plugin/__main__.py @@ -0,0 +1,4 @@ +from .main import main + + +main() diff --git a/src/betterproto/plugin/main.py b/src/betterproto/plugin/main.py new file mode 100644 index 0000000..2604af2 --- /dev/null +++ b/src/betterproto/plugin/main.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +import sys +import os + +from google.protobuf.compiler import plugin_pb2 as plugin + +from betterproto.plugin.parser import generate_code + + +def main(): + """The plugin's main entry point.""" + # Read request message from stdin + data = sys.stdin.buffer.read() + + # Parse request + request = plugin.CodeGeneratorRequest() + request.ParseFromString(data) + + dump_file = os.getenv("BETTERPROTO_DUMP") + if dump_file: + dump_request(dump_file, request) + + # Create response + response = plugin.CodeGeneratorResponse() + + # Generate code + generate_code(request, response) + + # Serialise response message + output = response.SerializeToString() + + # Write to stdout + sys.stdout.buffer.write(output) + + +def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest): + """ + For developers: Supports running plugin.py standalone so its possible to debug it. + Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file. + Then run plugin.py from your IDE in debugging mode, and redirect stdin to the file. + """ + with open(str(dump_file), "wb") as fh: + sys.stderr.write(f"\033[31mWriting input from protoc to: {dump_file}\033[0m\n") + fh.write(request.SerializeToString()) + + +if __name__ == "__main__": + main() diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py new file mode 100644 index 0000000..8e19961 --- /dev/null +++ b/src/betterproto/plugin/models.py @@ -0,0 +1,713 @@ +"""Plugin model dataclasses. + +These classes are meant to be an intermediate representation +of protbuf objects. They are used to organize the data collected during parsing. + +The general intention is to create a doubly-linked tree-like structure +with the following types of references: +- Downwards references: from message -> fields, from output package -> messages +or from service -> service methods +- Upwards references: from field -> message, message -> package. +- Input/ouput message references: from a service method to it's corresponding +input/output messages, which may even be in another package. + +There are convenience methods to allow climbing up and down this tree, for +example to retrieve the list of all messages that are in the same package as +the current message. + +Most of these classes take as inputs: +- proto_obj: A reference to it's corresponding protobuf object as +presented by the protoc plugin. +- parent: a reference to the parent object in the tree. + +With this information, the class is able to expose attributes, +such as a pythonized name, that will be calculated from proto_obj. + +The instantiation should also attach a reference to the new object +into the corresponding place within it's parent object. For example, +instantiating field `A` with parent message `B` should add a +reference to `A` to `B`'s `fields` attirbute. +""" + +import re +from dataclasses import dataclass +from dataclasses import field +from typing import ( + Union, + Type, + List, + Dict, + Set, + Text, +) +import textwrap + +import betterproto +from betterproto.compile.importing import ( + get_type_reference, + parse_source_type_name, +) +from betterproto.compile.naming import ( + pythonize_class_name, + pythonize_field_name, + pythonize_method_name, +) + +try: + # betterproto[compiler] specific dependencies + from google.protobuf.compiler import plugin_pb2 as plugin + from google.protobuf.descriptor_pb2 import ( + DescriptorProto, + EnumDescriptorProto, + FieldDescriptorProto, + FileDescriptorProto, + MethodDescriptorProto, + ) +except ImportError as err: + missing_import = re.match(r".*(cannot import name .*$)", err.args[0]).group(1) + print( + "\033[31m" + f"Unable to import `{missing_import}` from betterproto plugin! " + "Please ensure that you've installed betterproto as " + '`pip install "betterproto[compiler]"` so that compiler dependencies ' + "are included." + "\033[0m" + ) + raise SystemExit(1) + +# Create a unique placeholder to deal with +# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses +PLACEHOLDER = object() + +# Organize proto types into categories +PROTO_FLOAT_TYPES = ( + FieldDescriptorProto.TYPE_DOUBLE, # 1 + FieldDescriptorProto.TYPE_FLOAT, # 2 +) +PROTO_INT_TYPES = ( + FieldDescriptorProto.TYPE_INT64, # 3 + FieldDescriptorProto.TYPE_UINT64, # 4 + FieldDescriptorProto.TYPE_INT32, # 5 + FieldDescriptorProto.TYPE_FIXED64, # 6 + FieldDescriptorProto.TYPE_FIXED32, # 7 + FieldDescriptorProto.TYPE_UINT32, # 13 + FieldDescriptorProto.TYPE_SFIXED32, # 15 + FieldDescriptorProto.TYPE_SFIXED64, # 16 + FieldDescriptorProto.TYPE_SINT32, # 17 + FieldDescriptorProto.TYPE_SINT64, # 18 +) +PROTO_BOOL_TYPES = (FieldDescriptorProto.TYPE_BOOL,) # 8 +PROTO_STR_TYPES = (FieldDescriptorProto.TYPE_STRING,) # 9 +PROTO_BYTES_TYPES = (FieldDescriptorProto.TYPE_BYTES,) # 12 +PROTO_MESSAGE_TYPES = ( + FieldDescriptorProto.TYPE_MESSAGE, # 11 + FieldDescriptorProto.TYPE_ENUM, # 14 +) +PROTO_MAP_TYPES = (FieldDescriptorProto.TYPE_MESSAGE,) # 11 +PROTO_PACKED_TYPES = ( + FieldDescriptorProto.TYPE_DOUBLE, # 1 + FieldDescriptorProto.TYPE_FLOAT, # 2 + FieldDescriptorProto.TYPE_INT64, # 3 + FieldDescriptorProto.TYPE_UINT64, # 4 + FieldDescriptorProto.TYPE_INT32, # 5 + FieldDescriptorProto.TYPE_FIXED64, # 6 + FieldDescriptorProto.TYPE_FIXED32, # 7 + FieldDescriptorProto.TYPE_BOOL, # 8 + FieldDescriptorProto.TYPE_UINT32, # 13 + FieldDescriptorProto.TYPE_SFIXED32, # 15 + FieldDescriptorProto.TYPE_SFIXED64, # 16 + FieldDescriptorProto.TYPE_SINT32, # 17 + FieldDescriptorProto.TYPE_SINT64, # 18 +) + + +def get_comment(proto_file, path: List[int], indent: int = 4) -> str: + pad = " " * indent + for sci in proto_file.source_code_info.location: + # print(list(sci.path), path, file=sys.stderr) + if list(sci.path) == path and sci.leading_comments: + lines = textwrap.wrap( + sci.leading_comments.strip().replace("\n", ""), width=79 - indent, + ) + + if path[-2] == 2 and path[-4] != 6: + # This is a field + return f"{pad}# " + f"\n{pad}# ".join(lines) + else: + # This is a message, enum, service, or method + if len(lines) == 1 and len(lines[0]) < 79 - indent - 6: + lines[0] = lines[0].strip('"') + return f'{pad}"""{lines[0]}"""' + else: + joined = f"\n{pad}".join(lines) + return f'{pad}"""\n{pad}{joined}\n{pad}"""' + + return "" + + +class ProtoContentBase: + """Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler.""" + + path: List[int] + comment_indent: int = 4 + + def __post_init__(self): + """Checks that no fake default fields were left as placeholders.""" + for field_name, field_val in self.__dataclass_fields__.items(): + if field_val is PLACEHOLDER: + raise ValueError(f"`{field_name}` is a required field.") + + @property + def output_file(self) -> "OutputTemplate": + current = self + while not isinstance(current, OutputTemplate): + current = current.parent + return current + + @property + def proto_file(self) -> FieldDescriptorProto: + current = self + while not isinstance(current, OutputTemplate): + current = current.parent + return current.package_proto_obj + + @property + def request(self) -> "PluginRequestCompiler": + current = self + while not isinstance(current, OutputTemplate): + current = current.parent + return current.parent_request + + @property + def comment(self) -> str: + """Crawl the proto source code and retrieve comments + for this object. + """ + return get_comment( + proto_file=self.proto_file, path=self.path, indent=self.comment_indent, + ) + + +@dataclass +class PluginRequestCompiler: + + plugin_request_obj: plugin.CodeGeneratorRequest + output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict) + + @property + def all_messages(self) -> List["MessageCompiler"]: + """All of the messages in this request. + + Returns + ------- + List[MessageCompiler] + List of all of the messages in this request. + """ + return [ + msg for output in self.output_packages.values() for msg in output.messages + ] + + +@dataclass +class OutputTemplate: + """Representation of an output .py file. + + Each output file corresponds to a .proto input file, + but may need references to other .proto files to be + built. + """ + + parent_request: PluginRequestCompiler + package_proto_obj: FileDescriptorProto + input_files: List[str] = field(default_factory=list) + imports: Set[str] = field(default_factory=set) + datetime_imports: Set[str] = field(default_factory=set) + typing_imports: Set[str] = field(default_factory=set) + messages: List["MessageCompiler"] = field(default_factory=list) + enums: List["EnumDefinitionCompiler"] = field(default_factory=list) + services: List["ServiceCompiler"] = field(default_factory=list) + + @property + def package(self) -> str: + """Name of input package. + + Returns + ------- + str + Name of input package. + """ + return self.package_proto_obj.package + + @property + def input_filenames(self) -> List[str]: + """Names of the input files used to build this output. + + Returns + ------- + List[str] + Names of the input files used to build this output. + """ + return [f.name for f in self.input_files] + + +@dataclass +class MessageCompiler(ProtoContentBase): + """Representation of a protobuf message. + """ + + parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER + proto_obj: DescriptorProto = PLACEHOLDER + path: List[int] = PLACEHOLDER + fields: List[Union["FieldCompiler", "MessageCompiler"]] = field( + default_factory=list + ) + + def __post_init__(self): + # Add message to output file + if isinstance(self.parent, OutputTemplate): + if isinstance(self, EnumDefinitionCompiler): + self.output_file.enums.append(self) + else: + self.output_file.messages.append(self) + super().__post_init__() + + @property + def proto_name(self) -> str: + return self.proto_obj.name + + @property + def py_name(self) -> str: + return pythonize_class_name(self.proto_name) + + @property + def annotation(self) -> str: + if self.repeated: + return f"List[{self.py_name}]" + return self.py_name + + +def is_map( + proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto +) -> bool: + """True if proto_field_obj is a map, otherwise False. + """ + if proto_field_obj.type == FieldDescriptorProto.TYPE_MESSAGE: + # This might be a map... + message_type = proto_field_obj.type_name.split(".").pop().lower() + map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry" + if message_type == map_entry: + for nested in parent_message.nested_type: # parent message + if nested.name.replace("_", "").lower() == map_entry: + if nested.options.map_entry: + return True + return False + + +def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool: + """True if proto_field_obj is a OneOf, otherwise False. + """ + if proto_field_obj.HasField("oneof_index"): + return True + return False + + +@dataclass +class FieldCompiler(MessageCompiler): + parent: MessageCompiler = PLACEHOLDER + proto_obj: FieldDescriptorProto = PLACEHOLDER + + def __post_init__(self): + # Add field to message + self.parent.fields.append(self) + # Check for new imports + annotation = self.annotation + if "Optional[" in annotation: + self.output_file.typing_imports.add("Optional") + if "List[" in annotation: + self.output_file.typing_imports.add("List") + if "Dict[" in annotation: + self.output_file.typing_imports.add("Dict") + if "timedelta" in annotation: + self.output_file.datetime_imports.add("timedelta") + if "datetime" in annotation: + self.output_file.datetime_imports.add("datetime") + super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__ + + def get_field_string(self, indent: int = 4) -> str: + """Construct string representation of this field as a field.""" + name = f"{self.py_name}" + annotations = f": {self.annotation}" + betterproto_field_type = ( + f"betterproto.{self.field_type}_field({self.proto_obj.number}" + + f"{self.betterproto_field_args}" + + ")" + ) + return name + annotations + " = " + betterproto_field_type + + @property + def betterproto_field_args(self): + args = "" + if self.field_wraps: + args = args + f", wraps={self.field_wraps}" + return args + + @property + def field_wraps(self) -> Union[str, None]: + """Returns betterproto wrapped field type or None. + """ + match_wrapper = re.match( + r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name + ) + if match_wrapper: + wrapped_type = "TYPE_" + match_wrapper.group(1).upper() + if hasattr(betterproto, wrapped_type): + return f"betterproto.{wrapped_type}" + return None + + @property + def repeated(self) -> bool: + if self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED and not is_map( + self.proto_obj, self.parent + ): + return True + return False + + @property + def mutable(self) -> bool: + """True if the field is a mutable type, otherwise False.""" + annotation = self.annotation + return annotation.startswith("List[") or annotation.startswith("Dict[") + + @property + def field_type(self) -> str: + """String representation of proto field type.""" + return ( + self.proto_obj.Type.Name(self.proto_obj.type).lower().replace("type_", "") + ) + + @property + def default_value_string(self) -> Union[Text, None, float, int]: + """Python representation of the default proto value. + """ + if self.repeated: + return "[]" + if self.py_type == "int": + return "0" + if self.py_type == "float": + return "0.0" + elif self.py_type == "bool": + return "False" + elif self.py_type == "str": + return '""' + elif self.py_type == "bytes": + return 'b""' + else: + # Message type + return "None" + + @property + def packed(self) -> bool: + """True if the wire representation is a packed format.""" + if self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES: + return True + return False + + @property + def py_name(self) -> str: + """Pythonized name.""" + return pythonize_field_name(self.proto_name) + + @property + def proto_name(self) -> str: + """Original protobuf name.""" + return self.proto_obj.name + + @property + def py_type(self) -> str: + """String representation of Python type.""" + if self.proto_obj.type in PROTO_FLOAT_TYPES: + return "float" + elif self.proto_obj.type in PROTO_INT_TYPES: + return "int" + elif self.proto_obj.type in PROTO_BOOL_TYPES: + return "bool" + elif self.proto_obj.type in PROTO_STR_TYPES: + return "str" + elif self.proto_obj.type in PROTO_BYTES_TYPES: + return "bytes" + elif self.proto_obj.type in PROTO_MESSAGE_TYPES: + # Type referencing another defined Message or a named enum + return get_type_reference( + package=self.output_file.package, + imports=self.output_file.imports, + source_type=self.proto_obj.type_name, + ) + else: + raise NotImplementedError(f"Unknown type {field.type}") + + @property + def annotation(self) -> str: + if self.repeated: + return f"List[{self.py_type}]" + return self.py_type + + +@dataclass +class OneOfFieldCompiler(FieldCompiler): + @property + def betterproto_field_args(self) -> "str": + args = super().betterproto_field_args + group = self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name + args = args + f', group="{group}"' + return args + + +@dataclass +class MapEntryCompiler(FieldCompiler): + py_k_type: Type = PLACEHOLDER + py_v_type: Type = PLACEHOLDER + proto_k_type: str = PLACEHOLDER + proto_v_type: str = PLACEHOLDER + + def __post_init__(self): + """Explore nested types and set k_type and v_type if unset.""" + map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry" + for nested in self.parent.proto_obj.nested_type: + if nested.name.replace("_", "").lower() == map_entry: + if nested.options.map_entry: + # Get Python types + self.py_k_type = FieldCompiler( + parent=self, proto_obj=nested.field[0], # key + ).py_type + self.py_v_type = FieldCompiler( + parent=self, proto_obj=nested.field[1], # value + ).py_type + # Get proto types + self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type) + self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type) + super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__ + + def get_field_string(self, indent: int = 4) -> str: + """Construct string representation of this field.""" + name = f"{self.py_name}" + annotations = f": {self.annotation}" + betterproto_field_type = ( + f"betterproto.map_field(" + f"{self.proto_obj.number}, betterproto.{self.proto_k_type}, " + f"betterproto.{self.proto_v_type})" + ) + return name + annotations + " = " + betterproto_field_type + + @property + def annotation(self): + return f"Dict[{self.py_k_type}, {self.py_v_type}]" + + @property + def repeated(self): + return False # maps cannot be repeated + + +@dataclass +class EnumDefinitionCompiler(MessageCompiler): + """Representation of a proto Enum definition.""" + + proto_obj: EnumDescriptorProto = PLACEHOLDER + entries: List["EnumDefinitionCompiler.EnumEntry"] = PLACEHOLDER + + @dataclass(unsafe_hash=True) + class EnumEntry: + """Representation of an Enum entry.""" + + name: str + value: int + comment: str + + def __post_init__(self): + # Get entries/allowed values for this Enum + self.entries = [ + self.EnumEntry( + name=entry_proto_value.name, + value=entry_proto_value.number, + comment=get_comment( + proto_file=self.proto_file, path=self.path + [2, entry_number] + ), + ) + for entry_number, entry_proto_value in enumerate(self.proto_obj.value) + ] + super().__post_init__() # call MessageCompiler __post_init__ + + @property + def default_value_string(self) -> int: + """Python representation of the default value for Enums. + + As per the spec, this is the first value of the Enum. + """ + return str(self.entries[0].value) # ideally, should ALWAYS be int(0)! + + +@dataclass +class ServiceCompiler(ProtoContentBase): + parent: OutputTemplate = PLACEHOLDER + proto_obj: DescriptorProto = PLACEHOLDER + path: List[int] = PLACEHOLDER + methods: List["ServiceMethodCompiler"] = field(default_factory=list) + + def __post_init__(self) -> None: + # Add service to output file + self.output_file.services.append(self) + super().__post_init__() # check for unset fields + + @property + def proto_name(self): + return self.proto_obj.name + + @property + def py_name(self): + return pythonize_class_name(self.proto_name) + + +@dataclass +class ServiceMethodCompiler(ProtoContentBase): + + parent: ServiceCompiler + proto_obj: MethodDescriptorProto + path: List[int] = PLACEHOLDER + comment_indent: int = 8 + + def __post_init__(self) -> None: + # Add method to service + self.parent.methods.append(self) + + # Check for Optional import + if self.py_input_message: + for f in self.py_input_message.fields: + if f.default_value_string == "None": + self.output_file.typing_imports.add("Optional") + if "Optional" in self.py_output_message_type: + self.output_file.typing_imports.add("Optional") + self.mutable_default_args # ensure this is called before rendering + + # Check for Async imports + if self.client_streaming: + self.output_file.typing_imports.add("AsyncIterable") + self.output_file.typing_imports.add("Iterable") + self.output_file.typing_imports.add("Union") + if self.server_streaming: + self.output_file.typing_imports.add("AsyncIterator") + + super().__post_init__() # check for unset fields + + @property + def mutable_default_args(self) -> Dict[str, str]: + """Handle mutable default arguments. + + Returns a list of tuples containing the name and default value + for arguments to this message who's default value is mutable. + The defaults are swapped out for None and replaced back inside + the method's body. + Reference: + https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments + + Returns + ------- + Dict[str, str] + Name and actual default value (as a string) + for each argument with mutable default values. + """ + mutable_default_args = dict() + + if self.py_input_message: + for f in self.py_input_message.fields: + if ( + not self.client_streaming + and f.default_value_string != "None" + and f.mutable + ): + mutable_default_args[f.py_name] = f.default_value_string + self.output_file.typing_imports.add("Optional") + + return mutable_default_args + + @property + def py_name(self) -> str: + """Pythonized method name.""" + return pythonize_method_name(self.proto_obj.name) + + @property + def proto_name(self) -> str: + """Original protobuf name.""" + return self.proto_obj.name + + @property + def route(self) -> str: + return ( + f"/{self.output_file.package}." + f"{self.parent.proto_name}/{self.proto_name}" + ) + + @property + def py_input_message(self) -> Union[None, MessageCompiler]: + """Find the input message object. + + Returns + ------- + Union[None, MessageCompiler] + Method instance representing the input message. + If not input message could be found or there are no + input messages, None is returned. + """ + package, name = parse_source_type_name(self.proto_obj.input_type) + + # Nested types are currently flattened without dots. + # Todo: keep a fully quantified name in types, that is + # comparable with method.input_type + for msg in self.request.all_messages: + if ( + msg.py_name == name.replace(".", "") + and msg.output_file.package == package + ): + return msg + return None + + @property + def py_input_message_type(self) -> str: + """String representation of the Python type correspoding to the + input message. + + Returns + ------- + str + String representation of the Python type correspoding to the + input message. + """ + return get_type_reference( + package=self.output_file.package, + imports=self.output_file.imports, + source_type=self.proto_obj.input_type, + ).strip('"') + + @property + def py_output_message_type(self) -> str: + """String representation of the Python type correspoding to the + output message. + + Returns + ------- + str + String representation of the Python type correspoding to the + output message. + """ + return get_type_reference( + package=self.output_file.package, + imports=self.output_file.imports, + source_type=self.proto_obj.output_type, + unwrap=False, + ).strip('"') + + @property + def client_streaming(self) -> bool: + return self.proto_obj.client_streaming + + @property + def server_streaming(self) -> bool: + return self.proto_obj.server_streaming diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py new file mode 100644 index 0000000..33991ec --- /dev/null +++ b/src/betterproto/plugin/parser.py @@ -0,0 +1,188 @@ +import itertools +import os.path +import pathlib +import sys +from typing import List, Iterator + +try: + # betterproto[compiler] specific dependencies + import black + from google.protobuf.compiler import plugin_pb2 as plugin + from google.protobuf.descriptor_pb2 import ( + DescriptorProto, + EnumDescriptorProto, + FieldDescriptorProto, + ServiceDescriptorProto, + ) + import jinja2 +except ImportError as err: + missing_import = err.args[0][17:-1] + print( + "\033[31m" + f"Unable to import `{missing_import}` from betterproto plugin! " + "Please ensure that you've installed betterproto as " + '`pip install "betterproto[compiler]"` so that compiler dependencies ' + "are included." + "\033[0m" + ) + raise SystemExit(1) + +from betterproto.plugin.models import ( + PluginRequestCompiler, + OutputTemplate, + MessageCompiler, + FieldCompiler, + OneOfFieldCompiler, + MapEntryCompiler, + EnumDefinitionCompiler, + ServiceCompiler, + ServiceMethodCompiler, + is_map, + is_oneof, +) + + +def traverse(proto_file: FieldDescriptorProto) -> Iterator: + # Todo: Keep information about nested hierarchy + def _traverse(path, items, prefix=""): + for i, item in enumerate(items): + # Adjust the name since we flatten the hierarchy. + # Todo: don't change the name, but include full name in returned tuple + item.name = next_prefix = prefix + item.name + yield item, path + [i] + + if isinstance(item, DescriptorProto): + for enum in item.enum_type: + enum.name = next_prefix + enum.name + yield enum, path + [i, 4] + + if item.nested_type: + for n, p in _traverse(path + [i, 3], item.nested_type, next_prefix): + yield n, p + + return itertools.chain( + _traverse([5], proto_file.enum_type), _traverse([4], proto_file.message_type) + ) + + +def generate_code( + request: plugin.CodeGeneratorRequest, response: plugin.CodeGeneratorResponse +) -> None: + plugin_options = request.parameter.split(",") if request.parameter else [] + + templates_folder = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "templates") + ) + + env = jinja2.Environment( + trim_blocks=True, + lstrip_blocks=True, + loader=jinja2.FileSystemLoader(templates_folder), + ) + template = env.get_template("template.py.j2") + request_data = PluginRequestCompiler(plugin_request_obj=request) + # Gather output packages + for proto_file in request.proto_file: + if ( + proto_file.package == "google.protobuf" + and "INCLUDE_GOOGLE" not in plugin_options + ): + # If not INCLUDE_GOOGLE, + # skip re-compiling Google's well-known types + continue + + output_package_name = proto_file.package + if output_package_name not in request_data.output_packages: + # Create a new output if there is no output for this package + request_data.output_packages[output_package_name] = OutputTemplate( + parent_request=request_data, package_proto_obj=proto_file + ) + # Add this input file to the output corresponding to this package + request_data.output_packages[output_package_name].input_files.append(proto_file) + + # Read Messages and Enums + # We need to read Messages before Services in so that we can + # get the references to input/output messages for each service + for output_package_name, output_package in request_data.output_packages.items(): + for proto_input_file in output_package.input_files: + for item, path in traverse(proto_input_file): + read_protobuf_type(item=item, path=path, output_package=output_package) + + # Read Services + for output_package_name, output_package in request_data.output_packages.items(): + for proto_input_file in output_package.input_files: + for index, service in enumerate(proto_input_file.service): + read_protobuf_service(service, index, output_package) + + # Generate output files + output_paths: pathlib.Path = set() + for output_package_name, template_data in request_data.output_packages.items(): + + # Add files to the response object + output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") + output_paths.add(output_path) + + f: response.File = response.file.add() + f.name: str = str(output_path) + + # Render and then format the output file + f.content: str = black.format_str( + template.render(description=template_data), + mode=black.FileMode(target_versions={black.TargetVersion.PY37}), + ) + + # Make each output directory a package with __init__ file + init_files = ( + set( + directory.joinpath("__init__.py") + for path in output_paths + for directory in path.parents + ) + - output_paths + ) + + for init_file in init_files: + init = response.file.add() + init.name = str(init_file) + + for output_package_name in sorted(output_paths.union(init_files)): + print(f"Writing {output_package_name}", file=sys.stderr) + + +def read_protobuf_type( + item: DescriptorProto, path: List[int], output_package: OutputTemplate +) -> None: + if isinstance(item, DescriptorProto): + if item.options.map_entry: + # Skip generated map entry messages since we just use dicts + return + # Process Message + message_data = MessageCompiler(parent=output_package, proto_obj=item, path=path) + for index, field in enumerate(item.field): + if is_map(field, item): + MapEntryCompiler( + parent=message_data, proto_obj=field, path=path + [2, index] + ) + elif is_oneof(field): + OneOfFieldCompiler( + parent=message_data, proto_obj=field, path=path + [2, index] + ) + else: + FieldCompiler( + parent=message_data, proto_obj=field, path=path + [2, index] + ) + elif isinstance(item, EnumDescriptorProto): + # Enum + EnumDefinitionCompiler(parent=output_package, proto_obj=item, path=path) + + +def read_protobuf_service( + service: ServiceDescriptorProto, index: int, output_package: OutputTemplate +) -> None: + service_data = ServiceCompiler( + parent=output_package, proto_obj=service, path=[6, index], + ) + for j, method in enumerate(service.method): + ServiceMethodCompiler( + parent=service_data, proto_obj=method, path=[6, index, 2, j], + ) diff --git a/src/betterproto/plugin/plugin.bat b/src/betterproto/plugin/plugin.bat new file mode 100644 index 0000000..2a4444d --- /dev/null +++ b/src/betterproto/plugin/plugin.bat @@ -0,0 +1,2 @@ +@SET plugin_dir=%~dp0 +@python -m %plugin_dir% %* \ No newline at end of file diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index b7ca89c..bbd7cc5 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -1,13 +1,13 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! -# sources: {{ ', '.join(description.files) }} +# sources: {{ ', '.join(description.input_filenames) }} # plugin: python-betterproto from dataclasses import dataclass {% if description.datetime_imports %} -from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} +from datetime import {% for i in description.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} {% endif%} {% if description.typing_imports %} -from typing import {% for i in description.typing_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} +from typing import {% for i in description.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} {% endif %} @@ -40,13 +40,13 @@ class {{ message.py_name }}(betterproto.Message): {{ message.comment }} {% endif %} - {% for field in message.properties %} + {% for field in message.fields %} {% if field.comment %} {{ field.comment }} {% endif %} - {{ field.py_name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}{% if field.field_wraps %}, wraps={{ field.field_wraps }}{% endif %}) + {{ field.get_field_string() }} {% endfor %} - {% if not message.properties %} + {% if not message.fields %} pass {% endif %} @@ -61,32 +61,37 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% for method in service.methods %} async def {{ method.py_name }}(self {%- if not method.client_streaming -%} - {%- if method.input_message and method.input_message.properties -%}, *, - {%- for field in method.input_message.properties -%} - {{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") -%} - Optional[{{ field.type }}] + {%- if method.py_input_message and method.py_input_message.fields -%}, *, + {%- for field in method.py_input_message.fields -%} + {{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%} + Optional[{{ field.annotation }}] {%- else -%} - {{ field.type }} - {%- endif -%} = {{ field.zero }} + {{ field.annotation }} + {%- endif -%} = + {%- if field.py_name not in method.mutable_default_args -%} + {{ field.default_value_string }} + {%- else -%} + None + {% endif -%} {%- if not loop.last %}, {% endif -%} {%- endfor -%} {%- endif -%} {%- else -%} {# Client streaming: need a request iterator instead #} - , request_iterator: Union[AsyncIterable["{{ method.input }}"], Iterable["{{ method.input }}"]] + , request_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]] {%- endif -%} - ) -> {% if method.server_streaming %}AsyncIterator[{{ method.output }}]{% else %}{{ method.output }}{% endif %}: + ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: {% if method.comment %} {{ method.comment }} {% endif %} - {%- for py_name, zero in method.mutable_default_args %} + {%- for py_name, zero in method.mutable_default_args.items() %} {{ py_name }} = {{ py_name }} or {{ zero }} {% endfor %} {% if not method.client_streaming %} - request = {{ method.input }}() - {% for field in method.input_message.properties %} + request = {{ method.py_input_message_type }}() + {% for field in method.py_input_message.fields %} {% if field.field_type == 'message' %} if {{ field.py_name }} is not None: request.{{ field.py_name }} = {{ field.py_name }} @@ -101,15 +106,15 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): async for response in self._stream_stream( "{{ method.route }}", request_iterator, - {{ method.input }}, - {{ method.output.strip('"') }}, + {{ method.py_input_message_type }}, + {{ method.py_output_message_type.strip('"') }}, ): yield response {% else %}{# i.e. not client streaming #} async for response in self._unary_stream( "{{ method.route }}", request, - {{ method.output.strip('"') }}, + {{ method.py_output_message_type.strip('"') }}, ): yield response @@ -119,14 +124,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): return await self._stream_unary( "{{ method.route }}", request_iterator, - {{ method.input }}, - {{ method.output.strip('"') }} + {{ method.py_input_message_type }}, + {{ method.py_output_message_type.strip('"') }} ) {% else %}{# i.e. not client streaming #} return await self._unary_unary( "{{ method.route }}", request, - {{ method.output.strip('"') }} + {{ method.py_output_message_type.strip('"') }} ) {% endif %}{# client streaming #} {% endif %} @@ -134,6 +139,6 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% endfor %} {% endfor %} -{% for i in description.imports %} +{% for i in description.imports|sort %} {{ i }} -{% endfor %} \ No newline at end of file +{% endfor %}