Fix sqlite create/drop index (#379)
* Update add/drop index template for sqlite * tests: add sqlite migrate/upgrade command test * tests: add timeout for sqlite migrate command test * tests: add test cases for add/drop unique field for sqlite * fix: sqlite failed to add unique field
This commit is contained in:
parent
c2ebe9b5e4
commit
9c81bc6036
@ -5,6 +5,7 @@
|
||||
### [0.8.1](Unreleased)
|
||||
|
||||
#### Fixed
|
||||
- sqlite: failed to create/drop index. (#302)
|
||||
- PostgreSQL: Cannot drop constraint after deleting or rename FK on a model. (#378)
|
||||
- Sort m2m fields before comparing them with diff. (#271)
|
||||
|
||||
|
@ -3,6 +3,7 @@ from typing import Any, List, Type, cast
|
||||
|
||||
from tortoise import BaseDBAsyncClient, Model
|
||||
from tortoise.backends.base.schema_generator import BaseSchemaGenerator
|
||||
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
|
||||
|
||||
from aerich.utils import is_default_function
|
||||
|
||||
@ -122,7 +123,12 @@ class BaseDDL:
|
||||
unique = ""
|
||||
template = self._MODIFY_COLUMN_TEMPLATE
|
||||
else:
|
||||
unique = "UNIQUE" if field_describe.get("unique") else ""
|
||||
# sqlite does not support alter table to add unique column
|
||||
unique = (
|
||||
"UNIQUE"
|
||||
if field_describe.get("unique") and self.DIALECT != SqliteSchemaGenerator.DIALECT
|
||||
else ""
|
||||
)
|
||||
template = self._ADD_COLUMN_TEMPLATE
|
||||
return template.format(
|
||||
table_name=db_table,
|
||||
|
@ -10,6 +10,8 @@ from aerich.exceptions import NotSupportError
|
||||
class SqliteDDL(BaseDDL):
|
||||
schema_generator_cls = SqliteSchemaGenerator
|
||||
DIALECT = SqliteSchemaGenerator.DIALECT
|
||||
_ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX "{index_name}" ON "{table_name}" ({column_names})'
|
||||
_DROP_INDEX_TEMPLATE = 'DROP INDEX IF EXISTS "{index_name}"'
|
||||
|
||||
def modify_column(self, model: "Type[Model]", field_object: dict, is_pk: bool = True):
|
||||
raise NotSupportError("Modify column is unsupported in SQLite.")
|
||||
|
@ -63,6 +63,14 @@ def test_add_column():
|
||||
assert ret == "ALTER TABLE `category` ADD `name` VARCHAR(200)"
|
||||
else:
|
||||
assert ret == 'ALTER TABLE "category" ADD "name" VARCHAR(200)'
|
||||
# add unique column
|
||||
ret = Migrate.ddl.add_column(User, User._meta.fields_map.get("username").describe(False))
|
||||
if isinstance(Migrate.ddl, MysqlDDL):
|
||||
assert ret == "ALTER TABLE `user` ADD `username` VARCHAR(20) NOT NULL UNIQUE"
|
||||
elif isinstance(Migrate.ddl, PostgresDDL):
|
||||
assert ret == 'ALTER TABLE "user" ADD "username" VARCHAR(20) NOT NULL UNIQUE'
|
||||
else:
|
||||
assert ret == 'ALTER TABLE "user" ADD "username" VARCHAR(20) NOT NULL'
|
||||
|
||||
|
||||
def test_modify_column():
|
||||
@ -155,14 +163,9 @@ def test_add_index():
|
||||
if isinstance(Migrate.ddl, MysqlDDL):
|
||||
assert index == "ALTER TABLE `category` ADD INDEX `idx_category_name_8b0cb9` (`name`)"
|
||||
assert index_u == "ALTER TABLE `category` ADD UNIQUE INDEX `name` (`name`)"
|
||||
elif isinstance(Migrate.ddl, PostgresDDL):
|
||||
else:
|
||||
assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")'
|
||||
assert index_u == 'CREATE UNIQUE INDEX "uid_category_name_8b0cb9" ON "category" ("name")'
|
||||
else:
|
||||
assert index == 'ALTER TABLE "category" ADD INDEX "idx_category_name_8b0cb9" ("name")'
|
||||
assert (
|
||||
index_u == 'ALTER TABLE "category" ADD UNIQUE INDEX "uid_category_name_8b0cb9" ("name")'
|
||||
)
|
||||
|
||||
|
||||
def test_drop_index():
|
||||
@ -175,8 +178,8 @@ def test_drop_index():
|
||||
assert ret == 'DROP INDEX "idx_category_name_8b0cb9"'
|
||||
assert ret_u == 'DROP INDEX "uid_category_name_8b0cb9"'
|
||||
else:
|
||||
assert ret == 'ALTER TABLE "category" DROP INDEX "idx_category_name_8b0cb9"'
|
||||
assert ret_u == 'ALTER TABLE "category" DROP INDEX "uid_category_name_8b0cb9"'
|
||||
assert ret == 'DROP INDEX IF EXISTS "idx_category_name_8b0cb9"'
|
||||
assert ret_u == 'DROP INDEX IF EXISTS "uid_category_name_8b0cb9"'
|
||||
|
||||
|
||||
def test_add_fk():
|
||||
|
217
tests/test_sqlite_migrate.py
Normal file
217
tests/test_sqlite_migrate.py
Normal file
@ -0,0 +1,217 @@
|
||||
import contextlib
|
||||
import os
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from aerich.ddl.sqlite import SqliteDDL
|
||||
from aerich.migrate import Migrate
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from contextlib import chdir
|
||||
else:
|
||||
|
||||
class chdir(contextlib.AbstractContextManager): # Copied from source code of Python3.13
|
||||
"""Non thread-safe context manager to change the current working directory."""
|
||||
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
self._old_cwd = []
|
||||
|
||||
def __enter__(self):
|
||||
self._old_cwd.append(os.getcwd())
|
||||
os.chdir(self.path)
|
||||
|
||||
def __exit__(self, *excinfo):
|
||||
os.chdir(self._old_cwd.pop())
|
||||
|
||||
|
||||
MODELS = """from __future__ import annotations
|
||||
|
||||
from tortoise import Model, fields
|
||||
|
||||
|
||||
class Foo(Model):
|
||||
name = fields.CharField(max_length=60, db_index=False)
|
||||
"""
|
||||
|
||||
SETTINGS = """from __future__ import annotations
|
||||
|
||||
TORTOISE_ORM = {
|
||||
"connections": {"default": "sqlite://db.sqlite3"},
|
||||
"apps": {"models": {"models": ["models", "aerich.models"]}},
|
||||
}
|
||||
"""
|
||||
|
||||
CONFTEST = """from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from tortoise import Tortoise, connections
|
||||
|
||||
import settings
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop() -> Generator:
|
||||
policy = asyncio.get_event_loop_policy()
|
||||
res = policy.new_event_loop()
|
||||
asyncio.set_event_loop(res)
|
||||
res._close = res.close # type:ignore[attr-defined]
|
||||
res.close = lambda: None # type:ignore[method-assign]
|
||||
|
||||
yield res
|
||||
|
||||
res._close() # type:ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
async def api(event_loop, request):
|
||||
await Tortoise.init(config=settings.TORTOISE_ORM)
|
||||
request.addfinalizer(lambda: event_loop.run_until_complete(connections.close_all(discard=True)))
|
||||
"""
|
||||
|
||||
TESTS = """from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from tortoise.exceptions import IntegrityError
|
||||
|
||||
from models import Foo
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allow_duplicate() -> None:
|
||||
await Foo.all().delete()
|
||||
await Foo.create(name="foo")
|
||||
obj = await Foo.create(name="foo")
|
||||
assert (await Foo.all().count()) == 2
|
||||
await obj.delete()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unique_is_true() -> None:
|
||||
with pytest.raises(IntegrityError):
|
||||
await Foo.create(name="foo")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_unique_field() -> None:
|
||||
if not await Foo.filter(age=0).exists():
|
||||
await Foo.create(name="0_"+uuid.uuid4().hex, age=0)
|
||||
with pytest.raises(IntegrityError):
|
||||
await Foo.create(name=uuid.uuid4().hex, age=0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drop_unique_field() -> None:
|
||||
name = "1_" + uuid.uuid4().hex
|
||||
await Foo.create(name=name, age=0)
|
||||
assert (await Foo.filter(name=name).exists())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_age_field() -> None:
|
||||
name = "2_" + uuid.uuid4().hex
|
||||
await Foo.create(name=name, age=0)
|
||||
obj = await Foo.get(name=name)
|
||||
assert obj.age == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_without_age_field() -> None:
|
||||
name = "3_" + uuid.uuid4().hex
|
||||
await Foo.create(name=name, age=0)
|
||||
obj = await Foo.get(name=name)
|
||||
assert getattr(obj, "age", None) is None
|
||||
"""
|
||||
|
||||
|
||||
def run_aerich(cmd: str) -> None:
|
||||
with contextlib.suppress(subprocess.TimeoutExpired):
|
||||
if not cmd.startswith("aerich"):
|
||||
cmd = "aerich " + cmd
|
||||
subprocess.run(shlex.split(cmd), timeout=2)
|
||||
|
||||
|
||||
def run_shell(cmd: str) -> subprocess.CompletedProcess:
|
||||
envs = dict(os.environ, PYTHONPATH=".")
|
||||
return subprocess.run(shlex.split(cmd), env=envs)
|
||||
|
||||
|
||||
def test_sqlite_migrate(tmp_path: Path) -> None:
|
||||
if (ddl := getattr(Migrate, "ddl", None)) and not isinstance(ddl, SqliteDDL):
|
||||
return
|
||||
with chdir(tmp_path):
|
||||
models_py = Path("models.py")
|
||||
settings_py = Path("settings.py")
|
||||
test_py = Path("_test.py")
|
||||
models_py.write_text(MODELS)
|
||||
settings_py.write_text(SETTINGS)
|
||||
test_py.write_text(TESTS)
|
||||
Path("conftest.py").write_text(CONFTEST)
|
||||
run_aerich("aerich init -t settings.TORTOISE_ORM")
|
||||
run_aerich("aerich init-db")
|
||||
r = run_shell("pytest _test.py::test_allow_duplicate")
|
||||
assert r.returncode == 0
|
||||
# Add index
|
||||
models_py.write_text(MODELS.replace("index=False", "index=True"))
|
||||
run_aerich("aerich migrate") # migrations/models/1_
|
||||
run_aerich("aerich upgrade")
|
||||
r = run_shell("pytest -s _test.py::test_allow_duplicate")
|
||||
assert r.returncode == 0
|
||||
# Drop index
|
||||
models_py.write_text(MODELS)
|
||||
run_aerich("aerich migrate") # migrations/models/2_
|
||||
run_aerich("aerich upgrade")
|
||||
r = run_shell("pytest -s _test.py::test_allow_duplicate")
|
||||
assert r.returncode == 0
|
||||
# Add unique index
|
||||
models_py.write_text(MODELS.replace("index=False", "index=True, unique=True"))
|
||||
run_aerich("aerich migrate") # migrations/models/3_
|
||||
run_aerich("aerich upgrade")
|
||||
r = run_shell("pytest _test.py::test_unique_is_true")
|
||||
assert r.returncode == 0
|
||||
# Drop unique index
|
||||
models_py.write_text(MODELS)
|
||||
run_aerich("aerich migrate") # migrations/models/4_
|
||||
run_aerich("aerich upgrade")
|
||||
r = run_shell("pytest _test.py::test_allow_duplicate")
|
||||
assert r.returncode == 0
|
||||
# Add field with unique=True
|
||||
with models_py.open("a") as f:
|
||||
f.write(" age = fields.IntField(unique=True, default=0)")
|
||||
run_aerich("aerich migrate") # migrations/models/5_
|
||||
run_aerich("aerich upgrade")
|
||||
r = run_shell("pytest _test.py::test_add_unique_field")
|
||||
assert r.returncode == 0
|
||||
# Drop unique field
|
||||
models_py.write_text(MODELS)
|
||||
run_aerich("aerich migrate") # migrations/models/6_
|
||||
run_aerich("aerich upgrade")
|
||||
r = run_shell("pytest -s _test.py::test_drop_unique_field")
|
||||
assert r.returncode == 0
|
||||
|
||||
# Initial with indexed field and then drop it
|
||||
shutil.rmtree("migrations")
|
||||
Path("db.sqlite3").unlink()
|
||||
models_py.write_text(MODELS + " age = fields.IntField(db_index=True)")
|
||||
run_aerich("aerich init -t settings.TORTOISE_ORM")
|
||||
run_aerich("aerich init-db")
|
||||
migration_file = list(Path("migrations/models").glob("0_*.py"))[0]
|
||||
assert "CREATE INDEX" in migration_file.read_text()
|
||||
r = run_shell("pytest _test.py::test_with_age_field")
|
||||
assert r.returncode == 0
|
||||
models_py.write_text(MODELS)
|
||||
run_aerich("aerich migrate")
|
||||
run_aerich("aerich upgrade")
|
||||
migration_file_1 = list(Path("migrations/models").glob("1_*.py"))[0]
|
||||
assert "DROP INDEX" in migration_file_1.read_text()
|
||||
r = run_shell("pytest _test.py::test_without_age_field")
|
||||
assert r.returncode == 0
|
Loading…
x
Reference in New Issue
Block a user