diff --git a/aerich/ddl/__init__.py b/aerich/ddl/__init__.py index 192bb3e..cd9da5f 100644 --- a/aerich/ddl/__init__.py +++ b/aerich/ddl/__init__.py @@ -1,9 +1,8 @@ from enum import Enum from typing import List, Type -from tortoise import BaseDBAsyncClient, ManyToManyFieldInstance, Model +from tortoise import BaseDBAsyncClient, Model from tortoise.backends.base.schema_generator import BaseSchemaGenerator -from tortoise.fields import CASCADE class BaseDDL: @@ -22,7 +21,7 @@ class BaseDDL: _DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"' _ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"' - _M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment};' + _M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment}' _MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}' _CHANGE_COLUMN_TEMPLATE = ( 'ALTER TABLE "{table_name}" CHANGE {old_column_name} {new_column_name} {new_column_type}' @@ -38,28 +37,34 @@ class BaseDDL: def drop_table(self, model: "Type[Model]"): return self._DROP_TABLE_TEMPLATE.format(table_name=model._meta.db_table) - def create_m2m_table(self, model: "Type[Model]", field: ManyToManyFieldInstance): + def create_m2m( + self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict + ): + through = field_describe.get("through") + description = field_describe.get("description") + reference_id = reference_table_describe.get("pk_field").get("db_column") + db_field_types = reference_table_describe.get("pk_field").get("db_field_types") return self._M2M_TABLE_TEMPLATE.format( - table_name=field.through, + table_name=through, backward_table=model._meta.db_table, - forward_table=field.related_model._meta.db_table, + forward_table=reference_table_describe.get("table"), backward_field=model._meta.db_pk_column, - forward_field=field.related_model._meta.db_pk_column, - backward_key=field.backward_key, + forward_field=reference_id, + backward_key=field_describe.get("backward_key"), backward_type=model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"), - forward_key=field.forward_key, - forward_type=field.related_model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"), - on_delete=CASCADE, - extra=self.schema_generator._table_generate_extra(table=field.through), + forward_key=field_describe.get("forward_key"), + forward_type=db_field_types.get(self.DIALECT) or db_field_types.get(""), + on_delete=field_describe.get("on_delete"), + extra=self.schema_generator._table_generate_extra(table=through), comment=self.schema_generator._table_comment_generator( - table=field.through, comment=field.description + table=through, comment=description ) - if field.description + if description else "", ) - def drop_m2m(self, field: ManyToManyFieldInstance): - return self._DROP_TABLE_TEMPLATE.format(table_name=field.through) + def drop_m2m(self, table_name: str): + return self._DROP_TABLE_TEMPLATE.format(table_name=table_name) def _get_default(self, model: "Type[Model]", field_describe: dict): db_table = model._meta.db_table diff --git a/aerich/ddl/mysql/__init__.py b/aerich/ddl/mysql/__init__.py index daa50d8..d6aa1fa 100644 --- a/aerich/ddl/mysql/__init__.py +++ b/aerich/ddl/mysql/__init__.py @@ -26,7 +26,7 @@ class MysqlDDL(BaseDDL): _DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`" _ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}" _DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`" - _M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment};" + _M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment}" _MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}" def alter_column_null(self, model: "Type[Model]", field_describe: dict): diff --git a/aerich/migrate.py b/aerich/migrate.py index b0c8eeb..a41fc0b 100644 --- a/aerich/migrate.py +++ b/aerich/migrate.py @@ -183,7 +183,39 @@ class Migrate: # current only support rename pk if action == "change" and option == "name": cls._add_operator(cls._rename_field(model, *change), upgrade) - + # m2m fields + old_m2m_fields = old_model_describe.get("m2m_fields") + new_m2m_fields = new_model_describe.get("m2m_fields") + for action, option, change in diff(old_m2m_fields, new_m2m_fields): + table = change[0][1].get("through") + if action == "add": + add = False + if upgrade and table not in cls._upgrade_m2m: + cls._upgrade_m2m.append(table) + add = True + elif not upgrade and table not in cls._downgrade_m2m: + cls._downgrade_m2m.append(table) + add = True + if add: + cls._add_operator( + cls.create_m2m( + model, + change[0][1], + new_models.get(change[0][1].get("model_name")), + ), + upgrade, + fk_m2m=True, + ) + elif action == "remove": + add = False + if upgrade and table not in cls._upgrade_m2m: + cls._upgrade_m2m.append(table) + add = True + elif not upgrade and table not in cls._downgrade_m2m: + cls._downgrade_m2m.append(table) + add = True + if add: + cls._add_operator(cls.drop_m2m(table), upgrade, fk_m2m=True) # add unique_together for index in set(new_unique_together).difference(set(old_unique_together)): cls._add_operator( @@ -298,6 +330,7 @@ class Migrate: cls._add_operator( cls._add_fk(model, fk_field, old_models.get(fk_field.get("python_type"))), upgrade, + fk_m2m=True, ) # drop fk for old_fk_field_name in set(old_fk_fields_name).difference( @@ -311,6 +344,7 @@ class Migrate: model, old_fk_field, old_models.get(old_fk_field.get("python_type")) ), upgrade, + fk_m2m=True, ) # change fields for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)): @@ -357,6 +391,14 @@ class Migrate: def remove_model(cls, model: Type[Model]): return cls.ddl.drop_table(model) + @classmethod + def create_m2m(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict): + return cls.ddl.create_m2m(model, field_describe, reference_table_describe) + + @classmethod + def drop_m2m(cls, table_name: str): + return cls.ddl.drop_m2m(table_name) + @classmethod def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]): ret = [] diff --git a/conftest.py b/conftest.py index 4eee9e1..d00b512 100644 --- a/conftest.py +++ b/conftest.py @@ -20,7 +20,10 @@ tortoise_orm = { "second": expand_db_url(db_url_second, True), }, "apps": { - "models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"}, + "models": { + "models": ["tests.models", "aerich.models"], + "default_connection": "default", + }, "models_second": {"models": ["tests.models_second"], "default_connection": "second"}, }, } diff --git a/poetry.lock b/poetry.lock index 7970cba..8d867bb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -635,6 +635,7 @@ click = [ ] colorama = [ {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, + {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, ] ddlparse = [ {file = "ddlparse-1.9.0-py3-none-any.whl", hash = "sha256:a7962615a9325be7d0f182cbe34011e6283996473fb98c784c6f675b9783bc18"}, @@ -665,6 +666,7 @@ importlib-metadata = [ {file = "importlib_metadata-3.4.0.tar.gz", hash = "sha256:fa5daa4477a7414ae34e95942e4dd07f62adf589143c875c133c1e53c4eff38d"}, ] iniconfig = [ + {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, ] iso8601 = [ diff --git a/tests/models.py b/tests/models.py index adc47c1..81a2524 100644 --- a/tests/models.py +++ b/tests/models.py @@ -35,6 +35,7 @@ class Email(Model): email = fields.CharField(max_length=200, index=True) is_primary = fields.BooleanField(default=False) address = fields.CharField(max_length=200) + users = fields.ManyToManyField("models.User") class Category(Model): diff --git a/tests/test_migrate.py b/tests/test_migrate.py index 0adf2e4..2ada0aa 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -81,6 +81,8 @@ old_models_describe = { "mysql": "DATETIME(6)", "postgres": "TIMESTAMPTZ", }, + "auto_now_add": True, + "auto_now": False, }, { "name": "user_id", @@ -131,6 +133,13 @@ old_models_describe = { "description": None, "docstring": None, "constraints": {}, + "model_name": "models.Product", + "related_name": "categories", + "forward_key": "product_id", + "backward_key": "category_id", + "through": "product_category", + "on_delete": "CASCADE", + "_generated": True, } ], }, @@ -464,6 +473,8 @@ old_models_describe = { "mysql": "DATETIME(6)", "postgres": "TIMESTAMPTZ", }, + "auto_now_add": True, + "auto_now": False, }, ], "fk_fields": [], @@ -483,6 +494,13 @@ old_models_describe = { "description": None, "docstring": None, "constraints": {}, + "model_name": "models.Category", + "related_name": "products", + "forward_key": "category_id", + "backward_key": "product_id", + "through": "product_category", + "on_delete": "CASCADE", + "_generated": False, } ], }, @@ -558,6 +576,8 @@ old_models_describe = { "mysql": "DATETIME(6)", "postgres": "TIMESTAMPTZ", }, + "auto_now_add": False, + "auto_now": False, }, { "name": "is_active", @@ -741,6 +761,7 @@ def test_migrate(mocker: MockerFixture): - drop fk: Email.user - drop field: User.avatar - add index: Email.email + - add many to many: Email.users - remove unique: User.username - change column: length User.password - add unique_together: (name,type) of Product @@ -776,6 +797,7 @@ def test_migrate(mocker: MockerFixture): "ALTER TABLE `user` DROP COLUMN `avatar`", "ALTER TABLE `user` CHANGE password password VARCHAR(100)", "ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)", + "CREATE TABLE `email_user` (`email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,`user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE) CHARACTER SET utf8mb4", ] ) @@ -795,6 +817,7 @@ def test_migrate(mocker: MockerFixture): "ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''", "ALTER TABLE `user` DROP INDEX `idx_user_usernam_9987ab`", "ALTER TABLE `user` CHANGE password password VARCHAR(200)", + "DROP TABLE IF EXISTS `email_user`", ] ) @@ -815,6 +838,7 @@ def test_migrate(mocker: MockerFixture): 'ALTER TABLE "user" DROP COLUMN "avatar"', 'ALTER TABLE "user" CHANGE password password VARCHAR(100)', 'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")', + 'CREATE TABLE "email_user" ("email_id" INT NOT NULL REFERENCES "email" ("email_id") ON DELETE CASCADE,"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE)', ] ) assert sorted(Migrate.downgrade_operators) == sorted( @@ -833,6 +857,7 @@ def test_migrate(mocker: MockerFixture): 'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'', 'DROP INDEX "idx_user_usernam_9987ab"', 'ALTER TABLE "user" CHANGE password password VARCHAR(200)', + 'DROP TABLE IF EXISTS "email_user"', ] ) elif isinstance(Migrate.ddl, SqliteDDL):