From ce8c0b1f06167283d5b52d17462d6e164c5ec721 Mon Sep 17 00:00:00 2001 From: long2ice Date: Mon, 28 Sep 2020 10:40:04 +0800 Subject: [PATCH] Support `db_constraint` in fk --- CHANGELOG.md | 1 + Makefile | 2 +- aerich/ddl/__init__.py | 5 +++-- aerich/migrate.py | 31 +++++++++++++++++++++---------- tests/diff_models.py | 6 ++++++ tests/models.py | 6 ++++++ tests/test_migrate.py | 7 +++++++ 7 files changed, 45 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9295cb5..2fc071e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### 0.2.5 - Fix windows support. (#46) +- Support `db_constraint` in fk, m2m should manual define table with fk. (#52) ### 0.2.4 diff --git a/Makefile b/Makefile index 3e5495c..9e81f72 100644 --- a/Makefile +++ b/Makefile @@ -40,7 +40,7 @@ test_sqlite: $(py_warn) TEST_DB=sqlite://:memory: py.test test_mysql: - $(py_warn) TEST_DB="mysql://root:$(MYSQL_PASS)@$(MYSQL_HOST):$(MYSQL_PORT)/test_\{\}" pytest -v -s + $(py_warn) TEST_DB="mysql://root:$(MYSQL_PASS)@$(MYSQL_HOST):$(MYSQL_PORT)/test_\{\}" pytest -vv -s test_postgres: $(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest diff --git a/aerich/ddl/__init__.py b/aerich/ddl/__init__.py index 83a923a..9ed462e 100644 --- a/aerich/ddl/__init__.py +++ b/aerich/ddl/__init__.py @@ -2,7 +2,7 @@ from typing import List, Type from tortoise import BaseDBAsyncClient, ForeignKeyFieldInstance, ManyToManyFieldInstance, Model from tortoise.backends.base.schema_generator import BaseSchemaGenerator -from tortoise.fields import Field, JSONField, TextField, UUIDField +from tortoise.fields import CASCADE, Field, JSONField, TextField, UUIDField class BaseDDL: @@ -20,7 +20,7 @@ class BaseDDL: _DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"' _ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"' - _M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE CASCADE){extra}{comment};' + _M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment};' _MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}' def __init__(self, client: "BaseDBAsyncClient"): @@ -44,6 +44,7 @@ class BaseDDL: backward_type=model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"), forward_key=field.forward_key, forward_type=field.related_model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"), + on_delete=CASCADE, extra=self.schema_generator._table_generate_extra(table=field.through), comment=self.schema_generator._table_comment_generator( table=field.through, comment=field.description diff --git a/aerich/migrate.py b/aerich/migrate.py index d5d109a..0303e6d 100644 --- a/aerich/migrate.py +++ b/aerich/migrate.py @@ -132,7 +132,7 @@ class Migrate: return await cls._generate_diff_sql(name) @classmethod - def _add_operator(cls, operator: str, upgrade=True, fk=False): + def _add_operator(cls, operator: str, upgrade=True, fk_m2m=False): """ add operator,differentiate fk because fk is order limit :param operator: @@ -141,12 +141,12 @@ class Migrate: :return: """ if upgrade: - if fk: + if fk_m2m: cls._upgrade_fk_m2m_index_operators.append(operator) else: cls.upgrade_operators.append(operator) else: - if fk: + if fk_m2m: cls._downgrade_fk_m2m_index_operators.append(operator) else: cls.downgrade_operators.append(operator) @@ -268,13 +268,13 @@ class Migrate: continue if new_key not in old_keys: new_field_dict = new_field.describe(serializable=True) - new_field_dict.pop("name") - new_field_dict.pop("db_column") + new_field_dict.pop("name", None) + new_field_dict.pop("db_column", None) for diff_key in old_keys - new_keys: old_field = old_fields_map.get(diff_key) old_field_dict = old_field.describe(serializable=True) - old_field_dict.pop("name") - old_field_dict.pop("db_column") + old_field_dict.pop("name", None) + old_field_dict.pop("db_column", None) if old_field_dict == new_field_dict: if upgrade: is_rename = click.prompt( @@ -294,9 +294,7 @@ class Migrate: break else: cls._add_operator( - cls._add_field(new_model, new_field), - upgrade, - isinstance(new_field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)), + cls._add_field(new_model, new_field), upgrade, cls._is_fk_m2m(new_field), ) else: old_field = old_fields_map.get(new_key) @@ -344,6 +342,15 @@ class Migrate: upgrade, cls._is_fk_m2m(new_field), ) + if isinstance(new_field, ForeignKeyFieldInstance): + if old_field.db_constraint and not new_field.db_constraint: + cls._add_operator( + cls._drop_fk(new_model, new_field), upgrade, True, + ) + if new_field.db_constraint and not old_field.db_constraint: + cls._add_operator( + cls._add_fk(new_model, new_field), upgrade, True, + ) for old_key in old_keys: field = old_fields_map.get(old_key) @@ -437,6 +444,10 @@ class Migrate: def _modify_field(cls, model: Type[Model], field: Field): return cls.ddl.modify_column(model, field) + @classmethod + def _drop_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance): + return cls.ddl.drop_fk(model, field) + @classmethod def _remove_field(cls, model: Type[Model], field: Field): if isinstance(field, ForeignKeyFieldInstance): diff --git a/tests/diff_models.py b/tests/diff_models.py index b249cc9..6c591a4 100644 --- a/tests/diff_models.py +++ b/tests/diff_models.py @@ -31,6 +31,12 @@ class User(Model): intro = fields.TextField(default="") +class Email(Model): + email = fields.CharField(max_length=200) + is_primary = fields.BooleanField(default=False) + user = fields.ForeignKeyField("diff_models.User", db_constraint=True) + + class Category(Model): slug = fields.CharField(max_length=200) user = fields.ForeignKeyField("diff_models.User", description="User") diff --git a/tests/models.py b/tests/models.py index c6f5073..a464776 100644 --- a/tests/models.py +++ b/tests/models.py @@ -31,6 +31,12 @@ class User(Model): intro = fields.TextField(default="") +class Email(Model): + email = fields.CharField(max_length=200) + is_primary = fields.BooleanField(default=False) + user = fields.ForeignKeyField("models.User", db_constraint=False) + + class Category(Model): slug = fields.CharField(max_length=200) name = fields.CharField(max_length=200) diff --git a/tests/test_migrate.py b/tests/test_migrate.py index be31788..e063b0c 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -20,8 +20,10 @@ def test_migrate(mocker: MockerFixture): Migrate.diff_models(models, diff_models, False) else: Migrate.diff_models(models, diff_models, False) + Migrate._merge_operators() if isinstance(Migrate.ddl, MysqlDDL): assert Migrate.upgrade_operators == [ + "ALTER TABLE `email` DROP FOREIGN KEY `fk_email_user_5b58673d`", "ALTER TABLE `category` ADD `name` VARCHAR(200) NOT NULL", "ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)", "ALTER TABLE `user` RENAME COLUMN `last_login_at` TO `last_login`", @@ -30,9 +32,12 @@ def test_migrate(mocker: MockerFixture): "ALTER TABLE `category` DROP COLUMN `name`", "ALTER TABLE `user` DROP INDEX `uid_user_usernam_9987ab`", "ALTER TABLE `user` RENAME COLUMN `last_login` TO `last_login_at`", + "ALTER TABLE `email` ADD CONSTRAINT `fk_email_user_5b58673d` FOREIGN KEY " + "(`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE", ] elif isinstance(Migrate.ddl, PostgresDDL): assert Migrate.upgrade_operators == [ + 'ALTER TABLE "email" DROP CONSTRAINT "fk_email_user_5b58673d"', 'ALTER TABLE "category" ADD "name" VARCHAR(200) NOT NULL', 'ALTER TABLE "user" ADD CONSTRAINT "uid_user_usernam_9987ab" UNIQUE ("username")', 'ALTER TABLE "user" RENAME COLUMN "last_login_at" TO "last_login"', @@ -41,9 +46,11 @@ def test_migrate(mocker: MockerFixture): 'ALTER TABLE "category" DROP COLUMN "name"', 'ALTER TABLE "user" DROP CONSTRAINT "uid_user_usernam_9987ab"', 'ALTER TABLE "user" RENAME COLUMN "last_login" TO "last_login_at"', + 'ALTER TABLE "email" ADD CONSTRAINT "fk_email_user_5b58673d" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE', ] elif isinstance(Migrate.ddl, SqliteDDL): assert Migrate.upgrade_operators == [ + 'ALTER TABLE "email" DROP FOREIGN KEY "fk_email_user_5b58673d"', 'ALTER TABLE "category" ADD "name" VARCHAR(200) NOT NULL', 'ALTER TABLE "user" ADD UNIQUE INDEX "uid_user_usernam_9987ab" ("username")', 'ALTER TABLE "user" RENAME COLUMN "last_login_at" TO "last_login"',