From 69ce0cafa153e515dfc18c5318e8a71360d939fd Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Wed, 18 Dec 2024 00:13:19 +0800 Subject: [PATCH] fix: intermediate table for m2m relation not created (#394) * fix: intermediate table for m2m relation not created * Add unittest * docs: update changelog --- CHANGELOG.md | 1 + aerich/migrate.py | 6 +++--- tests/test_sqlite_migrate.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 85e5e65..5029231 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### [0.8.1](Unreleased) #### Fixed +- fix: intermediate table for m2m relation not created. (#394) - Migrate add m2m field with custom through generate duplicated table. (#393) - Migrate drop the wrong m2m field when model have multi m2m fields. (#376) - KeyError raised when removing or renaming an existing model (#386) diff --git a/aerich/migrate.py b/aerich/migrate.py index 11d7b03..a24b72f 100644 --- a/aerich/migrate.py +++ b/aerich/migrate.py @@ -237,8 +237,8 @@ class Migrate: def _handle_m2m_fields( cls, old_model_describe: Dict, new_model_describe: Dict, model, new_models, upgrade=True ) -> None: - old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields")) - new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields")) + old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields", [])) + new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields", [])) new_tables: Dict[str, dict] = {field["table"]: field for field in new_models.values()} for action, option, change in get_dict_diff_by_key(old_m2m_fields, new_m2m_fields): if (option and option[-1] == "nullable") or change[0][0] == "db_constraint": @@ -298,10 +298,10 @@ class Migrate: for new_model_str, new_model_describe in new_models.items(): model = cls._get_model(new_model_describe["name"].split(".")[1]) - if new_model_str not in old_models: if upgrade: cls._add_operator(cls.add_model(model), upgrade) + cls._handle_m2m_fields({}, new_model_describe, model, new_models, upgrade) else: # we can't find origin model when downgrade, so skip pass diff --git a/tests/test_sqlite_migrate.py b/tests/test_sqlite_migrate.py index 9f525eb..01b9582 100644 --- a/tests/test_sqlite_migrate.py +++ b/tests/test_sqlite_migrate.py @@ -142,6 +142,16 @@ async def test_m2m_with_custom_through() -> None: await foo.groups.add(group) foo_group = await FooGroup.get(foo=foo, group=group) assert not foo_group.is_active + + +@pytest.mark.asyncio +async def test_add_m2m_field_after_init_db() -> None: + from models import Group + name = "5_" + uuid.uuid4().hex + foo = await Foo.create(name=name) + group = await Group.create(name=name+"1") + await foo.groups.add(group) + assert (await group.users.all().first()) == foo """ @@ -271,3 +281,24 @@ class FooGroup(Model): assert "foo_group" in migration_file_1.read_text() r = run_shell("pytest _test.py::test_m2m_with_custom_through") assert r.returncode == 0 + + # add m2m field after init-db + new = """ + groups = fields.ManyToManyField("models.Group", through="foo_group", related_name="users") + +class Group(Model): + name = fields.CharField(max_length=60) + """ + if db_file.exists(): + db_file.unlink() + if migrations_dir.exists(): + shutil.rmtree(migrations_dir) + models_py.write_text(MODELS) + run_aerich("aerich init-db") + models_py.write_text(MODELS + new) + run_aerich("aerich migrate") + run_aerich("aerich upgrade") + migration_file_1 = list(migrations_dir.glob("1_*.py"))[0] + assert "foo_group" in migration_file_1.read_text() + r = run_shell("pytest _test.py::test_add_m2m_field_after_init_db") + assert r.returncode == 0