fix: migrate drop the wrong m2m field when model have multi m2m fields (#390)

* fix: migrate drop the wrong m2m field when model have multi m2m fields

* Make style and update changelog

* refactor: return new lists instead of change argument values in function

* refactor: use custom diff function instead of reorder lists

* docs: fix typo

* Fix hardcoded and rename custom diff function

* Update function doc
This commit is contained in:
Waket Zheng 2024-12-17 01:20:02 +08:00 committed by GitHub
parent 5af8c9cd56
commit 0780919ef3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 313 additions and 46 deletions

View File

@ -5,6 +5,7 @@
### [0.8.1](Unreleased) ### [0.8.1](Unreleased)
#### Fixed #### Fixed
- Migrate drop the wrong m2m field when model have multi m2m fields. (#376)
- KeyError raised when removing or renaming an existing model (#386) - KeyError raised when removing or renaming an existing model (#386)
- fix: error when there is `__init__.py` in the migration folder (#272) - fix: error when there is `__init__.py` in the migration folder (#272)
- Setting null=false on m2m field causes migration to fail. (#334) - Setting null=false on m2m field causes migration to fail. (#334)

View File

@ -13,7 +13,12 @@ from tortoise.indexes import Index
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
from aerich.models import MAX_VERSION_LENGTH, Aerich from aerich.models import MAX_VERSION_LENGTH, Aerich
from aerich.utils import get_app_connection, get_models_describe, is_default_function from aerich.utils import (
get_app_connection,
get_dict_diff_by_key,
get_models_describe,
is_default_function,
)
MIGRATE_TEMPLATE = """from tortoise import BaseDBAsyncClient MIGRATE_TEMPLATE = """from tortoise import BaseDBAsyncClient
@ -223,6 +228,49 @@ class Migrate:
indexes.add(cast(Tuple[str, ...], tuple(x))) indexes.add(cast(Tuple[str, ...], tuple(x)))
return indexes return indexes
@classmethod
def _handle_m2m_fields(
cls, old_model_describe: Dict, new_model_describe: Dict, model, new_models, upgrade=True
) -> None:
old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields"))
new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields"))
for action, option, change in get_dict_diff_by_key(old_m2m_fields, new_m2m_fields):
if (option and option[-1] == "nullable") or change[0][0] == "db_constraint":
continue
new_value = change[0][1]
if isinstance(new_value, str):
for new_m2m_field in new_m2m_fields:
if new_m2m_field["name"] == new_value:
table = cast(str, new_m2m_field.get("through"))
break
else:
table = new_value.get("through")
if action == "add":
add = False
if upgrade and table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table)
add = True
elif not upgrade and table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table)
add = True
if add:
ref_desc = cast(dict, new_models.get(new_value.get("model_name")))
cls._add_operator(
cls.create_m2m(model, new_value, ref_desc),
upgrade,
fk_m2m_index=True,
)
elif action == "remove":
add = False
if upgrade and table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table)
add = True
elif not upgrade and table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table)
add = True
if add:
cls._add_operator(cls.drop_m2m(table), upgrade, True)
@classmethod @classmethod
def diff_models( def diff_models(
cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True
@ -277,48 +325,9 @@ class Migrate:
if action == "change" and option == "name": if action == "change" and option == "name":
cls._add_operator(cls._rename_field(model, *change), upgrade) cls._add_operator(cls._rename_field(model, *change), upgrade)
# m2m fields # m2m fields
old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields")) cls._handle_m2m_fields(
new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields")) old_model_describe, new_model_describe, model, new_models, upgrade
if old_m2m_fields and len(new_m2m_fields) >= 2:
length = len(old_m2m_fields)
field_index = {f["name"]: i for i, f in enumerate(new_m2m_fields)}
new_m2m_fields.sort(key=lambda field: field_index.get(field["name"], length))
for action, option, change in diff(old_m2m_fields, new_m2m_fields):
if (option and option[-1] == "nullable") or change[0][0] == "db_constraint":
continue
new_value = change[0][1]
if isinstance(new_value, str):
for new_m2m_field in new_m2m_fields:
if new_m2m_field["name"] == new_value:
table = cast(str, new_m2m_field.get("through"))
break
else:
table = new_value.get("through")
if action == "add":
add = False
if upgrade and table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table)
add = True
elif not upgrade and table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table)
add = True
if add:
ref_desc = cast(dict, new_models.get(new_value.get("model_name")))
cls._add_operator(
cls.create_m2m(model, new_value, ref_desc),
upgrade,
fk_m2m_index=True,
) )
elif action == "remove":
add = False
if upgrade and table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table)
add = True
elif not upgrade and table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table)
add = True
if add:
cls._add_operator(cls.drop_m2m(table), upgrade, True)
# add unique_together # add unique_together
for index in new_unique_together.difference(old_unique_together): for index in new_unique_together.difference(old_unique_together):
cls._add_operator(cls._add_index(model, index, True), upgrade, True) cls._add_operator(cls._add_index(model, index, True), upgrade, True)

View File

@ -1,12 +1,15 @@
from __future__ import annotations
import importlib.util import importlib.util
import os import os
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import Dict, Optional, Union from typing import Dict, Generator, Optional, Union
from asyncclick import BadOptionUsage, ClickException, Context from asyncclick import BadOptionUsage, ClickException, Context
from dictdiffer import diff
from tortoise import BaseDBAsyncClient, Tortoise from tortoise import BaseDBAsyncClient, Tortoise
@ -101,3 +104,43 @@ def import_py_file(file: Union[str, Path]) -> ModuleType:
module = importlib.util.module_from_spec(spec) # type:ignore[arg-type] module = importlib.util.module_from_spec(spec) # type:ignore[arg-type]
spec.loader.exec_module(module) # type:ignore[union-attr] spec.loader.exec_module(module) # type:ignore[union-attr]
return module return module
def get_dict_diff_by_key(
old_fields: list[dict], new_fields: list[dict], key="through"
) -> Generator[tuple]:
"""
Compare two list by key instead of by index
:param old_fields: previous field info list
:param new_fields: current field info list
:param key: if two dicts have the same value of this key, action is change; otherwise, is remove/add
:return: similar to dictdiffer.diff
Example::
>>> old = [{'through': 'a'}, {'through': 'b'}, {'through': 'c'}]
>>> new = [{'through': 'a'}, {'through': 'c'}] # remove the second element
>>> list(diff(old, new))
[('change', [1, 'through'], ('b', 'c')),
('remove', '', [(2, {'through': 'c'})])]
>>> list(get_dict_diff_by_key(old, new))
[('remove', '', [(0, {'through': 'b'})])]
"""
length_old, length_new = len(old_fields), len(new_fields)
if length_old == 0 or length_new == 0 or length_old == length_new == 1:
yield from diff(old_fields, new_fields)
else:
value_index: dict[str, int] = {f[key]: i for i, f in enumerate(new_fields)}
additions = set(range(length_new))
for field in old_fields:
value = field[key]
if (index := value_index.get(value)) is not None:
additions.remove(index)
yield from diff([field], [new_fields[index]]) # change
else:
yield from diff([field], []) # remove
if additions:
for index in sorted(additions):
yield from diff([], [new_fields[index]]) # add

View File

@ -80,6 +80,9 @@ class Product(Model):
class Config(Model): class Config(Model):
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models.Category", through="config_category_map", related_name="category_set"
)
label = fields.CharField(max_length=200) label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20) key = fields.CharField(max_length=20)
value: dict = fields.JSONField() value: dict = fields.JSONField()

View File

@ -65,6 +65,10 @@ class Product(Model):
class Config(Model): class Config(Model):
category: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category")
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models.Category", through="config_category_map", related_name="config_set"
)
name = fields.CharField(max_length=100, unique=True) name = fields.CharField(max_length=100, unique=True)
label = fields.CharField(max_length=200) label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20) key = fields.CharField(max_length=20)

View File

@ -270,7 +270,48 @@ old_models_describe = {
"backward_fk_fields": [], "backward_fk_fields": [],
"o2o_fields": [], "o2o_fields": [],
"backward_o2o_fields": [], "backward_o2o_fields": [],
"m2m_fields": [], "m2m_fields": [
{
"name": "category",
"field_type": "ManyToManyFieldInstance",
"python_type": "models.Category",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"model_name": "models.Category",
"related_name": "configs",
"forward_key": "category_id",
"backward_key": "config_id",
"through": "config_category",
"on_delete": "CASCADE",
"_generated": False,
},
{
"name": "categories",
"field_type": "ManyToManyFieldInstance",
"python_type": "models.Category",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"model_name": "models.Category",
"related_name": "config_set",
"forward_key": "category_id",
"backward_key": "config_id",
"through": "config_category_map",
"on_delete": "CASCADE",
"_generated": False,
},
],
}, },
"models.Email": { "models.Email": {
"name": "models.Email", "name": "models.Email",
@ -898,6 +939,8 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL", "ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL",
"ALTER TABLE `email` MODIFY COLUMN `is_primary` BOOL NOT NULL DEFAULT 0", "ALTER TABLE `email` MODIFY COLUMN `is_primary` BOOL NOT NULL DEFAULT 0",
"CREATE TABLE `product_user` (\n `product_id` INT NOT NULL REFERENCES `product` (`id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4", "CREATE TABLE `product_user` (\n `product_id` INT NOT NULL REFERENCES `product` (`id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"CREATE TABLE `config_category_map` (\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE,\n `config_id` INT NOT NULL REFERENCES `config` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"DROP TABLE IF EXISTS `config_category`",
} }
expected_downgrade_operators = { expected_downgrade_operators = {
"ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL", "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL",
@ -937,6 +980,8 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(12,9) NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(12,9) NOT NULL",
"ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL", "ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL",
"ALTER TABLE `email` MODIFY COLUMN `is_primary` BOOL NOT NULL DEFAULT 0", "ALTER TABLE `email` MODIFY COLUMN `is_primary` BOOL NOT NULL DEFAULT 0",
"CREATE TABLE `config_category` (\n `config_id` INT NOT NULL REFERENCES `config` (`id`) ON DELETE CASCADE,\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"DROP TABLE IF EXISTS `config_category_map`",
} }
assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators) assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators)
@ -983,6 +1028,8 @@ def test_migrate(mocker: MockerFixture):
'CREATE UNIQUE INDEX "uid_product_name_869427" ON "product" ("name", "type_db_alias")', 'CREATE UNIQUE INDEX "uid_product_name_869427" ON "product" ("name", "type_db_alias")',
'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")', 'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")',
'CREATE TABLE "product_user" (\n "product_id" INT NOT NULL REFERENCES "product" ("id") ON DELETE CASCADE,\n "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE\n)', 'CREATE TABLE "product_user" (\n "product_id" INT NOT NULL REFERENCES "product" ("id") ON DELETE CASCADE,\n "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE\n)',
'CREATE TABLE "config_category_map" (\n "category_id" INT NOT NULL REFERENCES "category" ("id") ON DELETE CASCADE,\n "config_id" INT NOT NULL REFERENCES "config" ("id") ON DELETE CASCADE\n)',
'DROP TABLE IF EXISTS "config_category"',
} }
expected_downgrade_operators = { expected_downgrade_operators = {
'CREATE UNIQUE INDEX "uid_category_title_f7fc03" ON "category" ("title")', 'CREATE UNIQUE INDEX "uid_category_title_f7fc03" ON "category" ("title")',
@ -1022,6 +1069,8 @@ def test_migrate(mocker: MockerFixture):
'DROP INDEX IF EXISTS "uid_product_name_869427"', 'DROP INDEX IF EXISTS "uid_product_name_869427"',
'DROP TABLE IF EXISTS "email_user"', 'DROP TABLE IF EXISTS "email_user"',
'DROP TABLE IF EXISTS "newmodel"', 'DROP TABLE IF EXISTS "newmodel"',
'CREATE TABLE "config_category" (\n "config_id" INT NOT NULL REFERENCES "config" ("id") ON DELETE CASCADE,\n "category_id" INT NOT NULL REFERENCES "category" ("id") ON DELETE CASCADE\n)',
'DROP TABLE IF EXISTS "config_category_map"',
} }
assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators) assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators)
assert not set(Migrate.downgrade_operators).symmetric_difference( assert not set(Migrate.downgrade_operators).symmetric_difference(

View File

@ -1,6 +1,164 @@
from aerich.utils import import_py_file from aerich.utils import get_dict_diff_by_key, import_py_file
def test_import_py_file() -> None: def test_import_py_file() -> None:
m = import_py_file("aerich/utils.py") m = import_py_file("aerich/utils.py")
assert getattr(m, "import_py_file", None) assert getattr(m, "import_py_file", None)
class TestDiffFields:
def test_the_same_through_order(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "members", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert type(get_dict_diff_by_key(old, new)).__name__ == "generator"
assert len(diffs) == 1
assert diffs == [("change", [0, "name"], ("users", "members"))]
def test_same_through_with_different_orders(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "admins", "through": "admins_group"},
{"name": "members", "through": "users_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 1
assert diffs == [("change", [0, "name"], ("users", "members"))]
def test_the_same_field_name_order(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "users", "through": "user_groups"},
{"name": "admins", "through": "admin_groups"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 4
assert diffs == [
("remove", "", [(0, {"name": "users", "through": "users_group"})]),
("remove", "", [(0, {"name": "admins", "through": "admins_group"})]),
("add", "", [(0, {"name": "users", "through": "user_groups"})]),
("add", "", [(0, {"name": "admins", "through": "admin_groups"})]),
]
def test_same_field_name_with_different_orders(self) -> None:
old = [
{"name": "admins", "through": "admins_group"},
{"name": "users", "through": "users_group"},
]
new = [
{"name": "users", "through": "user_groups"},
{"name": "admins", "through": "admin_groups"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 4
assert diffs == [
("remove", "", [(0, {"name": "admins", "through": "admins_group"})]),
("remove", "", [(0, {"name": "users", "through": "users_group"})]),
("add", "", [(0, {"name": "users", "through": "user_groups"})]),
("add", "", [(0, {"name": "admins", "through": "admin_groups"})]),
]
def test_drop_one(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "admins", "through": "admins_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 1
assert diffs == [("remove", "", [(0, {"name": "users", "through": "users_group"})])]
def test_add_one(self) -> None:
old = [
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 1
assert diffs == [("add", "", [(0, {"name": "users", "through": "users_group"})])]
def test_drop_some(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
{"name": "staffs", "through": "staffs_group"},
]
new = [
{"name": "admins", "through": "admins_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 2
assert diffs == [
("remove", "", [(0, {"name": "users", "through": "users_group"})]),
("remove", "", [(0, {"name": "staffs", "through": "staffs_group"})]),
]
def test_add_some(self) -> None:
old = [
{"name": "staffs", "through": "staffs_group"},
]
new = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
{"name": "staffs", "through": "staffs_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 2
assert diffs == [
("add", "", [(0, {"name": "users", "through": "users_group"})]),
("add", "", [(0, {"name": "admins", "through": "admins_group"})]),
]
def test_some_through_unchanged(self) -> None:
old = [
{"name": "staffs", "through": "staffs_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "users", "through": "users_group"},
{"name": "admins_new", "through": "admins_group"},
{"name": "staffs_new", "through": "staffs_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 3
assert diffs == [
("change", [0, "name"], ("staffs", "staffs_new")),
("change", [0, "name"], ("admins", "admins_new")),
("add", "", [(0, {"name": "users", "through": "users_group"})]),
]
def test_some_unchanged_without_drop_or_add(self) -> None:
old = [
{"name": "staffs", "through": "staffs_group"},
{"name": "admins", "through": "admins_group"},
{"name": "users", "through": "users_group"},
]
new = [
{"name": "users_new", "through": "users_group"},
{"name": "admins_new", "through": "admins_group"},
{"name": "staffs_new", "through": "staffs_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 3
assert diffs == [
("change", [0, "name"], ("staffs", "staffs_new")),
("change", [0, "name"], ("admins", "admins_new")),
("change", [0, "name"], ("users", "users_new")),
]