fix: migrate drop the wrong m2m field when model have multi m2m fields (#390)
* fix: migrate drop the wrong m2m field when model have multi m2m fields * Make style and update changelog * refactor: return new lists instead of change argument values in function * refactor: use custom diff function instead of reorder lists * docs: fix typo * Fix hardcoded and rename custom diff function * Update function doc
This commit is contained in:
@@ -13,7 +13,12 @@ from tortoise.indexes import Index
|
||||
|
||||
from aerich.ddl import BaseDDL
|
||||
from aerich.models import MAX_VERSION_LENGTH, Aerich
|
||||
from aerich.utils import get_app_connection, get_models_describe, is_default_function
|
||||
from aerich.utils import (
|
||||
get_app_connection,
|
||||
get_dict_diff_by_key,
|
||||
get_models_describe,
|
||||
is_default_function,
|
||||
)
|
||||
|
||||
MIGRATE_TEMPLATE = """from tortoise import BaseDBAsyncClient
|
||||
|
||||
@@ -223,6 +228,49 @@ class Migrate:
|
||||
indexes.add(cast(Tuple[str, ...], tuple(x)))
|
||||
return indexes
|
||||
|
||||
@classmethod
|
||||
def _handle_m2m_fields(
|
||||
cls, old_model_describe: Dict, new_model_describe: Dict, model, new_models, upgrade=True
|
||||
) -> None:
|
||||
old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields"))
|
||||
new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields"))
|
||||
for action, option, change in get_dict_diff_by_key(old_m2m_fields, new_m2m_fields):
|
||||
if (option and option[-1] == "nullable") or change[0][0] == "db_constraint":
|
||||
continue
|
||||
new_value = change[0][1]
|
||||
if isinstance(new_value, str):
|
||||
for new_m2m_field in new_m2m_fields:
|
||||
if new_m2m_field["name"] == new_value:
|
||||
table = cast(str, new_m2m_field.get("through"))
|
||||
break
|
||||
else:
|
||||
table = new_value.get("through")
|
||||
if action == "add":
|
||||
add = False
|
||||
if upgrade and table not in cls._upgrade_m2m:
|
||||
cls._upgrade_m2m.append(table)
|
||||
add = True
|
||||
elif not upgrade and table not in cls._downgrade_m2m:
|
||||
cls._downgrade_m2m.append(table)
|
||||
add = True
|
||||
if add:
|
||||
ref_desc = cast(dict, new_models.get(new_value.get("model_name")))
|
||||
cls._add_operator(
|
||||
cls.create_m2m(model, new_value, ref_desc),
|
||||
upgrade,
|
||||
fk_m2m_index=True,
|
||||
)
|
||||
elif action == "remove":
|
||||
add = False
|
||||
if upgrade and table not in cls._upgrade_m2m:
|
||||
cls._upgrade_m2m.append(table)
|
||||
add = True
|
||||
elif not upgrade and table not in cls._downgrade_m2m:
|
||||
cls._downgrade_m2m.append(table)
|
||||
add = True
|
||||
if add:
|
||||
cls._add_operator(cls.drop_m2m(table), upgrade, True)
|
||||
|
||||
@classmethod
|
||||
def diff_models(
|
||||
cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True
|
||||
@@ -277,48 +325,9 @@ class Migrate:
|
||||
if action == "change" and option == "name":
|
||||
cls._add_operator(cls._rename_field(model, *change), upgrade)
|
||||
# m2m fields
|
||||
old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields"))
|
||||
new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields"))
|
||||
if old_m2m_fields and len(new_m2m_fields) >= 2:
|
||||
length = len(old_m2m_fields)
|
||||
field_index = {f["name"]: i for i, f in enumerate(new_m2m_fields)}
|
||||
new_m2m_fields.sort(key=lambda field: field_index.get(field["name"], length))
|
||||
for action, option, change in diff(old_m2m_fields, new_m2m_fields):
|
||||
if (option and option[-1] == "nullable") or change[0][0] == "db_constraint":
|
||||
continue
|
||||
new_value = change[0][1]
|
||||
if isinstance(new_value, str):
|
||||
for new_m2m_field in new_m2m_fields:
|
||||
if new_m2m_field["name"] == new_value:
|
||||
table = cast(str, new_m2m_field.get("through"))
|
||||
break
|
||||
else:
|
||||
table = new_value.get("through")
|
||||
if action == "add":
|
||||
add = False
|
||||
if upgrade and table not in cls._upgrade_m2m:
|
||||
cls._upgrade_m2m.append(table)
|
||||
add = True
|
||||
elif not upgrade and table not in cls._downgrade_m2m:
|
||||
cls._downgrade_m2m.append(table)
|
||||
add = True
|
||||
if add:
|
||||
ref_desc = cast(dict, new_models.get(new_value.get("model_name")))
|
||||
cls._add_operator(
|
||||
cls.create_m2m(model, new_value, ref_desc),
|
||||
upgrade,
|
||||
fk_m2m_index=True,
|
||||
)
|
||||
elif action == "remove":
|
||||
add = False
|
||||
if upgrade and table not in cls._upgrade_m2m:
|
||||
cls._upgrade_m2m.append(table)
|
||||
add = True
|
||||
elif not upgrade and table not in cls._downgrade_m2m:
|
||||
cls._downgrade_m2m.append(table)
|
||||
add = True
|
||||
if add:
|
||||
cls._add_operator(cls.drop_m2m(table), upgrade, True)
|
||||
cls._handle_m2m_fields(
|
||||
old_model_describe, new_model_describe, model, new_models, upgrade
|
||||
)
|
||||
# add unique_together
|
||||
for index in new_unique_together.difference(old_unique_together):
|
||||
cls._add_operator(cls._add_index(model, index, True), upgrade, True)
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Dict, Generator, Optional, Union
|
||||
|
||||
from asyncclick import BadOptionUsage, ClickException, Context
|
||||
from dictdiffer import diff
|
||||
from tortoise import BaseDBAsyncClient, Tortoise
|
||||
|
||||
|
||||
@@ -101,3 +104,43 @@ def import_py_file(file: Union[str, Path]) -> ModuleType:
|
||||
module = importlib.util.module_from_spec(spec) # type:ignore[arg-type]
|
||||
spec.loader.exec_module(module) # type:ignore[union-attr]
|
||||
return module
|
||||
|
||||
|
||||
def get_dict_diff_by_key(
|
||||
old_fields: list[dict], new_fields: list[dict], key="through"
|
||||
) -> Generator[tuple]:
|
||||
"""
|
||||
Compare two list by key instead of by index
|
||||
|
||||
:param old_fields: previous field info list
|
||||
:param new_fields: current field info list
|
||||
:param key: if two dicts have the same value of this key, action is change; otherwise, is remove/add
|
||||
:return: similar to dictdiffer.diff
|
||||
|
||||
Example::
|
||||
|
||||
>>> old = [{'through': 'a'}, {'through': 'b'}, {'through': 'c'}]
|
||||
>>> new = [{'through': 'a'}, {'through': 'c'}] # remove the second element
|
||||
>>> list(diff(old, new))
|
||||
[('change', [1, 'through'], ('b', 'c')),
|
||||
('remove', '', [(2, {'through': 'c'})])]
|
||||
>>> list(get_dict_diff_by_key(old, new))
|
||||
[('remove', '', [(0, {'through': 'b'})])]
|
||||
|
||||
"""
|
||||
length_old, length_new = len(old_fields), len(new_fields)
|
||||
if length_old == 0 or length_new == 0 or length_old == length_new == 1:
|
||||
yield from diff(old_fields, new_fields)
|
||||
else:
|
||||
value_index: dict[str, int] = {f[key]: i for i, f in enumerate(new_fields)}
|
||||
additions = set(range(length_new))
|
||||
for field in old_fields:
|
||||
value = field[key]
|
||||
if (index := value_index.get(value)) is not None:
|
||||
additions.remove(index)
|
||||
yield from diff([field], [new_fields[index]]) # change
|
||||
else:
|
||||
yield from diff([field], []) # remove
|
||||
if additions:
|
||||
for index in sorted(additions):
|
||||
yield from diff([], [new_fields[index]]) # add
|
||||
|
||||
Reference in New Issue
Block a user