feat: use .py version files
This commit is contained in:
@@ -13,12 +13,24 @@ from tortoise.indexes import Index
|
||||
|
||||
from aerich.ddl import BaseDDL
|
||||
from aerich.models import MAX_VERSION_LENGTH, Aerich
|
||||
from aerich.utils import (
|
||||
get_app_connection,
|
||||
get_models_describe,
|
||||
is_default_function,
|
||||
write_version_file,
|
||||
)
|
||||
from aerich.utils import get_app_connection, get_models_describe, is_default_function
|
||||
|
||||
MIGRATE_TEMPLATE = """from typing import List
|
||||
|
||||
from tortoise import BaseDBAsyncClient
|
||||
|
||||
|
||||
async def upgrade(db: BaseDBAsyncClient) -> List[str]:
|
||||
return [
|
||||
{upgrade_sql}
|
||||
]
|
||||
|
||||
|
||||
async def downgrade(db: BaseDBAsyncClient) -> List[str]:
|
||||
return [
|
||||
{downgrade_sql}
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
class Migrate:
|
||||
@@ -40,9 +52,9 @@ class Migrate:
|
||||
_db_version: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def get_all_version_files(cls) -> List[str]:
|
||||
def get_all_version_files(cls) -> list[str]:
|
||||
return sorted(
|
||||
filter(lambda x: x.endswith("sql"), os.listdir(cls.migrate_location)),
|
||||
filter(lambda x: x.endswith("py"), os.listdir(cls.migrate_location)),
|
||||
key=lambda x: int(x.split("_")[0]),
|
||||
)
|
||||
|
||||
@@ -97,24 +109,27 @@ class Migrate:
|
||||
now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "")
|
||||
last_version_num = await cls._get_last_version_num()
|
||||
if last_version_num is None:
|
||||
return f"0_{now}_init.sql"
|
||||
version = f"{last_version_num + 1}_{now}_{name}.sql"
|
||||
return f"0_{now}_init.py"
|
||||
version = f"{last_version_num + 1}_{now}_{name}.py"
|
||||
if len(version) > MAX_VERSION_LENGTH:
|
||||
raise ValueError(f"Version name exceeds maximum length ({MAX_VERSION_LENGTH})")
|
||||
return version
|
||||
|
||||
@classmethod
|
||||
async def _generate_diff_sql(cls, name):
|
||||
async def _generate_diff_py(cls, name):
|
||||
version = await cls.generate_version(name)
|
||||
# delete if same version exists
|
||||
for version_file in cls.get_all_version_files():
|
||||
if version_file.startswith(version.split("_")[0]):
|
||||
os.unlink(Path(cls.migrate_location, version_file))
|
||||
content = {
|
||||
"upgrade": list(dict.fromkeys(cls.upgrade_operators)),
|
||||
"downgrade": list(dict.fromkeys(cls.downgrade_operators)),
|
||||
}
|
||||
write_version_file(Path(cls.migrate_location, version), content)
|
||||
|
||||
version_file = Path(cls.migrate_location, version)
|
||||
content = MIGRATE_TEMPLATE.format(
|
||||
upgrade_sql=",\n ".join(map(lambda x: f'"""{x}"""', cls.upgrade_operators)),
|
||||
downgrade_sql=",\n ".join(map(lambda x: f'"""{x}"""', cls.downgrade_operators)),
|
||||
)
|
||||
with open(version_file, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
return version
|
||||
|
||||
@classmethod
|
||||
@@ -133,7 +148,7 @@ class Migrate:
|
||||
if not cls.upgrade_operators:
|
||||
return ""
|
||||
|
||||
return await cls._generate_diff_sql(name)
|
||||
return await cls._generate_diff_py(name)
|
||||
|
||||
@classmethod
|
||||
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False):
|
||||
|
||||
Reference in New Issue
Block a user