diff --git a/CHANGELOG.md b/CHANGELOG.md index 3dee6e0..05852f3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### [0.8.1](Unreleased) #### Fixed +- Migrate drop the wrong m2m field when model have multi m2m fields. (#376) - KeyError raised when removing or renaming an existing model (#386) - fix: error when there is `__init__.py` in the migration folder (#272) - Setting null=false on m2m field causes migration to fail. (#334) diff --git a/aerich/migrate.py b/aerich/migrate.py index fd0f765..397e62a 100644 --- a/aerich/migrate.py +++ b/aerich/migrate.py @@ -13,7 +13,12 @@ from tortoise.indexes import Index from aerich.ddl import BaseDDL from aerich.models import MAX_VERSION_LENGTH, Aerich -from aerich.utils import get_app_connection, get_models_describe, is_default_function +from aerich.utils import ( + get_app_connection, + get_dict_diff_by_key, + get_models_describe, + is_default_function, +) MIGRATE_TEMPLATE = """from tortoise import BaseDBAsyncClient @@ -223,6 +228,49 @@ class Migrate: indexes.add(cast(Tuple[str, ...], tuple(x))) return indexes + @classmethod + def _handle_m2m_fields( + cls, old_model_describe: Dict, new_model_describe: Dict, model, new_models, upgrade=True + ) -> None: + old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields")) + new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields")) + for action, option, change in get_dict_diff_by_key(old_m2m_fields, new_m2m_fields): + if (option and option[-1] == "nullable") or change[0][0] == "db_constraint": + continue + new_value = change[0][1] + if isinstance(new_value, str): + for new_m2m_field in new_m2m_fields: + if new_m2m_field["name"] == new_value: + table = cast(str, new_m2m_field.get("through")) + break + else: + table = new_value.get("through") + if action == "add": + add = False + if upgrade and table not in cls._upgrade_m2m: + cls._upgrade_m2m.append(table) + add = True + elif not upgrade and table not in cls._downgrade_m2m: + cls._downgrade_m2m.append(table) + add = True + if add: + ref_desc = cast(dict, new_models.get(new_value.get("model_name"))) + cls._add_operator( + cls.create_m2m(model, new_value, ref_desc), + upgrade, + fk_m2m_index=True, + ) + elif action == "remove": + add = False + if upgrade and table not in cls._upgrade_m2m: + cls._upgrade_m2m.append(table) + add = True + elif not upgrade and table not in cls._downgrade_m2m: + cls._downgrade_m2m.append(table) + add = True + if add: + cls._add_operator(cls.drop_m2m(table), upgrade, True) + @classmethod def diff_models( cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True @@ -277,48 +325,9 @@ class Migrate: if action == "change" and option == "name": cls._add_operator(cls._rename_field(model, *change), upgrade) # m2m fields - old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields")) - new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields")) - if old_m2m_fields and len(new_m2m_fields) >= 2: - length = len(old_m2m_fields) - field_index = {f["name"]: i for i, f in enumerate(new_m2m_fields)} - new_m2m_fields.sort(key=lambda field: field_index.get(field["name"], length)) - for action, option, change in diff(old_m2m_fields, new_m2m_fields): - if (option and option[-1] == "nullable") or change[0][0] == "db_constraint": - continue - new_value = change[0][1] - if isinstance(new_value, str): - for new_m2m_field in new_m2m_fields: - if new_m2m_field["name"] == new_value: - table = cast(str, new_m2m_field.get("through")) - break - else: - table = new_value.get("through") - if action == "add": - add = False - if upgrade and table not in cls._upgrade_m2m: - cls._upgrade_m2m.append(table) - add = True - elif not upgrade and table not in cls._downgrade_m2m: - cls._downgrade_m2m.append(table) - add = True - if add: - ref_desc = cast(dict, new_models.get(new_value.get("model_name"))) - cls._add_operator( - cls.create_m2m(model, new_value, ref_desc), - upgrade, - fk_m2m_index=True, - ) - elif action == "remove": - add = False - if upgrade and table not in cls._upgrade_m2m: - cls._upgrade_m2m.append(table) - add = True - elif not upgrade and table not in cls._downgrade_m2m: - cls._downgrade_m2m.append(table) - add = True - if add: - cls._add_operator(cls.drop_m2m(table), upgrade, True) + cls._handle_m2m_fields( + old_model_describe, new_model_describe, model, new_models, upgrade + ) # add unique_together for index in new_unique_together.difference(old_unique_together): cls._add_operator(cls._add_index(model, index, True), upgrade, True) diff --git a/aerich/utils.py b/aerich/utils.py index 728c91b..0fbbb04 100644 --- a/aerich/utils.py +++ b/aerich/utils.py @@ -1,12 +1,15 @@ +from __future__ import annotations + import importlib.util import os import re import sys from pathlib import Path from types import ModuleType -from typing import Dict, Optional, Union +from typing import Dict, Generator, Optional, Union from asyncclick import BadOptionUsage, ClickException, Context +from dictdiffer import diff from tortoise import BaseDBAsyncClient, Tortoise @@ -101,3 +104,43 @@ def import_py_file(file: Union[str, Path]) -> ModuleType: module = importlib.util.module_from_spec(spec) # type:ignore[arg-type] spec.loader.exec_module(module) # type:ignore[union-attr] return module + + +def get_dict_diff_by_key( + old_fields: list[dict], new_fields: list[dict], key="through" +) -> Generator[tuple]: + """ + Compare two list by key instead of by index + + :param old_fields: previous field info list + :param new_fields: current field info list + :param key: if two dicts have the same value of this key, action is change; otherwise, is remove/add + :return: similar to dictdiffer.diff + + Example:: + + >>> old = [{'through': 'a'}, {'through': 'b'}, {'through': 'c'}] + >>> new = [{'through': 'a'}, {'through': 'c'}] # remove the second element + >>> list(diff(old, new)) + [('change', [1, 'through'], ('b', 'c')), + ('remove', '', [(2, {'through': 'c'})])] + >>> list(get_dict_diff_by_key(old, new)) + [('remove', '', [(0, {'through': 'b'})])] + + """ + length_old, length_new = len(old_fields), len(new_fields) + if length_old == 0 or length_new == 0 or length_old == length_new == 1: + yield from diff(old_fields, new_fields) + else: + value_index: dict[str, int] = {f[key]: i for i, f in enumerate(new_fields)} + additions = set(range(length_new)) + for field in old_fields: + value = field[key] + if (index := value_index.get(value)) is not None: + additions.remove(index) + yield from diff([field], [new_fields[index]]) # change + else: + yield from diff([field], []) # remove + if additions: + for index in sorted(additions): + yield from diff([], [new_fields[index]]) # add diff --git a/tests/models.py b/tests/models.py index 3126fbe..e4ea87c 100644 --- a/tests/models.py +++ b/tests/models.py @@ -80,6 +80,9 @@ class Product(Model): class Config(Model): + categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField( + "models.Category", through="config_category_map", related_name="category_set" + ) label = fields.CharField(max_length=200) key = fields.CharField(max_length=20) value: dict = fields.JSONField() diff --git a/tests/old_models.py b/tests/old_models.py index 5225597..5444232 100644 --- a/tests/old_models.py +++ b/tests/old_models.py @@ -65,6 +65,10 @@ class Product(Model): class Config(Model): + category: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category") + categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField( + "models.Category", through="config_category_map", related_name="config_set" + ) name = fields.CharField(max_length=100, unique=True) label = fields.CharField(max_length=200) key = fields.CharField(max_length=20) diff --git a/tests/test_migrate.py b/tests/test_migrate.py index 6cb1d33..cab0fa9 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -270,7 +270,48 @@ old_models_describe = { "backward_fk_fields": [], "o2o_fields": [], "backward_o2o_fields": [], - "m2m_fields": [], + "m2m_fields": [ + { + "name": "category", + "field_type": "ManyToManyFieldInstance", + "python_type": "models.Category", + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {}, + "model_name": "models.Category", + "related_name": "configs", + "forward_key": "category_id", + "backward_key": "config_id", + "through": "config_category", + "on_delete": "CASCADE", + "_generated": False, + }, + { + "name": "categories", + "field_type": "ManyToManyFieldInstance", + "python_type": "models.Category", + "generated": False, + "nullable": False, + "unique": False, + "indexed": False, + "default": None, + "description": None, + "docstring": None, + "constraints": {}, + "model_name": "models.Category", + "related_name": "config_set", + "forward_key": "category_id", + "backward_key": "config_id", + "through": "config_category_map", + "on_delete": "CASCADE", + "_generated": False, + }, + ], }, "models.Email": { "name": "models.Email", @@ -898,6 +939,8 @@ def test_migrate(mocker: MockerFixture): "ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL", "ALTER TABLE `email` MODIFY COLUMN `is_primary` BOOL NOT NULL DEFAULT 0", "CREATE TABLE `product_user` (\n `product_id` INT NOT NULL REFERENCES `product` (`id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4", + "CREATE TABLE `config_category_map` (\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE,\n `config_id` INT NOT NULL REFERENCES `config` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4", + "DROP TABLE IF EXISTS `config_category`", } expected_downgrade_operators = { "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL", @@ -937,6 +980,8 @@ def test_migrate(mocker: MockerFixture): "ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(12,9) NOT NULL", "ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL", "ALTER TABLE `email` MODIFY COLUMN `is_primary` BOOL NOT NULL DEFAULT 0", + "CREATE TABLE `config_category` (\n `config_id` INT NOT NULL REFERENCES `config` (`id`) ON DELETE CASCADE,\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4", + "DROP TABLE IF EXISTS `config_category_map`", } assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators) @@ -983,6 +1028,8 @@ def test_migrate(mocker: MockerFixture): 'CREATE UNIQUE INDEX "uid_product_name_869427" ON "product" ("name", "type_db_alias")', 'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")', 'CREATE TABLE "product_user" (\n "product_id" INT NOT NULL REFERENCES "product" ("id") ON DELETE CASCADE,\n "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE\n)', + 'CREATE TABLE "config_category_map" (\n "category_id" INT NOT NULL REFERENCES "category" ("id") ON DELETE CASCADE,\n "config_id" INT NOT NULL REFERENCES "config" ("id") ON DELETE CASCADE\n)', + 'DROP TABLE IF EXISTS "config_category"', } expected_downgrade_operators = { 'CREATE UNIQUE INDEX "uid_category_title_f7fc03" ON "category" ("title")', @@ -1022,6 +1069,8 @@ def test_migrate(mocker: MockerFixture): 'DROP INDEX IF EXISTS "uid_product_name_869427"', 'DROP TABLE IF EXISTS "email_user"', 'DROP TABLE IF EXISTS "newmodel"', + 'CREATE TABLE "config_category" (\n "config_id" INT NOT NULL REFERENCES "config" ("id") ON DELETE CASCADE,\n "category_id" INT NOT NULL REFERENCES "category" ("id") ON DELETE CASCADE\n)', + 'DROP TABLE IF EXISTS "config_category_map"', } assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators) assert not set(Migrate.downgrade_operators).symmetric_difference( diff --git a/tests/test_utils.py b/tests/test_utils.py index 20ba73d..4be38a3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,164 @@ -from aerich.utils import import_py_file +from aerich.utils import get_dict_diff_by_key, import_py_file def test_import_py_file() -> None: m = import_py_file("aerich/utils.py") assert getattr(m, "import_py_file", None) + + +class TestDiffFields: + def test_the_same_through_order(self) -> None: + old = [ + {"name": "users", "through": "users_group"}, + {"name": "admins", "through": "admins_group"}, + ] + new = [ + {"name": "members", "through": "users_group"}, + {"name": "admins", "through": "admins_group"}, + ] + diffs = list(get_dict_diff_by_key(old, new)) + assert type(get_dict_diff_by_key(old, new)).__name__ == "generator" + assert len(diffs) == 1 + assert diffs == [("change", [0, "name"], ("users", "members"))] + + def test_same_through_with_different_orders(self) -> None: + old = [ + {"name": "users", "through": "users_group"}, + {"name": "admins", "through": "admins_group"}, + ] + new = [ + {"name": "admins", "through": "admins_group"}, + {"name": "members", "through": "users_group"}, + ] + diffs = list(get_dict_diff_by_key(old, new)) + assert len(diffs) == 1 + assert diffs == [("change", [0, "name"], ("users", "members"))] + + def test_the_same_field_name_order(self) -> None: + old = [ + {"name": "users", "through": "users_group"}, + {"name": "admins", "through": "admins_group"}, + ] + new = [ + {"name": "users", "through": "user_groups"}, + {"name": "admins", "through": "admin_groups"}, + ] + diffs = list(get_dict_diff_by_key(old, new)) + assert len(diffs) == 4 + assert diffs == [ + ("remove", "", [(0, {"name": "users", "through": "users_group"})]), + ("remove", "", [(0, {"name": "admins", "through": "admins_group"})]), + ("add", "", [(0, {"name": "users", "through": "user_groups"})]), + ("add", "", [(0, {"name": "admins", "through": "admin_groups"})]), + ] + + def test_same_field_name_with_different_orders(self) -> None: + old = [ + {"name": "admins", "through": "admins_group"}, + {"name": "users", "through": "users_group"}, + ] + new = [ + {"name": "users", "through": "user_groups"}, + {"name": "admins", "through": "admin_groups"}, + ] + diffs = list(get_dict_diff_by_key(old, new)) + assert len(diffs) == 4 + assert diffs == [ + ("remove", "", [(0, {"name": "admins", "through": "admins_group"})]), + ("remove", "", [(0, {"name": "users", "through": "users_group"})]), + ("add", "", [(0, {"name": "users", "through": "user_groups"})]), + ("add", "", [(0, {"name": "admins", "through": "admin_groups"})]), + ] + + def test_drop_one(self) -> None: + old = [ + {"name": "users", "through": "users_group"}, + {"name": "admins", "through": "admins_group"}, + ] + new = [ + {"name": "admins", "through": "admins_group"}, + ] + diffs = list(get_dict_diff_by_key(old, new)) + assert len(diffs) == 1 + assert diffs == [("remove", "", [(0, {"name": "users", "through": "users_group"})])] + + def test_add_one(self) -> None: + old = [ + {"name": "admins", "through": "admins_group"}, + ] + new = [ + {"name": "users", "through": "users_group"}, + {"name": "admins", "through": "admins_group"}, + ] + diffs = list(get_dict_diff_by_key(old, new)) + assert len(diffs) == 1 + assert diffs == [("add", "", [(0, {"name": "users", "through": "users_group"})])] + + def test_drop_some(self) -> None: + old = [ + {"name": "users", "through": "users_group"}, + {"name": "admins", "through": "admins_group"}, + {"name": "staffs", "through": "staffs_group"}, + ] + new = [ + {"name": "admins", "through": "admins_group"}, + ] + diffs = list(get_dict_diff_by_key(old, new)) + assert len(diffs) == 2 + assert diffs == [ + ("remove", "", [(0, {"name": "users", "through": "users_group"})]), + ("remove", "", [(0, {"name": "staffs", "through": "staffs_group"})]), + ] + + def test_add_some(self) -> None: + old = [ + {"name": "staffs", "through": "staffs_group"}, + ] + new = [ + {"name": "users", "through": "users_group"}, + {"name": "admins", "through": "admins_group"}, + {"name": "staffs", "through": "staffs_group"}, + ] + diffs = list(get_dict_diff_by_key(old, new)) + assert len(diffs) == 2 + assert diffs == [ + ("add", "", [(0, {"name": "users", "through": "users_group"})]), + ("add", "", [(0, {"name": "admins", "through": "admins_group"})]), + ] + + def test_some_through_unchanged(self) -> None: + old = [ + {"name": "staffs", "through": "staffs_group"}, + {"name": "admins", "through": "admins_group"}, + ] + new = [ + {"name": "users", "through": "users_group"}, + {"name": "admins_new", "through": "admins_group"}, + {"name": "staffs_new", "through": "staffs_group"}, + ] + diffs = list(get_dict_diff_by_key(old, new)) + assert len(diffs) == 3 + assert diffs == [ + ("change", [0, "name"], ("staffs", "staffs_new")), + ("change", [0, "name"], ("admins", "admins_new")), + ("add", "", [(0, {"name": "users", "through": "users_group"})]), + ] + + def test_some_unchanged_without_drop_or_add(self) -> None: + old = [ + {"name": "staffs", "through": "staffs_group"}, + {"name": "admins", "through": "admins_group"}, + {"name": "users", "through": "users_group"}, + ] + new = [ + {"name": "users_new", "through": "users_group"}, + {"name": "admins_new", "through": "admins_group"}, + {"name": "staffs_new", "through": "staffs_group"}, + ] + diffs = list(get_dict_diff_by_key(old, new)) + assert len(diffs) == 3 + assert diffs == [ + ("change", [0, "name"], ("staffs", "staffs_new")), + ("change", [0, "name"], ("admins", "admins_new")), + ("change", [0, "name"], ("users", "users_new")), + ]