Add type hints for aerich.migrate
This commit is contained in:
parent
51117867a6
commit
6466a852c8
@ -1,9 +1,9 @@
|
|||||||
|
import hashlib
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from hashlib import md5
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
from typing import Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from dictdiffer import diff
|
from dictdiffer import diff
|
||||||
@ -37,16 +37,21 @@ class Migrate:
|
|||||||
_upgrade_m2m: List[str] = []
|
_upgrade_m2m: List[str] = []
|
||||||
_downgrade_m2m: List[str] = []
|
_downgrade_m2m: List[str] = []
|
||||||
_aerich = Aerich.__name__
|
_aerich = Aerich.__name__
|
||||||
_rename_old = []
|
_rename_old: List[str] = []
|
||||||
_rename_new = []
|
_rename_new: List[str] = []
|
||||||
|
|
||||||
ddl: BaseDDL
|
ddl: BaseDDL
|
||||||
|
ddl_class: Type[BaseDDL]
|
||||||
_last_version_content: Optional[dict] = None
|
_last_version_content: Optional[dict] = None
|
||||||
app: str
|
app: str
|
||||||
migrate_location: Path
|
migrate_location: Path
|
||||||
dialect: str
|
dialect: str
|
||||||
_db_version: Optional[str] = None
|
_db_version: Optional[str] = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_field_by_name(name: str, fields: List[dict]) -> dict:
|
||||||
|
return next(filter(lambda x: x.get("name") == name, fields))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_all_version_files(cls) -> List[str]:
|
def get_all_version_files(cls) -> List[str]:
|
||||||
return sorted(
|
return sorted(
|
||||||
@ -56,35 +61,35 @@ class Migrate:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_model(cls, model: str) -> Type[Model]:
|
def _get_model(cls, model: str) -> Type[Model]:
|
||||||
return Tortoise.apps.get(cls.app).get(model)
|
return Tortoise.apps[cls.app][model]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_last_version(cls) -> Optional[Aerich]:
|
async def get_last_version(cls) -> Optional[Aerich]:
|
||||||
try:
|
try:
|
||||||
return await Aerich.filter(app=cls.app).first()
|
return await Aerich.filter(app=cls.app).first()
|
||||||
except OperationalError:
|
except OperationalError:
|
||||||
pass
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _get_db_version(cls, connection: BaseDBAsyncClient):
|
async def _get_db_version(cls, connection: BaseDBAsyncClient) -> None:
|
||||||
if cls.dialect == "mysql":
|
if cls.dialect == "mysql":
|
||||||
sql = "select version() as version"
|
sql = "select version() as version"
|
||||||
ret = await connection.execute_query(sql)
|
ret = await connection.execute_query(sql)
|
||||||
cls._db_version = ret[1][0].get("version")
|
cls._db_version = ret[1][0].get("version")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def load_ddl_class(cls):
|
async def load_ddl_class(cls) -> Type[BaseDDL]:
|
||||||
ddl_dialect_module = importlib.import_module(f"aerich.ddl.{cls.dialect}")
|
ddl_dialect_module = importlib.import_module(f"aerich.ddl.{cls.dialect}")
|
||||||
return getattr(ddl_dialect_module, f"{cls.dialect.capitalize()}DDL")
|
return getattr(ddl_dialect_module, f"{cls.dialect.capitalize()}DDL")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def init(cls, config: dict, app: str, location: str):
|
async def init(cls, config: dict, app: str, location: str) -> None:
|
||||||
await Tortoise.init(config=config)
|
await Tortoise.init(config=config)
|
||||||
last_version = await cls.get_last_version()
|
last_version = await cls.get_last_version()
|
||||||
cls.app = app
|
cls.app = app
|
||||||
cls.migrate_location = Path(location, app)
|
cls.migrate_location = Path(location, app)
|
||||||
if last_version:
|
if last_version:
|
||||||
cls._last_version_content = last_version.content
|
cls._last_version_content = cast(dict, last_version.content)
|
||||||
|
|
||||||
connection = get_app_connection(config, app)
|
connection = get_app_connection(config, app)
|
||||||
cls.dialect = connection.schema_generator.DIALECT
|
cls.dialect = connection.schema_generator.DIALECT
|
||||||
@ -93,7 +98,7 @@ class Migrate:
|
|||||||
await cls._get_db_version(connection)
|
await cls._get_db_version(connection)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _get_last_version_num(cls):
|
async def _get_last_version_num(cls) -> Optional[int]:
|
||||||
last_version = await cls.get_last_version()
|
last_version = await cls.get_last_version()
|
||||||
if not last_version:
|
if not last_version:
|
||||||
return None
|
return None
|
||||||
@ -101,7 +106,7 @@ class Migrate:
|
|||||||
return int(version.split("_", 1)[0])
|
return int(version.split("_", 1)[0])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def generate_version(cls, name=None):
|
async def generate_version(cls, name=None) -> str:
|
||||||
now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "")
|
now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "")
|
||||||
last_version_num = await cls._get_last_version_num()
|
last_version_num = await cls._get_last_version_num()
|
||||||
if last_version_num is None:
|
if last_version_num is None:
|
||||||
@ -112,18 +117,15 @@ class Migrate:
|
|||||||
return version
|
return version
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _generate_diff_py(cls, name):
|
async def _generate_diff_py(cls, name) -> str:
|
||||||
version = await cls.generate_version(name)
|
version = await cls.generate_version(name)
|
||||||
# delete if same version exists
|
# delete if same version exists
|
||||||
for version_file in cls.get_all_version_files():
|
for version_file in cls.get_all_version_files():
|
||||||
if version_file.startswith(version.split("_")[0]):
|
if version_file.startswith(version.split("_")[0]):
|
||||||
os.unlink(Path(cls.migrate_location, version_file))
|
os.unlink(Path(cls.migrate_location, version_file))
|
||||||
|
|
||||||
version_file = Path(cls.migrate_location, version)
|
|
||||||
content = cls._get_diff_file_content()
|
content = cls._get_diff_file_content()
|
||||||
|
Path(cls.migrate_location, version).write_text(content, encoding="utf-8")
|
||||||
with open(version_file, "w", encoding="utf-8") as f:
|
|
||||||
f.write(content)
|
|
||||||
return version
|
return version
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -136,10 +138,10 @@ class Migrate:
|
|||||||
"""
|
"""
|
||||||
if empty:
|
if empty:
|
||||||
return await cls._generate_diff_py(name)
|
return await cls._generate_diff_py(name)
|
||||||
|
|
||||||
new_version_content = get_models_describe(cls.app)
|
new_version_content = get_models_describe(cls.app)
|
||||||
cls.diff_models(cls._last_version_content, new_version_content)
|
last_version = cast(dict, cls._last_version_content)
|
||||||
cls.diff_models(new_version_content, cls._last_version_content, False)
|
cls.diff_models(last_version, new_version_content)
|
||||||
|
cls.diff_models(new_version_content, last_version, False)
|
||||||
|
|
||||||
cls._merge_operators()
|
cls._merge_operators()
|
||||||
|
|
||||||
@ -165,7 +167,7 @@ class Migrate:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False):
|
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False) -> None:
|
||||||
"""
|
"""
|
||||||
add operator,differentiate fk because fk is order limit
|
add operator,differentiate fk because fk is order limit
|
||||||
:param operator:
|
:param operator:
|
||||||
@ -186,19 +188,37 @@ class Migrate:
|
|||||||
cls.downgrade_operators.append(operator)
|
cls.downgrade_operators.append(operator)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]):
|
def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]) -> list:
|
||||||
ret = []
|
ret: list = []
|
||||||
|
|
||||||
|
def index_hash(self) -> str:
|
||||||
|
h = hashlib.new("MD5", usedforsecurity=False)
|
||||||
|
h.update(
|
||||||
|
self.index_name(cls.ddl.schema_generator, model).encode()
|
||||||
|
+ self.__class__.__name__.encode()
|
||||||
|
)
|
||||||
|
return h.hexdigest()
|
||||||
|
|
||||||
for index in indexes:
|
for index in indexes:
|
||||||
if isinstance(index, Index):
|
if isinstance(index, Index):
|
||||||
index.__hash__ = lambda self: md5( # nosec: B303
|
index.__hash__ = index_hash # type:ignore[method-assign,assignment]
|
||||||
self.index_name(cls.ddl.schema_generator, model).encode()
|
|
||||||
+ self.__class__.__name__.encode()
|
|
||||||
).hexdigest()
|
|
||||||
ret.append(index)
|
ret.append(index)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def diff_models(cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True):
|
def _get_indexes(cls, model, model_describe: dict) -> Set[Union[Index, Tuple[str, ...]]]:
|
||||||
|
indexes: Set[Union[Index, Tuple[str, ...]]] = set()
|
||||||
|
for x in cls._handle_indexes(model, model_describe.get("indexes", [])):
|
||||||
|
if isinstance(x, Index):
|
||||||
|
indexes.add(x)
|
||||||
|
else:
|
||||||
|
indexes.add(cast(Tuple[str, ...], tuple(x)))
|
||||||
|
return indexes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def diff_models(
|
||||||
|
cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
diff models and add operators
|
diff models and add operators
|
||||||
:param old_models:
|
:param old_models:
|
||||||
@ -211,39 +231,35 @@ class Migrate:
|
|||||||
new_models.pop(_aerich, None)
|
new_models.pop(_aerich, None)
|
||||||
|
|
||||||
for new_model_str, new_model_describe in new_models.items():
|
for new_model_str, new_model_describe in new_models.items():
|
||||||
model = cls._get_model(new_model_describe.get("name").split(".")[1])
|
model = cls._get_model(new_model_describe["name"].split(".")[1])
|
||||||
|
|
||||||
if new_model_str not in old_models.keys():
|
if new_model_str not in old_models:
|
||||||
if upgrade:
|
if upgrade:
|
||||||
cls._add_operator(cls.add_model(model), upgrade)
|
cls._add_operator(cls.add_model(model), upgrade)
|
||||||
else:
|
else:
|
||||||
# we can't find origin model when downgrade, so skip
|
# we can't find origin model when downgrade, so skip
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
old_model_describe = old_models.get(new_model_str)
|
old_model_describe = cast(dict, old_models.get(new_model_str))
|
||||||
# rename table
|
# rename table
|
||||||
new_table = new_model_describe.get("table")
|
new_table = cast(str, new_model_describe.get("table"))
|
||||||
old_table = old_model_describe.get("table")
|
old_table = cast(str, old_model_describe.get("table"))
|
||||||
if new_table != old_table:
|
if new_table != old_table:
|
||||||
cls._add_operator(cls.rename_table(model, old_table, new_table), upgrade)
|
cls._add_operator(cls.rename_table(model, old_table, new_table), upgrade)
|
||||||
old_unique_together = set(
|
old_unique_together = set(
|
||||||
map(lambda x: tuple(x), old_model_describe.get("unique_together"))
|
map(
|
||||||
|
lambda x: tuple(x),
|
||||||
|
cast(List[Iterable[str]], old_model_describe.get("unique_together")),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
new_unique_together = set(
|
new_unique_together = set(
|
||||||
map(lambda x: tuple(x), new_model_describe.get("unique_together"))
|
|
||||||
)
|
|
||||||
old_indexes = set(
|
|
||||||
map(
|
map(
|
||||||
lambda x: x if isinstance(x, Index) else tuple(x),
|
lambda x: tuple(x),
|
||||||
cls._handle_indexes(model, old_model_describe.get("indexes", [])),
|
cast(List[Iterable[str]], new_model_describe.get("unique_together")),
|
||||||
)
|
|
||||||
)
|
|
||||||
new_indexes = set(
|
|
||||||
map(
|
|
||||||
lambda x: x if isinstance(x, Index) else tuple(x),
|
|
||||||
cls._handle_indexes(model, new_model_describe.get("indexes", [])),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
old_indexes = cls._get_indexes(model, old_model_describe)
|
||||||
|
new_indexes = cls._get_indexes(model, new_model_describe)
|
||||||
old_pk_field = old_model_describe.get("pk_field")
|
old_pk_field = old_model_describe.get("pk_field")
|
||||||
new_pk_field = new_model_describe.get("pk_field")
|
new_pk_field = new_model_describe.get("pk_field")
|
||||||
# pk field
|
# pk field
|
||||||
@ -253,18 +269,19 @@ class Migrate:
|
|||||||
if action == "change" and option == "name":
|
if action == "change" and option == "name":
|
||||||
cls._add_operator(cls._rename_field(model, *change), upgrade)
|
cls._add_operator(cls._rename_field(model, *change), upgrade)
|
||||||
# m2m fields
|
# m2m fields
|
||||||
old_m2m_fields = old_model_describe.get("m2m_fields")
|
old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields"))
|
||||||
new_m2m_fields = new_model_describe.get("m2m_fields")
|
new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields"))
|
||||||
for action, option, change in diff(old_m2m_fields, new_m2m_fields):
|
for action, option, change in diff(old_m2m_fields, new_m2m_fields):
|
||||||
if change[0][0] == "db_constraint":
|
if change[0][0] == "db_constraint":
|
||||||
continue
|
continue
|
||||||
if isinstance(change[0][1], str):
|
new_value = change[0][1]
|
||||||
|
if isinstance(new_value, str):
|
||||||
for new_m2m_field in new_m2m_fields:
|
for new_m2m_field in new_m2m_fields:
|
||||||
if new_m2m_field["name"] == change[0][1]:
|
if new_m2m_field["name"] == new_value:
|
||||||
table = new_m2m_field.get("through")
|
table = cast(str, new_m2m_field.get("through"))
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
table = change[0][1].get("through")
|
table = new_value.get("through")
|
||||||
if action == "add":
|
if action == "add":
|
||||||
add = False
|
add = False
|
||||||
if upgrade and table not in cls._upgrade_m2m:
|
if upgrade and table not in cls._upgrade_m2m:
|
||||||
@ -274,12 +291,9 @@ class Migrate:
|
|||||||
cls._downgrade_m2m.append(table)
|
cls._downgrade_m2m.append(table)
|
||||||
add = True
|
add = True
|
||||||
if add:
|
if add:
|
||||||
|
ref_desc = cast(dict, new_models.get(new_value.get("model_name")))
|
||||||
cls._add_operator(
|
cls._add_operator(
|
||||||
cls.create_m2m(
|
cls.create_m2m(model, new_value, ref_desc),
|
||||||
model,
|
|
||||||
change[0][1],
|
|
||||||
new_models.get(change[0][1].get("model_name")),
|
|
||||||
),
|
|
||||||
upgrade,
|
upgrade,
|
||||||
fk_m2m_index=True,
|
fk_m2m_index=True,
|
||||||
)
|
)
|
||||||
@ -300,38 +314,36 @@ class Migrate:
|
|||||||
for index in old_unique_together.difference(new_unique_together):
|
for index in old_unique_together.difference(new_unique_together):
|
||||||
cls._add_operator(cls._drop_index(model, index, True), upgrade, True)
|
cls._add_operator(cls._drop_index(model, index, True), upgrade, True)
|
||||||
# add indexes
|
# add indexes
|
||||||
for index in new_indexes.difference(old_indexes):
|
for idx in new_indexes.difference(old_indexes):
|
||||||
cls._add_operator(cls._add_index(model, index, False), upgrade, True)
|
cls._add_operator(cls._add_index(model, idx, False), upgrade, True)
|
||||||
# remove indexes
|
# remove indexes
|
||||||
for index in old_indexes.difference(new_indexes):
|
for idx in old_indexes.difference(new_indexes):
|
||||||
cls._add_operator(cls._drop_index(model, index, False), upgrade, True)
|
cls._add_operator(cls._drop_index(model, idx, False), upgrade, True)
|
||||||
old_data_fields = list(
|
old_data_fields = list(
|
||||||
filter(
|
filter(
|
||||||
lambda x: x.get("db_field_types") is not None,
|
lambda x: x.get("db_field_types") is not None,
|
||||||
old_model_describe.get("data_fields"),
|
cast(List[dict], old_model_describe.get("data_fields")),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
new_data_fields = list(
|
new_data_fields = list(
|
||||||
filter(
|
filter(
|
||||||
lambda x: x.get("db_field_types") is not None,
|
lambda x: x.get("db_field_types") is not None,
|
||||||
new_model_describe.get("data_fields"),
|
cast(List[dict], new_model_describe.get("data_fields")),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields))
|
old_data_fields_name = cast(List[str], [i.get("name") for i in old_data_fields])
|
||||||
new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields))
|
new_data_fields_name = cast(List[str], [i.get("name") for i in new_data_fields])
|
||||||
|
|
||||||
# add fields or rename fields
|
# add fields or rename fields
|
||||||
for new_data_field_name in set(new_data_fields_name).difference(
|
for new_data_field_name in set(new_data_fields_name).difference(
|
||||||
set(old_data_fields_name)
|
set(old_data_fields_name)
|
||||||
):
|
):
|
||||||
new_data_field = next(
|
new_data_field = cls.get_field_by_name(new_data_field_name, new_data_fields)
|
||||||
filter(lambda x: x.get("name") == new_data_field_name, new_data_fields)
|
|
||||||
)
|
|
||||||
is_rename = False
|
is_rename = False
|
||||||
for old_data_field in old_data_fields:
|
for old_data_field in old_data_fields:
|
||||||
changes = list(diff(old_data_field, new_data_field))
|
changes = list(diff(old_data_field, new_data_field))
|
||||||
old_data_field_name = old_data_field.get("name")
|
old_data_field_name = cast(str, old_data_field.get("name"))
|
||||||
if len(changes) == 2:
|
if len(changes) == 2:
|
||||||
# rename field
|
# rename field
|
||||||
if (
|
if (
|
||||||
@ -392,7 +404,7 @@ class Migrate:
|
|||||||
if new_data_field["indexed"]:
|
if new_data_field["indexed"]:
|
||||||
cls._add_operator(
|
cls._add_operator(
|
||||||
cls._add_index(
|
cls._add_index(
|
||||||
model, {new_data_field["db_column"]}, new_data_field["unique"]
|
model, (new_data_field["db_column"],), new_data_field["unique"]
|
||||||
),
|
),
|
||||||
upgrade,
|
upgrade,
|
||||||
True,
|
True,
|
||||||
@ -406,45 +418,34 @@ class Migrate:
|
|||||||
not upgrade and old_data_field_name in cls._rename_new
|
not upgrade and old_data_field_name in cls._rename_new
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
old_data_field = next(
|
old_data_field = cls.get_field_by_name(old_data_field_name, old_data_fields)
|
||||||
filter(lambda x: x.get("name") == old_data_field_name, old_data_fields)
|
db_column = cast(str, old_data_field["db_column"])
|
||||||
)
|
|
||||||
db_column = old_data_field["db_column"]
|
|
||||||
cls._add_operator(
|
cls._add_operator(
|
||||||
cls._remove_field(
|
cls._remove_field(model, db_column),
|
||||||
model,
|
|
||||||
db_column,
|
|
||||||
),
|
|
||||||
upgrade,
|
upgrade,
|
||||||
)
|
)
|
||||||
if old_data_field["indexed"]:
|
if old_data_field["indexed"]:
|
||||||
cls._add_operator(
|
cls._add_operator(
|
||||||
cls._drop_index(
|
cls._drop_index(model, {db_column}),
|
||||||
model,
|
|
||||||
{db_column},
|
|
||||||
),
|
|
||||||
upgrade,
|
upgrade,
|
||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
|
|
||||||
old_fk_fields = old_model_describe.get("fk_fields")
|
old_fk_fields = cast(List[dict], old_model_describe.get("fk_fields"))
|
||||||
new_fk_fields = new_model_describe.get("fk_fields")
|
new_fk_fields = cast(List[dict], new_model_describe.get("fk_fields"))
|
||||||
|
|
||||||
old_fk_fields_name = list(map(lambda x: x.get("name"), old_fk_fields))
|
old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields]
|
||||||
new_fk_fields_name = list(map(lambda x: x.get("name"), new_fk_fields))
|
new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields]
|
||||||
|
|
||||||
# add fk
|
# add fk
|
||||||
for new_fk_field_name in set(new_fk_fields_name).difference(
|
for new_fk_field_name in set(new_fk_fields_name).difference(
|
||||||
set(old_fk_fields_name)
|
set(old_fk_fields_name)
|
||||||
):
|
):
|
||||||
fk_field = next(
|
fk_field = cls.get_field_by_name(new_fk_field_name, new_fk_fields)
|
||||||
filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields)
|
|
||||||
)
|
|
||||||
if fk_field.get("db_constraint"):
|
if fk_field.get("db_constraint"):
|
||||||
|
ref_describe = cast(dict, new_models[fk_field["python_type"]])
|
||||||
cls._add_operator(
|
cls._add_operator(
|
||||||
cls._add_fk(
|
cls._add_fk(model, fk_field, ref_describe),
|
||||||
model, fk_field, new_models.get(fk_field.get("python_type"))
|
|
||||||
),
|
|
||||||
upgrade,
|
upgrade,
|
||||||
fk_m2m_index=True,
|
fk_m2m_index=True,
|
||||||
)
|
)
|
||||||
@ -452,25 +453,20 @@ class Migrate:
|
|||||||
for old_fk_field_name in set(old_fk_fields_name).difference(
|
for old_fk_field_name in set(old_fk_fields_name).difference(
|
||||||
set(new_fk_fields_name)
|
set(new_fk_fields_name)
|
||||||
):
|
):
|
||||||
old_fk_field = next(
|
old_fk_field = cls.get_field_by_name(
|
||||||
filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields)
|
old_fk_field_name, cast(List[dict], old_fk_fields)
|
||||||
)
|
)
|
||||||
if old_fk_field.get("db_constraint"):
|
if old_fk_field.get("db_constraint"):
|
||||||
|
ref_describe = cast(dict, old_models[old_fk_field["python_type"]])
|
||||||
cls._add_operator(
|
cls._add_operator(
|
||||||
cls._drop_fk(
|
cls._drop_fk(model, old_fk_field, ref_describe),
|
||||||
model, old_fk_field, old_models.get(old_fk_field.get("python_type"))
|
|
||||||
),
|
|
||||||
upgrade,
|
upgrade,
|
||||||
fk_m2m_index=True,
|
fk_m2m_index=True,
|
||||||
)
|
)
|
||||||
# change fields
|
# change fields
|
||||||
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
|
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
|
||||||
old_data_field = next(
|
old_data_field = cls.get_field_by_name(field_name, old_data_fields)
|
||||||
filter(lambda x: x.get("name") == field_name, old_data_fields)
|
new_data_field = cls.get_field_by_name(field_name, new_data_fields)
|
||||||
)
|
|
||||||
new_data_field = next(
|
|
||||||
filter(lambda x: x.get("name") == field_name, new_data_fields)
|
|
||||||
)
|
|
||||||
changes = diff(old_data_field, new_data_field)
|
changes = diff(old_data_field, new_data_field)
|
||||||
modified = False
|
modified = False
|
||||||
for change in changes:
|
for change in changes:
|
||||||
@ -479,13 +475,14 @@ class Migrate:
|
|||||||
# change index
|
# change index
|
||||||
unique = new_data_field.get("unique")
|
unique = new_data_field.get("unique")
|
||||||
if old_new[0] is False and old_new[1] is True:
|
if old_new[0] is False and old_new[1] is True:
|
||||||
cls._add_operator(
|
add_or_drop_index = cls._add_index
|
||||||
cls._add_index(model, (field_name,), unique), upgrade, True
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
cls._add_operator(
|
# For drop case, unique value should get from old data
|
||||||
cls._drop_index(model, (field_name,), unique), upgrade, True
|
# TODO: unique = old_data_field.get("unique")
|
||||||
)
|
add_or_drop_index = cls._drop_index
|
||||||
|
cls._add_operator(
|
||||||
|
add_or_drop_index(model, (field_name,), unique), upgrade, True
|
||||||
|
)
|
||||||
elif option == "db_field_types.":
|
elif option == "db_field_types.":
|
||||||
if new_data_field.get("field_type") == "DecimalField":
|
if new_data_field.get("field_type") == "DecimalField":
|
||||||
# modify column
|
# modify column
|
||||||
@ -522,32 +519,33 @@ class Migrate:
|
|||||||
)
|
)
|
||||||
modified = True
|
modified = True
|
||||||
|
|
||||||
for old_model in old_models:
|
for old_model in old_models.keys() - new_models.keys():
|
||||||
if old_model not in new_models.keys():
|
cls._add_operator(cls.drop_model(old_models[old_model]["table"]), upgrade)
|
||||||
cls._add_operator(cls.drop_model(old_models.get(old_model).get("table")), upgrade)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str):
|
def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str) -> str:
|
||||||
return cls.ddl.rename_table(model, old_table_name, new_table_name)
|
return cls.ddl.rename_table(model, old_table_name, new_table_name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_model(cls, model: Type[Model]):
|
def add_model(cls, model: Type[Model]) -> str:
|
||||||
return cls.ddl.create_table(model)
|
return cls.ddl.create_table(model)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def drop_model(cls, table_name: str):
|
def drop_model(cls, table_name: str) -> str:
|
||||||
return cls.ddl.drop_table(table_name)
|
return cls.ddl.drop_table(table_name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_m2m(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
|
def create_m2m(
|
||||||
|
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
|
||||||
|
) -> str:
|
||||||
return cls.ddl.create_m2m(model, field_describe, reference_table_describe)
|
return cls.ddl.create_m2m(model, field_describe, reference_table_describe)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def drop_m2m(cls, table_name: str):
|
def drop_m2m(cls, table_name: str) -> str:
|
||||||
return cls.ddl.drop_m2m(table_name)
|
return cls.ddl.drop_m2m(table_name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]):
|
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Iterable[str]) -> List[str]:
|
||||||
ret = []
|
ret = []
|
||||||
for field_name in fields_name:
|
for field_name in fields_name:
|
||||||
field = model._meta.fields_map[field_name]
|
field = model._meta.fields_map[field_name]
|
||||||
@ -560,65 +558,75 @@ class Migrate:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _drop_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False):
|
def _drop_index(
|
||||||
|
cls, model: Type[Model], fields_name: Union[Iterable[str], Index], unique=False
|
||||||
|
) -> str:
|
||||||
if isinstance(fields_name, Index):
|
if isinstance(fields_name, Index):
|
||||||
return cls.ddl.drop_index_by_name(
|
return cls.ddl.drop_index_by_name(
|
||||||
model, fields_name.index_name(cls.ddl.schema_generator, model)
|
model, fields_name.index_name(cls.ddl.schema_generator, model)
|
||||||
)
|
)
|
||||||
fields_name = cls._resolve_fk_fields_name(model, fields_name)
|
field_names = cls._resolve_fk_fields_name(model, fields_name)
|
||||||
return cls.ddl.drop_index(model, fields_name, unique)
|
return cls.ddl.drop_index(model, field_names, unique)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _add_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False):
|
def _add_index(
|
||||||
|
cls, model: Type[Model], fields_name: Union[Iterable[str], Index], unique=False
|
||||||
|
) -> str:
|
||||||
if isinstance(fields_name, Index):
|
if isinstance(fields_name, Index):
|
||||||
return fields_name.get_sql(cls.ddl.schema_generator, model, False)
|
return fields_name.get_sql(cls.ddl.schema_generator, model, False)
|
||||||
fields_name = cls._resolve_fk_fields_name(model, fields_name)
|
field_names = cls._resolve_fk_fields_name(model, fields_name)
|
||||||
return cls.ddl.add_index(model, fields_name, unique)
|
return cls.ddl.add_index(model, field_names, unique)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False):
|
def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False) -> str:
|
||||||
return cls.ddl.add_column(model, field_describe, is_pk)
|
return cls.ddl.add_column(model, field_describe, is_pk)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _alter_default(cls, model: Type[Model], field_describe: dict):
|
def _alter_default(cls, model: Type[Model], field_describe: dict) -> str:
|
||||||
return cls.ddl.alter_column_default(model, field_describe)
|
return cls.ddl.alter_column_default(model, field_describe)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _alter_null(cls, model: Type[Model], field_describe: dict):
|
def _alter_null(cls, model: Type[Model], field_describe: dict) -> str:
|
||||||
return cls.ddl.alter_column_null(model, field_describe)
|
return cls.ddl.alter_column_null(model, field_describe)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _set_comment(cls, model: Type[Model], field_describe: dict):
|
def _set_comment(cls, model: Type[Model], field_describe: dict) -> str:
|
||||||
return cls.ddl.set_comment(model, field_describe)
|
return cls.ddl.set_comment(model, field_describe)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _modify_field(cls, model: Type[Model], field_describe: dict):
|
def _modify_field(cls, model: Type[Model], field_describe: dict) -> str:
|
||||||
return cls.ddl.modify_column(model, field_describe)
|
return cls.ddl.modify_column(model, field_describe)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _drop_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
|
def _drop_fk(
|
||||||
|
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
|
||||||
|
) -> str:
|
||||||
return cls.ddl.drop_fk(model, field_describe, reference_table_describe)
|
return cls.ddl.drop_fk(model, field_describe, reference_table_describe)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _remove_field(cls, model: Type[Model], column_name: str):
|
def _remove_field(cls, model: Type[Model], column_name: str) -> str:
|
||||||
return cls.ddl.drop_column(model, column_name)
|
return cls.ddl.drop_column(model, column_name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str):
|
def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str) -> str:
|
||||||
return cls.ddl.rename_column(model, old_field_name, new_field_name)
|
return cls.ddl.rename_column(model, old_field_name, new_field_name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _change_field(cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict):
|
def _change_field(
|
||||||
db_field_types = new_field_describe.get("db_field_types")
|
cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict
|
||||||
|
) -> str:
|
||||||
|
db_field_types = cast(dict, new_field_describe.get("db_field_types"))
|
||||||
return cls.ddl.change_column(
|
return cls.ddl.change_column(
|
||||||
model,
|
model,
|
||||||
old_field_describe.get("db_column"),
|
cast(str, old_field_describe.get("db_column")),
|
||||||
new_field_describe.get("db_column"),
|
cast(str, new_field_describe.get("db_column")),
|
||||||
db_field_types.get(cls.dialect) or db_field_types.get(""),
|
cast(str, db_field_types.get(cls.dialect) or db_field_types.get("")),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _add_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
|
def _add_fk(
|
||||||
|
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
add fk
|
add fk
|
||||||
:param model:
|
:param model:
|
||||||
@ -629,7 +637,7 @@ class Migrate:
|
|||||||
return cls.ddl.add_fk(model, field_describe, reference_table_describe)
|
return cls.ddl.add_fk(model, field_describe, reference_table_describe)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _merge_operators(cls):
|
def _merge_operators(cls) -> None:
|
||||||
"""
|
"""
|
||||||
fk/m2m/index must be last when add,first when drop
|
fk/m2m/index must be last when add,first when drop
|
||||||
:return:
|
:return:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user