Add rename column support MySQL5

This commit is contained in:
long2ice
2020-11-19 10:11:52 +08:00
parent 5760fe2040
commit 9879004fee
8 changed files with 66 additions and 10 deletions

View File

@@ -4,12 +4,15 @@ import re
from datetime import datetime
from importlib import import_module
from io import StringIO
from typing import Dict, List, Optional, Tuple, Type
from typing import Dict, List, Optional, Tuple, Type, Union
import click
from packaging import version
from packaging.version import LegacyVersion, Version
from tortoise import (
BackwardFKRelation,
BackwardOneToOneRelation,
BaseDBAsyncClient,
ForeignKeyFieldInstance,
ManyToManyFieldInstance,
Model,
@@ -41,6 +44,7 @@ class Migrate:
app: str
migrate_location: str
dialect: str
_db_version: Union[LegacyVersion, Version] = None
@classmethod
def get_old_model_file(cls, app: str, location: str):
@@ -67,6 +71,13 @@ class Migrate:
except (OSError, FileNotFoundError):
pass
@classmethod
async def _get_db_version(cls, connection: BaseDBAsyncClient):
if cls.dialect == "mysql":
sql = "select version() as version"
ret = await connection.execute_query(sql)
cls._db_version = version.parse(ret[1][0].get("version"))
@classmethod
async def init_with_old_models(cls, config: dict, app: str, location: str):
await Tortoise.init(config=config)
@@ -83,7 +94,6 @@ class Migrate:
await Tortoise.init(config=migrate_config)
connection = get_app_connection(config, app)
cls.dialect = connection.schema_generator.DIALECT
if cls.dialect == "mysql":
from aerich.ddl.mysql import MysqlDDL
@@ -96,6 +106,8 @@ class Migrate:
from aerich.ddl.postgres import PostgresDDL
cls.ddl = PostgresDDL(connection)
cls.dialect = cls.ddl.DIALECT
await cls._get_db_version(connection)
@classmethod
async def _get_last_version_num(cls):
@@ -300,10 +312,16 @@ class Migrate:
else:
is_rename = diff_key in cls._rename_new
if is_rename:
cls._add_operator(
cls._rename_field(new_model, old_field, new_field),
upgrade,
)
if cls.dialect == "mysql" and cls._db_version.major == 5:
cls._add_operator(
cls._change_field(new_model, old_field, new_field),
upgrade,
)
else:
cls._add_operator(
cls._rename_field(new_model, old_field, new_field),
upgrade,
)
break
else:
cls._add_operator(
@@ -487,6 +505,15 @@ class Migrate:
def _rename_field(cls, model: Type[Model], old_field: Field, new_field: Field):
return cls.ddl.rename_column(model, old_field.model_field_name, new_field.model_field_name)
@classmethod
def _change_field(cls, model: Type[Model], old_field: Field, new_field: Field):
return cls.ddl.change_column(
model,
old_field.model_field_name,
new_field.model_field_name,
new_field.get_for_dialect(cls.dialect, "SQL_TYPE"),
)
@classmethod
def _add_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance):
"""