Add Rename support

This commit is contained in:
long2ice
2020-09-25 17:48:32 +08:00
parent af4d4be19a
commit 141d7205bf
11 changed files with 238 additions and 192 deletions

View File

@@ -1,9 +1,7 @@
import functools
import json
import os
import sys
from configparser import ConfigParser
from enum import Enum
import asyncclick as click
from asyncclick import Context, UsageError
@@ -16,28 +14,12 @@ from aerich.migrate import Migrate
from aerich.utils import get_app_connection, get_app_connection_name, get_tortoise_config
from . import __version__
from .enums import Color
from .models import Aerich
class Color(str, Enum):
green = "green"
red = "red"
yellow = "yellow"
parser = ConfigParser()
def close_db(func):
@functools.wraps(func)
async def close_db_inner(*args, **kwargs):
result = await func(*args, **kwargs)
await Tortoise.close_connections()
return result
return close_db_inner
@click.group(context_settings={"help_option_names": ["-h", "--help"]})
@click.version_option(__version__, "-V", "--version")
@click.option(
@@ -81,12 +63,10 @@ async def cli(ctx: Context, config, app, name):
@cli.command(help="Generate migrate changes file.")
@click.option("--name", default="update", show_default=True, help="Migrate name.")
@click.pass_context
@close_db
async def migrate(ctx: Context, name):
config = ctx.obj["config"]
location = ctx.obj["location"]
app = ctx.obj["app"]
ret = await Migrate.migrate(name)
if not ret:
return click.secho("No changes detected", fg=Color.yellow)
@@ -96,7 +76,6 @@ async def migrate(ctx: Context, name):
@cli.command(help="Upgrade to latest version.")
@click.pass_context
@close_db
async def upgrade(ctx: Context):
config = ctx.obj["config"]
app = ctx.obj["app"]
@@ -123,7 +102,6 @@ async def upgrade(ctx: Context):
@cli.command(help="Downgrade to previous version.")
@click.pass_context
@close_db
async def downgrade(ctx: Context):
app = ctx.obj["app"]
config = ctx.obj["config"]
@@ -146,7 +124,6 @@ async def downgrade(ctx: Context):
@cli.command(help="Show current available heads in migrate location.")
@click.pass_context
@close_db
async def heads(ctx: Context):
app = ctx.obj["app"]
versions = Migrate.get_all_version_files()
@@ -161,7 +138,6 @@ async def heads(ctx: Context):
@cli.command(help="List all migrate items.")
@click.pass_context
@close_db
async def history(ctx: Context):
versions = Migrate.get_all_version_files()
for version in versions:
@@ -212,7 +188,6 @@ async def init(
show_default=True,
)
@click.pass_context
@close_db
async def init_db(ctx: Context, safe):
config = ctx.obj["config"]
location = ctx.obj["location"]

View File

@@ -11,6 +11,9 @@ class BaseDDL:
_DROP_TABLE_TEMPLATE = 'DROP TABLE IF EXISTS "{table_name}"'
_ADD_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {column}'
_DROP_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" DROP COLUMN "{column_name}"'
_RENAME_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" RENAME COLUMN "{old_column_name}" TO "{new_column_name}"'
)
_ADD_INDEX_TEMPLATE = (
'ALTER TABLE "{table_name}" ADD {unique} INDEX "{index_name}" ({column_names})'
)
@@ -125,6 +128,13 @@ class BaseDDL:
),
)
def rename_column(self, model: "Type[Model]", old_column_name: str, new_column_name: str):
return self._RENAME_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table,
old_column_name=old_column_name,
new_column_name=new_column_name,
)
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE" if unique else "",

View File

@@ -9,6 +9,9 @@ class MysqlDDL(BaseDDL):
_DROP_TABLE_TEMPLATE = "DROP TABLE IF EXISTS `{table_name}`"
_ADD_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` ADD {column}"
_DROP_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` DROP COLUMN `{column_name}`"
_RENAME_COLUMN_TEMPLATE = (
"ALTER TABLE `{table_name}` RENAME COLUMN `{old_column_name}` TO `{new_column_name}`"
)
_ADD_INDEX_TEMPLATE = (
"ALTER TABLE `{table_name}` ADD {unique} INDEX `{index_name}` ({column_names})"
)

7
aerich/enums.py Normal file
View File

@@ -0,0 +1,7 @@
from enum import Enum
class Color(str, Enum):
green = "green"
red = "red"
yellow = "yellow"

View File

@@ -5,6 +5,7 @@ from datetime import datetime
from importlib import import_module
from typing import Dict, List, Tuple, Type
import click
from tortoise import (
BackwardFKRelation,
BackwardOneToOneRelation,
@@ -28,6 +29,8 @@ class Migrate:
_upgrade_m2m: List[str] = []
_downgrade_m2m: List[str] = []
_aerich = Aerich.__name__
_rename_old = []
_rename_new = []
ddl: BaseDDL
migrate_config: dict
@@ -264,11 +267,37 @@ class Migrate:
if cls._exclude_field(new_field, upgrade):
continue
if new_key not in old_keys:
cls._add_operator(
cls._add_field(new_model, new_field),
upgrade,
isinstance(new_field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)),
)
new_field_dict = new_field.describe(serializable=True)
new_field_dict.pop("name")
new_field_dict.pop("db_column")
for diff_key in old_keys - new_keys:
old_field = old_fields_map.get(diff_key)
old_field_dict = old_field.describe(serializable=True)
old_field_dict.pop("name")
old_field_dict.pop("db_column")
if old_field_dict == new_field_dict:
if upgrade:
is_rename = click.prompt(
f"Rename {diff_key} to {new_key}",
default=True,
type=bool,
show_choices=True,
)
cls._rename_new.append(new_key)
cls._rename_old.append(diff_key)
else:
is_rename = diff_key in cls._rename_new
if is_rename:
cls._add_operator(
cls._rename_field(new_model, old_field, new_field), upgrade,
)
break
else:
cls._add_operator(
cls._add_field(new_model, new_field),
upgrade,
isinstance(new_field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)),
)
else:
old_field = old_fields_map.get(new_key)
new_field_dict = new_field.describe(serializable=True)
@@ -319,9 +348,12 @@ class Migrate:
for old_key in old_keys:
field = old_fields_map.get(old_key)
if old_key not in new_keys and not cls._exclude_field(field, upgrade):
cls._add_operator(
cls._remove_field(old_model, field), upgrade, cls._is_fk_m2m(field),
)
if (upgrade and old_key not in cls._rename_old) or (
not upgrade and old_key not in cls._rename_new
):
cls._add_operator(
cls._remove_field(old_model, field), upgrade, cls._is_fk_m2m(field),
)
for new_index in new_indexes:
if new_index not in old_indexes:
@@ -413,6 +445,10 @@ class Migrate:
return cls.ddl.drop_m2m(field)
return cls.ddl.drop_column(model, field.model_field_name)
@classmethod
def _rename_field(cls, model: Type[Model], old_field: Field, new_field: Field):
return cls.ddl.rename_column(model, old_field.model_field_name, new_field.model_field_name)
@classmethod
def _add_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance):
"""