From 6466a852c8b57a577ac9a11a2e04155bd1541e5c Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Sun, 2 Jun 2024 18:09:34 +0800 Subject: [PATCH] Add type hints for aerich.migrate --- aerich/migrate.py | 284 ++++++++++++++++++++++++---------------------- 1 file changed, 146 insertions(+), 138 deletions(-) diff --git a/aerich/migrate.py b/aerich/migrate.py index 32cad7f..8a88685 100644 --- a/aerich/migrate.py +++ b/aerich/migrate.py @@ -1,9 +1,9 @@ +import hashlib import importlib import os from datetime import datetime -from hashlib import md5 from pathlib import Path -from typing import Dict, List, Optional, Tuple, Type, Union +from typing import Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast import click from dictdiffer import diff @@ -37,16 +37,21 @@ class Migrate: _upgrade_m2m: List[str] = [] _downgrade_m2m: List[str] = [] _aerich = Aerich.__name__ - _rename_old = [] - _rename_new = [] + _rename_old: List[str] = [] + _rename_new: List[str] = [] ddl: BaseDDL + ddl_class: Type[BaseDDL] _last_version_content: Optional[dict] = None app: str migrate_location: Path dialect: str _db_version: Optional[str] = None + @staticmethod + def get_field_by_name(name: str, fields: List[dict]) -> dict: + return next(filter(lambda x: x.get("name") == name, fields)) + @classmethod def get_all_version_files(cls) -> List[str]: return sorted( @@ -56,35 +61,35 @@ class Migrate: @classmethod def _get_model(cls, model: str) -> Type[Model]: - return Tortoise.apps.get(cls.app).get(model) + return Tortoise.apps[cls.app][model] @classmethod async def get_last_version(cls) -> Optional[Aerich]: try: return await Aerich.filter(app=cls.app).first() except OperationalError: - pass + return None @classmethod - async def _get_db_version(cls, connection: BaseDBAsyncClient): + async def _get_db_version(cls, connection: BaseDBAsyncClient) -> None: if cls.dialect == "mysql": sql = "select version() as version" ret = await connection.execute_query(sql) cls._db_version = ret[1][0].get("version") @classmethod - async def load_ddl_class(cls): + async def load_ddl_class(cls) -> Type[BaseDDL]: ddl_dialect_module = importlib.import_module(f"aerich.ddl.{cls.dialect}") return getattr(ddl_dialect_module, f"{cls.dialect.capitalize()}DDL") @classmethod - async def init(cls, config: dict, app: str, location: str): + async def init(cls, config: dict, app: str, location: str) -> None: await Tortoise.init(config=config) last_version = await cls.get_last_version() cls.app = app cls.migrate_location = Path(location, app) if last_version: - cls._last_version_content = last_version.content + cls._last_version_content = cast(dict, last_version.content) connection = get_app_connection(config, app) cls.dialect = connection.schema_generator.DIALECT @@ -93,7 +98,7 @@ class Migrate: await cls._get_db_version(connection) @classmethod - async def _get_last_version_num(cls): + async def _get_last_version_num(cls) -> Optional[int]: last_version = await cls.get_last_version() if not last_version: return None @@ -101,7 +106,7 @@ class Migrate: return int(version.split("_", 1)[0]) @classmethod - async def generate_version(cls, name=None): + async def generate_version(cls, name=None) -> str: now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "") last_version_num = await cls._get_last_version_num() if last_version_num is None: @@ -112,18 +117,15 @@ class Migrate: return version @classmethod - async def _generate_diff_py(cls, name): + async def _generate_diff_py(cls, name) -> str: version = await cls.generate_version(name) # delete if same version exists for version_file in cls.get_all_version_files(): if version_file.startswith(version.split("_")[0]): os.unlink(Path(cls.migrate_location, version_file)) - version_file = Path(cls.migrate_location, version) content = cls._get_diff_file_content() - - with open(version_file, "w", encoding="utf-8") as f: - f.write(content) + Path(cls.migrate_location, version).write_text(content, encoding="utf-8") return version @classmethod @@ -136,10 +138,10 @@ class Migrate: """ if empty: return await cls._generate_diff_py(name) - new_version_content = get_models_describe(cls.app) - cls.diff_models(cls._last_version_content, new_version_content) - cls.diff_models(new_version_content, cls._last_version_content, False) + last_version = cast(dict, cls._last_version_content) + cls.diff_models(last_version, new_version_content) + cls.diff_models(new_version_content, last_version, False) cls._merge_operators() @@ -165,7 +167,7 @@ class Migrate: ) @classmethod - def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False): + def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False) -> None: """ add operator,differentiate fk because fk is order limit :param operator: @@ -186,19 +188,37 @@ class Migrate: cls.downgrade_operators.append(operator) @classmethod - def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]): - ret = [] + def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]) -> list: + ret: list = [] + + def index_hash(self) -> str: + h = hashlib.new("MD5", usedforsecurity=False) + h.update( + self.index_name(cls.ddl.schema_generator, model).encode() + + self.__class__.__name__.encode() + ) + return h.hexdigest() + for index in indexes: if isinstance(index, Index): - index.__hash__ = lambda self: md5( # nosec: B303 - self.index_name(cls.ddl.schema_generator, model).encode() - + self.__class__.__name__.encode() - ).hexdigest() + index.__hash__ = index_hash # type:ignore[method-assign,assignment] ret.append(index) return ret @classmethod - def diff_models(cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True): + def _get_indexes(cls, model, model_describe: dict) -> Set[Union[Index, Tuple[str, ...]]]: + indexes: Set[Union[Index, Tuple[str, ...]]] = set() + for x in cls._handle_indexes(model, model_describe.get("indexes", [])): + if isinstance(x, Index): + indexes.add(x) + else: + indexes.add(cast(Tuple[str, ...], tuple(x))) + return indexes + + @classmethod + def diff_models( + cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True + ) -> None: """ diff models and add operators :param old_models: @@ -211,39 +231,35 @@ class Migrate: new_models.pop(_aerich, None) for new_model_str, new_model_describe in new_models.items(): - model = cls._get_model(new_model_describe.get("name").split(".")[1]) + model = cls._get_model(new_model_describe["name"].split(".")[1]) - if new_model_str not in old_models.keys(): + if new_model_str not in old_models: if upgrade: cls._add_operator(cls.add_model(model), upgrade) else: # we can't find origin model when downgrade, so skip pass else: - old_model_describe = old_models.get(new_model_str) + old_model_describe = cast(dict, old_models.get(new_model_str)) # rename table - new_table = new_model_describe.get("table") - old_table = old_model_describe.get("table") + new_table = cast(str, new_model_describe.get("table")) + old_table = cast(str, old_model_describe.get("table")) if new_table != old_table: cls._add_operator(cls.rename_table(model, old_table, new_table), upgrade) old_unique_together = set( - map(lambda x: tuple(x), old_model_describe.get("unique_together")) + map( + lambda x: tuple(x), + cast(List[Iterable[str]], old_model_describe.get("unique_together")), + ) ) new_unique_together = set( - map(lambda x: tuple(x), new_model_describe.get("unique_together")) - ) - old_indexes = set( map( - lambda x: x if isinstance(x, Index) else tuple(x), - cls._handle_indexes(model, old_model_describe.get("indexes", [])), - ) - ) - new_indexes = set( - map( - lambda x: x if isinstance(x, Index) else tuple(x), - cls._handle_indexes(model, new_model_describe.get("indexes", [])), + lambda x: tuple(x), + cast(List[Iterable[str]], new_model_describe.get("unique_together")), ) ) + old_indexes = cls._get_indexes(model, old_model_describe) + new_indexes = cls._get_indexes(model, new_model_describe) old_pk_field = old_model_describe.get("pk_field") new_pk_field = new_model_describe.get("pk_field") # pk field @@ -253,18 +269,19 @@ class Migrate: if action == "change" and option == "name": cls._add_operator(cls._rename_field(model, *change), upgrade) # m2m fields - old_m2m_fields = old_model_describe.get("m2m_fields") - new_m2m_fields = new_model_describe.get("m2m_fields") + old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields")) + new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields")) for action, option, change in diff(old_m2m_fields, new_m2m_fields): if change[0][0] == "db_constraint": continue - if isinstance(change[0][1], str): + new_value = change[0][1] + if isinstance(new_value, str): for new_m2m_field in new_m2m_fields: - if new_m2m_field["name"] == change[0][1]: - table = new_m2m_field.get("through") + if new_m2m_field["name"] == new_value: + table = cast(str, new_m2m_field.get("through")) break else: - table = change[0][1].get("through") + table = new_value.get("through") if action == "add": add = False if upgrade and table not in cls._upgrade_m2m: @@ -274,12 +291,9 @@ class Migrate: cls._downgrade_m2m.append(table) add = True if add: + ref_desc = cast(dict, new_models.get(new_value.get("model_name"))) cls._add_operator( - cls.create_m2m( - model, - change[0][1], - new_models.get(change[0][1].get("model_name")), - ), + cls.create_m2m(model, new_value, ref_desc), upgrade, fk_m2m_index=True, ) @@ -300,38 +314,36 @@ class Migrate: for index in old_unique_together.difference(new_unique_together): cls._add_operator(cls._drop_index(model, index, True), upgrade, True) # add indexes - for index in new_indexes.difference(old_indexes): - cls._add_operator(cls._add_index(model, index, False), upgrade, True) + for idx in new_indexes.difference(old_indexes): + cls._add_operator(cls._add_index(model, idx, False), upgrade, True) # remove indexes - for index in old_indexes.difference(new_indexes): - cls._add_operator(cls._drop_index(model, index, False), upgrade, True) + for idx in old_indexes.difference(new_indexes): + cls._add_operator(cls._drop_index(model, idx, False), upgrade, True) old_data_fields = list( filter( lambda x: x.get("db_field_types") is not None, - old_model_describe.get("data_fields"), + cast(List[dict], old_model_describe.get("data_fields")), ) ) new_data_fields = list( filter( lambda x: x.get("db_field_types") is not None, - new_model_describe.get("data_fields"), + cast(List[dict], new_model_describe.get("data_fields")), ) ) - old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields)) - new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields)) + old_data_fields_name = cast(List[str], [i.get("name") for i in old_data_fields]) + new_data_fields_name = cast(List[str], [i.get("name") for i in new_data_fields]) # add fields or rename fields for new_data_field_name in set(new_data_fields_name).difference( set(old_data_fields_name) ): - new_data_field = next( - filter(lambda x: x.get("name") == new_data_field_name, new_data_fields) - ) + new_data_field = cls.get_field_by_name(new_data_field_name, new_data_fields) is_rename = False for old_data_field in old_data_fields: changes = list(diff(old_data_field, new_data_field)) - old_data_field_name = old_data_field.get("name") + old_data_field_name = cast(str, old_data_field.get("name")) if len(changes) == 2: # rename field if ( @@ -392,7 +404,7 @@ class Migrate: if new_data_field["indexed"]: cls._add_operator( cls._add_index( - model, {new_data_field["db_column"]}, new_data_field["unique"] + model, (new_data_field["db_column"],), new_data_field["unique"] ), upgrade, True, @@ -406,45 +418,34 @@ class Migrate: not upgrade and old_data_field_name in cls._rename_new ): continue - old_data_field = next( - filter(lambda x: x.get("name") == old_data_field_name, old_data_fields) - ) - db_column = old_data_field["db_column"] + old_data_field = cls.get_field_by_name(old_data_field_name, old_data_fields) + db_column = cast(str, old_data_field["db_column"]) cls._add_operator( - cls._remove_field( - model, - db_column, - ), + cls._remove_field(model, db_column), upgrade, ) if old_data_field["indexed"]: cls._add_operator( - cls._drop_index( - model, - {db_column}, - ), + cls._drop_index(model, {db_column}), upgrade, True, ) - old_fk_fields = old_model_describe.get("fk_fields") - new_fk_fields = new_model_describe.get("fk_fields") + old_fk_fields = cast(List[dict], old_model_describe.get("fk_fields")) + new_fk_fields = cast(List[dict], new_model_describe.get("fk_fields")) - old_fk_fields_name = list(map(lambda x: x.get("name"), old_fk_fields)) - new_fk_fields_name = list(map(lambda x: x.get("name"), new_fk_fields)) + old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields] + new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields] # add fk for new_fk_field_name in set(new_fk_fields_name).difference( set(old_fk_fields_name) ): - fk_field = next( - filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields) - ) + fk_field = cls.get_field_by_name(new_fk_field_name, new_fk_fields) if fk_field.get("db_constraint"): + ref_describe = cast(dict, new_models[fk_field["python_type"]]) cls._add_operator( - cls._add_fk( - model, fk_field, new_models.get(fk_field.get("python_type")) - ), + cls._add_fk(model, fk_field, ref_describe), upgrade, fk_m2m_index=True, ) @@ -452,25 +453,20 @@ class Migrate: for old_fk_field_name in set(old_fk_fields_name).difference( set(new_fk_fields_name) ): - old_fk_field = next( - filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields) + old_fk_field = cls.get_field_by_name( + old_fk_field_name, cast(List[dict], old_fk_fields) ) if old_fk_field.get("db_constraint"): + ref_describe = cast(dict, old_models[old_fk_field["python_type"]]) cls._add_operator( - cls._drop_fk( - model, old_fk_field, old_models.get(old_fk_field.get("python_type")) - ), + cls._drop_fk(model, old_fk_field, ref_describe), upgrade, fk_m2m_index=True, ) # change fields for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)): - old_data_field = next( - filter(lambda x: x.get("name") == field_name, old_data_fields) - ) - new_data_field = next( - filter(lambda x: x.get("name") == field_name, new_data_fields) - ) + old_data_field = cls.get_field_by_name(field_name, old_data_fields) + new_data_field = cls.get_field_by_name(field_name, new_data_fields) changes = diff(old_data_field, new_data_field) modified = False for change in changes: @@ -479,13 +475,14 @@ class Migrate: # change index unique = new_data_field.get("unique") if old_new[0] is False and old_new[1] is True: - cls._add_operator( - cls._add_index(model, (field_name,), unique), upgrade, True - ) + add_or_drop_index = cls._add_index else: - cls._add_operator( - cls._drop_index(model, (field_name,), unique), upgrade, True - ) + # For drop case, unique value should get from old data + # TODO: unique = old_data_field.get("unique") + add_or_drop_index = cls._drop_index + cls._add_operator( + add_or_drop_index(model, (field_name,), unique), upgrade, True + ) elif option == "db_field_types.": if new_data_field.get("field_type") == "DecimalField": # modify column @@ -522,32 +519,33 @@ class Migrate: ) modified = True - for old_model in old_models: - if old_model not in new_models.keys(): - cls._add_operator(cls.drop_model(old_models.get(old_model).get("table")), upgrade) + for old_model in old_models.keys() - new_models.keys(): + cls._add_operator(cls.drop_model(old_models[old_model]["table"]), upgrade) @classmethod - def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str): + def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str) -> str: return cls.ddl.rename_table(model, old_table_name, new_table_name) @classmethod - def add_model(cls, model: Type[Model]): + def add_model(cls, model: Type[Model]) -> str: return cls.ddl.create_table(model) @classmethod - def drop_model(cls, table_name: str): + def drop_model(cls, table_name: str) -> str: return cls.ddl.drop_table(table_name) @classmethod - def create_m2m(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict): + def create_m2m( + cls, model: Type[Model], field_describe: dict, reference_table_describe: dict + ) -> str: return cls.ddl.create_m2m(model, field_describe, reference_table_describe) @classmethod - def drop_m2m(cls, table_name: str): + def drop_m2m(cls, table_name: str) -> str: return cls.ddl.drop_m2m(table_name) @classmethod - def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]): + def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Iterable[str]) -> List[str]: ret = [] for field_name in fields_name: field = model._meta.fields_map[field_name] @@ -560,65 +558,75 @@ class Migrate: return ret @classmethod - def _drop_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False): + def _drop_index( + cls, model: Type[Model], fields_name: Union[Iterable[str], Index], unique=False + ) -> str: if isinstance(fields_name, Index): return cls.ddl.drop_index_by_name( model, fields_name.index_name(cls.ddl.schema_generator, model) ) - fields_name = cls._resolve_fk_fields_name(model, fields_name) - return cls.ddl.drop_index(model, fields_name, unique) + field_names = cls._resolve_fk_fields_name(model, fields_name) + return cls.ddl.drop_index(model, field_names, unique) @classmethod - def _add_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False): + def _add_index( + cls, model: Type[Model], fields_name: Union[Iterable[str], Index], unique=False + ) -> str: if isinstance(fields_name, Index): return fields_name.get_sql(cls.ddl.schema_generator, model, False) - fields_name = cls._resolve_fk_fields_name(model, fields_name) - return cls.ddl.add_index(model, fields_name, unique) + field_names = cls._resolve_fk_fields_name(model, fields_name) + return cls.ddl.add_index(model, field_names, unique) @classmethod - def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False): + def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False) -> str: return cls.ddl.add_column(model, field_describe, is_pk) @classmethod - def _alter_default(cls, model: Type[Model], field_describe: dict): + def _alter_default(cls, model: Type[Model], field_describe: dict) -> str: return cls.ddl.alter_column_default(model, field_describe) @classmethod - def _alter_null(cls, model: Type[Model], field_describe: dict): + def _alter_null(cls, model: Type[Model], field_describe: dict) -> str: return cls.ddl.alter_column_null(model, field_describe) @classmethod - def _set_comment(cls, model: Type[Model], field_describe: dict): + def _set_comment(cls, model: Type[Model], field_describe: dict) -> str: return cls.ddl.set_comment(model, field_describe) @classmethod - def _modify_field(cls, model: Type[Model], field_describe: dict): + def _modify_field(cls, model: Type[Model], field_describe: dict) -> str: return cls.ddl.modify_column(model, field_describe) @classmethod - def _drop_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict): + def _drop_fk( + cls, model: Type[Model], field_describe: dict, reference_table_describe: dict + ) -> str: return cls.ddl.drop_fk(model, field_describe, reference_table_describe) @classmethod - def _remove_field(cls, model: Type[Model], column_name: str): + def _remove_field(cls, model: Type[Model], column_name: str) -> str: return cls.ddl.drop_column(model, column_name) @classmethod - def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str): + def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str) -> str: return cls.ddl.rename_column(model, old_field_name, new_field_name) @classmethod - def _change_field(cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict): - db_field_types = new_field_describe.get("db_field_types") + def _change_field( + cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict + ) -> str: + db_field_types = cast(dict, new_field_describe.get("db_field_types")) return cls.ddl.change_column( model, - old_field_describe.get("db_column"), - new_field_describe.get("db_column"), - db_field_types.get(cls.dialect) or db_field_types.get(""), + cast(str, old_field_describe.get("db_column")), + cast(str, new_field_describe.get("db_column")), + cast(str, db_field_types.get(cls.dialect) or db_field_types.get("")), ) @classmethod - def _add_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict): + def _add_fk( + cls, model: Type[Model], field_describe: dict, reference_table_describe: dict + ) -> str: """ add fk :param model: @@ -629,7 +637,7 @@ class Migrate: return cls.ddl.add_fk(model, field_describe, reference_table_describe) @classmethod - def _merge_operators(cls): + def _merge_operators(cls) -> None: """ fk/m2m/index must be last when add,first when drop :return: