Factor code template compilation out into a separate module
This commit is contained in:
		
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			
						parent
						
							80bef7c94f
						
					
				
				
					commit
					c93351ef21
				
			
							
								
								
									
										40
									
								
								src/betterproto/plugin/compiler.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								src/betterproto/plugin/compiler.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,40 @@ | |||||||
|  | import os.path | ||||||
|  |  | ||||||
|  | try: | ||||||
|  |     # betterproto[compiler] specific dependencies | ||||||
|  |     import black | ||||||
|  |     import jinja2 | ||||||
|  | except ImportError as err: | ||||||
|  |     missing_import = err.args[0][17:-1] | ||||||
|  |     print( | ||||||
|  |         "\033[31m" | ||||||
|  |         f"Unable to import `{missing_import}` from betterproto plugin! " | ||||||
|  |         "Please ensure that you've installed betterproto as " | ||||||
|  |         '`pip install "betterproto[compiler]"` so that compiler dependencies ' | ||||||
|  |         "are included." | ||||||
|  |         "\033[0m" | ||||||
|  |     ) | ||||||
|  |     raise SystemExit(1) | ||||||
|  |  | ||||||
|  | from betterproto.plugin.models import OutputTemplate | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def outputfile_compiler(output_file: OutputTemplate) -> str: | ||||||
|  |  | ||||||
|  |     templates_folder = os.path.abspath( | ||||||
|  |         os.path.join(os.path.dirname(__file__), "..", "templates") | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     env = jinja2.Environment( | ||||||
|  |         trim_blocks=True, | ||||||
|  |         lstrip_blocks=True, | ||||||
|  |         loader=jinja2.FileSystemLoader(templates_folder), | ||||||
|  |     ) | ||||||
|  |     template = env.get_template("template.py.j2") | ||||||
|  |  | ||||||
|  |     res = black.format_str( | ||||||
|  |         template.render(output_file=output_file), | ||||||
|  |         mode=black.FileMode(target_versions={black.TargetVersion.PY37}), | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     return res | ||||||
| @@ -1,12 +1,10 @@ | |||||||
| import itertools | import itertools | ||||||
| import os.path |  | ||||||
| import pathlib | import pathlib | ||||||
| import sys | import sys | ||||||
| from typing import List, Iterator | from typing import List, Iterator | ||||||
|  |  | ||||||
| try: | try: | ||||||
|     # betterproto[compiler] specific dependencies |     # betterproto[compiler] specific dependencies | ||||||
|     import black |  | ||||||
|     from google.protobuf.compiler import plugin_pb2 as plugin |     from google.protobuf.compiler import plugin_pb2 as plugin | ||||||
|     from google.protobuf.descriptor_pb2 import ( |     from google.protobuf.descriptor_pb2 import ( | ||||||
|         DescriptorProto, |         DescriptorProto, | ||||||
| @@ -14,7 +12,6 @@ try: | |||||||
|         FieldDescriptorProto, |         FieldDescriptorProto, | ||||||
|         ServiceDescriptorProto, |         ServiceDescriptorProto, | ||||||
|     ) |     ) | ||||||
|     import jinja2 |  | ||||||
| except ImportError as err: | except ImportError as err: | ||||||
|     missing_import = err.args[0][17:-1] |     missing_import = err.args[0][17:-1] | ||||||
|     print( |     print( | ||||||
| @@ -41,6 +38,8 @@ from betterproto.plugin.models import ( | |||||||
|     is_oneof, |     is_oneof, | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  | from betterproto.plugin.compiler import outputfile_compiler | ||||||
|  |  | ||||||
|  |  | ||||||
| def traverse(proto_file: FieldDescriptorProto) -> Iterator: | def traverse(proto_file: FieldDescriptorProto) -> Iterator: | ||||||
|     # Todo: Keep information about nested hierarchy |     # Todo: Keep information about nested hierarchy | ||||||
| @@ -70,16 +69,6 @@ def generate_code( | |||||||
| ) -> None: | ) -> None: | ||||||
|     plugin_options = request.parameter.split(",") if request.parameter else [] |     plugin_options = request.parameter.split(",") if request.parameter else [] | ||||||
|  |  | ||||||
|     templates_folder = os.path.abspath( |  | ||||||
|         os.path.join(os.path.dirname(__file__), "..", "templates") |  | ||||||
|     ) |  | ||||||
|  |  | ||||||
|     env = jinja2.Environment( |  | ||||||
|         trim_blocks=True, |  | ||||||
|         lstrip_blocks=True, |  | ||||||
|         loader=jinja2.FileSystemLoader(templates_folder), |  | ||||||
|     ) |  | ||||||
|     template = env.get_template("template.py.j2") |  | ||||||
|     request_data = PluginRequestCompiler(plugin_request_obj=request) |     request_data = PluginRequestCompiler(plugin_request_obj=request) | ||||||
|     # Gather output packages |     # Gather output packages | ||||||
|     for proto_file in request.proto_file: |     for proto_file in request.proto_file: | ||||||
| @@ -116,7 +105,7 @@ def generate_code( | |||||||
|  |  | ||||||
|     # Generate output files |     # Generate output files | ||||||
|     output_paths: pathlib.Path = set() |     output_paths: pathlib.Path = set() | ||||||
|     for output_package_name, template_data in request_data.output_packages.items(): |     for output_package_name, output_package in request_data.output_packages.items(): | ||||||
|  |  | ||||||
|         # Add files to the response object |         # Add files to the response object | ||||||
|         output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") |         output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") | ||||||
| @@ -126,10 +115,7 @@ def generate_code( | |||||||
|         f.name: str = str(output_path) |         f.name: str = str(output_path) | ||||||
|  |  | ||||||
|         # Render and then format the output file |         # Render and then format the output file | ||||||
|         f.content: str = black.format_str( |         f.content: str = outputfile_compiler(output_file=output_package) | ||||||
|             template.render(description=template_data), |  | ||||||
|             mode=black.FileMode(target_versions={black.TargetVersion.PY37}), |  | ||||||
|         ) |  | ||||||
|  |  | ||||||
|     # Make each output directory a package with __init__ file |     # Make each output directory a package with __init__ file | ||||||
|     init_files = ( |     init_files = ( | ||||||
|   | |||||||
| @@ -1,26 +1,26 @@ | |||||||
| # Generated by the protocol buffer compiler.  DO NOT EDIT! | # Generated by the protocol buffer compiler.  DO NOT EDIT! | ||||||
| # sources: {{ ', '.join(description.input_filenames) }} | # sources: {{ ', '.join(output_file.input_filenames) }} | ||||||
| # plugin: python-betterproto | # plugin: python-betterproto | ||||||
| {% for i in description.python_module_imports|sort %} | {% for i in output_file.python_module_imports|sort %} | ||||||
| import {{ i }} | import {{ i }} | ||||||
| {% endfor %} | {% endfor %} | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
| {% if description.datetime_imports %} | {% if output_file.datetime_imports %} | ||||||
| from datetime import {% for i in description.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} | from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} | ||||||
|  |  | ||||||
| {% endif%} | {% endif%} | ||||||
| {% if description.typing_imports %} | {% if output_file.typing_imports %} | ||||||
| from typing import {% for i in description.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} | from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} | ||||||
|  |  | ||||||
| {% endif %} | {% endif %} | ||||||
|  |  | ||||||
| import betterproto | import betterproto | ||||||
| {% if description.services %} | {% if output_file.services %} | ||||||
| import grpclib | import grpclib | ||||||
| {% endif %} | {% endif %} | ||||||
|  |  | ||||||
|  |  | ||||||
| {% if description.enums %}{% for enum in description.enums %} | {% if output_file.enums %}{% for enum in output_file.enums %} | ||||||
| class {{ enum.py_name }}(betterproto.Enum): | class {{ enum.py_name }}(betterproto.Enum): | ||||||
|     {% if enum.comment %} |     {% if enum.comment %} | ||||||
| {{ enum.comment }} | {{ enum.comment }} | ||||||
| @@ -36,7 +36,7 @@ class {{ enum.py_name }}(betterproto.Enum): | |||||||
|  |  | ||||||
| {% endfor %} | {% endfor %} | ||||||
| {% endif %} | {% endif %} | ||||||
| {% for message in description.messages %} | {% for message in output_file.messages %} | ||||||
| @dataclass | @dataclass | ||||||
| class {{ message.py_name }}(betterproto.Message): | class {{ message.py_name }}(betterproto.Message): | ||||||
|     {% if message.comment %} |     {% if message.comment %} | ||||||
| @@ -67,7 +67,7 @@ class {{ message.py_name }}(betterproto.Message): | |||||||
|  |  | ||||||
|  |  | ||||||
| {% endfor %} | {% endfor %} | ||||||
| {% for service in description.services %} | {% for service in output_file.services %} | ||||||
| class {{ service.py_name }}Stub(betterproto.ServiceStub): | class {{ service.py_name }}Stub(betterproto.ServiceStub): | ||||||
|     {% if service.comment %} |     {% if service.comment %} | ||||||
| {{ service.comment }} | {{ service.comment }} | ||||||
| @@ -154,6 +154,6 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): | |||||||
|     {% endfor %} |     {% endfor %} | ||||||
| {% endfor %} | {% endfor %} | ||||||
|  |  | ||||||
| {% for i in description.imports|sort %} | {% for i in output_file.imports|sort %} | ||||||
| {{ i }} | {{ i }} | ||||||
| {% endfor %} | {% endfor %} | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user