Support db_constraint in fk

This commit is contained in:
long2ice
2020-09-28 10:40:04 +08:00
parent 43922d3734
commit ce8c0b1f06
7 changed files with 45 additions and 13 deletions

View File

@@ -2,7 +2,7 @@ from typing import List, Type
from tortoise import BaseDBAsyncClient, ForeignKeyFieldInstance, ManyToManyFieldInstance, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.fields import Field, JSONField, TextField, UUIDField
from tortoise.fields import CASCADE, Field, JSONField, TextField, UUIDField
class BaseDDL:
@@ -20,7 +20,7 @@ class BaseDDL:
_DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"'
_ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"'
_M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE CASCADE){extra}{comment};'
_M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment};'
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}'
def __init__(self, client: "BaseDBAsyncClient"):
@@ -44,6 +44,7 @@ class BaseDDL:
backward_type=model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"),
forward_key=field.forward_key,
forward_type=field.related_model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"),
on_delete=CASCADE,
extra=self.schema_generator._table_generate_extra(table=field.through),
comment=self.schema_generator._table_comment_generator(
table=field.through, comment=field.description

View File

@@ -132,7 +132,7 @@ class Migrate:
return await cls._generate_diff_sql(name)
@classmethod
def _add_operator(cls, operator: str, upgrade=True, fk=False):
def _add_operator(cls, operator: str, upgrade=True, fk_m2m=False):
"""
add operator,differentiate fk because fk is order limit
:param operator:
@@ -141,12 +141,12 @@ class Migrate:
:return:
"""
if upgrade:
if fk:
if fk_m2m:
cls._upgrade_fk_m2m_index_operators.append(operator)
else:
cls.upgrade_operators.append(operator)
else:
if fk:
if fk_m2m:
cls._downgrade_fk_m2m_index_operators.append(operator)
else:
cls.downgrade_operators.append(operator)
@@ -268,13 +268,13 @@ class Migrate:
continue
if new_key not in old_keys:
new_field_dict = new_field.describe(serializable=True)
new_field_dict.pop("name")
new_field_dict.pop("db_column")
new_field_dict.pop("name", None)
new_field_dict.pop("db_column", None)
for diff_key in old_keys - new_keys:
old_field = old_fields_map.get(diff_key)
old_field_dict = old_field.describe(serializable=True)
old_field_dict.pop("name")
old_field_dict.pop("db_column")
old_field_dict.pop("name", None)
old_field_dict.pop("db_column", None)
if old_field_dict == new_field_dict:
if upgrade:
is_rename = click.prompt(
@@ -294,9 +294,7 @@ class Migrate:
break
else:
cls._add_operator(
cls._add_field(new_model, new_field),
upgrade,
isinstance(new_field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)),
cls._add_field(new_model, new_field), upgrade, cls._is_fk_m2m(new_field),
)
else:
old_field = old_fields_map.get(new_key)
@@ -344,6 +342,15 @@ class Migrate:
upgrade,
cls._is_fk_m2m(new_field),
)
if isinstance(new_field, ForeignKeyFieldInstance):
if old_field.db_constraint and not new_field.db_constraint:
cls._add_operator(
cls._drop_fk(new_model, new_field), upgrade, True,
)
if new_field.db_constraint and not old_field.db_constraint:
cls._add_operator(
cls._add_fk(new_model, new_field), upgrade, True,
)
for old_key in old_keys:
field = old_fields_map.get(old_key)
@@ -437,6 +444,10 @@ class Migrate:
def _modify_field(cls, model: Type[Model], field: Field):
return cls.ddl.modify_column(model, field)
@classmethod
def _drop_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance):
return cls.ddl.drop_fk(model, field)
@classmethod
def _remove_field(cls, model: Type[Model], field: Field):
if isinstance(field, ForeignKeyFieldInstance):