From 41df464e8b0c1e1c340b281f6d5f75c55f79ad4f Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Thu, 20 Feb 2025 16:58:32 +0800 Subject: [PATCH] fix: no migration occurs when adding unique true to indexed field (#414) * feat: alter unique for indexed column * chore: update docs and change some var names --- CHANGELOG.md | 2 + aerich/ddl/__init__.py | 20 +++++- aerich/ddl/mysql/__init__.py | 16 +++++ aerich/migrate.py | 4 +- tests/assets/sqlite_migrate/_tests.py | 1 + tests/models.py | 1 + tests/old_models.py | 1 + tests/test_migrate.py | 30 +++++++-- tests/test_sqlite_migrate.py | 89 +++++++++++++++++++-------- 9 files changed, 130 insertions(+), 34 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 657e1a4..416994a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ #### Fixed - fix: aerich migrate raises tortoise.exceptions.FieldError when `index.INDEX_TYPE` is not empty. ([#415]) +- No migration occurs as expected when adding `unique=True` to indexed field. ([#404]) - fix: inspectdb raise KeyError 'int2' for smallint. ([#401]) - fix: inspectdb not match data type 'DOUBLE' and 'CHAR' for MySQL. ([#187]) @@ -18,6 +19,7 @@ [#398]: https://github.com/tortoise/aerich/pull/398 [#401]: https://github.com/tortoise/aerich/pull/401 +[#404]: https://github.com/tortoise/aerich/pull/404 [#412]: https://github.com/tortoise/aerich/pull/412 [#415]: https://github.com/tortoise/aerich/pull/415 [#417]: https://github.com/tortoise/aerich/pull/417 diff --git a/aerich/ddl/__init__.py b/aerich/ddl/__init__.py index dc7e09c..e88c338 100644 --- a/aerich/ddl/__init__.py +++ b/aerich/ddl/__init__.py @@ -43,6 +43,10 @@ class BaseDDL: self.client = client self.schema_generator = self.schema_generator_cls(client) + @staticmethod + def get_table_name(model: type[Model]) -> str: + return model._meta.db_table + def create_table(self, model: type[Model]) -> str: schema = self.schema_generator._get_table_sql(model, True)["table_creation_string"] if tortoise.__version__ <= "0.23.0": @@ -109,8 +113,6 @@ class BaseDDL: ) except NotImplementedError: default = "" - else: - default = None return default def add_column(self, model: type[Model], field_describe: dict, is_pk: bool = False) -> str: @@ -276,3 +278,17 @@ class BaseDDL: return self._RENAME_TABLE_TEMPLATE.format( table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name ) + + def alter_indexed_column_unique( + self, model: type[Model], field_name: str, drop: bool = False + ) -> list[str]: + """Change unique constraint for indexed field, e.g.: Field(db_index=True) --> Field(unique=True)""" + fields = [field_name] + if drop: + drop_unique = self.drop_index(model, fields, unique=True) + add_normal_index = self.add_index(model, fields, unique=False) + return [drop_unique, add_normal_index] + else: + drop_index = self.drop_index(model, fields, unique=False) + add_unique_index = self.add_index(model, fields, unique=True) + return [drop_index, add_unique_index] diff --git a/aerich/ddl/mysql/__init__.py b/aerich/ddl/mysql/__init__.py index 4e7e397..1d1e8c4 100644 --- a/aerich/ddl/mysql/__init__.py +++ b/aerich/ddl/mysql/__init__.py @@ -25,6 +25,12 @@ class MysqlDDL(BaseDDL): ) _ADD_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` ADD {index_type}{unique}INDEX `{index_name}` ({column_names}){extra}" _DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`" + _ADD_INDEXED_UNIQUE_TEMPLATE = ( + "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`, ADD UNIQUE (`{column_name}`)" + ) + _DROP_INDEXED_UNIQUE_TEMPLATE = ( + "ALTER TABLE `{table_name}` DROP INDEX `{column_name}`, ADD INDEX (`{index_name}`)" + ) _ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}" _DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`" _M2M_TABLE_TEMPLATE = ( @@ -47,3 +53,13 @@ class MysqlDDL(BaseDDL): else: index_prefix = "idx" return self.schema_generator._generate_index_name(index_prefix, model, field_names) + + def alter_indexed_column_unique( + self, model: type[Model], field_name: str, drop: bool = False + ) -> list[str]: + # if drop is false: Drop index and add unique + # else: Drop unique index and add normal index + template = self._DROP_INDEXED_UNIQUE_TEMPLATE if drop else self._ADD_INDEXED_UNIQUE_TEMPLATE + table = self.get_table_name(model) + index = self._index_name(unique=False, model=model, field_names=[field_name]) + return [template.format(table_name=table, index_name=index, column_name=field_name)] diff --git a/aerich/migrate.py b/aerich/migrate.py index 8332d75..08190b2 100644 --- a/aerich/migrate.py +++ b/aerich/migrate.py @@ -632,7 +632,9 @@ class Migrate: # indexed include it continue # Change unique for indexed field, e.g.: `db_index=True, unique=False` --> `db_index=True, unique=True` - # TODO + drop_unique = old_new[0] is True and old_new[1] is False + for sql in cls.ddl.alter_indexed_column_unique(model, field_name, drop_unique): + cls._add_operator(sql, upgrade, True) elif option == "nullable": # change nullable cls._add_operator(cls._alter_null(model, new_data_field), upgrade) diff --git a/tests/assets/sqlite_migrate/_tests.py b/tests/assets/sqlite_migrate/_tests.py index 6fe8d46..4664c3b 100644 --- a/tests/assets/sqlite_migrate/_tests.py +++ b/tests/assets/sqlite_migrate/_tests.py @@ -18,6 +18,7 @@ async def test_allow_duplicate() -> None: async def test_unique_is_true() -> None: with pytest.raises(IntegrityError): await Foo.create(name="foo") + await Foo.create(name="foo") @pytest.mark.asyncio diff --git a/tests/models.py b/tests/models.py index 6879da7..e89a496 100644 --- a/tests/models.py +++ b/tests/models.py @@ -49,6 +49,7 @@ class User(Model): class Email(Model): email_id = fields.IntField(primary_key=True) email = fields.CharField(max_length=200, db_index=True) + company = fields.CharField(max_length=100, db_index=True, unique=True) is_primary = fields.BooleanField(default=False) address = fields.CharField(max_length=200) users: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User") diff --git a/tests/old_models.py b/tests/old_models.py index c7cd4ff..cbe3131 100644 --- a/tests/old_models.py +++ b/tests/old_models.py @@ -40,6 +40,7 @@ class User(Model): class Email(Model): email = fields.CharField(max_length=200) + company = fields.CharField(max_length=100, db_index=True) is_primary = fields.BooleanField(default=False) user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( "models.User", db_constraint=False diff --git a/tests/test_migrate.py b/tests/test_migrate.py index b4152f0..e078904 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -365,6 +365,21 @@ old_models_describe = { "constraints": {"max_length": 200}, "db_field_types": {"": "VARCHAR(200)"}, }, + { + "name": "company", + "field_type": "CharField", + "db_column": "company", + "python_type": "str", + "generated": False, + "nullable": False, + "unique": False, + "indexed": True, + "default": None, + "description": None, + "docstring": None, + "constraints": {"max_length": 100}, + "db_field_types": {"": "VARCHAR(100)"}, + }, { "name": "is_primary", "field_type": "BooleanField", @@ -929,6 +944,7 @@ def test_migrate(mocker: MockerFixture): - drop fk field: Email.user - drop field: User.avatar - add index: Email.email + - add unique to indexed field: Email.company - change index type for indexed field: Email.slug - add many to many: Email.users - add one to one: Email.config @@ -977,6 +993,7 @@ def test_migrate(mocker: MockerFixture): "ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL", "ALTER TABLE `email` ADD CONSTRAINT `fk_email_config_76a9dc71` FOREIGN KEY (`config_id`) REFERENCES `config` (`id`) ON DELETE CASCADE", "ALTER TABLE `email` ADD `config_id` INT NOT NULL UNIQUE", + "ALTER TABLE `email` DROP INDEX `idx_email_company_1c9234`, ADD UNIQUE (`company`)", "ALTER TABLE `configs` RENAME TO `config`", "ALTER TABLE `product` DROP COLUMN `uuid`", "ALTER TABLE `product` DROP INDEX `uuid`", @@ -1019,20 +1036,21 @@ def test_migrate(mocker: MockerFixture): "ALTER TABLE `config` ADD UNIQUE INDEX `name` (`name`)", "ALTER TABLE `config` DROP FOREIGN KEY `fk_config_user_17daa970`", "ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1", - "ALTER TABLE `email` ADD `user_id` INT NOT NULL", "ALTER TABLE `config` DROP COLUMN `user_id`", + "ALTER TABLE `config` RENAME TO `configs`", + "ALTER TABLE `email` ADD `user_id` INT NOT NULL", "ALTER TABLE `email` DROP COLUMN `address`", "ALTER TABLE `email` DROP COLUMN `config_id`", "ALTER TABLE `email` DROP FOREIGN KEY `fk_email_config_76a9dc71`", - "ALTER TABLE `config` RENAME TO `configs`", - "ALTER TABLE `product` RENAME COLUMN `pic` TO `image`", "ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`", + "ALTER TABLE `email` DROP INDEX `company`, ADD INDEX (`idx_email_company_1c9234`)", + "ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`", + "ALTER TABLE `product` RENAME COLUMN `pic` TO `image`", "ALTER TABLE `product` ADD `uuid` INT NOT NULL UNIQUE", "ALTER TABLE `product` ADD UNIQUE INDEX `uuid` (`uuid`)", "ALTER TABLE `product` DROP INDEX `idx_product_name_869427`", "ALTER TABLE `product` DROP COLUMN `price`", "ALTER TABLE `product` DROP COLUMN `no`", - "ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`", "ALTER TABLE `product` DROP INDEX `uid_product_name_869427`", "ALTER TABLE `product` DROP INDEX `idx_product_no_e4d701`", "ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT", @@ -1074,6 +1092,8 @@ def test_migrate(mocker: MockerFixture): 'ALTER TABLE "email" DROP COLUMN "user_id"', 'ALTER TABLE "email" ADD CONSTRAINT "fk_email_config_76a9dc71" FOREIGN KEY ("config_id") REFERENCES "config" ("id") ON DELETE CASCADE', 'ALTER TABLE "email" ADD "config_id" INT NOT NULL UNIQUE', + 'DROP INDEX IF EXISTS "idx_email_company_1c9234"', + 'CREATE UNIQUE INDEX IF NOT EXISTS "uid_email_company_1c9234" ON "email" ("company")', 'DROP INDEX IF EXISTS "uid_product_uuid_d33c18"', 'ALTER TABLE "product" DROP COLUMN "uuid"', 'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0', @@ -1121,6 +1141,8 @@ def test_migrate(mocker: MockerFixture): 'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"', 'ALTER TABLE "email" DROP COLUMN "config_id"', 'ALTER TABLE "email" DROP CONSTRAINT IF EXISTS "fk_email_config_76a9dc71"', + 'CREATE INDEX IF NOT EXISTS "idx_email_company_1c9234" ON "email" ("company")', + 'DROP INDEX IF EXISTS "uid_email_company_1c9234"', 'ALTER TABLE "product" ADD "uuid" INT NOT NULL UNIQUE', 'CREATE UNIQUE INDEX IF NOT EXISTS "uid_product_uuid_d33c18" ON "product" ("uuid")', 'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT', diff --git a/tests/test_sqlite_migrate.py b/tests/test_sqlite_migrate.py index 735defb..91fb79f 100644 --- a/tests/test_sqlite_migrate.py +++ b/tests/test_sqlite_migrate.py @@ -1,13 +1,15 @@ +from __future__ import annotations + import contextlib import os import shlex import shutil import subprocess +from collections.abc import Generator +from contextlib import contextmanager from pathlib import Path -from aerich.ddl.sqlite import SqliteDDL -from aerich.migrate import Migrate -from tests._utils import chdir, copy_files +from tests._utils import Dialect, chdir, copy_files def run_aerich(cmd: str) -> None: @@ -22,9 +24,14 @@ def run_shell(cmd: str) -> subprocess.CompletedProcess: return subprocess.run(shlex.split(cmd), env=envs) -def test_sqlite_migrate(tmp_path: Path) -> None: - if (ddl := getattr(Migrate, "ddl", None)) and not isinstance(ddl, SqliteDDL): - return +def _get_empty_db() -> Path: + if (db_file := Path("db.sqlite3")).exists(): + db_file.unlink() + return db_file + + +@contextmanager +def prepare_sqlite_project(tmp_path: Path) -> Generator[tuple[Path, str]]: test_dir = Path(__file__).parent asset_dir = test_dir / "assets" / "sqlite_migrate" with chdir(tmp_path): @@ -32,9 +39,52 @@ def test_sqlite_migrate(tmp_path: Path) -> None: copy_files(*(asset_dir / f for f in files), target_dir=Path()) models_py, settings_py, test_py = (Path(f) for f in files) copy_files(asset_dir / "conftest_.py", target_dir=Path("conftest.py")) - if (db_file := Path("db.sqlite3")).exists(): - db_file.unlink() - MODELS = models_py.read_text("utf-8") + _get_empty_db() + yield models_py, models_py.read_text("utf-8") + + +def test_sqlite_migrate_alter_indexed_unique(tmp_path: Path) -> None: + if not Dialect.is_sqlite(): + return + with prepare_sqlite_project(tmp_path) as (models_py, models_text): + models_py.write_text(models_text.replace("db_index=False", "db_index=True")) + run_aerich("aerich init -t settings.TORTOISE_ORM") + run_aerich("aerich init-db") + r = run_shell("pytest -s _tests.py::test_allow_duplicate") + assert r.returncode == 0 + models_py.write_text(models_text.replace("db_index=False", "unique=True")) + run_aerich("aerich migrate") # migrations/models/1_ + run_aerich("aerich upgrade") + r = run_shell("pytest _tests.py::test_unique_is_true") + assert r.returncode == 0 + models_py.write_text(models_text.replace("db_index=False", "db_index=True")) + run_aerich("aerich migrate") # migrations/models/2_ + run_aerich("aerich upgrade") + r = run_shell("pytest -s _tests.py::test_allow_duplicate") + assert r.returncode == 0 + + +M2M_WITH_CUSTOM_THROUGH = """ + groups = fields.ManyToManyField("models.Group", through="foo_group") + +class Group(Model): + name = fields.CharField(max_length=60) + +class FooGroup(Model): + foo = fields.ForeignKeyField("models.Foo") + group = fields.ForeignKeyField("models.Group") + is_active = fields.BooleanField(default=False) + + class Meta: + table = "foo_group" +""" + + +def test_sqlite_migrate(tmp_path: Path) -> None: + if not Dialect.is_sqlite(): + return + with prepare_sqlite_project(tmp_path) as (models_py, models_text): + MODELS = models_text run_aerich("aerich init -t settings.TORTOISE_ORM") config_file = Path("pyproject.toml") modify_time = config_file.stat().st_mtime @@ -84,7 +134,7 @@ def test_sqlite_migrate(tmp_path: Path) -> None: # Initial with indexed field and then drop it migrations_dir = Path("migrations/models") shutil.rmtree(migrations_dir) - db_file.unlink() + db_file = _get_empty_db() models_py.write_text(MODELS + " age = fields.IntField(db_index=True)") run_aerich("aerich init -t settings.TORTOISE_ORM") run_aerich("aerich init-db") @@ -119,21 +169,7 @@ def test_sqlite_migrate(tmp_path: Path) -> None: assert "[tool.aerich]" in config_file.read_text() # add m2m with custom model for through - new = """ - groups = fields.ManyToManyField("models.Group", through="foo_group") - -class Group(Model): - name = fields.CharField(max_length=60) - -class FooGroup(Model): - foo = fields.ForeignKeyField("models.Foo") - group = fields.ForeignKeyField("models.Group") - is_active = fields.BooleanField(default=False) - - class Meta: - table = "foo_group" - """ - models_py.write_text(MODELS + new) + models_py.write_text(MODELS + M2M_WITH_CUSTOM_THROUGH) run_aerich("aerich migrate") run_aerich("aerich upgrade") migration_file_1 = list(migrations_dir.glob("1_*.py"))[0] @@ -148,8 +184,7 @@ class FooGroup(Model): class Group(Model): name = fields.CharField(max_length=60) """ - if db_file.exists(): - db_file.unlink() + _get_empty_db() if migrations_dir.exists(): shutil.rmtree(migrations_dir) models_py.write_text(MODELS)