Simple type hints for aerich/
This commit is contained in:
parent
51117867a6
commit
8756f64e3f
@ -1,6 +1,6 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List, Optional, Type
|
||||
|
||||
from tortoise import Tortoise, generate_schema_for_client
|
||||
from tortoise.exceptions import OperationalError
|
||||
@ -20,6 +20,9 @@ from aerich.utils import (
|
||||
import_py_file,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aerich.inspectdb import Inspect # noqa:F401
|
||||
|
||||
|
||||
class Command:
|
||||
def __init__(
|
||||
@ -27,16 +30,16 @@ class Command:
|
||||
tortoise_config: dict,
|
||||
app: str = "models",
|
||||
location: str = "./migrations",
|
||||
):
|
||||
) -> None:
|
||||
self.tortoise_config = tortoise_config
|
||||
self.app = app
|
||||
self.location = location
|
||||
Migrate.app = app
|
||||
|
||||
async def init(self):
|
||||
async def init(self) -> None:
|
||||
await Migrate.init(self.tortoise_config, self.app, self.location)
|
||||
|
||||
async def _upgrade(self, conn, version_file):
|
||||
async def _upgrade(self, conn, version_file) -> None:
|
||||
file_path = Path(Migrate.migrate_location, version_file)
|
||||
m = import_py_file(file_path)
|
||||
upgrade = getattr(m, "upgrade")
|
||||
@ -47,7 +50,7 @@ class Command:
|
||||
content=get_models_describe(self.app),
|
||||
)
|
||||
|
||||
async def upgrade(self, run_in_transaction: bool = True):
|
||||
async def upgrade(self, run_in_transaction: bool = True) -> List[str]:
|
||||
migrated = []
|
||||
for version_file in Migrate.get_all_version_files():
|
||||
try:
|
||||
@ -65,8 +68,8 @@ class Command:
|
||||
migrated.append(version_file)
|
||||
return migrated
|
||||
|
||||
async def downgrade(self, version: int, delete: bool):
|
||||
ret = []
|
||||
async def downgrade(self, version: int, delete: bool) -> List[str]:
|
||||
ret: List[str] = []
|
||||
if version == -1:
|
||||
specified_version = await Migrate.get_last_version()
|
||||
else:
|
||||
@ -79,8 +82,8 @@ class Command:
|
||||
versions = [specified_version]
|
||||
else:
|
||||
versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk)
|
||||
for version in versions:
|
||||
file = version.version
|
||||
for version_obj in versions:
|
||||
file = version_obj.version
|
||||
async with in_transaction(
|
||||
get_app_connection_name(self.tortoise_config, self.app)
|
||||
) as conn:
|
||||
@ -91,13 +94,13 @@ class Command:
|
||||
if not downgrade_sql.strip():
|
||||
raise DowngradeError("No downgrade items found")
|
||||
await conn.execute_script(downgrade_sql)
|
||||
await version.delete()
|
||||
await version_obj.delete()
|
||||
if delete:
|
||||
os.unlink(file_path)
|
||||
ret.append(file)
|
||||
return ret
|
||||
|
||||
async def heads(self):
|
||||
async def heads(self) -> List[str]:
|
||||
ret = []
|
||||
versions = Migrate.get_all_version_files()
|
||||
for version in versions:
|
||||
@ -105,15 +108,15 @@ class Command:
|
||||
ret.append(version)
|
||||
return ret
|
||||
|
||||
async def history(self):
|
||||
async def history(self) -> List[str]:
|
||||
versions = Migrate.get_all_version_files()
|
||||
return [version for version in versions]
|
||||
|
||||
async def inspectdb(self, tables: List[str] = None) -> str:
|
||||
async def inspectdb(self, tables: Optional[List[str]] = None) -> str:
|
||||
connection = get_app_connection(self.tortoise_config, self.app)
|
||||
dialect = connection.schema_generator.DIALECT
|
||||
if dialect == "mysql":
|
||||
cls = InspectMySQL
|
||||
cls: Type["Inspect"] = InspectMySQL
|
||||
elif dialect == "postgres":
|
||||
cls = InspectPostgres
|
||||
elif dialect == "sqlite":
|
||||
@ -126,7 +129,7 @@ class Command:
|
||||
async def migrate(self, name: str = "update", empty: bool = False) -> str:
|
||||
return await Migrate.migrate(name, empty)
|
||||
|
||||
async def init_db(self, safe: bool):
|
||||
async def init_db(self, safe: bool) -> None:
|
||||
location = self.location
|
||||
app = self.app
|
||||
dirname = Path(location, app)
|
||||
|
@ -2,7 +2,7 @@ import asyncio
|
||||
import os
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Dict, List, cast
|
||||
|
||||
import click
|
||||
import tomlkit
|
||||
@ -23,7 +23,7 @@ CONFIG_DEFAULT_VALUES = {
|
||||
|
||||
def coro(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
def wrapper(*args, **kwargs) -> None:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Close db connections at the end of all but the cli group function
|
||||
@ -48,7 +48,7 @@ def coro(f):
|
||||
@click.option("--app", required=False, help="Tortoise-ORM app name.")
|
||||
@click.pass_context
|
||||
@coro
|
||||
async def cli(ctx: Context, config, app):
|
||||
async def cli(ctx: Context, config, app) -> None:
|
||||
ctx.ensure_object(dict)
|
||||
ctx.obj["config_file"] = config
|
||||
|
||||
@ -58,9 +58,9 @@ async def cli(ctx: Context, config, app):
|
||||
if not config_path.exists():
|
||||
raise UsageError("You must exec init first", ctx=ctx)
|
||||
content = config_path.read_text()
|
||||
doc = tomlkit.parse(content)
|
||||
doc: dict = tomlkit.parse(content)
|
||||
try:
|
||||
tool = doc["tool"]["aerich"]
|
||||
tool = cast(Dict[str, str], doc["tool"]["aerich"])
|
||||
location = tool["location"]
|
||||
tortoise_orm = tool["tortoise_orm"]
|
||||
src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"])
|
||||
@ -68,7 +68,9 @@ async def cli(ctx: Context, config, app):
|
||||
raise UsageError("You need run aerich init again when upgrade to 0.6.0+")
|
||||
add_src_path(src_folder)
|
||||
tortoise_config = get_tortoise_config(ctx, tortoise_orm)
|
||||
app = app or list(tortoise_config.get("apps").keys())[0]
|
||||
if not app:
|
||||
apps_config = cast(dict, tortoise_config.get("apps"))
|
||||
app = list(apps_config.keys())[0]
|
||||
command = Command(tortoise_config=tortoise_config, app=app, location=location)
|
||||
ctx.obj["command"] = command
|
||||
if invoked_subcommand != "init-db":
|
||||
@ -82,7 +84,7 @@ async def cli(ctx: Context, config, app):
|
||||
@click.option("--empty", default=False, is_flag=True, help="Generate empty migration file.")
|
||||
@click.pass_context
|
||||
@coro
|
||||
async def migrate(ctx: Context, name):
|
||||
async def migrate(ctx: Context, name) -> None:
|
||||
command = ctx.obj["command"]
|
||||
ret = await command.migrate(name)
|
||||
if not ret:
|
||||
@ -100,7 +102,7 @@ async def migrate(ctx: Context, name):
|
||||
)
|
||||
@click.pass_context
|
||||
@coro
|
||||
async def upgrade(ctx: Context, in_transaction: bool):
|
||||
async def upgrade(ctx: Context, in_transaction: bool) -> None:
|
||||
command = ctx.obj["command"]
|
||||
migrated = await command.upgrade(run_in_transaction=in_transaction)
|
||||
if not migrated:
|
||||
@ -132,7 +134,7 @@ async def upgrade(ctx: Context, in_transaction: bool):
|
||||
prompt="Downgrade is dangerous, which maybe lose your data, are you sure?",
|
||||
)
|
||||
@coro
|
||||
async def downgrade(ctx: Context, version: int, delete: bool):
|
||||
async def downgrade(ctx: Context, version: int, delete: bool) -> None:
|
||||
command = ctx.obj["command"]
|
||||
try:
|
||||
files = await command.downgrade(version, delete)
|
||||
@ -145,7 +147,7 @@ async def downgrade(ctx: Context, version: int, delete: bool):
|
||||
@cli.command(help="Show current available heads in migrate location.")
|
||||
@click.pass_context
|
||||
@coro
|
||||
async def heads(ctx: Context):
|
||||
async def heads(ctx: Context) -> None:
|
||||
command = ctx.obj["command"]
|
||||
head_list = await command.heads()
|
||||
if not head_list:
|
||||
@ -157,7 +159,7 @@ async def heads(ctx: Context):
|
||||
@cli.command(help="List all migrate items.")
|
||||
@click.pass_context
|
||||
@coro
|
||||
async def history(ctx: Context):
|
||||
async def history(ctx: Context) -> None:
|
||||
command = ctx.obj["command"]
|
||||
versions = await command.history()
|
||||
if not versions:
|
||||
@ -188,7 +190,7 @@ async def history(ctx: Context):
|
||||
)
|
||||
@click.pass_context
|
||||
@coro
|
||||
async def init(ctx: Context, tortoise_orm, location, src_folder):
|
||||
async def init(ctx: Context, tortoise_orm, location, src_folder) -> None:
|
||||
config_file = ctx.obj["config_file"]
|
||||
|
||||
if os.path.isabs(src_folder):
|
||||
@ -203,9 +205,9 @@ async def init(ctx: Context, tortoise_orm, location, src_folder):
|
||||
config_path = Path(config_file)
|
||||
if config_path.exists():
|
||||
content = config_path.read_text()
|
||||
doc = tomlkit.parse(content)
|
||||
else:
|
||||
doc = tomlkit.parse("[tool.aerich]")
|
||||
content = "[tool.aerich]"
|
||||
doc: dict = tomlkit.parse(content)
|
||||
table = tomlkit.table()
|
||||
table["tortoise_orm"] = tortoise_orm
|
||||
table["location"] = location
|
||||
@ -232,7 +234,7 @@ async def init(ctx: Context, tortoise_orm, location, src_folder):
|
||||
)
|
||||
@click.pass_context
|
||||
@coro
|
||||
async def init_db(ctx: Context, safe: bool):
|
||||
async def init_db(ctx: Context, safe: bool) -> None:
|
||||
command = ctx.obj["command"]
|
||||
app = command.app
|
||||
dirname = Path(command.location, app)
|
||||
@ -256,13 +258,13 @@ async def init_db(ctx: Context, safe: bool):
|
||||
)
|
||||
@click.pass_context
|
||||
@coro
|
||||
async def inspectdb(ctx: Context, table: List[str]):
|
||||
async def inspectdb(ctx: Context, table: List[str]) -> None:
|
||||
command = ctx.obj["command"]
|
||||
ret = await command.inspectdb(table)
|
||||
click.secho(ret)
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
cli()
|
||||
|
||||
|
||||
|
@ -1,12 +1,13 @@
|
||||
import base64
|
||||
import json
|
||||
import pickle # nosec: B301,B403
|
||||
from typing import Any, Union
|
||||
|
||||
from tortoise.indexes import Index
|
||||
|
||||
|
||||
class JsonEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
def default(self, obj) -> Any:
|
||||
if isinstance(obj, Index):
|
||||
return {
|
||||
"type": "index",
|
||||
@ -16,16 +17,16 @@ class JsonEncoder(json.JSONEncoder):
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
def object_hook(obj):
|
||||
def object_hook(obj) -> Any:
|
||||
_type = obj.get("type")
|
||||
if not _type:
|
||||
return obj
|
||||
return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301
|
||||
|
||||
|
||||
def encoder(obj: dict):
|
||||
def encoder(obj: dict) -> str:
|
||||
return json.dumps(obj, cls=JsonEncoder)
|
||||
|
||||
|
||||
def decoder(obj: str):
|
||||
def decoder(obj: Union[str, bytes]) -> Any:
|
||||
return json.loads(obj, object_hook=object_hook)
|
||||
|
@ -3,7 +3,8 @@ import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from types import ModuleType
|
||||
from typing import Dict, Optional
|
||||
|
||||
from click import BadOptionUsage, ClickException, Context
|
||||
from tortoise import BaseDBAsyncClient, Tortoise
|
||||
@ -84,19 +85,19 @@ def get_models_describe(app: str) -> Dict:
|
||||
:return:
|
||||
"""
|
||||
ret = {}
|
||||
for model in Tortoise.apps.get(app).values():
|
||||
for model in Tortoise.apps[app].values():
|
||||
describe = model.describe()
|
||||
ret[describe.get("name")] = describe
|
||||
return ret
|
||||
|
||||
|
||||
def is_default_function(string: str):
|
||||
def is_default_function(string: str) -> Optional[re.Match]:
|
||||
return re.match(r"^<function.+>$", str(string or ""))
|
||||
|
||||
|
||||
def import_py_file(file: Path):
|
||||
def import_py_file(file: Path) -> ModuleType:
|
||||
module_name, file_ext = os.path.splitext(os.path.split(file)[-1])
|
||||
spec = importlib.util.spec_from_file_location(module_name, file)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
module = importlib.util.module_from_spec(spec) # type:ignore[arg-type]
|
||||
spec.loader.exec_module(module) # type:ignore[union-attr]
|
||||
return module
|
||||
|
Loading…
x
Reference in New Issue
Block a user