Add rename column support MySQL5

This commit is contained in:
long2ice 2020-11-19 10:11:52 +08:00
parent 5760fe2040
commit 9879004fee
8 changed files with 66 additions and 10 deletions

View File

@ -5,6 +5,7 @@
### 0.4.0 ### 0.4.0
- Use `.sql` instead of `.json` to store version file. - Use `.sql` instead of `.json` to store version file.
- Add `rename` column support MySQL5.
## 0.3 ## 0.3

View File

@ -3,8 +3,10 @@ black_opts = -l 100 -t py38
py_warn = PYTHONDEVMODE=1 py_warn = PYTHONDEVMODE=1
MYSQL_HOST ?= "127.0.0.1" MYSQL_HOST ?= "127.0.0.1"
MYSQL_PORT ?= 3306 MYSQL_PORT ?= 3306
MYSQL_PASS ?= "123456"
POSTGRES_HOST ?= "127.0.0.1" POSTGRES_HOST ?= "127.0.0.1"
POSTGRES_PORT ?= 5432 POSTGRES_PORT ?= 5432
POSTGRES_PAS ?= "123456"
help: help:
@echo "Aerich development makefile" @echo "Aerich development makefile"

View File

@ -134,6 +134,10 @@ Usage: aerich downgrade [OPTIONS]
Options: Options:
-v, --version INTEGER Specified version, default to last. [default: -1] -v, --version INTEGER Specified version, default to last. [default: -1]
-d, --delete Delete version files at the same time. [default:
False]
--yes Confirm the action without prompting.
-h, --help Show this message and exit. -h, --help Show this message and exit.
``` ```

View File

@ -135,12 +135,20 @@ async def upgrade(ctx: Context):
show_default=True, show_default=True,
help="Specified version, default to last.", help="Specified version, default to last.",
) )
@click.option(
"-d",
"--delete",
is_flag=True,
default=False,
show_default=True,
help="Delete version files at the same time.",
)
@click.pass_context @click.pass_context
@click.confirmation_option( @click.confirmation_option(
prompt="Downgrade is dangerous, which maybe lose your data, are you sure?", prompt="Downgrade is dangerous, which maybe lose your data, are you sure?",
) )
@coro @coro
async def downgrade(ctx: Context, version: int): async def downgrade(ctx: Context, version: int, delete: bool):
app = ctx.obj["app"] app = ctx.obj["app"]
config = ctx.obj["config"] config = ctx.obj["config"]
if version == -1: if version == -1:
@ -164,7 +172,8 @@ async def downgrade(ctx: Context, version: int):
for downgrade_query in downgrade_query_list: for downgrade_query in downgrade_query_list:
await conn.execute_query(downgrade_query) await conn.execute_query(downgrade_query)
await version.delete() await version.delete()
os.unlink(file_path) if delete:
os.unlink(file_path)
click.secho(f"Success downgrade {file}", fg=Color.green) click.secho(f"Success downgrade {file}", fg=Color.green)

View File

@ -22,6 +22,9 @@ class BaseDDL:
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"'
_M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment};' _M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment};'
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}' _MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}'
_CHANGE_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" CHANGE {old_column_name} {new_column_name} {new_column_type}'
)
def __init__(self, client: "BaseDBAsyncClient"): def __init__(self, client: "BaseDBAsyncClient"):
self.client = client self.client = client
@ -136,6 +139,16 @@ class BaseDDL:
new_column_name=new_column_name, new_column_name=new_column_name,
) )
def change_column(
self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str
):
return self._CHANGE_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table,
old_column_name=old_column_name,
new_column_name=new_column_name,
new_column_type=new_column_type,
)
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False): def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
return self._ADD_INDEX_TEMPLATE.format( return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE" if unique else "", unique="UNIQUE" if unique else "",

View File

@ -4,12 +4,15 @@ import re
from datetime import datetime from datetime import datetime
from importlib import import_module from importlib import import_module
from io import StringIO from io import StringIO
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Tuple, Type, Union
import click import click
from packaging import version
from packaging.version import LegacyVersion, Version
from tortoise import ( from tortoise import (
BackwardFKRelation, BackwardFKRelation,
BackwardOneToOneRelation, BackwardOneToOneRelation,
BaseDBAsyncClient,
ForeignKeyFieldInstance, ForeignKeyFieldInstance,
ManyToManyFieldInstance, ManyToManyFieldInstance,
Model, Model,
@ -41,6 +44,7 @@ class Migrate:
app: str app: str
migrate_location: str migrate_location: str
dialect: str dialect: str
_db_version: Union[LegacyVersion, Version] = None
@classmethod @classmethod
def get_old_model_file(cls, app: str, location: str): def get_old_model_file(cls, app: str, location: str):
@ -67,6 +71,13 @@ class Migrate:
except (OSError, FileNotFoundError): except (OSError, FileNotFoundError):
pass pass
@classmethod
async def _get_db_version(cls, connection: BaseDBAsyncClient):
if cls.dialect == "mysql":
sql = "select version() as version"
ret = await connection.execute_query(sql)
cls._db_version = version.parse(ret[1][0].get("version"))
@classmethod @classmethod
async def init_with_old_models(cls, config: dict, app: str, location: str): async def init_with_old_models(cls, config: dict, app: str, location: str):
await Tortoise.init(config=config) await Tortoise.init(config=config)
@ -83,7 +94,6 @@ class Migrate:
await Tortoise.init(config=migrate_config) await Tortoise.init(config=migrate_config)
connection = get_app_connection(config, app) connection = get_app_connection(config, app)
cls.dialect = connection.schema_generator.DIALECT
if cls.dialect == "mysql": if cls.dialect == "mysql":
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
@ -96,6 +106,8 @@ class Migrate:
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
cls.ddl = PostgresDDL(connection) cls.ddl = PostgresDDL(connection)
cls.dialect = cls.ddl.DIALECT
await cls._get_db_version(connection)
@classmethod @classmethod
async def _get_last_version_num(cls): async def _get_last_version_num(cls):
@ -300,10 +312,16 @@ class Migrate:
else: else:
is_rename = diff_key in cls._rename_new is_rename = diff_key in cls._rename_new
if is_rename: if is_rename:
cls._add_operator( if cls.dialect == "mysql" and cls._db_version.major == 5:
cls._rename_field(new_model, old_field, new_field), cls._add_operator(
upgrade, cls._change_field(new_model, old_field, new_field),
) upgrade,
)
else:
cls._add_operator(
cls._rename_field(new_model, old_field, new_field),
upgrade,
)
break break
else: else:
cls._add_operator( cls._add_operator(
@ -487,6 +505,15 @@ class Migrate:
def _rename_field(cls, model: Type[Model], old_field: Field, new_field: Field): def _rename_field(cls, model: Type[Model], old_field: Field, new_field: Field):
return cls.ddl.rename_column(model, old_field.model_field_name, new_field.model_field_name) return cls.ddl.rename_column(model, old_field.model_field_name, new_field.model_field_name)
@classmethod
def _change_field(cls, model: Type[Model], old_field: Field, new_field: Field):
return cls.ddl.change_column(
model,
old_field.model_field_name,
new_field.model_field_name,
new_field.get_for_dialect(cls.dialect, "SQL_TYPE"),
)
@classmethod @classmethod
def _add_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance): def _add_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance):
""" """

View File

@ -67,5 +67,5 @@ async def initialize_tests(event_loop, request):
Migrate.ddl = SqliteDDL(client) Migrate.ddl = SqliteDDL(client)
elif client.schema_generator is AsyncpgSchemaGenerator: elif client.schema_generator is AsyncpgSchemaGenerator:
Migrate.ddl = PostgresDDL(client) Migrate.ddl = PostgresDDL(client)
Migrate.dialect = Migrate.ddl.DIALECT
request.addfinalizer(lambda: event_loop.run_until_complete(Tortoise._drop_databases())) request.addfinalizer(lambda: event_loop.run_until_complete(Tortoise._drop_databases()))

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "aerich" name = "aerich"
version = "0.3.3" version = "0.4.0"
description = "A database migrations tool for Tortoise ORM." description = "A database migrations tool for Tortoise ORM."
authors = ["long2ice <long2ice@gmail.com>"] authors = ["long2ice <long2ice@gmail.com>"]
license = "Apache-2.0" license = "Apache-2.0"