Simple type hints for aerich/

This commit is contained in:
Waket Zheng 2024-06-01 21:16:53 +08:00
parent 51117867a6
commit 8756f64e3f
4 changed files with 49 additions and 42 deletions

View File

@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import List
from typing import TYPE_CHECKING, List, Optional, Type
from tortoise import Tortoise, generate_schema_for_client
from tortoise.exceptions import OperationalError
@ -20,6 +20,9 @@ from aerich.utils import (
import_py_file,
)
if TYPE_CHECKING:
from aerich.inspectdb import Inspect # noqa:F401
class Command:
def __init__(
@ -27,16 +30,16 @@ class Command:
tortoise_config: dict,
app: str = "models",
location: str = "./migrations",
):
) -> None:
self.tortoise_config = tortoise_config
self.app = app
self.location = location
Migrate.app = app
async def init(self):
async def init(self) -> None:
await Migrate.init(self.tortoise_config, self.app, self.location)
async def _upgrade(self, conn, version_file):
async def _upgrade(self, conn, version_file) -> None:
file_path = Path(Migrate.migrate_location, version_file)
m = import_py_file(file_path)
upgrade = getattr(m, "upgrade")
@ -47,7 +50,7 @@ class Command:
content=get_models_describe(self.app),
)
async def upgrade(self, run_in_transaction: bool = True):
async def upgrade(self, run_in_transaction: bool = True) -> List[str]:
migrated = []
for version_file in Migrate.get_all_version_files():
try:
@ -65,8 +68,8 @@ class Command:
migrated.append(version_file)
return migrated
async def downgrade(self, version: int, delete: bool):
ret = []
async def downgrade(self, version: int, delete: bool) -> List[str]:
ret: List[str] = []
if version == -1:
specified_version = await Migrate.get_last_version()
else:
@ -79,8 +82,8 @@ class Command:
versions = [specified_version]
else:
versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk)
for version in versions:
file = version.version
for version_obj in versions:
file = version_obj.version
async with in_transaction(
get_app_connection_name(self.tortoise_config, self.app)
) as conn:
@ -91,13 +94,13 @@ class Command:
if not downgrade_sql.strip():
raise DowngradeError("No downgrade items found")
await conn.execute_script(downgrade_sql)
await version.delete()
await version_obj.delete()
if delete:
os.unlink(file_path)
ret.append(file)
return ret
async def heads(self):
async def heads(self) -> List[str]:
ret = []
versions = Migrate.get_all_version_files()
for version in versions:
@ -105,15 +108,15 @@ class Command:
ret.append(version)
return ret
async def history(self):
async def history(self) -> List[str]:
versions = Migrate.get_all_version_files()
return [version for version in versions]
async def inspectdb(self, tables: List[str] = None) -> str:
async def inspectdb(self, tables: Optional[List[str]] = None) -> str:
connection = get_app_connection(self.tortoise_config, self.app)
dialect = connection.schema_generator.DIALECT
if dialect == "mysql":
cls = InspectMySQL
cls: Type["Inspect"] = InspectMySQL
elif dialect == "postgres":
cls = InspectPostgres
elif dialect == "sqlite":
@ -126,7 +129,7 @@ class Command:
async def migrate(self, name: str = "update", empty: bool = False) -> str:
return await Migrate.migrate(name, empty)
async def init_db(self, safe: bool):
async def init_db(self, safe: bool) -> None:
location = self.location
app = self.app
dirname = Path(location, app)

View File

@ -2,7 +2,7 @@ import asyncio
import os
from functools import wraps
from pathlib import Path
from typing import List
from typing import Dict, List, cast
import click
import tomlkit
@ -23,7 +23,7 @@ CONFIG_DEFAULT_VALUES = {
def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
def wrapper(*args, **kwargs) -> None:
loop = asyncio.get_event_loop()
# Close db connections at the end of all but the cli group function
@ -48,7 +48,7 @@ def coro(f):
@click.option("--app", required=False, help="Tortoise-ORM app name.")
@click.pass_context
@coro
async def cli(ctx: Context, config, app):
async def cli(ctx: Context, config, app) -> None:
ctx.ensure_object(dict)
ctx.obj["config_file"] = config
@ -58,9 +58,9 @@ async def cli(ctx: Context, config, app):
if not config_path.exists():
raise UsageError("You must exec init first", ctx=ctx)
content = config_path.read_text()
doc = tomlkit.parse(content)
doc: dict = tomlkit.parse(content)
try:
tool = doc["tool"]["aerich"]
tool = cast(Dict[str, str], doc["tool"]["aerich"])
location = tool["location"]
tortoise_orm = tool["tortoise_orm"]
src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"])
@ -68,7 +68,9 @@ async def cli(ctx: Context, config, app):
raise UsageError("You need run aerich init again when upgrade to 0.6.0+")
add_src_path(src_folder)
tortoise_config = get_tortoise_config(ctx, tortoise_orm)
app = app or list(tortoise_config.get("apps").keys())[0]
if not app:
apps_config = cast(dict, tortoise_config.get("apps"))
app = list(apps_config.keys())[0]
command = Command(tortoise_config=tortoise_config, app=app, location=location)
ctx.obj["command"] = command
if invoked_subcommand != "init-db":
@ -82,7 +84,7 @@ async def cli(ctx: Context, config, app):
@click.option("--empty", default=False, is_flag=True, help="Generate empty migration file.")
@click.pass_context
@coro
async def migrate(ctx: Context, name):
async def migrate(ctx: Context, name) -> None:
command = ctx.obj["command"]
ret = await command.migrate(name)
if not ret:
@ -100,7 +102,7 @@ async def migrate(ctx: Context, name):
)
@click.pass_context
@coro
async def upgrade(ctx: Context, in_transaction: bool):
async def upgrade(ctx: Context, in_transaction: bool) -> None:
command = ctx.obj["command"]
migrated = await command.upgrade(run_in_transaction=in_transaction)
if not migrated:
@ -132,7 +134,7 @@ async def upgrade(ctx: Context, in_transaction: bool):
prompt="Downgrade is dangerous, which maybe lose your data, are you sure?",
)
@coro
async def downgrade(ctx: Context, version: int, delete: bool):
async def downgrade(ctx: Context, version: int, delete: bool) -> None:
command = ctx.obj["command"]
try:
files = await command.downgrade(version, delete)
@ -145,7 +147,7 @@ async def downgrade(ctx: Context, version: int, delete: bool):
@cli.command(help="Show current available heads in migrate location.")
@click.pass_context
@coro
async def heads(ctx: Context):
async def heads(ctx: Context) -> None:
command = ctx.obj["command"]
head_list = await command.heads()
if not head_list:
@ -157,7 +159,7 @@ async def heads(ctx: Context):
@cli.command(help="List all migrate items.")
@click.pass_context
@coro
async def history(ctx: Context):
async def history(ctx: Context) -> None:
command = ctx.obj["command"]
versions = await command.history()
if not versions:
@ -188,7 +190,7 @@ async def history(ctx: Context):
)
@click.pass_context
@coro
async def init(ctx: Context, tortoise_orm, location, src_folder):
async def init(ctx: Context, tortoise_orm, location, src_folder) -> None:
config_file = ctx.obj["config_file"]
if os.path.isabs(src_folder):
@ -203,9 +205,9 @@ async def init(ctx: Context, tortoise_orm, location, src_folder):
config_path = Path(config_file)
if config_path.exists():
content = config_path.read_text()
doc = tomlkit.parse(content)
else:
doc = tomlkit.parse("[tool.aerich]")
content = "[tool.aerich]"
doc: dict = tomlkit.parse(content)
table = tomlkit.table()
table["tortoise_orm"] = tortoise_orm
table["location"] = location
@ -232,7 +234,7 @@ async def init(ctx: Context, tortoise_orm, location, src_folder):
)
@click.pass_context
@coro
async def init_db(ctx: Context, safe: bool):
async def init_db(ctx: Context, safe: bool) -> None:
command = ctx.obj["command"]
app = command.app
dirname = Path(command.location, app)
@ -256,13 +258,13 @@ async def init_db(ctx: Context, safe: bool):
)
@click.pass_context
@coro
async def inspectdb(ctx: Context, table: List[str]):
async def inspectdb(ctx: Context, table: List[str]) -> None:
command = ctx.obj["command"]
ret = await command.inspectdb(table)
click.secho(ret)
def main():
def main() -> None:
cli()

View File

@ -1,12 +1,13 @@
import base64
import json
import pickle # nosec: B301,B403
from typing import Any, Union
from tortoise.indexes import Index
class JsonEncoder(json.JSONEncoder):
def default(self, obj):
def default(self, obj) -> Any:
if isinstance(obj, Index):
return {
"type": "index",
@ -16,16 +17,16 @@ class JsonEncoder(json.JSONEncoder):
return super().default(obj)
def object_hook(obj):
def object_hook(obj) -> Any:
_type = obj.get("type")
if not _type:
return obj
return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301
def encoder(obj: dict):
def encoder(obj: dict) -> str:
return json.dumps(obj, cls=JsonEncoder)
def decoder(obj: str):
def decoder(obj: Union[str, bytes]) -> Any:
return json.loads(obj, object_hook=object_hook)

View File

@ -3,7 +3,8 @@ import os
import re
import sys
from pathlib import Path
from typing import Dict
from types import ModuleType
from typing import Dict, Optional
from click import BadOptionUsage, ClickException, Context
from tortoise import BaseDBAsyncClient, Tortoise
@ -84,19 +85,19 @@ def get_models_describe(app: str) -> Dict:
:return:
"""
ret = {}
for model in Tortoise.apps.get(app).values():
for model in Tortoise.apps[app].values():
describe = model.describe()
ret[describe.get("name")] = describe
return ret
def is_default_function(string: str):
def is_default_function(string: str) -> Optional[re.Match]:
return re.match(r"^<function.+>$", str(string or ""))
def import_py_file(file: Path):
def import_py_file(file: Path) -> ModuleType:
module_name, file_ext = os.path.splitext(os.path.split(file)[-1])
spec = importlib.util.spec_from_file_location(module_name, file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
module = importlib.util.module_from_spec(spec) # type:ignore[arg-type]
spec.loader.exec_module(module) # type:ignore[union-attr]
return module