Support db_constraint in fk

This commit is contained in:
long2ice 2020-09-28 10:40:04 +08:00
parent 43922d3734
commit ce8c0b1f06
7 changed files with 45 additions and 13 deletions

View File

@ -5,6 +5,7 @@
### 0.2.5 ### 0.2.5
- Fix windows support. (#46) - Fix windows support. (#46)
- Support `db_constraint` in fk, m2m should manual define table with fk. (#52)
### 0.2.4 ### 0.2.4

View File

@ -40,7 +40,7 @@ test_sqlite:
$(py_warn) TEST_DB=sqlite://:memory: py.test $(py_warn) TEST_DB=sqlite://:memory: py.test
test_mysql: test_mysql:
$(py_warn) TEST_DB="mysql://root:$(MYSQL_PASS)@$(MYSQL_HOST):$(MYSQL_PORT)/test_\{\}" pytest -v -s $(py_warn) TEST_DB="mysql://root:$(MYSQL_PASS)@$(MYSQL_HOST):$(MYSQL_PORT)/test_\{\}" pytest -vv -s
test_postgres: test_postgres:
$(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest $(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest

View File

@ -2,7 +2,7 @@ from typing import List, Type
from tortoise import BaseDBAsyncClient, ForeignKeyFieldInstance, ManyToManyFieldInstance, Model from tortoise import BaseDBAsyncClient, ForeignKeyFieldInstance, ManyToManyFieldInstance, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.fields import Field, JSONField, TextField, UUIDField from tortoise.fields import CASCADE, Field, JSONField, TextField, UUIDField
class BaseDDL: class BaseDDL:
@ -20,7 +20,7 @@ class BaseDDL:
_DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"' _DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"'
_ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}' _ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"'
_M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE CASCADE){extra}{comment};' _M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment};'
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}' _MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}'
def __init__(self, client: "BaseDBAsyncClient"): def __init__(self, client: "BaseDBAsyncClient"):
@ -44,6 +44,7 @@ class BaseDDL:
backward_type=model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"), backward_type=model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"),
forward_key=field.forward_key, forward_key=field.forward_key,
forward_type=field.related_model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"), forward_type=field.related_model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"),
on_delete=CASCADE,
extra=self.schema_generator._table_generate_extra(table=field.through), extra=self.schema_generator._table_generate_extra(table=field.through),
comment=self.schema_generator._table_comment_generator( comment=self.schema_generator._table_comment_generator(
table=field.through, comment=field.description table=field.through, comment=field.description

View File

@ -132,7 +132,7 @@ class Migrate:
return await cls._generate_diff_sql(name) return await cls._generate_diff_sql(name)
@classmethod @classmethod
def _add_operator(cls, operator: str, upgrade=True, fk=False): def _add_operator(cls, operator: str, upgrade=True, fk_m2m=False):
""" """
add operator,differentiate fk because fk is order limit add operator,differentiate fk because fk is order limit
:param operator: :param operator:
@ -141,12 +141,12 @@ class Migrate:
:return: :return:
""" """
if upgrade: if upgrade:
if fk: if fk_m2m:
cls._upgrade_fk_m2m_index_operators.append(operator) cls._upgrade_fk_m2m_index_operators.append(operator)
else: else:
cls.upgrade_operators.append(operator) cls.upgrade_operators.append(operator)
else: else:
if fk: if fk_m2m:
cls._downgrade_fk_m2m_index_operators.append(operator) cls._downgrade_fk_m2m_index_operators.append(operator)
else: else:
cls.downgrade_operators.append(operator) cls.downgrade_operators.append(operator)
@ -268,13 +268,13 @@ class Migrate:
continue continue
if new_key not in old_keys: if new_key not in old_keys:
new_field_dict = new_field.describe(serializable=True) new_field_dict = new_field.describe(serializable=True)
new_field_dict.pop("name") new_field_dict.pop("name", None)
new_field_dict.pop("db_column") new_field_dict.pop("db_column", None)
for diff_key in old_keys - new_keys: for diff_key in old_keys - new_keys:
old_field = old_fields_map.get(diff_key) old_field = old_fields_map.get(diff_key)
old_field_dict = old_field.describe(serializable=True) old_field_dict = old_field.describe(serializable=True)
old_field_dict.pop("name") old_field_dict.pop("name", None)
old_field_dict.pop("db_column") old_field_dict.pop("db_column", None)
if old_field_dict == new_field_dict: if old_field_dict == new_field_dict:
if upgrade: if upgrade:
is_rename = click.prompt( is_rename = click.prompt(
@ -294,9 +294,7 @@ class Migrate:
break break
else: else:
cls._add_operator( cls._add_operator(
cls._add_field(new_model, new_field), cls._add_field(new_model, new_field), upgrade, cls._is_fk_m2m(new_field),
upgrade,
isinstance(new_field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)),
) )
else: else:
old_field = old_fields_map.get(new_key) old_field = old_fields_map.get(new_key)
@ -344,6 +342,15 @@ class Migrate:
upgrade, upgrade,
cls._is_fk_m2m(new_field), cls._is_fk_m2m(new_field),
) )
if isinstance(new_field, ForeignKeyFieldInstance):
if old_field.db_constraint and not new_field.db_constraint:
cls._add_operator(
cls._drop_fk(new_model, new_field), upgrade, True,
)
if new_field.db_constraint and not old_field.db_constraint:
cls._add_operator(
cls._add_fk(new_model, new_field), upgrade, True,
)
for old_key in old_keys: for old_key in old_keys:
field = old_fields_map.get(old_key) field = old_fields_map.get(old_key)
@ -437,6 +444,10 @@ class Migrate:
def _modify_field(cls, model: Type[Model], field: Field): def _modify_field(cls, model: Type[Model], field: Field):
return cls.ddl.modify_column(model, field) return cls.ddl.modify_column(model, field)
@classmethod
def _drop_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance):
return cls.ddl.drop_fk(model, field)
@classmethod @classmethod
def _remove_field(cls, model: Type[Model], field: Field): def _remove_field(cls, model: Type[Model], field: Field):
if isinstance(field, ForeignKeyFieldInstance): if isinstance(field, ForeignKeyFieldInstance):

View File

@ -31,6 +31,12 @@ class User(Model):
intro = fields.TextField(default="") intro = fields.TextField(default="")
class Email(Model):
email = fields.CharField(max_length=200)
is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("diff_models.User", db_constraint=True)
class Category(Model): class Category(Model):
slug = fields.CharField(max_length=200) slug = fields.CharField(max_length=200)
user = fields.ForeignKeyField("diff_models.User", description="User") user = fields.ForeignKeyField("diff_models.User", description="User")

View File

@ -31,6 +31,12 @@ class User(Model):
intro = fields.TextField(default="") intro = fields.TextField(default="")
class Email(Model):
email = fields.CharField(max_length=200)
is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("models.User", db_constraint=False)
class Category(Model): class Category(Model):
slug = fields.CharField(max_length=200) slug = fields.CharField(max_length=200)
name = fields.CharField(max_length=200) name = fields.CharField(max_length=200)

View File

@ -20,8 +20,10 @@ def test_migrate(mocker: MockerFixture):
Migrate.diff_models(models, diff_models, False) Migrate.diff_models(models, diff_models, False)
else: else:
Migrate.diff_models(models, diff_models, False) Migrate.diff_models(models, diff_models, False)
Migrate._merge_operators()
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert Migrate.upgrade_operators == [ assert Migrate.upgrade_operators == [
"ALTER TABLE `email` DROP FOREIGN KEY `fk_email_user_5b58673d`",
"ALTER TABLE `category` ADD `name` VARCHAR(200) NOT NULL", "ALTER TABLE `category` ADD `name` VARCHAR(200) NOT NULL",
"ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)", "ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)",
"ALTER TABLE `user` RENAME COLUMN `last_login_at` TO `last_login`", "ALTER TABLE `user` RENAME COLUMN `last_login_at` TO `last_login`",
@ -30,9 +32,12 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `category` DROP COLUMN `name`", "ALTER TABLE `category` DROP COLUMN `name`",
"ALTER TABLE `user` DROP INDEX `uid_user_usernam_9987ab`", "ALTER TABLE `user` DROP INDEX `uid_user_usernam_9987ab`",
"ALTER TABLE `user` RENAME COLUMN `last_login` TO `last_login_at`", "ALTER TABLE `user` RENAME COLUMN `last_login` TO `last_login_at`",
"ALTER TABLE `email` ADD CONSTRAINT `fk_email_user_5b58673d` FOREIGN KEY "
"(`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
] ]
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert Migrate.upgrade_operators == [ assert Migrate.upgrade_operators == [
'ALTER TABLE "email" DROP CONSTRAINT "fk_email_user_5b58673d"',
'ALTER TABLE "category" ADD "name" VARCHAR(200) NOT NULL', 'ALTER TABLE "category" ADD "name" VARCHAR(200) NOT NULL',
'ALTER TABLE "user" ADD CONSTRAINT "uid_user_usernam_9987ab" UNIQUE ("username")', 'ALTER TABLE "user" ADD CONSTRAINT "uid_user_usernam_9987ab" UNIQUE ("username")',
'ALTER TABLE "user" RENAME COLUMN "last_login_at" TO "last_login"', 'ALTER TABLE "user" RENAME COLUMN "last_login_at" TO "last_login"',
@ -41,9 +46,11 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "category" DROP COLUMN "name"', 'ALTER TABLE "category" DROP COLUMN "name"',
'ALTER TABLE "user" DROP CONSTRAINT "uid_user_usernam_9987ab"', 'ALTER TABLE "user" DROP CONSTRAINT "uid_user_usernam_9987ab"',
'ALTER TABLE "user" RENAME COLUMN "last_login" TO "last_login_at"', 'ALTER TABLE "user" RENAME COLUMN "last_login" TO "last_login_at"',
'ALTER TABLE "email" ADD CONSTRAINT "fk_email_user_5b58673d" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE',
] ]
elif isinstance(Migrate.ddl, SqliteDDL): elif isinstance(Migrate.ddl, SqliteDDL):
assert Migrate.upgrade_operators == [ assert Migrate.upgrade_operators == [
'ALTER TABLE "email" DROP FOREIGN KEY "fk_email_user_5b58673d"',
'ALTER TABLE "category" ADD "name" VARCHAR(200) NOT NULL', 'ALTER TABLE "category" ADD "name" VARCHAR(200) NOT NULL',
'ALTER TABLE "user" ADD UNIQUE INDEX "uid_user_usernam_9987ab" ("username")', 'ALTER TABLE "user" ADD UNIQUE INDEX "uid_user_usernam_9987ab" ("username")',
'ALTER TABLE "user" RENAME COLUMN "last_login_at" TO "last_login"', 'ALTER TABLE "user" RENAME COLUMN "last_login_at" TO "last_login"',