fix: add m2m field with custom m2m through generate duplicated table when migrating (#393)
* fix: m2m table duplicated when using custom model for through * Add testcase * docs: update changelog * tests: add m2m custom through example test
This commit is contained in:
parent
d8addadb37
commit
4fc7f324d4
@ -5,6 +5,7 @@
|
|||||||
### [0.8.1](Unreleased)
|
### [0.8.1](Unreleased)
|
||||||
|
|
||||||
#### Fixed
|
#### Fixed
|
||||||
|
- Migrate add m2m field with custom through generate duplicated table. (#393)
|
||||||
- Migrate drop the wrong m2m field when model have multi m2m fields. (#376)
|
- Migrate drop the wrong m2m field when model have multi m2m fields. (#376)
|
||||||
- KeyError raised when removing or renaming an existing model (#386)
|
- KeyError raised when removing or renaming an existing model (#386)
|
||||||
- fix: error when there is `__init__.py` in the migration folder (#272)
|
- fix: error when there is `__init__.py` in the migration folder (#272)
|
||||||
|
@ -228,12 +228,18 @@ class Migrate:
|
|||||||
indexes.add(cast(Tuple[str, ...], tuple(x)))
|
indexes.add(cast(Tuple[str, ...], tuple(x)))
|
||||||
return indexes
|
return indexes
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_custom_m2m_through(field: dict) -> None:
|
||||||
|
# TODO: Check whether field includes required fk columns
|
||||||
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _handle_m2m_fields(
|
def _handle_m2m_fields(
|
||||||
cls, old_model_describe: Dict, new_model_describe: Dict, model, new_models, upgrade=True
|
cls, old_model_describe: Dict, new_model_describe: Dict, model, new_models, upgrade=True
|
||||||
) -> None:
|
) -> None:
|
||||||
old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields"))
|
old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields"))
|
||||||
new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields"))
|
new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields"))
|
||||||
|
new_tables: Dict[str, dict] = {field["table"]: field for field in new_models.values()}
|
||||||
for action, option, change in get_dict_diff_by_key(old_m2m_fields, new_m2m_fields):
|
for action, option, change in get_dict_diff_by_key(old_m2m_fields, new_m2m_fields):
|
||||||
if (option and option[-1] == "nullable") or change[0][0] == "db_constraint":
|
if (option and option[-1] == "nullable") or change[0][0] == "db_constraint":
|
||||||
continue
|
continue
|
||||||
@ -247,12 +253,16 @@ class Migrate:
|
|||||||
table = new_value.get("through")
|
table = new_value.get("through")
|
||||||
if action == "add":
|
if action == "add":
|
||||||
add = False
|
add = False
|
||||||
if upgrade and table not in cls._upgrade_m2m:
|
if upgrade:
|
||||||
cls._upgrade_m2m.append(table)
|
if field := new_tables.get(table):
|
||||||
add = True
|
cls._validate_custom_m2m_through(field)
|
||||||
elif not upgrade and table not in cls._downgrade_m2m:
|
elif table not in cls._upgrade_m2m:
|
||||||
cls._downgrade_m2m.append(table)
|
cls._upgrade_m2m.append(table)
|
||||||
add = True
|
add = True
|
||||||
|
else:
|
||||||
|
if table not in cls._downgrade_m2m:
|
||||||
|
cls._downgrade_m2m.append(table)
|
||||||
|
add = True
|
||||||
if add:
|
if add:
|
||||||
ref_desc = cast(dict, new_models.get(new_value.get("model_name")))
|
ref_desc = cast(dict, new_models.get(new_value.get("model_name")))
|
||||||
cls._add_operator(
|
cls._add_operator(
|
||||||
|
@ -130,6 +130,18 @@ async def test_without_age_field() -> None:
|
|||||||
await Foo.create(name=name, age=0)
|
await Foo.create(name=name, age=0)
|
||||||
obj = await Foo.get(name=name)
|
obj = await Foo.get(name=name)
|
||||||
assert getattr(obj, "age", None) is None
|
assert getattr(obj, "age", None) is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_m2m_with_custom_through() -> None:
|
||||||
|
from models import Group, FooGroup
|
||||||
|
name = "4_" + uuid.uuid4().hex
|
||||||
|
foo = await Foo.create(name=name)
|
||||||
|
group = await Group.create(name=name+"1")
|
||||||
|
await FooGroup.all().delete()
|
||||||
|
await foo.groups.add(group)
|
||||||
|
foo_group = await FooGroup.get(foo=foo, group=group)
|
||||||
|
assert not foo_group.is_active
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -236,3 +248,26 @@ def test_sqlite_migrate(tmp_path: Path) -> None:
|
|||||||
config_file.write_text('[project]\nname = "project"')
|
config_file.write_text('[project]\nname = "project"')
|
||||||
run_aerich("init -t settings.TORTOISE_ORM")
|
run_aerich("init -t settings.TORTOISE_ORM")
|
||||||
assert "[tool.aerich]" in config_file.read_text()
|
assert "[tool.aerich]" in config_file.read_text()
|
||||||
|
|
||||||
|
# add m2m with custom model for through
|
||||||
|
new = """
|
||||||
|
groups = fields.ManyToManyField("models.Group", through="foo_group")
|
||||||
|
|
||||||
|
class Group(Model):
|
||||||
|
name = fields.CharField(max_length=60)
|
||||||
|
|
||||||
|
class FooGroup(Model):
|
||||||
|
foo = fields.ForeignKeyField("models.Foo")
|
||||||
|
group = fields.ForeignKeyField("models.Group")
|
||||||
|
is_active = fields.BooleanField(default=False)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
table = "foo_group"
|
||||||
|
"""
|
||||||
|
models_py.write_text(MODELS + new)
|
||||||
|
run_aerich("aerich migrate")
|
||||||
|
run_aerich("aerich upgrade")
|
||||||
|
migration_file_1 = list(migrations_dir.glob("1_*.py"))[0]
|
||||||
|
assert "foo_group" in migration_file_1.read_text()
|
||||||
|
r = run_shell("pytest _test.py::test_m2m_with_custom_through")
|
||||||
|
assert r.returncode == 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user