61 Commits

Author SHA1 Message Date
long2ice
919d56c936 add ci branches-ignore master 2021-07-07 10:29:38 +08:00
long2ice
7bcf9b2fed Support drop column for sqlite. (#40) 2021-07-03 13:51:01 +08:00
long2ice
9f663299cf Merge pull request #174 from sasha00123/dev
Fixed typo in README.md concerning dowgrade usage
2021-06-25 13:52:34 +08:00
Alexander Batyrgariev
28dbdf2663 Fixed typo in README.md concerning dowgrade usage 2021-06-25 08:00:00 +03:00
long2ice
e71a4b60a5 Merge pull request #166 from spacemanspiff2007/dev
Added config option to specify source folder
2021-06-13 14:26:26 +08:00
-
62840136be used old black version 2021-06-11 15:36:54 +02:00
-
185514f711 reformatted with black 2021-06-11 15:18:06 +02:00
-
8e783e031e updated readme 2021-06-10 16:56:30 +02:00
-
10b7272ca8 Added an configuration option to specify the path of the source folder.
This will make aerich work with various folder structures (e.g. ./src/MyPythonModule)
Additionally this will try to import in init and show the user the error message on failure.
2021-06-10 16:52:03 +02:00
long2ice
0c763c6024 Fix repeat 2021-06-09 13:56:25 +08:00
long2ice
c6371a5c16 Fix repeat 2021-06-09 11:43:32 +08:00
long2ice
1dbf9185b6 Not catch exception when import config. (#164) 2021-06-04 17:47:39 +08:00
long2ice
9bf2de0b9a Fix incorrect index creation order. (#151) 2021-06-01 17:09:45 +08:00
long2ice
bf1cf21324 Merge pull request #158 from manzato/pyproject-update
Update URLs
2021-05-22 22:43:52 +08:00
Guillermo Manzato
8b08329493 Update URLs 2021-05-22 11:39:49 -03:00
long2ice
5bc7d23d95 Merge pull request #157 from tortoise/dependabot/pip/pydantic-1.8.2
Bump pydantic from 1.8.1 to 1.8.2
2021-05-14 09:30:47 +08:00
dependabot[bot]
a253aa96cb Bump pydantic from 1.8.1 to 1.8.2
Bumps [pydantic](https://github.com/samuelcolvin/pydantic) from 1.8.1 to 1.8.2.
- [Release notes](https://github.com/samuelcolvin/pydantic/releases)
- [Changelog](https://github.com/samuelcolvin/pydantic/blob/master/HISTORY.md)
- [Commits](https://github.com/samuelcolvin/pydantic/compare/v1.8.1...v1.8.2)

Signed-off-by: dependabot[bot] <support@github.com>
2021-05-13 20:51:51 +00:00
long2ice
15a6e874dd update deps 2021-05-03 14:23:27 +08:00
long2ice
19a5dcbf3f update deps 2021-04-26 21:01:40 +08:00
long2ice
922e3eef16 Fix CI 2021-04-05 17:11:28 +08:00
long2ice
44fd2fe6ae Fix default function when migrate. (#147) 2021-04-05 14:10:42 +08:00
long2ice
b147859960 Fix default function when migrate 2021-04-04 05:46:34 +00:00
long2ice
793cf2532c Create FUNDING.yml 2021-04-03 21:34:24 +08:00
long2ice
fa85e05d1d Fix postgre alter null. (#142) 2021-03-28 16:22:49 +08:00
long2ice
3f52ac348b Support rename table. (#139) 2021-03-25 21:21:49 +08:00
long2ice
f8aa7a8f34 Fix inspectdb for FloatField. (#138) 2021-03-22 14:16:59 +08:00
long2ice
44d520cc82 Fix postgres field type change error. (#135) 2021-03-21 21:18:08 +08:00
long2ice
364735f804 Fix rename field on the field add. (#134) 2021-03-21 20:43:05 +08:00
long2ice
505d361597 Fix drop model in the downgrade. (#132) 2021-03-18 23:40:13 +08:00
long2ice
a19edd3a35 update ci name 2021-03-13 16:45:35 +08:00
long2ice
84d1f78019 update workflow name and add cryptography 2021-03-13 16:43:22 +08:00
long2ice
8fb07a6c9e update deps 2021-03-13 16:40:27 +08:00
long2ice
54da8b22af update aiomysql to asyncmy 2021-03-13 16:37:45 +08:00
long2ice
4c0308ff22 update test.yml 2021-03-03 22:03:38 +08:00
long2ice
38c4a15661 update test.yml 2021-03-03 20:42:18 +08:00
long2ice
52151270e0 Fix bug for field change. (#119) 2021-03-03 20:36:54 +08:00
long2ice
49897dc4fd Merge pull request #121 from AulonSal/close-tortoise-connections
Close Tortoise connections properly
2021-02-28 14:47:58 +08:00
AulonSal
d4ad0e270f Update version and changelog 2021-02-28 12:13:59 +05:30
AulonSal
e74fc304a5 Don't close db connections when group function \(cli\) is run 2021-02-27 00:43:55 +05:30
AulonSal
14d20455e6 Replace coro logic with tortoise.run_async 2021-02-23 13:06:40 +05:30
long2ice
bd9ecfd6e1 Merge pull request #122 from personalcomputer/personalcomputer/improve_readme_english
Improve English grammar / clarity in README.md
2021-02-22 12:31:15 +08:00
John Miller
de8500b9a1 Improve English grammar / clarity in README.md 2021-02-21 19:46:04 -08:00
AulonSal
90b47c5af7 Close connections even if command raises exception 2021-02-22 07:40:18 +05:30
AulonSal
02fe5a9d31 Close Tortoise connections properly 2021-02-20 13:11:29 +05:30
long2ice
be41a1332a update tortoise-orm version 2021-02-04 20:53:04 +08:00
long2ice
09661c1d46 Fix unique_together 2021-02-04 14:39:07 +08:00
long2ice
abfa60133f Fix drop table 2021-02-04 14:23:46 +08:00
long2ice
048e428eac update tortoise-orm 2021-02-03 22:52:01 +08:00
long2ice
38a3df9b5a add support m2m 2021-02-03 22:22:22 +08:00
long2ice
0d94b22b3f Remove unused functions 2021-02-03 18:06:43 +08:00
long2ice
f1f0074255 Support rename field 2021-02-03 17:56:30 +08:00
long2ice
e3a14a2f60 Fix postgres index 2021-02-03 16:34:07 +08:00
long2ice
608ff8f071 update conftest.py 2021-02-03 15:49:40 +08:00
long2ice
d3a1342293 update README.md 2021-02-03 15:48:06 +08:00
long2ice
01e3de9522 basically completed 2021-02-03 15:43:04 +08:00
long2ice
c6c398fdf0 update 2021-02-02 22:52:50 +08:00
long2ice
c60bdd290e add fk and drop fk 2021-02-02 20:35:05 +08:00
long2ice
f443dc68db WIP 2021-02-01 16:54:35 +08:00
long2ice
36f84702b7 update 2021-02-01 14:00:12 +08:00
long2ice
b4cc2de0e3 v0.5 refactoring 2021-01-31 23:10:30 +08:00
long2ice
4780b90c1c add close_connections to fix stuck 2021-01-29 22:58:12 +08:00
25 changed files with 1934 additions and 946 deletions

1
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1 @@
custom: ["https://sponsor.long2ice.cn"]

View File

@@ -1,7 +1,13 @@
name: test name: ci
on: [ push, pull_request ] on:
push:
branches-ignore:
- master
pull_request:
branches-ignore:
- master
jobs: jobs:
testall: ci:
runs-on: ubuntu-latest runs-on: ubuntu-latest
services: services:
postgres: postgres:

1
.gitignore vendored
View File

@@ -146,3 +146,4 @@ aerich.ini
src src
.vscode .vscode
.DS_Store .DS_Store
.python-version

View File

@@ -1,5 +1,35 @@
# ChangeLog # ChangeLog
## 0.5
### 0.5.4
- Fix incorrect index creation order. (#151)
- Not catch exception when import config. (#164)
- Support `drop column` for sqlite. (#40)
### 0.5.3
- Fix postgre alter null. (#142)
- Fix default function when migrate. (#147)
### 0.5.2
- Fix rename field on the field add. (#134)
- Fix postgres field type change error. (#135)
- Fix inspectdb for `FloatField`. (#138)
- Support `rename table`. (#139)
### 0.5.1
- Fix tortoise connections not being closed properly. (#120)
- Fix bug for field change. (#119)
- Fix drop model in the downgrade. (#132)
### 0.5.0
- Refactor core code, now has no limitation for everything.
## 0.4 ## 0.4
### 0.4.4 ### 0.4.4

View File

@@ -8,23 +8,11 @@ POSTGRES_HOST ?= "127.0.0.1"
POSTGRES_PORT ?= 5432 POSTGRES_PORT ?= 5432
POSTGRES_PASS ?= "123456" POSTGRES_PASS ?= "123456"
help:
@echo "Aerich development makefile"
@echo
@echo "usage: make <target>"
@echo "Targets:"
@echo " up Updates dev/test dependencies"
@echo " deps Ensure dev/test dependencies are installed"
@echo " check Checks that build is sane"
@echo " lint Reports all linter violations"
@echo " test Runs all tests"
@echo " style Auto-formats the code"
up: up:
@poetry update @poetry update
deps: deps:
@poetry install -E dbdrivers @poetry install -E asyncpg -E asyncmy -E aiomysql
style: deps style: deps
isort -src $(checkfiles) isort -src $(checkfiles)
@@ -45,7 +33,7 @@ test_mysql:
$(py_warn) TEST_DB="mysql://root:$(MYSQL_PASS)@$(MYSQL_HOST):$(MYSQL_PORT)/test_\{\}" pytest -vv -s $(py_warn) TEST_DB="mysql://root:$(MYSQL_PASS)@$(MYSQL_HOST):$(MYSQL_PORT)/test_\{\}" pytest -vv -s
test_postgres: test_postgres:
$(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest $(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest -vv -s
testall: deps test_sqlite test_postgres test_mysql testall: deps test_sqlite test_postgres test_mysql

View File

@@ -1,16 +1,14 @@
# Aerich # Aerich
[![image](https://img.shields.io/pypi/v/aerich.svg?style=flat)](https://pypi.python.org/pypi/aerich) [![image](https://img.shields.io/pypi/v/aerich.svg?style=flat)](https://pypi.python.org/pypi/aerich)
[![image](https://img.shields.io/github/license/long2ice/aerich)](https://github.com/long2ice/aerich) [![image](https://img.shields.io/github/license/tortoise/aerich)](https://github.com/tortoise/aerich)
[![image](https://github.com/long2ice/aerich/workflows/pypi/badge.svg)](https://github.com/long2ice/aerich/actions?query=workflow:pypi) [![image](https://github.com/tortoise/aerich/workflows/pypi/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:pypi)
[![image](https://github.com/long2ice/aerich/workflows/test/badge.svg)](https://github.com/long2ice/aerich/actions?query=workflow:test) [![image](https://github.com/tortoise/aerich/workflows/ci/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:ci)
## Introduction ## Introduction
Aerich is a database migrations tool for Tortoise-ORM, which like alembic for SQLAlchemy, or Django ORM with it\'s own Aerich is a database migrations tool for Tortoise-ORM, which is like alembic for SQLAlchemy, or like Django ORM with
migrations solution. it\'s own migration solution.
**Important: You can only use absolutely import in your `models.py` to make `aerich` work.**
## Install ## Install
@@ -28,10 +26,12 @@ Just install from pypi:
Usage: aerich [OPTIONS] COMMAND [ARGS]... Usage: aerich [OPTIONS] COMMAND [ARGS]...
Options: Options:
-V, --version Show the version and exit.
-c, --config TEXT Config file. [default: aerich.ini] -c, --config TEXT Config file. [default: aerich.ini]
--app TEXT Tortoise-ORM app name. [default: models] --app TEXT Tortoise-ORM app name.
-n, --name TEXT Name of section in .ini file to use for aerich config. -n, --name TEXT Name of section in .ini file to use for aerich config.
[default: aerich] [default: aerich]
-h, --help Show this message and exit. -h, --help Show this message and exit.
Commands: Commands:
@@ -42,12 +42,12 @@ Commands:
init-db Generate schema and generate app migrate location. init-db Generate schema and generate app migrate location.
inspectdb Introspects the database tables to standard output as... inspectdb Introspects the database tables to standard output as...
migrate Generate migrate changes file. migrate Generate migrate changes file.
upgrade Upgrade to latest version. upgrade Upgrade to specified version.
``` ```
## Usage ## Usage
You need add `aerich.models` to your `Tortoise-ORM` config first, example: You need add `aerich.models` to your `Tortoise-ORM` config first. Example:
```python ```python
TORTOISE_ORM = { TORTOISE_ORM = {
@@ -70,14 +70,16 @@ Usage: aerich init [OPTIONS]
Init config file and generate root migrate location. Init config file and generate root migrate location.
Options: OOptions:
-t, --tortoise-orm TEXT Tortoise-ORM config module dict variable, like settings.TORTOISE_ORM. -t, --tortoise-orm TEXT Tortoise-ORM config module dict variable, like
[required] settings.TORTOISE_ORM. [required]
--location TEXT Migrate store location. [default: ./migrations] --location TEXT Migrate store location. [default: ./migrations]
-s, --src_folder TEXT Folder of the source, relative to the project root.
-h, --help Show this message and exit. -h, --help Show this message and exit.
``` ```
Init config file and location: Initialize the config file and migrations location:
```shell ```shell
> aerich init -t tests.backends.mysql.TORTOISE_ORM > aerich init -t tests.backends.mysql.TORTOISE_ORM
@@ -95,8 +97,8 @@ Success create app migrate location ./migrations/models
Success generate schema for app "models" Success generate schema for app "models"
``` ```
If your Tortoise-ORM app is not default `models`, you must specify If your Tortoise-ORM app is not the default `models`, you must specify the correct app via `--app`,
`--app` like `aerich --app other_models init-db`. e.g. `aerich --app other_models init-db`.
### Update models and make migrate ### Update models and make migrate
@@ -109,8 +111,9 @@ Success migrate 1_202029051520102929_drop_column.sql
Format of migrate filename is Format of migrate filename is
`{version_num}_{datetime}_{name|update}.sql`. `{version_num}_{datetime}_{name|update}.sql`.
And if `aerich` guess you are renaming a column, it will ask `Rename {old_column} to {new_column} [True]`, you can If `aerich` guesses you are renaming a column, it will ask `Rename {old_column} to {new_column} [True]`. You can choose
choice `True` to rename column without column drop, or choice `False` to drop column then create. `True` to rename column without column drop, or choose `False` to drop the column then create. Note that the latter may
lose data.
### Upgrade to latest version ### Upgrade to latest version
@@ -125,7 +128,7 @@ Now your db is migrated to latest.
### Downgrade to specified version ### Downgrade to specified version
```shell ```shell
> aerich init -h > aerich downgrade -h
Usage: aerich downgrade [OPTIONS] Usage: aerich downgrade [OPTIONS]
@@ -146,7 +149,7 @@ Options:
Success downgrade 1_202029051520102929_drop_column.sql Success downgrade 1_202029051520102929_drop_column.sql
``` ```
Now your db rollback to specified version. Now your db is rolled back to the specified version.
### Show history ### Show history
@@ -166,6 +169,8 @@ Now your db rollback to specified version.
### Inspect db tables to TortoiseORM model ### Inspect db tables to TortoiseORM model
Currently `inspectdb` only supports MySQL.
```shell ```shell
Usage: aerich inspectdb [OPTIONS] Usage: aerich inspectdb [OPTIONS]
@@ -179,17 +184,16 @@ Options:
Inspect all tables and print to console: Inspect all tables and print to console:
```shell ```shell
aerich --app models inspectdb -t user aerich --app models inspectdb
``` ```
Inspect a specified table in default app and redirect to `models.py`: Inspect a specified table in the default app and redirect to `models.py`:
```shell ```shell
aerich inspectdb -t user > models.py aerich inspectdb -t user > models.py
``` ```
Note that this command is restricted, which is not supported in some solutions, such as `IntEnumField` Note that this command is limited and cannot infer some fields, such as `IntEnumField`, `ForeignKeyField`, and others.
and `ForeignKeyField` and so on.
### Multiple databases ### Multiple databases
@@ -206,13 +210,7 @@ tortoise_orm = {
} }
``` ```
You need only specify `aerich.models` in one app, and must specify `--app` when run `aerich migrate` and so on. You only need to specify `aerich.models` in one app, and must specify `--app` when running `aerich migrate` and so on.
## Support this project
| AliPay | WeChatPay | PayPal |
| -------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | ---------------------------------------------------------------- |
| <img width="200" src="https://github.com/long2ice/aerich/raw/dev/images/alipay.jpeg"/> | <img width="200" src="https://github.com/long2ice/aerich/raw/dev/images/wechatpay.jpeg"/> | [PayPal](https://www.paypal.me/long2ice) to my account long2ice. |
## License ## License

View File

@@ -1 +1 @@
__version__ = "0.4.4" __version__ = "0.5.4"

View File

@@ -1,6 +1,5 @@
import asyncio import asyncio
import os import os
import sys
from configparser import ConfigParser from configparser import ConfigParser
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
@@ -16,8 +15,10 @@ from tortoise.utils import get_schema_sql
from aerich.inspectdb import InspectDb from aerich.inspectdb import InspectDb
from aerich.migrate import Migrate from aerich.migrate import Migrate
from aerich.utils import ( from aerich.utils import (
add_src_path,
get_app_connection, get_app_connection,
get_app_connection_name, get_app_connection_name,
get_models_describe,
get_tortoise_config, get_tortoise_config,
get_version_content_from_file, get_version_content_from_file,
write_version_file, write_version_file,
@@ -34,11 +35,13 @@ def coro(f):
@wraps(f) @wraps(f)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
ctx = args[0]
loop.run_until_complete(f(*args, **kwargs)) # Close db connections at the end of all all but the cli group function
app = ctx.obj.get("app") try:
if app: loop.run_until_complete(f(*args, **kwargs))
Migrate.remove_old_model_file(app, ctx.obj["location"]) finally:
if f.__name__ != "cli":
loop.run_until_complete(Tortoise.close_connections())
return wrapper return wrapper
@@ -46,11 +49,7 @@ def coro(f):
@click.group(context_settings={"help_option_names": ["-h", "--help"]}) @click.group(context_settings={"help_option_names": ["-h", "--help"]})
@click.version_option(__version__, "-V", "--version") @click.version_option(__version__, "-V", "--version")
@click.option( @click.option(
"-c", "-c", "--config", default="aerich.ini", show_default=True, help="Config file.",
"--config",
default="aerich.ini",
show_default=True,
help="Config file.",
) )
@click.option("--app", required=False, help="Tortoise-ORM app name.") @click.option("--app", required=False, help="Tortoise-ORM app name.")
@click.option( @click.option(
@@ -75,6 +74,10 @@ async def cli(ctx: Context, config, app, name):
location = parser[name]["location"] location = parser[name]["location"]
tortoise_orm = parser[name]["tortoise_orm"] tortoise_orm = parser[name]["tortoise_orm"]
src_folder = parser[name]["src_folder"]
# Add specified source folder to path
add_src_path(src_folder)
tortoise_config = get_tortoise_config(ctx, tortoise_orm) tortoise_config = get_tortoise_config(ctx, tortoise_orm)
app = app or list(tortoise_config.get("apps").keys())[0] app = app or list(tortoise_config.get("apps").keys())[0]
@@ -85,7 +88,7 @@ async def cli(ctx: Context, config, app, name):
if invoked_subcommand != "init-db": if invoked_subcommand != "init-db":
if not Path(location, app).exists(): if not Path(location, app).exists():
raise UsageError("You must exec init-db first", ctx=ctx) raise UsageError("You must exec init-db first", ctx=ctx)
await Migrate.init_with_old_models(tortoise_config, app, location) await Migrate.init(tortoise_config, app, location)
@cli.command(help="Generate migrate changes file.") @cli.command(help="Generate migrate changes file.")
@@ -105,7 +108,6 @@ async def migrate(ctx: Context, name):
async def upgrade(ctx: Context): async def upgrade(ctx: Context):
config = ctx.obj["config"] config = ctx.obj["config"]
app = ctx.obj["app"] app = ctx.obj["app"]
location = ctx.obj["location"]
migrated = False migrated = False
for version_file in Migrate.get_all_version_files(): for version_file in Migrate.get_all_version_files():
try: try:
@@ -120,9 +122,7 @@ async def upgrade(ctx: Context):
for upgrade_query in upgrade_query_list: for upgrade_query in upgrade_query_list:
await conn.execute_script(upgrade_query) await conn.execute_script(upgrade_query)
await Aerich.create( await Aerich.create(
version=version_file, version=version_file, app=app, content=get_models_describe(app),
app=app,
content=Migrate.get_models_content(config, app, location),
) )
click.secho(f"Success upgrade {version_file}", fg=Color.green) click.secho(f"Success upgrade {version_file}", fg=Color.green)
migrated = True migrated = True
@@ -216,26 +216,37 @@ async def history(ctx: Context):
help="Tortoise-ORM config module dict variable, like settings.TORTOISE_ORM.", help="Tortoise-ORM config module dict variable, like settings.TORTOISE_ORM.",
) )
@click.option( @click.option(
"--location", "--location", default="./migrations", show_default=True, help="Migrate store location.",
default="./migrations", )
show_default=True, @click.option(
help="Migrate store location.", "-s",
"--src_folder",
default=".",
show_default=False,
help="Folder of the source, relative to the project root.",
) )
@click.pass_context @click.pass_context
@coro @coro
async def init( async def init(ctx: Context, tortoise_orm, location, src_folder):
ctx: Context,
tortoise_orm,
location,
):
config_file = ctx.obj["config_file"] config_file = ctx.obj["config_file"]
name = ctx.obj["name"] name = ctx.obj["name"]
if Path(config_file).exists(): if Path(config_file).exists():
return click.secho("You have inited", fg=Color.yellow) return click.secho("Configuration file already created", fg=Color.yellow)
if os.path.isabs(src_folder):
src_folder = os.path.relpath(os.getcwd(), src_folder)
# Add ./ so it's clear that this is relative path
if not src_folder.startswith("./"):
src_folder = "./" + src_folder
# check that we can find the configuration, if not we can fail before the config file gets created
add_src_path(src_folder)
get_tortoise_config(ctx, tortoise_orm)
parser.add_section(name) parser.add_section(name)
parser.set(name, "tortoise_orm", tortoise_orm) parser.set(name, "tortoise_orm", tortoise_orm)
parser.set(name, "location", location) parser.set(name, "location", location)
parser.set(name, "src_folder", src_folder)
with open(config_file, "w", encoding="utf-8") as f: with open(config_file, "w", encoding="utf-8") as f:
parser.write(f) parser.write(f)
@@ -278,9 +289,7 @@ async def init_db(ctx: Context, safe):
version = await Migrate.generate_version() version = await Migrate.generate_version()
await Aerich.create( await Aerich.create(
version=version, version=version, app=app, content=get_models_describe(app),
app=app,
content=Migrate.get_models_content(config, app, location),
) )
content = { content = {
"upgrade": [schema], "upgrade": [schema],
@@ -291,11 +300,7 @@ async def init_db(ctx: Context, safe):
@cli.command(help="Introspects the database tables to standard output as TortoiseORM model.") @cli.command(help="Introspects the database tables to standard output as TortoiseORM model.")
@click.option( @click.option(
"-t", "-t", "--table", help="Which tables to inspect.", multiple=True, required=False,
"--table",
help="Which tables to inspect.",
multiple=True,
required=False,
) )
@click.pass_context @click.pass_context
@coro @coro
@@ -309,7 +314,6 @@ async def inspectdb(ctx: Context, table: List[str]):
def main(): def main():
sys.path.insert(0, ".")
cli() cli()

View File

@@ -1,8 +1,10 @@
from enum import Enum
from typing import List, Type from typing import List, Type
from tortoise import BaseDBAsyncClient, ForeignKeyFieldInstance, ManyToManyFieldInstance, Model from tortoise import BaseDBAsyncClient, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.fields import CASCADE, Field, JSONField, TextField, UUIDField
from aerich.utils import is_default_function
class BaseDDL: class BaseDDL:
@@ -11,20 +13,22 @@ class BaseDDL:
_DROP_TABLE_TEMPLATE = 'DROP TABLE IF EXISTS "{table_name}"' _DROP_TABLE_TEMPLATE = 'DROP TABLE IF EXISTS "{table_name}"'
_ADD_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {column}' _ADD_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {column}'
_DROP_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" DROP COLUMN "{column_name}"' _DROP_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" DROP COLUMN "{column_name}"'
_ALTER_DEFAULT_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {default}'
_RENAME_COLUMN_TEMPLATE = ( _RENAME_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" RENAME COLUMN "{old_column_name}" TO "{new_column_name}"' 'ALTER TABLE "{table_name}" RENAME COLUMN "{old_column_name}" TO "{new_column_name}"'
) )
_ADD_INDEX_TEMPLATE = ( _ADD_INDEX_TEMPLATE = (
'ALTER TABLE "{table_name}" ADD {unique} INDEX "{index_name}" ({column_names})' 'ALTER TABLE "{table_name}" ADD {unique}INDEX "{index_name}" ({column_names})'
) )
_DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"' _DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"'
_ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}' _ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"'
_M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment};' _M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment}'
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}' _MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}'
_CHANGE_COLUMN_TEMPLATE = ( _CHANGE_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" CHANGE {old_column_name} {new_column_name} {new_column_type}' 'ALTER TABLE "{table_name}" CHANGE {old_column_name} {new_column_name} {new_column_type}'
) )
_RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"'
def __init__(self, client: "BaseDBAsyncClient"): def __init__(self, client: "BaseDBAsyncClient"):
self.client = client self.client = client
@@ -33,43 +37,54 @@ class BaseDDL:
def create_table(self, model: "Type[Model]"): def create_table(self, model: "Type[Model]"):
return self.schema_generator._get_table_sql(model, True)["table_creation_string"] return self.schema_generator._get_table_sql(model, True)["table_creation_string"]
def drop_table(self, model: "Type[Model]"): def drop_table(self, table_name: str):
return self._DROP_TABLE_TEMPLATE.format(table_name=model._meta.db_table) return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def create_m2m_table(self, model: "Type[Model]", field: ManyToManyFieldInstance): def create_m2m(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
):
through = field_describe.get("through")
description = field_describe.get("description")
reference_id = reference_table_describe.get("pk_field").get("db_column")
db_field_types = reference_table_describe.get("pk_field").get("db_field_types")
return self._M2M_TABLE_TEMPLATE.format( return self._M2M_TABLE_TEMPLATE.format(
table_name=field.through, table_name=through,
backward_table=model._meta.db_table, backward_table=model._meta.db_table,
forward_table=field.related_model._meta.db_table, forward_table=reference_table_describe.get("table"),
backward_field=model._meta.db_pk_column, backward_field=model._meta.db_pk_column,
forward_field=field.related_model._meta.db_pk_column, forward_field=reference_id,
backward_key=field.backward_key, backward_key=field_describe.get("backward_key"),
backward_type=model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"), backward_type=model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"),
forward_key=field.forward_key, forward_key=field_describe.get("forward_key"),
forward_type=field.related_model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"), forward_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
on_delete=CASCADE, on_delete=field_describe.get("on_delete"),
extra=self.schema_generator._table_generate_extra(table=field.through), extra=self.schema_generator._table_generate_extra(table=through),
comment=self.schema_generator._table_comment_generator( comment=self.schema_generator._table_comment_generator(
table=field.through, comment=field.description table=through, comment=description
) )
if field.description if description
else "", else "",
) )
def drop_m2m(self, field: ManyToManyFieldInstance): def drop_m2m(self, table_name: str):
return self._DROP_TABLE_TEMPLATE.format(table_name=field.through) return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def _get_default(self, model: "Type[Model]", field_object: Field): def _get_default(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table db_table = model._meta.db_table
default = field_object.default default = field_describe.get("default")
db_column = field_object.model_field_name if isinstance(default, Enum):
auto_now_add = getattr(field_object, "auto_now_add", False) default = default.value
auto_now = getattr(field_object, "auto_now", False) db_column = field_describe.get("db_column")
auto_now_add = field_describe.get("auto_now_add", False)
auto_now = field_describe.get("auto_now", False)
if default is not None or auto_now_add: if default is not None or auto_now_add:
if callable(default) or isinstance(field_object, (UUIDField, TextField, JSONField)): if field_describe.get("field_type") in [
"UUIDField",
"TextField",
"JSONField",
] or is_default_function(default):
default = "" default = ""
else: else:
default = field_object.to_db_value(default, model)
try: try:
default = self.schema_generator._column_default_generator( default = self.schema_generator._column_default_generator(
db_table, db_table,
@@ -81,28 +96,31 @@ class BaseDDL:
except NotImplementedError: except NotImplementedError:
default = "" default = ""
else: else:
default = "" default = None
return default return default
def add_column(self, model: "Type[Model]", field_object: Field): def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
db_table = model._meta.db_table db_table = model._meta.db_table
description = field_describe.get("description")
db_column = field_describe.get("db_column")
db_field_types = field_describe.get("db_field_types")
default = self._get_default(model, field_describe)
if default is None:
default = ""
return self._ADD_COLUMN_TEMPLATE.format( return self._ADD_COLUMN_TEMPLATE.format(
table_name=db_table, table_name=db_table,
column=self.schema_generator._create_string( column=self.schema_generator._create_string(
db_column=field_object.model_field_name, db_column=db_column,
field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"), field_type=db_field_types.get(self.DIALECT, db_field_types.get("")),
nullable="NOT NULL" if not field_object.null else "", nullable="NOT NULL" if not field_describe.get("nullable") else "",
unique="UNIQUE" if field_object.unique else "", unique="UNIQUE" if field_describe.get("unique") else "",
comment=self.schema_generator._column_comment_generator( comment=self.schema_generator._column_comment_generator(
table=db_table, table=db_table, column=db_column, comment=field_describe.get("description"),
column=field_object.model_field_name,
comment=field_object.description,
) )
if field_object.description if description
else "", else "",
is_primary_key=field_object.pk, is_primary_key=is_pk,
default=self._get_default(model, field_object), default=default,
), ),
) )
@@ -111,24 +129,28 @@ class BaseDDL:
table_name=model._meta.db_table, column_name=column_name table_name=model._meta.db_table, column_name=column_name
) )
def modify_column(self, model: "Type[Model]", field_object: Field): def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
db_table = model._meta.db_table db_table = model._meta.db_table
db_field_types = field_describe.get("db_field_types")
default = self._get_default(model, field_describe)
if default is None:
default = ""
return self._MODIFY_COLUMN_TEMPLATE.format( return self._MODIFY_COLUMN_TEMPLATE.format(
table_name=db_table, table_name=db_table,
column=self.schema_generator._create_string( column=self.schema_generator._create_string(
db_column=field_object.model_field_name, db_column=field_describe.get("db_column"),
field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"), field_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
nullable="NOT NULL" if not field_object.null else "", nullable="NOT NULL" if not field_describe.get("nullable") else "",
unique="", unique="",
comment=self.schema_generator._column_comment_generator( comment=self.schema_generator._column_comment_generator(
table=db_table, table=db_table,
column=field_object.model_field_name, column=field_describe.get("db_column"),
comment=field_object.description, comment=field_describe.get("description"),
) )
if field_object.description if field_describe.get("description")
else "", else "",
is_primary_key=field_object.pk, is_primary_key=is_pk,
default=self._get_default(model, field_object), default=default,
), ),
) )
@@ -151,7 +173,7 @@ class BaseDDL:
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False): def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
return self._ADD_INDEX_TEMPLATE.format( return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE" if unique else "", unique="UNIQUE " if unique else "",
index_name=self.schema_generator._generate_index_name( index_name=self.schema_generator._generate_index_name(
"idx" if not unique else "uid", model, field_names "idx" if not unique else "uid", model, field_names
), ),
@@ -167,48 +189,55 @@ class BaseDDL:
table_name=model._meta.db_table, table_name=model._meta.db_table,
) )
def add_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance): def add_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict):
db_table = model._meta.db_table db_table = model._meta.db_table
to_field_name = field.to_field_instance.source_field
if not to_field_name:
to_field_name = field.to_field_instance.model_field_name
db_column = field.source_field or field.model_field_name + "_id" db_column = field_describe.get("raw_field")
reference_id = reference_table_describe.get("pk_field").get("db_column")
fk_name = self.schema_generator._generate_fk_name( fk_name = self.schema_generator._generate_fk_name(
from_table=db_table, from_table=db_table,
from_field=db_column, from_field=db_column,
to_table=field.related_model._meta.db_table, to_table=reference_table_describe.get("table"),
to_field=to_field_name, to_field=reference_table_describe.get("pk_field").get("db_column"),
) )
return self._ADD_FK_TEMPLATE.format( return self._ADD_FK_TEMPLATE.format(
table_name=db_table, table_name=db_table,
fk_name=fk_name, fk_name=fk_name,
db_column=db_column, db_column=db_column,
table=field.related_model._meta.db_table, table=reference_table_describe.get("table"),
field=to_field_name, field=reference_id,
on_delete=field.on_delete, on_delete=field_describe.get("on_delete"),
) )
def drop_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance): def drop_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict):
to_field_name = field.to_field_instance.source_field
if not to_field_name:
to_field_name = field.to_field_instance.model_field_name
db_table = model._meta.db_table db_table = model._meta.db_table
return self._DROP_FK_TEMPLATE.format( return self._DROP_FK_TEMPLATE.format(
table_name=db_table, table_name=db_table,
fk_name=self.schema_generator._generate_fk_name( fk_name=self.schema_generator._generate_fk_name(
from_table=db_table, from_table=db_table,
from_field=field.source_field or field.model_field_name + "_id", from_field=field_describe.get("raw_field"),
to_table=field.related_model._meta.db_table, to_table=reference_table_describe.get("table"),
to_field=to_field_name, to_field=reference_table_describe.get("pk_field").get("db_column"),
), ),
) )
def alter_column_default(self, model: "Type[Model]", field_object: Field): def alter_column_default(self, model: "Type[Model]", field_describe: dict):
pass db_table = model._meta.db_table
default = self._get_default(model, field_describe)
return self._ALTER_DEFAULT_TEMPLATE.format(
table_name=db_table,
column=field_describe.get("db_column"),
default="SET" + default if default is not None else "DROP DEFAULT",
)
def alter_column_null(self, model: "Type[Model]", field_object: Field): def alter_column_null(self, model: "Type[Model]", field_describe: dict):
pass return self.modify_column(model, field_describe)
def set_comment(self, model: "Type[Model]", field_object: Field): def set_comment(self, model: "Type[Model]", field_describe: dict):
pass return self.modify_column(model, field_describe)
def rename_table(self, model: "Type[Model]", old_table_name: str, new_table_name: str):
db_table = model._meta.db_table
return self._RENAME_TABLE_TEMPLATE.format(
table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name
)

View File

@@ -8,15 +8,20 @@ class MysqlDDL(BaseDDL):
DIALECT = MySQLSchemaGenerator.DIALECT DIALECT = MySQLSchemaGenerator.DIALECT
_DROP_TABLE_TEMPLATE = "DROP TABLE IF EXISTS `{table_name}`" _DROP_TABLE_TEMPLATE = "DROP TABLE IF EXISTS `{table_name}`"
_ADD_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` ADD {column}" _ADD_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` ADD {column}"
_ALTER_DEFAULT_TEMPLATE = "ALTER TABLE `{table_name}` ALTER COLUMN `{column}` {default}"
_CHANGE_COLUMN_TEMPLATE = (
"ALTER TABLE `{table_name}` CHANGE {old_column_name} {new_column_name} {new_column_type}"
)
_DROP_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` DROP COLUMN `{column_name}`" _DROP_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` DROP COLUMN `{column_name}`"
_RENAME_COLUMN_TEMPLATE = ( _RENAME_COLUMN_TEMPLATE = (
"ALTER TABLE `{table_name}` RENAME COLUMN `{old_column_name}` TO `{new_column_name}`" "ALTER TABLE `{table_name}` RENAME COLUMN `{old_column_name}` TO `{new_column_name}`"
) )
_ADD_INDEX_TEMPLATE = ( _ADD_INDEX_TEMPLATE = (
"ALTER TABLE `{table_name}` ADD {unique} INDEX `{index_name}` ({column_names})" "ALTER TABLE `{table_name}` ADD {unique}INDEX `{index_name}` ({column_names})"
) )
_DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`" _DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`"
_ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}" _ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}"
_DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`" _DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`"
_M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment};" _M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment}"
_MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}" _MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}"
_RENAME_TABLE_TEMPLATE = "ALTER TABLE `{old_table_name}` RENAME TO `{new_table_name}`"

View File

@@ -1,8 +1,7 @@
from typing import List, Type from typing import Type
from tortoise import Model from tortoise import Model
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
from tortoise.fields import Field
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
@@ -10,66 +9,41 @@ from aerich.ddl import BaseDDL
class PostgresDDL(BaseDDL): class PostgresDDL(BaseDDL):
schema_generator_cls = AsyncpgSchemaGenerator schema_generator_cls = AsyncpgSchemaGenerator
DIALECT = AsyncpgSchemaGenerator.DIALECT DIALECT = AsyncpgSchemaGenerator.DIALECT
_ADD_INDEX_TEMPLATE = 'CREATE INDEX "{index_name}" ON "{table_name}" ({column_names})' _ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX "{index_name}" ON "{table_name}" ({column_names})'
_ADD_UNIQUE_TEMPLATE = (
'ALTER TABLE "{table_name}" ADD CONSTRAINT "{index_name}" UNIQUE ({column_names})'
)
_DROP_INDEX_TEMPLATE = 'DROP INDEX "{index_name}"' _DROP_INDEX_TEMPLATE = 'DROP INDEX "{index_name}"'
_DROP_UNIQUE_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{index_name}"'
_ALTER_DEFAULT_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {default}'
_ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL' _ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL'
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}' _MODIFY_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}{using}'
)
_SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}' _SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"'
def alter_column_default(self, model: "Type[Model]", field_object: Field): def alter_column_null(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table
default = self._get_default(model, field_object)
return self._ALTER_DEFAULT_TEMPLATE.format(
table_name=db_table,
column=field_object.model_field_name,
default="SET" + default if default else "DROP DEFAULT",
)
def alter_column_null(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table db_table = model._meta.db_table
return self._ALTER_NULL_TEMPLATE.format( return self._ALTER_NULL_TEMPLATE.format(
table_name=db_table, table_name=db_table,
column=field_object.model_field_name, column=field_describe.get("db_column"),
set_drop="DROP" if field_object.null else "SET", set_drop="DROP" if field_describe.get("nullable") else "SET",
) )
def modify_column(self, model: "Type[Model]", field_object: Field): def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
db_table = model._meta.db_table db_table = model._meta.db_table
db_field_types = field_describe.get("db_field_types")
db_column = field_describe.get("db_column")
datatype = db_field_types.get(self.DIALECT) or db_field_types.get("")
return self._MODIFY_COLUMN_TEMPLATE.format( return self._MODIFY_COLUMN_TEMPLATE.format(
table_name=db_table, table_name=db_table,
column=field_object.model_field_name, column=db_column,
datatype=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"), datatype=datatype,
using=f' USING "{db_column}"::{datatype}',
) )
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False): def set_comment(self, model: "Type[Model]", field_describe: dict):
template = self._ADD_UNIQUE_TEMPLATE if unique else self._ADD_INDEX_TEMPLATE
return template.format(
index_name=self.schema_generator._generate_index_name(
"uid" if unique else "idx", model, field_names
),
table_name=model._meta.db_table,
column_names=", ".join([self.schema_generator.quote(f) for f in field_names]),
)
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False):
template = self._DROP_UNIQUE_TEMPLATE if unique else self._DROP_INDEX_TEMPLATE
return template.format(
index_name=self.schema_generator._generate_index_name(
"uid" if unique else "idx", model, field_names
),
table_name=model._meta.db_table,
)
def set_comment(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table db_table = model._meta.db_table
return self._SET_COMMENT_TEMPLATE.format( return self._SET_COMMENT_TEMPLATE.format(
table_name=db_table, table_name=db_table,
column=field_object.model_field_name, column=field_describe.get("db_column") or field_describe.get("raw_field"),
comment="'{}'".format(field_object.description) if field_object.description else "NULL", comment="'{}'".format(field_describe.get("description"))
if field_describe.get("description")
else "NULL",
) )

View File

@@ -2,7 +2,6 @@ from typing import Type
from tortoise import Model from tortoise import Model
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
from tortoise.fields import Field
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
from aerich.exceptions import NotSupportError from aerich.exceptions import NotSupportError
@@ -12,8 +11,14 @@ class SqliteDDL(BaseDDL):
schema_generator_cls = SqliteSchemaGenerator schema_generator_cls = SqliteSchemaGenerator
DIALECT = SqliteSchemaGenerator.DIALECT DIALECT = SqliteSchemaGenerator.DIALECT
def drop_column(self, model: "Type[Model]", column_name: str): def modify_column(self, model: "Type[Model]", field_object: dict, is_pk: bool = True):
raise NotSupportError("Drop column is unsupported in SQLite.")
def modify_column(self, model: "Type[Model]", field_object: Field):
raise NotSupportError("Modify column is unsupported in SQLite.") raise NotSupportError("Modify column is unsupported in SQLite.")
def alter_column_default(self, model: "Type[Model]", field_describe: dict):
raise NotSupportError("Alter column default is unsupported in SQLite.")
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
raise NotSupportError("Alter column null is unsupported in SQLite.")
def set_comment(self, model: "Type[Model]", field_describe: dict):
raise NotSupportError("Alter column comment is unsupported in SQLite.")

View File

@@ -15,6 +15,7 @@ class InspectDb:
"LONGTEXT": " {field} = fields.TextField({null}{default}{comment})", "LONGTEXT": " {field} = fields.TextField({null}{default}{comment})",
"TEXT": " {field} = fields.TextField({null}{default}{comment})", "TEXT": " {field} = fields.TextField({null}{default}{comment})",
"DATETIME": " {field} = fields.DatetimeField({null}{default}{comment})", "DATETIME": " {field} = fields.DatetimeField({null}{default}{comment})",
"FLOAT": " {field} = fields.FloatField({null}{default}{comment})",
} }
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None): def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):

View File

@@ -1,29 +1,21 @@
import inspect
import os import os
import re
from datetime import datetime from datetime import datetime
from importlib import import_module
from io import StringIO
from pathlib import Path from pathlib import Path
from types import ModuleType
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Tuple, Type
import click import click
from tortoise import ( from dictdiffer import diff
BackwardFKRelation, from tortoise import BaseDBAsyncClient, Model, Tortoise
BackwardOneToOneRelation,
BaseDBAsyncClient,
ForeignKeyFieldInstance,
ManyToManyFieldInstance,
Model,
Tortoise,
)
from tortoise.exceptions import OperationalError from tortoise.exceptions import OperationalError
from tortoise.fields import Field
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
from aerich.models import MAX_VERSION_LENGTH, Aerich from aerich.models import MAX_VERSION_LENGTH, Aerich
from aerich.utils import get_app_connection, write_version_file from aerich.utils import (
get_app_connection,
get_models_describe,
is_default_function,
write_version_file,
)
class Migrate: class Migrate:
@@ -38,18 +30,12 @@ class Migrate:
_rename_new = [] _rename_new = []
ddl: BaseDDL ddl: BaseDDL
migrate_config: dict _last_version_content: Optional[dict] = None
old_models = "old_models"
diff_app = "diff_models"
app: str app: str
migrate_location: str migrate_location: str
dialect: str dialect: str
_db_version: Optional[str] = None _db_version: Optional[str] = None
@classmethod
def get_old_model_file(cls, app: str, location: str):
return Path(location, app, cls.old_models + ".py")
@classmethod @classmethod
def get_all_version_files(cls) -> List[str]: def get_all_version_files(cls) -> List[str]:
return sorted( return sorted(
@@ -57,6 +43,10 @@ class Migrate:
key=lambda x: int(x.split("_")[0]), key=lambda x: int(x.split("_")[0]),
) )
@classmethod
def _get_model(cls, model: str) -> Type[Model]:
return Tortoise.apps.get(cls.app).get(model)
@classmethod @classmethod
async def get_last_version(cls) -> Optional[Aerich]: async def get_last_version(cls) -> Optional[Aerich]:
try: try:
@@ -64,13 +54,6 @@ class Migrate:
except OperationalError: except OperationalError:
pass pass
@classmethod
def remove_old_model_file(cls, app: str, location: str):
try:
os.unlink(cls.get_old_model_file(app, location))
except (OSError, FileNotFoundError):
pass
@classmethod @classmethod
async def _get_db_version(cls, connection: BaseDBAsyncClient): async def _get_db_version(cls, connection: BaseDBAsyncClient):
if cls.dialect == "mysql": if cls.dialect == "mysql":
@@ -79,19 +62,13 @@ class Migrate:
cls._db_version = ret[1][0].get("version") cls._db_version = ret[1][0].get("version")
@classmethod @classmethod
async def init_with_old_models(cls, config: dict, app: str, location: str): async def init(cls, config: dict, app: str, location: str):
await Tortoise.init(config=config) await Tortoise.init(config=config)
last_version = await cls.get_last_version() last_version = await cls.get_last_version()
cls.app = app cls.app = app
cls.migrate_location = Path(location, app) cls.migrate_location = Path(location, app)
if last_version: if last_version:
content = last_version.content cls._last_version_content = last_version.content
with open(cls.get_old_model_file(app, location), "w", encoding="utf-8") as f:
f.write(content)
migrate_config = cls._get_migrate_config(config, app, location)
cls.migrate_config = migrate_config
await Tortoise.init(config=migrate_config)
connection = get_app_connection(config, app) connection = get_app_connection(config, app)
cls.dialect = connection.schema_generator.DIALECT cls.dialect = connection.schema_generator.DIALECT
@@ -136,8 +113,8 @@ class Migrate:
if version_file.startswith(version.split("_")[0]): if version_file.startswith(version.split("_")[0]):
os.unlink(Path(cls.migrate_location, version_file)) os.unlink(Path(cls.migrate_location, version_file))
content = { content = {
"upgrade": cls.upgrade_operators, "upgrade": list(dict.fromkeys(cls.upgrade_operators)),
"downgrade": cls.downgrade_operators, "downgrade": list(dict.fromkeys(cls.downgrade_operators)),
} }
write_version_file(Path(cls.migrate_location, version), content) write_version_file(Path(cls.migrate_location, version), content)
return version return version
@@ -149,12 +126,9 @@ class Migrate:
:param name: :param name:
:return: :return:
""" """
apps = Tortoise.apps new_version_content = get_models_describe(cls.app)
diff_models = apps.get(cls.diff_app) cls.diff_models(cls._last_version_content, new_version_content)
app_models = apps.get(cls.app) cls.diff_models(new_version_content, cls._last_version_content, False)
cls.diff_models(diff_models, app_models)
cls.diff_models(app_models, diff_models, False)
cls._merge_operators() cls._merge_operators()
@@ -184,58 +158,7 @@ class Migrate:
cls.downgrade_operators.append(operator) cls.downgrade_operators.append(operator)
@classmethod @classmethod
def _get_migrate_config(cls, config: dict, app: str, location: str): def diff_models(cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True):
"""
generate tmp config with old models
:param config:
:param app:
:param location:
:return:
"""
path = Path(location, app, cls.old_models).as_posix().replace("/", ".")
config["apps"][cls.diff_app] = {
"models": [path],
"default_connection": config.get("apps").get(app).get("default_connection", "default"),
}
return config
@classmethod
def get_models_content(cls, config: dict, app: str, location: str):
"""
write new models to old models
:param config:
:param app:
:param location:
:return:
"""
old_model_files = []
models = config.get("apps").get(app).get("models")
for model in models:
if isinstance(model, ModuleType):
module = model
else:
module = import_module(model)
possible_models = [getattr(module, attr_name) for attr_name in dir(module)]
for attr in filter(
lambda x: inspect.isclass(x) and issubclass(x, Model) and x is not Model,
possible_models,
):
file = inspect.getfile(attr)
if file not in old_model_files:
old_model_files.append(file)
pattern = rf"(\n)?('|\")({app})(.\w+)('|\")"
str_io = StringIO()
for i, model_file in enumerate(old_model_files):
with open(model_file, "r", encoding="utf-8") as f:
content = f.read()
ret = re.sub(pattern, rf"\2{cls.diff_app}\4\5", content)
str_io.write(f"{ret}\n")
return str_io.getvalue()
@classmethod
def diff_models(
cls, old_models: Dict[str, Type[Model]], new_models: Dict[str, Type[Model]], upgrade=True
):
""" """
diff models and add operators diff models and add operators
:param old_models: :param old_models:
@@ -243,192 +166,263 @@ class Migrate:
:param upgrade: :param upgrade:
:return: :return:
""" """
old_models.pop(cls._aerich, None) _aerich = f"{cls.app}.{cls._aerich}"
new_models.pop(cls._aerich, None) old_models.pop(_aerich, None)
new_models.pop(_aerich, None)
for new_model_str, new_model_describe in new_models.items():
model = cls._get_model(new_model_describe.get("name").split(".")[1])
for new_model_str, new_model in new_models.items():
if new_model_str not in old_models.keys(): if new_model_str not in old_models.keys():
cls._add_operator(cls.add_model(new_model), upgrade) if upgrade:
cls._add_operator(cls.add_model(model), upgrade)
else:
# we can't find origin model when downgrade, so skip
pass
else: else:
cls.diff_model(old_models.get(new_model_str), new_model, upgrade) old_model_describe = old_models.get(new_model_str)
# rename table
new_table = new_model_describe.get("table")
old_table = old_model_describe.get("table")
if new_table != old_table:
cls._add_operator(cls.rename_table(model, old_table, new_table), upgrade)
old_unique_together = set(
map(lambda x: tuple(x), old_model_describe.get("unique_together"))
)
new_unique_together = set(
map(lambda x: tuple(x), new_model_describe.get("unique_together"))
)
old_pk_field = old_model_describe.get("pk_field")
new_pk_field = new_model_describe.get("pk_field")
# pk field
changes = diff(old_pk_field, new_pk_field)
for action, option, change in changes:
# current only support rename pk
if action == "change" and option == "name":
cls._add_operator(cls._rename_field(model, *change), upgrade)
# m2m fields
old_m2m_fields = old_model_describe.get("m2m_fields")
new_m2m_fields = new_model_describe.get("m2m_fields")
for action, option, change in diff(old_m2m_fields, new_m2m_fields):
table = change[0][1].get("through")
if action == "add":
add = False
if upgrade and table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table)
add = True
elif not upgrade and table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table)
add = True
if add:
cls._add_operator(
cls.create_m2m(
model,
change[0][1],
new_models.get(change[0][1].get("model_name")),
),
upgrade,
fk_m2m=True,
)
elif action == "remove":
add = False
if upgrade and table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table)
add = True
elif not upgrade and table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table)
add = True
if add:
cls._add_operator(cls.drop_m2m(table), upgrade, fk_m2m=True)
# add unique_together
for index in new_unique_together.difference(old_unique_together):
cls._add_operator(cls._add_index(model, index, True), upgrade, True)
# remove unique_together
for index in old_unique_together.difference(new_unique_together):
cls._add_operator(cls._drop_index(model, index, True), upgrade, True)
old_data_fields = old_model_describe.get("data_fields")
new_data_fields = new_model_describe.get("data_fields")
old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields))
new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields))
# add fields or rename fields
for new_data_field_name in set(new_data_fields_name).difference(
set(old_data_fields_name)
):
new_data_field = next(
filter(lambda x: x.get("name") == new_data_field_name, new_data_fields)
)
is_rename = False
for old_data_field in old_data_fields:
changes = list(diff(old_data_field, new_data_field))
old_data_field_name = old_data_field.get("name")
if len(changes) == 2:
# rename field
if (
changes[0]
== ("change", "name", (old_data_field_name, new_data_field_name),)
and changes[1]
== (
"change",
"db_column",
(
old_data_field.get("db_column"),
new_data_field.get("db_column"),
),
)
and old_data_field_name not in new_data_fields_name
):
if upgrade:
is_rename = click.prompt(
f"Rename {old_data_field_name} to {new_data_field_name}?",
default=True,
type=bool,
show_choices=True,
)
else:
is_rename = old_data_field_name in cls._rename_new
if is_rename:
cls._rename_new.append(new_data_field_name)
cls._rename_old.append(old_data_field_name)
# only MySQL8+ has rename syntax
if (
cls.dialect == "mysql"
and cls._db_version
and cls._db_version.startswith("5.")
):
cls._add_operator(
cls._modify_field(model, new_data_field), upgrade,
)
else:
cls._add_operator(
cls._rename_field(model, *changes[1][2]), upgrade,
)
if not is_rename:
cls._add_operator(
cls._add_field(model, new_data_field,), upgrade,
)
# remove fields
for old_data_field_name in set(old_data_fields_name).difference(
set(new_data_fields_name)
):
# don't remove field if is rename
if (upgrade and old_data_field_name in cls._rename_old) or (
not upgrade and old_data_field_name in cls._rename_new
):
continue
cls._add_operator(
cls._remove_field(
model,
next(
filter(
lambda x: x.get("name") == old_data_field_name, old_data_fields
)
).get("db_column"),
),
upgrade,
)
old_fk_fields = old_model_describe.get("fk_fields")
new_fk_fields = new_model_describe.get("fk_fields")
old_fk_fields_name = list(map(lambda x: x.get("name"), old_fk_fields))
new_fk_fields_name = list(map(lambda x: x.get("name"), new_fk_fields))
# add fk
for new_fk_field_name in set(new_fk_fields_name).difference(
set(old_fk_fields_name)
):
fk_field = next(
filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields)
)
cls._add_operator(
cls._add_fk(model, fk_field, new_models.get(fk_field.get("python_type"))),
upgrade,
fk_m2m=True,
)
# drop fk
for old_fk_field_name in set(old_fk_fields_name).difference(
set(new_fk_fields_name)
):
old_fk_field = next(
filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields)
)
cls._add_operator(
cls._drop_fk(
model, old_fk_field, old_models.get(old_fk_field.get("python_type"))
),
upgrade,
fk_m2m=True,
)
# change fields
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
old_data_field = next(
filter(lambda x: x.get("name") == field_name, old_data_fields)
)
new_data_field = next(
filter(lambda x: x.get("name") == field_name, new_data_fields)
)
changes = diff(old_data_field, new_data_field)
for change in changes:
_, option, old_new = change
if option == "indexed":
# change index
unique = new_data_field.get("unique")
if old_new[0] is False and old_new[1] is True:
cls._add_operator(
cls._add_index(model, (field_name,), unique), upgrade, True
)
else:
cls._add_operator(
cls._drop_index(model, (field_name,), unique), upgrade, True
)
elif option == "db_field_types.":
# continue since repeated with others
continue
elif option == "default":
if not (
is_default_function(old_new[0]) or is_default_function(old_new[1])
):
# change column default
cls._add_operator(
cls._alter_default(model, new_data_field), upgrade
)
elif option == "unique":
# because indexed include it
continue
elif option == "nullable":
# change nullable
cls._add_operator(cls._alter_null(model, new_data_field), upgrade)
else:
# modify column
cls._add_operator(
cls._modify_field(model, new_data_field), upgrade,
)
for old_model in old_models: for old_model in old_models:
if old_model not in new_models.keys(): if old_model not in new_models.keys():
cls._add_operator(cls.remove_model(old_models.get(old_model)), upgrade) cls._add_operator(cls.drop_model(old_models.get(old_model).get("table")), upgrade)
@classmethod @classmethod
def _is_fk_m2m(cls, field: Field): def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str):
return isinstance(field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)) return cls.ddl.rename_table(model, old_table_name, new_table_name)
@classmethod @classmethod
def add_model(cls, model: Type[Model]): def add_model(cls, model: Type[Model]):
return cls.ddl.create_table(model) return cls.ddl.create_table(model)
@classmethod @classmethod
def remove_model(cls, model: Type[Model]): def drop_model(cls, table_name: str):
return cls.ddl.drop_table(model) return cls.ddl.drop_table(table_name)
@classmethod @classmethod
def diff_model(cls, old_model: Type[Model], new_model: Type[Model], upgrade=True): def create_m2m(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
""" return cls.ddl.create_m2m(model, field_describe, reference_table_describe)
diff single model
:param old_model:
:param new_model:
:param upgrade:
:return:
"""
old_indexes = old_model._meta.indexes
new_indexes = new_model._meta.indexes
old_unique_together = old_model._meta.unique_together @classmethod
new_unique_together = new_model._meta.unique_together def drop_m2m(cls, table_name: str):
return cls.ddl.drop_m2m(table_name)
old_fields_map = old_model._meta.fields_map
new_fields_map = new_model._meta.fields_map
old_keys = old_fields_map.keys()
new_keys = new_fields_map.keys()
for new_key in new_keys:
new_field = new_fields_map.get(new_key)
if cls._exclude_field(new_field, upgrade):
continue
if new_key not in old_keys:
new_field_dict = new_field.describe(serializable=True)
new_field_dict.pop("name", None)
new_field_dict.pop("db_column", None)
for diff_key in old_keys - new_keys:
old_field = old_fields_map.get(diff_key)
old_field_dict = old_field.describe(serializable=True)
old_field_dict.pop("name", None)
old_field_dict.pop("db_column", None)
if old_field_dict == new_field_dict:
if upgrade:
is_rename = click.prompt(
f"Rename {diff_key} to {new_key}?",
default=True,
type=bool,
show_choices=True,
)
cls._rename_new.append(new_key)
cls._rename_old.append(diff_key)
else:
is_rename = diff_key in cls._rename_new
if is_rename:
if (
cls.dialect == "mysql"
and cls._db_version
and cls._db_version.startswith("5.")
):
cls._add_operator(
cls._change_field(new_model, old_field, new_field),
upgrade,
)
else:
cls._add_operator(
cls._rename_field(new_model, old_field, new_field),
upgrade,
)
break
else:
cls._add_operator(
cls._add_field(new_model, new_field),
upgrade,
cls._is_fk_m2m(new_field),
)
else:
old_field = old_fields_map.get(new_key)
new_field_dict = new_field.describe(serializable=True)
new_field_dict.pop("unique")
new_field_dict.pop("indexed")
old_field_dict = old_field.describe(serializable=True)
old_field_dict.pop("unique")
old_field_dict.pop("indexed")
if not cls._is_fk_m2m(new_field) and new_field_dict != old_field_dict:
if cls.dialect == "postgres":
if new_field.null != old_field.null:
cls._add_operator(
cls._alter_null(new_model, new_field), upgrade=upgrade
)
if new_field.default != old_field.default and not callable(
new_field.default
):
cls._add_operator(
cls._alter_default(new_model, new_field), upgrade=upgrade
)
if new_field.description != old_field.description:
cls._add_operator(
cls._set_comment(new_model, new_field), upgrade=upgrade
)
if new_field.field_type != old_field.field_type:
cls._add_operator(
cls._modify_field(new_model, new_field), upgrade=upgrade
)
else:
cls._add_operator(cls._modify_field(new_model, new_field), upgrade=upgrade)
if (old_field.index and not new_field.index) or (
old_field.unique and not new_field.unique
):
cls._add_operator(
cls._remove_index(
old_model, (old_field.model_field_name,), old_field.unique
),
upgrade,
cls._is_fk_m2m(old_field),
)
elif (new_field.index and not old_field.index) or (
new_field.unique and not old_field.unique
):
cls._add_operator(
cls._add_index(new_model, (new_field.model_field_name,), new_field.unique),
upgrade,
cls._is_fk_m2m(new_field),
)
if isinstance(new_field, ForeignKeyFieldInstance):
if old_field.db_constraint and not new_field.db_constraint:
cls._add_operator(
cls._drop_fk(new_model, new_field),
upgrade,
True,
)
if new_field.db_constraint and not old_field.db_constraint:
cls._add_operator(
cls._add_fk(new_model, new_field),
upgrade,
True,
)
for old_key in old_keys:
field = old_fields_map.get(old_key)
if old_key not in new_keys and not cls._exclude_field(field, upgrade):
if (upgrade and old_key not in cls._rename_old) or (
not upgrade and old_key not in cls._rename_new
):
cls._add_operator(
cls._remove_field(old_model, field),
upgrade,
cls._is_fk_m2m(field),
)
for new_index in new_indexes:
if new_index not in old_indexes:
cls._add_operator(
cls._add_index(
new_model,
new_index,
),
upgrade,
)
for old_index in old_indexes:
if old_index not in new_indexes:
cls._add_operator(cls._remove_index(old_model, old_index), upgrade)
for new_unique in new_unique_together:
if new_unique not in old_unique_together:
cls._add_operator(cls._add_index(new_model, new_unique, unique=True), upgrade)
for old_unique in old_unique_together:
if old_unique not in new_unique_together:
cls._add_operator(cls._remove_index(old_model, old_unique, unique=True), upgrade)
@classmethod @classmethod
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]): def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]):
@@ -441,7 +435,7 @@ class Migrate:
return ret return ret
@classmethod @classmethod
def _remove_index(cls, model: Type[Model], fields_name: Tuple[str], unique=False): def _drop_index(cls, model: Type[Model], fields_name: Tuple[str], unique=False):
fields_name = cls._resolve_fk_fields_name(model, fields_name) fields_name = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.drop_index(model, fields_name, unique) return cls.ddl.drop_index(model, fields_name, unique)
@@ -451,96 +445,57 @@ class Migrate:
return cls.ddl.add_index(model, fields_name, unique) return cls.ddl.add_index(model, fields_name, unique)
@classmethod @classmethod
def _exclude_field(cls, field: Field, upgrade=False): def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False):
""" return cls.ddl.add_column(model, field_describe, is_pk)
exclude BackwardFKRelation and repeat m2m field
:param field:
:return:
"""
if isinstance(field, ManyToManyFieldInstance):
through = field.through
if upgrade:
if through in cls._upgrade_m2m:
return True
else:
cls._upgrade_m2m.append(through)
return False
else:
if through in cls._downgrade_m2m:
return True
else:
cls._downgrade_m2m.append(through)
return False
return isinstance(field, (BackwardFKRelation, BackwardOneToOneRelation))
@classmethod @classmethod
def _add_field(cls, model: Type[Model], field: Field): def _alter_default(cls, model: Type[Model], field_describe: dict):
if isinstance(field, ForeignKeyFieldInstance): return cls.ddl.alter_column_default(model, field_describe)
return cls.ddl.add_fk(model, field)
if isinstance(field, ManyToManyFieldInstance):
return cls.ddl.create_m2m_table(model, field)
return cls.ddl.add_column(model, field)
@classmethod @classmethod
def _alter_default(cls, model: Type[Model], field: Field): def _alter_null(cls, model: Type[Model], field_describe: dict):
return cls.ddl.alter_column_default(model, field) return cls.ddl.alter_column_null(model, field_describe)
@classmethod @classmethod
def _alter_null(cls, model: Type[Model], field: Field): def _set_comment(cls, model: Type[Model], field_describe: dict):
return cls.ddl.alter_column_null(model, field) return cls.ddl.set_comment(model, field_describe)
@classmethod @classmethod
def _set_comment(cls, model: Type[Model], field: Field): def _modify_field(cls, model: Type[Model], field_describe: dict):
return cls.ddl.set_comment(model, field) return cls.ddl.modify_column(model, field_describe)
@classmethod @classmethod
def _modify_field(cls, model: Type[Model], field: Field): def _drop_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
return cls.ddl.modify_column(model, field) return cls.ddl.drop_fk(model, field_describe, reference_table_describe)
@classmethod @classmethod
def _drop_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance): def _remove_field(cls, model: Type[Model], column_name: str):
return cls.ddl.drop_fk(model, field) return cls.ddl.drop_column(model, column_name)
@classmethod @classmethod
def _remove_field(cls, model: Type[Model], field: Field): def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str):
if isinstance(field, ForeignKeyFieldInstance): return cls.ddl.rename_column(model, old_field_name, new_field_name)
return cls.ddl.drop_fk(model, field)
if isinstance(field, ManyToManyFieldInstance):
return cls.ddl.drop_m2m(field)
return cls.ddl.drop_column(model, field.model_field_name)
@classmethod @classmethod
def _rename_field(cls, model: Type[Model], old_field: Field, new_field: Field): def _change_field(cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict):
return cls.ddl.rename_column(model, old_field.model_field_name, new_field.model_field_name) db_field_types = new_field_describe.get("db_field_types")
@classmethod
def _change_field(cls, model: Type[Model], old_field: Field, new_field: Field):
return cls.ddl.change_column( return cls.ddl.change_column(
model, model,
old_field.model_field_name, old_field_describe.get("db_column"),
new_field.model_field_name, new_field_describe.get("db_column"),
new_field.get_for_dialect(cls.dialect, "SQL_TYPE"), db_field_types.get(cls.dialect) or db_field_types.get(""),
) )
@classmethod @classmethod
def _add_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance): def _add_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
""" """
add fk add fk
:param model: :param model:
:param field: :param field_describe:
:param reference_table_describe:
:return: :return:
""" """
return cls.ddl.add_fk(model, field) return cls.ddl.add_fk(model, field_describe, reference_table_describe)
@classmethod
def _remove_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance):
"""
drop fk
:param model:
:param field:
:return:
"""
return cls.ddl.drop_fk(model, field)
@classmethod @classmethod
def _merge_operators(cls): def _merge_operators(cls):

View File

@@ -6,7 +6,7 @@ MAX_VERSION_LENGTH = 255
class Aerich(Model): class Aerich(Model):
version = fields.CharField(max_length=MAX_VERSION_LENGTH) version = fields.CharField(max_length=MAX_VERSION_LENGTH)
app = fields.CharField(max_length=20) app = fields.CharField(max_length=20)
content = fields.TextField() content = fields.JSONField()
class Meta: class Meta:
ordering = ["-id"] ordering = ["-id"]

View File

@@ -1,10 +1,30 @@
import importlib import importlib
import os
import re
import sys
from pathlib import Path
from typing import Dict from typing import Dict
from click import BadOptionUsage, Context from click import BadOptionUsage, ClickException, Context
from tortoise import BaseDBAsyncClient, Tortoise from tortoise import BaseDBAsyncClient, Tortoise
def add_src_path(path: str) -> str:
"""
add a folder to the paths so we can import from there
:param path: path to add
:return: absolute path
"""
if not os.path.isabs(path):
# use the absolute path, otherwise some other things (e.g. __file__) won't work properly
path = os.path.abspath(path)
if not os.path.isdir(path):
raise ClickException(f"Specified source folder does not exist: {path}")
if path not in sys.path:
sys.path.insert(0, path)
return path
def get_app_connection_name(config, app_name: str) -> str: def get_app_connection_name(config, app_name: str) -> str:
""" """
get connection name get connection name
@@ -16,8 +36,7 @@ def get_app_connection_name(config, app_name: str) -> str:
if app: if app:
return app.get("default_connection", "default") return app.get("default_connection", "default")
raise BadOptionUsage( raise BadOptionUsage(
option_name="--app", option_name="--app", message=f'Can\'t get app named "{app_name}"',
message=f'Can\'t get app named "{app_name}"',
) )
@@ -41,12 +60,11 @@ def get_tortoise_config(ctx: Context, tortoise_orm: str) -> dict:
splits = tortoise_orm.split(".") splits = tortoise_orm.split(".")
config_path = ".".join(splits[:-1]) config_path = ".".join(splits[:-1])
tortoise_config = splits[-1] tortoise_config = splits[-1]
try: try:
config_module = importlib.import_module(config_path) config_module = importlib.import_module(config_path)
except (ModuleNotFoundError, AttributeError): except ModuleNotFoundError as e:
raise BadOptionUsage( raise ClickException(f"Error while importing configuration module: {e}") from None
ctx=ctx, message=f'No config named "{config_path}"', option_name="--config"
)
config = getattr(config_module, tortoise_config, None) config = getattr(config_module, tortoise_config, None)
if not config: if not config:
@@ -84,7 +102,7 @@ def get_version_content_from_file(version_file: str) -> Dict:
return ret return ret
def write_version_file(version_file: str, content: Dict): def write_version_file(version_file: Path, content: Dict):
""" """
write version file write version file
:param version_file: :param version_file:
@@ -108,3 +126,20 @@ def write_version_file(version_file: str, content: Dict):
f.write(";\n".join(downgrade) + ";\n") f.write(";\n".join(downgrade) + ";\n")
else: else:
f.write(f"{downgrade[0]};\n") f.write(f"{downgrade[0]};\n")
def get_models_describe(app: str) -> Dict:
"""
get app models describe
:param app:
:return:
"""
ret = {}
for model in Tortoise.apps.get(app).values():
describe = model.describe()
ret[describe.get("name")] = describe
return ret
def is_default_function(string: str):
return re.match(r"^<function.+>$", str(string or ""))

View File

@@ -51,12 +51,6 @@ def event_loop():
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
async def initialize_tests(event_loop, request): async def initialize_tests(event_loop, request):
tortoise_orm["connections"]["diff_models"] = "sqlite://:memory:"
tortoise_orm["apps"]["diff_models"] = {
"models": ["tests.diff_models"],
"default_connection": "diff_models",
}
await Tortoise.init(config=tortoise_orm, _create_db=True) await Tortoise.init(config=tortoise_orm, _create_db=True)
await generate_schema_for_client(Tortoise.get_connection("default"), safe=True) await generate_schema_for_client(Tortoise.get_connection("default"), safe=True)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 76 KiB

647
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,13 @@
[tool.poetry] [tool.poetry]
name = "aerich" name = "aerich"
version = "0.4.4" version = "0.5.4"
description = "A database migrations tool for Tortoise ORM." description = "A database migrations tool for Tortoise ORM."
authors = ["long2ice <long2ice@gmail.com>"] authors = ["long2ice <long2ice@gmail.com>"]
license = "Apache-2.0" license = "Apache-2.0"
readme = "README.md" readme = "README.md"
homepage = "https://github.com/long2ice/aerich" homepage = "https://github.com/tortoise/aerich"
repository = "https://github.com/long2ice/aerich.git" repository = "https://github.com/tortoise/aerich.git"
documentation = "https://github.com/long2ice/aerich" documentation = "https://github.com/tortoise/aerich"
keywords = ["migrate", "Tortoise-ORM", "mysql"] keywords = ["migrate", "Tortoise-ORM", "mysql"]
packages = [ packages = [
{ include = "aerich" } { include = "aerich" }
@@ -19,22 +19,26 @@ python = "^3.7"
tortoise-orm = "*" tortoise-orm = "*"
click = "*" click = "*"
pydantic = "*" pydantic = "*"
aiomysql = {version = "*", optional = true} aiomysql = { version = "*", optional = true }
asyncpg = {version = "*", optional = true} asyncpg = { version = "*", optional = true }
ddlparse = "*" ddlparse = "*"
dictdiffer = "*"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
flake8 = "*" flake8 = "*"
isort = "*" isort = "*"
black = "^20.8b1" black = "19.10b0"
pytest = "*" pytest = "*"
pytest-xdist = "*" pytest-xdist = "*"
pytest-asyncio = "*" pytest-asyncio = "*"
bandit = "*" bandit = "*"
pytest-mock = "*" pytest-mock = "*"
cryptography = "*"
[tool.poetry.extras] [tool.poetry.extras]
dbdrivers = ["aiomysql", "asyncpg"] asyncmy = ["asyncmy"]
asyncpg = ["asyncpg"]
aiomysql = ["aiomysql"]
[build-system] [build-system]
requires = ["poetry>=0.12"] requires = ["poetry>=0.12"]

View File

@@ -1,4 +1,5 @@
import datetime import datetime
import uuid
from enum import IntEnum from enum import IntEnum
from tortoise import Model, fields from tortoise import Model, fields
@@ -23,23 +24,28 @@ class Status(IntEnum):
class User(Model): class User(Model):
username = fields.CharField(max_length=20, unique=True) username = fields.CharField(max_length=20, unique=True)
password = fields.CharField(max_length=200) password = fields.CharField(max_length=100)
last_login = fields.DatetimeField(description="Last Login", default=datetime.datetime.now) last_login = fields.DatetimeField(description="Last Login", default=datetime.datetime.now)
is_active = fields.BooleanField(default=True, description="Is Active") is_active = fields.BooleanField(default=True, description="Is Active")
is_superuser = fields.BooleanField(default=False, description="Is SuperUser") is_superuser = fields.BooleanField(default=False, description="Is SuperUser")
avatar = fields.CharField(max_length=200, default="")
intro = fields.TextField(default="") intro = fields.TextField(default="")
class Email(Model): class Email(Model):
email = fields.CharField(max_length=200) email_id = fields.IntField(pk=True)
email = fields.CharField(max_length=200, index=True)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("models.User", db_constraint=False) address = fields.CharField(max_length=200)
users = fields.ManyToManyField("models.User")
def default_name():
return uuid.uuid4()
class Category(Model): class Category(Model):
slug = fields.CharField(max_length=200) slug = fields.CharField(max_length=100)
name = fields.CharField(max_length=200) name = fields.CharField(max_length=200, null=True, default=default_name)
user = fields.ForeignKeyField("models.User", description="User") user = fields.ForeignKeyField("models.User", description="User")
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
@@ -47,17 +53,25 @@ class Category(Model):
class Product(Model): class Product(Model):
categories = fields.ManyToManyField("models.Category") categories = fields.ManyToManyField("models.Category")
name = fields.CharField(max_length=50) name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num") view_num = fields.IntField(description="View Num", default=0)
sort = fields.IntField() sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed") is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField(ProductType, description="Product Type") type = fields.IntEnumField(ProductType, description="Product Type")
image = fields.CharField(max_length=200) pic = fields.CharField(max_length=200)
body = fields.TextField() body = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Meta:
unique_together = (("name", "type"),)
class Config(Model): class Config(Model):
label = fields.CharField(max_length=200) label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20) key = fields.CharField(max_length=20)
value = fields.JSONField() value = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on) status: Status = fields.IntEnumField(Status)
user = fields.ForeignKeyField("models.User", description="User")
class NewModel(Model):
name = fields.CharField(max_length=50)

View File

@@ -24,7 +24,7 @@ class Status(IntEnum):
class User(Model): class User(Model):
username = fields.CharField(max_length=20) username = fields.CharField(max_length=20)
password = fields.CharField(max_length=200) password = fields.CharField(max_length=200)
last_login_at = fields.DatetimeField(description="Last Login", default=datetime.datetime.now) last_login = fields.DatetimeField(description="Last Login", default=datetime.datetime.now)
is_active = fields.BooleanField(default=True, description="Is Active") is_active = fields.BooleanField(default=True, description="Is Active")
is_superuser = fields.BooleanField(default=False, description="Is SuperUser") is_superuser = fields.BooleanField(default=False, description="Is SuperUser")
avatar = fields.CharField(max_length=200, default="") avatar = fields.CharField(max_length=200, default="")
@@ -34,17 +34,18 @@ class User(Model):
class Email(Model): class Email(Model):
email = fields.CharField(max_length=200) email = fields.CharField(max_length=200)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("diff_models.User", db_constraint=True) user = fields.ForeignKeyField("models.User", db_constraint=False)
class Category(Model): class Category(Model):
slug = fields.CharField(max_length=200) slug = fields.CharField(max_length=200)
user = fields.ForeignKeyField("diff_models.User", description="User") name = fields.CharField(max_length=200)
user = fields.ForeignKeyField("models.User", description="User")
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Product(Model): class Product(Model):
categories = fields.ManyToManyField("diff_models.Category") categories = fields.ManyToManyField("models.Category")
name = fields.CharField(max_length=50) name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num") view_num = fields.IntField(description="View Num")
sort = fields.IntField() sort = fields.IntField()
@@ -60,3 +61,6 @@ class Config(Model):
key = fields.CharField(max_length=20) key = fields.CharField(max_length=20)
value = fields.JSONField() value = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on) status: Status = fields.IntEnumField(Status, default=Status.on)
class Meta:
table = "configs"

View File

@@ -1,11 +1,8 @@
import pytest
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
from aerich.exceptions import NotSupportError
from aerich.migrate import Migrate from aerich.migrate import Migrate
from tests.models import Category, User from tests.models import Category, Product, User
def test_create_table(): def test_create_table():
@@ -15,8 +12,8 @@ def test_create_table():
ret ret
== """CREATE TABLE IF NOT EXISTS `category` ( == """CREATE TABLE IF NOT EXISTS `category` (
`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT, `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
`slug` VARCHAR(200) NOT NULL, `slug` VARCHAR(100) NOT NULL,
`name` VARCHAR(200) NOT NULL, `name` VARCHAR(200),
`created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
`user_id` INT NOT NULL COMMENT 'User', `user_id` INT NOT NULL COMMENT 'User',
CONSTRAINT `fk_category_user_e2e3874c` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE CONSTRAINT `fk_category_user_e2e3874c` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE
@@ -28,8 +25,8 @@ def test_create_table():
ret ret
== """CREATE TABLE IF NOT EXISTS "category" ( == """CREATE TABLE IF NOT EXISTS "category" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
"slug" VARCHAR(200) NOT NULL, "slug" VARCHAR(100) NOT NULL,
"name" VARCHAR(200) NOT NULL, "name" VARCHAR(200),
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */ "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */
);""" );"""
@@ -40,8 +37,8 @@ def test_create_table():
ret ret
== """CREATE TABLE IF NOT EXISTS "category" ( == """CREATE TABLE IF NOT EXISTS "category" (
"id" SERIAL NOT NULL PRIMARY KEY, "id" SERIAL NOT NULL PRIMARY KEY,
"slug" VARCHAR(200) NOT NULL, "slug" VARCHAR(100) NOT NULL,
"name" VARCHAR(200) NOT NULL, "name" VARCHAR(200),
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, "created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE
); );
@@ -50,7 +47,7 @@ COMMENT ON COLUMN "category"."user_id" IS 'User';"""
def test_drop_table(): def test_drop_table():
ret = Migrate.ddl.drop_table(Category) ret = Migrate.ddl.drop_table(Category._meta.db_table)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "DROP TABLE IF EXISTS `category`" assert ret == "DROP TABLE IF EXISTS `category`"
else: else:
@@ -58,26 +55,28 @@ def test_drop_table():
def test_add_column(): def test_add_column():
ret = Migrate.ddl.add_column(Category, Category._meta.fields_map.get("name")) ret = Migrate.ddl.add_column(Category, Category._meta.fields_map.get("name").describe(False))
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` ADD `name` VARCHAR(200) NOT NULL" assert ret == "ALTER TABLE `category` ADD `name` VARCHAR(200)"
else: else:
assert ret == 'ALTER TABLE "category" ADD "name" VARCHAR(200) NOT NULL' assert ret == 'ALTER TABLE "category" ADD "name" VARCHAR(200)'
def test_modify_column(): def test_modify_column():
if isinstance(Migrate.ddl, SqliteDDL): if isinstance(Migrate.ddl, SqliteDDL):
with pytest.raises(NotSupportError): return
ret0 = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name"))
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active"))
else: ret0 = Migrate.ddl.modify_column(
ret0 = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name")) Category, Category._meta.fields_map.get("name").describe(False)
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active")) )
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active").describe(False))
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL" assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)"
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert ret0 == 'ALTER TABLE "category" ALTER COLUMN "name" TYPE VARCHAR(200)' assert (
ret0
== 'ALTER TABLE "category" ALTER COLUMN "name" TYPE VARCHAR(200) USING "name"::VARCHAR(200)'
)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ( assert (
@@ -85,59 +84,64 @@ def test_modify_column():
== "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1" == "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1"
) )
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert ret1 == 'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL' assert (
ret1 == 'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL USING "is_active"::BOOL'
)
def test_alter_column_default(): def test_alter_column_default():
ret = Migrate.ddl.alter_column_default(Category, Category._meta.fields_map.get("name")) if isinstance(Migrate.ddl, SqliteDDL):
return
ret = Migrate.ddl.alter_column_default(User, User._meta.fields_map.get("intro").describe(False))
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" DROP DEFAULT' assert ret == 'ALTER TABLE "user" ALTER COLUMN "intro" SET DEFAULT \'\''
else: elif isinstance(Migrate.ddl, MysqlDDL):
assert ret is None assert ret == "ALTER TABLE `user` ALTER COLUMN `intro` SET DEFAULT ''"
ret = Migrate.ddl.alter_column_default(Category, Category._meta.fields_map.get("created_at")) ret = Migrate.ddl.alter_column_default(
Category, Category._meta.fields_map.get("created_at").describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):
assert ( assert (
ret == 'ALTER TABLE "category" ALTER COLUMN "created_at" SET DEFAULT CURRENT_TIMESTAMP' ret == 'ALTER TABLE "category" ALTER COLUMN "created_at" SET DEFAULT CURRENT_TIMESTAMP'
) )
else: elif isinstance(Migrate.ddl, MysqlDDL):
assert ret is None assert (
ret
== "ALTER TABLE `category` ALTER COLUMN `created_at` SET DEFAULT CURRENT_TIMESTAMP(6)"
)
ret = Migrate.ddl.alter_column_default(User, User._meta.fields_map.get("avatar")) ret = Migrate.ddl.alter_column_default(
Product, Product._meta.fields_map.get("view_num").describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "user" ALTER COLUMN "avatar" SET DEFAULT \'\'' assert ret == 'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0'
else: elif isinstance(Migrate.ddl, MysqlDDL):
assert ret is None assert ret == "ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0"
def test_alter_column_null(): def test_alter_column_null():
ret = Migrate.ddl.alter_column_null(Category, Category._meta.fields_map.get("name")) if isinstance(Migrate.ddl, (SqliteDDL, MysqlDDL)):
return
ret = Migrate.ddl.alter_column_null(
Category, Category._meta.fields_map.get("name").describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL' assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" DROP NOT NULL'
else:
assert ret is None
def test_set_comment(): def test_set_comment():
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("name")) if isinstance(Migrate.ddl, (SqliteDDL, MysqlDDL)):
if isinstance(Migrate.ddl, PostgresDDL): return
assert ret == 'COMMENT ON COLUMN "category"."name" IS NULL' ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("name").describe(False))
else: assert ret == 'COMMENT ON COLUMN "category"."name" IS NULL'
assert ret is None
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("user")) ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("user").describe(False))
if isinstance(Migrate.ddl, PostgresDDL): assert ret == 'COMMENT ON COLUMN "category"."user_id" IS \'User\''
assert ret == 'COMMENT ON COLUMN "category"."user" IS \'User\''
else:
assert ret is None
def test_drop_column(): def test_drop_column():
if isinstance(Migrate.ddl, SqliteDDL): ret = Migrate.ddl.drop_column(Category, "name")
with pytest.raises(NotSupportError):
ret = Migrate.ddl.drop_column(Category, "name")
else:
ret = Migrate.ddl.drop_column(Category, "name")
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` DROP COLUMN `name`" assert ret == "ALTER TABLE `category` DROP COLUMN `name`"
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
@@ -148,18 +152,15 @@ def test_add_index():
index = Migrate.ddl.add_index(Category, ["name"]) index = Migrate.ddl.add_index(Category, ["name"])
index_u = Migrate.ddl.add_index(Category, ["name"], True) index_u = Migrate.ddl.add_index(Category, ["name"], True)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert index == "ALTER TABLE `category` ADD INDEX `idx_category_name_8b0cb9` (`name`)" assert index == "ALTER TABLE `category` ADD INDEX `idx_category_name_8b0cb9` (`name`)"
assert ( assert (
index_u == "ALTER TABLE `category` ADD UNIQUE INDEX `uid_category_name_8b0cb9` (`name`)" index_u == "ALTER TABLE `category` ADD UNIQUE INDEX `uid_category_name_8b0cb9` (`name`)"
) )
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")' assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")'
assert ( assert index_u == 'CREATE UNIQUE INDEX "uid_category_name_8b0cb9" ON "category" ("name")'
index_u
== 'ALTER TABLE "category" ADD CONSTRAINT "uid_category_name_8b0cb9" UNIQUE ("name")'
)
else: else:
assert index == 'ALTER TABLE "category" ADD INDEX "idx_category_name_8b0cb9" ("name")' assert index == 'ALTER TABLE "category" ADD INDEX "idx_category_name_8b0cb9" ("name")'
assert ( assert (
index_u == 'ALTER TABLE "category" ADD UNIQUE INDEX "uid_category_name_8b0cb9" ("name")' index_u == 'ALTER TABLE "category" ADD UNIQUE INDEX "uid_category_name_8b0cb9" ("name")'
) )
@@ -173,14 +174,16 @@ def test_drop_index():
assert ret_u == "ALTER TABLE `category` DROP INDEX `uid_category_name_8b0cb9`" assert ret_u == "ALTER TABLE `category` DROP INDEX `uid_category_name_8b0cb9`"
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'DROP INDEX "idx_category_name_8b0cb9"' assert ret == 'DROP INDEX "idx_category_name_8b0cb9"'
assert ret_u == 'ALTER TABLE "category" DROP CONSTRAINT "uid_category_name_8b0cb9"' assert ret_u == 'DROP INDEX "uid_category_name_8b0cb9"'
else: else:
assert ret == 'ALTER TABLE "category" DROP INDEX "idx_category_name_8b0cb9"' assert ret == 'ALTER TABLE "category" DROP INDEX "idx_category_name_8b0cb9"'
assert ret_u == 'ALTER TABLE "category" DROP INDEX "uid_category_name_8b0cb9"' assert ret_u == 'ALTER TABLE "category" DROP INDEX "uid_category_name_8b0cb9"'
def test_add_fk(): def test_add_fk():
ret = Migrate.ddl.add_fk(Category, Category._meta.fields_map.get("user")) ret = Migrate.ddl.add_fk(
Category, Category._meta.fields_map.get("user").describe(False), User.describe(False)
)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ( assert (
ret ret
@@ -194,7 +197,9 @@ def test_add_fk():
def test_drop_fk(): def test_drop_fk():
ret = Migrate.ddl.drop_fk(Category, Category._meta.fields_map.get("user")) ret = Migrate.ddl.drop_fk(
Category, Category._meta.fields_map.get("user").describe(False), User.describe(False)
)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` DROP FOREIGN KEY `fk_category_user_e2e3874c`" assert ret == "ALTER TABLE `category` DROP FOREIGN KEY `fk_category_user_e2e3874c`"
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):

View File

@@ -1,60 +1,884 @@
import pytest import pytest
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from tortoise import Tortoise
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
from aerich.exceptions import NotSupportError from aerich.exceptions import NotSupportError
from aerich.migrate import Migrate from aerich.migrate import Migrate
from aerich.utils import get_models_describe
old_models_describe = {
"models.Category": {
"name": "models.Category",
"app": "models",
"table": "category",
"abstract": False,
"description": None,
"docstring": None,
"unique_together": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
"db_column": "id",
"python_type": "int",
"generated": True,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"},
},
"data_fields": [
{
"name": "slug",
"field_type": "CharField",
"db_column": "slug",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 200},
"db_field_types": {"": "VARCHAR(200)"},
},
{
"name": "name",
"field_type": "CharField",
"db_column": "name",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 200},
"db_field_types": {"": "VARCHAR(200)"},
},
{
"name": "created_at",
"field_type": "DatetimeField",
"db_column": "created_at",
"python_type": "datetime.datetime",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"readOnly": True},
"db_field_types": {
"": "TIMESTAMP",
"mysql": "DATETIME(6)",
"postgres": "TIMESTAMPTZ",
},
"auto_now_add": True,
"auto_now": False,
},
{
"name": "user_id",
"field_type": "IntField",
"db_column": "user_id",
"python_type": "int",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": "User",
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"},
},
],
"fk_fields": [
{
"name": "user",
"field_type": "ForeignKeyFieldInstance",
"python_type": "models.User",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": "User",
"docstring": None,
"constraints": {},
"raw_field": "user_id",
"on_delete": "CASCADE",
}
],
"backward_fk_fields": [],
"o2o_fields": [],
"backward_o2o_fields": [],
"m2m_fields": [
{
"name": "products",
"field_type": "ManyToManyFieldInstance",
"python_type": "models.Product",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"model_name": "models.Product",
"related_name": "categories",
"forward_key": "product_id",
"backward_key": "category_id",
"through": "product_category",
"on_delete": "CASCADE",
"_generated": True,
}
],
},
"models.Config": {
"name": "models.Config",
"app": "models",
"table": "configs",
"abstract": False,
"description": None,
"docstring": None,
"unique_together": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
"db_column": "id",
"python_type": "int",
"generated": True,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"},
},
"data_fields": [
{
"name": "label",
"field_type": "CharField",
"db_column": "label",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 200},
"db_field_types": {"": "VARCHAR(200)"},
},
{
"name": "key",
"field_type": "CharField",
"db_column": "key",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 20},
"db_field_types": {"": "VARCHAR(20)"},
},
{
"name": "value",
"field_type": "JSONField",
"db_column": "value",
"python_type": "Union[dict, list]",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"db_field_types": {"": "TEXT", "postgres": "JSONB"},
},
{
"name": "status",
"field_type": "IntEnumFieldInstance",
"db_column": "status",
"python_type": "int",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": 1,
"description": "on: 1\noff: 0",
"docstring": None,
"constraints": {"ge": -32768, "le": 32767},
"db_field_types": {"": "SMALLINT"},
},
],
"fk_fields": [],
"backward_fk_fields": [],
"o2o_fields": [],
"backward_o2o_fields": [],
"m2m_fields": [],
},
"models.Email": {
"name": "models.Email",
"app": "models",
"table": "email",
"abstract": False,
"description": None,
"docstring": None,
"unique_together": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
"db_column": "id",
"python_type": "int",
"generated": True,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"},
},
"data_fields": [
{
"name": "email",
"field_type": "CharField",
"db_column": "email",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 200},
"db_field_types": {"": "VARCHAR(200)"},
},
{
"name": "is_primary",
"field_type": "BooleanField",
"db_column": "is_primary",
"python_type": "bool",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": False,
"description": None,
"docstring": None,
"constraints": {},
"db_field_types": {"": "BOOL", "sqlite": "INT"},
},
{
"name": "user_id",
"field_type": "IntField",
"db_column": "user_id",
"python_type": "int",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"},
},
],
"fk_fields": [
{
"name": "user",
"field_type": "ForeignKeyFieldInstance",
"python_type": "models.User",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"raw_field": "user_id",
"on_delete": "CASCADE",
}
],
"backward_fk_fields": [],
"o2o_fields": [],
"backward_o2o_fields": [],
"m2m_fields": [],
},
"models.Product": {
"name": "models.Product",
"app": "models",
"table": "product",
"abstract": False,
"description": None,
"docstring": None,
"unique_together": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
"db_column": "id",
"python_type": "int",
"generated": True,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"},
},
"data_fields": [
{
"name": "name",
"field_type": "CharField",
"db_column": "name",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 50},
"db_field_types": {"": "VARCHAR(50)"},
},
{
"name": "view_num",
"field_type": "IntField",
"db_column": "view_num",
"python_type": "int",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": "View Num",
"docstring": None,
"constraints": {"ge": -2147483648, "le": 2147483647},
"db_field_types": {"": "INT"},
},
{
"name": "sort",
"field_type": "IntField",
"db_column": "sort",
"python_type": "int",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": -2147483648, "le": 2147483647},
"db_field_types": {"": "INT"},
},
{
"name": "is_reviewed",
"field_type": "BooleanField",
"db_column": "is_reviewed",
"python_type": "bool",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": "Is Reviewed",
"docstring": None,
"constraints": {},
"db_field_types": {"": "BOOL", "sqlite": "INT"},
},
{
"name": "type",
"field_type": "IntEnumFieldInstance",
"db_column": "type",
"python_type": "int",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": "Product Type",
"docstring": None,
"constraints": {"ge": -32768, "le": 32767},
"db_field_types": {"": "SMALLINT"},
},
{
"name": "image",
"field_type": "CharField",
"db_column": "image",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 200},
"db_field_types": {"": "VARCHAR(200)"},
},
{
"name": "body",
"field_type": "TextField",
"db_column": "body",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"db_field_types": {"": "TEXT", "mysql": "LONGTEXT"},
},
{
"name": "created_at",
"field_type": "DatetimeField",
"db_column": "created_at",
"python_type": "datetime.datetime",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"readOnly": True},
"db_field_types": {
"": "TIMESTAMP",
"mysql": "DATETIME(6)",
"postgres": "TIMESTAMPTZ",
},
"auto_now_add": True,
"auto_now": False,
},
],
"fk_fields": [],
"backward_fk_fields": [],
"o2o_fields": [],
"backward_o2o_fields": [],
"m2m_fields": [
{
"name": "categories",
"field_type": "ManyToManyFieldInstance",
"python_type": "models.Category",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"model_name": "models.Category",
"related_name": "products",
"forward_key": "category_id",
"backward_key": "product_id",
"through": "product_category",
"on_delete": "CASCADE",
"_generated": False,
}
],
},
"models.User": {
"name": "models.User",
"app": "models",
"table": "user",
"abstract": False,
"description": None,
"docstring": None,
"unique_together": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
"db_column": "id",
"python_type": "int",
"generated": True,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"},
},
"data_fields": [
{
"name": "username",
"field_type": "CharField",
"db_column": "username",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 20},
"db_field_types": {"": "VARCHAR(20)"},
},
{
"name": "password",
"field_type": "CharField",
"db_column": "password",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 200},
"db_field_types": {"": "VARCHAR(200)"},
},
{
"name": "last_login",
"field_type": "DatetimeField",
"db_column": "last_login",
"python_type": "datetime.datetime",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": "<function None.now>",
"description": "Last Login",
"docstring": None,
"constraints": {},
"db_field_types": {
"": "TIMESTAMP",
"mysql": "DATETIME(6)",
"postgres": "TIMESTAMPTZ",
},
"auto_now_add": False,
"auto_now": False,
},
{
"name": "is_active",
"field_type": "BooleanField",
"db_column": "is_active",
"python_type": "bool",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": True,
"description": "Is Active",
"docstring": None,
"constraints": {},
"db_field_types": {"": "BOOL", "sqlite": "INT"},
},
{
"name": "is_superuser",
"field_type": "BooleanField",
"db_column": "is_superuser",
"python_type": "bool",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": False,
"description": "Is SuperUser",
"docstring": None,
"constraints": {},
"db_field_types": {"": "BOOL", "sqlite": "INT"},
},
{
"name": "avatar",
"field_type": "CharField",
"db_column": "avatar",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": "",
"description": None,
"docstring": None,
"constraints": {"max_length": 200},
"db_field_types": {"": "VARCHAR(200)"},
},
{
"name": "intro",
"field_type": "TextField",
"db_column": "intro",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": "",
"description": None,
"docstring": None,
"constraints": {},
"db_field_types": {"": "TEXT", "mysql": "LONGTEXT"},
},
],
"fk_fields": [],
"backward_fk_fields": [
{
"name": "categorys",
"field_type": "BackwardFKRelation",
"python_type": "models.Category",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": "User",
"docstring": None,
"constraints": {},
},
{
"name": "emails",
"field_type": "BackwardFKRelation",
"python_type": "models.Email",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
},
],
"o2o_fields": [],
"backward_o2o_fields": [],
"m2m_fields": [],
},
"models.Aerich": {
"name": "models.Aerich",
"app": "models",
"table": "aerich",
"abstract": False,
"description": None,
"docstring": None,
"unique_together": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
"db_column": "id",
"python_type": "int",
"generated": True,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"},
},
"data_fields": [
{
"name": "version",
"field_type": "CharField",
"db_column": "version",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 255},
"db_field_types": {"": "VARCHAR(255)"},
},
{
"name": "app",
"field_type": "CharField",
"db_column": "app",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 20},
"db_field_types": {"": "VARCHAR(20)"},
},
{
"name": "content",
"field_type": "JSONField",
"db_column": "content",
"python_type": "Union[dict, list]",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"db_field_types": {"": "TEXT", "postgres": "JSONB"},
},
],
"fk_fields": [],
"backward_fk_fields": [],
"o2o_fields": [],
"backward_o2o_fields": [],
"m2m_fields": [],
},
}
def test_migrate(mocker: MockerFixture): def test_migrate(mocker: MockerFixture):
mocker.patch("click.prompt", return_value=True) """
apps = Tortoise.apps models.py diff with old_models.py
models = apps.get("models") - change email pk: id -> email_id
diff_models = apps.get("diff_models") - add field: Email.address
Migrate.diff_models(diff_models, models) - add fk: Config.user
- drop fk: Email.user
- drop field: User.avatar
- add index: Email.email
- add many to many: Email.users
- remove unique: User.username
- change column: length User.password
- add unique_together: (name,type) of Product
- alter default: Config.status
- rename column: Product.image -> Product.pic
"""
mocker.patch("click.prompt", side_effect=(True,))
models_describe = get_models_describe("models")
Migrate.app = "models"
if isinstance(Migrate.ddl, SqliteDDL): if isinstance(Migrate.ddl, SqliteDDL):
with pytest.raises(NotSupportError): with pytest.raises(NotSupportError):
Migrate.diff_models(models, diff_models, False) Migrate.diff_models(old_models_describe, models_describe)
Migrate.diff_models(models_describe, old_models_describe, False)
else: else:
Migrate.diff_models(models, diff_models, False) Migrate.diff_models(old_models_describe, models_describe)
Migrate.diff_models(models_describe, old_models_describe, False)
Migrate._merge_operators() Migrate._merge_operators()
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert Migrate.upgrade_operators == [ assert sorted(Migrate.upgrade_operators) == sorted(
"ALTER TABLE `email` DROP FOREIGN KEY `fk_email_user_5b58673d`", [
"ALTER TABLE `category` ADD `name` VARCHAR(200) NOT NULL", "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)",
"ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)", "ALTER TABLE `category` MODIFY COLUMN `slug` VARCHAR(100) NOT NULL",
"ALTER TABLE `user` RENAME COLUMN `last_login_at` TO `last_login`", "ALTER TABLE `config` ADD `user_id` INT NOT NULL COMMENT 'User'",
] "ALTER TABLE `config` ADD CONSTRAINT `fk_config_user_17daa970` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
assert Migrate.downgrade_operators == [ "ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT",
"ALTER TABLE `category` DROP COLUMN `name`", "ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL",
"ALTER TABLE `user` DROP INDEX `uid_user_usernam_9987ab`", "ALTER TABLE `email` DROP COLUMN `user_id`",
"ALTER TABLE `user` RENAME COLUMN `last_login` TO `last_login_at`", "ALTER TABLE `configs` RENAME TO `config`",
"ALTER TABLE `email` ADD CONSTRAINT `fk_email_user_5b58673d` FOREIGN KEY " "ALTER TABLE `product` RENAME COLUMN `image` TO `pic`",
"(`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE", "ALTER TABLE `email` RENAME COLUMN `id` TO `email_id`",
] "ALTER TABLE `email` DROP FOREIGN KEY `fk_email_user_5b58673d`",
"ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)",
"ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_f14935` (`name`, `type`)",
"ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0",
"ALTER TABLE `user` DROP COLUMN `avatar`",
"ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(100) NOT NULL",
"CREATE TABLE IF NOT EXISTS `newmodel` (\n `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,\n `name` VARCHAR(50) NOT NULL\n) CHARACTER SET utf8mb4;",
"ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)",
"CREATE TABLE `email_user` (`email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,`user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE) CHARACTER SET utf8mb4",
]
)
assert sorted(Migrate.downgrade_operators) == sorted(
[
"ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL",
"ALTER TABLE `category` MODIFY COLUMN `slug` VARCHAR(200) NOT NULL",
"ALTER TABLE `config` DROP COLUMN `user_id`",
"ALTER TABLE `config` DROP FOREIGN KEY `fk_config_user_17daa970`",
"ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1",
"ALTER TABLE `email` ADD `user_id` INT NOT NULL",
"ALTER TABLE `email` DROP COLUMN `address`",
"ALTER TABLE `config` RENAME TO `configs`",
"ALTER TABLE `product` RENAME COLUMN `pic` TO `image`",
"ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`",
"ALTER TABLE `email` ADD CONSTRAINT `fk_email_user_5b58673d` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
"ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`",
"ALTER TABLE `product` DROP INDEX `uid_product_name_f14935`",
"ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT",
"ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''",
"ALTER TABLE `user` DROP INDEX `idx_user_usernam_9987ab`",
"ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(200) NOT NULL",
"DROP TABLE IF EXISTS `email_user`",
"DROP TABLE IF EXISTS `newmodel`",
]
)
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert Migrate.upgrade_operators == [ assert sorted(Migrate.upgrade_operators) == sorted(
'ALTER TABLE "email" DROP CONSTRAINT "fk_email_user_5b58673d"', [
'ALTER TABLE "category" ADD "name" VARCHAR(200) NOT NULL', 'ALTER TABLE "category" ALTER COLUMN "name" DROP NOT NULL',
'ALTER TABLE "user" ADD CONSTRAINT "uid_user_usernam_9987ab" UNIQUE ("username")', 'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(100) USING "slug"::VARCHAR(100)',
'ALTER TABLE "user" RENAME COLUMN "last_login_at" TO "last_login"', 'ALTER TABLE "config" ADD "user_id" INT NOT NULL',
] 'ALTER TABLE "config" ADD CONSTRAINT "fk_config_user_17daa970" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE',
assert Migrate.downgrade_operators == [ 'ALTER TABLE "config" ALTER COLUMN "status" DROP DEFAULT',
'ALTER TABLE "category" DROP COLUMN "name"', 'ALTER TABLE "email" ADD "address" VARCHAR(200) NOT NULL',
'ALTER TABLE "user" DROP CONSTRAINT "uid_user_usernam_9987ab"', 'ALTER TABLE "email" DROP COLUMN "user_id"',
'ALTER TABLE "user" RENAME COLUMN "last_login" TO "last_login_at"', 'ALTER TABLE "product" RENAME COLUMN "image" TO "pic"',
'ALTER TABLE "email" ADD CONSTRAINT "fk_email_user_5b58673d" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE', 'ALTER TABLE "email" RENAME COLUMN "id" TO "email_id"',
] 'ALTER TABLE "configs" RENAME TO "config"',
'ALTER TABLE "email" DROP CONSTRAINT "fk_email_user_5b58673d"',
'CREATE INDEX "idx_email_email_4a1a33" ON "email" ("email")',
'CREATE UNIQUE INDEX "uid_product_name_f14935" ON "product" ("name", "type")',
'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0',
'ALTER TABLE "user" DROP COLUMN "avatar"',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(100) USING "password"::VARCHAR(100)',
'CREATE TABLE IF NOT EXISTS "newmodel" (\n "id" SERIAL NOT NULL PRIMARY KEY,\n "name" VARCHAR(50) NOT NULL\n);\nCOMMENT ON COLUMN "config"."user_id" IS \'User\';',
'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")',
'CREATE TABLE "email_user" ("email_id" INT NOT NULL REFERENCES "email" ("email_id") ON DELETE CASCADE,"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE)',
]
)
assert sorted(Migrate.downgrade_operators) == sorted(
[
'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL',
'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(200) USING "slug"::VARCHAR(200)',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(200) USING "password"::VARCHAR(200)',
'ALTER TABLE "config" DROP COLUMN "user_id"',
'ALTER TABLE "config" DROP CONSTRAINT "fk_config_user_17daa970"',
'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1',
'ALTER TABLE "email" ADD "user_id" INT NOT NULL',
'ALTER TABLE "email" DROP COLUMN "address"',
'ALTER TABLE "config" RENAME TO "configs"',
'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"',
'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"',
'ALTER TABLE "email" ADD CONSTRAINT "fk_email_user_5b58673d" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE',
'DROP INDEX "idx_email_email_4a1a33"',
'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT',
'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'',
'DROP INDEX "idx_user_usernam_9987ab"',
'DROP INDEX "uid_product_name_f14935"',
'DROP TABLE IF EXISTS "email_user"',
'DROP TABLE IF EXISTS "newmodel"',
]
)
elif isinstance(Migrate.ddl, SqliteDDL): elif isinstance(Migrate.ddl, SqliteDDL):
assert Migrate.upgrade_operators == [ assert Migrate.upgrade_operators == []
'ALTER TABLE "email" DROP FOREIGN KEY "fk_email_user_5b58673d"',
'ALTER TABLE "category" ADD "name" VARCHAR(200) NOT NULL',
'ALTER TABLE "user" ADD UNIQUE INDEX "uid_user_usernam_9987ab" ("username")',
'ALTER TABLE "user" RENAME COLUMN "last_login_at" TO "last_login"',
]
assert Migrate.downgrade_operators == [] assert Migrate.downgrade_operators == []