Add type hints for aerich.migrate

This commit is contained in:
Waket Zheng 2024-06-02 18:09:34 +08:00
parent 51117867a6
commit 6466a852c8

View File

@ -1,9 +1,9 @@
import hashlib
import importlib import importlib
import os import os
from datetime import datetime from datetime import datetime
from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type, Union from typing import Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast
import click import click
from dictdiffer import diff from dictdiffer import diff
@ -37,16 +37,21 @@ class Migrate:
_upgrade_m2m: List[str] = [] _upgrade_m2m: List[str] = []
_downgrade_m2m: List[str] = [] _downgrade_m2m: List[str] = []
_aerich = Aerich.__name__ _aerich = Aerich.__name__
_rename_old = [] _rename_old: List[str] = []
_rename_new = [] _rename_new: List[str] = []
ddl: BaseDDL ddl: BaseDDL
ddl_class: Type[BaseDDL]
_last_version_content: Optional[dict] = None _last_version_content: Optional[dict] = None
app: str app: str
migrate_location: Path migrate_location: Path
dialect: str dialect: str
_db_version: Optional[str] = None _db_version: Optional[str] = None
@staticmethod
def get_field_by_name(name: str, fields: List[dict]) -> dict:
return next(filter(lambda x: x.get("name") == name, fields))
@classmethod @classmethod
def get_all_version_files(cls) -> List[str]: def get_all_version_files(cls) -> List[str]:
return sorted( return sorted(
@ -56,35 +61,35 @@ class Migrate:
@classmethod @classmethod
def _get_model(cls, model: str) -> Type[Model]: def _get_model(cls, model: str) -> Type[Model]:
return Tortoise.apps.get(cls.app).get(model) return Tortoise.apps[cls.app][model]
@classmethod @classmethod
async def get_last_version(cls) -> Optional[Aerich]: async def get_last_version(cls) -> Optional[Aerich]:
try: try:
return await Aerich.filter(app=cls.app).first() return await Aerich.filter(app=cls.app).first()
except OperationalError: except OperationalError:
pass return None
@classmethod @classmethod
async def _get_db_version(cls, connection: BaseDBAsyncClient): async def _get_db_version(cls, connection: BaseDBAsyncClient) -> None:
if cls.dialect == "mysql": if cls.dialect == "mysql":
sql = "select version() as version" sql = "select version() as version"
ret = await connection.execute_query(sql) ret = await connection.execute_query(sql)
cls._db_version = ret[1][0].get("version") cls._db_version = ret[1][0].get("version")
@classmethod @classmethod
async def load_ddl_class(cls): async def load_ddl_class(cls) -> Type[BaseDDL]:
ddl_dialect_module = importlib.import_module(f"aerich.ddl.{cls.dialect}") ddl_dialect_module = importlib.import_module(f"aerich.ddl.{cls.dialect}")
return getattr(ddl_dialect_module, f"{cls.dialect.capitalize()}DDL") return getattr(ddl_dialect_module, f"{cls.dialect.capitalize()}DDL")
@classmethod @classmethod
async def init(cls, config: dict, app: str, location: str): async def init(cls, config: dict, app: str, location: str) -> None:
await Tortoise.init(config=config) await Tortoise.init(config=config)
last_version = await cls.get_last_version() last_version = await cls.get_last_version()
cls.app = app cls.app = app
cls.migrate_location = Path(location, app) cls.migrate_location = Path(location, app)
if last_version: if last_version:
cls._last_version_content = last_version.content cls._last_version_content = cast(dict, last_version.content)
connection = get_app_connection(config, app) connection = get_app_connection(config, app)
cls.dialect = connection.schema_generator.DIALECT cls.dialect = connection.schema_generator.DIALECT
@ -93,7 +98,7 @@ class Migrate:
await cls._get_db_version(connection) await cls._get_db_version(connection)
@classmethod @classmethod
async def _get_last_version_num(cls): async def _get_last_version_num(cls) -> Optional[int]:
last_version = await cls.get_last_version() last_version = await cls.get_last_version()
if not last_version: if not last_version:
return None return None
@ -101,7 +106,7 @@ class Migrate:
return int(version.split("_", 1)[0]) return int(version.split("_", 1)[0])
@classmethod @classmethod
async def generate_version(cls, name=None): async def generate_version(cls, name=None) -> str:
now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "") now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "")
last_version_num = await cls._get_last_version_num() last_version_num = await cls._get_last_version_num()
if last_version_num is None: if last_version_num is None:
@ -112,18 +117,15 @@ class Migrate:
return version return version
@classmethod @classmethod
async def _generate_diff_py(cls, name): async def _generate_diff_py(cls, name) -> str:
version = await cls.generate_version(name) version = await cls.generate_version(name)
# delete if same version exists # delete if same version exists
for version_file in cls.get_all_version_files(): for version_file in cls.get_all_version_files():
if version_file.startswith(version.split("_")[0]): if version_file.startswith(version.split("_")[0]):
os.unlink(Path(cls.migrate_location, version_file)) os.unlink(Path(cls.migrate_location, version_file))
version_file = Path(cls.migrate_location, version)
content = cls._get_diff_file_content() content = cls._get_diff_file_content()
Path(cls.migrate_location, version).write_text(content, encoding="utf-8")
with open(version_file, "w", encoding="utf-8") as f:
f.write(content)
return version return version
@classmethod @classmethod
@ -136,10 +138,10 @@ class Migrate:
""" """
if empty: if empty:
return await cls._generate_diff_py(name) return await cls._generate_diff_py(name)
new_version_content = get_models_describe(cls.app) new_version_content = get_models_describe(cls.app)
cls.diff_models(cls._last_version_content, new_version_content) last_version = cast(dict, cls._last_version_content)
cls.diff_models(new_version_content, cls._last_version_content, False) cls.diff_models(last_version, new_version_content)
cls.diff_models(new_version_content, last_version, False)
cls._merge_operators() cls._merge_operators()
@ -165,7 +167,7 @@ class Migrate:
) )
@classmethod @classmethod
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False): def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False) -> None:
""" """
add operator,differentiate fk because fk is order limit add operator,differentiate fk because fk is order limit
:param operator: :param operator:
@ -186,19 +188,37 @@ class Migrate:
cls.downgrade_operators.append(operator) cls.downgrade_operators.append(operator)
@classmethod @classmethod
def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]): def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]) -> list:
ret = [] ret: list = []
def index_hash(self) -> str:
h = hashlib.new("MD5", usedforsecurity=False)
h.update(
self.index_name(cls.ddl.schema_generator, model).encode()
+ self.__class__.__name__.encode()
)
return h.hexdigest()
for index in indexes: for index in indexes:
if isinstance(index, Index): if isinstance(index, Index):
index.__hash__ = lambda self: md5( # nosec: B303 index.__hash__ = index_hash # type:ignore[method-assign,assignment]
self.index_name(cls.ddl.schema_generator, model).encode()
+ self.__class__.__name__.encode()
).hexdigest()
ret.append(index) ret.append(index)
return ret return ret
@classmethod @classmethod
def diff_models(cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True): def _get_indexes(cls, model, model_describe: dict) -> Set[Union[Index, Tuple[str, ...]]]:
indexes: Set[Union[Index, Tuple[str, ...]]] = set()
for x in cls._handle_indexes(model, model_describe.get("indexes", [])):
if isinstance(x, Index):
indexes.add(x)
else:
indexes.add(cast(Tuple[str, ...], tuple(x)))
return indexes
@classmethod
def diff_models(
cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True
) -> None:
""" """
diff models and add operators diff models and add operators
:param old_models: :param old_models:
@ -211,39 +231,35 @@ class Migrate:
new_models.pop(_aerich, None) new_models.pop(_aerich, None)
for new_model_str, new_model_describe in new_models.items(): for new_model_str, new_model_describe in new_models.items():
model = cls._get_model(new_model_describe.get("name").split(".")[1]) model = cls._get_model(new_model_describe["name"].split(".")[1])
if new_model_str not in old_models.keys(): if new_model_str not in old_models:
if upgrade: if upgrade:
cls._add_operator(cls.add_model(model), upgrade) cls._add_operator(cls.add_model(model), upgrade)
else: else:
# we can't find origin model when downgrade, so skip # we can't find origin model when downgrade, so skip
pass pass
else: else:
old_model_describe = old_models.get(new_model_str) old_model_describe = cast(dict, old_models.get(new_model_str))
# rename table # rename table
new_table = new_model_describe.get("table") new_table = cast(str, new_model_describe.get("table"))
old_table = old_model_describe.get("table") old_table = cast(str, old_model_describe.get("table"))
if new_table != old_table: if new_table != old_table:
cls._add_operator(cls.rename_table(model, old_table, new_table), upgrade) cls._add_operator(cls.rename_table(model, old_table, new_table), upgrade)
old_unique_together = set( old_unique_together = set(
map(lambda x: tuple(x), old_model_describe.get("unique_together")) map(
lambda x: tuple(x),
cast(List[Iterable[str]], old_model_describe.get("unique_together")),
)
) )
new_unique_together = set( new_unique_together = set(
map(lambda x: tuple(x), new_model_describe.get("unique_together"))
)
old_indexes = set(
map( map(
lambda x: x if isinstance(x, Index) else tuple(x), lambda x: tuple(x),
cls._handle_indexes(model, old_model_describe.get("indexes", [])), cast(List[Iterable[str]], new_model_describe.get("unique_together")),
)
)
new_indexes = set(
map(
lambda x: x if isinstance(x, Index) else tuple(x),
cls._handle_indexes(model, new_model_describe.get("indexes", [])),
) )
) )
old_indexes = cls._get_indexes(model, old_model_describe)
new_indexes = cls._get_indexes(model, new_model_describe)
old_pk_field = old_model_describe.get("pk_field") old_pk_field = old_model_describe.get("pk_field")
new_pk_field = new_model_describe.get("pk_field") new_pk_field = new_model_describe.get("pk_field")
# pk field # pk field
@ -253,18 +269,19 @@ class Migrate:
if action == "change" and option == "name": if action == "change" and option == "name":
cls._add_operator(cls._rename_field(model, *change), upgrade) cls._add_operator(cls._rename_field(model, *change), upgrade)
# m2m fields # m2m fields
old_m2m_fields = old_model_describe.get("m2m_fields") old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields"))
new_m2m_fields = new_model_describe.get("m2m_fields") new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields"))
for action, option, change in diff(old_m2m_fields, new_m2m_fields): for action, option, change in diff(old_m2m_fields, new_m2m_fields):
if change[0][0] == "db_constraint": if change[0][0] == "db_constraint":
continue continue
if isinstance(change[0][1], str): new_value = change[0][1]
if isinstance(new_value, str):
for new_m2m_field in new_m2m_fields: for new_m2m_field in new_m2m_fields:
if new_m2m_field["name"] == change[0][1]: if new_m2m_field["name"] == new_value:
table = new_m2m_field.get("through") table = cast(str, new_m2m_field.get("through"))
break break
else: else:
table = change[0][1].get("through") table = new_value.get("through")
if action == "add": if action == "add":
add = False add = False
if upgrade and table not in cls._upgrade_m2m: if upgrade and table not in cls._upgrade_m2m:
@ -274,12 +291,9 @@ class Migrate:
cls._downgrade_m2m.append(table) cls._downgrade_m2m.append(table)
add = True add = True
if add: if add:
ref_desc = cast(dict, new_models.get(new_value.get("model_name")))
cls._add_operator( cls._add_operator(
cls.create_m2m( cls.create_m2m(model, new_value, ref_desc),
model,
change[0][1],
new_models.get(change[0][1].get("model_name")),
),
upgrade, upgrade,
fk_m2m_index=True, fk_m2m_index=True,
) )
@ -300,38 +314,36 @@ class Migrate:
for index in old_unique_together.difference(new_unique_together): for index in old_unique_together.difference(new_unique_together):
cls._add_operator(cls._drop_index(model, index, True), upgrade, True) cls._add_operator(cls._drop_index(model, index, True), upgrade, True)
# add indexes # add indexes
for index in new_indexes.difference(old_indexes): for idx in new_indexes.difference(old_indexes):
cls._add_operator(cls._add_index(model, index, False), upgrade, True) cls._add_operator(cls._add_index(model, idx, False), upgrade, True)
# remove indexes # remove indexes
for index in old_indexes.difference(new_indexes): for idx in old_indexes.difference(new_indexes):
cls._add_operator(cls._drop_index(model, index, False), upgrade, True) cls._add_operator(cls._drop_index(model, idx, False), upgrade, True)
old_data_fields = list( old_data_fields = list(
filter( filter(
lambda x: x.get("db_field_types") is not None, lambda x: x.get("db_field_types") is not None,
old_model_describe.get("data_fields"), cast(List[dict], old_model_describe.get("data_fields")),
) )
) )
new_data_fields = list( new_data_fields = list(
filter( filter(
lambda x: x.get("db_field_types") is not None, lambda x: x.get("db_field_types") is not None,
new_model_describe.get("data_fields"), cast(List[dict], new_model_describe.get("data_fields")),
) )
) )
old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields)) old_data_fields_name = cast(List[str], [i.get("name") for i in old_data_fields])
new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields)) new_data_fields_name = cast(List[str], [i.get("name") for i in new_data_fields])
# add fields or rename fields # add fields or rename fields
for new_data_field_name in set(new_data_fields_name).difference( for new_data_field_name in set(new_data_fields_name).difference(
set(old_data_fields_name) set(old_data_fields_name)
): ):
new_data_field = next( new_data_field = cls.get_field_by_name(new_data_field_name, new_data_fields)
filter(lambda x: x.get("name") == new_data_field_name, new_data_fields)
)
is_rename = False is_rename = False
for old_data_field in old_data_fields: for old_data_field in old_data_fields:
changes = list(diff(old_data_field, new_data_field)) changes = list(diff(old_data_field, new_data_field))
old_data_field_name = old_data_field.get("name") old_data_field_name = cast(str, old_data_field.get("name"))
if len(changes) == 2: if len(changes) == 2:
# rename field # rename field
if ( if (
@ -392,7 +404,7 @@ class Migrate:
if new_data_field["indexed"]: if new_data_field["indexed"]:
cls._add_operator( cls._add_operator(
cls._add_index( cls._add_index(
model, {new_data_field["db_column"]}, new_data_field["unique"] model, (new_data_field["db_column"],), new_data_field["unique"]
), ),
upgrade, upgrade,
True, True,
@ -406,45 +418,34 @@ class Migrate:
not upgrade and old_data_field_name in cls._rename_new not upgrade and old_data_field_name in cls._rename_new
): ):
continue continue
old_data_field = next( old_data_field = cls.get_field_by_name(old_data_field_name, old_data_fields)
filter(lambda x: x.get("name") == old_data_field_name, old_data_fields) db_column = cast(str, old_data_field["db_column"])
)
db_column = old_data_field["db_column"]
cls._add_operator( cls._add_operator(
cls._remove_field( cls._remove_field(model, db_column),
model,
db_column,
),
upgrade, upgrade,
) )
if old_data_field["indexed"]: if old_data_field["indexed"]:
cls._add_operator( cls._add_operator(
cls._drop_index( cls._drop_index(model, {db_column}),
model,
{db_column},
),
upgrade, upgrade,
True, True,
) )
old_fk_fields = old_model_describe.get("fk_fields") old_fk_fields = cast(List[dict], old_model_describe.get("fk_fields"))
new_fk_fields = new_model_describe.get("fk_fields") new_fk_fields = cast(List[dict], new_model_describe.get("fk_fields"))
old_fk_fields_name = list(map(lambda x: x.get("name"), old_fk_fields)) old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields]
new_fk_fields_name = list(map(lambda x: x.get("name"), new_fk_fields)) new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields]
# add fk # add fk
for new_fk_field_name in set(new_fk_fields_name).difference( for new_fk_field_name in set(new_fk_fields_name).difference(
set(old_fk_fields_name) set(old_fk_fields_name)
): ):
fk_field = next( fk_field = cls.get_field_by_name(new_fk_field_name, new_fk_fields)
filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields)
)
if fk_field.get("db_constraint"): if fk_field.get("db_constraint"):
ref_describe = cast(dict, new_models[fk_field["python_type"]])
cls._add_operator( cls._add_operator(
cls._add_fk( cls._add_fk(model, fk_field, ref_describe),
model, fk_field, new_models.get(fk_field.get("python_type"))
),
upgrade, upgrade,
fk_m2m_index=True, fk_m2m_index=True,
) )
@ -452,25 +453,20 @@ class Migrate:
for old_fk_field_name in set(old_fk_fields_name).difference( for old_fk_field_name in set(old_fk_fields_name).difference(
set(new_fk_fields_name) set(new_fk_fields_name)
): ):
old_fk_field = next( old_fk_field = cls.get_field_by_name(
filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields) old_fk_field_name, cast(List[dict], old_fk_fields)
) )
if old_fk_field.get("db_constraint"): if old_fk_field.get("db_constraint"):
ref_describe = cast(dict, old_models[old_fk_field["python_type"]])
cls._add_operator( cls._add_operator(
cls._drop_fk( cls._drop_fk(model, old_fk_field, ref_describe),
model, old_fk_field, old_models.get(old_fk_field.get("python_type"))
),
upgrade, upgrade,
fk_m2m_index=True, fk_m2m_index=True,
) )
# change fields # change fields
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)): for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
old_data_field = next( old_data_field = cls.get_field_by_name(field_name, old_data_fields)
filter(lambda x: x.get("name") == field_name, old_data_fields) new_data_field = cls.get_field_by_name(field_name, new_data_fields)
)
new_data_field = next(
filter(lambda x: x.get("name") == field_name, new_data_fields)
)
changes = diff(old_data_field, new_data_field) changes = diff(old_data_field, new_data_field)
modified = False modified = False
for change in changes: for change in changes:
@ -479,13 +475,14 @@ class Migrate:
# change index # change index
unique = new_data_field.get("unique") unique = new_data_field.get("unique")
if old_new[0] is False and old_new[1] is True: if old_new[0] is False and old_new[1] is True:
cls._add_operator( add_or_drop_index = cls._add_index
cls._add_index(model, (field_name,), unique), upgrade, True
)
else: else:
cls._add_operator( # For drop case, unique value should get from old data
cls._drop_index(model, (field_name,), unique), upgrade, True # TODO: unique = old_data_field.get("unique")
) add_or_drop_index = cls._drop_index
cls._add_operator(
add_or_drop_index(model, (field_name,), unique), upgrade, True
)
elif option == "db_field_types.": elif option == "db_field_types.":
if new_data_field.get("field_type") == "DecimalField": if new_data_field.get("field_type") == "DecimalField":
# modify column # modify column
@ -522,32 +519,33 @@ class Migrate:
) )
modified = True modified = True
for old_model in old_models: for old_model in old_models.keys() - new_models.keys():
if old_model not in new_models.keys(): cls._add_operator(cls.drop_model(old_models[old_model]["table"]), upgrade)
cls._add_operator(cls.drop_model(old_models.get(old_model).get("table")), upgrade)
@classmethod @classmethod
def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str): def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str) -> str:
return cls.ddl.rename_table(model, old_table_name, new_table_name) return cls.ddl.rename_table(model, old_table_name, new_table_name)
@classmethod @classmethod
def add_model(cls, model: Type[Model]): def add_model(cls, model: Type[Model]) -> str:
return cls.ddl.create_table(model) return cls.ddl.create_table(model)
@classmethod @classmethod
def drop_model(cls, table_name: str): def drop_model(cls, table_name: str) -> str:
return cls.ddl.drop_table(table_name) return cls.ddl.drop_table(table_name)
@classmethod @classmethod
def create_m2m(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict): def create_m2m(
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
return cls.ddl.create_m2m(model, field_describe, reference_table_describe) return cls.ddl.create_m2m(model, field_describe, reference_table_describe)
@classmethod @classmethod
def drop_m2m(cls, table_name: str): def drop_m2m(cls, table_name: str) -> str:
return cls.ddl.drop_m2m(table_name) return cls.ddl.drop_m2m(table_name)
@classmethod @classmethod
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]): def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Iterable[str]) -> List[str]:
ret = [] ret = []
for field_name in fields_name: for field_name in fields_name:
field = model._meta.fields_map[field_name] field = model._meta.fields_map[field_name]
@ -560,65 +558,75 @@ class Migrate:
return ret return ret
@classmethod @classmethod
def _drop_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False): def _drop_index(
cls, model: Type[Model], fields_name: Union[Iterable[str], Index], unique=False
) -> str:
if isinstance(fields_name, Index): if isinstance(fields_name, Index):
return cls.ddl.drop_index_by_name( return cls.ddl.drop_index_by_name(
model, fields_name.index_name(cls.ddl.schema_generator, model) model, fields_name.index_name(cls.ddl.schema_generator, model)
) )
fields_name = cls._resolve_fk_fields_name(model, fields_name) field_names = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.drop_index(model, fields_name, unique) return cls.ddl.drop_index(model, field_names, unique)
@classmethod @classmethod
def _add_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False): def _add_index(
cls, model: Type[Model], fields_name: Union[Iterable[str], Index], unique=False
) -> str:
if isinstance(fields_name, Index): if isinstance(fields_name, Index):
return fields_name.get_sql(cls.ddl.schema_generator, model, False) return fields_name.get_sql(cls.ddl.schema_generator, model, False)
fields_name = cls._resolve_fk_fields_name(model, fields_name) field_names = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.add_index(model, fields_name, unique) return cls.ddl.add_index(model, field_names, unique)
@classmethod @classmethod
def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False): def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False) -> str:
return cls.ddl.add_column(model, field_describe, is_pk) return cls.ddl.add_column(model, field_describe, is_pk)
@classmethod @classmethod
def _alter_default(cls, model: Type[Model], field_describe: dict): def _alter_default(cls, model: Type[Model], field_describe: dict) -> str:
return cls.ddl.alter_column_default(model, field_describe) return cls.ddl.alter_column_default(model, field_describe)
@classmethod @classmethod
def _alter_null(cls, model: Type[Model], field_describe: dict): def _alter_null(cls, model: Type[Model], field_describe: dict) -> str:
return cls.ddl.alter_column_null(model, field_describe) return cls.ddl.alter_column_null(model, field_describe)
@classmethod @classmethod
def _set_comment(cls, model: Type[Model], field_describe: dict): def _set_comment(cls, model: Type[Model], field_describe: dict) -> str:
return cls.ddl.set_comment(model, field_describe) return cls.ddl.set_comment(model, field_describe)
@classmethod @classmethod
def _modify_field(cls, model: Type[Model], field_describe: dict): def _modify_field(cls, model: Type[Model], field_describe: dict) -> str:
return cls.ddl.modify_column(model, field_describe) return cls.ddl.modify_column(model, field_describe)
@classmethod @classmethod
def _drop_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict): def _drop_fk(
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
return cls.ddl.drop_fk(model, field_describe, reference_table_describe) return cls.ddl.drop_fk(model, field_describe, reference_table_describe)
@classmethod @classmethod
def _remove_field(cls, model: Type[Model], column_name: str): def _remove_field(cls, model: Type[Model], column_name: str) -> str:
return cls.ddl.drop_column(model, column_name) return cls.ddl.drop_column(model, column_name)
@classmethod @classmethod
def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str): def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str) -> str:
return cls.ddl.rename_column(model, old_field_name, new_field_name) return cls.ddl.rename_column(model, old_field_name, new_field_name)
@classmethod @classmethod
def _change_field(cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict): def _change_field(
db_field_types = new_field_describe.get("db_field_types") cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict
) -> str:
db_field_types = cast(dict, new_field_describe.get("db_field_types"))
return cls.ddl.change_column( return cls.ddl.change_column(
model, model,
old_field_describe.get("db_column"), cast(str, old_field_describe.get("db_column")),
new_field_describe.get("db_column"), cast(str, new_field_describe.get("db_column")),
db_field_types.get(cls.dialect) or db_field_types.get(""), cast(str, db_field_types.get(cls.dialect) or db_field_types.get("")),
) )
@classmethod @classmethod
def _add_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict): def _add_fk(
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
""" """
add fk add fk
:param model: :param model:
@ -629,7 +637,7 @@ class Migrate:
return cls.ddl.add_fk(model, field_describe, reference_table_describe) return cls.ddl.add_fk(model, field_describe, reference_table_describe)
@classmethod @classmethod
def _merge_operators(cls): def _merge_operators(cls) -> None:
""" """
fk/m2m/index must be last when add,first when drop fk/m2m/index must be last when add,first when drop
:return: :return: