refactor: avoid updating inited config file (#402)

* refactor: avoid updating config file if init config items not changed

* fix unittest error with tortoise develop branch

* Remove extra space

* fix mysql test error

* fix mysql create index error
This commit is contained in:
Waket Zheng 2025-01-04 09:08:14 +08:00 committed by GitHub
parent f5d7d56fa5
commit ac847ba616
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 76 additions and 51 deletions

View File

@ -157,6 +157,19 @@ async def history(ctx: Context) -> None:
click.secho(version, fg=Color.green) click.secho(version, fg=Color.green)
def _write_config(config_path, doc, table) -> None:
try:
import tomli_w as tomlkit
except ImportError:
import tomlkit # type: ignore
try:
doc["tool"]["aerich"] = table
except KeyError:
doc["tool"] = {"aerich": table}
config_path.write_text(tomlkit.dumps(doc))
@cli.command(help="Initialize aerich config and create migrations folder.") @cli.command(help="Initialize aerich config and create migrations folder.")
@click.option( @click.option(
"-t", "-t",
@ -179,10 +192,6 @@ async def history(ctx: Context) -> None:
) )
@click.pass_context @click.pass_context
async def init(ctx: Context, tortoise_orm, location, src_folder) -> None: async def init(ctx: Context, tortoise_orm, location, src_folder) -> None:
try:
import tomli_w as tomlkit
except ImportError:
import tomlkit # type: ignore
config_file = ctx.obj["config_file"] config_file = ctx.obj["config_file"]
if os.path.isabs(src_folder): if os.path.isabs(src_folder):
@ -197,20 +206,18 @@ async def init(ctx: Context, tortoise_orm, location, src_folder) -> None:
config_path = Path(config_file) config_path = Path(config_file)
content = config_path.read_text("utf-8") if config_path.exists() else "[tool.aerich]" content = config_path.read_text("utf-8") if config_path.exists() else "[tool.aerich]"
doc: dict = tomllib.loads(content) doc: dict = tomllib.loads(content)
table: dict = getattr(tomlkit, "table", dict)()
table["tortoise_orm"] = tortoise_orm table = {"tortoise_orm": tortoise_orm, "location": location, "src_folder": src_folder}
table["location"] = location if (aerich_config := doc.get("tool", {}).get("aerich")) and all(
table["src_folder"] = src_folder aerich_config.get(k) == v for k, v in table.items()
try: ):
doc["tool"]["aerich"] = table click.echo(f"Aerich config {config_file} already inited.")
except KeyError: else:
doc["tool"] = {"aerich": table} _write_config(config_path, doc, table)
config_path.write_text(tomlkit.dumps(doc)) click.secho(f"Success writing aerich config to {config_file}", fg=Color.green)
Path(location).mkdir(parents=True, exist_ok=True) Path(location).mkdir(parents=True, exist_ok=True)
click.secho(f"Success creating migrations folder {location}", fg=Color.green) click.secho(f"Success creating migrations folder {location}", fg=Color.green)
click.secho(f"Success writing aerich config to {config_file}", fg=Color.green)
@cli.command(help="Generate schema and generate app migration folder.") @cli.command(help="Generate schema and generate app migration folder.")

View File

@ -1,6 +1,8 @@
import re
from enum import Enum from enum import Enum
from typing import Any, List, Type, cast from typing import Any, List, Type, cast
import tortoise
from tortoise import BaseDBAsyncClient, Model from tortoise import BaseDBAsyncClient, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
@ -41,9 +43,11 @@ class BaseDDL:
self.schema_generator = self.schema_generator_cls(client) self.schema_generator = self.schema_generator_cls(client)
def create_table(self, model: "Type[Model]") -> str: def create_table(self, model: "Type[Model]") -> str:
return self.schema_generator._get_table_sql(model, True)["table_creation_string"].rstrip( schema = self.schema_generator._get_table_sql(model, True)["table_creation_string"]
";" if tortoise.__version__ <= "0.23.0":
) # Remove extra space
schema = re.sub(r'(["()A-Za-z]) (["()A-Za-z])', r"\1 \2", schema)
return schema.rstrip(";")
def drop_table(self, table_name: str) -> str: def drop_table(self, table_name: str) -> str:
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name) return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
@ -125,31 +129,31 @@ class BaseDDL:
else: else:
# sqlite does not support alter table to add unique column # sqlite does not support alter table to add unique column
unique = ( unique = (
"UNIQUE" " UNIQUE"
if field_describe.get("unique") and self.DIALECT != SqliteSchemaGenerator.DIALECT if field_describe.get("unique") and self.DIALECT != SqliteSchemaGenerator.DIALECT
else "" else ""
) )
template = self._ADD_COLUMN_TEMPLATE template = self._ADD_COLUMN_TEMPLATE
return template.format( column = self.schema_generator._create_string(
table_name=db_table, db_column=db_column,
column=self.schema_generator._create_string( field_type=db_field_types.get(self.DIALECT, db_field_types.get("")),
db_column=db_column, nullable=" NOT NULL" if not field_describe.get("nullable") else "",
field_type=db_field_types.get(self.DIALECT, db_field_types.get("")), unique=unique,
nullable="NOT NULL" if not field_describe.get("nullable") else "", comment=(
unique=unique, self.schema_generator._column_comment_generator(
comment=( table=db_table,
self.schema_generator._column_comment_generator( column=db_column,
table=db_table, comment=description,
column=db_column, )
comment=description, if description
) else ""
if description
else ""
),
is_primary_key=is_pk,
default=default,
), ),
is_primary_key=is_pk,
default=default,
) )
if tortoise.__version__ <= "0.23.0":
column = column.replace(" ", " ")
return template.format(table_name=db_table, column=column)
def drop_column(self, model: "Type[Model]", column_name: str) -> str: def drop_column(self, model: "Type[Model]", column_name: str) -> str:
return self._DROP_COLUMN_TEMPLATE.format( return self._DROP_COLUMN_TEMPLATE.format(

View File

@ -61,7 +61,17 @@ async def initialize_tests(event_loop, request) -> None:
with contextlib.suppress(DBConnectionError, OperationalError): with contextlib.suppress(DBConnectionError, OperationalError):
await Tortoise._drop_databases() await Tortoise._drop_databases()
await Tortoise.init(config=tortoise_orm, _create_db=True) await Tortoise.init(config=tortoise_orm, _create_db=True)
await generate_schema_for_client(Tortoise.get_connection("default"), safe=True) try:
await generate_schema_for_client(Tortoise.get_connection("default"), safe=True)
except OperationalError as e:
if (s := "IF NOT EXISTS") not in str(e):
raise e
# MySQL does not support `CREATE INDEX IF NOT EXISTS` syntax
client = Tortoise.get_connection("default")
generator = client.schema_generator(client)
schema = generator.get_create_schema_sql(safe=True)
schema = schema.replace(f" INDEX {s}", " INDEX")
await generator.generate_from_string(schema)
client = Tortoise.get_connection("default") client = Tortoise.get_connection("default")
if client.schema_generator is MySQLSchemaGenerator: if client.schema_generator is MySQLSchemaGenerator:

View File

@ -15,7 +15,7 @@ def test_create_table():
`slug` VARCHAR(100) NOT NULL, `slug` VARCHAR(100) NOT NULL,
`name` VARCHAR(200), `name` VARCHAR(200),
`title` VARCHAR(20) NOT NULL, `title` VARCHAR(20) NOT NULL,
`created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
`owner_id` INT NOT NULL COMMENT 'User', `owner_id` INT NOT NULL COMMENT 'User',
CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE
) CHARACTER SET utf8mb4""" ) CHARACTER SET utf8mb4"""
@ -29,7 +29,7 @@ def test_create_table():
"slug" VARCHAR(100) NOT NULL, "slug" VARCHAR(100) NOT NULL,
"name" VARCHAR(200), "name" VARCHAR(200),
"title" VARCHAR(20) NOT NULL, "title" VARCHAR(20) NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */ "owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */
)""" )"""
) )
@ -42,7 +42,7 @@ def test_create_table():
"slug" VARCHAR(100) NOT NULL, "slug" VARCHAR(100) NOT NULL,
"name" VARCHAR(200), "name" VARCHAR(200),
"title" VARCHAR(20) NOT NULL, "title" VARCHAR(20) NOT NULL,
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, "created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE "owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE
); );
COMMENT ON COLUMN "category"."owner_id" IS 'User'""" COMMENT ON COLUMN "category"."owner_id" IS 'User'"""
@ -85,7 +85,7 @@ def test_modify_column():
assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)" assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)"
assert ( assert (
ret1 ret1
== "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1" == "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1"
) )
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert ( assert (

View File

@ -955,7 +955,7 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `email` DROP COLUMN `user_id`", "ALTER TABLE `email` DROP COLUMN `user_id`",
"ALTER TABLE `config` DROP COLUMN `name`", "ALTER TABLE `config` DROP COLUMN `name`",
"ALTER TABLE `config` DROP INDEX `name`", "ALTER TABLE `config` DROP INDEX `name`",
"ALTER TABLE `config` ADD `user_id` INT NOT NULL COMMENT 'User'", "ALTER TABLE `config` ADD `user_id` INT NOT NULL COMMENT 'User'",
"ALTER TABLE `config` ADD CONSTRAINT `fk_config_user_17daa970` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE", "ALTER TABLE `config` ADD CONSTRAINT `fk_config_user_17daa970` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
"ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT", "ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT",
"ALTER TABLE `config` MODIFY COLUMN `value` JSON NOT NULL", "ALTER TABLE `config` MODIFY COLUMN `value` JSON NOT NULL",
@ -971,18 +971,18 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)", "ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)",
"ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_869427` (`name`, `type_db_alias`)", "ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_869427` (`name`, `type_db_alias`)",
"ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0", "ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0",
"ALTER TABLE `product` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)", "ALTER TABLE `product` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `product` RENAME COLUMN `is_delete` TO `is_deleted`", "ALTER TABLE `product` RENAME COLUMN `is_delete` TO `is_deleted`",
"ALTER TABLE `product` RENAME COLUMN `is_review` TO `is_reviewed`", "ALTER TABLE `product` RENAME COLUMN `is_review` TO `is_reviewed`",
"ALTER TABLE `user` DROP COLUMN `avatar`", "ALTER TABLE `user` DROP COLUMN `avatar`",
"ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(100) NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(100) NOT NULL",
"ALTER TABLE `user` MODIFY COLUMN `intro` LONGTEXT NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `intro` LONGTEXT NOT NULL",
"ALTER TABLE `user` MODIFY COLUMN `last_login` DATETIME(6) NOT NULL COMMENT 'Last Login'", "ALTER TABLE `user` MODIFY COLUMN `last_login` DATETIME(6) NOT NULL COMMENT 'Last Login'",
"ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(10,8) NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(10,8) NOT NULL",
"ALTER TABLE `user` ADD UNIQUE INDEX `username` (`username`)", "ALTER TABLE `user` ADD UNIQUE INDEX `username` (`username`)",
"CREATE TABLE `email_user` (\n `email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4", "CREATE TABLE `email_user` (\n `email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"CREATE TABLE IF NOT EXISTS `newmodel` (\n `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,\n `name` VARCHAR(50) NOT NULL\n) CHARACTER SET utf8mb4", "CREATE TABLE IF NOT EXISTS `newmodel` (\n `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,\n `name` VARCHAR(50) NOT NULL\n) CHARACTER SET utf8mb4",
"ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)", "ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL", "ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL",
"CREATE TABLE `product_user` (\n `product_id` INT NOT NULL REFERENCES `product` (`id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4", "CREATE TABLE `product_user` (\n `product_id` INT NOT NULL REFERENCES `product` (`id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"CREATE TABLE `config_category_map` (\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE,\n `config_id` INT NOT NULL REFERENCES `config` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4", "CREATE TABLE `config_category_map` (\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE,\n `config_id` INT NOT NULL REFERENCES `config` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
@ -1014,7 +1014,7 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT", "ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT",
"ALTER TABLE `product` RENAME COLUMN `is_deleted` TO `is_delete`", "ALTER TABLE `product` RENAME COLUMN `is_deleted` TO `is_delete`",
"ALTER TABLE `product` RENAME COLUMN `is_reviewed` TO `is_review`", "ALTER TABLE `product` RENAME COLUMN `is_reviewed` TO `is_review`",
"ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''", "ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''",
"ALTER TABLE `user` DROP INDEX `username`", "ALTER TABLE `user` DROP INDEX `username`",
"ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(200) NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(200) NOT NULL",
"DROP TABLE IF EXISTS `email_user`", "DROP TABLE IF EXISTS `email_user`",
@ -1022,9 +1022,9 @@ def test_migrate(mocker: MockerFixture):
"DROP TABLE IF EXISTS `product_user`", "DROP TABLE IF EXISTS `product_user`",
"ALTER TABLE `user` MODIFY COLUMN `intro` LONGTEXT NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `intro` LONGTEXT NOT NULL",
"ALTER TABLE `config` MODIFY COLUMN `value` TEXT NOT NULL", "ALTER TABLE `config` MODIFY COLUMN `value` TEXT NOT NULL",
"ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)", "ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `product` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)", "ALTER TABLE `product` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `user` MODIFY COLUMN `last_login` DATETIME(6) NOT NULL COMMENT 'Last Login'", "ALTER TABLE `user` MODIFY COLUMN `last_login` DATETIME(6) NOT NULL COMMENT 'Last Login'",
"ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(12,9) NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(12,9) NOT NULL",
"ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL", "ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL",
"CREATE TABLE `config_category` (\n `config_id` INT NOT NULL REFERENCES `config` (`id`) ON DELETE CASCADE,\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4", "CREATE TABLE `config_category` (\n `config_id` INT NOT NULL REFERENCES `config` (`id`) ON DELETE CASCADE,\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
@ -1104,7 +1104,7 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"', 'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"',
'ALTER TABLE "product" RENAME COLUMN "is_deleted" TO "is_delete"', 'ALTER TABLE "product" RENAME COLUMN "is_deleted" TO "is_delete"',
'ALTER TABLE "product" RENAME COLUMN "is_reviewed" TO "is_review"', 'ALTER TABLE "product" RENAME COLUMN "is_reviewed" TO "is_review"',
'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'', 'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(200) USING "password"::VARCHAR(200)', 'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(200) USING "password"::VARCHAR(200)',
'ALTER TABLE "user" ALTER COLUMN "last_login" TYPE TIMESTAMPTZ USING "last_login"::TIMESTAMPTZ', 'ALTER TABLE "user" ALTER COLUMN "last_login" TYPE TIMESTAMPTZ USING "last_login"::TIMESTAMPTZ',
'ALTER TABLE "user" ALTER COLUMN "intro" TYPE TEXT USING "intro"::TEXT', 'ALTER TABLE "user" ALTER COLUMN "intro" TYPE TEXT USING "intro"::TEXT',

View File

@ -181,7 +181,11 @@ def test_sqlite_migrate(tmp_path: Path) -> None:
if (db_file := Path("db.sqlite3")).exists(): if (db_file := Path("db.sqlite3")).exists():
db_file.unlink() db_file.unlink()
run_aerich("aerich init -t settings.TORTOISE_ORM") run_aerich("aerich init -t settings.TORTOISE_ORM")
config_file = Path("pyproject.toml")
modify_time = config_file.stat().st_mtime
run_aerich("aerich init-db") run_aerich("aerich init-db")
run_aerich("aerich init -t settings.TORTOISE_ORM")
assert modify_time == config_file.stat().st_mtime
r = run_shell("pytest _test.py::test_allow_duplicate") r = run_shell("pytest _test.py::test_allow_duplicate")
assert r.returncode == 0 assert r.returncode == 0
# Add index # Add index