fix: migrate drop the wrong m2m field when model have multi m2m fields (#390)

* fix: migrate drop the wrong m2m field when model have multi m2m fields

* Make style and update changelog

* refactor: return new lists instead of change argument values in function

* refactor: use custom diff function instead of reorder lists

* docs: fix typo

* Fix hardcoded and rename custom diff function

* Update function doc
This commit is contained in:
Waket Zheng
2024-12-17 01:20:02 +08:00
committed by GitHub
parent 5af8c9cd56
commit 0780919ef3
7 changed files with 313 additions and 46 deletions

View File

@@ -13,7 +13,12 @@ from tortoise.indexes import Index
from aerich.ddl import BaseDDL
from aerich.models import MAX_VERSION_LENGTH, Aerich
from aerich.utils import get_app_connection, get_models_describe, is_default_function
from aerich.utils import (
get_app_connection,
get_dict_diff_by_key,
get_models_describe,
is_default_function,
)
MIGRATE_TEMPLATE = """from tortoise import BaseDBAsyncClient
@@ -223,6 +228,49 @@ class Migrate:
indexes.add(cast(Tuple[str, ...], tuple(x)))
return indexes
@classmethod
def _handle_m2m_fields(
cls, old_model_describe: Dict, new_model_describe: Dict, model, new_models, upgrade=True
) -> None:
old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields"))
new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields"))
for action, option, change in get_dict_diff_by_key(old_m2m_fields, new_m2m_fields):
if (option and option[-1] == "nullable") or change[0][0] == "db_constraint":
continue
new_value = change[0][1]
if isinstance(new_value, str):
for new_m2m_field in new_m2m_fields:
if new_m2m_field["name"] == new_value:
table = cast(str, new_m2m_field.get("through"))
break
else:
table = new_value.get("through")
if action == "add":
add = False
if upgrade and table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table)
add = True
elif not upgrade and table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table)
add = True
if add:
ref_desc = cast(dict, new_models.get(new_value.get("model_name")))
cls._add_operator(
cls.create_m2m(model, new_value, ref_desc),
upgrade,
fk_m2m_index=True,
)
elif action == "remove":
add = False
if upgrade and table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table)
add = True
elif not upgrade and table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table)
add = True
if add:
cls._add_operator(cls.drop_m2m(table), upgrade, True)
@classmethod
def diff_models(
cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True
@@ -277,48 +325,9 @@ class Migrate:
if action == "change" and option == "name":
cls._add_operator(cls._rename_field(model, *change), upgrade)
# m2m fields
old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields"))
new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields"))
if old_m2m_fields and len(new_m2m_fields) >= 2:
length = len(old_m2m_fields)
field_index = {f["name"]: i for i, f in enumerate(new_m2m_fields)}
new_m2m_fields.sort(key=lambda field: field_index.get(field["name"], length))
for action, option, change in diff(old_m2m_fields, new_m2m_fields):
if (option and option[-1] == "nullable") or change[0][0] == "db_constraint":
continue
new_value = change[0][1]
if isinstance(new_value, str):
for new_m2m_field in new_m2m_fields:
if new_m2m_field["name"] == new_value:
table = cast(str, new_m2m_field.get("through"))
break
else:
table = new_value.get("through")
if action == "add":
add = False
if upgrade and table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table)
add = True
elif not upgrade and table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table)
add = True
if add:
ref_desc = cast(dict, new_models.get(new_value.get("model_name")))
cls._add_operator(
cls.create_m2m(model, new_value, ref_desc),
upgrade,
fk_m2m_index=True,
)
elif action == "remove":
add = False
if upgrade and table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table)
add = True
elif not upgrade and table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table)
add = True
if add:
cls._add_operator(cls.drop_m2m(table), upgrade, True)
cls._handle_m2m_fields(
old_model_describe, new_model_describe, model, new_models, upgrade
)
# add unique_together
for index in new_unique_together.difference(old_unique_together):
cls._add_operator(cls._add_index(model, index, True), upgrade, True)

View File

@@ -1,12 +1,15 @@
from __future__ import annotations
import importlib.util
import os
import re
import sys
from pathlib import Path
from types import ModuleType
from typing import Dict, Optional, Union
from typing import Dict, Generator, Optional, Union
from asyncclick import BadOptionUsage, ClickException, Context
from dictdiffer import diff
from tortoise import BaseDBAsyncClient, Tortoise
@@ -101,3 +104,43 @@ def import_py_file(file: Union[str, Path]) -> ModuleType:
module = importlib.util.module_from_spec(spec) # type:ignore[arg-type]
spec.loader.exec_module(module) # type:ignore[union-attr]
return module
def get_dict_diff_by_key(
old_fields: list[dict], new_fields: list[dict], key="through"
) -> Generator[tuple]:
"""
Compare two list by key instead of by index
:param old_fields: previous field info list
:param new_fields: current field info list
:param key: if two dicts have the same value of this key, action is change; otherwise, is remove/add
:return: similar to dictdiffer.diff
Example::
>>> old = [{'through': 'a'}, {'through': 'b'}, {'through': 'c'}]
>>> new = [{'through': 'a'}, {'through': 'c'}] # remove the second element
>>> list(diff(old, new))
[('change', [1, 'through'], ('b', 'c')),
('remove', '', [(2, {'through': 'c'})])]
>>> list(get_dict_diff_by_key(old, new))
[('remove', '', [(0, {'through': 'b'})])]
"""
length_old, length_new = len(old_fields), len(new_fields)
if length_old == 0 or length_new == 0 or length_old == length_new == 1:
yield from diff(old_fields, new_fields)
else:
value_index: dict[str, int] = {f[key]: i for i, f in enumerate(new_fields)}
additions = set(range(length_new))
for field in old_fields:
value = field[key]
if (index := value_index.get(value)) is not None:
additions.remove(index)
yield from diff([field], [new_fields[index]]) # change
else:
yield from diff([field], []) # remove
if additions:
for index in sorted(additions):
yield from diff([], [new_fields[index]]) # add