diff --git a/aerich/__init__.py b/aerich/__init__.py index f3e045c..c6943c4 100644 --- a/aerich/__init__.py +++ b/aerich/__init__.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import List +from typing import TYPE_CHECKING, List, Optional, Type from tortoise import Tortoise, generate_schema_for_client from tortoise.exceptions import OperationalError @@ -20,6 +20,9 @@ from aerich.utils import ( import_py_file, ) +if TYPE_CHECKING: + from aerich.inspectdb import Inspect # noqa:F401 + class Command: def __init__( @@ -27,16 +30,16 @@ class Command: tortoise_config: dict, app: str = "models", location: str = "./migrations", - ): + ) -> None: self.tortoise_config = tortoise_config self.app = app self.location = location Migrate.app = app - async def init(self): + async def init(self) -> None: await Migrate.init(self.tortoise_config, self.app, self.location) - async def _upgrade(self, conn, version_file): + async def _upgrade(self, conn, version_file) -> None: file_path = Path(Migrate.migrate_location, version_file) m = import_py_file(file_path) upgrade = getattr(m, "upgrade") @@ -47,7 +50,7 @@ class Command: content=get_models_describe(self.app), ) - async def upgrade(self, run_in_transaction: bool = True): + async def upgrade(self, run_in_transaction: bool = True) -> List[str]: migrated = [] for version_file in Migrate.get_all_version_files(): try: @@ -65,8 +68,8 @@ class Command: migrated.append(version_file) return migrated - async def downgrade(self, version: int, delete: bool): - ret = [] + async def downgrade(self, version: int, delete: bool) -> List[str]: + ret: List[str] = [] if version == -1: specified_version = await Migrate.get_last_version() else: @@ -79,8 +82,8 @@ class Command: versions = [specified_version] else: versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk) - for version in versions: - file = version.version + for version_obj in versions: + file = version_obj.version async with in_transaction( get_app_connection_name(self.tortoise_config, self.app) ) as conn: @@ -91,13 +94,13 @@ class Command: if not downgrade_sql.strip(): raise DowngradeError("No downgrade items found") await conn.execute_script(downgrade_sql) - await version.delete() + await version_obj.delete() if delete: os.unlink(file_path) ret.append(file) return ret - async def heads(self): + async def heads(self) -> List[str]: ret = [] versions = Migrate.get_all_version_files() for version in versions: @@ -105,15 +108,15 @@ class Command: ret.append(version) return ret - async def history(self): + async def history(self) -> List[str]: versions = Migrate.get_all_version_files() return [version for version in versions] - async def inspectdb(self, tables: List[str] = None) -> str: + async def inspectdb(self, tables: Optional[List[str]] = None) -> str: connection = get_app_connection(self.tortoise_config, self.app) dialect = connection.schema_generator.DIALECT if dialect == "mysql": - cls = InspectMySQL + cls: Type["Inspect"] = InspectMySQL elif dialect == "postgres": cls = InspectPostgres elif dialect == "sqlite": @@ -126,7 +129,7 @@ class Command: async def migrate(self, name: str = "update", empty: bool = False) -> str: return await Migrate.migrate(name, empty) - async def init_db(self, safe: bool): + async def init_db(self, safe: bool) -> None: location = self.location app = self.app dirname = Path(location, app) diff --git a/aerich/cli.py b/aerich/cli.py index 924f044..c40ce8e 100644 --- a/aerich/cli.py +++ b/aerich/cli.py @@ -2,7 +2,7 @@ import asyncio import os from functools import wraps from pathlib import Path -from typing import List +from typing import Dict, List, cast import click import tomlkit @@ -23,7 +23,7 @@ CONFIG_DEFAULT_VALUES = { def coro(f): @wraps(f) - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs) -> None: loop = asyncio.get_event_loop() # Close db connections at the end of all but the cli group function @@ -48,7 +48,7 @@ def coro(f): @click.option("--app", required=False, help="Tortoise-ORM app name.") @click.pass_context @coro -async def cli(ctx: Context, config, app): +async def cli(ctx: Context, config, app) -> None: ctx.ensure_object(dict) ctx.obj["config_file"] = config @@ -58,9 +58,9 @@ async def cli(ctx: Context, config, app): if not config_path.exists(): raise UsageError("You must exec init first", ctx=ctx) content = config_path.read_text() - doc = tomlkit.parse(content) + doc: dict = tomlkit.parse(content) try: - tool = doc["tool"]["aerich"] + tool = cast(Dict[str, str], doc["tool"]["aerich"]) location = tool["location"] tortoise_orm = tool["tortoise_orm"] src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"]) @@ -68,7 +68,9 @@ async def cli(ctx: Context, config, app): raise UsageError("You need run aerich init again when upgrade to 0.6.0+") add_src_path(src_folder) tortoise_config = get_tortoise_config(ctx, tortoise_orm) - app = app or list(tortoise_config.get("apps").keys())[0] + if not app: + apps_config = cast(dict, tortoise_config.get("apps")) + app = list(apps_config.keys())[0] command = Command(tortoise_config=tortoise_config, app=app, location=location) ctx.obj["command"] = command if invoked_subcommand != "init-db": @@ -82,7 +84,7 @@ async def cli(ctx: Context, config, app): @click.option("--empty", default=False, is_flag=True, help="Generate empty migration file.") @click.pass_context @coro -async def migrate(ctx: Context, name): +async def migrate(ctx: Context, name) -> None: command = ctx.obj["command"] ret = await command.migrate(name) if not ret: @@ -100,7 +102,7 @@ async def migrate(ctx: Context, name): ) @click.pass_context @coro -async def upgrade(ctx: Context, in_transaction: bool): +async def upgrade(ctx: Context, in_transaction: bool) -> None: command = ctx.obj["command"] migrated = await command.upgrade(run_in_transaction=in_transaction) if not migrated: @@ -132,7 +134,7 @@ async def upgrade(ctx: Context, in_transaction: bool): prompt="Downgrade is dangerous, which maybe lose your data, are you sure?", ) @coro -async def downgrade(ctx: Context, version: int, delete: bool): +async def downgrade(ctx: Context, version: int, delete: bool) -> None: command = ctx.obj["command"] try: files = await command.downgrade(version, delete) @@ -145,7 +147,7 @@ async def downgrade(ctx: Context, version: int, delete: bool): @cli.command(help="Show current available heads in migrate location.") @click.pass_context @coro -async def heads(ctx: Context): +async def heads(ctx: Context) -> None: command = ctx.obj["command"] head_list = await command.heads() if not head_list: @@ -157,7 +159,7 @@ async def heads(ctx: Context): @cli.command(help="List all migrate items.") @click.pass_context @coro -async def history(ctx: Context): +async def history(ctx: Context) -> None: command = ctx.obj["command"] versions = await command.history() if not versions: @@ -188,7 +190,7 @@ async def history(ctx: Context): ) @click.pass_context @coro -async def init(ctx: Context, tortoise_orm, location, src_folder): +async def init(ctx: Context, tortoise_orm, location, src_folder) -> None: config_file = ctx.obj["config_file"] if os.path.isabs(src_folder): @@ -203,9 +205,9 @@ async def init(ctx: Context, tortoise_orm, location, src_folder): config_path = Path(config_file) if config_path.exists(): content = config_path.read_text() - doc = tomlkit.parse(content) else: - doc = tomlkit.parse("[tool.aerich]") + content = "[tool.aerich]" + doc: dict = tomlkit.parse(content) table = tomlkit.table() table["tortoise_orm"] = tortoise_orm table["location"] = location @@ -232,7 +234,7 @@ async def init(ctx: Context, tortoise_orm, location, src_folder): ) @click.pass_context @coro -async def init_db(ctx: Context, safe: bool): +async def init_db(ctx: Context, safe: bool) -> None: command = ctx.obj["command"] app = command.app dirname = Path(command.location, app) @@ -256,13 +258,13 @@ async def init_db(ctx: Context, safe: bool): ) @click.pass_context @coro -async def inspectdb(ctx: Context, table: List[str]): +async def inspectdb(ctx: Context, table: List[str]) -> None: command = ctx.obj["command"] ret = await command.inspectdb(table) click.secho(ret) -def main(): +def main() -> None: cli() diff --git a/aerich/coder.py b/aerich/coder.py index 3501c24..870ee7a 100644 --- a/aerich/coder.py +++ b/aerich/coder.py @@ -1,12 +1,13 @@ import base64 import json import pickle # nosec: B301,B403 +from typing import Any, Union from tortoise.indexes import Index class JsonEncoder(json.JSONEncoder): - def default(self, obj): + def default(self, obj) -> Any: if isinstance(obj, Index): return { "type": "index", @@ -16,16 +17,16 @@ class JsonEncoder(json.JSONEncoder): return super().default(obj) -def object_hook(obj): +def object_hook(obj) -> Any: _type = obj.get("type") if not _type: return obj return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301 -def encoder(obj: dict): +def encoder(obj: dict) -> str: return json.dumps(obj, cls=JsonEncoder) -def decoder(obj: str): +def decoder(obj: Union[str, bytes]) -> Any: return json.loads(obj, object_hook=object_hook) diff --git a/aerich/utils.py b/aerich/utils.py index ac5fb87..14664a0 100644 --- a/aerich/utils.py +++ b/aerich/utils.py @@ -3,7 +3,8 @@ import os import re import sys from pathlib import Path -from typing import Dict +from types import ModuleType +from typing import Dict, Optional from click import BadOptionUsage, ClickException, Context from tortoise import BaseDBAsyncClient, Tortoise @@ -84,19 +85,19 @@ def get_models_describe(app: str) -> Dict: :return: """ ret = {} - for model in Tortoise.apps.get(app).values(): + for model in Tortoise.apps[app].values(): describe = model.describe() ret[describe.get("name")] = describe return ret -def is_default_function(string: str): +def is_default_function(string: str) -> Optional[re.Match]: return re.match(r"^$", str(string or "")) -def import_py_file(file: Path): +def import_py_file(file: Path) -> ModuleType: module_name, file_ext = os.path.splitext(os.path.split(file)[-1]) spec = importlib.util.spec_from_file_location(module_name, file) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + module = importlib.util.module_from_spec(spec) # type:ignore[arg-type] + spec.loader.exec_module(module) # type:ignore[union-attr] return module