add support modify column

This commit is contained in:
long2ice
2020-05-25 22:39:39 +08:00
parent 95e41720cb
commit 6ffca1a0c7
7 changed files with 67 additions and 16 deletions

View File

@@ -164,7 +164,7 @@ def history(ctx):
)
@click.pass_context
async def init(
ctx: Context, tortoise_orm, location,
ctx: Context, tortoise_orm, location,
):
config_file = ctx.obj["config_file"]
name = ctx.obj["name"]

View File

@@ -18,6 +18,7 @@ class BaseDDL:
_ADD_FK_TEMPLATE = "ALTER TABLE {table_name} ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}"
_DROP_FK_TEMPLATE = "ALTER TABLE {table_name} DROP FOREIGN KEY {fk_name}"
_M2M_TABLE_TEMPLATE = "CREATE TABLE {table_name} ({backward_key} {backward_type} NOT NULL REFERENCES {backward_table} ({backward_field}) ON DELETE CASCADE,{forward_key} {forward_type} NOT NULL REFERENCES {forward_table} ({forward_field}) ON DELETE CASCADE){extra}{comment};"
_MODIFY_COLUMN_TEMPLATE = "ALTER TABLE {table_name} MODIFY COLUMN {column}"
def __init__(self, client: "BaseDBAsyncClient"):
self.client = client
@@ -51,7 +52,7 @@ class BaseDDL:
def drop_m2m(self, field: ManyToManyFieldInstance):
return self._DROP_TABLE_TEMPLATE.format(table_name=field.through)
def add_column(self, model: "Type[Model]", field_object: Field):
def _get_default(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table
default = field_object.default
db_column = field_object.model_field_name
@@ -74,6 +75,11 @@ class BaseDDL:
default = ""
else:
default = ""
return default
def add_column(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table
return self._ADD_COLUMN_TEMPLATE.format(
table_name=db_table,
column=self.schema_generator._create_string(
@@ -89,7 +95,7 @@ class BaseDDL:
if field_object.description
else "",
is_primary_key=field_object.pk,
default=default,
default=self._get_default(model, field_object),
),
)
@@ -98,6 +104,27 @@ class BaseDDL:
table_name=model._meta.db_table, column_name=column_name
)
def modify_column(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table
return self._MODIFY_COLUMN_TEMPLATE.format(
table_name=db_table,
column=self.schema_generator._create_string(
db_column=field_object.model_field_name,
field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"),
nullable="NOT NULL" if not field_object.null else "",
unique="",
comment=self.schema_generator._column_comment_generator(
table=db_table,
column=field_object.model_field_name,
comment=field_object.description,
)
if field_object.description
else "",
is_primary_key=field_object.pk,
default=self._get_default(model, field_object),
),
)
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE" if unique else "",

View File

@@ -220,6 +220,10 @@ class Migrate:
if old_model not in new_models.keys():
cls._add_operator(cls.remove_model(old_models.get(old_model)), upgrade)
@classmethod
def _is_fk_m2m(cls, field: Field):
return isinstance(field, (ForeignKeyFieldInstance, ManyToManyFieldInstance))
@classmethod
def add_model(cls, model: Type[Model]):
return cls.ddl.create_table(model)
@@ -260,6 +264,14 @@ class Migrate:
)
else:
old_field = old_fields_map.get(new_key)
new_field_dict = new_field.describe(serializable=True)
new_field_dict.pop("unique")
new_field_dict.pop("indexed")
old_field_dict = old_field.describe(serializable=True)
old_field_dict.pop("unique")
old_field_dict.pop("indexed")
if not cls._is_fk_m2m(new_field) and new_field_dict != old_field_dict:
cls._add_operator(cls._modify_field(new_model, new_field), upgrade=upgrade)
if (old_field.index and not new_field.index) or (
old_field.unique and not new_field.unique
):
@@ -268,7 +280,7 @@ class Migrate:
old_model, (old_field.model_field_name,), old_field.unique
),
upgrade,
isinstance(old_field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)),
cls._is_fk_m2m(old_field),
)
elif (new_field.index and not old_field.index) or (
new_field.unique and not old_field.unique
@@ -276,16 +288,14 @@ class Migrate:
cls._add_operator(
cls._add_index(new_model, (new_field.model_field_name,), new_field.unique),
upgrade,
isinstance(new_field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)),
cls._is_fk_m2m(new_field),
)
for old_key in old_keys:
field = old_fields_map.get(old_key)
if old_key not in new_keys and not cls._exclude_field(field, upgrade):
cls._add_operator(
cls._remove_field(old_model, field),
upgrade,
isinstance(field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)),
cls._remove_field(old_model, field), upgrade, cls._is_fk_m2m(field),
)
for new_index in new_indexes:
@@ -354,6 +364,10 @@ class Migrate:
return cls.ddl.create_m2m_table(model, field)
return cls.ddl.add_column(model, field)
@classmethod
def _modify_field(cls, model: Type[Model], field: Field):
return cls.ddl.modify_column(model, field)
@classmethod
def _remove_field(cls, model: Type[Model], field: Field):
if isinstance(field, ForeignKeyFieldInstance):