Added an configuration option to specify the path of the source folder.
This will make aerich work with various folder structures (e.g. ./src/MyPythonModule) Additionally this will try to import in init and show the user the error message on failure.
This commit is contained in:
		| @@ -1,6 +1,5 @@ | |||||||
| import asyncio | import asyncio | ||||||
| import os | import os | ||||||
| import sys |  | ||||||
| from configparser import ConfigParser | from configparser import ConfigParser | ||||||
| from functools import wraps | from functools import wraps | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| @@ -16,6 +15,7 @@ from tortoise.utils import get_schema_sql | |||||||
| from aerich.inspectdb import InspectDb | from aerich.inspectdb import InspectDb | ||||||
| from aerich.migrate import Migrate | from aerich.migrate import Migrate | ||||||
| from aerich.utils import ( | from aerich.utils import ( | ||||||
|  |     add_src_path, | ||||||
|     get_app_connection, |     get_app_connection, | ||||||
|     get_app_connection_name, |     get_app_connection_name, | ||||||
|     get_models_describe, |     get_models_describe, | ||||||
| @@ -23,7 +23,6 @@ from aerich.utils import ( | |||||||
|     get_version_content_from_file, |     get_version_content_from_file, | ||||||
|     write_version_file, |     write_version_file, | ||||||
| ) | ) | ||||||
|  |  | ||||||
| from . import __version__ | from . import __version__ | ||||||
| from .enums import Color | from .enums import Color | ||||||
| from .models import Aerich | from .models import Aerich | ||||||
| @@ -74,6 +73,10 @@ async def cli(ctx: Context, config, app, name): | |||||||
|  |  | ||||||
|         location = parser[name]["location"] |         location = parser[name]["location"] | ||||||
|         tortoise_orm = parser[name]["tortoise_orm"] |         tortoise_orm = parser[name]["tortoise_orm"] | ||||||
|  |         src_folder = parser[name]["src_folder"] | ||||||
|  |  | ||||||
|  |         # Add specified source folder to path | ||||||
|  |         add_src_path(src_folder) | ||||||
|  |  | ||||||
|         tortoise_config = get_tortoise_config(ctx, tortoise_orm) |         tortoise_config = get_tortoise_config(ctx, tortoise_orm) | ||||||
|         app = app or list(tortoise_config.get("apps").keys())[0] |         app = app or list(tortoise_config.get("apps").keys())[0] | ||||||
| @@ -214,19 +217,34 @@ async def history(ctx: Context): | |||||||
| @click.option( | @click.option( | ||||||
|     "--location", default="./migrations", show_default=True, help="Migrate store location.", |     "--location", default="./migrations", show_default=True, help="Migrate store location.", | ||||||
| ) | ) | ||||||
|  | @click.option( | ||||||
|  |     "-s", | ||||||
|  |     "--src_folder", default=".", show_default=False, help="Folder of the source, relative to the project root." | ||||||
|  | ) | ||||||
| @click.pass_context | @click.pass_context | ||||||
| @coro | @coro | ||||||
| async def init( | async def init( | ||||||
|     ctx: Context, tortoise_orm, location, |     ctx: Context, tortoise_orm, location, src_folder | ||||||
| ): | ): | ||||||
|     config_file = ctx.obj["config_file"] |     config_file = ctx.obj["config_file"] | ||||||
|     name = ctx.obj["name"] |     name = ctx.obj["name"] | ||||||
|     if Path(config_file).exists(): |     if Path(config_file).exists(): | ||||||
|         return click.secho("You have inited", fg=Color.yellow) |         return click.secho("Configuration file already created", fg=Color.yellow) | ||||||
|  |  | ||||||
|  |     if os.path.isabs(src_folder): | ||||||
|  |         src_folder = os.path.relpath(os.getcwd(), src_folder) | ||||||
|  |     # Add ./ so it's clear that this is relative path | ||||||
|  |     if not src_folder.startswith('./'): | ||||||
|  |         src_folder = './' + src_folder | ||||||
|  |  | ||||||
|  |     # check that we can find the configuration, if not we can fail before the config file gets created | ||||||
|  |     add_src_path(src_folder) | ||||||
|  |     get_tortoise_config(ctx, tortoise_orm) | ||||||
|  |  | ||||||
|     parser.add_section(name) |     parser.add_section(name) | ||||||
|     parser.set(name, "tortoise_orm", tortoise_orm) |     parser.set(name, "tortoise_orm", tortoise_orm) | ||||||
|     parser.set(name, "location", location) |     parser.set(name, "location", location) | ||||||
|  |     parser.set(name, "src_folder", src_folder) | ||||||
|  |  | ||||||
|     with open(config_file, "w", encoding="utf-8") as f: |     with open(config_file, "w", encoding="utf-8") as f: | ||||||
|         parser.write(f) |         parser.write(f) | ||||||
| @@ -294,7 +312,6 @@ async def inspectdb(ctx: Context, table: List[str]): | |||||||
|  |  | ||||||
|  |  | ||||||
| def main(): | def main(): | ||||||
|     sys.path.insert(0, ".") |  | ||||||
|     cli() |     cli() | ||||||
|  |  | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1,12 +1,30 @@ | |||||||
| import importlib | import importlib | ||||||
|  | import os | ||||||
| import re | import re | ||||||
|  | import sys | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| from typing import Dict | from typing import Dict | ||||||
|  |  | ||||||
| from click import BadOptionUsage, Context | from click import BadOptionUsage, Context, ClickException | ||||||
| from tortoise import BaseDBAsyncClient, Tortoise | from tortoise import BaseDBAsyncClient, Tortoise | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def add_src_path(path: str) -> str: | ||||||
|  |     """ | ||||||
|  |     add a folder to the paths so we can import from there | ||||||
|  |     :param path: path to add | ||||||
|  |     :return: absolute path | ||||||
|  |     """ | ||||||
|  |     if not os.path.isabs(path): | ||||||
|  |         # use the absolute path, otherwise some other things (e.g. __file__) won't work properly | ||||||
|  |         path = os.path.abspath(path) | ||||||
|  |     if not os.path.isdir(path): | ||||||
|  |         raise ClickException(f"Specified source folder does not exist: {path}") | ||||||
|  |     if path not in sys.path: | ||||||
|  |         sys.path.insert(0, path) | ||||||
|  |     return path | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_app_connection_name(config, app_name: str) -> str: | def get_app_connection_name(config, app_name: str) -> str: | ||||||
|     """ |     """ | ||||||
|     get connection name |     get connection name | ||||||
| @@ -42,7 +60,12 @@ def get_tortoise_config(ctx: Context, tortoise_orm: str) -> dict: | |||||||
|     splits = tortoise_orm.split(".") |     splits = tortoise_orm.split(".") | ||||||
|     config_path = ".".join(splits[:-1]) |     config_path = ".".join(splits[:-1]) | ||||||
|     tortoise_config = splits[-1] |     tortoise_config = splits[-1] | ||||||
|  |  | ||||||
|  |     try: | ||||||
|         config_module = importlib.import_module(config_path) |         config_module = importlib.import_module(config_path) | ||||||
|  |     except ModuleNotFoundError as e: | ||||||
|  |         raise ClickException(f'Error while importing configuration module: {e}') from None | ||||||
|  |  | ||||||
|     config = getattr(config_module, tortoise_config, None) |     config = getattr(config_module, tortoise_config, None) | ||||||
|     if not config: |     if not config: | ||||||
|         raise BadOptionUsage( |         raise BadOptionUsage( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user