Support rename field

This commit is contained in:
long2ice 2021-02-03 17:56:30 +08:00
parent e3a14a2f60
commit f1f0074255
5 changed files with 95 additions and 16 deletions

View File

@ -1,5 +1,11 @@
# ChangeLog # ChangeLog
## 0.5
### 0.5.0
- Refactor core code, now has no limitation for everything.
## 0.4 ## 0.4
### 0.4.4 ### 0.4.4

View File

@ -111,7 +111,7 @@ Format of migrate filename is
`{version_num}_{datetime}_{name|update}.sql`. `{version_num}_{datetime}_{name|update}.sql`.
And if `aerich` guess you are renaming a column, it will ask `Rename {old_column} to {new_column} [True]`, you can And if `aerich` guess you are renaming a column, it will ask `Rename {old_column} to {new_column} [True]`, you can
choice `True` to rename column without column drop, or choice `False` to drop column then create. choice `True` to rename column without column drop, or choice `False` to drop column then create, note that the after maybe lose data.
### Upgrade to latest version ### Upgrade to latest version

View File

@ -3,6 +3,7 @@ from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Tuple, Type
import click
from dictdiffer import diff from dictdiffer import diff
from tortoise import ( from tortoise import (
BackwardFKRelation, BackwardFKRelation,
@ -183,6 +184,15 @@ class Migrate:
old_unique_together = old_model_describe.get("unique_together") old_unique_together = old_model_describe.get("unique_together")
new_unique_together = new_model_describe.get("unique_together") new_unique_together = new_model_describe.get("unique_together")
old_pk_field = old_model_describe.get("pk_field")
new_pk_field = new_model_describe.get("pk_field")
# pk field
changes = diff(old_pk_field, new_pk_field)
for action, option, change in changes:
# current only support rename pk
if action == "change" and option == "name":
cls._add_operator(cls._rename_field(model, *change), upgrade)
# add unique_together # add unique_together
for index in set(new_unique_together).difference(set(old_unique_together)): for index in set(new_unique_together).difference(set(old_unique_together)):
cls._add_operator( cls._add_operator(
@ -202,18 +212,62 @@ class Migrate:
old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields)) old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields))
new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields)) new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields))
# add fields # add fields or rename fields
for new_data_field_name in set(new_data_fields_name).difference( for new_data_field_name in set(new_data_fields_name).difference(
set(old_data_fields_name) set(old_data_fields_name)
): ):
new_data_field = next(
filter(lambda x: x.get("name") == new_data_field_name, new_data_fields)
)
is_rename = False
for old_data_field in old_data_fields:
changes = list(diff(old_data_field, new_data_field))
old_data_field_name = old_data_field.get("name")
if len(changes) == 2:
# rename field
if changes[0] == (
"change",
"name",
(old_data_field_name, new_data_field_name),
) and changes[1] == (
"change",
"db_column",
(old_data_field.get("db_column"), new_data_field.get("db_column")),
):
if upgrade:
is_rename = click.prompt(
f"Rename {old_data_field_name} to {new_data_field_name}?",
default=True,
type=bool,
show_choices=True,
)
else:
is_rename = old_data_field_name in cls._rename_new
if is_rename:
cls._rename_new.append(new_data_field_name)
cls._rename_old.append(old_data_field_name)
# only MySQL8+ has rename syntax
if (
cls.dialect == "mysql"
and cls._db_version
and cls._db_version.startswith("5.")
):
cls._add_operator(
cls._change_field(
model, new_data_field, old_data_field
),
upgrade,
)
else:
cls._add_operator(
cls._rename_field(model, *changes[1][2]),
upgrade,
)
if not is_rename:
cls._add_operator( cls._add_operator(
cls._add_field( cls._add_field(
model, model,
next( new_data_field,
filter(
lambda x: x.get("name") == new_data_field_name, new_data_fields
)
),
), ),
upgrade, upgrade,
) )
@ -221,6 +275,11 @@ class Migrate:
for old_data_field_name in set(old_data_fields_name).difference( for old_data_field_name in set(old_data_fields_name).difference(
set(new_data_fields_name) set(new_data_fields_name)
): ):
# don't remove field if is rename
if (upgrade and old_data_field_name in cls._rename_old) or (
not upgrade and old_data_field_name in cls._rename_new
):
continue
cls._add_operator( cls._add_operator(
cls._remove_field( cls._remove_field(
model, model,
@ -383,8 +442,8 @@ class Migrate:
return cls.ddl.drop_column(model, column_name) return cls.ddl.drop_column(model, column_name)
@classmethod @classmethod
def _rename_field(cls, model: Type[Model], old_field: Field, new_field: Field): def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str):
return cls.ddl.rename_column(model, old_field.model_field_name, new_field.model_field_name) return cls.ddl.rename_column(model, old_field_name, new_field_name)
@classmethod @classmethod
def _change_field(cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict): def _change_field(cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict):

View File

@ -31,6 +31,7 @@ class User(Model):
class Email(Model): class Email(Model):
email_id = fields.IntField(pk=True)
email = fields.CharField(max_length=200, index=True) email = fields.CharField(max_length=200, index=True)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
address = fields.CharField(max_length=200) address = fields.CharField(max_length=200)
@ -50,7 +51,7 @@ class Product(Model):
sort = fields.IntField() sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed") is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField(ProductType, description="Product Type") type = fields.IntEnumField(ProductType, description="Product Type")
image = fields.CharField(max_length=200) pic = fields.CharField(max_length=200)
body = fields.TextField() body = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)

View File

@ -1,4 +1,5 @@
import pytest import pytest
from pytest_mock import MockerFixture
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
@ -731,9 +732,10 @@ old_models_describe = {
} }
def test_migrate(): def test_migrate(mocker: MockerFixture):
""" """
models.py diff with old_models.py models.py diff with old_models.py
- change email pk: id -> email_id
- add field: Email.address - add field: Email.address
- add fk: Config.user - add fk: Config.user
- drop fk: Email.user - drop fk: Email.user
@ -743,7 +745,10 @@ def test_migrate():
- change column: length User.password - change column: length User.password
- add unique_together: (name,type) of Product - add unique_together: (name,type) of Product
- alter default: Config.status - alter default: Config.status
- rename column: Product.image -> Product.pic
""" """
mocker.patch("click.prompt", side_effect=(False, True))
models_describe = get_models_describe("models") models_describe = get_models_describe("models")
Migrate.app = "models" Migrate.app = "models"
if isinstance(Migrate.ddl, SqliteDDL): if isinstance(Migrate.ddl, SqliteDDL):
@ -762,6 +767,8 @@ def test_migrate():
"ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT", "ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT",
"ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL", "ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL",
"ALTER TABLE `email` DROP COLUMN `user_id`", "ALTER TABLE `email` DROP COLUMN `user_id`",
"ALTER TABLE `product` RENAME COLUMN `image` TO `pic`",
"ALTER TABLE `email` RENAME COLUMN `id` TO `email_id`",
"ALTER TABLE `email` DROP FOREIGN KEY `fk_email_user_5b58673d`", "ALTER TABLE `email` DROP FOREIGN KEY `fk_email_user_5b58673d`",
"ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)", "ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)",
"ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_f14935` (`name`, `type`)", "ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_f14935` (`name`, `type`)",
@ -779,6 +786,8 @@ def test_migrate():
"ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1", "ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1",
"ALTER TABLE `email` ADD `user_id` INT NOT NULL", "ALTER TABLE `email` ADD `user_id` INT NOT NULL",
"ALTER TABLE `email` DROP COLUMN `address`", "ALTER TABLE `email` DROP COLUMN `address`",
"ALTER TABLE `product` RENAME COLUMN `pic` TO `image`",
"ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`",
"ALTER TABLE `email` ADD CONSTRAINT `fk_email_user_5b58673d` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE", "ALTER TABLE `email` ADD CONSTRAINT `fk_email_user_5b58673d` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
"ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`", "ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`",
"ALTER TABLE `product` DROP INDEX `uid_product_name_f14935`", "ALTER TABLE `product` DROP INDEX `uid_product_name_f14935`",
@ -797,6 +806,8 @@ def test_migrate():
'ALTER TABLE "config" ALTER COLUMN "status" DROP DEFAULT', 'ALTER TABLE "config" ALTER COLUMN "status" DROP DEFAULT',
'ALTER TABLE "email" ADD "address" VARCHAR(200) NOT NULL', 'ALTER TABLE "email" ADD "address" VARCHAR(200) NOT NULL',
'ALTER TABLE "email" DROP COLUMN "user_id"', 'ALTER TABLE "email" DROP COLUMN "user_id"',
'ALTER TABLE "product" RENAME COLUMN "image" TO "pic"',
'ALTER TABLE "email" RENAME COLUMN "id" TO "email_id"',
'ALTER TABLE "email" DROP CONSTRAINT "fk_email_user_5b58673d"', 'ALTER TABLE "email" DROP CONSTRAINT "fk_email_user_5b58673d"',
'CREATE INDEX "idx_email_email_4a1a33" ON "email" ("email")', 'CREATE INDEX "idx_email_email_4a1a33" ON "email" ("email")',
'CREATE UNIQUE INDEX "uid_product_name_f14935" ON "product" ("name", "type")', 'CREATE UNIQUE INDEX "uid_product_name_f14935" ON "product" ("name", "type")',
@ -813,6 +824,8 @@ def test_migrate():
'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1', 'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1',
'ALTER TABLE "email" ADD "user_id" INT NOT NULL', 'ALTER TABLE "email" ADD "user_id" INT NOT NULL',
'ALTER TABLE "email" DROP COLUMN "address"', 'ALTER TABLE "email" DROP COLUMN "address"',
'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"',
'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"',
'ALTER TABLE "email" ADD CONSTRAINT "fk_email_user_5b58673d" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE', 'ALTER TABLE "email" ADD CONSTRAINT "fk_email_user_5b58673d" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE',
'DROP INDEX "idx_email_email_4a1a33"', 'DROP INDEX "idx_email_email_4a1a33"',
'ALTER TABLE "product" DROP CONSTRAINT "uid_product_name_f14935"', 'ALTER TABLE "product" DROP CONSTRAINT "uid_product_name_f14935"',