refactor: avoid updating inited config file (#402)

* refactor: avoid updating config file if init config items not changed

* fix unittest error with tortoise develop branch

* Remove extra space

* fix mysql test error

* fix mysql create index error
This commit is contained in:
Waket Zheng 2025-01-04 09:08:14 +08:00 committed by GitHub
parent f5d7d56fa5
commit ac847ba616
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 76 additions and 51 deletions

View File

@ -157,6 +157,19 @@ async def history(ctx: Context) -> None:
click.secho(version, fg=Color.green) click.secho(version, fg=Color.green)
def _write_config(config_path, doc, table) -> None:
try:
import tomli_w as tomlkit
except ImportError:
import tomlkit # type: ignore
try:
doc["tool"]["aerich"] = table
except KeyError:
doc["tool"] = {"aerich": table}
config_path.write_text(tomlkit.dumps(doc))
@cli.command(help="Initialize aerich config and create migrations folder.") @cli.command(help="Initialize aerich config and create migrations folder.")
@click.option( @click.option(
"-t", "-t",
@ -179,10 +192,6 @@ async def history(ctx: Context) -> None:
) )
@click.pass_context @click.pass_context
async def init(ctx: Context, tortoise_orm, location, src_folder) -> None: async def init(ctx: Context, tortoise_orm, location, src_folder) -> None:
try:
import tomli_w as tomlkit
except ImportError:
import tomlkit # type: ignore
config_file = ctx.obj["config_file"] config_file = ctx.obj["config_file"]
if os.path.isabs(src_folder): if os.path.isabs(src_folder):
@ -197,20 +206,18 @@ async def init(ctx: Context, tortoise_orm, location, src_folder) -> None:
config_path = Path(config_file) config_path = Path(config_file)
content = config_path.read_text("utf-8") if config_path.exists() else "[tool.aerich]" content = config_path.read_text("utf-8") if config_path.exists() else "[tool.aerich]"
doc: dict = tomllib.loads(content) doc: dict = tomllib.loads(content)
table: dict = getattr(tomlkit, "table", dict)()
table["tortoise_orm"] = tortoise_orm table = {"tortoise_orm": tortoise_orm, "location": location, "src_folder": src_folder}
table["location"] = location if (aerich_config := doc.get("tool", {}).get("aerich")) and all(
table["src_folder"] = src_folder aerich_config.get(k) == v for k, v in table.items()
try: ):
doc["tool"]["aerich"] = table click.echo(f"Aerich config {config_file} already inited.")
except KeyError: else:
doc["tool"] = {"aerich": table} _write_config(config_path, doc, table)
config_path.write_text(tomlkit.dumps(doc)) click.secho(f"Success writing aerich config to {config_file}", fg=Color.green)
Path(location).mkdir(parents=True, exist_ok=True) Path(location).mkdir(parents=True, exist_ok=True)
click.secho(f"Success creating migrations folder {location}", fg=Color.green) click.secho(f"Success creating migrations folder {location}", fg=Color.green)
click.secho(f"Success writing aerich config to {config_file}", fg=Color.green)
@cli.command(help="Generate schema and generate app migration folder.") @cli.command(help="Generate schema and generate app migration folder.")

View File

@ -1,6 +1,8 @@
import re
from enum import Enum from enum import Enum
from typing import Any, List, Type, cast from typing import Any, List, Type, cast
import tortoise
from tortoise import BaseDBAsyncClient, Model from tortoise import BaseDBAsyncClient, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
@ -41,9 +43,11 @@ class BaseDDL:
self.schema_generator = self.schema_generator_cls(client) self.schema_generator = self.schema_generator_cls(client)
def create_table(self, model: "Type[Model]") -> str: def create_table(self, model: "Type[Model]") -> str:
return self.schema_generator._get_table_sql(model, True)["table_creation_string"].rstrip( schema = self.schema_generator._get_table_sql(model, True)["table_creation_string"]
";" if tortoise.__version__ <= "0.23.0":
) # Remove extra space
schema = re.sub(r'(["()A-Za-z]) (["()A-Za-z])', r"\1 \2", schema)
return schema.rstrip(";")
def drop_table(self, table_name: str) -> str: def drop_table(self, table_name: str) -> str:
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name) return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
@ -130,8 +134,6 @@ class BaseDDL:
else "" else ""
) )
template = self._ADD_COLUMN_TEMPLATE template = self._ADD_COLUMN_TEMPLATE
return template.format(
table_name=db_table,
column = self.schema_generator._create_string( column = self.schema_generator._create_string(
db_column=db_column, db_column=db_column,
field_type=db_field_types.get(self.DIALECT, db_field_types.get("")), field_type=db_field_types.get(self.DIALECT, db_field_types.get("")),
@ -148,8 +150,10 @@ class BaseDDL:
), ),
is_primary_key=is_pk, is_primary_key=is_pk,
default=default, default=default,
),
) )
if tortoise.__version__ <= "0.23.0":
column = column.replace(" ", " ")
return template.format(table_name=db_table, column=column)
def drop_column(self, model: "Type[Model]", column_name: str) -> str: def drop_column(self, model: "Type[Model]", column_name: str) -> str:
return self._DROP_COLUMN_TEMPLATE.format( return self._DROP_COLUMN_TEMPLATE.format(

View File

@ -61,7 +61,17 @@ async def initialize_tests(event_loop, request) -> None:
with contextlib.suppress(DBConnectionError, OperationalError): with contextlib.suppress(DBConnectionError, OperationalError):
await Tortoise._drop_databases() await Tortoise._drop_databases()
await Tortoise.init(config=tortoise_orm, _create_db=True) await Tortoise.init(config=tortoise_orm, _create_db=True)
try:
await generate_schema_for_client(Tortoise.get_connection("default"), safe=True) await generate_schema_for_client(Tortoise.get_connection("default"), safe=True)
except OperationalError as e:
if (s := "IF NOT EXISTS") not in str(e):
raise e
# MySQL does not support `CREATE INDEX IF NOT EXISTS` syntax
client = Tortoise.get_connection("default")
generator = client.schema_generator(client)
schema = generator.get_create_schema_sql(safe=True)
schema = schema.replace(f" INDEX {s}", " INDEX")
await generator.generate_from_string(schema)
client = Tortoise.get_connection("default") client = Tortoise.get_connection("default")
if client.schema_generator is MySQLSchemaGenerator: if client.schema_generator is MySQLSchemaGenerator:

View File

@ -181,7 +181,11 @@ def test_sqlite_migrate(tmp_path: Path) -> None:
if (db_file := Path("db.sqlite3")).exists(): if (db_file := Path("db.sqlite3")).exists():
db_file.unlink() db_file.unlink()
run_aerich("aerich init -t settings.TORTOISE_ORM") run_aerich("aerich init -t settings.TORTOISE_ORM")
config_file = Path("pyproject.toml")
modify_time = config_file.stat().st_mtime
run_aerich("aerich init-db") run_aerich("aerich init-db")
run_aerich("aerich init -t settings.TORTOISE_ORM")
assert modify_time == config_file.stat().st_mtime
r = run_shell("pytest _test.py::test_allow_duplicate") r = run_shell("pytest _test.py::test_allow_duplicate")
assert r.returncode == 0 assert r.returncode == 0
# Add index # Add index