fix: aerich migrate raises tortoise.exceptions.FieldError when index.INDEX_TYPE is not empty (#415)

* fix: aerich migrate raises `tortoise.exceptions.FieldError` when `index.INDEX_TYPE` is not empty

* feat: add `IF NOT EXISTS` to postgres create index template

* chore: explicit declare type hints of function parameters
This commit is contained in:
Waket Zheng 2025-02-13 18:48:45 +08:00 committed by GitHub
parent 0be5c1b545
commit 6bdfdfc6db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 320 additions and 164 deletions

View File

@ -8,6 +8,7 @@
- feat: add --fake to upgrade/downgrade. ([#398]) - feat: add --fake to upgrade/downgrade. ([#398])
#### Fixed #### Fixed
- fix: aerich migrate raises tortoise.exceptions.FieldError when `index.INDEX_TYPE` is not empty. ([#415])
- fix: inspectdb raise KeyError 'int2' for smallint. ([#401]) - fix: inspectdb raise KeyError 'int2' for smallint. ([#401])
### Changed ### Changed
@ -16,6 +17,7 @@
[#398]: https://github.com/tortoise/aerich/pull/398 [#398]: https://github.com/tortoise/aerich/pull/398
[#401]: https://github.com/tortoise/aerich/pull/401 [#401]: https://github.com/tortoise/aerich/pull/401
[#412]: https://github.com/tortoise/aerich/pull/412 [#412]: https://github.com/tortoise/aerich/pull/412
[#415]: https://github.com/tortoise/aerich/pull/415
### [0.8.1](../../releases/tag/v0.8.1) - 2024-12-27 ### [0.8.1](../../releases/tag/v0.8.1) - 2024-12-27

View File

@ -39,7 +39,7 @@ class Command:
async def init(self) -> None: async def init(self) -> None:
await Migrate.init(self.tortoise_config, self.app, self.location) await Migrate.init(self.tortoise_config, self.app, self.location)
async def _upgrade(self, conn, version_file, fake=False) -> None: async def _upgrade(self, conn, version_file, fake: bool = False) -> None:
file_path = Path(Migrate.migrate_location, version_file) file_path = Path(Migrate.migrate_location, version_file)
m = import_py_file(file_path) m = import_py_file(file_path)
upgrade = m.upgrade upgrade = m.upgrade
@ -51,7 +51,7 @@ class Command:
content=get_models_describe(self.app), content=get_models_describe(self.app),
) )
async def upgrade(self, run_in_transaction: bool = True, fake=False) -> List[str]: async def upgrade(self, run_in_transaction: bool = True, fake: bool = False) -> List[str]:
migrated = [] migrated = []
for version_file in Migrate.get_all_version_files(): for version_file in Migrate.get_all_version_files():
try: try:
@ -69,7 +69,7 @@ class Command:
migrated.append(version_file) migrated.append(version_file)
return migrated return migrated
async def downgrade(self, version: int, delete: bool, fake=False) -> List[str]: async def downgrade(self, version: int, delete: bool, fake: bool = False) -> List[str]:
ret: List[str] = [] ret: List[str] = []
if version == -1: if version == -1:
specified_version = await Migrate.get_last_version() specified_version = await Migrate.get_last_version()

View File

@ -9,6 +9,9 @@ from tortoise.indexes import Index
class JsonEncoder(json.JSONEncoder): class JsonEncoder(json.JSONEncoder):
def default(self, obj) -> Any: def default(self, obj) -> Any:
if isinstance(obj, Index): if isinstance(obj, Index):
if hasattr(obj, "describe"):
# For tortoise>=0.24
return obj.describe()
return { return {
"type": "index", "type": "index",
"val": base64.b64encode(pickle.dumps(obj)).decode(), # nosec: B301 "val": base64.b64encode(pickle.dumps(obj)).decode(), # nosec: B301
@ -17,11 +20,20 @@ class JsonEncoder(json.JSONEncoder):
return super().default(obj) return super().default(obj)
def load_index(obj: dict) -> Index:
"""Convert a dict that generated by `Index.decribe()` to a Index instance"""
index = Index(fields=obj["fields"] or obj["expressions"], name=obj.get("name"))
if extra := obj.get("extra"):
index.extra = extra
if idx_type := obj.get("type"):
index.INDEX_TYPE = idx_type
return index
def object_hook(obj) -> Any: def object_hook(obj) -> Any:
_type = obj.get("type") if (type_ := obj.get("type")) and type_ == "index" and (val := obj.get("val")):
if not _type: return pickle.loads(base64.b64decode(val)) # nosec: B301
return obj return obj
return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301
def encoder(obj: dict) -> str: def encoder(obj: dict) -> str:

View File

@ -1,16 +1,20 @@
from __future__ import annotations
import re import re
from enum import Enum from enum import Enum
from typing import Any, List, Type, cast from typing import TYPE_CHECKING, Any, cast
import tortoise import tortoise
from tortoise import BaseDBAsyncClient, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from aerich.utils import is_default_function from aerich.utils import is_default_function
if TYPE_CHECKING:
from tortoise import BaseDBAsyncClient, Model
class BaseDDL: class BaseDDL:
schema_generator_cls: Type[BaseSchemaGenerator] = BaseSchemaGenerator schema_generator_cls: type[BaseSchemaGenerator] = BaseSchemaGenerator
DIALECT = "sql" DIALECT = "sql"
_DROP_TABLE_TEMPLATE = 'DROP TABLE IF EXISTS "{table_name}"' _DROP_TABLE_TEMPLATE = 'DROP TABLE IF EXISTS "{table_name}"'
_ADD_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {column}' _ADD_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {column}'
@ -19,10 +23,8 @@ class BaseDDL:
_RENAME_COLUMN_TEMPLATE = ( _RENAME_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" RENAME COLUMN "{old_column_name}" TO "{new_column_name}"' 'ALTER TABLE "{table_name}" RENAME COLUMN "{old_column_name}" TO "{new_column_name}"'
) )
_ADD_INDEX_TEMPLATE = ( _ADD_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {index_type}{unique}INDEX "{index_name}" ({column_names}){extra}'
'ALTER TABLE "{table_name}" ADD {unique}INDEX "{index_name}" ({column_names})' _DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX IF EXISTS "{index_name}"'
)
_DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"'
_ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}' _ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"'
_M2M_TABLE_TEMPLATE = ( _M2M_TABLE_TEMPLATE = (
@ -37,11 +39,11 @@ class BaseDDL:
) )
_RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"' _RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"'
def __init__(self, client: "BaseDBAsyncClient") -> None: def __init__(self, client: BaseDBAsyncClient) -> None:
self.client = client self.client = client
self.schema_generator = self.schema_generator_cls(client) self.schema_generator = self.schema_generator_cls(client)
def create_table(self, model: "Type[Model]") -> str: def create_table(self, model: type[Model]) -> str:
schema = self.schema_generator._get_table_sql(model, True)["table_creation_string"] schema = self.schema_generator._get_table_sql(model, True)["table_creation_string"]
if tortoise.__version__ <= "0.23.0": if tortoise.__version__ <= "0.23.0":
# Remove extra space # Remove extra space
@ -52,7 +54,7 @@ class BaseDDL:
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name) return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def create_m2m( def create_m2m(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict self, model: type[Model], field_describe: dict, reference_table_describe: dict
) -> str: ) -> str:
through = cast(str, field_describe.get("through")) through = cast(str, field_describe.get("through"))
description = field_describe.get("description") description = field_describe.get("description")
@ -81,7 +83,7 @@ class BaseDDL:
def drop_m2m(self, table_name: str) -> str: def drop_m2m(self, table_name: str) -> str:
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name) return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def _get_default(self, model: "Type[Model]", field_describe: dict) -> Any: def _get_default(self, model: type[Model], field_describe: dict) -> Any:
db_table = model._meta.db_table db_table = model._meta.db_table
default = field_describe.get("default") default = field_describe.get("default")
if isinstance(default, Enum): if isinstance(default, Enum):
@ -111,10 +113,12 @@ class BaseDDL:
default = None default = None
return default return default
def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str: def add_column(self, model: type[Model], field_describe: dict, is_pk: bool = False) -> str:
return self._add_or_modify_column(model, field_describe, is_pk) return self._add_or_modify_column(model, field_describe, is_pk)
def _add_or_modify_column(self, model, field_describe: dict, is_pk: bool, modify=False) -> str: def _add_or_modify_column(
self, model: type[Model], field_describe: dict, is_pk: bool, modify: bool = False
) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
description = field_describe.get("description") description = field_describe.get("description")
db_column = cast(str, field_describe.get("db_column")) db_column = cast(str, field_describe.get("db_column"))
@ -150,17 +154,15 @@ class BaseDDL:
column = column.replace(" ", " ") column = column.replace(" ", " ")
return template.format(table_name=db_table, column=column) return template.format(table_name=db_table, column=column)
def drop_column(self, model: "Type[Model]", column_name: str) -> str: def drop_column(self, model: type[Model], column_name: str) -> str:
return self._DROP_COLUMN_TEMPLATE.format( return self._DROP_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, column_name=column_name table_name=model._meta.db_table, column_name=column_name
) )
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str: def modify_column(self, model: type[Model], field_describe: dict, is_pk: bool = False) -> str:
return self._add_or_modify_column(model, field_describe, is_pk, modify=True) return self._add_or_modify_column(model, field_describe, is_pk, modify=True)
def rename_column( def rename_column(self, model: type[Model], old_column_name: str, new_column_name: str) -> str:
self, model: "Type[Model]", old_column_name: str, new_column_name: str
) -> str:
return self._RENAME_COLUMN_TEMPLATE.format( return self._RENAME_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, table_name=model._meta.db_table,
old_column_name=old_column_name, old_column_name=old_column_name,
@ -168,7 +170,7 @@ class BaseDDL:
) )
def change_column( def change_column(
self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str self, model: type[Model], old_column_name: str, new_column_name: str, new_column_type: str
) -> str: ) -> str:
return self._CHANGE_COLUMN_TEMPLATE.format( return self._CHANGE_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, table_name=model._meta.db_table,
@ -177,32 +179,46 @@ class BaseDDL:
new_column_type=new_column_type, new_column_type=new_column_type,
) )
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str: def _index_name(self, unique: bool | None, model: type[Model], field_names: list[str]) -> str:
return self.schema_generator._generate_index_name(
"idx" if not unique else "uid", model, field_names
)
def add_index(
self,
model: type[Model],
field_names: list[str],
unique: bool | None = False,
name: str | None = None,
index_type: str = "",
extra: str | None = "",
) -> str:
return self._ADD_INDEX_TEMPLATE.format( return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE " if unique else "", unique="UNIQUE " if unique else "",
index_name=self.schema_generator._generate_index_name( index_name=name or self._index_name(unique, model, field_names),
"idx" if not unique else "uid", model, field_names
),
table_name=model._meta.db_table, table_name=model._meta.db_table,
column_names=", ".join(self.schema_generator.quote(f) for f in field_names), column_names=", ".join(self.schema_generator.quote(f) for f in field_names),
index_type=f"{index_type} " if index_type else "",
extra=f"{extra}" if extra else "",
) )
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str: def drop_index(
self,
model: type[Model],
field_names: list[str],
unique: bool | None = False,
name: str | None = None,
) -> str:
return self._DROP_INDEX_TEMPLATE.format( return self._DROP_INDEX_TEMPLATE.format(
index_name=self.schema_generator._generate_index_name( index_name=name or self._index_name(unique, model, field_names),
"idx" if not unique else "uid", model, field_names
),
table_name=model._meta.db_table, table_name=model._meta.db_table,
) )
def drop_index_by_name(self, model: "Type[Model]", index_name: str) -> str: def drop_index_by_name(self, model: type[Model], index_name: str) -> str:
return self._DROP_INDEX_TEMPLATE.format( return self.drop_index(model, [], name=index_name)
index_name=index_name,
table_name=model._meta.db_table,
)
def _generate_fk_name( def _generate_fk_name(
self, db_table, field_describe: dict, reference_table_describe: dict self, db_table: str, field_describe: dict, reference_table_describe: dict
) -> str: ) -> str:
"""Generate fk name""" """Generate fk name"""
db_column = cast(str, field_describe.get("raw_field")) db_column = cast(str, field_describe.get("raw_field"))
@ -217,7 +233,7 @@ class BaseDDL:
) )
def add_fk( def add_fk(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict self, model: type[Model], field_describe: dict, reference_table_describe: dict
) -> str: ) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
@ -234,13 +250,13 @@ class BaseDDL:
) )
def drop_fk( def drop_fk(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict self, model: type[Model], field_describe: dict, reference_table_describe: dict
) -> str: ) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
fk_name = self._generate_fk_name(db_table, field_describe, reference_table_describe) fk_name = self._generate_fk_name(db_table, field_describe, reference_table_describe)
return self._DROP_FK_TEMPLATE.format(table_name=db_table, fk_name=fk_name) return self._DROP_FK_TEMPLATE.format(table_name=db_table, fk_name=fk_name)
def alter_column_default(self, model: "Type[Model]", field_describe: dict) -> str: def alter_column_default(self, model: type[Model], field_describe: dict) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
default = self._get_default(model, field_describe) default = self._get_default(model, field_describe)
return self._ALTER_DEFAULT_TEMPLATE.format( return self._ALTER_DEFAULT_TEMPLATE.format(
@ -249,13 +265,13 @@ class BaseDDL:
default="SET" + default if default is not None else "DROP DEFAULT", default="SET" + default if default is not None else "DROP DEFAULT",
) )
def alter_column_null(self, model: "Type[Model]", field_describe: dict) -> str: def alter_column_null(self, model: type[Model], field_describe: dict) -> str:
return self.modify_column(model, field_describe) return self.modify_column(model, field_describe)
def set_comment(self, model: "Type[Model]", field_describe: dict) -> str: def set_comment(self, model: type[Model], field_describe: dict) -> str:
return self.modify_column(model, field_describe) return self.modify_column(model, field_describe)
def rename_table(self, model: "Type[Model]", old_table_name: str, new_table_name: str) -> str: def rename_table(self, model: type[Model], old_table_name: str, new_table_name: str) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
return self._RENAME_TABLE_TEMPLATE.format( return self._RENAME_TABLE_TEMPLATE.format(
table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name

View File

@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, List, Type from __future__ import annotations
from typing import TYPE_CHECKING
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
@ -21,9 +23,7 @@ class MysqlDDL(BaseDDL):
_RENAME_COLUMN_TEMPLATE = ( _RENAME_COLUMN_TEMPLATE = (
"ALTER TABLE `{table_name}` RENAME COLUMN `{old_column_name}` TO `{new_column_name}`" "ALTER TABLE `{table_name}` RENAME COLUMN `{old_column_name}` TO `{new_column_name}`"
) )
_ADD_INDEX_TEMPLATE = ( _ADD_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` ADD {index_type}{unique}INDEX `{index_name}` ({column_names}){extra}"
"ALTER TABLE `{table_name}` ADD {unique}INDEX `{index_name}` ({column_names})"
)
_DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`" _DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`"
_ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}" _ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}"
_DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`" _DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`"
@ -36,7 +36,7 @@ class MysqlDDL(BaseDDL):
_MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}" _MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}"
_RENAME_TABLE_TEMPLATE = "ALTER TABLE `{old_table_name}` RENAME TO `{new_table_name}`" _RENAME_TABLE_TEMPLATE = "ALTER TABLE `{old_table_name}` RENAME TO `{new_table_name}`"
def _index_name(self, unique: bool, model: "Type[Model]", field_names: List[str]) -> str: def _index_name(self, unique: bool | None, model: type[Model], field_names: list[str]) -> str:
if unique: if unique:
if len(field_names) == 1: if len(field_names) == 1:
# Example: `email = CharField(max_length=50, unique=True)` # Example: `email = CharField(max_length=50, unique=True)`
@ -47,17 +47,3 @@ class MysqlDDL(BaseDDL):
else: else:
index_prefix = "idx" index_prefix = "idx"
return self.schema_generator._generate_index_name(index_prefix, model, field_names) return self.schema_generator._generate_index_name(index_prefix, model, field_names)
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str:
return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE " if unique else "",
index_name=self._index_name(unique, model, field_names),
table_name=model._meta.db_table,
column_names=", ".join(self.schema_generator.quote(f) for f in field_names),
)
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str:
return self._DROP_INDEX_TEMPLATE.format(
index_name=self._index_name(unique, model, field_names),
table_name=model._meta.db_table,
)

View File

@ -1,4 +1,6 @@
from typing import Type, cast from __future__ import annotations
from typing import cast
from tortoise import Model from tortoise import Model
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
@ -9,7 +11,7 @@ from aerich.ddl import BaseDDL
class PostgresDDL(BaseDDL): class PostgresDDL(BaseDDL):
schema_generator_cls = AsyncpgSchemaGenerator schema_generator_cls = AsyncpgSchemaGenerator
DIALECT = AsyncpgSchemaGenerator.DIALECT DIALECT = AsyncpgSchemaGenerator.DIALECT
_ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX "{index_name}" ON "{table_name}" ({column_names})' _ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX IF NOT EXISTS "{index_name}" ON "{table_name}" {index_type}({column_names}){extra}'
_DROP_INDEX_TEMPLATE = 'DROP INDEX IF EXISTS "{index_name}"' _DROP_INDEX_TEMPLATE = 'DROP INDEX IF EXISTS "{index_name}"'
_ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL' _ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL'
_MODIFY_COLUMN_TEMPLATE = ( _MODIFY_COLUMN_TEMPLATE = (
@ -18,7 +20,7 @@ class PostgresDDL(BaseDDL):
_SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}' _SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT IF EXISTS "{fk_name}"' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT IF EXISTS "{fk_name}"'
def alter_column_null(self, model: "Type[Model]", field_describe: dict) -> str: def alter_column_null(self, model: type[Model], field_describe: dict) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
return self._ALTER_NULL_TEMPLATE.format( return self._ALTER_NULL_TEMPLATE.format(
table_name=db_table, table_name=db_table,
@ -26,7 +28,7 @@ class PostgresDDL(BaseDDL):
set_drop="DROP" if field_describe.get("nullable") else "SET", set_drop="DROP" if field_describe.get("nullable") else "SET",
) )
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str: def modify_column(self, model: type[Model], field_describe: dict, is_pk: bool = False) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
db_field_types = cast(dict, field_describe.get("db_field_types")) db_field_types = cast(dict, field_describe.get("db_field_types"))
db_column = field_describe.get("db_column") db_column = field_describe.get("db_column")
@ -38,7 +40,7 @@ class PostgresDDL(BaseDDL):
using=f' USING "{db_column}"::{datatype}', using=f' USING "{db_column}"::{datatype}',
) )
def set_comment(self, model: "Type[Model]", field_describe: dict) -> str: def set_comment(self, model: type[Model], field_describe: dict) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
return self._SET_COMMENT_TEMPLATE.format( return self._SET_COMMENT_TEMPLATE.format(
table_name=db_table, table_name=db_table,

View File

@ -13,6 +13,7 @@ from tortoise import BaseDBAsyncClient, Model, Tortoise
from tortoise.exceptions import OperationalError from tortoise.exceptions import OperationalError
from tortoise.indexes import Index from tortoise.indexes import Index
from aerich.coder import load_index
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
from aerich.models import MAX_VERSION_LENGTH, Aerich from aerich.models import MAX_VERSION_LENGTH, Aerich
from aerich.utils import ( from aerich.utils import (
@ -120,7 +121,7 @@ class Migrate:
return int(version.split("_", 1)[0]) return int(version.split("_", 1)[0])
@classmethod @classmethod
async def generate_version(cls, name=None) -> str: async def generate_version(cls, name: str | None = None) -> str:
now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "") now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "")
last_version_num = await cls._get_last_version_num() last_version_num = await cls._get_last_version_num()
if last_version_num is None: if last_version_num is None:
@ -197,7 +198,7 @@ class Migrate:
) )
@classmethod @classmethod
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False) -> None: def _add_operator(cls, operator: str, upgrade: bool = True, fk_m2m_index: bool = False) -> None:
""" """
add operator,differentiate fk because fk is order limit add operator,differentiate fk because fk is order limit
:param operator: :param operator:
@ -245,6 +246,8 @@ class Migrate:
for x in cls._handle_indexes(model, model_describe.get("indexes", [])): for x in cls._handle_indexes(model, model_describe.get("indexes", [])):
if isinstance(x, Index): if isinstance(x, Index):
indexes.add(x) indexes.add(x)
elif isinstance(x, dict):
indexes.add(load_index(x))
else: else:
indexes.add(cast("tuple[str, ...]", tuple(x))) indexes.add(cast("tuple[str, ...]", tuple(x)))
return indexes return indexes
@ -439,10 +442,10 @@ class Migrate:
cls._add_operator(cls._drop_index(model, index, True), upgrade, True) cls._add_operator(cls._drop_index(model, index, True), upgrade, True)
# add indexes # add indexes
for idx in new_indexes.difference(old_indexes): for idx in new_indexes.difference(old_indexes):
cls._add_operator(cls._add_index(model, idx, False), upgrade, True) cls._add_operator(cls._add_index(model, idx), upgrade, fk_m2m_index=True)
# remove indexes # remove indexes
for idx in old_indexes.difference(new_indexes): for idx in old_indexes.difference(new_indexes):
cls._add_operator(cls._drop_index(model, idx, False), upgrade, True) cls._add_operator(cls._drop_index(model, idx), upgrade, fk_m2m_index=True)
old_data_fields = list( old_data_fields = list(
filter( filter(
lambda x: x.get("db_field_types") is not None, lambda x: x.get("db_field_types") is not None,
@ -584,63 +587,64 @@ class Migrate:
# change fields # change fields
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)): for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
old_data_field = cls.get_field_by_name(field_name, old_data_fields) cls._handle_field_changes(
new_data_field = cls.get_field_by_name(field_name, new_data_fields) model, field_name, old_data_fields, new_data_fields, upgrade
changes = cls._exclude_extra_field_types(diff(old_data_field, new_data_field)) )
modified = False
for change in changes:
_, option, old_new = change
if option == "indexed":
# change index
if old_new[0] is False and old_new[1] is True:
unique = new_data_field.get("unique")
cls._add_operator(
cls._add_index(model, (field_name,), unique), upgrade, True
)
else:
unique = old_data_field.get("unique")
cls._add_operator(
cls._drop_index(model, (field_name,), unique), upgrade, True
)
elif option == "db_field_types.":
if new_data_field.get("field_type") == "DecimalField":
# modify column
cls._add_operator(
cls._modify_field(model, new_data_field),
upgrade,
)
else:
continue
elif option == "default":
if not (
is_default_function(old_new[0]) or is_default_function(old_new[1])
):
# change column default
cls._add_operator(
cls._alter_default(model, new_data_field), upgrade
)
elif option == "unique":
# because indexed include it
continue
elif option == "nullable":
# change nullable
cls._add_operator(cls._alter_null(model, new_data_field), upgrade)
elif option == "description":
# change comment
cls._add_operator(cls._set_comment(model, new_data_field), upgrade)
else:
if modified:
continue
# modify column
cls._add_operator(
cls._modify_field(model, new_data_field),
upgrade,
)
modified = True
for old_model in old_models.keys() - new_models.keys(): for old_model in old_models.keys() - new_models.keys():
cls._add_operator(cls.drop_model(old_models[old_model]["table"]), upgrade) cls._add_operator(cls.drop_model(old_models[old_model]["table"]), upgrade)
@classmethod
def _handle_field_changes(
cls,
model: type[Model],
field_name: str,
old_data_fields: list[dict],
new_data_fields: list[dict],
upgrade: bool,
) -> None:
old_data_field = cls.get_field_by_name(field_name, old_data_fields)
new_data_field = cls.get_field_by_name(field_name, new_data_fields)
changes = cls._exclude_extra_field_types(diff(old_data_field, new_data_field))
options = {c[1] for c in changes}
modified = False
for change in changes:
_, option, old_new = change
if option == "indexed":
# change index
if old_new[0] is False and old_new[1] is True:
unique = new_data_field.get("unique")
cls._add_operator(cls._add_index(model, (field_name,), unique), upgrade, True)
else:
unique = old_data_field.get("unique")
cls._add_operator(cls._drop_index(model, (field_name,), unique), upgrade, True)
elif option == "db_field_types.":
if new_data_field.get("field_type") == "DecimalField":
# modify column
cls._add_operator(cls._modify_field(model, new_data_field), upgrade)
elif option == "default":
if not (is_default_function(old_new[0]) or is_default_function(old_new[1])):
# change column default
cls._add_operator(cls._alter_default(model, new_data_field), upgrade)
elif option == "unique":
if "indexed" in options:
# indexed include it
continue
# Change unique for indexed field, e.g.: `db_index=True, unique=False` --> `db_index=True, unique=True`
# TODO
elif option == "nullable":
# change nullable
cls._add_operator(cls._alter_null(model, new_data_field), upgrade)
elif option == "description":
# change comment
cls._add_operator(cls._set_comment(model, new_data_field), upgrade)
else:
if modified:
continue
# modify column
cls._add_operator(cls._modify_field(model, new_data_field), upgrade)
modified = True
@classmethod @classmethod
def rename_table(cls, model: type[Model], old_table_name: str, new_table_name: str) -> str: def rename_table(cls, model: type[Model], old_table_name: str, new_table_name: str) -> str:
return cls.ddl.rename_table(model, old_table_name, new_table_name) return cls.ddl.rename_table(model, old_table_name, new_table_name)
@ -685,6 +689,16 @@ class Migrate:
cls, model: type[Model], fields_name: Union[Iterable[str], Index], unique=False cls, model: type[Model], fields_name: Union[Iterable[str], Index], unique=False
) -> str: ) -> str:
if isinstance(fields_name, Index): if isinstance(fields_name, Index):
if cls.dialect == "mysql":
# schema_generator of MySQL return a empty index sql
if hasattr(fields_name, "field_names"):
# tortoise>=0.24
fields = fields_name.field_names
else:
# TODO: remove else when drop support for tortoise<0.24
if not (fields := fields_name.fields):
fields = [getattr(i, "get_sql")() for i in fields_name.expressions]
return cls.ddl.drop_index(model, fields, unique, name=fields_name.name)
return cls.ddl.drop_index_by_name( return cls.ddl.drop_index_by_name(
model, fields_name.index_name(cls.ddl.schema_generator, model) model, fields_name.index_name(cls.ddl.schema_generator, model)
) )
@ -696,7 +710,29 @@ class Migrate:
cls, model: type[Model], fields_name: Union[Iterable[str], Index], unique=False cls, model: type[Model], fields_name: Union[Iterable[str], Index], unique=False
) -> str: ) -> str:
if isinstance(fields_name, Index): if isinstance(fields_name, Index):
return fields_name.get_sql(cls.ddl.schema_generator, model, False) if cls.dialect == "mysql":
# schema_generator of MySQL return a empty index sql
if hasattr(fields_name, "field_names"):
# tortoise>=0.24
fields = fields_name.field_names
else:
# TODO: remove else when drop support for tortoise<0.24
if not (fields := fields_name.fields):
fields = [getattr(i, "get_sql")() for i in fields_name.expressions]
return cls.ddl.add_index(
model,
fields,
name=fields_name.name,
index_type=fields_name.INDEX_TYPE,
extra=fields_name.extra,
)
sql = fields_name.get_sql(cls.ddl.schema_generator, model, safe=True)
if tortoise.__version__ < "0.24":
sql = sql.replace(" ", " ")
if cls.dialect == "postgres" and (exists := "IF NOT EXISTS ") not in sql:
idx = " INDEX "
sql = sql.replace(idx, idx + exists)
return sql
field_names = cls._resolve_fk_fields_name(model, fields_name) field_names = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.add_index(model, field_names, unique) return cls.ddl.add_index(model, field_names, unique)

68
poetry.lock generated
View File

@ -690,16 +690,19 @@ files = [
[[package]] [[package]]
name = "pbr" name = "pbr"
version = "6.1.0" version = "6.1.1"
description = "Python Build Reasonableness" description = "Python Build Reasonableness"
optional = false optional = false
python-versions = ">=2.6" python-versions = ">=2.6"
groups = ["dev"] groups = ["dev"]
files = [ files = [
{file = "pbr-6.1.0-py2.py3-none-any.whl", hash = "sha256:a776ae228892d8013649c0aeccbb3d5f99ee15e005a4cbb7e61d55a067b28a2a"}, {file = "pbr-6.1.1-py2.py3-none-any.whl", hash = "sha256:38d4daea5d9fa63b3f626131b9d34947fd0c8be9b05a29276870580050a25a76"},
{file = "pbr-6.1.0.tar.gz", hash = "sha256:788183e382e3d1d7707db08978239965e8b9e4e5ed42669bf4758186734d5f24"}, {file = "pbr-6.1.1.tar.gz", hash = "sha256:93ea72ce6989eb2eed99d0f75721474f69ad88128afdef5ac377eb797c4bf76b"},
] ]
[package.dependencies]
setuptools = "*"
[[package]] [[package]]
name = "platformdirs" name = "platformdirs"
version = "4.3.6" version = "4.3.6"
@ -1085,32 +1088,53 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"]
[[package]] [[package]]
name = "ruff" name = "ruff"
version = "0.9.4" version = "0.9.6"
description = "An extremely fast Python linter and code formatter, written in Rust." description = "An extremely fast Python linter and code formatter, written in Rust."
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
groups = ["dev"] groups = ["dev"]
files = [ files = [
{file = "ruff-0.9.4-py3-none-linux_armv6l.whl", hash = "sha256:64e73d25b954f71ff100bb70f39f1ee09e880728efb4250c632ceed4e4cdf706"}, {file = "ruff-0.9.6-py3-none-linux_armv6l.whl", hash = "sha256:2f218f356dd2d995839f1941322ff021c72a492c470f0b26a34f844c29cdf5ba"},
{file = "ruff-0.9.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:6ce6743ed64d9afab4fafeaea70d3631b4d4b28b592db21a5c2d1f0ef52934bf"}, {file = "ruff-0.9.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b908ff4df65dad7b251c9968a2e4560836d8f5487c2f0cc238321ed951ea0504"},
{file = "ruff-0.9.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:54499fb08408e32b57360f6f9de7157a5fec24ad79cb3f42ef2c3f3f728dfe2b"}, {file = "ruff-0.9.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b109c0ad2ececf42e75fa99dc4043ff72a357436bb171900714a9ea581ddef83"},
{file = "ruff-0.9.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37c892540108314a6f01f105040b5106aeb829fa5fb0561d2dcaf71485021137"}, {file = "ruff-0.9.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1de4367cca3dac99bcbd15c161404e849bb0bfd543664db39232648dc00112dc"},
{file = "ruff-0.9.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:de9edf2ce4b9ddf43fd93e20ef635a900e25f622f87ed6e3047a664d0e8f810e"}, {file = "ruff-0.9.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac3ee4d7c2c92ddfdaedf0bf31b2b176fa7aa8950efc454628d477394d35638b"},
{file = "ruff-0.9.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:87c90c32357c74f11deb7fbb065126d91771b207bf9bfaaee01277ca59b574ec"}, {file = "ruff-0.9.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5dc1edd1775270e6aa2386119aea692039781429f0be1e0949ea5884e011aa8e"},
{file = "ruff-0.9.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:56acd6c694da3695a7461cc55775f3a409c3815ac467279dfa126061d84b314b"}, {file = "ruff-0.9.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:4a091729086dffa4bd070aa5dab7e39cc6b9d62eb2bef8f3d91172d30d599666"},
{file = "ruff-0.9.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0c93e7d47ed951b9394cf352d6695b31498e68fd5782d6cbc282425655f687a"}, {file = "ruff-0.9.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d1bbc6808bf7b15796cef0815e1dfb796fbd383e7dbd4334709642649625e7c5"},
{file = "ruff-0.9.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1d4c8772670aecf037d1bf7a07c39106574d143b26cfe5ed1787d2f31e800214"}, {file = "ruff-0.9.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:589d1d9f25b5754ff230dce914a174a7c951a85a4e9270613a2b74231fdac2f5"},
{file = "ruff-0.9.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bfc5f1d7afeda8d5d37660eeca6d389b142d7f2b5a1ab659d9214ebd0e025231"}, {file = "ruff-0.9.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc61dd5131742e21103fbbdcad683a8813be0e3c204472d520d9a5021ca8b217"},
{file = "ruff-0.9.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:faa935fc00ae854d8b638c16a5f1ce881bc3f67446957dd6f2af440a5fc8526b"}, {file = "ruff-0.9.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5e2d9126161d0357e5c8f30b0bd6168d2c3872372f14481136d13de9937f79b6"},
{file = "ruff-0.9.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a6c634fc6f5a0ceae1ab3e13c58183978185d131a29c425e4eaa9f40afe1e6d6"}, {file = "ruff-0.9.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:68660eab1a8e65babb5229a1f97b46e3120923757a68b5413d8561f8a85d4897"},
{file = "ruff-0.9.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:433dedf6ddfdec7f1ac7575ec1eb9844fa60c4c8c2f8887a070672b8d353d34c"}, {file = "ruff-0.9.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c4cae6c4cc7b9b4017c71114115db0445b00a16de3bcde0946273e8392856f08"},
{file = "ruff-0.9.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d612dbd0f3a919a8cc1d12037168bfa536862066808960e0cc901404b77968f0"}, {file = "ruff-0.9.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:19f505b643228b417c1111a2a536424ddde0db4ef9023b9e04a46ed8a1cb4656"},
{file = "ruff-0.9.4-py3-none-win32.whl", hash = "sha256:db1192ddda2200671f9ef61d9597fcef89d934f5d1705e571a93a67fb13a4402"}, {file = "ruff-0.9.6-py3-none-win32.whl", hash = "sha256:194d8402bceef1b31164909540a597e0d913c0e4952015a5b40e28c146121b5d"},
{file = "ruff-0.9.4-py3-none-win_amd64.whl", hash = "sha256:05bebf4cdbe3ef75430d26c375773978950bbf4ee3c95ccb5448940dc092408e"}, {file = "ruff-0.9.6-py3-none-win_amd64.whl", hash = "sha256:03482d5c09d90d4ee3f40d97578423698ad895c87314c4de39ed2af945633caa"},
{file = "ruff-0.9.4-py3-none-win_arm64.whl", hash = "sha256:585792f1e81509e38ac5123492f8875fbc36f3ede8185af0a26df348e5154f41"}, {file = "ruff-0.9.6-py3-none-win_arm64.whl", hash = "sha256:0e2bb706a2be7ddfea4a4af918562fdc1bcb16df255e5fa595bbd800ce322a5a"},
{file = "ruff-0.9.4.tar.gz", hash = "sha256:6907ee3529244bb0ed066683e075f09285b38dd5b4039370df6ff06041ca19e7"}, {file = "ruff-0.9.6.tar.gz", hash = "sha256:81761592f72b620ec8fa1068a6fd00e98a5ebee342a3642efd84454f3031dca9"},
] ]
[[package]]
name = "setuptools"
version = "75.3.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false
python-versions = ">=3.8"
groups = ["dev"]
files = [
{file = "setuptools-75.3.0-py3-none-any.whl", hash = "sha256:f2504966861356aa38616760c0f66568e535562374995367b4e69c7143cf6bcd"},
{file = "setuptools-75.3.0.tar.gz", hash = "sha256:fba5dd4d766e97be1b1681d98712680ae8f2f26d7881245f2ce9e40714f1a686"},
]
[package.extras]
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.5.2)"]
core = ["importlib-metadata (>=6)", "importlib-resources (>=5.10.2)", "jaraco.collections", "jaraco.functools", "jaraco.text (>=3.7)", "more-itertools", "more-itertools (>=8.8)", "packaging", "packaging (>=24)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"]
cover = ["pytest-cov"]
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
enabler = ["pytest-enabler (>=2.2)"]
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test (>=5.5)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
type = ["importlib-metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.12.*)", "pytest-mypy"]
[[package]] [[package]]
name = "sniffio" name = "sniffio"
version = "1.3.1" version = "1.3.1"

View File

@ -44,3 +44,27 @@ async def init_db(tortoise_orm, generate_schemas=True) -> None:
def copy_files(*src_files: Path, target_dir: Path) -> None: def copy_files(*src_files: Path, target_dir: Path) -> None:
for src in src_files: for src in src_files:
shutil.copy(src, target_dir) shutil.copy(src, target_dir)
class Dialect:
test_db_url: str
@classmethod
def load_env(cls) -> None:
if getattr(cls, "test_db_url", None) is None:
cls.test_db_url = os.getenv("TEST_DB", "")
@classmethod
def is_postgres(cls) -> bool:
cls.load_env()
return "postgres" in cls.test_db_url
@classmethod
def is_mysql(cls) -> bool:
cls.load_env()
return "mysql" in cls.test_db_url
@classmethod
def is_sqlite(cls) -> bool:
cls.load_env()
return not cls.test_db_url or "sqlite" in cls.test_db_url

View File

@ -1,10 +1,15 @@
from __future__ import annotations
import datetime import datetime
import uuid import uuid
from enum import IntEnum from enum import IntEnum
from tortoise import Model, fields from tortoise import Model, fields
from tortoise.contrib.mysql.indexes import FullTextIndex
from tortoise.contrib.postgres.indexes import HashIndex
from tortoise.indexes import Index from tortoise.indexes import Index
from tests._utils import Dialect
from tests.indexes import CustomIndex from tests.indexes import CustomIndex
@ -63,6 +68,14 @@ class Category(Model):
title = fields.CharField(max_length=20, unique=False) title = fields.CharField(max_length=20, unique=False)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Meta:
if Dialect.is_postgres():
indexes = [HashIndex(fields=("slug",))]
elif Dialect.is_mysql():
indexes = [FullTextIndex(fields=("slug",))] # type:ignore
else:
indexes = [Index(fields=("slug",))] # type:ignore
class Product(Model): class Product(Model):
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField( categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
@ -75,7 +88,7 @@ class Product(Model):
view_num = fields.IntField(description="View Num", default=0) view_num = fields.IntField(description="View Num", default=0)
sort = fields.IntField() sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed") is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField( type: int = fields.IntEnumField(
ProductType, description="Product Type", source_field="type_db_alias" ProductType, description="Product Type", source_field="type_db_alias"
) )
pic = fields.CharField(max_length=200) pic = fields.CharField(max_length=200)

View File

@ -56,7 +56,7 @@ class Product(Model):
view_num = fields.IntField(description="View Num") view_num = fields.IntField(description="View Num")
sort = fields.IntField() sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed") is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField( type: int = fields.IntEnumField(
ProductType, description="Product Type", source_field="type_db_alias" ProductType, description="Product Type", source_field="type_db_alias"
) )
image = fields.CharField(max_length=200) image = fields.CharField(max_length=200)

View File

@ -55,6 +55,9 @@ class Category(Model):
title = fields.CharField(max_length=20, unique=True) title = fields.CharField(max_length=20, unique=True)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Meta:
indexes = [Index(fields=("slug",))]
class Product(Model): class Product(Model):
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category") categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category")
@ -63,7 +66,7 @@ class Product(Model):
view_num = fields.IntField(description="View Num") view_num = fields.IntField(description="View Num")
sort = fields.IntField() sort = fields.IntField()
is_review = fields.BooleanField(description="Is Reviewed") is_review = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField( type: int = fields.IntEnumField(
ProductType, description="Product Type", source_field="type_db_alias" ProductType, description="Product Type", source_field="type_db_alias"
) )
image = fields.CharField(max_length=200) image = fields.CharField(max_length=200)

View File

@ -1,3 +1,5 @@
import tortoise
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
@ -8,6 +10,21 @@ from tests.models import Category, Product, User
def test_create_table(): def test_create_table():
ret = Migrate.ddl.create_table(Category) ret = Migrate.ddl.create_table(Category)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
if tortoise.__version__ >= "0.24":
assert (
ret
== """CREATE TABLE IF NOT EXISTS `category` (
`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
`slug` VARCHAR(100) NOT NULL,
`name` VARCHAR(200),
`title` VARCHAR(20) NOT NULL,
`created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
`owner_id` INT NOT NULL COMMENT 'User',
CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE,
FULLTEXT KEY `idx_category_slug_e9bcff` (`slug`)
) CHARACTER SET utf8mb4"""
)
return
assert ( assert (
ret ret
== """CREATE TABLE IF NOT EXISTS `category` ( == """CREATE TABLE IF NOT EXISTS `category` (
@ -18,20 +35,23 @@ def test_create_table():
`created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
`owner_id` INT NOT NULL COMMENT 'User', `owner_id` INT NOT NULL COMMENT 'User',
CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE
) CHARACTER SET utf8mb4""" ) CHARACTER SET utf8mb4;
CREATE FULLTEXT INDEX `idx_category_slug_e9bcff` ON `category` (`slug`)"""
) )
elif isinstance(Migrate.ddl, SqliteDDL): elif isinstance(Migrate.ddl, SqliteDDL):
exists = "IF NOT EXISTS " if tortoise.__version__ >= "0.24" else ""
assert ( assert (
ret ret
== """CREATE TABLE IF NOT EXISTS "category" ( == f"""CREATE TABLE IF NOT EXISTS "category" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
"slug" VARCHAR(100) NOT NULL, "slug" VARCHAR(100) NOT NULL,
"name" VARCHAR(200), "name" VARCHAR(200),
"title" VARCHAR(20) NOT NULL, "title" VARCHAR(20) NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */ "owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */
)""" );
CREATE INDEX {exists}"idx_category_slug_e9bcff" ON "category" ("slug")"""
) )
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
@ -45,6 +65,7 @@ def test_create_table():
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, "created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE "owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE
); );
CREATE INDEX IF NOT EXISTS "idx_category_slug_e9bcff" ON "category" USING HASH ("slug");
COMMENT ON COLUMN "category"."owner_id" IS 'User'""" COMMENT ON COLUMN "category"."owner_id" IS 'User'"""
) )
@ -163,6 +184,14 @@ def test_add_index():
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert index == "ALTER TABLE `category` ADD INDEX `idx_category_name_8b0cb9` (`name`)" assert index == "ALTER TABLE `category` ADD INDEX `idx_category_name_8b0cb9` (`name`)"
assert index_u == "ALTER TABLE `category` ADD UNIQUE INDEX `name` (`name`)" assert index_u == "ALTER TABLE `category` ADD UNIQUE INDEX `name` (`name`)"
elif isinstance(Migrate.ddl, PostgresDDL):
assert (
index == 'CREATE INDEX IF NOT EXISTS "idx_category_name_8b0cb9" ON "category" ("name")'
)
assert (
index_u
== 'CREATE UNIQUE INDEX IF NOT EXISTS "uid_category_name_8b0cb9" ON "category" ("name")'
)
else: else:
assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")' assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")'
assert index_u == 'CREATE UNIQUE INDEX "uid_category_name_8b0cb9" ON "category" ("name")' assert index_u == 'CREATE UNIQUE INDEX "uid_category_name_8b0cb9" ON "category" ("name")'

View File

@ -35,7 +35,7 @@ old_models_describe = {
"description": None, "description": None,
"docstring": None, "docstring": None,
"unique_together": [], "unique_together": [],
"indexes": [], "indexes": [describe_index(Index(fields=("slug",)))],
"pk_field": { "pk_field": {
"name": "id", "name": "id",
"field_type": "IntField", "field_type": "IntField",
@ -929,6 +929,7 @@ def test_migrate(mocker: MockerFixture):
- drop fk field: Email.user - drop fk field: Email.user
- drop field: User.avatar - drop field: User.avatar
- add index: Email.email - add index: Email.email
- change index type for indexed field: Email.slug
- add many to many: Email.users - add many to many: Email.users
- add one to one: Email.config - add one to one: Email.config
- remove unique: Category.title - remove unique: Category.title
@ -965,6 +966,8 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `category` DROP INDEX `title`", "ALTER TABLE `category` DROP INDEX `title`",
"ALTER TABLE `category` RENAME COLUMN `user_id` TO `owner_id`", "ALTER TABLE `category` RENAME COLUMN `user_id` TO `owner_id`",
"ALTER TABLE `category` ADD CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE", "ALTER TABLE `category` ADD CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
"ALTER TABLE `category` ADD FULLTEXT INDEX `idx_category_slug_e9bcff` (`slug`)",
"ALTER TABLE `category` DROP INDEX `idx_category_slug_e9bcff`",
"ALTER TABLE `email` DROP COLUMN `user_id`", "ALTER TABLE `email` DROP COLUMN `user_id`",
"ALTER TABLE `config` DROP COLUMN `name`", "ALTER TABLE `config` DROP COLUMN `name`",
"ALTER TABLE `config` DROP INDEX `name`", "ALTER TABLE `config` DROP INDEX `name`",
@ -1007,6 +1010,8 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `category` ADD UNIQUE INDEX `title` (`title`)", "ALTER TABLE `category` ADD UNIQUE INDEX `title` (`title`)",
"ALTER TABLE `category` RENAME COLUMN `owner_id` TO `user_id`", "ALTER TABLE `category` RENAME COLUMN `owner_id` TO `user_id`",
"ALTER TABLE `category` DROP FOREIGN KEY `fk_category_user_110d4c63`", "ALTER TABLE `category` DROP FOREIGN KEY `fk_category_user_110d4c63`",
"ALTER TABLE `category` ADD INDEX `idx_category_slug_e9bcff` (`slug`)",
"ALTER TABLE `category` DROP INDEX `idx_category_slug_e9bcff`",
"ALTER TABLE `config` ADD `name` VARCHAR(100) NOT NULL UNIQUE", "ALTER TABLE `config` ADD `name` VARCHAR(100) NOT NULL UNIQUE",
"ALTER TABLE `config` ADD UNIQUE INDEX `name` (`name`)", "ALTER TABLE `config` ADD UNIQUE INDEX `name` (`name`)",
"ALTER TABLE `config` DROP FOREIGN KEY `fk_config_user_17daa970`", "ALTER TABLE `config` DROP FOREIGN KEY `fk_config_user_17daa970`",
@ -1050,6 +1055,8 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(100) USING "slug"::VARCHAR(100)', 'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(100) USING "slug"::VARCHAR(100)',
'ALTER TABLE "category" RENAME COLUMN "user_id" TO "owner_id"', 'ALTER TABLE "category" RENAME COLUMN "user_id" TO "owner_id"',
'ALTER TABLE "category" ADD CONSTRAINT "fk_category_user_110d4c63" FOREIGN KEY ("owner_id") REFERENCES "user" ("id") ON DELETE CASCADE', 'ALTER TABLE "category" ADD CONSTRAINT "fk_category_user_110d4c63" FOREIGN KEY ("owner_id") REFERENCES "user" ("id") ON DELETE CASCADE',
'CREATE INDEX IF NOT EXISTS "idx_category_slug_e9bcff" ON "category" USING HASH ("slug")',
'DROP INDEX IF EXISTS "idx_category_slug_e9bcff"',
'ALTER TABLE "config" DROP COLUMN "name"', 'ALTER TABLE "config" DROP COLUMN "name"',
'DROP INDEX IF EXISTS "uid_config_name_2c83c8"', 'DROP INDEX IF EXISTS "uid_config_name_2c83c8"',
'ALTER TABLE "config" ADD "user_id" INT NOT NULL', 'ALTER TABLE "config" ADD "user_id" INT NOT NULL',
@ -1070,12 +1077,12 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(100) USING "password"::VARCHAR(100)', 'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(100) USING "password"::VARCHAR(100)',
'ALTER TABLE "user" DROP COLUMN "avatar"', 'ALTER TABLE "user" DROP COLUMN "avatar"',
'ALTER TABLE "user" ALTER COLUMN "longitude" TYPE DECIMAL(10,8) USING "longitude"::DECIMAL(10,8)', 'ALTER TABLE "user" ALTER COLUMN "longitude" TYPE DECIMAL(10,8) USING "longitude"::DECIMAL(10,8)',
'CREATE INDEX "idx_product_name_869427" ON "product" ("name", "type_db_alias")', 'CREATE INDEX IF NOT EXISTS "idx_product_name_869427" ON "product" ("name", "type_db_alias")',
'CREATE INDEX "idx_email_email_4a1a33" ON "email" ("email")', 'CREATE INDEX IF NOT EXISTS "idx_email_email_4a1a33" ON "email" ("email")',
'CREATE TABLE "email_user" (\n "email_id" INT NOT NULL REFERENCES "email" ("email_id") ON DELETE CASCADE,\n "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE\n)', 'CREATE TABLE "email_user" (\n "email_id" INT NOT NULL REFERENCES "email" ("email_id") ON DELETE CASCADE,\n "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE\n)',
'CREATE TABLE IF NOT EXISTS "newmodel" (\n "id" SERIAL NOT NULL PRIMARY KEY,\n "name" VARCHAR(50) NOT NULL\n);\nCOMMENT ON COLUMN "config"."user_id" IS \'User\'', 'CREATE TABLE IF NOT EXISTS "newmodel" (\n "id" SERIAL NOT NULL PRIMARY KEY,\n "name" VARCHAR(50) NOT NULL\n);\nCOMMENT ON COLUMN "config"."user_id" IS \'User\'',
'CREATE UNIQUE INDEX "uid_product_name_869427" ON "product" ("name", "type_db_alias")', 'CREATE UNIQUE INDEX IF NOT EXISTS "uid_product_name_869427" ON "product" ("name", "type_db_alias")',
'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")', 'CREATE UNIQUE INDEX IF NOT EXISTS "uid_user_usernam_9987ab" ON "user" ("username")',
'CREATE TABLE "product_user" (\n "product_id" INT NOT NULL REFERENCES "product" ("id") ON DELETE CASCADE,\n "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE\n)', 'CREATE TABLE "product_user" (\n "product_id" INT NOT NULL REFERENCES "product" ("id") ON DELETE CASCADE,\n "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE\n)',
'CREATE TABLE "config_category_map" (\n "category_id" INT NOT NULL REFERENCES "category" ("id") ON DELETE CASCADE,\n "config_id" INT NOT NULL REFERENCES "config" ("id") ON DELETE CASCADE\n)', 'CREATE TABLE "config_category_map" (\n "category_id" INT NOT NULL REFERENCES "category" ("id") ON DELETE CASCADE,\n "config_id" INT NOT NULL REFERENCES "config" ("id") ON DELETE CASCADE\n)',
'DROP TABLE IF EXISTS "config_category"', 'DROP TABLE IF EXISTS "config_category"',
@ -1087,13 +1094,15 @@ def test_migrate(mocker: MockerFixture):
assert not upgrade_less_than_expected assert not upgrade_less_than_expected
expected_downgrade_operators = { expected_downgrade_operators = {
'CREATE UNIQUE INDEX "uid_category_title_f7fc03" ON "category" ("title")', 'CREATE UNIQUE INDEX IF NOT EXISTS "uid_category_title_f7fc03" ON "category" ("title")',
'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL', 'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL',
'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(200) USING "slug"::VARCHAR(200)', 'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(200) USING "slug"::VARCHAR(200)',
'ALTER TABLE "category" RENAME COLUMN "owner_id" TO "user_id"', 'ALTER TABLE "category" RENAME COLUMN "owner_id" TO "user_id"',
'ALTER TABLE "category" DROP CONSTRAINT IF EXISTS "fk_category_user_110d4c63"', 'ALTER TABLE "category" DROP CONSTRAINT IF EXISTS "fk_category_user_110d4c63"',
'DROP INDEX IF EXISTS "idx_category_slug_e9bcff"',
'CREATE INDEX IF NOT EXISTS "idx_category_slug_e9bcff" ON "category" ("slug")',
'ALTER TABLE "config" ADD "name" VARCHAR(100) NOT NULL UNIQUE', 'ALTER TABLE "config" ADD "name" VARCHAR(100) NOT NULL UNIQUE',
'CREATE UNIQUE INDEX "uid_config_name_2c83c8" ON "config" ("name")', 'CREATE UNIQUE INDEX IF NOT EXISTS "uid_config_name_2c83c8" ON "config" ("name")',
'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1', 'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1',
'ALTER TABLE "config" DROP CONSTRAINT IF EXISTS "fk_config_user_17daa970"', 'ALTER TABLE "config" DROP CONSTRAINT IF EXISTS "fk_config_user_17daa970"',
'ALTER TABLE "config" RENAME TO "configs"', 'ALTER TABLE "config" RENAME TO "configs"',
@ -1104,7 +1113,7 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "email" DROP COLUMN "config_id"', 'ALTER TABLE "email" DROP COLUMN "config_id"',
'ALTER TABLE "email" DROP CONSTRAINT IF EXISTS "fk_email_config_76a9dc71"', 'ALTER TABLE "email" DROP CONSTRAINT IF EXISTS "fk_email_config_76a9dc71"',
'ALTER TABLE "product" ADD "uuid" INT NOT NULL UNIQUE', 'ALTER TABLE "product" ADD "uuid" INT NOT NULL UNIQUE',
'CREATE UNIQUE INDEX "uid_product_uuid_d33c18" ON "product" ("uuid")', 'CREATE UNIQUE INDEX IF NOT EXISTS "uid_product_uuid_d33c18" ON "product" ("uuid")',
'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT', 'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT',
'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"', 'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"',
'ALTER TABLE "product" RENAME COLUMN "is_deleted" TO "is_delete"', 'ALTER TABLE "product" RENAME COLUMN "is_deleted" TO "is_delete"',