From ac847ba6160d532d5e5ec705f82f5d3a2ecd8bcb Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Sat, 4 Jan 2025 09:08:14 +0800 Subject: [PATCH] refactor: avoid updating inited config file (#402) * refactor: avoid updating config file if init config items not changed * fix unittest error with tortoise develop branch * Remove extra space * fix mysql test error * fix mysql create index error --- aerich/cli.py | 37 ++++++++++++++++----------- aerich/ddl/__init__.py | 48 +++++++++++++++++++----------------- conftest.py | 12 ++++++++- tests/test_ddl.py | 8 +++--- tests/test_migrate.py | 18 +++++++------- tests/test_sqlite_migrate.py | 4 +++ 6 files changed, 76 insertions(+), 51 deletions(-) diff --git a/aerich/cli.py b/aerich/cli.py index 9b6f452..ddf88c4 100644 --- a/aerich/cli.py +++ b/aerich/cli.py @@ -157,6 +157,19 @@ async def history(ctx: Context) -> None: click.secho(version, fg=Color.green) +def _write_config(config_path, doc, table) -> None: + try: + import tomli_w as tomlkit + except ImportError: + import tomlkit # type: ignore + + try: + doc["tool"]["aerich"] = table + except KeyError: + doc["tool"] = {"aerich": table} + config_path.write_text(tomlkit.dumps(doc)) + + @cli.command(help="Initialize aerich config and create migrations folder.") @click.option( "-t", @@ -179,10 +192,6 @@ async def history(ctx: Context) -> None: ) @click.pass_context async def init(ctx: Context, tortoise_orm, location, src_folder) -> None: - try: - import tomli_w as tomlkit - except ImportError: - import tomlkit # type: ignore config_file = ctx.obj["config_file"] if os.path.isabs(src_folder): @@ -197,20 +206,18 @@ async def init(ctx: Context, tortoise_orm, location, src_folder) -> None: config_path = Path(config_file) content = config_path.read_text("utf-8") if config_path.exists() else "[tool.aerich]" doc: dict = tomllib.loads(content) - table: dict = getattr(tomlkit, "table", dict)() - table["tortoise_orm"] = tortoise_orm - table["location"] = location - table["src_folder"] = src_folder - try: - doc["tool"]["aerich"] = table - except KeyError: - doc["tool"] = {"aerich": table} - config_path.write_text(tomlkit.dumps(doc)) + + table = {"tortoise_orm": tortoise_orm, "location": location, "src_folder": src_folder} + if (aerich_config := doc.get("tool", {}).get("aerich")) and all( + aerich_config.get(k) == v for k, v in table.items() + ): + click.echo(f"Aerich config {config_file} already inited.") + else: + _write_config(config_path, doc, table) + click.secho(f"Success writing aerich config to {config_file}", fg=Color.green) Path(location).mkdir(parents=True, exist_ok=True) - click.secho(f"Success creating migrations folder {location}", fg=Color.green) - click.secho(f"Success writing aerich config to {config_file}", fg=Color.green) @cli.command(help="Generate schema and generate app migration folder.") diff --git a/aerich/ddl/__init__.py b/aerich/ddl/__init__.py index 11b355d..50d2ea5 100644 --- a/aerich/ddl/__init__.py +++ b/aerich/ddl/__init__.py @@ -1,6 +1,8 @@ +import re from enum import Enum from typing import Any, List, Type, cast +import tortoise from tortoise import BaseDBAsyncClient, Model from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator @@ -41,9 +43,11 @@ class BaseDDL: self.schema_generator = self.schema_generator_cls(client) def create_table(self, model: "Type[Model]") -> str: - return self.schema_generator._get_table_sql(model, True)["table_creation_string"].rstrip( - ";" - ) + schema = self.schema_generator._get_table_sql(model, True)["table_creation_string"] + if tortoise.__version__ <= "0.23.0": + # Remove extra space + schema = re.sub(r'(["()A-Za-z]) (["()A-Za-z])', r"\1 \2", schema) + return schema.rstrip(";") def drop_table(self, table_name: str) -> str: return self._DROP_TABLE_TEMPLATE.format(table_name=table_name) @@ -125,31 +129,31 @@ class BaseDDL: else: # sqlite does not support alter table to add unique column unique = ( - "UNIQUE" + " UNIQUE" if field_describe.get("unique") and self.DIALECT != SqliteSchemaGenerator.DIALECT else "" ) template = self._ADD_COLUMN_TEMPLATE - return template.format( - table_name=db_table, - column=self.schema_generator._create_string( - db_column=db_column, - field_type=db_field_types.get(self.DIALECT, db_field_types.get("")), - nullable="NOT NULL" if not field_describe.get("nullable") else "", - unique=unique, - comment=( - self.schema_generator._column_comment_generator( - table=db_table, - column=db_column, - comment=description, - ) - if description - else "" - ), - is_primary_key=is_pk, - default=default, + column = self.schema_generator._create_string( + db_column=db_column, + field_type=db_field_types.get(self.DIALECT, db_field_types.get("")), + nullable=" NOT NULL" if not field_describe.get("nullable") else "", + unique=unique, + comment=( + self.schema_generator._column_comment_generator( + table=db_table, + column=db_column, + comment=description, + ) + if description + else "" ), + is_primary_key=is_pk, + default=default, ) + if tortoise.__version__ <= "0.23.0": + column = column.replace(" ", " ") + return template.format(table_name=db_table, column=column) def drop_column(self, model: "Type[Model]", column_name: str) -> str: return self._DROP_COLUMN_TEMPLATE.format( diff --git a/conftest.py b/conftest.py index f87aa26..f4843df 100644 --- a/conftest.py +++ b/conftest.py @@ -61,7 +61,17 @@ async def initialize_tests(event_loop, request) -> None: with contextlib.suppress(DBConnectionError, OperationalError): await Tortoise._drop_databases() await Tortoise.init(config=tortoise_orm, _create_db=True) - await generate_schema_for_client(Tortoise.get_connection("default"), safe=True) + try: + await generate_schema_for_client(Tortoise.get_connection("default"), safe=True) + except OperationalError as e: + if (s := "IF NOT EXISTS") not in str(e): + raise e + # MySQL does not support `CREATE INDEX IF NOT EXISTS` syntax + client = Tortoise.get_connection("default") + generator = client.schema_generator(client) + schema = generator.get_create_schema_sql(safe=True) + schema = schema.replace(f" INDEX {s}", " INDEX") + await generator.generate_from_string(schema) client = Tortoise.get_connection("default") if client.schema_generator is MySQLSchemaGenerator: diff --git a/tests/test_ddl.py b/tests/test_ddl.py index 417eb92..09e01a8 100644 --- a/tests/test_ddl.py +++ b/tests/test_ddl.py @@ -15,7 +15,7 @@ def test_create_table(): `slug` VARCHAR(100) NOT NULL, `name` VARCHAR(200), `title` VARCHAR(20) NOT NULL, - `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), `owner_id` INT NOT NULL COMMENT 'User', CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE ) CHARACTER SET utf8mb4""" @@ -29,7 +29,7 @@ def test_create_table(): "slug" VARCHAR(100) NOT NULL, "name" VARCHAR(200), "title" VARCHAR(20) NOT NULL, - "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, "owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */ )""" ) @@ -42,7 +42,7 @@ def test_create_table(): "slug" VARCHAR(100) NOT NULL, "name" VARCHAR(200), "title" VARCHAR(20) NOT NULL, - "created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + "created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, "owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE ); COMMENT ON COLUMN "category"."owner_id" IS 'User'""" @@ -85,7 +85,7 @@ def test_modify_column(): assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)" assert ( ret1 - == "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1" + == "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1" ) elif isinstance(Migrate.ddl, PostgresDDL): assert ( diff --git a/tests/test_migrate.py b/tests/test_migrate.py index 9e85a2e..4838b39 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -955,7 +955,7 @@ def test_migrate(mocker: MockerFixture): "ALTER TABLE `email` DROP COLUMN `user_id`", "ALTER TABLE `config` DROP COLUMN `name`", "ALTER TABLE `config` DROP INDEX `name`", - "ALTER TABLE `config` ADD `user_id` INT NOT NULL COMMENT 'User'", + "ALTER TABLE `config` ADD `user_id` INT NOT NULL COMMENT 'User'", "ALTER TABLE `config` ADD CONSTRAINT `fk_config_user_17daa970` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE", "ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT", "ALTER TABLE `config` MODIFY COLUMN `value` JSON NOT NULL", @@ -971,18 +971,18 @@ def test_migrate(mocker: MockerFixture): "ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)", "ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_869427` (`name`, `type_db_alias`)", "ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0", - "ALTER TABLE `product` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)", + "ALTER TABLE `product` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)", "ALTER TABLE `product` RENAME COLUMN `is_delete` TO `is_deleted`", "ALTER TABLE `product` RENAME COLUMN `is_review` TO `is_reviewed`", "ALTER TABLE `user` DROP COLUMN `avatar`", "ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(100) NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `intro` LONGTEXT NOT NULL", - "ALTER TABLE `user` MODIFY COLUMN `last_login` DATETIME(6) NOT NULL COMMENT 'Last Login'", + "ALTER TABLE `user` MODIFY COLUMN `last_login` DATETIME(6) NOT NULL COMMENT 'Last Login'", "ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(10,8) NOT NULL", "ALTER TABLE `user` ADD UNIQUE INDEX `username` (`username`)", "CREATE TABLE `email_user` (\n `email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4", "CREATE TABLE IF NOT EXISTS `newmodel` (\n `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,\n `name` VARCHAR(50) NOT NULL\n) CHARACTER SET utf8mb4", - "ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)", + "ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)", "ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL", "CREATE TABLE `product_user` (\n `product_id` INT NOT NULL REFERENCES `product` (`id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4", "CREATE TABLE `config_category_map` (\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE,\n `config_id` INT NOT NULL REFERENCES `config` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4", @@ -1014,7 +1014,7 @@ def test_migrate(mocker: MockerFixture): "ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT", "ALTER TABLE `product` RENAME COLUMN `is_deleted` TO `is_delete`", "ALTER TABLE `product` RENAME COLUMN `is_reviewed` TO `is_review`", - "ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''", + "ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''", "ALTER TABLE `user` DROP INDEX `username`", "ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(200) NOT NULL", "DROP TABLE IF EXISTS `email_user`", @@ -1022,9 +1022,9 @@ def test_migrate(mocker: MockerFixture): "DROP TABLE IF EXISTS `product_user`", "ALTER TABLE `user` MODIFY COLUMN `intro` LONGTEXT NOT NULL", "ALTER TABLE `config` MODIFY COLUMN `value` TEXT NOT NULL", - "ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)", - "ALTER TABLE `product` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)", - "ALTER TABLE `user` MODIFY COLUMN `last_login` DATETIME(6) NOT NULL COMMENT 'Last Login'", + "ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)", + "ALTER TABLE `product` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)", + "ALTER TABLE `user` MODIFY COLUMN `last_login` DATETIME(6) NOT NULL COMMENT 'Last Login'", "ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(12,9) NOT NULL", "ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL", "CREATE TABLE `config_category` (\n `config_id` INT NOT NULL REFERENCES `config` (`id`) ON DELETE CASCADE,\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4", @@ -1104,7 +1104,7 @@ def test_migrate(mocker: MockerFixture): 'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"', 'ALTER TABLE "product" RENAME COLUMN "is_deleted" TO "is_delete"', 'ALTER TABLE "product" RENAME COLUMN "is_reviewed" TO "is_review"', - 'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'', + 'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'', 'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(200) USING "password"::VARCHAR(200)', 'ALTER TABLE "user" ALTER COLUMN "last_login" TYPE TIMESTAMPTZ USING "last_login"::TIMESTAMPTZ', 'ALTER TABLE "user" ALTER COLUMN "intro" TYPE TEXT USING "intro"::TEXT', diff --git a/tests/test_sqlite_migrate.py b/tests/test_sqlite_migrate.py index 01b9582..b629635 100644 --- a/tests/test_sqlite_migrate.py +++ b/tests/test_sqlite_migrate.py @@ -181,7 +181,11 @@ def test_sqlite_migrate(tmp_path: Path) -> None: if (db_file := Path("db.sqlite3")).exists(): db_file.unlink() run_aerich("aerich init -t settings.TORTOISE_ORM") + config_file = Path("pyproject.toml") + modify_time = config_file.stat().st_mtime run_aerich("aerich init-db") + run_aerich("aerich init -t settings.TORTOISE_ORM") + assert modify_time == config_file.stat().st_mtime r = run_shell("pytest _test.py::test_allow_duplicate") assert r.returncode == 0 # Add index