Merge pull request #327 from ar0ne/dev

Added an option to generate empty migration file
This commit is contained in:
long2ice 2024-01-18 09:46:02 +08:00 committed by GitHub
commit 4370b5ed08
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 49 additions and 9 deletions

View File

@ -113,6 +113,14 @@ If `aerich` guesses you are renaming a column, it will ask `Rename {old_column}
`True` to rename column without column drop, or choose `False` to drop the column then create. Note that the latter may `True` to rename column without column drop, or choose `False` to drop the column then create. Note that the latter may
lose data. lose data.
If you need to manually write migration, you could generate empty file:
```shell
> aerich migrate --name add_index --empty
Success migrate 1_202326122220101229_add_index.py
```
### Upgrade to latest version ### Upgrade to latest version
```shell ```shell

View File

@ -123,8 +123,8 @@ class Command:
inspect = cls(connection, tables) inspect = cls(connection, tables)
return await inspect.inspect() return await inspect.inspect()
async def migrate(self, name: str = "update"): async def migrate(self, name: str = "update", empty: bool = False) -> str:
return await Migrate.migrate(name) return await Migrate.migrate(name, empty)
async def init_db(self, safe: bool): async def init_db(self, safe: bool):
location = self.location location = self.location

View File

@ -79,6 +79,7 @@ async def cli(ctx: Context, config, app):
@cli.command(help="Generate migrate changes file.") @cli.command(help="Generate migrate changes file.")
@click.option("--name", default="update", show_default=True, help="Migrate name.") @click.option("--name", default="update", show_default=True, help="Migrate name.")
@click.option("--empty", default=False, is_flag=True, help="Generate empty migration file.")
@click.pass_context @click.pass_context
@coro @coro
async def migrate(ctx: Context, name): async def migrate(ctx: Context, name):

View File

@ -120,22 +120,23 @@ class Migrate:
os.unlink(Path(cls.migrate_location, version_file)) os.unlink(Path(cls.migrate_location, version_file))
version_file = Path(cls.migrate_location, version) version_file = Path(cls.migrate_location, version)
content = MIGRATE_TEMPLATE.format( content = cls._get_diff_file_content()
upgrade_sql=";\n ".join(cls.upgrade_operators) + ";",
downgrade_sql=";\n ".join(cls.downgrade_operators) + ";",
)
with open(version_file, "w", encoding="utf-8") as f: with open(version_file, "w", encoding="utf-8") as f:
f.write(content) f.write(content)
return version return version
@classmethod @classmethod
async def migrate(cls, name) -> str: async def migrate(cls, name: str, empty: bool) -> str:
""" """
diff old models and new models to generate diff content diff old models and new models to generate diff content
:param name: :param name: str name for migration
:param empty: bool if True generates empty migration
:return: :return:
""" """
if empty:
return await cls._generate_diff_py(name)
new_version_content = get_models_describe(cls.app) new_version_content = get_models_describe(cls.app)
cls.diff_models(cls._last_version_content, new_version_content) cls.diff_models(cls._last_version_content, new_version_content)
cls.diff_models(new_version_content, cls._last_version_content, False) cls.diff_models(new_version_content, cls._last_version_content, False)
@ -147,6 +148,21 @@ class Migrate:
return await cls._generate_diff_py(name) return await cls._generate_diff_py(name)
@classmethod
def _get_diff_file_content(cls) -> str:
"""
builds content for diff file from template
"""
def join_lines(lines: List[str]) -> str:
if not lines:
return ""
return ";\n ".join(lines) + ";"
return MIGRATE_TEMPLATE.format(
upgrade_sql=join_lines(cls.upgrade_operators),
downgrade_sql=join_lines(cls.downgrade_operators)
)
@classmethod @classmethod
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False): def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False):
""" """

View File

@ -1,3 +1,6 @@
import tempfile
from pathlib import Path
import pytest import pytest
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
@ -5,7 +8,7 @@ from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
from aerich.exceptions import NotSupportError from aerich.exceptions import NotSupportError
from aerich.migrate import Migrate from aerich.migrate import MIGRATE_TEMPLATE, Migrate
from aerich.utils import get_models_describe from aerich.utils import get_models_describe
old_models_describe = { old_models_describe = {
@ -966,3 +969,15 @@ def test_sort_all_version_files(mocker):
"10_datetime_update.py", "10_datetime_update.py",
"11_datetime_update.py", "11_datetime_update.py",
] ]
async def test_empty_migration(mocker) -> None:
mocker.patch("os.listdir", return_value=[])
Migrate.app = "foo"
expected_content = MIGRATE_TEMPLATE.format(upgrade_sql="", downgrade_sql="")
with tempfile.TemporaryDirectory() as temp_dir:
Migrate.migrate_location = temp_dir
migration_file = await Migrate.migrate("update", True)
with open(Path(temp_dir, migration_file), "r") as f:
assert f.read() == expected_content