Add rename column support MySQL5
This commit is contained in:
@@ -4,12 +4,15 @@ import re
|
||||
from datetime import datetime
|
||||
from importlib import import_module
|
||||
from io import StringIO
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import click
|
||||
from packaging import version
|
||||
from packaging.version import LegacyVersion, Version
|
||||
from tortoise import (
|
||||
BackwardFKRelation,
|
||||
BackwardOneToOneRelation,
|
||||
BaseDBAsyncClient,
|
||||
ForeignKeyFieldInstance,
|
||||
ManyToManyFieldInstance,
|
||||
Model,
|
||||
@@ -41,6 +44,7 @@ class Migrate:
|
||||
app: str
|
||||
migrate_location: str
|
||||
dialect: str
|
||||
_db_version: Union[LegacyVersion, Version] = None
|
||||
|
||||
@classmethod
|
||||
def get_old_model_file(cls, app: str, location: str):
|
||||
@@ -67,6 +71,13 @@ class Migrate:
|
||||
except (OSError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def _get_db_version(cls, connection: BaseDBAsyncClient):
|
||||
if cls.dialect == "mysql":
|
||||
sql = "select version() as version"
|
||||
ret = await connection.execute_query(sql)
|
||||
cls._db_version = version.parse(ret[1][0].get("version"))
|
||||
|
||||
@classmethod
|
||||
async def init_with_old_models(cls, config: dict, app: str, location: str):
|
||||
await Tortoise.init(config=config)
|
||||
@@ -83,7 +94,6 @@ class Migrate:
|
||||
await Tortoise.init(config=migrate_config)
|
||||
|
||||
connection = get_app_connection(config, app)
|
||||
cls.dialect = connection.schema_generator.DIALECT
|
||||
if cls.dialect == "mysql":
|
||||
from aerich.ddl.mysql import MysqlDDL
|
||||
|
||||
@@ -96,6 +106,8 @@ class Migrate:
|
||||
from aerich.ddl.postgres import PostgresDDL
|
||||
|
||||
cls.ddl = PostgresDDL(connection)
|
||||
cls.dialect = cls.ddl.DIALECT
|
||||
await cls._get_db_version(connection)
|
||||
|
||||
@classmethod
|
||||
async def _get_last_version_num(cls):
|
||||
@@ -300,10 +312,16 @@ class Migrate:
|
||||
else:
|
||||
is_rename = diff_key in cls._rename_new
|
||||
if is_rename:
|
||||
cls._add_operator(
|
||||
cls._rename_field(new_model, old_field, new_field),
|
||||
upgrade,
|
||||
)
|
||||
if cls.dialect == "mysql" and cls._db_version.major == 5:
|
||||
cls._add_operator(
|
||||
cls._change_field(new_model, old_field, new_field),
|
||||
upgrade,
|
||||
)
|
||||
else:
|
||||
cls._add_operator(
|
||||
cls._rename_field(new_model, old_field, new_field),
|
||||
upgrade,
|
||||
)
|
||||
break
|
||||
else:
|
||||
cls._add_operator(
|
||||
@@ -487,6 +505,15 @@ class Migrate:
|
||||
def _rename_field(cls, model: Type[Model], old_field: Field, new_field: Field):
|
||||
return cls.ddl.rename_column(model, old_field.model_field_name, new_field.model_field_name)
|
||||
|
||||
@classmethod
|
||||
def _change_field(cls, model: Type[Model], old_field: Field, new_field: Field):
|
||||
return cls.ddl.change_column(
|
||||
model,
|
||||
old_field.model_field_name,
|
||||
new_field.model_field_name,
|
||||
new_field.get_for_dialect(cls.dialect, "SQL_TYPE"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _add_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user