Support migration for specified index. (#203)
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import click
|
||||
from dictdiffer import diff
|
||||
from tortoise import BaseDBAsyncClient, Model, Tortoise
|
||||
from tortoise.exceptions import OperationalError
|
||||
from tortoise.indexes import Index
|
||||
|
||||
from aerich.ddl import BaseDDL
|
||||
from aerich.models import MAX_VERSION_LENGTH, Aerich
|
||||
@@ -32,7 +34,7 @@ class Migrate:
|
||||
ddl: BaseDDL
|
||||
_last_version_content: Optional[dict] = None
|
||||
app: str
|
||||
migrate_location: str
|
||||
migrate_location: Path
|
||||
dialect: str
|
||||
_db_version: Optional[str] = None
|
||||
|
||||
@@ -157,6 +159,18 @@ class Migrate:
|
||||
else:
|
||||
cls.downgrade_operators.append(operator)
|
||||
|
||||
@classmethod
|
||||
def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]):
|
||||
ret = []
|
||||
for index in indexes:
|
||||
if isinstance(index, Index):
|
||||
index.__hash__ = lambda self: md5( # nosec: B303
|
||||
self.index_name(cls.ddl.schema_generator, model).encode()
|
||||
+ self.__class__.__name__.encode()
|
||||
).hexdigest()
|
||||
ret.append(index)
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def diff_models(cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True):
|
||||
"""
|
||||
@@ -192,8 +206,18 @@ class Migrate:
|
||||
new_unique_together = set(
|
||||
map(lambda x: tuple(x), new_model_describe.get("unique_together"))
|
||||
)
|
||||
old_indexes = set(map(lambda x: tuple(x), old_model_describe.get("indexes", [])))
|
||||
new_indexes = set(map(lambda x: tuple(x), new_model_describe.get("indexes", [])))
|
||||
old_indexes = set(
|
||||
map(
|
||||
lambda x: x if isinstance(x, Index) else tuple(x),
|
||||
cls._handle_indexes(model, old_model_describe.get("indexes", [])),
|
||||
)
|
||||
)
|
||||
new_indexes = set(
|
||||
map(
|
||||
lambda x: x if isinstance(x, Index) else tuple(x),
|
||||
cls._handle_indexes(model, new_model_describe.get("indexes", [])),
|
||||
)
|
||||
)
|
||||
old_pk_field = old_model_describe.get("pk_field")
|
||||
new_pk_field = new_model_describe.get("pk_field")
|
||||
# pk field
|
||||
@@ -463,12 +487,18 @@ class Migrate:
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def _drop_index(cls, model: Type[Model], fields_name: Tuple[str], unique=False):
|
||||
def _drop_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False):
|
||||
if isinstance(fields_name, Index):
|
||||
return cls.ddl.drop_index_by_name(
|
||||
model, fields_name.index_name(cls.ddl.schema_generator, model)
|
||||
)
|
||||
fields_name = cls._resolve_fk_fields_name(model, fields_name)
|
||||
return cls.ddl.drop_index(model, fields_name, unique)
|
||||
|
||||
@classmethod
|
||||
def _add_index(cls, model: Type[Model], fields_name: Tuple[str], unique=False):
|
||||
def _add_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False):
|
||||
if isinstance(fields_name, Index):
|
||||
return fields_name.get_sql(cls.ddl.schema_generator, model, False)
|
||||
fields_name = cls._resolve_fk_fields_name(model, fields_name)
|
||||
return cls.ddl.add_index(model, fields_name, unique)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user