chore: upgrade deps and fix ruff lint issues (#374)
* chore: upgrade deps and apply ruff lint for tests/ * style: fix ruff lint issues
This commit is contained in:
@@ -42,7 +42,7 @@ class Command:
|
||||
async def _upgrade(self, conn, version_file) -> None:
|
||||
file_path = Path(Migrate.migrate_location, version_file)
|
||||
m = import_py_file(file_path)
|
||||
upgrade = getattr(m, "upgrade")
|
||||
upgrade = m.upgrade
|
||||
await conn.execute_script(await upgrade(conn))
|
||||
await Aerich.create(
|
||||
version=version_file,
|
||||
@@ -89,7 +89,7 @@ class Command:
|
||||
) as conn:
|
||||
file_path = Path(Migrate.migrate_location, file)
|
||||
m = import_py_file(file_path)
|
||||
downgrade = getattr(m, "downgrade")
|
||||
downgrade = m.downgrade
|
||||
downgrade_sql = await downgrade(conn)
|
||||
if not downgrade_sql.strip():
|
||||
raise DowngradeError("No downgrade items found")
|
||||
|
||||
@@ -47,8 +47,10 @@ async def cli(ctx: Context, config, app) -> None:
|
||||
location = tool["location"]
|
||||
tortoise_orm = tool["tortoise_orm"]
|
||||
src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"])
|
||||
except NonExistentKey:
|
||||
raise UsageError("You need run `aerich init` again when upgrading to aerich 0.6.0+.")
|
||||
except NonExistentKey as e:
|
||||
raise UsageError(
|
||||
"You need run `aerich init` again when upgrading to aerich 0.6.0+."
|
||||
) from e
|
||||
add_src_path(src_folder)
|
||||
tortoise_config = get_tortoise_config(ctx, tortoise_orm)
|
||||
if not app:
|
||||
@@ -182,10 +184,7 @@ async def init(ctx: Context, tortoise_orm, location, src_folder) -> None:
|
||||
add_src_path(src_folder)
|
||||
get_tortoise_config(ctx, tortoise_orm)
|
||||
config_path = Path(config_file)
|
||||
if config_path.exists():
|
||||
content = config_path.read_text()
|
||||
else:
|
||||
content = "[tool.aerich]"
|
||||
content = config_path.read_bytes() if config_path.exists() else "[tool.aerich]"
|
||||
doc: dict = tomlkit.parse(content)
|
||||
table = tomlkit.table()
|
||||
table["tortoise_orm"] = tortoise_orm
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from typing import Any, Callable, Dict, Optional, TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -61,8 +62,8 @@ class Column(BaseModel):
|
||||
elif self.data_type == "bool":
|
||||
default = f"default={'True' if self.default == 'true' else 'False'}, "
|
||||
elif self.data_type in ("datetime", "timestamptz", "TIMESTAMP"):
|
||||
if "CURRENT_TIMESTAMP" == self.default:
|
||||
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
|
||||
if self.default == "CURRENT_TIMESTAMP":
|
||||
if self.extra == "DEFAULT_GENERATED on update CURRENT_TIMESTAMP":
|
||||
default = "auto_now=True, "
|
||||
else:
|
||||
default = "auto_now_add=True, "
|
||||
@@ -94,10 +95,8 @@ class Inspect:
|
||||
|
||||
def __init__(self, conn: BaseDBAsyncClient, tables: list[str] | None = None) -> None:
|
||||
self.conn = conn
|
||||
try:
|
||||
with contextlib.suppress(AttributeError):
|
||||
self.database = conn.database # type:ignore[attr-defined]
|
||||
except AttributeError:
|
||||
pass
|
||||
self.tables = tables
|
||||
|
||||
@property
|
||||
|
||||
@@ -40,16 +40,11 @@ where c.TABLE_SCHEMA = %s
|
||||
and c.TABLE_NAME = %s"""
|
||||
ret = await self.conn.execute_query_dict(sql, [self.database, table])
|
||||
for row in ret:
|
||||
non_unique = row["NON_UNIQUE"]
|
||||
if non_unique is None:
|
||||
unique = False
|
||||
else:
|
||||
unique = index = False
|
||||
if (non_unique := row["NON_UNIQUE"]) is not None:
|
||||
unique = not non_unique
|
||||
index_name = row["INDEX_NAME"]
|
||||
if index_name is None:
|
||||
index = False
|
||||
else:
|
||||
index = row["INDEX_NAME"] != "PRIMARY"
|
||||
if (index_name := row["INDEX_NAME"]) is not None:
|
||||
index = index_name != "PRIMARY"
|
||||
columns.append(
|
||||
Column(
|
||||
name=row["COLUMN_NAME"],
|
||||
|
||||
@@ -271,7 +271,7 @@ class Migrate:
|
||||
# m2m fields
|
||||
old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields"))
|
||||
new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields"))
|
||||
for action, option, change in diff(old_m2m_fields, new_m2m_fields):
|
||||
for action, _, change in diff(old_m2m_fields, new_m2m_fields):
|
||||
if change[0][0] == "db_constraint":
|
||||
continue
|
||||
new_value = change[0][1]
|
||||
@@ -346,22 +346,14 @@ class Migrate:
|
||||
old_data_field_name = cast(str, old_data_field.get("name"))
|
||||
if len(changes) == 2:
|
||||
# rename field
|
||||
name_diff = (old_data_field_name, new_data_field_name)
|
||||
column_diff = (
|
||||
old_data_field.get("db_column"),
|
||||
new_data_field.get("db_column"),
|
||||
)
|
||||
if (
|
||||
changes[0]
|
||||
== (
|
||||
"change",
|
||||
"name",
|
||||
(old_data_field_name, new_data_field_name),
|
||||
)
|
||||
and changes[1]
|
||||
== (
|
||||
"change",
|
||||
"db_column",
|
||||
(
|
||||
old_data_field.get("db_column"),
|
||||
new_data_field.get("db_column"),
|
||||
),
|
||||
)
|
||||
changes[0] == ("change", "name", name_diff)
|
||||
and changes[1] == ("change", "db_column", column_diff)
|
||||
and old_data_field_name not in new_data_fields_name
|
||||
):
|
||||
if upgrade:
|
||||
|
||||
Reference in New Issue
Block a user