Simple type hints for aerich/

This commit is contained in:
Waket Zheng 2024-06-01 21:16:53 +08:00
parent 51117867a6
commit 8756f64e3f
4 changed files with 49 additions and 42 deletions

View File

@ -1,6 +1,6 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import List from typing import TYPE_CHECKING, List, Optional, Type
from tortoise import Tortoise, generate_schema_for_client from tortoise import Tortoise, generate_schema_for_client
from tortoise.exceptions import OperationalError from tortoise.exceptions import OperationalError
@ -20,6 +20,9 @@ from aerich.utils import (
import_py_file, import_py_file,
) )
if TYPE_CHECKING:
from aerich.inspectdb import Inspect # noqa:F401
class Command: class Command:
def __init__( def __init__(
@ -27,16 +30,16 @@ class Command:
tortoise_config: dict, tortoise_config: dict,
app: str = "models", app: str = "models",
location: str = "./migrations", location: str = "./migrations",
): ) -> None:
self.tortoise_config = tortoise_config self.tortoise_config = tortoise_config
self.app = app self.app = app
self.location = location self.location = location
Migrate.app = app Migrate.app = app
async def init(self): async def init(self) -> None:
await Migrate.init(self.tortoise_config, self.app, self.location) await Migrate.init(self.tortoise_config, self.app, self.location)
async def _upgrade(self, conn, version_file): async def _upgrade(self, conn, version_file) -> None:
file_path = Path(Migrate.migrate_location, version_file) file_path = Path(Migrate.migrate_location, version_file)
m = import_py_file(file_path) m = import_py_file(file_path)
upgrade = getattr(m, "upgrade") upgrade = getattr(m, "upgrade")
@ -47,7 +50,7 @@ class Command:
content=get_models_describe(self.app), content=get_models_describe(self.app),
) )
async def upgrade(self, run_in_transaction: bool = True): async def upgrade(self, run_in_transaction: bool = True) -> List[str]:
migrated = [] migrated = []
for version_file in Migrate.get_all_version_files(): for version_file in Migrate.get_all_version_files():
try: try:
@ -65,8 +68,8 @@ class Command:
migrated.append(version_file) migrated.append(version_file)
return migrated return migrated
async def downgrade(self, version: int, delete: bool): async def downgrade(self, version: int, delete: bool) -> List[str]:
ret = [] ret: List[str] = []
if version == -1: if version == -1:
specified_version = await Migrate.get_last_version() specified_version = await Migrate.get_last_version()
else: else:
@ -79,8 +82,8 @@ class Command:
versions = [specified_version] versions = [specified_version]
else: else:
versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk) versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk)
for version in versions: for version_obj in versions:
file = version.version file = version_obj.version
async with in_transaction( async with in_transaction(
get_app_connection_name(self.tortoise_config, self.app) get_app_connection_name(self.tortoise_config, self.app)
) as conn: ) as conn:
@ -91,13 +94,13 @@ class Command:
if not downgrade_sql.strip(): if not downgrade_sql.strip():
raise DowngradeError("No downgrade items found") raise DowngradeError("No downgrade items found")
await conn.execute_script(downgrade_sql) await conn.execute_script(downgrade_sql)
await version.delete() await version_obj.delete()
if delete: if delete:
os.unlink(file_path) os.unlink(file_path)
ret.append(file) ret.append(file)
return ret return ret
async def heads(self): async def heads(self) -> List[str]:
ret = [] ret = []
versions = Migrate.get_all_version_files() versions = Migrate.get_all_version_files()
for version in versions: for version in versions:
@ -105,15 +108,15 @@ class Command:
ret.append(version) ret.append(version)
return ret return ret
async def history(self): async def history(self) -> List[str]:
versions = Migrate.get_all_version_files() versions = Migrate.get_all_version_files()
return [version for version in versions] return [version for version in versions]
async def inspectdb(self, tables: List[str] = None) -> str: async def inspectdb(self, tables: Optional[List[str]] = None) -> str:
connection = get_app_connection(self.tortoise_config, self.app) connection = get_app_connection(self.tortoise_config, self.app)
dialect = connection.schema_generator.DIALECT dialect = connection.schema_generator.DIALECT
if dialect == "mysql": if dialect == "mysql":
cls = InspectMySQL cls: Type["Inspect"] = InspectMySQL
elif dialect == "postgres": elif dialect == "postgres":
cls = InspectPostgres cls = InspectPostgres
elif dialect == "sqlite": elif dialect == "sqlite":
@ -126,7 +129,7 @@ class Command:
async def migrate(self, name: str = "update", empty: bool = False) -> str: async def migrate(self, name: str = "update", empty: bool = False) -> str:
return await Migrate.migrate(name, empty) return await Migrate.migrate(name, empty)
async def init_db(self, safe: bool): async def init_db(self, safe: bool) -> None:
location = self.location location = self.location
app = self.app app = self.app
dirname = Path(location, app) dirname = Path(location, app)

View File

@ -2,7 +2,7 @@ import asyncio
import os import os
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from typing import List from typing import Dict, List, cast
import click import click
import tomlkit import tomlkit
@ -23,7 +23,7 @@ CONFIG_DEFAULT_VALUES = {
def coro(f): def coro(f):
@wraps(f) @wraps(f)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs) -> None:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
# Close db connections at the end of all but the cli group function # Close db connections at the end of all but the cli group function
@ -48,7 +48,7 @@ def coro(f):
@click.option("--app", required=False, help="Tortoise-ORM app name.") @click.option("--app", required=False, help="Tortoise-ORM app name.")
@click.pass_context @click.pass_context
@coro @coro
async def cli(ctx: Context, config, app): async def cli(ctx: Context, config, app) -> None:
ctx.ensure_object(dict) ctx.ensure_object(dict)
ctx.obj["config_file"] = config ctx.obj["config_file"] = config
@ -58,9 +58,9 @@ async def cli(ctx: Context, config, app):
if not config_path.exists(): if not config_path.exists():
raise UsageError("You must exec init first", ctx=ctx) raise UsageError("You must exec init first", ctx=ctx)
content = config_path.read_text() content = config_path.read_text()
doc = tomlkit.parse(content) doc: dict = tomlkit.parse(content)
try: try:
tool = doc["tool"]["aerich"] tool = cast(Dict[str, str], doc["tool"]["aerich"])
location = tool["location"] location = tool["location"]
tortoise_orm = tool["tortoise_orm"] tortoise_orm = tool["tortoise_orm"]
src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"]) src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"])
@ -68,7 +68,9 @@ async def cli(ctx: Context, config, app):
raise UsageError("You need run aerich init again when upgrade to 0.6.0+") raise UsageError("You need run aerich init again when upgrade to 0.6.0+")
add_src_path(src_folder) add_src_path(src_folder)
tortoise_config = get_tortoise_config(ctx, tortoise_orm) tortoise_config = get_tortoise_config(ctx, tortoise_orm)
app = app or list(tortoise_config.get("apps").keys())[0] if not app:
apps_config = cast(dict, tortoise_config.get("apps"))
app = list(apps_config.keys())[0]
command = Command(tortoise_config=tortoise_config, app=app, location=location) command = Command(tortoise_config=tortoise_config, app=app, location=location)
ctx.obj["command"] = command ctx.obj["command"] = command
if invoked_subcommand != "init-db": if invoked_subcommand != "init-db":
@ -82,7 +84,7 @@ async def cli(ctx: Context, config, app):
@click.option("--empty", default=False, is_flag=True, help="Generate empty migration file.") @click.option("--empty", default=False, is_flag=True, help="Generate empty migration file.")
@click.pass_context @click.pass_context
@coro @coro
async def migrate(ctx: Context, name): async def migrate(ctx: Context, name) -> None:
command = ctx.obj["command"] command = ctx.obj["command"]
ret = await command.migrate(name) ret = await command.migrate(name)
if not ret: if not ret:
@ -100,7 +102,7 @@ async def migrate(ctx: Context, name):
) )
@click.pass_context @click.pass_context
@coro @coro
async def upgrade(ctx: Context, in_transaction: bool): async def upgrade(ctx: Context, in_transaction: bool) -> None:
command = ctx.obj["command"] command = ctx.obj["command"]
migrated = await command.upgrade(run_in_transaction=in_transaction) migrated = await command.upgrade(run_in_transaction=in_transaction)
if not migrated: if not migrated:
@ -132,7 +134,7 @@ async def upgrade(ctx: Context, in_transaction: bool):
prompt="Downgrade is dangerous, which maybe lose your data, are you sure?", prompt="Downgrade is dangerous, which maybe lose your data, are you sure?",
) )
@coro @coro
async def downgrade(ctx: Context, version: int, delete: bool): async def downgrade(ctx: Context, version: int, delete: bool) -> None:
command = ctx.obj["command"] command = ctx.obj["command"]
try: try:
files = await command.downgrade(version, delete) files = await command.downgrade(version, delete)
@ -145,7 +147,7 @@ async def downgrade(ctx: Context, version: int, delete: bool):
@cli.command(help="Show current available heads in migrate location.") @cli.command(help="Show current available heads in migrate location.")
@click.pass_context @click.pass_context
@coro @coro
async def heads(ctx: Context): async def heads(ctx: Context) -> None:
command = ctx.obj["command"] command = ctx.obj["command"]
head_list = await command.heads() head_list = await command.heads()
if not head_list: if not head_list:
@ -157,7 +159,7 @@ async def heads(ctx: Context):
@cli.command(help="List all migrate items.") @cli.command(help="List all migrate items.")
@click.pass_context @click.pass_context
@coro @coro
async def history(ctx: Context): async def history(ctx: Context) -> None:
command = ctx.obj["command"] command = ctx.obj["command"]
versions = await command.history() versions = await command.history()
if not versions: if not versions:
@ -188,7 +190,7 @@ async def history(ctx: Context):
) )
@click.pass_context @click.pass_context
@coro @coro
async def init(ctx: Context, tortoise_orm, location, src_folder): async def init(ctx: Context, tortoise_orm, location, src_folder) -> None:
config_file = ctx.obj["config_file"] config_file = ctx.obj["config_file"]
if os.path.isabs(src_folder): if os.path.isabs(src_folder):
@ -203,9 +205,9 @@ async def init(ctx: Context, tortoise_orm, location, src_folder):
config_path = Path(config_file) config_path = Path(config_file)
if config_path.exists(): if config_path.exists():
content = config_path.read_text() content = config_path.read_text()
doc = tomlkit.parse(content)
else: else:
doc = tomlkit.parse("[tool.aerich]") content = "[tool.aerich]"
doc: dict = tomlkit.parse(content)
table = tomlkit.table() table = tomlkit.table()
table["tortoise_orm"] = tortoise_orm table["tortoise_orm"] = tortoise_orm
table["location"] = location table["location"] = location
@ -232,7 +234,7 @@ async def init(ctx: Context, tortoise_orm, location, src_folder):
) )
@click.pass_context @click.pass_context
@coro @coro
async def init_db(ctx: Context, safe: bool): async def init_db(ctx: Context, safe: bool) -> None:
command = ctx.obj["command"] command = ctx.obj["command"]
app = command.app app = command.app
dirname = Path(command.location, app) dirname = Path(command.location, app)
@ -256,13 +258,13 @@ async def init_db(ctx: Context, safe: bool):
) )
@click.pass_context @click.pass_context
@coro @coro
async def inspectdb(ctx: Context, table: List[str]): async def inspectdb(ctx: Context, table: List[str]) -> None:
command = ctx.obj["command"] command = ctx.obj["command"]
ret = await command.inspectdb(table) ret = await command.inspectdb(table)
click.secho(ret) click.secho(ret)
def main(): def main() -> None:
cli() cli()

View File

@ -1,12 +1,13 @@
import base64 import base64
import json import json
import pickle # nosec: B301,B403 import pickle # nosec: B301,B403
from typing import Any, Union
from tortoise.indexes import Index from tortoise.indexes import Index
class JsonEncoder(json.JSONEncoder): class JsonEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj) -> Any:
if isinstance(obj, Index): if isinstance(obj, Index):
return { return {
"type": "index", "type": "index",
@ -16,16 +17,16 @@ class JsonEncoder(json.JSONEncoder):
return super().default(obj) return super().default(obj)
def object_hook(obj): def object_hook(obj) -> Any:
_type = obj.get("type") _type = obj.get("type")
if not _type: if not _type:
return obj return obj
return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301 return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301
def encoder(obj: dict): def encoder(obj: dict) -> str:
return json.dumps(obj, cls=JsonEncoder) return json.dumps(obj, cls=JsonEncoder)
def decoder(obj: str): def decoder(obj: Union[str, bytes]) -> Any:
return json.loads(obj, object_hook=object_hook) return json.loads(obj, object_hook=object_hook)

View File

@ -3,7 +3,8 @@ import os
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict from types import ModuleType
from typing import Dict, Optional
from click import BadOptionUsage, ClickException, Context from click import BadOptionUsage, ClickException, Context
from tortoise import BaseDBAsyncClient, Tortoise from tortoise import BaseDBAsyncClient, Tortoise
@ -84,19 +85,19 @@ def get_models_describe(app: str) -> Dict:
:return: :return:
""" """
ret = {} ret = {}
for model in Tortoise.apps.get(app).values(): for model in Tortoise.apps[app].values():
describe = model.describe() describe = model.describe()
ret[describe.get("name")] = describe ret[describe.get("name")] = describe
return ret return ret
def is_default_function(string: str): def is_default_function(string: str) -> Optional[re.Match]:
return re.match(r"^<function.+>$", str(string or "")) return re.match(r"^<function.+>$", str(string or ""))
def import_py_file(file: Path): def import_py_file(file: Path) -> ModuleType:
module_name, file_ext = os.path.splitext(os.path.split(file)[-1]) module_name, file_ext = os.path.splitext(os.path.split(file)[-1])
spec = importlib.util.spec_from_file_location(module_name, file) spec = importlib.util.spec_from_file_location(module_name, file)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec) # type:ignore[arg-type]
spec.loader.exec_module(module) spec.loader.exec_module(module) # type:ignore[union-attr]
return module return module