Simple type hints for aerich/
This commit is contained in:
parent
51117867a6
commit
8756f64e3f
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import TYPE_CHECKING, List, Optional, Type
|
||||||
|
|
||||||
from tortoise import Tortoise, generate_schema_for_client
|
from tortoise import Tortoise, generate_schema_for_client
|
||||||
from tortoise.exceptions import OperationalError
|
from tortoise.exceptions import OperationalError
|
||||||
@ -20,6 +20,9 @@ from aerich.utils import (
|
|||||||
import_py_file,
|
import_py_file,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from aerich.inspectdb import Inspect # noqa:F401
|
||||||
|
|
||||||
|
|
||||||
class Command:
|
class Command:
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -27,16 +30,16 @@ class Command:
|
|||||||
tortoise_config: dict,
|
tortoise_config: dict,
|
||||||
app: str = "models",
|
app: str = "models",
|
||||||
location: str = "./migrations",
|
location: str = "./migrations",
|
||||||
):
|
) -> None:
|
||||||
self.tortoise_config = tortoise_config
|
self.tortoise_config = tortoise_config
|
||||||
self.app = app
|
self.app = app
|
||||||
self.location = location
|
self.location = location
|
||||||
Migrate.app = app
|
Migrate.app = app
|
||||||
|
|
||||||
async def init(self):
|
async def init(self) -> None:
|
||||||
await Migrate.init(self.tortoise_config, self.app, self.location)
|
await Migrate.init(self.tortoise_config, self.app, self.location)
|
||||||
|
|
||||||
async def _upgrade(self, conn, version_file):
|
async def _upgrade(self, conn, version_file) -> None:
|
||||||
file_path = Path(Migrate.migrate_location, version_file)
|
file_path = Path(Migrate.migrate_location, version_file)
|
||||||
m = import_py_file(file_path)
|
m = import_py_file(file_path)
|
||||||
upgrade = getattr(m, "upgrade")
|
upgrade = getattr(m, "upgrade")
|
||||||
@ -47,7 +50,7 @@ class Command:
|
|||||||
content=get_models_describe(self.app),
|
content=get_models_describe(self.app),
|
||||||
)
|
)
|
||||||
|
|
||||||
async def upgrade(self, run_in_transaction: bool = True):
|
async def upgrade(self, run_in_transaction: bool = True) -> List[str]:
|
||||||
migrated = []
|
migrated = []
|
||||||
for version_file in Migrate.get_all_version_files():
|
for version_file in Migrate.get_all_version_files():
|
||||||
try:
|
try:
|
||||||
@ -65,8 +68,8 @@ class Command:
|
|||||||
migrated.append(version_file)
|
migrated.append(version_file)
|
||||||
return migrated
|
return migrated
|
||||||
|
|
||||||
async def downgrade(self, version: int, delete: bool):
|
async def downgrade(self, version: int, delete: bool) -> List[str]:
|
||||||
ret = []
|
ret: List[str] = []
|
||||||
if version == -1:
|
if version == -1:
|
||||||
specified_version = await Migrate.get_last_version()
|
specified_version = await Migrate.get_last_version()
|
||||||
else:
|
else:
|
||||||
@ -79,8 +82,8 @@ class Command:
|
|||||||
versions = [specified_version]
|
versions = [specified_version]
|
||||||
else:
|
else:
|
||||||
versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk)
|
versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk)
|
||||||
for version in versions:
|
for version_obj in versions:
|
||||||
file = version.version
|
file = version_obj.version
|
||||||
async with in_transaction(
|
async with in_transaction(
|
||||||
get_app_connection_name(self.tortoise_config, self.app)
|
get_app_connection_name(self.tortoise_config, self.app)
|
||||||
) as conn:
|
) as conn:
|
||||||
@ -91,13 +94,13 @@ class Command:
|
|||||||
if not downgrade_sql.strip():
|
if not downgrade_sql.strip():
|
||||||
raise DowngradeError("No downgrade items found")
|
raise DowngradeError("No downgrade items found")
|
||||||
await conn.execute_script(downgrade_sql)
|
await conn.execute_script(downgrade_sql)
|
||||||
await version.delete()
|
await version_obj.delete()
|
||||||
if delete:
|
if delete:
|
||||||
os.unlink(file_path)
|
os.unlink(file_path)
|
||||||
ret.append(file)
|
ret.append(file)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async def heads(self):
|
async def heads(self) -> List[str]:
|
||||||
ret = []
|
ret = []
|
||||||
versions = Migrate.get_all_version_files()
|
versions = Migrate.get_all_version_files()
|
||||||
for version in versions:
|
for version in versions:
|
||||||
@ -105,15 +108,15 @@ class Command:
|
|||||||
ret.append(version)
|
ret.append(version)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async def history(self):
|
async def history(self) -> List[str]:
|
||||||
versions = Migrate.get_all_version_files()
|
versions = Migrate.get_all_version_files()
|
||||||
return [version for version in versions]
|
return [version for version in versions]
|
||||||
|
|
||||||
async def inspectdb(self, tables: List[str] = None) -> str:
|
async def inspectdb(self, tables: Optional[List[str]] = None) -> str:
|
||||||
connection = get_app_connection(self.tortoise_config, self.app)
|
connection = get_app_connection(self.tortoise_config, self.app)
|
||||||
dialect = connection.schema_generator.DIALECT
|
dialect = connection.schema_generator.DIALECT
|
||||||
if dialect == "mysql":
|
if dialect == "mysql":
|
||||||
cls = InspectMySQL
|
cls: Type["Inspect"] = InspectMySQL
|
||||||
elif dialect == "postgres":
|
elif dialect == "postgres":
|
||||||
cls = InspectPostgres
|
cls = InspectPostgres
|
||||||
elif dialect == "sqlite":
|
elif dialect == "sqlite":
|
||||||
@ -126,7 +129,7 @@ class Command:
|
|||||||
async def migrate(self, name: str = "update", empty: bool = False) -> str:
|
async def migrate(self, name: str = "update", empty: bool = False) -> str:
|
||||||
return await Migrate.migrate(name, empty)
|
return await Migrate.migrate(name, empty)
|
||||||
|
|
||||||
async def init_db(self, safe: bool):
|
async def init_db(self, safe: bool) -> None:
|
||||||
location = self.location
|
location = self.location
|
||||||
app = self.app
|
app = self.app
|
||||||
dirname = Path(location, app)
|
dirname = Path(location, app)
|
||||||
|
@ -2,7 +2,7 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import Dict, List, cast
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import tomlkit
|
import tomlkit
|
||||||
@ -23,7 +23,7 @@ CONFIG_DEFAULT_VALUES = {
|
|||||||
|
|
||||||
def coro(f):
|
def coro(f):
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs) -> None:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
# Close db connections at the end of all but the cli group function
|
# Close db connections at the end of all but the cli group function
|
||||||
@ -48,7 +48,7 @@ def coro(f):
|
|||||||
@click.option("--app", required=False, help="Tortoise-ORM app name.")
|
@click.option("--app", required=False, help="Tortoise-ORM app name.")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
async def cli(ctx: Context, config, app):
|
async def cli(ctx: Context, config, app) -> None:
|
||||||
ctx.ensure_object(dict)
|
ctx.ensure_object(dict)
|
||||||
ctx.obj["config_file"] = config
|
ctx.obj["config_file"] = config
|
||||||
|
|
||||||
@ -58,9 +58,9 @@ async def cli(ctx: Context, config, app):
|
|||||||
if not config_path.exists():
|
if not config_path.exists():
|
||||||
raise UsageError("You must exec init first", ctx=ctx)
|
raise UsageError("You must exec init first", ctx=ctx)
|
||||||
content = config_path.read_text()
|
content = config_path.read_text()
|
||||||
doc = tomlkit.parse(content)
|
doc: dict = tomlkit.parse(content)
|
||||||
try:
|
try:
|
||||||
tool = doc["tool"]["aerich"]
|
tool = cast(Dict[str, str], doc["tool"]["aerich"])
|
||||||
location = tool["location"]
|
location = tool["location"]
|
||||||
tortoise_orm = tool["tortoise_orm"]
|
tortoise_orm = tool["tortoise_orm"]
|
||||||
src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"])
|
src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"])
|
||||||
@ -68,7 +68,9 @@ async def cli(ctx: Context, config, app):
|
|||||||
raise UsageError("You need run aerich init again when upgrade to 0.6.0+")
|
raise UsageError("You need run aerich init again when upgrade to 0.6.0+")
|
||||||
add_src_path(src_folder)
|
add_src_path(src_folder)
|
||||||
tortoise_config = get_tortoise_config(ctx, tortoise_orm)
|
tortoise_config = get_tortoise_config(ctx, tortoise_orm)
|
||||||
app = app or list(tortoise_config.get("apps").keys())[0]
|
if not app:
|
||||||
|
apps_config = cast(dict, tortoise_config.get("apps"))
|
||||||
|
app = list(apps_config.keys())[0]
|
||||||
command = Command(tortoise_config=tortoise_config, app=app, location=location)
|
command = Command(tortoise_config=tortoise_config, app=app, location=location)
|
||||||
ctx.obj["command"] = command
|
ctx.obj["command"] = command
|
||||||
if invoked_subcommand != "init-db":
|
if invoked_subcommand != "init-db":
|
||||||
@ -82,7 +84,7 @@ async def cli(ctx: Context, config, app):
|
|||||||
@click.option("--empty", default=False, is_flag=True, help="Generate empty migration file.")
|
@click.option("--empty", default=False, is_flag=True, help="Generate empty migration file.")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
async def migrate(ctx: Context, name):
|
async def migrate(ctx: Context, name) -> None:
|
||||||
command = ctx.obj["command"]
|
command = ctx.obj["command"]
|
||||||
ret = await command.migrate(name)
|
ret = await command.migrate(name)
|
||||||
if not ret:
|
if not ret:
|
||||||
@ -100,7 +102,7 @@ async def migrate(ctx: Context, name):
|
|||||||
)
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
async def upgrade(ctx: Context, in_transaction: bool):
|
async def upgrade(ctx: Context, in_transaction: bool) -> None:
|
||||||
command = ctx.obj["command"]
|
command = ctx.obj["command"]
|
||||||
migrated = await command.upgrade(run_in_transaction=in_transaction)
|
migrated = await command.upgrade(run_in_transaction=in_transaction)
|
||||||
if not migrated:
|
if not migrated:
|
||||||
@ -132,7 +134,7 @@ async def upgrade(ctx: Context, in_transaction: bool):
|
|||||||
prompt="Downgrade is dangerous, which maybe lose your data, are you sure?",
|
prompt="Downgrade is dangerous, which maybe lose your data, are you sure?",
|
||||||
)
|
)
|
||||||
@coro
|
@coro
|
||||||
async def downgrade(ctx: Context, version: int, delete: bool):
|
async def downgrade(ctx: Context, version: int, delete: bool) -> None:
|
||||||
command = ctx.obj["command"]
|
command = ctx.obj["command"]
|
||||||
try:
|
try:
|
||||||
files = await command.downgrade(version, delete)
|
files = await command.downgrade(version, delete)
|
||||||
@ -145,7 +147,7 @@ async def downgrade(ctx: Context, version: int, delete: bool):
|
|||||||
@cli.command(help="Show current available heads in migrate location.")
|
@cli.command(help="Show current available heads in migrate location.")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
async def heads(ctx: Context):
|
async def heads(ctx: Context) -> None:
|
||||||
command = ctx.obj["command"]
|
command = ctx.obj["command"]
|
||||||
head_list = await command.heads()
|
head_list = await command.heads()
|
||||||
if not head_list:
|
if not head_list:
|
||||||
@ -157,7 +159,7 @@ async def heads(ctx: Context):
|
|||||||
@cli.command(help="List all migrate items.")
|
@cli.command(help="List all migrate items.")
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
async def history(ctx: Context):
|
async def history(ctx: Context) -> None:
|
||||||
command = ctx.obj["command"]
|
command = ctx.obj["command"]
|
||||||
versions = await command.history()
|
versions = await command.history()
|
||||||
if not versions:
|
if not versions:
|
||||||
@ -188,7 +190,7 @@ async def history(ctx: Context):
|
|||||||
)
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
async def init(ctx: Context, tortoise_orm, location, src_folder):
|
async def init(ctx: Context, tortoise_orm, location, src_folder) -> None:
|
||||||
config_file = ctx.obj["config_file"]
|
config_file = ctx.obj["config_file"]
|
||||||
|
|
||||||
if os.path.isabs(src_folder):
|
if os.path.isabs(src_folder):
|
||||||
@ -203,9 +205,9 @@ async def init(ctx: Context, tortoise_orm, location, src_folder):
|
|||||||
config_path = Path(config_file)
|
config_path = Path(config_file)
|
||||||
if config_path.exists():
|
if config_path.exists():
|
||||||
content = config_path.read_text()
|
content = config_path.read_text()
|
||||||
doc = tomlkit.parse(content)
|
|
||||||
else:
|
else:
|
||||||
doc = tomlkit.parse("[tool.aerich]")
|
content = "[tool.aerich]"
|
||||||
|
doc: dict = tomlkit.parse(content)
|
||||||
table = tomlkit.table()
|
table = tomlkit.table()
|
||||||
table["tortoise_orm"] = tortoise_orm
|
table["tortoise_orm"] = tortoise_orm
|
||||||
table["location"] = location
|
table["location"] = location
|
||||||
@ -232,7 +234,7 @@ async def init(ctx: Context, tortoise_orm, location, src_folder):
|
|||||||
)
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
async def init_db(ctx: Context, safe: bool):
|
async def init_db(ctx: Context, safe: bool) -> None:
|
||||||
command = ctx.obj["command"]
|
command = ctx.obj["command"]
|
||||||
app = command.app
|
app = command.app
|
||||||
dirname = Path(command.location, app)
|
dirname = Path(command.location, app)
|
||||||
@ -256,13 +258,13 @@ async def init_db(ctx: Context, safe: bool):
|
|||||||
)
|
)
|
||||||
@click.pass_context
|
@click.pass_context
|
||||||
@coro
|
@coro
|
||||||
async def inspectdb(ctx: Context, table: List[str]):
|
async def inspectdb(ctx: Context, table: List[str]) -> None:
|
||||||
command = ctx.obj["command"]
|
command = ctx.obj["command"]
|
||||||
ret = await command.inspectdb(table)
|
ret = await command.inspectdb(table)
|
||||||
click.secho(ret)
|
click.secho(ret)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main() -> None:
|
||||||
cli()
|
cli()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import pickle # nosec: B301,B403
|
import pickle # nosec: B301,B403
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
from tortoise.indexes import Index
|
from tortoise.indexes import Index
|
||||||
|
|
||||||
|
|
||||||
class JsonEncoder(json.JSONEncoder):
|
class JsonEncoder(json.JSONEncoder):
|
||||||
def default(self, obj):
|
def default(self, obj) -> Any:
|
||||||
if isinstance(obj, Index):
|
if isinstance(obj, Index):
|
||||||
return {
|
return {
|
||||||
"type": "index",
|
"type": "index",
|
||||||
@ -16,16 +17,16 @@ class JsonEncoder(json.JSONEncoder):
|
|||||||
return super().default(obj)
|
return super().default(obj)
|
||||||
|
|
||||||
|
|
||||||
def object_hook(obj):
|
def object_hook(obj) -> Any:
|
||||||
_type = obj.get("type")
|
_type = obj.get("type")
|
||||||
if not _type:
|
if not _type:
|
||||||
return obj
|
return obj
|
||||||
return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301
|
return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301
|
||||||
|
|
||||||
|
|
||||||
def encoder(obj: dict):
|
def encoder(obj: dict) -> str:
|
||||||
return json.dumps(obj, cls=JsonEncoder)
|
return json.dumps(obj, cls=JsonEncoder)
|
||||||
|
|
||||||
|
|
||||||
def decoder(obj: str):
|
def decoder(obj: Union[str, bytes]) -> Any:
|
||||||
return json.loads(obj, object_hook=object_hook)
|
return json.loads(obj, object_hook=object_hook)
|
||||||
|
@ -3,7 +3,8 @@ import os
|
|||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from types import ModuleType
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
from click import BadOptionUsage, ClickException, Context
|
from click import BadOptionUsage, ClickException, Context
|
||||||
from tortoise import BaseDBAsyncClient, Tortoise
|
from tortoise import BaseDBAsyncClient, Tortoise
|
||||||
@ -84,19 +85,19 @@ def get_models_describe(app: str) -> Dict:
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
ret = {}
|
ret = {}
|
||||||
for model in Tortoise.apps.get(app).values():
|
for model in Tortoise.apps[app].values():
|
||||||
describe = model.describe()
|
describe = model.describe()
|
||||||
ret[describe.get("name")] = describe
|
ret[describe.get("name")] = describe
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def is_default_function(string: str):
|
def is_default_function(string: str) -> Optional[re.Match]:
|
||||||
return re.match(r"^<function.+>$", str(string or ""))
|
return re.match(r"^<function.+>$", str(string or ""))
|
||||||
|
|
||||||
|
|
||||||
def import_py_file(file: Path):
|
def import_py_file(file: Path) -> ModuleType:
|
||||||
module_name, file_ext = os.path.splitext(os.path.split(file)[-1])
|
module_name, file_ext = os.path.splitext(os.path.split(file)[-1])
|
||||||
spec = importlib.util.spec_from_file_location(module_name, file)
|
spec = importlib.util.spec_from_file_location(module_name, file)
|
||||||
module = importlib.util.module_from_spec(spec)
|
module = importlib.util.module_from_spec(spec) # type:ignore[arg-type]
|
||||||
spec.loader.exec_module(module)
|
spec.loader.exec_module(module) # type:ignore[union-attr]
|
||||||
return module
|
return module
|
||||||
|
Loading…
x
Reference in New Issue
Block a user