Added an configuration option to specify the path of the source folder.

This will make aerich work with various folder structures (e.g. ./src/MyPythonModule)
Additionally this will try to import in init and show the user the error message on failure.
This commit is contained in:
- 2021-06-10 16:52:03 +02:00
parent 0c763c6024
commit 10b7272ca8
2 changed files with 47 additions and 7 deletions

View File

@ -1,6 +1,5 @@
import asyncio import asyncio
import os import os
import sys
from configparser import ConfigParser from configparser import ConfigParser
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
@ -16,6 +15,7 @@ from tortoise.utils import get_schema_sql
from aerich.inspectdb import InspectDb from aerich.inspectdb import InspectDb
from aerich.migrate import Migrate from aerich.migrate import Migrate
from aerich.utils import ( from aerich.utils import (
add_src_path,
get_app_connection, get_app_connection,
get_app_connection_name, get_app_connection_name,
get_models_describe, get_models_describe,
@ -23,7 +23,6 @@ from aerich.utils import (
get_version_content_from_file, get_version_content_from_file,
write_version_file, write_version_file,
) )
from . import __version__ from . import __version__
from .enums import Color from .enums import Color
from .models import Aerich from .models import Aerich
@ -74,6 +73,10 @@ async def cli(ctx: Context, config, app, name):
location = parser[name]["location"] location = parser[name]["location"]
tortoise_orm = parser[name]["tortoise_orm"] tortoise_orm = parser[name]["tortoise_orm"]
src_folder = parser[name]["src_folder"]
# Add specified source folder to path
add_src_path(src_folder)
tortoise_config = get_tortoise_config(ctx, tortoise_orm) tortoise_config = get_tortoise_config(ctx, tortoise_orm)
app = app or list(tortoise_config.get("apps").keys())[0] app = app or list(tortoise_config.get("apps").keys())[0]
@ -214,19 +217,34 @@ async def history(ctx: Context):
@click.option( @click.option(
"--location", default="./migrations", show_default=True, help="Migrate store location.", "--location", default="./migrations", show_default=True, help="Migrate store location.",
) )
@click.option(
"-s",
"--src_folder", default=".", show_default=False, help="Folder of the source, relative to the project root."
)
@click.pass_context @click.pass_context
@coro @coro
async def init( async def init(
ctx: Context, tortoise_orm, location, ctx: Context, tortoise_orm, location, src_folder
): ):
config_file = ctx.obj["config_file"] config_file = ctx.obj["config_file"]
name = ctx.obj["name"] name = ctx.obj["name"]
if Path(config_file).exists(): if Path(config_file).exists():
return click.secho("You have inited", fg=Color.yellow) return click.secho("Configuration file already created", fg=Color.yellow)
if os.path.isabs(src_folder):
src_folder = os.path.relpath(os.getcwd(), src_folder)
# Add ./ so it's clear that this is relative path
if not src_folder.startswith('./'):
src_folder = './' + src_folder
# check that we can find the configuration, if not we can fail before the config file gets created
add_src_path(src_folder)
get_tortoise_config(ctx, tortoise_orm)
parser.add_section(name) parser.add_section(name)
parser.set(name, "tortoise_orm", tortoise_orm) parser.set(name, "tortoise_orm", tortoise_orm)
parser.set(name, "location", location) parser.set(name, "location", location)
parser.set(name, "src_folder", src_folder)
with open(config_file, "w", encoding="utf-8") as f: with open(config_file, "w", encoding="utf-8") as f:
parser.write(f) parser.write(f)
@ -294,7 +312,6 @@ async def inspectdb(ctx: Context, table: List[str]):
def main(): def main():
sys.path.insert(0, ".")
cli() cli()

View File

@ -1,12 +1,30 @@
import importlib import importlib
import os
import re import re
import sys
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict
from click import BadOptionUsage, Context from click import BadOptionUsage, Context, ClickException
from tortoise import BaseDBAsyncClient, Tortoise from tortoise import BaseDBAsyncClient, Tortoise
def add_src_path(path: str) -> str:
"""
add a folder to the paths so we can import from there
:param path: path to add
:return: absolute path
"""
if not os.path.isabs(path):
# use the absolute path, otherwise some other things (e.g. __file__) won't work properly
path = os.path.abspath(path)
if not os.path.isdir(path):
raise ClickException(f"Specified source folder does not exist: {path}")
if path not in sys.path:
sys.path.insert(0, path)
return path
def get_app_connection_name(config, app_name: str) -> str: def get_app_connection_name(config, app_name: str) -> str:
""" """
get connection name get connection name
@ -42,7 +60,12 @@ def get_tortoise_config(ctx: Context, tortoise_orm: str) -> dict:
splits = tortoise_orm.split(".") splits = tortoise_orm.split(".")
config_path = ".".join(splits[:-1]) config_path = ".".join(splits[:-1])
tortoise_config = splits[-1] tortoise_config = splits[-1]
config_module = importlib.import_module(config_path)
try:
config_module = importlib.import_module(config_path)
except ModuleNotFoundError as e:
raise ClickException(f'Error while importing configuration module: {e}') from None
config = getattr(config_module, tortoise_config, None) config = getattr(config_module, tortoise_config, None)
if not config: if not config:
raise BadOptionUsage( raise BadOptionUsage(