From 4fc7f324d4031264ae0d4fd4b7866155ab8d7074 Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Tue, 17 Dec 2024 22:28:06 +0800 Subject: [PATCH] fix: add m2m field with custom m2m through generate duplicated table when migrating (#393) * fix: m2m table duplicated when using custom model for through * Add testcase * docs: update changelog * tests: add m2m custom through example test --- CHANGELOG.md | 1 + aerich/migrate.py | 22 ++++++++++++++++------ tests/test_sqlite_migrate.py | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ddf1e3..85e5e65 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### [0.8.1](Unreleased) #### Fixed +- Migrate add m2m field with custom through generate duplicated table. (#393) - Migrate drop the wrong m2m field when model have multi m2m fields. (#376) - KeyError raised when removing or renaming an existing model (#386) - fix: error when there is `__init__.py` in the migration folder (#272) diff --git a/aerich/migrate.py b/aerich/migrate.py index 397e62a..11d7b03 100644 --- a/aerich/migrate.py +++ b/aerich/migrate.py @@ -228,12 +228,18 @@ class Migrate: indexes.add(cast(Tuple[str, ...], tuple(x))) return indexes + @staticmethod + def _validate_custom_m2m_through(field: dict) -> None: + # TODO: Check whether field includes required fk columns + pass + @classmethod def _handle_m2m_fields( cls, old_model_describe: Dict, new_model_describe: Dict, model, new_models, upgrade=True ) -> None: old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields")) new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields")) + new_tables: Dict[str, dict] = {field["table"]: field for field in new_models.values()} for action, option, change in get_dict_diff_by_key(old_m2m_fields, new_m2m_fields): if (option and option[-1] == "nullable") or change[0][0] == "db_constraint": continue @@ -247,12 +253,16 @@ class Migrate: table = new_value.get("through") if action == "add": add = False - if upgrade and table not in cls._upgrade_m2m: - cls._upgrade_m2m.append(table) - add = True - elif not upgrade and table not in cls._downgrade_m2m: - cls._downgrade_m2m.append(table) - add = True + if upgrade: + if field := new_tables.get(table): + cls._validate_custom_m2m_through(field) + elif table not in cls._upgrade_m2m: + cls._upgrade_m2m.append(table) + add = True + else: + if table not in cls._downgrade_m2m: + cls._downgrade_m2m.append(table) + add = True if add: ref_desc = cast(dict, new_models.get(new_value.get("model_name"))) cls._add_operator( diff --git a/tests/test_sqlite_migrate.py b/tests/test_sqlite_migrate.py index 568f273..9f525eb 100644 --- a/tests/test_sqlite_migrate.py +++ b/tests/test_sqlite_migrate.py @@ -130,6 +130,18 @@ async def test_without_age_field() -> None: await Foo.create(name=name, age=0) obj = await Foo.get(name=name) assert getattr(obj, "age", None) is None + + +@pytest.mark.asyncio +async def test_m2m_with_custom_through() -> None: + from models import Group, FooGroup + name = "4_" + uuid.uuid4().hex + foo = await Foo.create(name=name) + group = await Group.create(name=name+"1") + await FooGroup.all().delete() + await foo.groups.add(group) + foo_group = await FooGroup.get(foo=foo, group=group) + assert not foo_group.is_active """ @@ -236,3 +248,26 @@ def test_sqlite_migrate(tmp_path: Path) -> None: config_file.write_text('[project]\nname = "project"') run_aerich("init -t settings.TORTOISE_ORM") assert "[tool.aerich]" in config_file.read_text() + + # add m2m with custom model for through + new = """ + groups = fields.ManyToManyField("models.Group", through="foo_group") + +class Group(Model): + name = fields.CharField(max_length=60) + +class FooGroup(Model): + foo = fields.ForeignKeyField("models.Foo") + group = fields.ForeignKeyField("models.Group") + is_active = fields.BooleanField(default=False) + + class Meta: + table = "foo_group" + """ + models_py.write_text(MODELS + new) + run_aerich("aerich migrate") + run_aerich("aerich upgrade") + migration_file_1 = list(migrations_dir.glob("1_*.py"))[0] + assert "foo_group" in migration_file_1.read_text() + r = run_shell("pytest _test.py::test_m2m_with_custom_through") + assert r.returncode == 0