Add type hints for aerich.migrate
This commit is contained in:
parent
51117867a6
commit
6466a852c8
@ -1,9 +1,9 @@
|
||||
import hashlib
|
||||
import importlib
|
||||
import os
|
||||
from datetime import datetime
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast
|
||||
|
||||
import click
|
||||
from dictdiffer import diff
|
||||
@ -37,16 +37,21 @@ class Migrate:
|
||||
_upgrade_m2m: List[str] = []
|
||||
_downgrade_m2m: List[str] = []
|
||||
_aerich = Aerich.__name__
|
||||
_rename_old = []
|
||||
_rename_new = []
|
||||
_rename_old: List[str] = []
|
||||
_rename_new: List[str] = []
|
||||
|
||||
ddl: BaseDDL
|
||||
ddl_class: Type[BaseDDL]
|
||||
_last_version_content: Optional[dict] = None
|
||||
app: str
|
||||
migrate_location: Path
|
||||
dialect: str
|
||||
_db_version: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def get_field_by_name(name: str, fields: List[dict]) -> dict:
|
||||
return next(filter(lambda x: x.get("name") == name, fields))
|
||||
|
||||
@classmethod
|
||||
def get_all_version_files(cls) -> List[str]:
|
||||
return sorted(
|
||||
@ -56,35 +61,35 @@ class Migrate:
|
||||
|
||||
@classmethod
|
||||
def _get_model(cls, model: str) -> Type[Model]:
|
||||
return Tortoise.apps.get(cls.app).get(model)
|
||||
return Tortoise.apps[cls.app][model]
|
||||
|
||||
@classmethod
|
||||
async def get_last_version(cls) -> Optional[Aerich]:
|
||||
try:
|
||||
return await Aerich.filter(app=cls.app).first()
|
||||
except OperationalError:
|
||||
pass
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def _get_db_version(cls, connection: BaseDBAsyncClient):
|
||||
async def _get_db_version(cls, connection: BaseDBAsyncClient) -> None:
|
||||
if cls.dialect == "mysql":
|
||||
sql = "select version() as version"
|
||||
ret = await connection.execute_query(sql)
|
||||
cls._db_version = ret[1][0].get("version")
|
||||
|
||||
@classmethod
|
||||
async def load_ddl_class(cls):
|
||||
async def load_ddl_class(cls) -> Type[BaseDDL]:
|
||||
ddl_dialect_module = importlib.import_module(f"aerich.ddl.{cls.dialect}")
|
||||
return getattr(ddl_dialect_module, f"{cls.dialect.capitalize()}DDL")
|
||||
|
||||
@classmethod
|
||||
async def init(cls, config: dict, app: str, location: str):
|
||||
async def init(cls, config: dict, app: str, location: str) -> None:
|
||||
await Tortoise.init(config=config)
|
||||
last_version = await cls.get_last_version()
|
||||
cls.app = app
|
||||
cls.migrate_location = Path(location, app)
|
||||
if last_version:
|
||||
cls._last_version_content = last_version.content
|
||||
cls._last_version_content = cast(dict, last_version.content)
|
||||
|
||||
connection = get_app_connection(config, app)
|
||||
cls.dialect = connection.schema_generator.DIALECT
|
||||
@ -93,7 +98,7 @@ class Migrate:
|
||||
await cls._get_db_version(connection)
|
||||
|
||||
@classmethod
|
||||
async def _get_last_version_num(cls):
|
||||
async def _get_last_version_num(cls) -> Optional[int]:
|
||||
last_version = await cls.get_last_version()
|
||||
if not last_version:
|
||||
return None
|
||||
@ -101,7 +106,7 @@ class Migrate:
|
||||
return int(version.split("_", 1)[0])
|
||||
|
||||
@classmethod
|
||||
async def generate_version(cls, name=None):
|
||||
async def generate_version(cls, name=None) -> str:
|
||||
now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "")
|
||||
last_version_num = await cls._get_last_version_num()
|
||||
if last_version_num is None:
|
||||
@ -112,18 +117,15 @@ class Migrate:
|
||||
return version
|
||||
|
||||
@classmethod
|
||||
async def _generate_diff_py(cls, name):
|
||||
async def _generate_diff_py(cls, name) -> str:
|
||||
version = await cls.generate_version(name)
|
||||
# delete if same version exists
|
||||
for version_file in cls.get_all_version_files():
|
||||
if version_file.startswith(version.split("_")[0]):
|
||||
os.unlink(Path(cls.migrate_location, version_file))
|
||||
|
||||
version_file = Path(cls.migrate_location, version)
|
||||
content = cls._get_diff_file_content()
|
||||
|
||||
with open(version_file, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
Path(cls.migrate_location, version).write_text(content, encoding="utf-8")
|
||||
return version
|
||||
|
||||
@classmethod
|
||||
@ -136,10 +138,10 @@ class Migrate:
|
||||
"""
|
||||
if empty:
|
||||
return await cls._generate_diff_py(name)
|
||||
|
||||
new_version_content = get_models_describe(cls.app)
|
||||
cls.diff_models(cls._last_version_content, new_version_content)
|
||||
cls.diff_models(new_version_content, cls._last_version_content, False)
|
||||
last_version = cast(dict, cls._last_version_content)
|
||||
cls.diff_models(last_version, new_version_content)
|
||||
cls.diff_models(new_version_content, last_version, False)
|
||||
|
||||
cls._merge_operators()
|
||||
|
||||
@ -165,7 +167,7 @@ class Migrate:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False):
|
||||
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False) -> None:
|
||||
"""
|
||||
add operator,differentiate fk because fk is order limit
|
||||
:param operator:
|
||||
@ -186,19 +188,37 @@ class Migrate:
|
||||
cls.downgrade_operators.append(operator)
|
||||
|
||||
@classmethod
|
||||
def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]):
|
||||
ret = []
|
||||
for index in indexes:
|
||||
if isinstance(index, Index):
|
||||
index.__hash__ = lambda self: md5( # nosec: B303
|
||||
def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]) -> list:
|
||||
ret: list = []
|
||||
|
||||
def index_hash(self) -> str:
|
||||
h = hashlib.new("MD5", usedforsecurity=False)
|
||||
h.update(
|
||||
self.index_name(cls.ddl.schema_generator, model).encode()
|
||||
+ self.__class__.__name__.encode()
|
||||
).hexdigest()
|
||||
)
|
||||
return h.hexdigest()
|
||||
|
||||
for index in indexes:
|
||||
if isinstance(index, Index):
|
||||
index.__hash__ = index_hash # type:ignore[method-assign,assignment]
|
||||
ret.append(index)
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def diff_models(cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True):
|
||||
def _get_indexes(cls, model, model_describe: dict) -> Set[Union[Index, Tuple[str, ...]]]:
|
||||
indexes: Set[Union[Index, Tuple[str, ...]]] = set()
|
||||
for x in cls._handle_indexes(model, model_describe.get("indexes", [])):
|
||||
if isinstance(x, Index):
|
||||
indexes.add(x)
|
||||
else:
|
||||
indexes.add(cast(Tuple[str, ...], tuple(x)))
|
||||
return indexes
|
||||
|
||||
@classmethod
|
||||
def diff_models(
|
||||
cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True
|
||||
) -> None:
|
||||
"""
|
||||
diff models and add operators
|
||||
:param old_models:
|
||||
@ -211,39 +231,35 @@ class Migrate:
|
||||
new_models.pop(_aerich, None)
|
||||
|
||||
for new_model_str, new_model_describe in new_models.items():
|
||||
model = cls._get_model(new_model_describe.get("name").split(".")[1])
|
||||
model = cls._get_model(new_model_describe["name"].split(".")[1])
|
||||
|
||||
if new_model_str not in old_models.keys():
|
||||
if new_model_str not in old_models:
|
||||
if upgrade:
|
||||
cls._add_operator(cls.add_model(model), upgrade)
|
||||
else:
|
||||
# we can't find origin model when downgrade, so skip
|
||||
pass
|
||||
else:
|
||||
old_model_describe = old_models.get(new_model_str)
|
||||
old_model_describe = cast(dict, old_models.get(new_model_str))
|
||||
# rename table
|
||||
new_table = new_model_describe.get("table")
|
||||
old_table = old_model_describe.get("table")
|
||||
new_table = cast(str, new_model_describe.get("table"))
|
||||
old_table = cast(str, old_model_describe.get("table"))
|
||||
if new_table != old_table:
|
||||
cls._add_operator(cls.rename_table(model, old_table, new_table), upgrade)
|
||||
old_unique_together = set(
|
||||
map(lambda x: tuple(x), old_model_describe.get("unique_together"))
|
||||
map(
|
||||
lambda x: tuple(x),
|
||||
cast(List[Iterable[str]], old_model_describe.get("unique_together")),
|
||||
)
|
||||
)
|
||||
new_unique_together = set(
|
||||
map(lambda x: tuple(x), new_model_describe.get("unique_together"))
|
||||
)
|
||||
old_indexes = set(
|
||||
map(
|
||||
lambda x: x if isinstance(x, Index) else tuple(x),
|
||||
cls._handle_indexes(model, old_model_describe.get("indexes", [])),
|
||||
)
|
||||
)
|
||||
new_indexes = set(
|
||||
map(
|
||||
lambda x: x if isinstance(x, Index) else tuple(x),
|
||||
cls._handle_indexes(model, new_model_describe.get("indexes", [])),
|
||||
lambda x: tuple(x),
|
||||
cast(List[Iterable[str]], new_model_describe.get("unique_together")),
|
||||
)
|
||||
)
|
||||
old_indexes = cls._get_indexes(model, old_model_describe)
|
||||
new_indexes = cls._get_indexes(model, new_model_describe)
|
||||
old_pk_field = old_model_describe.get("pk_field")
|
||||
new_pk_field = new_model_describe.get("pk_field")
|
||||
# pk field
|
||||
@ -253,18 +269,19 @@ class Migrate:
|
||||
if action == "change" and option == "name":
|
||||
cls._add_operator(cls._rename_field(model, *change), upgrade)
|
||||
# m2m fields
|
||||
old_m2m_fields = old_model_describe.get("m2m_fields")
|
||||
new_m2m_fields = new_model_describe.get("m2m_fields")
|
||||
old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields"))
|
||||
new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields"))
|
||||
for action, option, change in diff(old_m2m_fields, new_m2m_fields):
|
||||
if change[0][0] == "db_constraint":
|
||||
continue
|
||||
if isinstance(change[0][1], str):
|
||||
new_value = change[0][1]
|
||||
if isinstance(new_value, str):
|
||||
for new_m2m_field in new_m2m_fields:
|
||||
if new_m2m_field["name"] == change[0][1]:
|
||||
table = new_m2m_field.get("through")
|
||||
if new_m2m_field["name"] == new_value:
|
||||
table = cast(str, new_m2m_field.get("through"))
|
||||
break
|
||||
else:
|
||||
table = change[0][1].get("through")
|
||||
table = new_value.get("through")
|
||||
if action == "add":
|
||||
add = False
|
||||
if upgrade and table not in cls._upgrade_m2m:
|
||||
@ -274,12 +291,9 @@ class Migrate:
|
||||
cls._downgrade_m2m.append(table)
|
||||
add = True
|
||||
if add:
|
||||
ref_desc = cast(dict, new_models.get(new_value.get("model_name")))
|
||||
cls._add_operator(
|
||||
cls.create_m2m(
|
||||
model,
|
||||
change[0][1],
|
||||
new_models.get(change[0][1].get("model_name")),
|
||||
),
|
||||
cls.create_m2m(model, new_value, ref_desc),
|
||||
upgrade,
|
||||
fk_m2m_index=True,
|
||||
)
|
||||
@ -300,38 +314,36 @@ class Migrate:
|
||||
for index in old_unique_together.difference(new_unique_together):
|
||||
cls._add_operator(cls._drop_index(model, index, True), upgrade, True)
|
||||
# add indexes
|
||||
for index in new_indexes.difference(old_indexes):
|
||||
cls._add_operator(cls._add_index(model, index, False), upgrade, True)
|
||||
for idx in new_indexes.difference(old_indexes):
|
||||
cls._add_operator(cls._add_index(model, idx, False), upgrade, True)
|
||||
# remove indexes
|
||||
for index in old_indexes.difference(new_indexes):
|
||||
cls._add_operator(cls._drop_index(model, index, False), upgrade, True)
|
||||
for idx in old_indexes.difference(new_indexes):
|
||||
cls._add_operator(cls._drop_index(model, idx, False), upgrade, True)
|
||||
old_data_fields = list(
|
||||
filter(
|
||||
lambda x: x.get("db_field_types") is not None,
|
||||
old_model_describe.get("data_fields"),
|
||||
cast(List[dict], old_model_describe.get("data_fields")),
|
||||
)
|
||||
)
|
||||
new_data_fields = list(
|
||||
filter(
|
||||
lambda x: x.get("db_field_types") is not None,
|
||||
new_model_describe.get("data_fields"),
|
||||
cast(List[dict], new_model_describe.get("data_fields")),
|
||||
)
|
||||
)
|
||||
|
||||
old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields))
|
||||
new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields))
|
||||
old_data_fields_name = cast(List[str], [i.get("name") for i in old_data_fields])
|
||||
new_data_fields_name = cast(List[str], [i.get("name") for i in new_data_fields])
|
||||
|
||||
# add fields or rename fields
|
||||
for new_data_field_name in set(new_data_fields_name).difference(
|
||||
set(old_data_fields_name)
|
||||
):
|
||||
new_data_field = next(
|
||||
filter(lambda x: x.get("name") == new_data_field_name, new_data_fields)
|
||||
)
|
||||
new_data_field = cls.get_field_by_name(new_data_field_name, new_data_fields)
|
||||
is_rename = False
|
||||
for old_data_field in old_data_fields:
|
||||
changes = list(diff(old_data_field, new_data_field))
|
||||
old_data_field_name = old_data_field.get("name")
|
||||
old_data_field_name = cast(str, old_data_field.get("name"))
|
||||
if len(changes) == 2:
|
||||
# rename field
|
||||
if (
|
||||
@ -392,7 +404,7 @@ class Migrate:
|
||||
if new_data_field["indexed"]:
|
||||
cls._add_operator(
|
||||
cls._add_index(
|
||||
model, {new_data_field["db_column"]}, new_data_field["unique"]
|
||||
model, (new_data_field["db_column"],), new_data_field["unique"]
|
||||
),
|
||||
upgrade,
|
||||
True,
|
||||
@ -406,45 +418,34 @@ class Migrate:
|
||||
not upgrade and old_data_field_name in cls._rename_new
|
||||
):
|
||||
continue
|
||||
old_data_field = next(
|
||||
filter(lambda x: x.get("name") == old_data_field_name, old_data_fields)
|
||||
)
|
||||
db_column = old_data_field["db_column"]
|
||||
old_data_field = cls.get_field_by_name(old_data_field_name, old_data_fields)
|
||||
db_column = cast(str, old_data_field["db_column"])
|
||||
cls._add_operator(
|
||||
cls._remove_field(
|
||||
model,
|
||||
db_column,
|
||||
),
|
||||
cls._remove_field(model, db_column),
|
||||
upgrade,
|
||||
)
|
||||
if old_data_field["indexed"]:
|
||||
cls._add_operator(
|
||||
cls._drop_index(
|
||||
model,
|
||||
{db_column},
|
||||
),
|
||||
cls._drop_index(model, {db_column}),
|
||||
upgrade,
|
||||
True,
|
||||
)
|
||||
|
||||
old_fk_fields = old_model_describe.get("fk_fields")
|
||||
new_fk_fields = new_model_describe.get("fk_fields")
|
||||
old_fk_fields = cast(List[dict], old_model_describe.get("fk_fields"))
|
||||
new_fk_fields = cast(List[dict], new_model_describe.get("fk_fields"))
|
||||
|
||||
old_fk_fields_name = list(map(lambda x: x.get("name"), old_fk_fields))
|
||||
new_fk_fields_name = list(map(lambda x: x.get("name"), new_fk_fields))
|
||||
old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields]
|
||||
new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields]
|
||||
|
||||
# add fk
|
||||
for new_fk_field_name in set(new_fk_fields_name).difference(
|
||||
set(old_fk_fields_name)
|
||||
):
|
||||
fk_field = next(
|
||||
filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields)
|
||||
)
|
||||
fk_field = cls.get_field_by_name(new_fk_field_name, new_fk_fields)
|
||||
if fk_field.get("db_constraint"):
|
||||
ref_describe = cast(dict, new_models[fk_field["python_type"]])
|
||||
cls._add_operator(
|
||||
cls._add_fk(
|
||||
model, fk_field, new_models.get(fk_field.get("python_type"))
|
||||
),
|
||||
cls._add_fk(model, fk_field, ref_describe),
|
||||
upgrade,
|
||||
fk_m2m_index=True,
|
||||
)
|
||||
@ -452,25 +453,20 @@ class Migrate:
|
||||
for old_fk_field_name in set(old_fk_fields_name).difference(
|
||||
set(new_fk_fields_name)
|
||||
):
|
||||
old_fk_field = next(
|
||||
filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields)
|
||||
old_fk_field = cls.get_field_by_name(
|
||||
old_fk_field_name, cast(List[dict], old_fk_fields)
|
||||
)
|
||||
if old_fk_field.get("db_constraint"):
|
||||
ref_describe = cast(dict, old_models[old_fk_field["python_type"]])
|
||||
cls._add_operator(
|
||||
cls._drop_fk(
|
||||
model, old_fk_field, old_models.get(old_fk_field.get("python_type"))
|
||||
),
|
||||
cls._drop_fk(model, old_fk_field, ref_describe),
|
||||
upgrade,
|
||||
fk_m2m_index=True,
|
||||
)
|
||||
# change fields
|
||||
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
|
||||
old_data_field = next(
|
||||
filter(lambda x: x.get("name") == field_name, old_data_fields)
|
||||
)
|
||||
new_data_field = next(
|
||||
filter(lambda x: x.get("name") == field_name, new_data_fields)
|
||||
)
|
||||
old_data_field = cls.get_field_by_name(field_name, old_data_fields)
|
||||
new_data_field = cls.get_field_by_name(field_name, new_data_fields)
|
||||
changes = diff(old_data_field, new_data_field)
|
||||
modified = False
|
||||
for change in changes:
|
||||
@ -479,12 +475,13 @@ class Migrate:
|
||||
# change index
|
||||
unique = new_data_field.get("unique")
|
||||
if old_new[0] is False and old_new[1] is True:
|
||||
cls._add_operator(
|
||||
cls._add_index(model, (field_name,), unique), upgrade, True
|
||||
)
|
||||
add_or_drop_index = cls._add_index
|
||||
else:
|
||||
# For drop case, unique value should get from old data
|
||||
# TODO: unique = old_data_field.get("unique")
|
||||
add_or_drop_index = cls._drop_index
|
||||
cls._add_operator(
|
||||
cls._drop_index(model, (field_name,), unique), upgrade, True
|
||||
add_or_drop_index(model, (field_name,), unique), upgrade, True
|
||||
)
|
||||
elif option == "db_field_types.":
|
||||
if new_data_field.get("field_type") == "DecimalField":
|
||||
@ -522,32 +519,33 @@ class Migrate:
|
||||
)
|
||||
modified = True
|
||||
|
||||
for old_model in old_models:
|
||||
if old_model not in new_models.keys():
|
||||
cls._add_operator(cls.drop_model(old_models.get(old_model).get("table")), upgrade)
|
||||
for old_model in old_models.keys() - new_models.keys():
|
||||
cls._add_operator(cls.drop_model(old_models[old_model]["table"]), upgrade)
|
||||
|
||||
@classmethod
|
||||
def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str):
|
||||
def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str) -> str:
|
||||
return cls.ddl.rename_table(model, old_table_name, new_table_name)
|
||||
|
||||
@classmethod
|
||||
def add_model(cls, model: Type[Model]):
|
||||
def add_model(cls, model: Type[Model]) -> str:
|
||||
return cls.ddl.create_table(model)
|
||||
|
||||
@classmethod
|
||||
def drop_model(cls, table_name: str):
|
||||
def drop_model(cls, table_name: str) -> str:
|
||||
return cls.ddl.drop_table(table_name)
|
||||
|
||||
@classmethod
|
||||
def create_m2m(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
|
||||
def create_m2m(
|
||||
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
|
||||
) -> str:
|
||||
return cls.ddl.create_m2m(model, field_describe, reference_table_describe)
|
||||
|
||||
@classmethod
|
||||
def drop_m2m(cls, table_name: str):
|
||||
def drop_m2m(cls, table_name: str) -> str:
|
||||
return cls.ddl.drop_m2m(table_name)
|
||||
|
||||
@classmethod
|
||||
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]):
|
||||
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Iterable[str]) -> List[str]:
|
||||
ret = []
|
||||
for field_name in fields_name:
|
||||
field = model._meta.fields_map[field_name]
|
||||
@ -560,65 +558,75 @@ class Migrate:
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def _drop_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False):
|
||||
def _drop_index(
|
||||
cls, model: Type[Model], fields_name: Union[Iterable[str], Index], unique=False
|
||||
) -> str:
|
||||
if isinstance(fields_name, Index):
|
||||
return cls.ddl.drop_index_by_name(
|
||||
model, fields_name.index_name(cls.ddl.schema_generator, model)
|
||||
)
|
||||
fields_name = cls._resolve_fk_fields_name(model, fields_name)
|
||||
return cls.ddl.drop_index(model, fields_name, unique)
|
||||
field_names = cls._resolve_fk_fields_name(model, fields_name)
|
||||
return cls.ddl.drop_index(model, field_names, unique)
|
||||
|
||||
@classmethod
|
||||
def _add_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False):
|
||||
def _add_index(
|
||||
cls, model: Type[Model], fields_name: Union[Iterable[str], Index], unique=False
|
||||
) -> str:
|
||||
if isinstance(fields_name, Index):
|
||||
return fields_name.get_sql(cls.ddl.schema_generator, model, False)
|
||||
fields_name = cls._resolve_fk_fields_name(model, fields_name)
|
||||
return cls.ddl.add_index(model, fields_name, unique)
|
||||
field_names = cls._resolve_fk_fields_name(model, fields_name)
|
||||
return cls.ddl.add_index(model, field_names, unique)
|
||||
|
||||
@classmethod
|
||||
def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False):
|
||||
def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False) -> str:
|
||||
return cls.ddl.add_column(model, field_describe, is_pk)
|
||||
|
||||
@classmethod
|
||||
def _alter_default(cls, model: Type[Model], field_describe: dict):
|
||||
def _alter_default(cls, model: Type[Model], field_describe: dict) -> str:
|
||||
return cls.ddl.alter_column_default(model, field_describe)
|
||||
|
||||
@classmethod
|
||||
def _alter_null(cls, model: Type[Model], field_describe: dict):
|
||||
def _alter_null(cls, model: Type[Model], field_describe: dict) -> str:
|
||||
return cls.ddl.alter_column_null(model, field_describe)
|
||||
|
||||
@classmethod
|
||||
def _set_comment(cls, model: Type[Model], field_describe: dict):
|
||||
def _set_comment(cls, model: Type[Model], field_describe: dict) -> str:
|
||||
return cls.ddl.set_comment(model, field_describe)
|
||||
|
||||
@classmethod
|
||||
def _modify_field(cls, model: Type[Model], field_describe: dict):
|
||||
def _modify_field(cls, model: Type[Model], field_describe: dict) -> str:
|
||||
return cls.ddl.modify_column(model, field_describe)
|
||||
|
||||
@classmethod
|
||||
def _drop_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
|
||||
def _drop_fk(
|
||||
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
|
||||
) -> str:
|
||||
return cls.ddl.drop_fk(model, field_describe, reference_table_describe)
|
||||
|
||||
@classmethod
|
||||
def _remove_field(cls, model: Type[Model], column_name: str):
|
||||
def _remove_field(cls, model: Type[Model], column_name: str) -> str:
|
||||
return cls.ddl.drop_column(model, column_name)
|
||||
|
||||
@classmethod
|
||||
def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str):
|
||||
def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str) -> str:
|
||||
return cls.ddl.rename_column(model, old_field_name, new_field_name)
|
||||
|
||||
@classmethod
|
||||
def _change_field(cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict):
|
||||
db_field_types = new_field_describe.get("db_field_types")
|
||||
def _change_field(
|
||||
cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict
|
||||
) -> str:
|
||||
db_field_types = cast(dict, new_field_describe.get("db_field_types"))
|
||||
return cls.ddl.change_column(
|
||||
model,
|
||||
old_field_describe.get("db_column"),
|
||||
new_field_describe.get("db_column"),
|
||||
db_field_types.get(cls.dialect) or db_field_types.get(""),
|
||||
cast(str, old_field_describe.get("db_column")),
|
||||
cast(str, new_field_describe.get("db_column")),
|
||||
cast(str, db_field_types.get(cls.dialect) or db_field_types.get("")),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _add_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
|
||||
def _add_fk(
|
||||
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
|
||||
) -> str:
|
||||
"""
|
||||
add fk
|
||||
:param model:
|
||||
@ -629,7 +637,7 @@ class Migrate:
|
||||
return cls.ddl.add_fk(model, field_describe, reference_table_describe)
|
||||
|
||||
@classmethod
|
||||
def _merge_operators(cls):
|
||||
def _merge_operators(cls) -> None:
|
||||
"""
|
||||
fk/m2m/index must be last when add,first when drop
|
||||
:return:
|
||||
|
Loading…
x
Reference in New Issue
Block a user