From 9c81bc6036719e6904d7ce72ea763180065d5d63 Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Tue, 10 Dec 2024 16:37:30 +0800 Subject: [PATCH] Fix sqlite create/drop index (#379) * Update add/drop index template for sqlite * tests: add sqlite migrate/upgrade command test * tests: add timeout for sqlite migrate command test * tests: add test cases for add/drop unique field for sqlite * fix: sqlite failed to add unique field --- CHANGELOG.md | 1 + aerich/ddl/__init__.py | 8 +- aerich/ddl/sqlite/__init__.py | 2 + tests/test_ddl.py | 19 +-- tests/test_sqlite_migrate.py | 217 ++++++++++++++++++++++++++++++++++ 5 files changed, 238 insertions(+), 9 deletions(-) create mode 100644 tests/test_sqlite_migrate.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6093d90..1b2fe83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### [0.8.1](Unreleased) #### Fixed +- sqlite: failed to create/drop index. (#302) - PostgreSQL: Cannot drop constraint after deleting or rename FK on a model. (#378) - Sort m2m fields before comparing them with diff. (#271) diff --git a/aerich/ddl/__init__.py b/aerich/ddl/__init__.py index cb57274..11b355d 100644 --- a/aerich/ddl/__init__.py +++ b/aerich/ddl/__init__.py @@ -3,6 +3,7 @@ from typing import Any, List, Type, cast from tortoise import BaseDBAsyncClient, Model from tortoise.backends.base.schema_generator import BaseSchemaGenerator +from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator from aerich.utils import is_default_function @@ -122,7 +123,12 @@ class BaseDDL: unique = "" template = self._MODIFY_COLUMN_TEMPLATE else: - unique = "UNIQUE" if field_describe.get("unique") else "" + # sqlite does not support alter table to add unique column + unique = ( + "UNIQUE" + if field_describe.get("unique") and self.DIALECT != SqliteSchemaGenerator.DIALECT + else "" + ) template = self._ADD_COLUMN_TEMPLATE return template.format( table_name=db_table, diff --git a/aerich/ddl/sqlite/__init__.py b/aerich/ddl/sqlite/__init__.py index 0ce1290..67dfd3a 100644 --- a/aerich/ddl/sqlite/__init__.py +++ b/aerich/ddl/sqlite/__init__.py @@ -10,6 +10,8 @@ from aerich.exceptions import NotSupportError class SqliteDDL(BaseDDL): schema_generator_cls = SqliteSchemaGenerator DIALECT = SqliteSchemaGenerator.DIALECT + _ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX "{index_name}" ON "{table_name}" ({column_names})' + _DROP_INDEX_TEMPLATE = 'DROP INDEX IF EXISTS "{index_name}"' def modify_column(self, model: "Type[Model]", field_object: dict, is_pk: bool = True): raise NotSupportError("Modify column is unsupported in SQLite.") diff --git a/tests/test_ddl.py b/tests/test_ddl.py index 338dfd2..5a8d659 100644 --- a/tests/test_ddl.py +++ b/tests/test_ddl.py @@ -63,6 +63,14 @@ def test_add_column(): assert ret == "ALTER TABLE `category` ADD `name` VARCHAR(200)" else: assert ret == 'ALTER TABLE "category" ADD "name" VARCHAR(200)' + # add unique column + ret = Migrate.ddl.add_column(User, User._meta.fields_map.get("username").describe(False)) + if isinstance(Migrate.ddl, MysqlDDL): + assert ret == "ALTER TABLE `user` ADD `username` VARCHAR(20) NOT NULL UNIQUE" + elif isinstance(Migrate.ddl, PostgresDDL): + assert ret == 'ALTER TABLE "user" ADD "username" VARCHAR(20) NOT NULL UNIQUE' + else: + assert ret == 'ALTER TABLE "user" ADD "username" VARCHAR(20) NOT NULL' def test_modify_column(): @@ -155,14 +163,9 @@ def test_add_index(): if isinstance(Migrate.ddl, MysqlDDL): assert index == "ALTER TABLE `category` ADD INDEX `idx_category_name_8b0cb9` (`name`)" assert index_u == "ALTER TABLE `category` ADD UNIQUE INDEX `name` (`name`)" - elif isinstance(Migrate.ddl, PostgresDDL): + else: assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")' assert index_u == 'CREATE UNIQUE INDEX "uid_category_name_8b0cb9" ON "category" ("name")' - else: - assert index == 'ALTER TABLE "category" ADD INDEX "idx_category_name_8b0cb9" ("name")' - assert ( - index_u == 'ALTER TABLE "category" ADD UNIQUE INDEX "uid_category_name_8b0cb9" ("name")' - ) def test_drop_index(): @@ -175,8 +178,8 @@ def test_drop_index(): assert ret == 'DROP INDEX "idx_category_name_8b0cb9"' assert ret_u == 'DROP INDEX "uid_category_name_8b0cb9"' else: - assert ret == 'ALTER TABLE "category" DROP INDEX "idx_category_name_8b0cb9"' - assert ret_u == 'ALTER TABLE "category" DROP INDEX "uid_category_name_8b0cb9"' + assert ret == 'DROP INDEX IF EXISTS "idx_category_name_8b0cb9"' + assert ret_u == 'DROP INDEX IF EXISTS "uid_category_name_8b0cb9"' def test_add_fk(): diff --git a/tests/test_sqlite_migrate.py b/tests/test_sqlite_migrate.py new file mode 100644 index 0000000..c14a4dd --- /dev/null +++ b/tests/test_sqlite_migrate.py @@ -0,0 +1,217 @@ +import contextlib +import os +import shlex +import shutil +import subprocess +import sys +from pathlib import Path + +from aerich.ddl.sqlite import SqliteDDL +from aerich.migrate import Migrate + +if sys.version_info >= (3, 11): + from contextlib import chdir +else: + + class chdir(contextlib.AbstractContextManager): # Copied from source code of Python3.13 + """Non thread-safe context manager to change the current working directory.""" + + def __init__(self, path): + self.path = path + self._old_cwd = [] + + def __enter__(self): + self._old_cwd.append(os.getcwd()) + os.chdir(self.path) + + def __exit__(self, *excinfo): + os.chdir(self._old_cwd.pop()) + + +MODELS = """from __future__ import annotations + +from tortoise import Model, fields + + +class Foo(Model): + name = fields.CharField(max_length=60, db_index=False) +""" + +SETTINGS = """from __future__ import annotations + +TORTOISE_ORM = { + "connections": {"default": "sqlite://db.sqlite3"}, + "apps": {"models": {"models": ["models", "aerich.models"]}}, +} +""" + +CONFTEST = """from __future__ import annotations + +import asyncio +from typing import Generator + +import pytest +import pytest_asyncio +from tortoise import Tortoise, connections + +import settings + + +@pytest.fixture(scope="session") +def event_loop() -> Generator: + policy = asyncio.get_event_loop_policy() + res = policy.new_event_loop() + asyncio.set_event_loop(res) + res._close = res.close # type:ignore[attr-defined] + res.close = lambda: None # type:ignore[method-assign] + + yield res + + res._close() # type:ignore[attr-defined] + + +@pytest_asyncio.fixture(scope="session", autouse=True) +async def api(event_loop, request): + await Tortoise.init(config=settings.TORTOISE_ORM) + request.addfinalizer(lambda: event_loop.run_until_complete(connections.close_all(discard=True))) +""" + +TESTS = """from __future__ import annotations + +import uuid + +import pytest +from tortoise.exceptions import IntegrityError + +from models import Foo + + +@pytest.mark.asyncio +async def test_allow_duplicate() -> None: + await Foo.all().delete() + await Foo.create(name="foo") + obj = await Foo.create(name="foo") + assert (await Foo.all().count()) == 2 + await obj.delete() + + +@pytest.mark.asyncio +async def test_unique_is_true() -> None: + with pytest.raises(IntegrityError): + await Foo.create(name="foo") + + +@pytest.mark.asyncio +async def test_add_unique_field() -> None: + if not await Foo.filter(age=0).exists(): + await Foo.create(name="0_"+uuid.uuid4().hex, age=0) + with pytest.raises(IntegrityError): + await Foo.create(name=uuid.uuid4().hex, age=0) + + +@pytest.mark.asyncio +async def test_drop_unique_field() -> None: + name = "1_" + uuid.uuid4().hex + await Foo.create(name=name, age=0) + assert (await Foo.filter(name=name).exists()) + + +@pytest.mark.asyncio +async def test_with_age_field() -> None: + name = "2_" + uuid.uuid4().hex + await Foo.create(name=name, age=0) + obj = await Foo.get(name=name) + assert obj.age == 0 + + +@pytest.mark.asyncio +async def test_without_age_field() -> None: + name = "3_" + uuid.uuid4().hex + await Foo.create(name=name, age=0) + obj = await Foo.get(name=name) + assert getattr(obj, "age", None) is None +""" + + +def run_aerich(cmd: str) -> None: + with contextlib.suppress(subprocess.TimeoutExpired): + if not cmd.startswith("aerich"): + cmd = "aerich " + cmd + subprocess.run(shlex.split(cmd), timeout=2) + + +def run_shell(cmd: str) -> subprocess.CompletedProcess: + envs = dict(os.environ, PYTHONPATH=".") + return subprocess.run(shlex.split(cmd), env=envs) + + +def test_sqlite_migrate(tmp_path: Path) -> None: + if (ddl := getattr(Migrate, "ddl", None)) and not isinstance(ddl, SqliteDDL): + return + with chdir(tmp_path): + models_py = Path("models.py") + settings_py = Path("settings.py") + test_py = Path("_test.py") + models_py.write_text(MODELS) + settings_py.write_text(SETTINGS) + test_py.write_text(TESTS) + Path("conftest.py").write_text(CONFTEST) + run_aerich("aerich init -t settings.TORTOISE_ORM") + run_aerich("aerich init-db") + r = run_shell("pytest _test.py::test_allow_duplicate") + assert r.returncode == 0 + # Add index + models_py.write_text(MODELS.replace("index=False", "index=True")) + run_aerich("aerich migrate") # migrations/models/1_ + run_aerich("aerich upgrade") + r = run_shell("pytest -s _test.py::test_allow_duplicate") + assert r.returncode == 0 + # Drop index + models_py.write_text(MODELS) + run_aerich("aerich migrate") # migrations/models/2_ + run_aerich("aerich upgrade") + r = run_shell("pytest -s _test.py::test_allow_duplicate") + assert r.returncode == 0 + # Add unique index + models_py.write_text(MODELS.replace("index=False", "index=True, unique=True")) + run_aerich("aerich migrate") # migrations/models/3_ + run_aerich("aerich upgrade") + r = run_shell("pytest _test.py::test_unique_is_true") + assert r.returncode == 0 + # Drop unique index + models_py.write_text(MODELS) + run_aerich("aerich migrate") # migrations/models/4_ + run_aerich("aerich upgrade") + r = run_shell("pytest _test.py::test_allow_duplicate") + assert r.returncode == 0 + # Add field with unique=True + with models_py.open("a") as f: + f.write(" age = fields.IntField(unique=True, default=0)") + run_aerich("aerich migrate") # migrations/models/5_ + run_aerich("aerich upgrade") + r = run_shell("pytest _test.py::test_add_unique_field") + assert r.returncode == 0 + # Drop unique field + models_py.write_text(MODELS) + run_aerich("aerich migrate") # migrations/models/6_ + run_aerich("aerich upgrade") + r = run_shell("pytest -s _test.py::test_drop_unique_field") + assert r.returncode == 0 + + # Initial with indexed field and then drop it + shutil.rmtree("migrations") + Path("db.sqlite3").unlink() + models_py.write_text(MODELS + " age = fields.IntField(db_index=True)") + run_aerich("aerich init -t settings.TORTOISE_ORM") + run_aerich("aerich init-db") + migration_file = list(Path("migrations/models").glob("0_*.py"))[0] + assert "CREATE INDEX" in migration_file.read_text() + r = run_shell("pytest _test.py::test_with_age_field") + assert r.returncode == 0 + models_py.write_text(MODELS) + run_aerich("aerich migrate") + run_aerich("aerich upgrade") + migration_file_1 = list(Path("migrations/models").glob("1_*.py"))[0] + assert "DROP INDEX" in migration_file_1.read_text() + r = run_shell("pytest _test.py::test_without_age_field") + assert r.returncode == 0