Factor code template compilation out into a separate module

This commit is contained in:
Adrian Garcia Badaracco 2020-08-09 13:06:39 -05:00 committed by GitHub
parent 80bef7c94f
commit c93351ef21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 29 deletions

View File

@ -0,0 +1,40 @@
import os.path
try:
# betterproto[compiler] specific dependencies
import black
import jinja2
except ImportError as err:
missing_import = err.args[0][17:-1]
print(
"\033[31m"
f"Unable to import `{missing_import}` from betterproto plugin! "
"Please ensure that you've installed betterproto as "
'`pip install "betterproto[compiler]"` so that compiler dependencies '
"are included."
"\033[0m"
)
raise SystemExit(1)
from betterproto.plugin.models import OutputTemplate
def outputfile_compiler(output_file: OutputTemplate) -> str:
templates_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "templates")
)
env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
loader=jinja2.FileSystemLoader(templates_folder),
)
template = env.get_template("template.py.j2")
res = black.format_str(
template.render(output_file=output_file),
mode=black.FileMode(target_versions={black.TargetVersion.PY37}),
)
return res

View File

@ -1,12 +1,10 @@
import itertools import itertools
import os.path
import pathlib import pathlib
import sys import sys
from typing import List, Iterator from typing import List, Iterator
try: try:
# betterproto[compiler] specific dependencies # betterproto[compiler] specific dependencies
import black
from google.protobuf.compiler import plugin_pb2 as plugin from google.protobuf.compiler import plugin_pb2 as plugin
from google.protobuf.descriptor_pb2 import ( from google.protobuf.descriptor_pb2 import (
DescriptorProto, DescriptorProto,
@ -14,7 +12,6 @@ try:
FieldDescriptorProto, FieldDescriptorProto,
ServiceDescriptorProto, ServiceDescriptorProto,
) )
import jinja2
except ImportError as err: except ImportError as err:
missing_import = err.args[0][17:-1] missing_import = err.args[0][17:-1]
print( print(
@ -41,6 +38,8 @@ from betterproto.plugin.models import (
is_oneof, is_oneof,
) )
from betterproto.plugin.compiler import outputfile_compiler
def traverse(proto_file: FieldDescriptorProto) -> Iterator: def traverse(proto_file: FieldDescriptorProto) -> Iterator:
# Todo: Keep information about nested hierarchy # Todo: Keep information about nested hierarchy
@ -70,16 +69,6 @@ def generate_code(
) -> None: ) -> None:
plugin_options = request.parameter.split(",") if request.parameter else [] plugin_options = request.parameter.split(",") if request.parameter else []
templates_folder = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "templates")
)
env = jinja2.Environment(
trim_blocks=True,
lstrip_blocks=True,
loader=jinja2.FileSystemLoader(templates_folder),
)
template = env.get_template("template.py.j2")
request_data = PluginRequestCompiler(plugin_request_obj=request) request_data = PluginRequestCompiler(plugin_request_obj=request)
# Gather output packages # Gather output packages
for proto_file in request.proto_file: for proto_file in request.proto_file:
@ -116,7 +105,7 @@ def generate_code(
# Generate output files # Generate output files
output_paths: pathlib.Path = set() output_paths: pathlib.Path = set()
for output_package_name, template_data in request_data.output_packages.items(): for output_package_name, output_package in request_data.output_packages.items():
# Add files to the response object # Add files to the response object
output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") output_path = pathlib.Path(*output_package_name.split("."), "__init__.py")
@ -126,10 +115,7 @@ def generate_code(
f.name: str = str(output_path) f.name: str = str(output_path)
# Render and then format the output file # Render and then format the output file
f.content: str = black.format_str( f.content: str = outputfile_compiler(output_file=output_package)
template.render(description=template_data),
mode=black.FileMode(target_versions={black.TargetVersion.PY37}),
)
# Make each output directory a package with __init__ file # Make each output directory a package with __init__ file
init_files = ( init_files = (

View File

@ -1,26 +1,26 @@
# Generated by the protocol buffer compiler. DO NOT EDIT! # Generated by the protocol buffer compiler. DO NOT EDIT!
# sources: {{ ', '.join(description.input_filenames) }} # sources: {{ ', '.join(output_file.input_filenames) }}
# plugin: python-betterproto # plugin: python-betterproto
{% for i in description.python_module_imports|sort %} {% for i in output_file.python_module_imports|sort %}
import {{ i }} import {{ i }}
{% endfor %} {% endfor %}
from dataclasses import dataclass from dataclasses import dataclass
{% if description.datetime_imports %} {% if output_file.datetime_imports %}
from datetime import {% for i in description.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif%} {% endif%}
{% if description.typing_imports %} {% if output_file.typing_imports %}
from typing import {% for i in description.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}
{% endif %} {% endif %}
import betterproto import betterproto
{% if description.services %} {% if output_file.services %}
import grpclib import grpclib
{% endif %} {% endif %}
{% if description.enums %}{% for enum in description.enums %} {% if output_file.enums %}{% for enum in output_file.enums %}
class {{ enum.py_name }}(betterproto.Enum): class {{ enum.py_name }}(betterproto.Enum):
{% if enum.comment %} {% if enum.comment %}
{{ enum.comment }} {{ enum.comment }}
@ -36,7 +36,7 @@ class {{ enum.py_name }}(betterproto.Enum):
{% endfor %} {% endfor %}
{% endif %} {% endif %}
{% for message in description.messages %} {% for message in output_file.messages %}
@dataclass @dataclass
class {{ message.py_name }}(betterproto.Message): class {{ message.py_name }}(betterproto.Message):
{% if message.comment %} {% if message.comment %}
@ -67,7 +67,7 @@ class {{ message.py_name }}(betterproto.Message):
{% endfor %} {% endfor %}
{% for service in description.services %} {% for service in output_file.services %}
class {{ service.py_name }}Stub(betterproto.ServiceStub): class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% if service.comment %} {% if service.comment %}
{{ service.comment }} {{ service.comment }}
@ -154,6 +154,6 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub):
{% endfor %} {% endfor %}
{% endfor %} {% endfor %}
{% for i in description.imports|sort %} {% for i in output_file.imports|sort %}
{{ i }} {{ i }}
{% endfor %} {% endfor %}