feat: use .py version files

This commit is contained in:
long2ice
2022-09-23 10:29:48 +08:00
parent 8c2ecbaef1
commit e9b76bdd35
10 changed files with 91 additions and 111 deletions

View File

@@ -11,14 +11,13 @@ from aerich.exceptions import DowngradeError
from aerich.inspectdb.mysql import InspectMySQL
from aerich.inspectdb.postgres import InspectPostgres
from aerich.inspectdb.sqlite import InspectSQLite
from aerich.migrate import Migrate
from aerich.migrate import MIGRATE_TEMPLATE, Migrate
from aerich.models import Aerich
from aerich.utils import (
get_app_connection,
get_app_connection_name,
get_models_describe,
get_version_content_from_file,
write_version_file,
import_py_file,
)
@@ -49,10 +48,9 @@ class Command:
get_app_connection_name(self.tortoise_config, self.app)
) as conn:
file_path = Path(Migrate.migrate_location, version_file)
content = get_version_content_from_file(file_path)
upgrade_query_list = content.get("upgrade")
for upgrade_query in upgrade_query_list:
await conn.execute_script(upgrade_query)
m = import_py_file(file_path)
upgrade = getattr(m, "upgrade")
await upgrade(conn)
await Aerich.create(
version=version_file,
app=self.app,
@@ -81,12 +79,11 @@ class Command:
get_app_connection_name(self.tortoise_config, self.app)
) as conn:
file_path = Path(Migrate.migrate_location, file)
content = get_version_content_from_file(file_path)
downgrade_query_list = content.get("downgrade")
if not downgrade_query_list:
m = import_py_file(file_path)
downgrade = getattr(m, "downgrade", None)
if not downgrade:
raise DowngradeError("No downgrade items found")
for downgrade_query in downgrade_query_list:
await conn.execute_query(downgrade_query)
await downgrade(conn)
await version.delete()
if delete:
os.unlink(file_path)
@@ -143,7 +140,7 @@ class Command:
app=app,
content=get_models_describe(app),
)
content = {
"upgrade": [schema],
}
write_version_file(Path(dirname, version), content)
version_file = Path(dirname, version)
content = MIGRATE_TEMPLATE.format(upgrade_sql=f'"""{schema}"""', downgrade_sql="")
with open(version_file, "w", encoding="utf-8") as f:
f.write(content)

View File

@@ -26,11 +26,11 @@ def coro(f):
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
# Close db connections at the end of all all but the cli group function
# Close db connections at the end of all but the cli group function
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
if f.__name__ not in ["cli", "init_db"]:
if f.__name__ not in ["cli", "init_db", "init"]:
loop.run_until_complete(Tortoise.close_connections())
return wrapper

View File

@@ -13,12 +13,24 @@ from tortoise.indexes import Index
from aerich.ddl import BaseDDL
from aerich.models import MAX_VERSION_LENGTH, Aerich
from aerich.utils import (
get_app_connection,
get_models_describe,
is_default_function,
write_version_file,
)
from aerich.utils import get_app_connection, get_models_describe, is_default_function
MIGRATE_TEMPLATE = """from typing import List
from tortoise import BaseDBAsyncClient
async def upgrade(db: BaseDBAsyncClient) -> List[str]:
return [
{upgrade_sql}
]
async def downgrade(db: BaseDBAsyncClient) -> List[str]:
return [
{downgrade_sql}
]
"""
class Migrate:
@@ -40,9 +52,9 @@ class Migrate:
_db_version: Optional[str] = None
@classmethod
def get_all_version_files(cls) -> List[str]:
def get_all_version_files(cls) -> list[str]:
return sorted(
filter(lambda x: x.endswith("sql"), os.listdir(cls.migrate_location)),
filter(lambda x: x.endswith("py"), os.listdir(cls.migrate_location)),
key=lambda x: int(x.split("_")[0]),
)
@@ -97,24 +109,27 @@ class Migrate:
now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "")
last_version_num = await cls._get_last_version_num()
if last_version_num is None:
return f"0_{now}_init.sql"
version = f"{last_version_num + 1}_{now}_{name}.sql"
return f"0_{now}_init.py"
version = f"{last_version_num + 1}_{now}_{name}.py"
if len(version) > MAX_VERSION_LENGTH:
raise ValueError(f"Version name exceeds maximum length ({MAX_VERSION_LENGTH})")
return version
@classmethod
async def _generate_diff_sql(cls, name):
async def _generate_diff_py(cls, name):
version = await cls.generate_version(name)
# delete if same version exists
for version_file in cls.get_all_version_files():
if version_file.startswith(version.split("_")[0]):
os.unlink(Path(cls.migrate_location, version_file))
content = {
"upgrade": list(dict.fromkeys(cls.upgrade_operators)),
"downgrade": list(dict.fromkeys(cls.downgrade_operators)),
}
write_version_file(Path(cls.migrate_location, version), content)
version_file = Path(cls.migrate_location, version)
content = MIGRATE_TEMPLATE.format(
upgrade_sql=",\n ".join(map(lambda x: f'"""{x}"""', cls.upgrade_operators)),
downgrade_sql=",\n ".join(map(lambda x: f'"""{x}"""', cls.downgrade_operators)),
)
with open(version_file, "w", encoding="utf-8") as f:
f.write(content)
return version
@classmethod
@@ -133,7 +148,7 @@ class Migrate:
if not cls.upgrade_operators:
return ""
return await cls._generate_diff_sql(name)
return await cls._generate_diff_py(name)
@classmethod
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False):

View File

@@ -1,9 +1,9 @@
import importlib
import importlib.util
import os
import re
import sys
from pathlib import Path
from typing import Dict, Union
from typing import Dict
from click import BadOptionUsage, ClickException, Context
from tortoise import BaseDBAsyncClient, Tortoise
@@ -11,7 +11,7 @@ from tortoise import BaseDBAsyncClient, Tortoise
def add_src_path(path: str) -> str:
"""
add a folder to the paths so we can import from there
add a folder to the paths, so we can import from there
:param path: path to add
:return: absolute path
"""
@@ -77,60 +77,6 @@ def get_tortoise_config(ctx: Context, tortoise_orm: str) -> dict:
return config
_UPGRADE = "-- upgrade --\n"
_DOWNGRADE = "-- downgrade --\n"
def get_version_content_from_file(version_file: Union[str, Path]) -> Dict:
"""
get version content
:param version_file:
:return:
"""
with open(version_file, "r", encoding="utf-8") as f:
content = f.read()
first = content.index(_UPGRADE)
try:
second = content.index(_DOWNGRADE)
except ValueError:
second = len(content) - 1
upgrade_content = content[first + len(_UPGRADE) : second].strip() # noqa:E203
downgrade_content = content[second + len(_DOWNGRADE) :].strip() # noqa:E203
ret = {
"upgrade": list(filter(lambda x: x or False, upgrade_content.split(";\n"))),
"downgrade": list(filter(lambda x: x or False, downgrade_content.split(";\n"))),
}
return ret
def write_version_file(version_file: Path, content: Dict):
"""
write version file
:param version_file:
:param content:
:return:
"""
with open(version_file, "w", encoding="utf-8") as f:
f.write(_UPGRADE)
upgrade = content.get("upgrade")
if len(upgrade) > 1:
f.write(";\n".join(upgrade))
if not upgrade[-1].endswith(";"):
f.write(";\n")
else:
f.write(f"{upgrade[0]}")
if not upgrade[0].endswith(";"):
f.write(";")
f.write("\n")
downgrade = content.get("downgrade")
if downgrade:
f.write(_DOWNGRADE)
if len(downgrade) > 1:
f.write(";\n".join(downgrade) + ";\n")
else:
f.write(f"{downgrade[0]};\n")
def get_models_describe(app: str) -> Dict:
"""
get app models describe
@@ -146,3 +92,11 @@ def get_models_describe(app: str) -> Dict:
def is_default_function(string: str):
return re.match(r"^<function.+>$", str(string or ""))
def import_py_file(file: Path):
module_name, file_ext = os.path.splitext(os.path.split(file)[-1])
spec = importlib.util.spec_from_file_location(module_name, file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module