Add Rename support
This commit is contained in:
@@ -1,9 +1,7 @@
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from configparser import ConfigParser
|
||||
from enum import Enum
|
||||
|
||||
import asyncclick as click
|
||||
from asyncclick import Context, UsageError
|
||||
@@ -16,28 +14,12 @@ from aerich.migrate import Migrate
|
||||
from aerich.utils import get_app_connection, get_app_connection_name, get_tortoise_config
|
||||
|
||||
from . import __version__
|
||||
from .enums import Color
|
||||
from .models import Aerich
|
||||
|
||||
|
||||
class Color(str, Enum):
|
||||
green = "green"
|
||||
red = "red"
|
||||
yellow = "yellow"
|
||||
|
||||
|
||||
parser = ConfigParser()
|
||||
|
||||
|
||||
def close_db(func):
|
||||
@functools.wraps(func)
|
||||
async def close_db_inner(*args, **kwargs):
|
||||
result = await func(*args, **kwargs)
|
||||
await Tortoise.close_connections()
|
||||
return result
|
||||
|
||||
return close_db_inner
|
||||
|
||||
|
||||
@click.group(context_settings={"help_option_names": ["-h", "--help"]})
|
||||
@click.version_option(__version__, "-V", "--version")
|
||||
@click.option(
|
||||
@@ -81,12 +63,10 @@ async def cli(ctx: Context, config, app, name):
|
||||
@cli.command(help="Generate migrate changes file.")
|
||||
@click.option("--name", default="update", show_default=True, help="Migrate name.")
|
||||
@click.pass_context
|
||||
@close_db
|
||||
async def migrate(ctx: Context, name):
|
||||
config = ctx.obj["config"]
|
||||
location = ctx.obj["location"]
|
||||
app = ctx.obj["app"]
|
||||
|
||||
ret = await Migrate.migrate(name)
|
||||
if not ret:
|
||||
return click.secho("No changes detected", fg=Color.yellow)
|
||||
@@ -96,7 +76,6 @@ async def migrate(ctx: Context, name):
|
||||
|
||||
@cli.command(help="Upgrade to latest version.")
|
||||
@click.pass_context
|
||||
@close_db
|
||||
async def upgrade(ctx: Context):
|
||||
config = ctx.obj["config"]
|
||||
app = ctx.obj["app"]
|
||||
@@ -123,7 +102,6 @@ async def upgrade(ctx: Context):
|
||||
|
||||
@cli.command(help="Downgrade to previous version.")
|
||||
@click.pass_context
|
||||
@close_db
|
||||
async def downgrade(ctx: Context):
|
||||
app = ctx.obj["app"]
|
||||
config = ctx.obj["config"]
|
||||
@@ -146,7 +124,6 @@ async def downgrade(ctx: Context):
|
||||
|
||||
@cli.command(help="Show current available heads in migrate location.")
|
||||
@click.pass_context
|
||||
@close_db
|
||||
async def heads(ctx: Context):
|
||||
app = ctx.obj["app"]
|
||||
versions = Migrate.get_all_version_files()
|
||||
@@ -161,7 +138,6 @@ async def heads(ctx: Context):
|
||||
|
||||
@cli.command(help="List all migrate items.")
|
||||
@click.pass_context
|
||||
@close_db
|
||||
async def history(ctx: Context):
|
||||
versions = Migrate.get_all_version_files()
|
||||
for version in versions:
|
||||
@@ -212,7 +188,6 @@ async def init(
|
||||
show_default=True,
|
||||
)
|
||||
@click.pass_context
|
||||
@close_db
|
||||
async def init_db(ctx: Context, safe):
|
||||
config = ctx.obj["config"]
|
||||
location = ctx.obj["location"]
|
||||
|
||||
@@ -11,6 +11,9 @@ class BaseDDL:
|
||||
_DROP_TABLE_TEMPLATE = 'DROP TABLE IF EXISTS "{table_name}"'
|
||||
_ADD_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {column}'
|
||||
_DROP_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" DROP COLUMN "{column_name}"'
|
||||
_RENAME_COLUMN_TEMPLATE = (
|
||||
'ALTER TABLE "{table_name}" RENAME COLUMN "{old_column_name}" TO "{new_column_name}"'
|
||||
)
|
||||
_ADD_INDEX_TEMPLATE = (
|
||||
'ALTER TABLE "{table_name}" ADD {unique} INDEX "{index_name}" ({column_names})'
|
||||
)
|
||||
@@ -125,6 +128,13 @@ class BaseDDL:
|
||||
),
|
||||
)
|
||||
|
||||
def rename_column(self, model: "Type[Model]", old_column_name: str, new_column_name: str):
|
||||
return self._RENAME_COLUMN_TEMPLATE.format(
|
||||
table_name=model._meta.db_table,
|
||||
old_column_name=old_column_name,
|
||||
new_column_name=new_column_name,
|
||||
)
|
||||
|
||||
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
|
||||
return self._ADD_INDEX_TEMPLATE.format(
|
||||
unique="UNIQUE" if unique else "",
|
||||
|
||||
@@ -9,6 +9,9 @@ class MysqlDDL(BaseDDL):
|
||||
_DROP_TABLE_TEMPLATE = "DROP TABLE IF EXISTS `{table_name}`"
|
||||
_ADD_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` ADD {column}"
|
||||
_DROP_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` DROP COLUMN `{column_name}`"
|
||||
_RENAME_COLUMN_TEMPLATE = (
|
||||
"ALTER TABLE `{table_name}` RENAME COLUMN `{old_column_name}` TO `{new_column_name}`"
|
||||
)
|
||||
_ADD_INDEX_TEMPLATE = (
|
||||
"ALTER TABLE `{table_name}` ADD {unique} INDEX `{index_name}` ({column_names})"
|
||||
)
|
||||
|
||||
7
aerich/enums.py
Normal file
7
aerich/enums.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Color(str, Enum):
|
||||
green = "green"
|
||||
red = "red"
|
||||
yellow = "yellow"
|
||||
@@ -5,6 +5,7 @@ from datetime import datetime
|
||||
from importlib import import_module
|
||||
from typing import Dict, List, Tuple, Type
|
||||
|
||||
import click
|
||||
from tortoise import (
|
||||
BackwardFKRelation,
|
||||
BackwardOneToOneRelation,
|
||||
@@ -28,6 +29,8 @@ class Migrate:
|
||||
_upgrade_m2m: List[str] = []
|
||||
_downgrade_m2m: List[str] = []
|
||||
_aerich = Aerich.__name__
|
||||
_rename_old = []
|
||||
_rename_new = []
|
||||
|
||||
ddl: BaseDDL
|
||||
migrate_config: dict
|
||||
@@ -264,11 +267,37 @@ class Migrate:
|
||||
if cls._exclude_field(new_field, upgrade):
|
||||
continue
|
||||
if new_key not in old_keys:
|
||||
cls._add_operator(
|
||||
cls._add_field(new_model, new_field),
|
||||
upgrade,
|
||||
isinstance(new_field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)),
|
||||
)
|
||||
new_field_dict = new_field.describe(serializable=True)
|
||||
new_field_dict.pop("name")
|
||||
new_field_dict.pop("db_column")
|
||||
for diff_key in old_keys - new_keys:
|
||||
old_field = old_fields_map.get(diff_key)
|
||||
old_field_dict = old_field.describe(serializable=True)
|
||||
old_field_dict.pop("name")
|
||||
old_field_dict.pop("db_column")
|
||||
if old_field_dict == new_field_dict:
|
||||
if upgrade:
|
||||
is_rename = click.prompt(
|
||||
f"Rename {diff_key} to {new_key}",
|
||||
default=True,
|
||||
type=bool,
|
||||
show_choices=True,
|
||||
)
|
||||
cls._rename_new.append(new_key)
|
||||
cls._rename_old.append(diff_key)
|
||||
else:
|
||||
is_rename = diff_key in cls._rename_new
|
||||
if is_rename:
|
||||
cls._add_operator(
|
||||
cls._rename_field(new_model, old_field, new_field), upgrade,
|
||||
)
|
||||
break
|
||||
else:
|
||||
cls._add_operator(
|
||||
cls._add_field(new_model, new_field),
|
||||
upgrade,
|
||||
isinstance(new_field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)),
|
||||
)
|
||||
else:
|
||||
old_field = old_fields_map.get(new_key)
|
||||
new_field_dict = new_field.describe(serializable=True)
|
||||
@@ -319,9 +348,12 @@ class Migrate:
|
||||
for old_key in old_keys:
|
||||
field = old_fields_map.get(old_key)
|
||||
if old_key not in new_keys and not cls._exclude_field(field, upgrade):
|
||||
cls._add_operator(
|
||||
cls._remove_field(old_model, field), upgrade, cls._is_fk_m2m(field),
|
||||
)
|
||||
if (upgrade and old_key not in cls._rename_old) or (
|
||||
not upgrade and old_key not in cls._rename_new
|
||||
):
|
||||
cls._add_operator(
|
||||
cls._remove_field(old_model, field), upgrade, cls._is_fk_m2m(field),
|
||||
)
|
||||
|
||||
for new_index in new_indexes:
|
||||
if new_index not in old_indexes:
|
||||
@@ -413,6 +445,10 @@ class Migrate:
|
||||
return cls.ddl.drop_m2m(field)
|
||||
return cls.ddl.drop_column(model, field.model_field_name)
|
||||
|
||||
@classmethod
|
||||
def _rename_field(cls, model: Type[Model], old_field: Field, new_field: Field):
|
||||
return cls.ddl.rename_column(model, old_field.model_field_name, new_field.model_field_name)
|
||||
|
||||
@classmethod
|
||||
def _add_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user