fix: no migration occurs when adding unique true to indexed field (#414)

* feat: alter unique for indexed column

* chore: update docs and change some var names
This commit is contained in:
Waket Zheng 2025-02-20 16:58:32 +08:00 committed by GitHub
parent c35282c2a3
commit 41df464e8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 130 additions and 34 deletions

View File

@ -10,6 +10,7 @@
#### Fixed #### Fixed
- fix: aerich migrate raises tortoise.exceptions.FieldError when `index.INDEX_TYPE` is not empty. ([#415]) - fix: aerich migrate raises tortoise.exceptions.FieldError when `index.INDEX_TYPE` is not empty. ([#415])
- No migration occurs as expected when adding `unique=True` to indexed field. ([#404])
- fix: inspectdb raise KeyError 'int2' for smallint. ([#401]) - fix: inspectdb raise KeyError 'int2' for smallint. ([#401])
- fix: inspectdb not match data type 'DOUBLE' and 'CHAR' for MySQL. ([#187]) - fix: inspectdb not match data type 'DOUBLE' and 'CHAR' for MySQL. ([#187])
@ -18,6 +19,7 @@
[#398]: https://github.com/tortoise/aerich/pull/398 [#398]: https://github.com/tortoise/aerich/pull/398
[#401]: https://github.com/tortoise/aerich/pull/401 [#401]: https://github.com/tortoise/aerich/pull/401
[#404]: https://github.com/tortoise/aerich/pull/404
[#412]: https://github.com/tortoise/aerich/pull/412 [#412]: https://github.com/tortoise/aerich/pull/412
[#415]: https://github.com/tortoise/aerich/pull/415 [#415]: https://github.com/tortoise/aerich/pull/415
[#417]: https://github.com/tortoise/aerich/pull/417 [#417]: https://github.com/tortoise/aerich/pull/417

View File

@ -43,6 +43,10 @@ class BaseDDL:
self.client = client self.client = client
self.schema_generator = self.schema_generator_cls(client) self.schema_generator = self.schema_generator_cls(client)
@staticmethod
def get_table_name(model: type[Model]) -> str:
return model._meta.db_table
def create_table(self, model: type[Model]) -> str: def create_table(self, model: type[Model]) -> str:
schema = self.schema_generator._get_table_sql(model, True)["table_creation_string"] schema = self.schema_generator._get_table_sql(model, True)["table_creation_string"]
if tortoise.__version__ <= "0.23.0": if tortoise.__version__ <= "0.23.0":
@ -109,8 +113,6 @@ class BaseDDL:
) )
except NotImplementedError: except NotImplementedError:
default = "" default = ""
else:
default = None
return default return default
def add_column(self, model: type[Model], field_describe: dict, is_pk: bool = False) -> str: def add_column(self, model: type[Model], field_describe: dict, is_pk: bool = False) -> str:
@ -276,3 +278,17 @@ class BaseDDL:
return self._RENAME_TABLE_TEMPLATE.format( return self._RENAME_TABLE_TEMPLATE.format(
table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name
) )
def alter_indexed_column_unique(
self, model: type[Model], field_name: str, drop: bool = False
) -> list[str]:
"""Change unique constraint for indexed field, e.g.: Field(db_index=True) --> Field(unique=True)"""
fields = [field_name]
if drop:
drop_unique = self.drop_index(model, fields, unique=True)
add_normal_index = self.add_index(model, fields, unique=False)
return [drop_unique, add_normal_index]
else:
drop_index = self.drop_index(model, fields, unique=False)
add_unique_index = self.add_index(model, fields, unique=True)
return [drop_index, add_unique_index]

View File

@ -25,6 +25,12 @@ class MysqlDDL(BaseDDL):
) )
_ADD_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` ADD {index_type}{unique}INDEX `{index_name}` ({column_names}){extra}" _ADD_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` ADD {index_type}{unique}INDEX `{index_name}` ({column_names}){extra}"
_DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`" _DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`"
_ADD_INDEXED_UNIQUE_TEMPLATE = (
"ALTER TABLE `{table_name}` DROP INDEX `{index_name}`, ADD UNIQUE (`{column_name}`)"
)
_DROP_INDEXED_UNIQUE_TEMPLATE = (
"ALTER TABLE `{table_name}` DROP INDEX `{column_name}`, ADD INDEX (`{index_name}`)"
)
_ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}" _ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}"
_DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`" _DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`"
_M2M_TABLE_TEMPLATE = ( _M2M_TABLE_TEMPLATE = (
@ -47,3 +53,13 @@ class MysqlDDL(BaseDDL):
else: else:
index_prefix = "idx" index_prefix = "idx"
return self.schema_generator._generate_index_name(index_prefix, model, field_names) return self.schema_generator._generate_index_name(index_prefix, model, field_names)
def alter_indexed_column_unique(
self, model: type[Model], field_name: str, drop: bool = False
) -> list[str]:
# if drop is false: Drop index and add unique
# else: Drop unique index and add normal index
template = self._DROP_INDEXED_UNIQUE_TEMPLATE if drop else self._ADD_INDEXED_UNIQUE_TEMPLATE
table = self.get_table_name(model)
index = self._index_name(unique=False, model=model, field_names=[field_name])
return [template.format(table_name=table, index_name=index, column_name=field_name)]

View File

@ -632,7 +632,9 @@ class Migrate:
# indexed include it # indexed include it
continue continue
# Change unique for indexed field, e.g.: `db_index=True, unique=False` --> `db_index=True, unique=True` # Change unique for indexed field, e.g.: `db_index=True, unique=False` --> `db_index=True, unique=True`
# TODO drop_unique = old_new[0] is True and old_new[1] is False
for sql in cls.ddl.alter_indexed_column_unique(model, field_name, drop_unique):
cls._add_operator(sql, upgrade, True)
elif option == "nullable": elif option == "nullable":
# change nullable # change nullable
cls._add_operator(cls._alter_null(model, new_data_field), upgrade) cls._add_operator(cls._alter_null(model, new_data_field), upgrade)

View File

@ -18,6 +18,7 @@ async def test_allow_duplicate() -> None:
async def test_unique_is_true() -> None: async def test_unique_is_true() -> None:
with pytest.raises(IntegrityError): with pytest.raises(IntegrityError):
await Foo.create(name="foo") await Foo.create(name="foo")
await Foo.create(name="foo")
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@ -49,6 +49,7 @@ class User(Model):
class Email(Model): class Email(Model):
email_id = fields.IntField(primary_key=True) email_id = fields.IntField(primary_key=True)
email = fields.CharField(max_length=200, db_index=True) email = fields.CharField(max_length=200, db_index=True)
company = fields.CharField(max_length=100, db_index=True, unique=True)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
address = fields.CharField(max_length=200) address = fields.CharField(max_length=200)
users: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User") users: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User")

View File

@ -40,6 +40,7 @@ class User(Model):
class Email(Model): class Email(Model):
email = fields.CharField(max_length=200) email = fields.CharField(max_length=200)
company = fields.CharField(max_length=100, db_index=True)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", db_constraint=False "models.User", db_constraint=False

View File

@ -365,6 +365,21 @@ old_models_describe = {
"constraints": {"max_length": 200}, "constraints": {"max_length": 200},
"db_field_types": {"": "VARCHAR(200)"}, "db_field_types": {"": "VARCHAR(200)"},
}, },
{
"name": "company",
"field_type": "CharField",
"db_column": "company",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 100},
"db_field_types": {"": "VARCHAR(100)"},
},
{ {
"name": "is_primary", "name": "is_primary",
"field_type": "BooleanField", "field_type": "BooleanField",
@ -929,6 +944,7 @@ def test_migrate(mocker: MockerFixture):
- drop fk field: Email.user - drop fk field: Email.user
- drop field: User.avatar - drop field: User.avatar
- add index: Email.email - add index: Email.email
- add unique to indexed field: Email.company
- change index type for indexed field: Email.slug - change index type for indexed field: Email.slug
- add many to many: Email.users - add many to many: Email.users
- add one to one: Email.config - add one to one: Email.config
@ -977,6 +993,7 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL", "ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL",
"ALTER TABLE `email` ADD CONSTRAINT `fk_email_config_76a9dc71` FOREIGN KEY (`config_id`) REFERENCES `config` (`id`) ON DELETE CASCADE", "ALTER TABLE `email` ADD CONSTRAINT `fk_email_config_76a9dc71` FOREIGN KEY (`config_id`) REFERENCES `config` (`id`) ON DELETE CASCADE",
"ALTER TABLE `email` ADD `config_id` INT NOT NULL UNIQUE", "ALTER TABLE `email` ADD `config_id` INT NOT NULL UNIQUE",
"ALTER TABLE `email` DROP INDEX `idx_email_company_1c9234`, ADD UNIQUE (`company`)",
"ALTER TABLE `configs` RENAME TO `config`", "ALTER TABLE `configs` RENAME TO `config`",
"ALTER TABLE `product` DROP COLUMN `uuid`", "ALTER TABLE `product` DROP COLUMN `uuid`",
"ALTER TABLE `product` DROP INDEX `uuid`", "ALTER TABLE `product` DROP INDEX `uuid`",
@ -1019,20 +1036,21 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `config` ADD UNIQUE INDEX `name` (`name`)", "ALTER TABLE `config` ADD UNIQUE INDEX `name` (`name`)",
"ALTER TABLE `config` DROP FOREIGN KEY `fk_config_user_17daa970`", "ALTER TABLE `config` DROP FOREIGN KEY `fk_config_user_17daa970`",
"ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1", "ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1",
"ALTER TABLE `email` ADD `user_id` INT NOT NULL",
"ALTER TABLE `config` DROP COLUMN `user_id`", "ALTER TABLE `config` DROP COLUMN `user_id`",
"ALTER TABLE `config` RENAME TO `configs`",
"ALTER TABLE `email` ADD `user_id` INT NOT NULL",
"ALTER TABLE `email` DROP COLUMN `address`", "ALTER TABLE `email` DROP COLUMN `address`",
"ALTER TABLE `email` DROP COLUMN `config_id`", "ALTER TABLE `email` DROP COLUMN `config_id`",
"ALTER TABLE `email` DROP FOREIGN KEY `fk_email_config_76a9dc71`", "ALTER TABLE `email` DROP FOREIGN KEY `fk_email_config_76a9dc71`",
"ALTER TABLE `config` RENAME TO `configs`",
"ALTER TABLE `product` RENAME COLUMN `pic` TO `image`",
"ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`", "ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`",
"ALTER TABLE `email` DROP INDEX `company`, ADD INDEX (`idx_email_company_1c9234`)",
"ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`",
"ALTER TABLE `product` RENAME COLUMN `pic` TO `image`",
"ALTER TABLE `product` ADD `uuid` INT NOT NULL UNIQUE", "ALTER TABLE `product` ADD `uuid` INT NOT NULL UNIQUE",
"ALTER TABLE `product` ADD UNIQUE INDEX `uuid` (`uuid`)", "ALTER TABLE `product` ADD UNIQUE INDEX `uuid` (`uuid`)",
"ALTER TABLE `product` DROP INDEX `idx_product_name_869427`", "ALTER TABLE `product` DROP INDEX `idx_product_name_869427`",
"ALTER TABLE `product` DROP COLUMN `price`", "ALTER TABLE `product` DROP COLUMN `price`",
"ALTER TABLE `product` DROP COLUMN `no`", "ALTER TABLE `product` DROP COLUMN `no`",
"ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`",
"ALTER TABLE `product` DROP INDEX `uid_product_name_869427`", "ALTER TABLE `product` DROP INDEX `uid_product_name_869427`",
"ALTER TABLE `product` DROP INDEX `idx_product_no_e4d701`", "ALTER TABLE `product` DROP INDEX `idx_product_no_e4d701`",
"ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT", "ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT",
@ -1074,6 +1092,8 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "email" DROP COLUMN "user_id"', 'ALTER TABLE "email" DROP COLUMN "user_id"',
'ALTER TABLE "email" ADD CONSTRAINT "fk_email_config_76a9dc71" FOREIGN KEY ("config_id") REFERENCES "config" ("id") ON DELETE CASCADE', 'ALTER TABLE "email" ADD CONSTRAINT "fk_email_config_76a9dc71" FOREIGN KEY ("config_id") REFERENCES "config" ("id") ON DELETE CASCADE',
'ALTER TABLE "email" ADD "config_id" INT NOT NULL UNIQUE', 'ALTER TABLE "email" ADD "config_id" INT NOT NULL UNIQUE',
'DROP INDEX IF EXISTS "idx_email_company_1c9234"',
'CREATE UNIQUE INDEX IF NOT EXISTS "uid_email_company_1c9234" ON "email" ("company")',
'DROP INDEX IF EXISTS "uid_product_uuid_d33c18"', 'DROP INDEX IF EXISTS "uid_product_uuid_d33c18"',
'ALTER TABLE "product" DROP COLUMN "uuid"', 'ALTER TABLE "product" DROP COLUMN "uuid"',
'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0', 'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0',
@ -1121,6 +1141,8 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"', 'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"',
'ALTER TABLE "email" DROP COLUMN "config_id"', 'ALTER TABLE "email" DROP COLUMN "config_id"',
'ALTER TABLE "email" DROP CONSTRAINT IF EXISTS "fk_email_config_76a9dc71"', 'ALTER TABLE "email" DROP CONSTRAINT IF EXISTS "fk_email_config_76a9dc71"',
'CREATE INDEX IF NOT EXISTS "idx_email_company_1c9234" ON "email" ("company")',
'DROP INDEX IF EXISTS "uid_email_company_1c9234"',
'ALTER TABLE "product" ADD "uuid" INT NOT NULL UNIQUE', 'ALTER TABLE "product" ADD "uuid" INT NOT NULL UNIQUE',
'CREATE UNIQUE INDEX IF NOT EXISTS "uid_product_uuid_d33c18" ON "product" ("uuid")', 'CREATE UNIQUE INDEX IF NOT EXISTS "uid_product_uuid_d33c18" ON "product" ("uuid")',
'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT', 'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT',

View File

@ -1,13 +1,15 @@
from __future__ import annotations
import contextlib import contextlib
import os import os
import shlex import shlex
import shutil import shutil
import subprocess import subprocess
from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from aerich.ddl.sqlite import SqliteDDL from tests._utils import Dialect, chdir, copy_files
from aerich.migrate import Migrate
from tests._utils import chdir, copy_files
def run_aerich(cmd: str) -> None: def run_aerich(cmd: str) -> None:
@ -22,9 +24,14 @@ def run_shell(cmd: str) -> subprocess.CompletedProcess:
return subprocess.run(shlex.split(cmd), env=envs) return subprocess.run(shlex.split(cmd), env=envs)
def test_sqlite_migrate(tmp_path: Path) -> None: def _get_empty_db() -> Path:
if (ddl := getattr(Migrate, "ddl", None)) and not isinstance(ddl, SqliteDDL): if (db_file := Path("db.sqlite3")).exists():
return db_file.unlink()
return db_file
@contextmanager
def prepare_sqlite_project(tmp_path: Path) -> Generator[tuple[Path, str]]:
test_dir = Path(__file__).parent test_dir = Path(__file__).parent
asset_dir = test_dir / "assets" / "sqlite_migrate" asset_dir = test_dir / "assets" / "sqlite_migrate"
with chdir(tmp_path): with chdir(tmp_path):
@ -32,9 +39,52 @@ def test_sqlite_migrate(tmp_path: Path) -> None:
copy_files(*(asset_dir / f for f in files), target_dir=Path()) copy_files(*(asset_dir / f for f in files), target_dir=Path())
models_py, settings_py, test_py = (Path(f) for f in files) models_py, settings_py, test_py = (Path(f) for f in files)
copy_files(asset_dir / "conftest_.py", target_dir=Path("conftest.py")) copy_files(asset_dir / "conftest_.py", target_dir=Path("conftest.py"))
if (db_file := Path("db.sqlite3")).exists(): _get_empty_db()
db_file.unlink() yield models_py, models_py.read_text("utf-8")
MODELS = models_py.read_text("utf-8")
def test_sqlite_migrate_alter_indexed_unique(tmp_path: Path) -> None:
if not Dialect.is_sqlite():
return
with prepare_sqlite_project(tmp_path) as (models_py, models_text):
models_py.write_text(models_text.replace("db_index=False", "db_index=True"))
run_aerich("aerich init -t settings.TORTOISE_ORM")
run_aerich("aerich init-db")
r = run_shell("pytest -s _tests.py::test_allow_duplicate")
assert r.returncode == 0
models_py.write_text(models_text.replace("db_index=False", "unique=True"))
run_aerich("aerich migrate") # migrations/models/1_
run_aerich("aerich upgrade")
r = run_shell("pytest _tests.py::test_unique_is_true")
assert r.returncode == 0
models_py.write_text(models_text.replace("db_index=False", "db_index=True"))
run_aerich("aerich migrate") # migrations/models/2_
run_aerich("aerich upgrade")
r = run_shell("pytest -s _tests.py::test_allow_duplicate")
assert r.returncode == 0
M2M_WITH_CUSTOM_THROUGH = """
groups = fields.ManyToManyField("models.Group", through="foo_group")
class Group(Model):
name = fields.CharField(max_length=60)
class FooGroup(Model):
foo = fields.ForeignKeyField("models.Foo")
group = fields.ForeignKeyField("models.Group")
is_active = fields.BooleanField(default=False)
class Meta:
table = "foo_group"
"""
def test_sqlite_migrate(tmp_path: Path) -> None:
if not Dialect.is_sqlite():
return
with prepare_sqlite_project(tmp_path) as (models_py, models_text):
MODELS = models_text
run_aerich("aerich init -t settings.TORTOISE_ORM") run_aerich("aerich init -t settings.TORTOISE_ORM")
config_file = Path("pyproject.toml") config_file = Path("pyproject.toml")
modify_time = config_file.stat().st_mtime modify_time = config_file.stat().st_mtime
@ -84,7 +134,7 @@ def test_sqlite_migrate(tmp_path: Path) -> None:
# Initial with indexed field and then drop it # Initial with indexed field and then drop it
migrations_dir = Path("migrations/models") migrations_dir = Path("migrations/models")
shutil.rmtree(migrations_dir) shutil.rmtree(migrations_dir)
db_file.unlink() db_file = _get_empty_db()
models_py.write_text(MODELS + " age = fields.IntField(db_index=True)") models_py.write_text(MODELS + " age = fields.IntField(db_index=True)")
run_aerich("aerich init -t settings.TORTOISE_ORM") run_aerich("aerich init -t settings.TORTOISE_ORM")
run_aerich("aerich init-db") run_aerich("aerich init-db")
@ -119,21 +169,7 @@ def test_sqlite_migrate(tmp_path: Path) -> None:
assert "[tool.aerich]" in config_file.read_text() assert "[tool.aerich]" in config_file.read_text()
# add m2m with custom model for through # add m2m with custom model for through
new = """ models_py.write_text(MODELS + M2M_WITH_CUSTOM_THROUGH)
groups = fields.ManyToManyField("models.Group", through="foo_group")
class Group(Model):
name = fields.CharField(max_length=60)
class FooGroup(Model):
foo = fields.ForeignKeyField("models.Foo")
group = fields.ForeignKeyField("models.Group")
is_active = fields.BooleanField(default=False)
class Meta:
table = "foo_group"
"""
models_py.write_text(MODELS + new)
run_aerich("aerich migrate") run_aerich("aerich migrate")
run_aerich("aerich upgrade") run_aerich("aerich upgrade")
migration_file_1 = list(migrations_dir.glob("1_*.py"))[0] migration_file_1 = list(migrations_dir.glob("1_*.py"))[0]
@ -148,8 +184,7 @@ class FooGroup(Model):
class Group(Model): class Group(Model):
name = fields.CharField(max_length=60) name = fields.CharField(max_length=60)
""" """
if db_file.exists(): _get_empty_db()
db_file.unlink()
if migrations_dir.exists(): if migrations_dir.exists():
shutil.rmtree(migrations_dir) shutil.rmtree(migrations_dir)
models_py.write_text(MODELS) models_py.write_text(MODELS)