1 Commits
v0.8.0 ... main

Author SHA1 Message Date
long2ice
40d0823c01 refactor: make in_transaction default True 2023-08-04 10:35:46 +08:00
28 changed files with 1046 additions and 1367 deletions

View File

@@ -18,22 +18,13 @@ jobs:
POSTGRES_PASSWORD: 123456 POSTGRES_PASSWORD: 123456
POSTGRES_USER: postgres POSTGRES_USER: postgres
options: --health-cmd=pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 options: --health-cmd=pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps: steps:
- name: Start MySQL - name: Start MySQL
run: sudo systemctl start mysql.service run: sudo systemctl start mysql.service
- uses: actions/cache@v4 - uses: actions/checkout@v2
- uses: actions/setup-python@v2
with: with:
path: ~/.cache/pip python-version: '3.x'
key: ${{ runner.os }}-pip-${{ hashFiles('**/poetry.lock') }}
restore-keys: |
${{ runner.os }}-pip-
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install and configure Poetry - name: Install and configure Poetry
run: | run: |
pip install -U pip poetry pip install -U pip poetry

View File

@@ -7,8 +7,8 @@ jobs:
publish: publish:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v2
- uses: actions/setup-python@v5 - uses: actions/setup-python@v2
with: with:
python-version: '3.x' python-version: '3.x'
- name: Install and configure Poetry - name: Install and configure Poetry

View File

@@ -1,21 +1,8 @@
# ChangeLog # ChangeLog
## 0.8
### [0.8.0](../../releases/tag/v0.8.0) - 2024-12-04
- Fix the issue of parameter concatenation when generating ORM with inspectdb (#331)
- Fix KeyError when deleting a field with unqiue=True. (#364)
- Correct the click import. (#360)
- Improve CLI help text and output. (#355)
- Fix mysql drop unique index raises OperationalError. (#346)
**Upgrade note:**
1. Use column name as unique key name for mysql
2. Drop support for Python3.7
## 0.7 ## 0.7
### [0.7.2](../../releases/tag/v0.7.2) - 2023-07-20 ### 0.7.2
- Support virtual fields. - Support virtual fields.
- Fix modify multiple times. (#279) - Fix modify multiple times. (#279)

View File

@@ -14,20 +14,13 @@ up:
deps: deps:
@poetry install -E asyncpg -E asyncmy @poetry install -E asyncpg -E asyncmy
_style: style: deps
@isort -src $(checkfiles) @isort -src $(checkfiles)
@black $(black_opts) $(checkfiles) @black $(black_opts) $(checkfiles)
style: deps _style
_check: check: deps
@black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false) @black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
@ruff check $(checkfiles) @ruff $(checkfiles)
@mypy $(checkfiles)
ifneq ($(shell python -c 'import sys;is_py38=sys.version_info<(3,9);rc=int(is_py38);sys.exit(rc)'),)
# Run bandit with Python3.9+, as the `usedforsecurity=...` parameter of `hashlib.new` is only added from Python 3.9 onwards.
@bandit -r aerich
endif
check: deps _check
test: deps test: deps
$(py_warn) TEST_DB=sqlite://:memory: py.test $(py_warn) TEST_DB=sqlite://:memory: py.test
@@ -41,10 +34,9 @@ test_mysql:
test_postgres: test_postgres:
$(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest -vv -s $(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest -vv -s
_testall: test_sqlite test_postgres test_mysql testall: deps test_sqlite test_postgres test_mysql
testall: deps _testall
build: deps build: deps
@poetry build @poetry build
ci: check _testall ci: check testall

View File

@@ -46,7 +46,7 @@ Commands:
## Usage ## Usage
You need to add `aerich.models` to your `Tortoise-ORM` config first. Example: You need add `aerich.models` to your `Tortoise-ORM` config first. Example:
```python ```python
TORTOISE_ORM = { TORTOISE_ORM = {
@@ -113,14 +113,6 @@ If `aerich` guesses you are renaming a column, it will ask `Rename {old_column}
`True` to rename column without column drop, or choose `False` to drop the column then create. Note that the latter may `True` to rename column without column drop, or choose `False` to drop the column then create. Note that the latter may
lose data. lose data.
If you need to manually write migration, you could generate empty file:
```shell
> aerich migrate --name add_index --empty
Success migrate 1_202326122220101229_add_index.py
```
### Upgrade to latest version ### Upgrade to latest version
```shell ```shell

View File

@@ -1,6 +1,6 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Type from typing import List
from tortoise import Tortoise, generate_schema_for_client from tortoise import Tortoise, generate_schema_for_client
from tortoise.exceptions import OperationalError from tortoise.exceptions import OperationalError
@@ -20,9 +20,6 @@ from aerich.utils import (
import_py_file, import_py_file,
) )
if TYPE_CHECKING:
from aerich.inspectdb import Inspect # noqa:F401
class Command: class Command:
def __init__( def __init__(
@@ -30,16 +27,16 @@ class Command:
tortoise_config: dict, tortoise_config: dict,
app: str = "models", app: str = "models",
location: str = "./migrations", location: str = "./migrations",
) -> None: ):
self.tortoise_config = tortoise_config self.tortoise_config = tortoise_config
self.app = app self.app = app
self.location = location self.location = location
Migrate.app = app Migrate.app = app
async def init(self) -> None: async def init(self):
await Migrate.init(self.tortoise_config, self.app, self.location) await Migrate.init(self.tortoise_config, self.app, self.location)
async def _upgrade(self, conn, version_file) -> None: async def _upgrade(self, conn, version_file):
file_path = Path(Migrate.migrate_location, version_file) file_path = Path(Migrate.migrate_location, version_file)
m = import_py_file(file_path) m = import_py_file(file_path)
upgrade = getattr(m, "upgrade") upgrade = getattr(m, "upgrade")
@@ -50,7 +47,7 @@ class Command:
content=get_models_describe(self.app), content=get_models_describe(self.app),
) )
async def upgrade(self, run_in_transaction: bool = True) -> List[str]: async def upgrade(self, run_in_transaction: bool = True):
migrated = [] migrated = []
for version_file in Migrate.get_all_version_files(): for version_file in Migrate.get_all_version_files():
try: try:
@@ -68,8 +65,8 @@ class Command:
migrated.append(version_file) migrated.append(version_file)
return migrated return migrated
async def downgrade(self, version: int, delete: bool) -> List[str]: async def downgrade(self, version: int, delete: bool):
ret: List[str] = [] ret = []
if version == -1: if version == -1:
specified_version = await Migrate.get_last_version() specified_version = await Migrate.get_last_version()
else: else:
@@ -82,8 +79,8 @@ class Command:
versions = [specified_version] versions = [specified_version]
else: else:
versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk) versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk)
for version_obj in versions: for version in versions:
file = version_obj.version file = version.version
async with in_transaction( async with in_transaction(
get_app_connection_name(self.tortoise_config, self.app) get_app_connection_name(self.tortoise_config, self.app)
) as conn: ) as conn:
@@ -94,13 +91,13 @@ class Command:
if not downgrade_sql.strip(): if not downgrade_sql.strip():
raise DowngradeError("No downgrade items found") raise DowngradeError("No downgrade items found")
await conn.execute_script(downgrade_sql) await conn.execute_script(downgrade_sql)
await version_obj.delete() await version.delete()
if delete: if delete:
os.unlink(file_path) os.unlink(file_path)
ret.append(file) ret.append(file)
return ret return ret
async def heads(self) -> List[str]: async def heads(self):
ret = [] ret = []
versions = Migrate.get_all_version_files() versions = Migrate.get_all_version_files()
for version in versions: for version in versions:
@@ -108,15 +105,15 @@ class Command:
ret.append(version) ret.append(version)
return ret return ret
async def history(self) -> List[str]: async def history(self):
versions = Migrate.get_all_version_files() versions = Migrate.get_all_version_files()
return [version for version in versions] return [version for version in versions]
async def inspectdb(self, tables: Optional[List[str]] = None) -> str: async def inspectdb(self, tables: List[str] = None) -> str:
connection = get_app_connection(self.tortoise_config, self.app) connection = get_app_connection(self.tortoise_config, self.app)
dialect = connection.schema_generator.DIALECT dialect = connection.schema_generator.DIALECT
if dialect == "mysql": if dialect == "mysql":
cls: Type["Inspect"] = InspectMySQL cls = InspectMySQL
elif dialect == "postgres": elif dialect == "postgres":
cls = InspectPostgres cls = InspectPostgres
elif dialect == "sqlite": elif dialect == "sqlite":
@@ -126,10 +123,10 @@ class Command:
inspect = cls(connection, tables) inspect = cls(connection, tables)
return await inspect.inspect() return await inspect.inspect()
async def migrate(self, name: str = "update", empty: bool = False) -> str: async def migrate(self, name: str = "update"):
return await Migrate.migrate(name, empty) return await Migrate.migrate(name)
async def init_db(self, safe: bool) -> None: async def init_db(self, safe: bool):
location = self.location location = self.location
app = self.app app = self.app
dirname = Path(location, app) dirname = Path(location, app)

View File

@@ -1,11 +1,14 @@
import asyncio
import os import os
from functools import wraps
from pathlib import Path from pathlib import Path
from typing import Dict, List, cast from typing import List
import asyncclick as click import click
import tomlkit import tomlkit
from asyncclick import Context, UsageError from click import Context, UsageError
from tomlkit.exceptions import NonExistentKey from tomlkit.exceptions import NonExistentKey
from tortoise import Tortoise
from aerich import Command from aerich import Command
from aerich.enums import Color from aerich.enums import Color
@@ -18,6 +21,21 @@ CONFIG_DEFAULT_VALUES = {
} }
def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
# Close db connections at the end of all but the cli group function
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
if f.__name__ not in ["cli", "init"]:
loop.run_until_complete(Tortoise.close_connections())
return wrapper
@click.group(context_settings={"help_option_names": ["-h", "--help"]}) @click.group(context_settings={"help_option_names": ["-h", "--help"]})
@click.version_option(__version__, "-V", "--version") @click.version_option(__version__, "-V", "--version")
@click.option( @click.option(
@@ -29,7 +47,8 @@ CONFIG_DEFAULT_VALUES = {
) )
@click.option("--app", required=False, help="Tortoise-ORM app name.") @click.option("--app", required=False, help="Tortoise-ORM app name.")
@click.pass_context @click.pass_context
async def cli(ctx: Context, config, app) -> None: @coro
async def cli(ctx: Context, config, app):
ctx.ensure_object(dict) ctx.ensure_object(dict)
ctx.obj["config_file"] = config ctx.obj["config_file"] = config
@@ -37,62 +56,57 @@ async def cli(ctx: Context, config, app) -> None:
if invoked_subcommand != "init": if invoked_subcommand != "init":
config_path = Path(config) config_path = Path(config)
if not config_path.exists(): if not config_path.exists():
raise UsageError( raise UsageError("You must exec init first", ctx=ctx)
"You need to run `aerich init` first to create the config file.", ctx=ctx
)
content = config_path.read_text() content = config_path.read_text()
doc: dict = tomlkit.parse(content) doc = tomlkit.parse(content)
try: try:
tool = cast(Dict[str, str], doc["tool"]["aerich"]) tool = doc["tool"]["aerich"]
location = tool["location"] location = tool["location"]
tortoise_orm = tool["tortoise_orm"] tortoise_orm = tool["tortoise_orm"]
src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"]) src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"])
except NonExistentKey: except NonExistentKey:
raise UsageError("You need run `aerich init` again when upgrading to aerich 0.6.0+.") raise UsageError("You need run aerich init again when upgrade to 0.6.0+")
add_src_path(src_folder) add_src_path(src_folder)
tortoise_config = get_tortoise_config(ctx, tortoise_orm) tortoise_config = get_tortoise_config(ctx, tortoise_orm)
if not app: app = app or list(tortoise_config.get("apps").keys())[0]
apps_config = cast(dict, tortoise_config.get("apps"))
app = list(apps_config.keys())[0]
command = Command(tortoise_config=tortoise_config, app=app, location=location) command = Command(tortoise_config=tortoise_config, app=app, location=location)
ctx.obj["command"] = command ctx.obj["command"] = command
if invoked_subcommand != "init-db": if invoked_subcommand != "init-db":
if not Path(location, app).exists(): if not Path(location, app).exists():
raise UsageError( raise UsageError("You must exec init-db first", ctx=ctx)
"You need to run `aerich init-db` first to initialize the database.", ctx=ctx
)
await command.init() await command.init()
@cli.command(help="Generate a migration file for the current state of the models.") @cli.command(help="Generate migrate changes file.")
@click.option("--name", default="update", show_default=True, help="Migration name.") @click.option("--name", default="update", show_default=True, help="Migrate name.")
@click.option("--empty", default=False, is_flag=True, help="Generate an empty migration file.")
@click.pass_context @click.pass_context
async def migrate(ctx: Context, name, empty) -> None: @coro
async def migrate(ctx: Context, name):
command = ctx.obj["command"] command = ctx.obj["command"]
ret = await command.migrate(name, empty) ret = await command.migrate(name)
if not ret: if not ret:
return click.secho("No changes detected", fg=Color.yellow) return click.secho("No changes detected", fg=Color.yellow)
click.secho(f"Success creating migration file {ret}", fg=Color.green) click.secho(f"Success migrate {ret}", fg=Color.green)
@cli.command(help="Upgrade to specified migration version.") @cli.command(help="Upgrade to specified version.")
@click.option( @click.option(
"--in-transaction", "--in-transaction",
"-i", "-i",
default=True, default=True,
type=bool, type=bool,
help="Make migrations in a single transaction or not. Can be helpful for large migrations or creating concurrent indexes.", help="Make migrations in transaction or not. Can be helpful for large migrations or creating concurrent indexes.",
) )
@click.pass_context @click.pass_context
async def upgrade(ctx: Context, in_transaction: bool) -> None: @coro
async def upgrade(ctx: Context, in_transaction: bool):
command = ctx.obj["command"] command = ctx.obj["command"]
migrated = await command.upgrade(run_in_transaction=in_transaction) migrated = await command.upgrade(run_in_transaction=in_transaction)
if not migrated: if not migrated:
click.secho("No upgrade items found", fg=Color.yellow) click.secho("No upgrade items found", fg=Color.yellow)
else: else:
for version_file in migrated: for version_file in migrated:
click.secho(f"Success upgrading to {version_file}", fg=Color.green) click.secho(f"Success upgrade {version_file}", fg=Color.green)
@cli.command(help="Downgrade to specified version.") @cli.command(help="Downgrade to specified version.")
@@ -101,8 +115,8 @@ async def upgrade(ctx: Context, in_transaction: bool) -> None:
"--version", "--version",
default=-1, default=-1,
type=int, type=int,
show_default=False, show_default=True,
help="Specified version, default to last migration.", help="Specified version, default to last.",
) )
@click.option( @click.option(
"-d", "-d",
@@ -110,56 +124,59 @@ async def upgrade(ctx: Context, in_transaction: bool) -> None:
is_flag=True, is_flag=True,
default=False, default=False,
show_default=True, show_default=True,
help="Also delete the migration files.", help="Delete version files at the same time.",
) )
@click.pass_context @click.pass_context
@click.confirmation_option( @click.confirmation_option(
prompt="Downgrade is dangerous: you might lose your data! Are you sure?", prompt="Downgrade is dangerous, which maybe lose your data, are you sure?",
) )
async def downgrade(ctx: Context, version: int, delete: bool) -> None: @coro
async def downgrade(ctx: Context, version: int, delete: bool):
command = ctx.obj["command"] command = ctx.obj["command"]
try: try:
files = await command.downgrade(version, delete) files = await command.downgrade(version, delete)
except DowngradeError as e: except DowngradeError as e:
return click.secho(str(e), fg=Color.yellow) return click.secho(str(e), fg=Color.yellow)
for file in files: for file in files:
click.secho(f"Success downgrading to {file}", fg=Color.green) click.secho(f"Success downgrade {file}", fg=Color.green)
@cli.command(help="Show currently available heads (unapplied migrations).") @cli.command(help="Show current available heads in migrate location.")
@click.pass_context @click.pass_context
async def heads(ctx: Context) -> None: @coro
async def heads(ctx: Context):
command = ctx.obj["command"] command = ctx.obj["command"]
head_list = await command.heads() head_list = await command.heads()
if not head_list: if not head_list:
return click.secho("No available heads.", fg=Color.green) return click.secho("No available heads, try migrate first", fg=Color.green)
for version in head_list: for version in head_list:
click.secho(version, fg=Color.green) click.secho(version, fg=Color.green)
@cli.command(help="List all migrations.") @cli.command(help="List all migrate items.")
@click.pass_context @click.pass_context
async def history(ctx: Context) -> None: @coro
async def history(ctx: Context):
command = ctx.obj["command"] command = ctx.obj["command"]
versions = await command.history() versions = await command.history()
if not versions: if not versions:
return click.secho("No migrations created yet.", fg=Color.green) return click.secho("No history, try migrate", fg=Color.green)
for version in versions: for version in versions:
click.secho(version, fg=Color.green) click.secho(version, fg=Color.green)
@cli.command(help="Initialize aerich config and create migrations folder.") @cli.command(help="Init config file and generate root migrate location.")
@click.option( @click.option(
"-t", "-t",
"--tortoise-orm", "--tortoise-orm",
required=True, required=True,
help="Tortoise-ORM config dict location, like `settings.TORTOISE_ORM`.", help="Tortoise-ORM config module dict variable, like settings.TORTOISE_ORM.",
) )
@click.option( @click.option(
"--location", "--location",
default="./migrations", default="./migrations",
show_default=True, show_default=True,
help="Migrations folder.", help="Migrate store location.",
) )
@click.option( @click.option(
"-s", "-s",
@@ -169,7 +186,8 @@ async def history(ctx: Context) -> None:
help="Folder of the source, relative to the project root.", help="Folder of the source, relative to the project root.",
) )
@click.pass_context @click.pass_context
async def init(ctx: Context, tortoise_orm, location, src_folder) -> None: @coro
async def init(ctx: Context, tortoise_orm, location, src_folder):
config_file = ctx.obj["config_file"] config_file = ctx.obj["config_file"]
if os.path.isabs(src_folder): if os.path.isabs(src_folder):
@@ -184,9 +202,9 @@ async def init(ctx: Context, tortoise_orm, location, src_folder) -> None:
config_path = Path(config_file) config_path = Path(config_file)
if config_path.exists(): if config_path.exists():
content = config_path.read_text() content = config_path.read_text()
doc = tomlkit.parse(content)
else: else:
content = "[tool.aerich]" doc = tomlkit.parse("[tool.aerich]")
doc: dict = tomlkit.parse(content)
table = tomlkit.table() table = tomlkit.table()
table["tortoise_orm"] = tortoise_orm table["tortoise_orm"] = tortoise_orm
table["location"] = location table["location"] = location
@@ -197,36 +215,37 @@ async def init(ctx: Context, tortoise_orm, location, src_folder) -> None:
Path(location).mkdir(parents=True, exist_ok=True) Path(location).mkdir(parents=True, exist_ok=True)
click.secho(f"Success creating migrations folder {location}", fg=Color.green) click.secho(f"Success create migrate location {location}", fg=Color.green)
click.secho(f"Success writing aerich config to {config_file}", fg=Color.green) click.secho(f"Success write config to {config_file}", fg=Color.green)
@cli.command(help="Generate schema and generate app migration folder.") @cli.command(help="Generate schema and generate app migrate location.")
@click.option( @click.option(
"-s", "-s",
"--safe", "--safe",
type=bool, type=bool,
is_flag=True, is_flag=True,
default=True, default=True,
help="Create tables only when they do not already exist.", help="When set to true, creates the table only when it does not already exist.",
show_default=True, show_default=True,
) )
@click.pass_context @click.pass_context
async def init_db(ctx: Context, safe: bool) -> None: @coro
async def init_db(ctx: Context, safe: bool):
command = ctx.obj["command"] command = ctx.obj["command"]
app = command.app app = command.app
dirname = Path(command.location, app) dirname = Path(command.location, app)
try: try:
await command.init_db(safe) await command.init_db(safe)
click.secho(f"Success creating app migration folder {dirname}", fg=Color.green) click.secho(f"Success create app migrate location {dirname}", fg=Color.green)
click.secho(f'Success generating initial migration file for app "{app}"', fg=Color.green) click.secho(f'Success generate schema for app "{app}"', fg=Color.green)
except FileExistsError: except FileExistsError:
return click.secho( return click.secho(
f"App {app} is already initialized. Delete {dirname} and try again.", fg=Color.yellow f"Inited {app} already, or delete {dirname} and try again.", fg=Color.yellow
) )
@cli.command(help="Prints the current database tables to stdout as Tortoise-ORM models.") @cli.command(help="Introspects the database tables to standard output as TortoiseORM model.")
@click.option( @click.option(
"-t", "-t",
"--table", "--table",
@@ -235,13 +254,14 @@ async def init_db(ctx: Context, safe: bool) -> None:
required=False, required=False,
) )
@click.pass_context @click.pass_context
async def inspectdb(ctx: Context, table: List[str]) -> None: @coro
async def inspectdb(ctx: Context, table: List[str]):
command = ctx.obj["command"] command = ctx.obj["command"]
ret = await command.inspectdb(table) ret = await command.inspectdb(table)
click.secho(ret) click.secho(ret)
def main() -> None: def main():
cli() cli()

View File

@@ -1,13 +1,12 @@
import base64 import base64
import json import json
import pickle # nosec: B301,B403 import pickle # nosec: B301,B403
from typing import Any, Union
from tortoise.indexes import Index from tortoise.indexes import Index
class JsonEncoder(json.JSONEncoder): class JsonEncoder(json.JSONEncoder):
def default(self, obj) -> Any: def default(self, obj):
if isinstance(obj, Index): if isinstance(obj, Index):
return { return {
"type": "index", "type": "index",
@@ -17,16 +16,16 @@ class JsonEncoder(json.JSONEncoder):
return super().default(obj) return super().default(obj)
def object_hook(obj) -> Any: def object_hook(obj):
_type = obj.get("type") _type = obj.get("type")
if not _type: if not _type:
return obj return obj
return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301 return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301
def encoder(obj: dict) -> str: def encoder(obj: dict):
return json.dumps(obj, cls=JsonEncoder) return json.dumps(obj, cls=JsonEncoder)
def decoder(obj: Union[str, bytes]) -> Any: def decoder(obj: str):
return json.loads(obj, object_hook=object_hook) return json.loads(obj, object_hook=object_hook)

View File

@@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Any, List, Type, cast from typing import List, Type
from tortoise import BaseDBAsyncClient, Model from tortoise import BaseDBAsyncClient, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.backends.base.schema_generator import BaseSchemaGenerator
@@ -35,26 +35,25 @@ class BaseDDL:
) )
_RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"' _RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"'
def __init__(self, client: "BaseDBAsyncClient") -> None: def __init__(self, client: "BaseDBAsyncClient"):
self.client = client self.client = client
self.schema_generator = self.schema_generator_cls(client) self.schema_generator = self.schema_generator_cls(client)
def create_table(self, model: "Type[Model]") -> str: def create_table(self, model: "Type[Model]"):
return self.schema_generator._get_table_sql(model, True)["table_creation_string"].rstrip( return self.schema_generator._get_table_sql(model, True)["table_creation_string"].rstrip(
";" ";"
) )
def drop_table(self, table_name: str) -> str: def drop_table(self, table_name: str):
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name) return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def create_m2m( def create_m2m(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
) -> str: ):
through = cast(str, field_describe.get("through")) through = field_describe.get("through")
description = field_describe.get("description") description = field_describe.get("description")
pk_field = cast(dict, reference_table_describe.get("pk_field")) reference_id = reference_table_describe.get("pk_field").get("db_column")
reference_id = pk_field.get("db_column") db_field_types = reference_table_describe.get("pk_field").get("db_field_types")
db_field_types = cast(dict, pk_field.get("db_field_types"))
return self._M2M_TABLE_TEMPLATE.format( return self._M2M_TABLE_TEMPLATE.format(
table_name=through, table_name=through,
backward_table=model._meta.db_table, backward_table=model._meta.db_table,
@@ -67,22 +66,22 @@ class BaseDDL:
forward_type=db_field_types.get(self.DIALECT) or db_field_types.get(""), forward_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
on_delete=field_describe.get("on_delete"), on_delete=field_describe.get("on_delete"),
extra=self.schema_generator._table_generate_extra(table=through), extra=self.schema_generator._table_generate_extra(table=through),
comment=( comment=self.schema_generator._table_comment_generator(
self.schema_generator._table_comment_generator(table=through, comment=description) table=through, comment=description
)
if description if description
else "" else "",
),
) )
def drop_m2m(self, table_name: str) -> str: def drop_m2m(self, table_name: str):
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name) return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def _get_default(self, model: "Type[Model]", field_describe: dict) -> Any: def _get_default(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table db_table = model._meta.db_table
default = field_describe.get("default") default = field_describe.get("default")
if isinstance(default, Enum): if isinstance(default, Enum):
default = default.value default = default.value
db_column = cast(str, field_describe.get("db_column")) db_column = field_describe.get("db_column")
auto_now_add = field_describe.get("auto_now_add", False) auto_now_add = field_describe.get("auto_now_add", False)
auto_now = field_describe.get("auto_now", False) auto_now = field_describe.get("auto_now", False)
if default is not None or auto_now_add: if default is not None or auto_now_add:
@@ -107,55 +106,64 @@ class BaseDDL:
default = None default = None
return default return default
def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str: def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
return self._add_or_modify_column(model, field_describe, is_pk)
def _add_or_modify_column(self, model, field_describe: dict, is_pk: bool, modify=False) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
description = field_describe.get("description") description = field_describe.get("description")
db_column = cast(str, field_describe.get("db_column")) db_column = field_describe.get("db_column")
db_field_types = cast(dict, field_describe.get("db_field_types")) db_field_types = field_describe.get("db_field_types")
default = self._get_default(model, field_describe) default = self._get_default(model, field_describe)
if default is None: if default is None:
default = "" default = ""
if modify: return self._ADD_COLUMN_TEMPLATE.format(
unique = ""
template = self._MODIFY_COLUMN_TEMPLATE
else:
unique = "UNIQUE" if field_describe.get("unique") else ""
template = self._ADD_COLUMN_TEMPLATE
return template.format(
table_name=db_table, table_name=db_table,
column=self.schema_generator._create_string( column=self.schema_generator._create_string(
db_column=db_column, db_column=db_column,
field_type=db_field_types.get(self.DIALECT, db_field_types.get("")), field_type=db_field_types.get(self.DIALECT, db_field_types.get("")),
nullable="NOT NULL" if not field_describe.get("nullable") else "", nullable="NOT NULL" if not field_describe.get("nullable") else "",
unique=unique, unique="UNIQUE" if field_describe.get("unique") else "",
comment=( comment=self.schema_generator._column_comment_generator(
self.schema_generator._column_comment_generator(
table=db_table, table=db_table,
column=db_column, column=db_column,
comment=description, comment=field_describe.get("description"),
) )
if description if description
else "" else "",
),
is_primary_key=is_pk, is_primary_key=is_pk,
default=default, default=default,
), ),
) )
def drop_column(self, model: "Type[Model]", column_name: str) -> str: def drop_column(self, model: "Type[Model]", column_name: str):
return self._DROP_COLUMN_TEMPLATE.format( return self._DROP_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, column_name=column_name table_name=model._meta.db_table, column_name=column_name
) )
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str: def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
return self._add_or_modify_column(model, field_describe, is_pk, modify=True) db_table = model._meta.db_table
db_field_types = field_describe.get("db_field_types")
default = self._get_default(model, field_describe)
if default is None:
default = ""
return self._MODIFY_COLUMN_TEMPLATE.format(
table_name=db_table,
column=self.schema_generator._create_string(
db_column=field_describe.get("db_column"),
field_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
nullable="NOT NULL" if not field_describe.get("nullable") else "",
unique="",
comment=self.schema_generator._column_comment_generator(
table=db_table,
column=field_describe.get("db_column"),
comment=field_describe.get("description"),
)
if field_describe.get("description")
else "",
is_primary_key=is_pk,
default=default,
),
)
def rename_column( def rename_column(self, model: "Type[Model]", old_column_name: str, new_column_name: str):
self, model: "Type[Model]", old_column_name: str, new_column_name: str
) -> str:
return self._RENAME_COLUMN_TEMPLATE.format( return self._RENAME_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, table_name=model._meta.db_table,
old_column_name=old_column_name, old_column_name=old_column_name,
@@ -164,7 +172,7 @@ class BaseDDL:
def change_column( def change_column(
self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str
) -> str: ):
return self._CHANGE_COLUMN_TEMPLATE.format( return self._CHANGE_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, table_name=model._meta.db_table,
old_column_name=old_column_name, old_column_name=old_column_name,
@@ -172,7 +180,7 @@ class BaseDDL:
new_column_type=new_column_type, new_column_type=new_column_type,
) )
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str: def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
return self._ADD_INDEX_TEMPLATE.format( return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE " if unique else "", unique="UNIQUE " if unique else "",
index_name=self.schema_generator._generate_index_name( index_name=self.schema_generator._generate_index_name(
@@ -182,7 +190,7 @@ class BaseDDL:
column_names=", ".join(self.schema_generator.quote(f) for f in field_names), column_names=", ".join(self.schema_generator.quote(f) for f in field_names),
) )
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str: def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False):
return self._DROP_INDEX_TEMPLATE.format( return self._DROP_INDEX_TEMPLATE.format(
index_name=self.schema_generator._generate_index_name( index_name=self.schema_generator._generate_index_name(
"idx" if not unique else "uid", model, field_names "idx" if not unique else "uid", model, field_names
@@ -190,52 +198,45 @@ class BaseDDL:
table_name=model._meta.db_table, table_name=model._meta.db_table,
) )
def drop_index_by_name(self, model: "Type[Model]", index_name: str) -> str: def drop_index_by_name(self, model: "Type[Model]", index_name: str):
return self._DROP_INDEX_TEMPLATE.format( return self._DROP_INDEX_TEMPLATE.format(
index_name=index_name, index_name=index_name,
table_name=model._meta.db_table, table_name=model._meta.db_table,
) )
def _generate_fk_name( def add_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict):
self, db_table, field_describe: dict, reference_table_describe: dict
) -> str:
"""Generate fk name"""
db_column = cast(str, field_describe.get("raw_field"))
pk_field = cast(dict, reference_table_describe.get("pk_field"))
to_field = cast(str, pk_field.get("db_column"))
to_table = cast(str, reference_table_describe.get("table"))
return self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=db_column,
to_table=to_table,
to_field=to_field,
)
def add_fk(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
db_column = field_describe.get("raw_field") db_column = field_describe.get("raw_field")
pk_field = cast(dict, reference_table_describe.get("pk_field")) reference_id = reference_table_describe.get("pk_field").get("db_column")
reference_id = pk_field.get("db_column") fk_name = self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=db_column,
to_table=reference_table_describe.get("table"),
to_field=reference_table_describe.get("pk_field").get("db_column"),
)
return self._ADD_FK_TEMPLATE.format( return self._ADD_FK_TEMPLATE.format(
table_name=db_table, table_name=db_table,
fk_name=self._generate_fk_name(db_table, field_describe, reference_table_describe), fk_name=fk_name,
db_column=db_column, db_column=db_column,
table=reference_table_describe.get("table"), table=reference_table_describe.get("table"),
field=reference_id, field=reference_id,
on_delete=field_describe.get("on_delete"), on_delete=field_describe.get("on_delete"),
) )
def drop_fk( def drop_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict):
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
fk_name = self._generate_fk_name(db_table, field_describe, reference_table_describe) return self._DROP_FK_TEMPLATE.format(
return self._DROP_FK_TEMPLATE.format(table_name=db_table, fk_name=fk_name) table_name=db_table,
fk_name=self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=field_describe.get("raw_field"),
to_table=reference_table_describe.get("table"),
to_field=reference_table_describe.get("pk_field").get("db_column"),
),
)
def alter_column_default(self, model: "Type[Model]", field_describe: dict) -> str: def alter_column_default(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table db_table = model._meta.db_table
default = self._get_default(model, field_describe) default = self._get_default(model, field_describe)
return self._ALTER_DEFAULT_TEMPLATE.format( return self._ALTER_DEFAULT_TEMPLATE.format(
@@ -244,13 +245,13 @@ class BaseDDL:
default="SET" + default if default is not None else "DROP DEFAULT", default="SET" + default if default is not None else "DROP DEFAULT",
) )
def alter_column_null(self, model: "Type[Model]", field_describe: dict) -> str: def alter_column_null(self, model: "Type[Model]", field_describe: dict):
return self.modify_column(model, field_describe) return self.modify_column(model, field_describe)
def set_comment(self, model: "Type[Model]", field_describe: dict) -> str: def set_comment(self, model: "Type[Model]", field_describe: dict):
return self.modify_column(model, field_describe) return self.modify_column(model, field_describe)
def rename_table(self, model: "Type[Model]", old_table_name: str, new_table_name: str) -> str: def rename_table(self, model: "Type[Model]", old_table_name: str, new_table_name: str):
db_table = model._meta.db_table db_table = model._meta.db_table
return self._RENAME_TABLE_TEMPLATE.format( return self._RENAME_TABLE_TEMPLATE.format(
table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name

View File

@@ -1,12 +1,7 @@
from typing import TYPE_CHECKING, List, Type
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
if TYPE_CHECKING:
from tortoise import Model # noqa:F401
class MysqlDDL(BaseDDL): class MysqlDDL(BaseDDL):
schema_generator_cls = MySQLSchemaGenerator schema_generator_cls = MySQLSchemaGenerator
@@ -35,29 +30,3 @@ class MysqlDDL(BaseDDL):
) )
_MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}" _MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}"
_RENAME_TABLE_TEMPLATE = "ALTER TABLE `{old_table_name}` RENAME TO `{new_table_name}`" _RENAME_TABLE_TEMPLATE = "ALTER TABLE `{old_table_name}` RENAME TO `{new_table_name}`"
def _index_name(self, unique: bool, model: "Type[Model]", field_names: List[str]) -> str:
if unique:
if len(field_names) == 1:
# Example: `email = CharField(max_length=50, unique=True)`
# Generate schema: `"email" VARCHAR(10) NOT NULL UNIQUE`
# Unique index key is the same as field name: `email`
return field_names[0]
index_prefix = "uid"
else:
index_prefix = "idx"
return self.schema_generator._generate_index_name(index_prefix, model, field_names)
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str:
return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE " if unique else "",
index_name=self._index_name(unique, model, field_names),
table_name=model._meta.db_table,
column_names=", ".join(self.schema_generator.quote(f) for f in field_names),
)
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str:
return self._DROP_INDEX_TEMPLATE.format(
index_name=self._index_name(unique, model, field_names),
table_name=model._meta.db_table,
)

View File

@@ -1,4 +1,4 @@
from typing import Type, cast from typing import Type
from tortoise import Model from tortoise import Model
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
@@ -18,7 +18,7 @@ class PostgresDDL(BaseDDL):
_SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}' _SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"'
def alter_column_null(self, model: "Type[Model]", field_describe: dict) -> str: def alter_column_null(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table db_table = model._meta.db_table
return self._ALTER_NULL_TEMPLATE.format( return self._ALTER_NULL_TEMPLATE.format(
table_name=db_table, table_name=db_table,
@@ -26,9 +26,9 @@ class PostgresDDL(BaseDDL):
set_drop="DROP" if field_describe.get("nullable") else "SET", set_drop="DROP" if field_describe.get("nullable") else "SET",
) )
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str: def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
db_table = model._meta.db_table db_table = model._meta.db_table
db_field_types = cast(dict, field_describe.get("db_field_types")) db_field_types = field_describe.get("db_field_types")
db_column = field_describe.get("db_column") db_column = field_describe.get("db_column")
datatype = db_field_types.get(self.DIALECT) or db_field_types.get("") datatype = db_field_types.get(self.DIALECT) or db_field_types.get("")
return self._MODIFY_COLUMN_TEMPLATE.format( return self._MODIFY_COLUMN_TEMPLATE.format(
@@ -38,14 +38,12 @@ class PostgresDDL(BaseDDL):
using=f' USING "{db_column}"::{datatype}', using=f' USING "{db_column}"::{datatype}',
) )
def set_comment(self, model: "Type[Model]", field_describe: dict) -> str: def set_comment(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table db_table = model._meta.db_table
return self._SET_COMMENT_TEMPLATE.format( return self._SET_COMMENT_TEMPLATE.format(
table_name=db_table, table_name=db_table,
column=field_describe.get("db_column") or field_describe.get("raw_field"), column=field_describe.get("db_column") or field_describe.get("raw_field"),
comment=( comment="'{}'".format(field_describe.get("description"))
"'{}'".format(field_describe.get("description"))
if field_describe.get("description") if field_describe.get("description")
else "NULL" else "NULL",
),
) )

View File

@@ -1,39 +1,24 @@
from __future__ import annotations from typing import Any, List, Optional
from typing import Any, Callable, Dict, Optional, TypedDict
from pydantic import BaseModel from pydantic import BaseModel
from tortoise import BaseDBAsyncClient from tortoise import BaseDBAsyncClient
class ColumnInfoDict(TypedDict):
name: str
pk: str
index: str
null: str
default: str
length: str
comment: str
FieldMapDict = Dict[str, Callable[..., str]]
class Column(BaseModel): class Column(BaseModel):
name: str name: str
data_type: str data_type: str
null: bool null: bool
default: Any default: Any
comment: Optional[str] = None comment: Optional[str]
pk: bool pk: bool
unique: bool unique: bool
index: bool index: bool
length: Optional[int] = None length: Optional[int]
extra: Optional[str] = None extra: Optional[str]
decimal_places: Optional[int] = None decimal_places: Optional[int]
max_digits: Optional[int] = None max_digits: Optional[int]
def translate(self) -> ColumnInfoDict: def translate(self) -> dict:
comment = default = length = index = null = pk = "" comment = default = length = index = null = pk = ""
if self.pk: if self.pk:
pk = "pk=True, " pk = "pk=True, "
@@ -43,24 +28,23 @@ class Column(BaseModel):
else: else:
if self.index: if self.index:
index = "index=True, " index = "index=True, "
if self.data_type in ("varchar", "VARCHAR"): if self.data_type in ["varchar", "VARCHAR"]:
length = f"max_length={self.length}, " length = f"max_length={self.length}, "
elif self.data_type in ("decimal", "numeric"): if self.data_type in ["decimal", "numeric"]:
length_parts = [] length_parts = []
if self.max_digits: if self.max_digits:
length_parts.append(f"max_digits={self.max_digits}") length_parts.append(f"max_digits={self.max_digits}")
if self.decimal_places: if self.decimal_places:
length_parts.append(f"decimal_places={self.decimal_places}") length_parts.append(f"decimal_places={self.decimal_places}")
if length_parts: length = ", ".join(length_parts)
length = ", ".join(length_parts) + ", "
if self.null: if self.null:
null = "null=True, " null = "null=True, "
if self.default is not None: if self.default is not None:
if self.data_type in ("tinyint", "INT"): if self.data_type in ["tinyint", "INT"]:
default = f"default={'True' if self.default == '1' else 'False'}, " default = f"default={'True' if self.default == '1' else 'False'}, "
elif self.data_type == "bool": elif self.data_type == "bool":
default = f"default={'True' if self.default == 'true' else 'False'}, " default = f"default={'True' if self.default == 'true' else 'False'}, "
elif self.data_type in ("datetime", "timestamptz", "TIMESTAMP"): elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]:
if "CURRENT_TIMESTAMP" == self.default: if "CURRENT_TIMESTAMP" == self.default:
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra: if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
default = "auto_now=True, " default = "auto_now=True, "
@@ -71,8 +55,6 @@ class Column(BaseModel):
default = f"default={self.default.split('::')[0]}, " default = f"default={self.default.split('::')[0]}, "
elif self.default.endswith("()"): elif self.default.endswith("()"):
default = "" default = ""
elif self.default == "":
default = 'default=""'
else: else:
default = f"default={self.default}, " default = f"default={self.default}, "
@@ -92,16 +74,16 @@ class Column(BaseModel):
class Inspect: class Inspect:
_table_template = "class {table}(Model):\n" _table_template = "class {table}(Model):\n"
def __init__(self, conn: BaseDBAsyncClient, tables: list[str] | None = None) -> None: def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
self.conn = conn self.conn = conn
try: try:
self.database = conn.database # type:ignore[attr-defined] self.database = conn.database
except AttributeError: except AttributeError:
pass pass
self.tables = tables self.tables = tables
@property @property
def field_map(self) -> FieldMapDict: def field_map(self) -> dict:
raise NotImplementedError raise NotImplementedError
async def inspect(self) -> str: async def inspect(self) -> str:
@@ -119,10 +101,10 @@ class Inspect:
tables.append(model + "\n".join(fields)) tables.append(model + "\n".join(fields))
return result + "\n\n\n".join(tables) return result + "\n\n\n".join(tables)
async def get_columns(self, table: str) -> list[Column]: async def get_columns(self, table: str) -> List[Column]:
raise NotImplementedError raise NotImplementedError
async def get_all_tables(self) -> list[str]: async def get_all_tables(self) -> List[str]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod

View File

@@ -1,18 +1,17 @@
from __future__ import annotations from typing import List
from aerich.inspectdb import Column, FieldMapDict, Inspect from aerich.inspectdb import Column, Inspect
class InspectMySQL(Inspect): class InspectMySQL(Inspect):
@property @property
def field_map(self) -> FieldMapDict: def field_map(self) -> dict:
return { return {
"int": self.int_field, "int": self.int_field,
"smallint": self.smallint_field, "smallint": self.smallint_field,
"tinyint": self.bool_field, "tinyint": self.bool_field,
"bigint": self.bigint_field, "bigint": self.bigint_field,
"varchar": self.char_field, "varchar": self.char_field,
"char": self.char_field,
"longtext": self.text_field, "longtext": self.text_field,
"text": self.text_field, "text": self.text_field,
"datetime": self.datetime_field, "datetime": self.datetime_field,
@@ -24,12 +23,12 @@ class InspectMySQL(Inspect):
"longblob": self.binary_field, "longblob": self.binary_field,
} }
async def get_all_tables(self) -> list[str]: async def get_all_tables(self) -> List[str]:
sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s" sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s"
ret = await self.conn.execute_query_dict(sql, [self.database]) ret = await self.conn.execute_query_dict(sql, [self.database])
return list(map(lambda x: x["TABLE_NAME"], ret)) return list(map(lambda x: x["TABLE_NAME"], ret))
async def get_columns(self, table: str) -> list[Column]: async def get_columns(self, table: str) -> List[Column]:
columns = [] columns = []
sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME
from information_schema.COLUMNS c from information_schema.COLUMNS c
@@ -60,8 +59,7 @@ where c.TABLE_SCHEMA = %s
comment=row["COLUMN_COMMENT"], comment=row["COLUMN_COMMENT"],
unique=row["COLUMN_KEY"] == "UNI", unique=row["COLUMN_KEY"] == "UNI",
extra=row["EXTRA"], extra=row["EXTRA"],
# TODO: why `unque`? unque=unique,
unque=unique, # type:ignore
index=index, index=index,
length=row["CHARACTER_MAXIMUM_LENGTH"], length=row["CHARACTER_MAXIMUM_LENGTH"],
max_digits=row["NUMERIC_PRECISION"], max_digits=row["NUMERIC_PRECISION"],

View File

@@ -1,20 +1,17 @@
from __future__ import annotations from typing import List, Optional
from typing import TYPE_CHECKING from tortoise import BaseDBAsyncClient
from aerich.inspectdb import Column, FieldMapDict, Inspect from aerich.inspectdb import Column, Inspect
if TYPE_CHECKING:
from tortoise.backends.base_postgres.client import BasePostgresClient
class InspectPostgres(Inspect): class InspectPostgres(Inspect):
def __init__(self, conn: "BasePostgresClient", tables: list[str] | None = None) -> None: def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
super().__init__(conn, tables) super().__init__(conn, tables)
self.schema = conn.server_settings.get("schema") or "public" self.schema = self.conn.server_settings.get("schema") or "public"
@property @property
def field_map(self) -> FieldMapDict: def field_map(self) -> dict:
return { return {
"int4": self.int_field, "int4": self.int_field,
"int8": self.int_field, "int8": self.int_field,
@@ -36,12 +33,12 @@ class InspectPostgres(Inspect):
"timestamp": self.datetime_field, "timestamp": self.datetime_field,
} }
async def get_all_tables(self) -> list[str]: async def get_all_tables(self) -> List[str]:
sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2" sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2"
ret = await self.conn.execute_query_dict(sql, [self.database, self.schema]) ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
return list(map(lambda x: x["table_name"], ret)) return list(map(lambda x: x["table_name"], ret))
async def get_columns(self, table: str) -> list[Column]: async def get_columns(self, table: str) -> List[Column]:
columns = [] columns = []
sql = f"""select c.column_name, sql = f"""select c.column_name,
col_description('public.{table}'::regclass, ordinal_position) as column_comment, col_description('public.{table}'::regclass, ordinal_position) as column_comment,
@@ -58,7 +55,7 @@ from information_schema.constraint_column_usage const
right join information_schema.columns c using (column_name, table_catalog, table_schema, table_name) right join information_schema.columns c using (column_name, table_catalog, table_schema, table_name)
where c.table_catalog = $1 where c.table_catalog = $1
and c.table_name = $2 and c.table_name = $2
and c.table_schema = $3""" # nosec:B608 and c.table_schema = $3"""
ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema]) ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema])
for row in ret: for row in ret:
columns.append( columns.append(

View File

@@ -1,11 +1,11 @@
from __future__ import annotations from typing import List
from aerich.inspectdb import Column, FieldMapDict, Inspect from aerich.inspectdb import Column, Inspect
class InspectSQLite(Inspect): class InspectSQLite(Inspect):
@property @property
def field_map(self) -> FieldMapDict: def field_map(self) -> dict:
return { return {
"INTEGER": self.int_field, "INTEGER": self.int_field,
"INT": self.bool_field, "INT": self.bool_field,
@@ -21,7 +21,7 @@ class InspectSQLite(Inspect):
"BLOB": self.binary_field, "BLOB": self.binary_field,
} }
async def get_columns(self, table: str) -> list[Column]: async def get_columns(self, table: str) -> List[Column]:
columns = [] columns = []
sql = f"PRAGMA table_info({table})" sql = f"PRAGMA table_info({table})"
ret = await self.conn.execute_query_dict(sql) ret = await self.conn.execute_query_dict(sql)
@@ -45,7 +45,7 @@ class InspectSQLite(Inspect):
) )
return columns return columns
async def _get_columns_index(self, table: str) -> dict[str, str]: async def _get_columns_index(self, table: str):
sql = f"PRAGMA index_list ({table})" sql = f"PRAGMA index_list ({table})"
indexes = await self.conn.execute_query_dict(sql) indexes = await self.conn.execute_query_dict(sql)
ret = {} ret = {}
@@ -55,7 +55,7 @@ class InspectSQLite(Inspect):
ret[index_info["name"]] = "unique" if index["unique"] else "index" ret[index_info["name"]] = "unique" if index["unique"] else "index"
return ret return ret
async def get_all_tables(self) -> list[str]: async def get_all_tables(self) -> List[str]:
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'" sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
ret = await self.conn.execute_query_dict(sql) ret = await self.conn.execute_query_dict(sql)
return list(map(lambda x: x["tbl_name"], ret)) return list(map(lambda x: x["tbl_name"], ret))

View File

@@ -1,11 +1,11 @@
import hashlib
import importlib import importlib
import os import os
from datetime import datetime from datetime import datetime
from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast from typing import Dict, List, Optional, Tuple, Type, Union
import asyncclick as click import click
from dictdiffer import diff from dictdiffer import diff
from tortoise import BaseDBAsyncClient, Model, Tortoise from tortoise import BaseDBAsyncClient, Model, Tortoise
from tortoise.exceptions import OperationalError from tortoise.exceptions import OperationalError
@@ -37,21 +37,16 @@ class Migrate:
_upgrade_m2m: List[str] = [] _upgrade_m2m: List[str] = []
_downgrade_m2m: List[str] = [] _downgrade_m2m: List[str] = []
_aerich = Aerich.__name__ _aerich = Aerich.__name__
_rename_old: List[str] = [] _rename_old = []
_rename_new: List[str] = [] _rename_new = []
ddl: BaseDDL ddl: BaseDDL
ddl_class: Type[BaseDDL]
_last_version_content: Optional[dict] = None _last_version_content: Optional[dict] = None
app: str app: str
migrate_location: Path migrate_location: Path
dialect: str dialect: str
_db_version: Optional[str] = None _db_version: Optional[str] = None
@staticmethod
def get_field_by_name(name: str, fields: List[dict]) -> dict:
return next(filter(lambda x: x.get("name") == name, fields))
@classmethod @classmethod
def get_all_version_files(cls) -> List[str]: def get_all_version_files(cls) -> List[str]:
return sorted( return sorted(
@@ -61,35 +56,35 @@ class Migrate:
@classmethod @classmethod
def _get_model(cls, model: str) -> Type[Model]: def _get_model(cls, model: str) -> Type[Model]:
return Tortoise.apps[cls.app][model] return Tortoise.apps.get(cls.app).get(model)
@classmethod @classmethod
async def get_last_version(cls) -> Optional[Aerich]: async def get_last_version(cls) -> Optional[Aerich]:
try: try:
return await Aerich.filter(app=cls.app).first() return await Aerich.filter(app=cls.app).first()
except OperationalError: except OperationalError:
return None pass
@classmethod @classmethod
async def _get_db_version(cls, connection: BaseDBAsyncClient) -> None: async def _get_db_version(cls, connection: BaseDBAsyncClient):
if cls.dialect == "mysql": if cls.dialect == "mysql":
sql = "select version() as version" sql = "select version() as version"
ret = await connection.execute_query(sql) ret = await connection.execute_query(sql)
cls._db_version = ret[1][0].get("version") cls._db_version = ret[1][0].get("version")
@classmethod @classmethod
async def load_ddl_class(cls) -> Type[BaseDDL]: async def load_ddl_class(cls):
ddl_dialect_module = importlib.import_module(f"aerich.ddl.{cls.dialect}") ddl_dialect_module = importlib.import_module(f"aerich.ddl.{cls.dialect}")
return getattr(ddl_dialect_module, f"{cls.dialect.capitalize()}DDL") return getattr(ddl_dialect_module, f"{cls.dialect.capitalize()}DDL")
@classmethod @classmethod
async def init(cls, config: dict, app: str, location: str) -> None: async def init(cls, config: dict, app: str, location: str):
await Tortoise.init(config=config) await Tortoise.init(config=config)
last_version = await cls.get_last_version() last_version = await cls.get_last_version()
cls.app = app cls.app = app
cls.migrate_location = Path(location, app) cls.migrate_location = Path(location, app)
if last_version: if last_version:
cls._last_version_content = cast(dict, last_version.content) cls._last_version_content = last_version.content
connection = get_app_connection(config, app) connection = get_app_connection(config, app)
cls.dialect = connection.schema_generator.DIALECT cls.dialect = connection.schema_generator.DIALECT
@@ -98,7 +93,7 @@ class Migrate:
await cls._get_db_version(connection) await cls._get_db_version(connection)
@classmethod @classmethod
async def _get_last_version_num(cls) -> Optional[int]: async def _get_last_version_num(cls):
last_version = await cls.get_last_version() last_version = await cls.get_last_version()
if not last_version: if not last_version:
return None return None
@@ -106,7 +101,7 @@ class Migrate:
return int(version.split("_", 1)[0]) return int(version.split("_", 1)[0])
@classmethod @classmethod
async def generate_version(cls, name=None) -> str: async def generate_version(cls, name=None):
now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "") now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "")
last_version_num = await cls._get_last_version_num() last_version_num = await cls._get_last_version_num()
if last_version_num is None: if last_version_num is None:
@@ -117,31 +112,33 @@ class Migrate:
return version return version
@classmethod @classmethod
async def _generate_diff_py(cls, name) -> str: async def _generate_diff_py(cls, name):
version = await cls.generate_version(name) version = await cls.generate_version(name)
# delete if same version exists # delete if same version exists
for version_file in cls.get_all_version_files(): for version_file in cls.get_all_version_files():
if version_file.startswith(version.split("_")[0]): if version_file.startswith(version.split("_")[0]):
os.unlink(Path(cls.migrate_location, version_file)) os.unlink(Path(cls.migrate_location, version_file))
content = cls._get_diff_file_content() version_file = Path(cls.migrate_location, version)
Path(cls.migrate_location, version).write_text(content, encoding="utf-8") content = MIGRATE_TEMPLATE.format(
upgrade_sql=";\n ".join(cls.upgrade_operators) + ";",
downgrade_sql=";\n ".join(cls.downgrade_operators) + ";",
)
with open(version_file, "w", encoding="utf-8") as f:
f.write(content)
return version return version
@classmethod @classmethod
async def migrate(cls, name: str, empty: bool) -> str: async def migrate(cls, name) -> str:
""" """
diff old models and new models to generate diff content diff old models and new models to generate diff content
:param name: str name for migration :param name:
:param empty: bool if True generates empty migration
:return: :return:
""" """
if empty:
return await cls._generate_diff_py(name)
new_version_content = get_models_describe(cls.app) new_version_content = get_models_describe(cls.app)
last_version = cast(dict, cls._last_version_content) cls.diff_models(cls._last_version_content, new_version_content)
cls.diff_models(last_version, new_version_content) cls.diff_models(new_version_content, cls._last_version_content, False)
cls.diff_models(new_version_content, last_version, False)
cls._merge_operators() cls._merge_operators()
@@ -151,23 +148,7 @@ class Migrate:
return await cls._generate_diff_py(name) return await cls._generate_diff_py(name)
@classmethod @classmethod
def _get_diff_file_content(cls) -> str: def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False):
"""
builds content for diff file from template
"""
def join_lines(lines: List[str]) -> str:
if not lines:
return ""
return ";\n ".join(lines) + ";"
return MIGRATE_TEMPLATE.format(
upgrade_sql=join_lines(cls.upgrade_operators),
downgrade_sql=join_lines(cls.downgrade_operators),
)
@classmethod
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False) -> None:
""" """
add operator,differentiate fk because fk is order limit add operator,differentiate fk because fk is order limit
:param operator: :param operator:
@@ -188,37 +169,19 @@ class Migrate:
cls.downgrade_operators.append(operator) cls.downgrade_operators.append(operator)
@classmethod @classmethod
def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]) -> list: def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]):
ret: list = [] ret = []
def index_hash(self) -> str:
h = hashlib.new("MD5", usedforsecurity=False) # type:ignore[call-arg]
h.update(
self.index_name(cls.ddl.schema_generator, model).encode()
+ self.__class__.__name__.encode()
)
return h.hexdigest()
for index in indexes: for index in indexes:
if isinstance(index, Index): if isinstance(index, Index):
index.__hash__ = index_hash # type:ignore[method-assign,assignment] index.__hash__ = lambda self: md5( # nosec: B303
self.index_name(cls.ddl.schema_generator, model).encode()
+ self.__class__.__name__.encode()
).hexdigest()
ret.append(index) ret.append(index)
return ret return ret
@classmethod @classmethod
def _get_indexes(cls, model, model_describe: dict) -> Set[Union[Index, Tuple[str, ...]]]: def diff_models(cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True):
indexes: Set[Union[Index, Tuple[str, ...]]] = set()
for x in cls._handle_indexes(model, model_describe.get("indexes", [])):
if isinstance(x, Index):
indexes.add(x)
else:
indexes.add(cast(Tuple[str, ...], tuple(x)))
return indexes
@classmethod
def diff_models(
cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True
) -> None:
""" """
diff models and add operators diff models and add operators
:param old_models: :param old_models:
@@ -231,35 +194,39 @@ class Migrate:
new_models.pop(_aerich, None) new_models.pop(_aerich, None)
for new_model_str, new_model_describe in new_models.items(): for new_model_str, new_model_describe in new_models.items():
model = cls._get_model(new_model_describe["name"].split(".")[1]) model = cls._get_model(new_model_describe.get("name").split(".")[1])
if new_model_str not in old_models: if new_model_str not in old_models.keys():
if upgrade: if upgrade:
cls._add_operator(cls.add_model(model), upgrade) cls._add_operator(cls.add_model(model), upgrade)
else: else:
# we can't find origin model when downgrade, so skip # we can't find origin model when downgrade, so skip
pass pass
else: else:
old_model_describe = cast(dict, old_models.get(new_model_str)) old_model_describe = old_models.get(new_model_str)
# rename table # rename table
new_table = cast(str, new_model_describe.get("table")) new_table = new_model_describe.get("table")
old_table = cast(str, old_model_describe.get("table")) old_table = old_model_describe.get("table")
if new_table != old_table: if new_table != old_table:
cls._add_operator(cls.rename_table(model, old_table, new_table), upgrade) cls._add_operator(cls.rename_table(model, old_table, new_table), upgrade)
old_unique_together = set( old_unique_together = set(
map( map(lambda x: tuple(x), old_model_describe.get("unique_together"))
lambda x: tuple(x),
cast(List[Iterable[str]], old_model_describe.get("unique_together")),
)
) )
new_unique_together = set( new_unique_together = set(
map(lambda x: tuple(x), new_model_describe.get("unique_together"))
)
old_indexes = set(
map( map(
lambda x: tuple(x), lambda x: x if isinstance(x, Index) else tuple(x),
cast(List[Iterable[str]], new_model_describe.get("unique_together")), cls._handle_indexes(model, old_model_describe.get("indexes", [])),
)
)
new_indexes = set(
map(
lambda x: x if isinstance(x, Index) else tuple(x),
cls._handle_indexes(model, new_model_describe.get("indexes", [])),
) )
) )
old_indexes = cls._get_indexes(model, old_model_describe)
new_indexes = cls._get_indexes(model, new_model_describe)
old_pk_field = old_model_describe.get("pk_field") old_pk_field = old_model_describe.get("pk_field")
new_pk_field = new_model_describe.get("pk_field") new_pk_field = new_model_describe.get("pk_field")
# pk field # pk field
@@ -269,19 +236,12 @@ class Migrate:
if action == "change" and option == "name": if action == "change" and option == "name":
cls._add_operator(cls._rename_field(model, *change), upgrade) cls._add_operator(cls._rename_field(model, *change), upgrade)
# m2m fields # m2m fields
old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields")) old_m2m_fields = old_model_describe.get("m2m_fields")
new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields")) new_m2m_fields = new_model_describe.get("m2m_fields")
for action, option, change in diff(old_m2m_fields, new_m2m_fields): for action, option, change in diff(old_m2m_fields, new_m2m_fields):
if change[0][0] == "db_constraint": if change[0][0] == "db_constraint":
continue continue
new_value = change[0][1] table = change[0][1].get("through")
if isinstance(new_value, str):
for new_m2m_field in new_m2m_fields:
if new_m2m_field["name"] == new_value:
table = cast(str, new_m2m_field.get("through"))
break
else:
table = new_value.get("through")
if action == "add": if action == "add":
add = False add = False
if upgrade and table not in cls._upgrade_m2m: if upgrade and table not in cls._upgrade_m2m:
@@ -291,9 +251,12 @@ class Migrate:
cls._downgrade_m2m.append(table) cls._downgrade_m2m.append(table)
add = True add = True
if add: if add:
ref_desc = cast(dict, new_models.get(new_value.get("model_name")))
cls._add_operator( cls._add_operator(
cls.create_m2m(model, new_value, ref_desc), cls.create_m2m(
model,
change[0][1],
new_models.get(change[0][1].get("model_name")),
),
upgrade, upgrade,
fk_m2m_index=True, fk_m2m_index=True,
) )
@@ -314,36 +277,38 @@ class Migrate:
for index in old_unique_together.difference(new_unique_together): for index in old_unique_together.difference(new_unique_together):
cls._add_operator(cls._drop_index(model, index, True), upgrade, True) cls._add_operator(cls._drop_index(model, index, True), upgrade, True)
# add indexes # add indexes
for idx in new_indexes.difference(old_indexes): for index in new_indexes.difference(old_indexes):
cls._add_operator(cls._add_index(model, idx, False), upgrade, True) cls._add_operator(cls._add_index(model, index, False), upgrade, True)
# remove indexes # remove indexes
for idx in old_indexes.difference(new_indexes): for index in old_indexes.difference(new_indexes):
cls._add_operator(cls._drop_index(model, idx, False), upgrade, True) cls._add_operator(cls._drop_index(model, index, False), upgrade, True)
old_data_fields = list( old_data_fields = list(
filter( filter(
lambda x: x.get("db_field_types") is not None, lambda x: x.get("db_field_types") is not None,
cast(List[dict], old_model_describe.get("data_fields")), old_model_describe.get("data_fields"),
) )
) )
new_data_fields = list( new_data_fields = list(
filter( filter(
lambda x: x.get("db_field_types") is not None, lambda x: x.get("db_field_types") is not None,
cast(List[dict], new_model_describe.get("data_fields")), new_model_describe.get("data_fields"),
) )
) )
old_data_fields_name = cast(List[str], [i.get("name") for i in old_data_fields]) old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields))
new_data_fields_name = cast(List[str], [i.get("name") for i in new_data_fields]) new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields))
# add fields or rename fields # add fields or rename fields
for new_data_field_name in set(new_data_fields_name).difference( for new_data_field_name in set(new_data_fields_name).difference(
set(old_data_fields_name) set(old_data_fields_name)
): ):
new_data_field = cls.get_field_by_name(new_data_field_name, new_data_fields) new_data_field = next(
filter(lambda x: x.get("name") == new_data_field_name, new_data_fields)
)
is_rename = False is_rename = False
for old_data_field in old_data_fields: for old_data_field in old_data_fields:
changes = list(diff(old_data_field, new_data_field)) changes = list(diff(old_data_field, new_data_field))
old_data_field_name = cast(str, old_data_field.get("name")) old_data_field_name = old_data_field.get("name")
if len(changes) == 2: if len(changes) == 2:
# rename field # rename field
if ( if (
@@ -404,7 +369,7 @@ class Migrate:
if new_data_field["indexed"]: if new_data_field["indexed"]:
cls._add_operator( cls._add_operator(
cls._add_index( cls._add_index(
model, (new_data_field["db_column"],), new_data_field["unique"] model, {new_data_field["db_column"]}, new_data_field["unique"]
), ),
upgrade, upgrade,
True, True,
@@ -418,35 +383,45 @@ class Migrate:
not upgrade and old_data_field_name in cls._rename_new not upgrade and old_data_field_name in cls._rename_new
): ):
continue continue
old_data_field = cls.get_field_by_name(old_data_field_name, old_data_fields) old_data_field = next(
db_column = cast(str, old_data_field["db_column"]) filter(lambda x: x.get("name") == old_data_field_name, old_data_fields)
)
db_column = old_data_field["db_column"]
cls._add_operator( cls._add_operator(
cls._remove_field(model, db_column), cls._remove_field(
model,
db_column,
),
upgrade, upgrade,
) )
if old_data_field["indexed"]: if old_data_field["indexed"]:
is_unique_field = old_data_field.get("unique")
cls._add_operator( cls._add_operator(
cls._drop_index(model, {db_column}, is_unique_field), cls._drop_index(
model,
{db_column},
),
upgrade, upgrade,
True, True,
) )
old_fk_fields = cast(List[dict], old_model_describe.get("fk_fields")) old_fk_fields = old_model_describe.get("fk_fields")
new_fk_fields = cast(List[dict], new_model_describe.get("fk_fields")) new_fk_fields = new_model_describe.get("fk_fields")
old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields] old_fk_fields_name = list(map(lambda x: x.get("name"), old_fk_fields))
new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields] new_fk_fields_name = list(map(lambda x: x.get("name"), new_fk_fields))
# add fk # add fk
for new_fk_field_name in set(new_fk_fields_name).difference( for new_fk_field_name in set(new_fk_fields_name).difference(
set(old_fk_fields_name) set(old_fk_fields_name)
): ):
fk_field = cls.get_field_by_name(new_fk_field_name, new_fk_fields) fk_field = next(
filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields)
)
if fk_field.get("db_constraint"): if fk_field.get("db_constraint"):
ref_describe = cast(dict, new_models[fk_field["python_type"]])
cls._add_operator( cls._add_operator(
cls._add_fk(model, fk_field, ref_describe), cls._add_fk(
model, fk_field, new_models.get(fk_field.get("python_type"))
),
upgrade, upgrade,
fk_m2m_index=True, fk_m2m_index=True,
) )
@@ -454,33 +429,37 @@ class Migrate:
for old_fk_field_name in set(old_fk_fields_name).difference( for old_fk_field_name in set(old_fk_fields_name).difference(
set(new_fk_fields_name) set(new_fk_fields_name)
): ):
old_fk_field = cls.get_field_by_name( old_fk_field = next(
old_fk_field_name, cast(List[dict], old_fk_fields) filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields)
) )
if old_fk_field.get("db_constraint"): if old_fk_field.get("db_constraint"):
ref_describe = cast(dict, old_models[old_fk_field["python_type"]])
cls._add_operator( cls._add_operator(
cls._drop_fk(model, old_fk_field, ref_describe), cls._drop_fk(
model, old_fk_field, old_models.get(old_fk_field.get("python_type"))
),
upgrade, upgrade,
fk_m2m_index=True, fk_m2m_index=True,
) )
# change fields # change fields
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)): for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
old_data_field = cls.get_field_by_name(field_name, old_data_fields) old_data_field = next(
new_data_field = cls.get_field_by_name(field_name, new_data_fields) filter(lambda x: x.get("name") == field_name, old_data_fields)
)
new_data_field = next(
filter(lambda x: x.get("name") == field_name, new_data_fields)
)
changes = diff(old_data_field, new_data_field) changes = diff(old_data_field, new_data_field)
modified = False modified = False
for change in changes: for change in changes:
_, option, old_new = change _, option, old_new = change
if option == "indexed": if option == "indexed":
# change index # change index
if old_new[0] is False and old_new[1] is True:
unique = new_data_field.get("unique") unique = new_data_field.get("unique")
if old_new[0] is False and old_new[1] is True:
cls._add_operator( cls._add_operator(
cls._add_index(model, (field_name,), unique), upgrade, True cls._add_index(model, (field_name,), unique), upgrade, True
) )
else: else:
unique = old_data_field.get("unique")
cls._add_operator( cls._add_operator(
cls._drop_index(model, (field_name,), unique), upgrade, True cls._drop_index(model, (field_name,), unique), upgrade, True
) )
@@ -507,9 +486,6 @@ class Migrate:
elif option == "nullable": elif option == "nullable":
# change nullable # change nullable
cls._add_operator(cls._alter_null(model, new_data_field), upgrade) cls._add_operator(cls._alter_null(model, new_data_field), upgrade)
elif option == "description":
# change comment
cls._add_operator(cls._set_comment(model, new_data_field), upgrade)
else: else:
if modified: if modified:
continue continue
@@ -520,118 +496,103 @@ class Migrate:
) )
modified = True modified = True
for old_model in old_models.keys() - new_models.keys(): for old_model in old_models:
cls._add_operator(cls.drop_model(old_models[old_model]["table"]), upgrade) if old_model not in new_models.keys():
cls._add_operator(cls.drop_model(old_models.get(old_model).get("table")), upgrade)
@classmethod @classmethod
def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str) -> str: def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str):
return cls.ddl.rename_table(model, old_table_name, new_table_name) return cls.ddl.rename_table(model, old_table_name, new_table_name)
@classmethod @classmethod
def add_model(cls, model: Type[Model]) -> str: def add_model(cls, model: Type[Model]):
return cls.ddl.create_table(model) return cls.ddl.create_table(model)
@classmethod @classmethod
def drop_model(cls, table_name: str) -> str: def drop_model(cls, table_name: str):
return cls.ddl.drop_table(table_name) return cls.ddl.drop_table(table_name)
@classmethod @classmethod
def create_m2m( def create_m2m(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
return cls.ddl.create_m2m(model, field_describe, reference_table_describe) return cls.ddl.create_m2m(model, field_describe, reference_table_describe)
@classmethod @classmethod
def drop_m2m(cls, table_name: str) -> str: def drop_m2m(cls, table_name: str):
return cls.ddl.drop_m2m(table_name) return cls.ddl.drop_m2m(table_name)
@classmethod @classmethod
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Iterable[str]) -> List[str]: def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]):
ret = [] ret = []
for field_name in fields_name: for field_name in fields_name:
try:
field = model._meta.fields_map[field_name] field = model._meta.fields_map[field_name]
except KeyError:
# field dropped or to be add
pass
else:
if field.source_field: if field.source_field:
field_name = field.source_field ret.append(field.source_field)
elif field_name in model._meta.fk_fields: elif field_name in model._meta.fk_fields:
field_name += "_id" ret.append(field_name + "_id")
else:
ret.append(field_name) ret.append(field_name)
return ret return ret
@classmethod @classmethod
def _drop_index( def _drop_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False):
cls, model: Type[Model], fields_name: Union[Iterable[str], Index], unique=False
) -> str:
if isinstance(fields_name, Index): if isinstance(fields_name, Index):
return cls.ddl.drop_index_by_name( return cls.ddl.drop_index_by_name(
model, fields_name.index_name(cls.ddl.schema_generator, model) model, fields_name.index_name(cls.ddl.schema_generator, model)
) )
field_names = cls._resolve_fk_fields_name(model, fields_name) fields_name = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.drop_index(model, field_names, unique) return cls.ddl.drop_index(model, fields_name, unique)
@classmethod @classmethod
def _add_index( def _add_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False):
cls, model: Type[Model], fields_name: Union[Iterable[str], Index], unique=False
) -> str:
if isinstance(fields_name, Index): if isinstance(fields_name, Index):
return fields_name.get_sql(cls.ddl.schema_generator, model, False) return fields_name.get_sql(cls.ddl.schema_generator, model, False)
field_names = cls._resolve_fk_fields_name(model, fields_name) fields_name = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.add_index(model, field_names, unique) return cls.ddl.add_index(model, fields_name, unique)
@classmethod @classmethod
def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False) -> str: def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False):
return cls.ddl.add_column(model, field_describe, is_pk) return cls.ddl.add_column(model, field_describe, is_pk)
@classmethod @classmethod
def _alter_default(cls, model: Type[Model], field_describe: dict) -> str: def _alter_default(cls, model: Type[Model], field_describe: dict):
return cls.ddl.alter_column_default(model, field_describe) return cls.ddl.alter_column_default(model, field_describe)
@classmethod @classmethod
def _alter_null(cls, model: Type[Model], field_describe: dict) -> str: def _alter_null(cls, model: Type[Model], field_describe: dict):
return cls.ddl.alter_column_null(model, field_describe) return cls.ddl.alter_column_null(model, field_describe)
@classmethod @classmethod
def _set_comment(cls, model: Type[Model], field_describe: dict) -> str: def _set_comment(cls, model: Type[Model], field_describe: dict):
return cls.ddl.set_comment(model, field_describe) return cls.ddl.set_comment(model, field_describe)
@classmethod @classmethod
def _modify_field(cls, model: Type[Model], field_describe: dict) -> str: def _modify_field(cls, model: Type[Model], field_describe: dict):
return cls.ddl.modify_column(model, field_describe) return cls.ddl.modify_column(model, field_describe)
@classmethod @classmethod
def _drop_fk( def _drop_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
return cls.ddl.drop_fk(model, field_describe, reference_table_describe) return cls.ddl.drop_fk(model, field_describe, reference_table_describe)
@classmethod @classmethod
def _remove_field(cls, model: Type[Model], column_name: str) -> str: def _remove_field(cls, model: Type[Model], column_name: str):
return cls.ddl.drop_column(model, column_name) return cls.ddl.drop_column(model, column_name)
@classmethod @classmethod
def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str) -> str: def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str):
return cls.ddl.rename_column(model, old_field_name, new_field_name) return cls.ddl.rename_column(model, old_field_name, new_field_name)
@classmethod @classmethod
def _change_field( def _change_field(cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict):
cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict db_field_types = new_field_describe.get("db_field_types")
) -> str:
db_field_types = cast(dict, new_field_describe.get("db_field_types"))
return cls.ddl.change_column( return cls.ddl.change_column(
model, model,
cast(str, old_field_describe.get("db_column")), old_field_describe.get("db_column"),
cast(str, new_field_describe.get("db_column")), new_field_describe.get("db_column"),
cast(str, db_field_types.get(cls.dialect) or db_field_types.get("")), db_field_types.get(cls.dialect) or db_field_types.get(""),
) )
@classmethod @classmethod
def _add_fk( def _add_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
""" """
add fk add fk
:param model: :param model:
@@ -642,7 +603,7 @@ class Migrate:
return cls.ddl.add_fk(model, field_describe, reference_table_describe) return cls.ddl.add_fk(model, field_describe, reference_table_describe)
@classmethod @classmethod
def _merge_operators(cls) -> None: def _merge_operators(cls):
""" """
fk/m2m/index must be last when add,first when drop fk/m2m/index must be last when add,first when drop
:return: :return:

View File

@@ -9,7 +9,7 @@ MAX_APP_LENGTH = 100
class Aerich(Model): class Aerich(Model):
version = fields.CharField(max_length=MAX_VERSION_LENGTH) version = fields.CharField(max_length=MAX_VERSION_LENGTH)
app = fields.CharField(max_length=MAX_APP_LENGTH) app = fields.CharField(max_length=MAX_APP_LENGTH)
content: dict = fields.JSONField(encoder=encoder, decoder=decoder) content = fields.JSONField(encoder=encoder, decoder=decoder)
class Meta: class Meta:
ordering = ["-id"] ordering = ["-id"]

View File

@@ -3,10 +3,9 @@ import os
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
from types import ModuleType from typing import Dict
from typing import Dict, Optional, Union
from asyncclick import BadOptionUsage, ClickException, Context from click import BadOptionUsage, ClickException, Context
from tortoise import BaseDBAsyncClient, Tortoise from tortoise import BaseDBAsyncClient, Tortoise
@@ -85,19 +84,19 @@ def get_models_describe(app: str) -> Dict:
:return: :return:
""" """
ret = {} ret = {}
for model in Tortoise.apps[app].values(): for model in Tortoise.apps.get(app).values():
describe = model.describe() describe = model.describe()
ret[describe.get("name")] = describe ret[describe.get("name")] = describe
return ret return ret
def is_default_function(string: str) -> Optional[re.Match]: def is_default_function(string: str):
return re.match(r"^<function.+>$", str(string or "")) return re.match(r"^<function.+>$", str(string or ""))
def import_py_file(file: Union[str, Path]) -> ModuleType: def import_py_file(file: Path):
module_name, file_ext = os.path.splitext(os.path.split(file)[-1]) module_name, file_ext = os.path.splitext(os.path.split(file)[-1])
spec = importlib.util.spec_from_file_location(module_name, file) spec = importlib.util.spec_from_file_location(module_name, file)
module = importlib.util.module_from_spec(spec) # type:ignore[arg-type] module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type:ignore[union-attr] spec.loader.exec_module(module)
return module return module

View File

@@ -1 +1 @@
__version__ = "0.8.0" __version__ = "0.7.2"

View File

@@ -1,22 +1,19 @@
import asyncio import asyncio
import os import os
from typing import Generator
import pytest import pytest
from tortoise import Tortoise, expand_db_url, generate_schema_for_client from tortoise import Tortoise, expand_db_url, generate_schema_for_client
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
from tortoise.exceptions import DBConnectionError, OperationalError
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
from aerich.migrate import Migrate from aerich.migrate import Migrate
MEMORY_SQLITE = "sqlite://:memory:" db_url = os.getenv("TEST_DB", "sqlite://:memory:")
db_url = os.getenv("TEST_DB", MEMORY_SQLITE) db_url_second = os.getenv("TEST_DB_SECOND", "sqlite://:memory:")
db_url_second = os.getenv("TEST_DB_SECOND", MEMORY_SQLITE)
tortoise_orm = { tortoise_orm = {
"connections": { "connections": {
"default": expand_db_url(db_url, True), "default": expand_db_url(db_url, True),
@@ -30,7 +27,7 @@ tortoise_orm = {
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def reset_migrate() -> None: def reset_migrate():
Migrate.upgrade_operators = [] Migrate.upgrade_operators = []
Migrate.downgrade_operators = [] Migrate.downgrade_operators = []
Migrate._upgrade_fk_m2m_index_operators = [] Migrate._upgrade_fk_m2m_index_operators = []
@@ -40,27 +37,20 @@ def reset_migrate() -> None:
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def event_loop() -> Generator: def event_loop():
policy = asyncio.get_event_loop_policy() policy = asyncio.get_event_loop_policy()
res = policy.new_event_loop() res = policy.new_event_loop()
asyncio.set_event_loop(res) asyncio.set_event_loop(res)
res._close = res.close # type:ignore[attr-defined] res._close = res.close
res.close = lambda: None # type:ignore[method-assign] res.close = lambda: None
yield res yield res
res._close() # type:ignore[attr-defined] res._close()
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
async def initialize_tests(event_loop, request) -> None: async def initialize_tests(event_loop, request):
# Placing init outside the try block since it doesn't
# establish connections to the DB eagerly.
await Tortoise.init(config=tortoise_orm)
try:
await Tortoise._drop_databases()
except (DBConnectionError, OperationalError):
pass
await Tortoise.init(config=tortoise_orm, _create_db=True) await Tortoise.init(config=tortoise_orm, _create_db=True)
await generate_schema_for_client(Tortoise.get_connection("default"), safe=True) await generate_schema_for_client(Tortoise.get_connection("default"), safe=True)

1295
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "aerich" name = "aerich"
version = "0.8.0" version = "0.7.2"
description = "A database migrations tool for Tortoise ORM." description = "A database migrations tool for Tortoise ORM."
authors = ["long2ice <long2ice@gmail.com>"] authors = ["long2ice <long2ice@gmail.com>"]
license = "Apache-2.0" license = "Apache-2.0"
@@ -15,57 +15,52 @@ packages = [
include = ["CHANGELOG.md", "LICENSE", "README.md"] include = ["CHANGELOG.md", "LICENSE", "README.md"]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.8" python = "^3.7"
tortoise-orm = "*" tortoise-orm = "*"
click = "*"
asyncpg = { version = "*", optional = true } asyncpg = { version = "*", optional = true }
asyncmy = { version = "^0.2.9", optional = true, allow-prereleases = true } asyncmy = { version = "^0.2.8rc1", optional = true, allow-prereleases = true }
pydantic = "^2.0" pydantic = "*"
dictdiffer = "*" dictdiffer = "*"
tomlkit = "*" tomlkit = "*"
asyncclick = "^8.1.7.2"
[tool.poetry.group.dev.dependencies] [tool.poetry.dev-dependencies]
ruff = "*" ruff = "*"
isort = "*" isort = "*"
black = "*" black = "*"
pytest = "*" pytest = "*"
pytest-xdist = "*" pytest-xdist = "*"
# Breaking change in 0.23.* pytest-asyncio = "*"
# https://github.com/pytest-dev/pytest-asyncio/issues/706
pytest-asyncio = "^0.21.2"
bandit = "*" bandit = "*"
pytest-mock = "*" pytest-mock = "*"
cryptography = "*" cryptography = "*"
mypy = "^1.10.0"
[tool.poetry.extras] [tool.poetry.extras]
asyncmy = ["asyncmy"] asyncmy = ["asyncmy"]
asyncpg = ["asyncpg"] asyncpg = ["asyncpg"]
[tool.aerich] [tool.aerich]
tortoise_orm = "conftest.tortoise_orm" tortoise_orm = "conftest.tortoise_orm"
location = "./migrations" location = "./migrations"
src_folder = "./." src_folder = "./."
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry>=0.12"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.masonry.api"
[tool.poetry.scripts] [tool.poetry.scripts]
aerich = "aerich.cli:main" aerich = "aerich.cli:main"
[tool.black] [tool.black]
line-length = 100 line-length = 100
target-version = ['py38', 'py39', 'py310', 'py311', 'py312'] target-version = ['py36', 'py37', 'py38', 'py39']
[tool.pytest.ini_options] [tool.pytest.ini_options]
asyncio_mode = 'auto' asyncio_mode = 'auto'
[tool.mypy] [tool.mypy]
pretty = true pretty = true
python_version = "3.8"
ignore_missing_imports = true ignore_missing_imports = true
[tool.ruff.lint] [tool.ruff]
ignore = ['E501'] ignore = ['E501']

View File

@@ -33,11 +33,11 @@ class User(Model):
class Email(Model): class Email(Model):
email_id = fields.IntField(primary_key=True) email_id = fields.IntField(pk=True)
email = fields.CharField(max_length=200, db_index=True) email = fields.CharField(max_length=200, index=True)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
address = fields.CharField(max_length=200) address = fields.CharField(max_length=200)
users: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User") users = fields.ManyToManyField("models.User")
def default_name(): def default_name():
@@ -47,15 +47,12 @@ def default_name():
class Category(Model): class Category(Model):
slug = fields.CharField(max_length=100) slug = fields.CharField(max_length=100)
name = fields.CharField(max_length=200, null=True, default=default_name) name = fields.CharField(max_length=200, null=True, default=default_name)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( user = fields.ForeignKeyField("models.User", description="User")
"models.User", description="User"
)
title = fields.CharField(max_length=20, unique=False)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Product(Model): class Product(Model):
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category") categories = fields.ManyToManyField("models.Category")
name = fields.CharField(max_length=50) name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num", default=0) view_num = fields.IntField(description="View Num", default=0)
sort = fields.IntField() sort = fields.IntField()
@@ -75,11 +72,9 @@ class Product(Model):
class Config(Model): class Config(Model):
label = fields.CharField(max_length=200) label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20) key = fields.CharField(max_length=20)
value: dict = fields.JSONField() value = fields.JSONField()
status: Status = fields.IntEnumField(Status) status: Status = fields.IntEnumField(Status)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( user = fields.ForeignKeyField("models.User", description="User")
"models.User", description="User"
)
class NewModel(Model): class NewModel(Model):

View File

@@ -34,24 +34,18 @@ class User(Model):
class Email(Model): class Email(Model):
email = fields.CharField(max_length=200) email = fields.CharField(max_length=200)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( user = fields.ForeignKeyField("models_second.User", db_constraint=False)
"models_second.User", db_constraint=False
)
class Category(Model): class Category(Model):
slug = fields.CharField(max_length=200) slug = fields.CharField(max_length=200)
name = fields.CharField(max_length=200) name = fields.CharField(max_length=200)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( user = fields.ForeignKeyField("models_second.User", description="User")
"models_second.User", description="User"
)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Product(Model): class Product(Model):
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField( categories = fields.ManyToManyField("models_second.Category")
"models_second.Category"
)
name = fields.CharField(max_length=50) name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num") view_num = fields.IntField(description="View Num")
sort = fields.IntField() sort = fields.IntField()
@@ -67,5 +61,5 @@ class Product(Model):
class Config(Model): class Config(Model):
label = fields.CharField(max_length=200) label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20) key = fields.CharField(max_length=20)
value: dict = fields.JSONField() value = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on) status: Status = fields.IntEnumField(Status, default=Status.on)

View File

@@ -35,23 +35,18 @@ class User(Model):
class Email(Model): class Email(Model):
email = fields.CharField(max_length=200) email = fields.CharField(max_length=200)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( user = fields.ForeignKeyField("models.User", db_constraint=False)
"models.User", db_constraint=False
)
class Category(Model): class Category(Model):
slug = fields.CharField(max_length=200) slug = fields.CharField(max_length=200)
name = fields.CharField(max_length=200) name = fields.CharField(max_length=200)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( user = fields.ForeignKeyField("models.User", description="User")
"models.User", description="User"
)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Product(Model): class Product(Model):
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category") categories = fields.ManyToManyField("models.Category")
uid = fields.IntField(source_field="uuid", unique=True)
name = fields.CharField(max_length=50) name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num") view_num = fields.IntField(description="View Num")
sort = fields.IntField() sort = fields.IntField()
@@ -65,10 +60,9 @@ class Product(Model):
class Config(Model): class Config(Model):
name = fields.CharField(max_length=100, unique=True)
label = fields.CharField(max_length=200) label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20) key = fields.CharField(max_length=20)
value: dict = fields.JSONField() value = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on) status: Status = fields.IntEnumField(Status, default=Status.on)
class Meta: class Meta:

View File

@@ -14,7 +14,6 @@ def test_create_table():
`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT, `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
`slug` VARCHAR(100) NOT NULL, `slug` VARCHAR(100) NOT NULL,
`name` VARCHAR(200), `name` VARCHAR(200),
`title` VARCHAR(20) NOT NULL,
`created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
`user_id` INT NOT NULL COMMENT 'User', `user_id` INT NOT NULL COMMENT 'User',
CONSTRAINT `fk_category_user_e2e3874c` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE CONSTRAINT `fk_category_user_e2e3874c` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE
@@ -28,7 +27,6 @@ def test_create_table():
"id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
"slug" VARCHAR(100) NOT NULL, "slug" VARCHAR(100) NOT NULL,
"name" VARCHAR(200), "name" VARCHAR(200),
"title" VARCHAR(20) NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */ "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */
)""" )"""
@@ -41,7 +39,6 @@ def test_create_table():
"id" SERIAL NOT NULL PRIMARY KEY, "id" SERIAL NOT NULL PRIMARY KEY,
"slug" VARCHAR(100) NOT NULL, "slug" VARCHAR(100) NOT NULL,
"name" VARCHAR(200), "name" VARCHAR(200),
"title" VARCHAR(20) NOT NULL,
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, "created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE
); );
@@ -154,7 +151,9 @@ def test_add_index():
index_u = Migrate.ddl.add_index(Category, ["name"], True) index_u = Migrate.ddl.add_index(Category, ["name"], True)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert index == "ALTER TABLE `category` ADD INDEX `idx_category_name_8b0cb9` (`name`)" assert index == "ALTER TABLE `category` ADD INDEX `idx_category_name_8b0cb9` (`name`)"
assert index_u == "ALTER TABLE `category` ADD UNIQUE INDEX `name` (`name`)" assert (
index_u == "ALTER TABLE `category` ADD UNIQUE INDEX `uid_category_name_8b0cb9` (`name`)"
)
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")' assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")'
assert index_u == 'CREATE UNIQUE INDEX "uid_category_name_8b0cb9" ON "category" ("name")' assert index_u == 'CREATE UNIQUE INDEX "uid_category_name_8b0cb9" ON "category" ("name")'
@@ -170,7 +169,7 @@ def test_drop_index():
ret_u = Migrate.ddl.drop_index(Category, ["name"], True) ret_u = Migrate.ddl.drop_index(Category, ["name"], True)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` DROP INDEX `idx_category_name_8b0cb9`" assert ret == "ALTER TABLE `category` DROP INDEX `idx_category_name_8b0cb9`"
assert ret_u == "ALTER TABLE `category` DROP INDEX `name`" assert ret_u == "ALTER TABLE `category` DROP INDEX `uid_category_name_8b0cb9`"
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'DROP INDEX "idx_category_name_8b0cb9"' assert ret == 'DROP INDEX "idx_category_name_8b0cb9"'
assert ret_u == 'DROP INDEX "uid_category_name_8b0cb9"' assert ret_u == 'DROP INDEX "uid_category_name_8b0cb9"'

View File

@@ -1,15 +1,11 @@
from pathlib import Path
from typing import List, cast
import pytest import pytest
import tortoise
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
from aerich.exceptions import NotSupportError from aerich.exceptions import NotSupportError
from aerich.migrate import MIGRATE_TEMPLATE, Migrate from aerich.migrate import Migrate
from aerich.utils import get_models_describe from aerich.utils import get_models_describe
old_models_describe = { old_models_describe = {
@@ -104,21 +100,6 @@ old_models_describe = {
"constraints": {"ge": 1, "le": 2147483647}, "constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"}, "db_field_types": {"": "INT"},
}, },
{
"name": "title",
"field_type": "CharField",
"db_column": "title",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 20},
"db_field_types": {"": "VARCHAR(20)"},
},
], ],
"fk_fields": [ "fk_fields": [
{ {
@@ -188,21 +169,6 @@ old_models_describe = {
"db_field_types": {"": "INT"}, "db_field_types": {"": "INT"},
}, },
"data_fields": [ "data_fields": [
{
"name": "name",
"field_type": "CharField",
"db_column": "name",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 100},
"db_field_types": {"": "VARCHAR(100)"},
},
{ {
"name": "label", "name": "label",
"field_type": "CharField", "field_type": "CharField",
@@ -403,21 +369,6 @@ old_models_describe = {
"constraints": {"max_length": 50}, "constraints": {"max_length": 50},
"db_field_types": {"": "VARCHAR(50)"}, "db_field_types": {"": "VARCHAR(50)"},
}, },
{
"name": "uid",
"field_type": "IntField",
"db_column": "uuid",
"python_type": "int",
"generated": False,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": -2147483648, "le": 2147483647},
"db_field_types": {"": "INT"},
},
{ {
"name": "view_num", "name": "view_num",
"field_type": "IntField", "field_type": "IntField",
@@ -822,16 +773,6 @@ old_models_describe = {
} }
def should_add_user_id_column_type_alter_sql() -> bool:
if tortoise.__version__ < "0.21":
return False
# tortoise-orm>=0.21 changes IntField constraints
# from {"ge": 1, "le": 2147483647} to {"ge": -2147483648,"le": 2147483647}
data_fields = cast(List[dict], old_models_describe["models.Category"]["data_fields"])
user_id_constraints = data_fields[-1]["constraints"]
return tortoise.fields.data.IntField.constraints != user_id_constraints
def test_migrate(mocker: MockerFixture): def test_migrate(mocker: MockerFixture):
""" """
models.py diff with old_models.py models.py diff with old_models.py
@@ -842,25 +783,20 @@ def test_migrate(mocker: MockerFixture):
- drop field: User.avatar - drop field: User.avatar
- add index: Email.email - add index: Email.email
- add many to many: Email.users - add many to many: Email.users
- remove unique: Category.title - remove unique: User.username
- add unique: User.username
- change column: length User.password - change column: length User.password
- add unique_together: (name,type) of Product - add unique_together: (name,type) of Product
- drop unique field: Config.name
- alter default: Config.status - alter default: Config.status
- rename column: Product.image -> Product.pic - rename column: Product.image -> Product.pic
""" """
mocker.patch("asyncclick.prompt", side_effect=(True,)) mocker.patch("click.prompt", side_effect=(True,))
models_describe = get_models_describe("models") models_describe = get_models_describe("models")
Migrate.app = "models" Migrate.app = "models"
if isinstance(Migrate.ddl, SqliteDDL): if isinstance(Migrate.ddl, SqliteDDL):
with pytest.raises(NotSupportError): with pytest.raises(NotSupportError):
Migrate.diff_models(old_models_describe, models_describe) Migrate.diff_models(old_models_describe, models_describe)
Migrate.upgrade_operators.clear()
with pytest.raises(NotSupportError):
Migrate.diff_models(models_describe, old_models_describe, False) Migrate.diff_models(models_describe, old_models_describe, False)
Migrate.downgrade_operators.clear()
else: else:
Migrate.diff_models(old_models_describe, models_describe) Migrate.diff_models(old_models_describe, models_describe)
Migrate.diff_models(models_describe, old_models_describe, False) Migrate.diff_models(models_describe, old_models_describe, False)
@@ -869,9 +805,6 @@ def test_migrate(mocker: MockerFixture):
expected_upgrade_operators = { expected_upgrade_operators = {
"ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)", "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)",
"ALTER TABLE `category` MODIFY COLUMN `slug` VARCHAR(100) NOT NULL", "ALTER TABLE `category` MODIFY COLUMN `slug` VARCHAR(100) NOT NULL",
"ALTER TABLE `category` DROP INDEX `title`",
"ALTER TABLE `config` DROP COLUMN `name`",
"ALTER TABLE `config` DROP INDEX `name`",
"ALTER TABLE `config` ADD `user_id` INT NOT NULL COMMENT 'User'", "ALTER TABLE `config` ADD `user_id` INT NOT NULL COMMENT 'User'",
"ALTER TABLE `config` ADD CONSTRAINT `fk_config_user_17daa970` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE", "ALTER TABLE `config` ADD CONSTRAINT `fk_config_user_17daa970` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
"ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT", "ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT",
@@ -879,8 +812,6 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL", "ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL",
"ALTER TABLE `email` DROP COLUMN `user_id`", "ALTER TABLE `email` DROP COLUMN `user_id`",
"ALTER TABLE `configs` RENAME TO `config`", "ALTER TABLE `configs` RENAME TO `config`",
"ALTER TABLE `product` DROP COLUMN `uuid`",
"ALTER TABLE `product` DROP INDEX `uuid`",
"ALTER TABLE `product` RENAME COLUMN `image` TO `pic`", "ALTER TABLE `product` RENAME COLUMN `image` TO `pic`",
"ALTER TABLE `email` RENAME COLUMN `id` TO `email_id`", "ALTER TABLE `email` RENAME COLUMN `id` TO `email_id`",
"ALTER TABLE `product` ADD INDEX `idx_product_name_869427` (`name`, `type_db_alias`)", "ALTER TABLE `product` ADD INDEX `idx_product_name_869427` (`name`, `type_db_alias`)",
@@ -896,7 +827,7 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1", "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1",
"ALTER TABLE `user` MODIFY COLUMN `is_superuser` BOOL NOT NULL COMMENT 'Is SuperUser' DEFAULT 0", "ALTER TABLE `user` MODIFY COLUMN `is_superuser` BOOL NOT NULL COMMENT 'Is SuperUser' DEFAULT 0",
"ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(10,8) NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(10,8) NOT NULL",
"ALTER TABLE `user` ADD UNIQUE INDEX `username` (`username`)", "ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)",
"CREATE TABLE `email_user` (\n `email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4", "CREATE TABLE `email_user` (\n `email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"CREATE TABLE IF NOT EXISTS `newmodel` (\n `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,\n `name` VARCHAR(50) NOT NULL\n) CHARACTER SET utf8mb4", "CREATE TABLE IF NOT EXISTS `newmodel` (\n `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,\n `name` VARCHAR(50) NOT NULL\n) CHARACTER SET utf8mb4",
"ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)", "ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
@@ -906,9 +837,6 @@ def test_migrate(mocker: MockerFixture):
expected_downgrade_operators = { expected_downgrade_operators = {
"ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL", "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL",
"ALTER TABLE `category` MODIFY COLUMN `slug` VARCHAR(200) NOT NULL", "ALTER TABLE `category` MODIFY COLUMN `slug` VARCHAR(200) NOT NULL",
"ALTER TABLE `category` ADD UNIQUE INDEX `title` (`title`)",
"ALTER TABLE `config` ADD `name` VARCHAR(100) NOT NULL UNIQUE",
"ALTER TABLE `config` ADD UNIQUE INDEX `name` (`name`)",
"ALTER TABLE `config` DROP COLUMN `user_id`", "ALTER TABLE `config` DROP COLUMN `user_id`",
"ALTER TABLE `config` DROP FOREIGN KEY `fk_config_user_17daa970`", "ALTER TABLE `config` DROP FOREIGN KEY `fk_config_user_17daa970`",
"ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1", "ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1",
@@ -917,14 +845,12 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `config` RENAME TO `configs`", "ALTER TABLE `config` RENAME TO `configs`",
"ALTER TABLE `product` RENAME COLUMN `pic` TO `image`", "ALTER TABLE `product` RENAME COLUMN `pic` TO `image`",
"ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`", "ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`",
"ALTER TABLE `product` ADD `uuid` INT NOT NULL UNIQUE",
"ALTER TABLE `product` ADD UNIQUE INDEX `uuid` (`uuid`)",
"ALTER TABLE `product` DROP INDEX `idx_product_name_869427`", "ALTER TABLE `product` DROP INDEX `idx_product_name_869427`",
"ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`", "ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`",
"ALTER TABLE `product` DROP INDEX `uid_product_name_869427`", "ALTER TABLE `product` DROP INDEX `uid_product_name_869427`",
"ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT", "ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT",
"ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''", "ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''",
"ALTER TABLE `user` DROP INDEX `username`", "ALTER TABLE `user` DROP INDEX `idx_user_usernam_9987ab`",
"ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(200) NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(200) NOT NULL",
"DROP TABLE IF EXISTS `email_user`", "DROP TABLE IF EXISTS `email_user`",
"DROP TABLE IF EXISTS `newmodel`", "DROP TABLE IF EXISTS `newmodel`",
@@ -940,10 +866,6 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL", "ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL",
"ALTER TABLE `email` MODIFY COLUMN `is_primary` BOOL NOT NULL DEFAULT 0", "ALTER TABLE `email` MODIFY COLUMN `is_primary` BOOL NOT NULL DEFAULT 0",
} }
if should_add_user_id_column_type_alter_sql():
sql = "ALTER TABLE `category` MODIFY COLUMN `user_id` INT NOT NULL COMMENT 'User'"
expected_upgrade_operators.add(sql)
expected_downgrade_operators.add(sql)
assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators) assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators)
assert not set(Migrate.downgrade_operators).symmetric_difference( assert not set(Migrate.downgrade_operators).symmetric_difference(
@@ -952,12 +874,9 @@ def test_migrate(mocker: MockerFixture):
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
expected_upgrade_operators = { expected_upgrade_operators = {
'DROP INDEX "uid_category_title_f7fc03"',
'ALTER TABLE "category" ALTER COLUMN "name" DROP NOT NULL', 'ALTER TABLE "category" ALTER COLUMN "name" DROP NOT NULL',
'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(100) USING "slug"::VARCHAR(100)', 'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(100) USING "slug"::VARCHAR(100)',
'ALTER TABLE "category" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ', 'ALTER TABLE "category" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ',
'ALTER TABLE "config" DROP COLUMN "name"',
'DROP INDEX "uid_config_name_2c83c8"',
'ALTER TABLE "config" ADD "user_id" INT NOT NULL', 'ALTER TABLE "config" ADD "user_id" INT NOT NULL',
'ALTER TABLE "config" ADD CONSTRAINT "fk_config_user_17daa970" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE', 'ALTER TABLE "config" ADD CONSTRAINT "fk_config_user_17daa970" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE',
'ALTER TABLE "config" ALTER COLUMN "status" DROP DEFAULT', 'ALTER TABLE "config" ALTER COLUMN "status" DROP DEFAULT',
@@ -967,8 +886,6 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "email" DROP COLUMN "user_id"', 'ALTER TABLE "email" DROP COLUMN "user_id"',
'ALTER TABLE "email" RENAME COLUMN "id" TO "email_id"', 'ALTER TABLE "email" RENAME COLUMN "id" TO "email_id"',
'ALTER TABLE "email" ALTER COLUMN "is_primary" TYPE BOOL USING "is_primary"::BOOL', 'ALTER TABLE "email" ALTER COLUMN "is_primary" TYPE BOOL USING "is_primary"::BOOL',
'DROP INDEX "uid_product_uuid_d33c18"',
'ALTER TABLE "product" DROP COLUMN "uuid"',
'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0', 'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0',
'ALTER TABLE "product" RENAME COLUMN "image" TO "pic"', 'ALTER TABLE "product" RENAME COLUMN "image" TO "pic"',
'ALTER TABLE "product" ALTER COLUMN "is_reviewed" TYPE BOOL USING "is_reviewed"::BOOL', 'ALTER TABLE "product" ALTER COLUMN "is_reviewed" TYPE BOOL USING "is_reviewed"::BOOL',
@@ -989,12 +906,9 @@ def test_migrate(mocker: MockerFixture):
'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")', 'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")',
} }
expected_downgrade_operators = { expected_downgrade_operators = {
'CREATE UNIQUE INDEX "uid_category_title_f7fc03" ON "category" ("title")',
'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL', 'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL',
'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(200) USING "slug"::VARCHAR(200)', 'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(200) USING "slug"::VARCHAR(200)',
'ALTER TABLE "category" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ', 'ALTER TABLE "category" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ',
'ALTER TABLE "config" ADD "name" VARCHAR(100) NOT NULL UNIQUE',
'CREATE UNIQUE INDEX "uid_config_name_2c83c8" ON "config" ("name")',
'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1', 'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1',
'ALTER TABLE "config" DROP COLUMN "user_id"', 'ALTER TABLE "config" DROP COLUMN "user_id"',
'ALTER TABLE "config" DROP CONSTRAINT "fk_config_user_17daa970"', 'ALTER TABLE "config" DROP CONSTRAINT "fk_config_user_17daa970"',
@@ -1004,8 +918,6 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "email" DROP COLUMN "address"', 'ALTER TABLE "email" DROP COLUMN "address"',
'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"', 'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"',
'ALTER TABLE "email" ALTER COLUMN "is_primary" TYPE BOOL USING "is_primary"::BOOL', 'ALTER TABLE "email" ALTER COLUMN "is_primary" TYPE BOOL USING "is_primary"::BOOL',
'ALTER TABLE "product" ADD "uuid" INT NOT NULL UNIQUE',
'CREATE UNIQUE INDEX "uid_product_uuid_d33c18" ON "product" ("uuid")',
'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT', 'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT',
'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"', 'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"',
'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'', 'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'',
@@ -1020,15 +932,11 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "product" ALTER COLUMN "body" TYPE TEXT USING "body"::TEXT', 'ALTER TABLE "product" ALTER COLUMN "body" TYPE TEXT USING "body"::TEXT',
'DROP INDEX "idx_product_name_869427"', 'DROP INDEX "idx_product_name_869427"',
'DROP INDEX "idx_email_email_4a1a33"', 'DROP INDEX "idx_email_email_4a1a33"',
'DROP INDEX "uid_user_usernam_9987ab"', 'DROP INDEX "idx_user_usernam_9987ab"',
'DROP INDEX "uid_product_name_869427"', 'DROP INDEX "uid_product_name_869427"',
'DROP TABLE IF EXISTS "email_user"', 'DROP TABLE IF EXISTS "email_user"',
'DROP TABLE IF EXISTS "newmodel"', 'DROP TABLE IF EXISTS "newmodel"',
} }
if should_add_user_id_column_type_alter_sql():
sql = 'ALTER TABLE "category" ALTER COLUMN "user_id" TYPE INT USING "user_id"::INT'
expected_upgrade_operators.add(sql)
expected_downgrade_operators.add(sql)
assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators) assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators)
assert not set(Migrate.downgrade_operators).symmetric_difference( assert not set(Migrate.downgrade_operators).symmetric_difference(
expected_downgrade_operators expected_downgrade_operators
@@ -1058,15 +966,3 @@ def test_sort_all_version_files(mocker):
"10_datetime_update.py", "10_datetime_update.py",
"11_datetime_update.py", "11_datetime_update.py",
] ]
async def test_empty_migration(mocker, tmp_path: Path) -> None:
mocker.patch("os.listdir", return_value=[])
Migrate.app = "foo"
expected_content = MIGRATE_TEMPLATE.format(upgrade_sql="", downgrade_sql="")
Migrate.migrate_location = tmp_path
migration_file = await Migrate.migrate("update", True)
f = tmp_path / migration_file
assert f.read_text() == expected_content

View File

@@ -1,6 +1,6 @@
from aerich.utils import import_py_file from aerich.utils import import_py_file
def test_import_py_file() -> None: def test_import_py_file():
m = import_py_file("aerich/utils.py") m = import_py_file("aerich/utils.py")
assert getattr(m, "import_py_file") assert getattr(m, "import_py_file")