Add type hints for aerich.migrate

This commit is contained in:
Waket Zheng 2024-06-02 18:09:34 +08:00
parent 51117867a6
commit 6466a852c8

View File

@ -1,9 +1,9 @@
import hashlib
import importlib
import os
from datetime import datetime
from hashlib import md5
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type, Union
from typing import Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast
import click
from dictdiffer import diff
@ -37,16 +37,21 @@ class Migrate:
_upgrade_m2m: List[str] = []
_downgrade_m2m: List[str] = []
_aerich = Aerich.__name__
_rename_old = []
_rename_new = []
_rename_old: List[str] = []
_rename_new: List[str] = []
ddl: BaseDDL
ddl_class: Type[BaseDDL]
_last_version_content: Optional[dict] = None
app: str
migrate_location: Path
dialect: str
_db_version: Optional[str] = None
@staticmethod
def get_field_by_name(name: str, fields: List[dict]) -> dict:
return next(filter(lambda x: x.get("name") == name, fields))
@classmethod
def get_all_version_files(cls) -> List[str]:
return sorted(
@ -56,35 +61,35 @@ class Migrate:
@classmethod
def _get_model(cls, model: str) -> Type[Model]:
return Tortoise.apps.get(cls.app).get(model)
return Tortoise.apps[cls.app][model]
@classmethod
async def get_last_version(cls) -> Optional[Aerich]:
try:
return await Aerich.filter(app=cls.app).first()
except OperationalError:
pass
return None
@classmethod
async def _get_db_version(cls, connection: BaseDBAsyncClient):
async def _get_db_version(cls, connection: BaseDBAsyncClient) -> None:
if cls.dialect == "mysql":
sql = "select version() as version"
ret = await connection.execute_query(sql)
cls._db_version = ret[1][0].get("version")
@classmethod
async def load_ddl_class(cls):
async def load_ddl_class(cls) -> Type[BaseDDL]:
ddl_dialect_module = importlib.import_module(f"aerich.ddl.{cls.dialect}")
return getattr(ddl_dialect_module, f"{cls.dialect.capitalize()}DDL")
@classmethod
async def init(cls, config: dict, app: str, location: str):
async def init(cls, config: dict, app: str, location: str) -> None:
await Tortoise.init(config=config)
last_version = await cls.get_last_version()
cls.app = app
cls.migrate_location = Path(location, app)
if last_version:
cls._last_version_content = last_version.content
cls._last_version_content = cast(dict, last_version.content)
connection = get_app_connection(config, app)
cls.dialect = connection.schema_generator.DIALECT
@ -93,7 +98,7 @@ class Migrate:
await cls._get_db_version(connection)
@classmethod
async def _get_last_version_num(cls):
async def _get_last_version_num(cls) -> Optional[int]:
last_version = await cls.get_last_version()
if not last_version:
return None
@ -101,7 +106,7 @@ class Migrate:
return int(version.split("_", 1)[0])
@classmethod
async def generate_version(cls, name=None):
async def generate_version(cls, name=None) -> str:
now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "")
last_version_num = await cls._get_last_version_num()
if last_version_num is None:
@ -112,18 +117,15 @@ class Migrate:
return version
@classmethod
async def _generate_diff_py(cls, name):
async def _generate_diff_py(cls, name) -> str:
version = await cls.generate_version(name)
# delete if same version exists
for version_file in cls.get_all_version_files():
if version_file.startswith(version.split("_")[0]):
os.unlink(Path(cls.migrate_location, version_file))
version_file = Path(cls.migrate_location, version)
content = cls._get_diff_file_content()
with open(version_file, "w", encoding="utf-8") as f:
f.write(content)
Path(cls.migrate_location, version).write_text(content, encoding="utf-8")
return version
@classmethod
@ -136,10 +138,10 @@ class Migrate:
"""
if empty:
return await cls._generate_diff_py(name)
new_version_content = get_models_describe(cls.app)
cls.diff_models(cls._last_version_content, new_version_content)
cls.diff_models(new_version_content, cls._last_version_content, False)
last_version = cast(dict, cls._last_version_content)
cls.diff_models(last_version, new_version_content)
cls.diff_models(new_version_content, last_version, False)
cls._merge_operators()
@ -165,7 +167,7 @@ class Migrate:
)
@classmethod
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False):
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False) -> None:
"""
add operator,differentiate fk because fk is order limit
:param operator:
@ -186,19 +188,37 @@ class Migrate:
cls.downgrade_operators.append(operator)
@classmethod
def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]):
ret = []
def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]) -> list:
ret: list = []
def index_hash(self) -> str:
h = hashlib.new("MD5", usedforsecurity=False)
h.update(
self.index_name(cls.ddl.schema_generator, model).encode()
+ self.__class__.__name__.encode()
)
return h.hexdigest()
for index in indexes:
if isinstance(index, Index):
index.__hash__ = lambda self: md5( # nosec: B303
self.index_name(cls.ddl.schema_generator, model).encode()
+ self.__class__.__name__.encode()
).hexdigest()
index.__hash__ = index_hash # type:ignore[method-assign,assignment]
ret.append(index)
return ret
@classmethod
def diff_models(cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True):
def _get_indexes(cls, model, model_describe: dict) -> Set[Union[Index, Tuple[str, ...]]]:
indexes: Set[Union[Index, Tuple[str, ...]]] = set()
for x in cls._handle_indexes(model, model_describe.get("indexes", [])):
if isinstance(x, Index):
indexes.add(x)
else:
indexes.add(cast(Tuple[str, ...], tuple(x)))
return indexes
@classmethod
def diff_models(
cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True
) -> None:
"""
diff models and add operators
:param old_models:
@ -211,39 +231,35 @@ class Migrate:
new_models.pop(_aerich, None)
for new_model_str, new_model_describe in new_models.items():
model = cls._get_model(new_model_describe.get("name").split(".")[1])
model = cls._get_model(new_model_describe["name"].split(".")[1])
if new_model_str not in old_models.keys():
if new_model_str not in old_models:
if upgrade:
cls._add_operator(cls.add_model(model), upgrade)
else:
# we can't find origin model when downgrade, so skip
pass
else:
old_model_describe = old_models.get(new_model_str)
old_model_describe = cast(dict, old_models.get(new_model_str))
# rename table
new_table = new_model_describe.get("table")
old_table = old_model_describe.get("table")
new_table = cast(str, new_model_describe.get("table"))
old_table = cast(str, old_model_describe.get("table"))
if new_table != old_table:
cls._add_operator(cls.rename_table(model, old_table, new_table), upgrade)
old_unique_together = set(
map(lambda x: tuple(x), old_model_describe.get("unique_together"))
map(
lambda x: tuple(x),
cast(List[Iterable[str]], old_model_describe.get("unique_together")),
)
)
new_unique_together = set(
map(lambda x: tuple(x), new_model_describe.get("unique_together"))
)
old_indexes = set(
map(
lambda x: x if isinstance(x, Index) else tuple(x),
cls._handle_indexes(model, old_model_describe.get("indexes", [])),
)
)
new_indexes = set(
map(
lambda x: x if isinstance(x, Index) else tuple(x),
cls._handle_indexes(model, new_model_describe.get("indexes", [])),
lambda x: tuple(x),
cast(List[Iterable[str]], new_model_describe.get("unique_together")),
)
)
old_indexes = cls._get_indexes(model, old_model_describe)
new_indexes = cls._get_indexes(model, new_model_describe)
old_pk_field = old_model_describe.get("pk_field")
new_pk_field = new_model_describe.get("pk_field")
# pk field
@ -253,18 +269,19 @@ class Migrate:
if action == "change" and option == "name":
cls._add_operator(cls._rename_field(model, *change), upgrade)
# m2m fields
old_m2m_fields = old_model_describe.get("m2m_fields")
new_m2m_fields = new_model_describe.get("m2m_fields")
old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields"))
new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields"))
for action, option, change in diff(old_m2m_fields, new_m2m_fields):
if change[0][0] == "db_constraint":
continue
if isinstance(change[0][1], str):
new_value = change[0][1]
if isinstance(new_value, str):
for new_m2m_field in new_m2m_fields:
if new_m2m_field["name"] == change[0][1]:
table = new_m2m_field.get("through")
if new_m2m_field["name"] == new_value:
table = cast(str, new_m2m_field.get("through"))
break
else:
table = change[0][1].get("through")
table = new_value.get("through")
if action == "add":
add = False
if upgrade and table not in cls._upgrade_m2m:
@ -274,12 +291,9 @@ class Migrate:
cls._downgrade_m2m.append(table)
add = True
if add:
ref_desc = cast(dict, new_models.get(new_value.get("model_name")))
cls._add_operator(
cls.create_m2m(
model,
change[0][1],
new_models.get(change[0][1].get("model_name")),
),
cls.create_m2m(model, new_value, ref_desc),
upgrade,
fk_m2m_index=True,
)
@ -300,38 +314,36 @@ class Migrate:
for index in old_unique_together.difference(new_unique_together):
cls._add_operator(cls._drop_index(model, index, True), upgrade, True)
# add indexes
for index in new_indexes.difference(old_indexes):
cls._add_operator(cls._add_index(model, index, False), upgrade, True)
for idx in new_indexes.difference(old_indexes):
cls._add_operator(cls._add_index(model, idx, False), upgrade, True)
# remove indexes
for index in old_indexes.difference(new_indexes):
cls._add_operator(cls._drop_index(model, index, False), upgrade, True)
for idx in old_indexes.difference(new_indexes):
cls._add_operator(cls._drop_index(model, idx, False), upgrade, True)
old_data_fields = list(
filter(
lambda x: x.get("db_field_types") is not None,
old_model_describe.get("data_fields"),
cast(List[dict], old_model_describe.get("data_fields")),
)
)
new_data_fields = list(
filter(
lambda x: x.get("db_field_types") is not None,
new_model_describe.get("data_fields"),
cast(List[dict], new_model_describe.get("data_fields")),
)
)
old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields))
new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields))
old_data_fields_name = cast(List[str], [i.get("name") for i in old_data_fields])
new_data_fields_name = cast(List[str], [i.get("name") for i in new_data_fields])
# add fields or rename fields
for new_data_field_name in set(new_data_fields_name).difference(
set(old_data_fields_name)
):
new_data_field = next(
filter(lambda x: x.get("name") == new_data_field_name, new_data_fields)
)
new_data_field = cls.get_field_by_name(new_data_field_name, new_data_fields)
is_rename = False
for old_data_field in old_data_fields:
changes = list(diff(old_data_field, new_data_field))
old_data_field_name = old_data_field.get("name")
old_data_field_name = cast(str, old_data_field.get("name"))
if len(changes) == 2:
# rename field
if (
@ -392,7 +404,7 @@ class Migrate:
if new_data_field["indexed"]:
cls._add_operator(
cls._add_index(
model, {new_data_field["db_column"]}, new_data_field["unique"]
model, (new_data_field["db_column"],), new_data_field["unique"]
),
upgrade,
True,
@ -406,45 +418,34 @@ class Migrate:
not upgrade and old_data_field_name in cls._rename_new
):
continue
old_data_field = next(
filter(lambda x: x.get("name") == old_data_field_name, old_data_fields)
)
db_column = old_data_field["db_column"]
old_data_field = cls.get_field_by_name(old_data_field_name, old_data_fields)
db_column = cast(str, old_data_field["db_column"])
cls._add_operator(
cls._remove_field(
model,
db_column,
),
cls._remove_field(model, db_column),
upgrade,
)
if old_data_field["indexed"]:
cls._add_operator(
cls._drop_index(
model,
{db_column},
),
cls._drop_index(model, {db_column}),
upgrade,
True,
)
old_fk_fields = old_model_describe.get("fk_fields")
new_fk_fields = new_model_describe.get("fk_fields")
old_fk_fields = cast(List[dict], old_model_describe.get("fk_fields"))
new_fk_fields = cast(List[dict], new_model_describe.get("fk_fields"))
old_fk_fields_name = list(map(lambda x: x.get("name"), old_fk_fields))
new_fk_fields_name = list(map(lambda x: x.get("name"), new_fk_fields))
old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields]
new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields]
# add fk
for new_fk_field_name in set(new_fk_fields_name).difference(
set(old_fk_fields_name)
):
fk_field = next(
filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields)
)
fk_field = cls.get_field_by_name(new_fk_field_name, new_fk_fields)
if fk_field.get("db_constraint"):
ref_describe = cast(dict, new_models[fk_field["python_type"]])
cls._add_operator(
cls._add_fk(
model, fk_field, new_models.get(fk_field.get("python_type"))
),
cls._add_fk(model, fk_field, ref_describe),
upgrade,
fk_m2m_index=True,
)
@ -452,25 +453,20 @@ class Migrate:
for old_fk_field_name in set(old_fk_fields_name).difference(
set(new_fk_fields_name)
):
old_fk_field = next(
filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields)
old_fk_field = cls.get_field_by_name(
old_fk_field_name, cast(List[dict], old_fk_fields)
)
if old_fk_field.get("db_constraint"):
ref_describe = cast(dict, old_models[old_fk_field["python_type"]])
cls._add_operator(
cls._drop_fk(
model, old_fk_field, old_models.get(old_fk_field.get("python_type"))
),
cls._drop_fk(model, old_fk_field, ref_describe),
upgrade,
fk_m2m_index=True,
)
# change fields
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
old_data_field = next(
filter(lambda x: x.get("name") == field_name, old_data_fields)
)
new_data_field = next(
filter(lambda x: x.get("name") == field_name, new_data_fields)
)
old_data_field = cls.get_field_by_name(field_name, old_data_fields)
new_data_field = cls.get_field_by_name(field_name, new_data_fields)
changes = diff(old_data_field, new_data_field)
modified = False
for change in changes:
@ -479,13 +475,14 @@ class Migrate:
# change index
unique = new_data_field.get("unique")
if old_new[0] is False and old_new[1] is True:
cls._add_operator(
cls._add_index(model, (field_name,), unique), upgrade, True
)
add_or_drop_index = cls._add_index
else:
cls._add_operator(
cls._drop_index(model, (field_name,), unique), upgrade, True
)
# For drop case, unique value should get from old data
# TODO: unique = old_data_field.get("unique")
add_or_drop_index = cls._drop_index
cls._add_operator(
add_or_drop_index(model, (field_name,), unique), upgrade, True
)
elif option == "db_field_types.":
if new_data_field.get("field_type") == "DecimalField":
# modify column
@ -522,32 +519,33 @@ class Migrate:
)
modified = True
for old_model in old_models:
if old_model not in new_models.keys():
cls._add_operator(cls.drop_model(old_models.get(old_model).get("table")), upgrade)
for old_model in old_models.keys() - new_models.keys():
cls._add_operator(cls.drop_model(old_models[old_model]["table"]), upgrade)
@classmethod
def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str):
def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str) -> str:
return cls.ddl.rename_table(model, old_table_name, new_table_name)
@classmethod
def add_model(cls, model: Type[Model]):
def add_model(cls, model: Type[Model]) -> str:
return cls.ddl.create_table(model)
@classmethod
def drop_model(cls, table_name: str):
def drop_model(cls, table_name: str) -> str:
return cls.ddl.drop_table(table_name)
@classmethod
def create_m2m(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
def create_m2m(
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
return cls.ddl.create_m2m(model, field_describe, reference_table_describe)
@classmethod
def drop_m2m(cls, table_name: str):
def drop_m2m(cls, table_name: str) -> str:
return cls.ddl.drop_m2m(table_name)
@classmethod
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]):
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Iterable[str]) -> List[str]:
ret = []
for field_name in fields_name:
field = model._meta.fields_map[field_name]
@ -560,65 +558,75 @@ class Migrate:
return ret
@classmethod
def _drop_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False):
def _drop_index(
cls, model: Type[Model], fields_name: Union[Iterable[str], Index], unique=False
) -> str:
if isinstance(fields_name, Index):
return cls.ddl.drop_index_by_name(
model, fields_name.index_name(cls.ddl.schema_generator, model)
)
fields_name = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.drop_index(model, fields_name, unique)
field_names = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.drop_index(model, field_names, unique)
@classmethod
def _add_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False):
def _add_index(
cls, model: Type[Model], fields_name: Union[Iterable[str], Index], unique=False
) -> str:
if isinstance(fields_name, Index):
return fields_name.get_sql(cls.ddl.schema_generator, model, False)
fields_name = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.add_index(model, fields_name, unique)
field_names = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.add_index(model, field_names, unique)
@classmethod
def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False):
def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False) -> str:
return cls.ddl.add_column(model, field_describe, is_pk)
@classmethod
def _alter_default(cls, model: Type[Model], field_describe: dict):
def _alter_default(cls, model: Type[Model], field_describe: dict) -> str:
return cls.ddl.alter_column_default(model, field_describe)
@classmethod
def _alter_null(cls, model: Type[Model], field_describe: dict):
def _alter_null(cls, model: Type[Model], field_describe: dict) -> str:
return cls.ddl.alter_column_null(model, field_describe)
@classmethod
def _set_comment(cls, model: Type[Model], field_describe: dict):
def _set_comment(cls, model: Type[Model], field_describe: dict) -> str:
return cls.ddl.set_comment(model, field_describe)
@classmethod
def _modify_field(cls, model: Type[Model], field_describe: dict):
def _modify_field(cls, model: Type[Model], field_describe: dict) -> str:
return cls.ddl.modify_column(model, field_describe)
@classmethod
def _drop_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
def _drop_fk(
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
return cls.ddl.drop_fk(model, field_describe, reference_table_describe)
@classmethod
def _remove_field(cls, model: Type[Model], column_name: str):
def _remove_field(cls, model: Type[Model], column_name: str) -> str:
return cls.ddl.drop_column(model, column_name)
@classmethod
def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str):
def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str) -> str:
return cls.ddl.rename_column(model, old_field_name, new_field_name)
@classmethod
def _change_field(cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict):
db_field_types = new_field_describe.get("db_field_types")
def _change_field(
cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict
) -> str:
db_field_types = cast(dict, new_field_describe.get("db_field_types"))
return cls.ddl.change_column(
model,
old_field_describe.get("db_column"),
new_field_describe.get("db_column"),
db_field_types.get(cls.dialect) or db_field_types.get(""),
cast(str, old_field_describe.get("db_column")),
cast(str, new_field_describe.get("db_column")),
cast(str, db_field_types.get(cls.dialect) or db_field_types.get("")),
)
@classmethod
def _add_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
def _add_fk(
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
"""
add fk
:param model:
@ -629,7 +637,7 @@ class Migrate:
return cls.ddl.add_fk(model, field_describe, reference_table_describe)
@classmethod
def _merge_operators(cls):
def _merge_operators(cls) -> None:
"""
fk/m2m/index must be last when add,first when drop
:return: