62 Commits

Author SHA1 Message Date
long2ice
be41a1332a update tortoise-orm version 2021-02-04 20:53:04 +08:00
long2ice
09661c1d46 Fix unique_together 2021-02-04 14:39:07 +08:00
long2ice
abfa60133f Fix drop table 2021-02-04 14:23:46 +08:00
long2ice
048e428eac update tortoise-orm 2021-02-03 22:52:01 +08:00
long2ice
38a3df9b5a add support m2m 2021-02-03 22:22:22 +08:00
long2ice
0d94b22b3f Remove unused functions 2021-02-03 18:06:43 +08:00
long2ice
f1f0074255 Support rename field 2021-02-03 17:56:30 +08:00
long2ice
e3a14a2f60 Fix postgres index 2021-02-03 16:34:07 +08:00
long2ice
608ff8f071 update conftest.py 2021-02-03 15:49:40 +08:00
long2ice
d3a1342293 update README.md 2021-02-03 15:48:06 +08:00
long2ice
01e3de9522 basically completed 2021-02-03 15:43:04 +08:00
long2ice
c6c398fdf0 update 2021-02-02 22:52:50 +08:00
long2ice
c60bdd290e add fk and drop fk 2021-02-02 20:35:05 +08:00
long2ice
f443dc68db WIP 2021-02-01 16:54:35 +08:00
long2ice
36f84702b7 update 2021-02-01 14:00:12 +08:00
long2ice
b4cc2de0e3 v0.5 refactoring 2021-01-31 23:10:30 +08:00
long2ice
4780b90c1c add close_connections to fix stuck 2021-01-29 22:58:12 +08:00
long2ice
cd176c1fd6 Merge pull request #111 from lqmanh/bugfixes/fix-tortoise-orm-0.16.19
Fix Aerich b/c of a new feature in Tortoise ORM v0.16.19
2021-01-04 14:59:11 +08:00
long2ice
c2819fc8dc update CHANGELOG.md 2020-12-29 19:13:37 +08:00
long2ice
530e7cfce5 Fixed unnecessary import. (#113) 2020-12-29 19:12:36 +08:00
Lương Quang Mạnh
47824a100b Fix Aerich b/c of Tortoise ORM v0.16.19 2020-12-26 10:31:10 +07:00
long2ice
78a15f9f19 Merge pull request #108 from lqmanh/features/make-parent-dirs-as-needed
Make parent directories as needed
2020-12-25 22:10:56 +08:00
long2ice
5ae8b9e85f complete InspectDb 2020-12-25 21:44:26 +08:00
long2ice
55a6d4bbc7 add InspectDb and show_create_tables 2020-12-24 23:32:58 +08:00
long2ice
c5535f16e1 TODO: Add inspectdb command 2020-12-23 23:38:45 +08:00
long2ice
840cd71e44 Replace migrations separator to sql standard comment 2020-12-23 23:30:35 +08:00
Lương Quang Mạnh
e0d52b1210 Fix make style 2020-12-21 15:36:29 +07:00
Lương Quang Mạnh
4dc45f723a Make parent directories as needed 2020-12-21 15:13:26 +07:00
long2ice
d2e0a68351 Fix packaging error. (#92) 2020-12-02 23:03:15 +08:00
long2ice
ee6cc20c7d Fix empty items 2020-11-30 11:14:09 +08:00
long2ice
4e917495a0 Fix upgrade in new db. (#96) 2020-11-30 11:02:48 +08:00
long2ice
bfa66f6dd4 update changelog 2020-11-29 11:15:43 +08:00
long2ice
f00715d4c4 Merge pull request #97 from TrDex/pathlib-for-path-resolving
Use `pathlib` for path resolving
2020-11-29 11:02:44 +08:00
Mykola Solodukha
6e3105690a Use pathlib for path resolving 2020-11-28 19:23:34 +02:00
long2ice
c707f7ecb2 bug fix 2020-11-28 14:31:41 +08:00
long2ice
0bbc471e00 Fix sqlite stuck. (#90) 2020-11-26 23:38:57 +08:00
long2ice
fb6cc62047 update README and CHANGELOG 2020-11-23 16:44:16 +08:00
long2ice
e9ceaf471f Merge pull request #87 from ALexALed/remove-default-detections-for-callable
Remove callable detection for defaults
2020-11-23 16:41:30 +08:00
alexaled
85fc3b2aa2 Remove callable detection for defaults 2020-11-23 10:35:40 +02:00
long2ice
a677d506a9 Fix ci error 2020-11-19 10:41:52 +08:00
long2ice
9879004fee Add rename column support MySQL5 2020-11-19 10:11:52 +08:00
long2ice
5760fe2040 Merge pull request #83 from SakuraSound/fix-migrate-unlink
Catch OSError (if read-only file system)
2020-11-18 15:40:29 +08:00
Joir-dan Gumbs
b229c30558 Catch OSError (if read-only file system) 2020-11-17 23:28:00 -08:00
long2ice
5d2f1604c3 update github action poetry 2020-11-17 10:57:56 +08:00
long2ice
499c4e1c02 Fix black 2020-11-17 10:50:57 +08:00
long2ice
1463ee30bc update deps 2020-11-17 10:43:27 +08:00
long2ice
3b801932f5 Merge remote-tracking branch 'origin/dev' into dev 2020-11-17 10:36:14 +08:00
long2ice
c2eb4dc9e3 update poetry in github actions 2020-11-17 10:35:51 +08:00
long2ice
5927febd0c Delete .DS_Store 2020-11-17 10:10:32 +08:00
long2ice
a1c10ff330 exclude .DS_store 2020-11-17 10:09:37 +08:00
long2ice
f2013c931a Fix test error 2020-11-16 22:32:19 +08:00
long2ice
b21b954d32 Use .sql instead of .json to store version file. (#79) 2020-11-16 22:25:01 +08:00
long2ice
f5588a35c5 update deps 2020-11-12 21:27:58 +08:00
long2ice
f5dff84476 Fix encoding error. (#75) 2020-11-08 23:00:44 +08:00
long2ice
e399821116 update deps 2020-11-05 17:43:41 +08:00
long2ice
648f25a951 Compatible with models file in directory. (#70) 2020-10-30 19:51:46 +08:00
long2ice
fa73e132e2 remove .vscode 2020-10-30 16:45:12 +08:00
long2ice
1bac33cd33 add confirmation_option when downgrade 2020-10-30 16:39:14 +08:00
long2ice
4e76f12ccf update README.md 2020-10-28 17:12:23 +08:00
long2ice
724379700e Support multiple databases. (#68) 2020-10-28 17:02:02 +08:00
long2ice
bb929f2b55 update deps 2020-10-25 17:48:05 +08:00
long2ice
6339dc86a8 Fix migrate to new database error 2020-10-14 20:33:23 +08:00
24 changed files with 2251 additions and 1051 deletions

View File

@@ -11,7 +11,10 @@ jobs:
- uses: actions/setup-python@v2
with:
python-version: '3.x'
- uses: dschep/install-poetry-action@v1.3
- name: Install and configure Poetry
uses: snok/install-poetry@v1.1.1
with:
virtualenvs-create: false
- name: Build dists
run: make build
- name: Pypi Publish

View File

@@ -19,7 +19,10 @@ jobs:
- uses: actions/setup-python@v2
with:
python-version: '3.x'
- uses: dschep/install-poetry-action@v1.3
- name: Install and configure Poetry
uses: snok/install-poetry@v1.1.1
with:
virtualenvs-create: false
- name: CI
env:
MYSQL_PASS: root

2
.gitignore vendored
View File

@@ -144,3 +144,5 @@ cython_debug/
migrations
aerich.ini
src
.vscode
.DS_Store

View File

@@ -1,7 +1,51 @@
# ChangeLog
## 0.5
### 0.5.0
- Refactor core code, now has no limitation for everything.
## 0.4
### 0.4.4
- Fix unnecessary import. (#113)
### 0.4.3
- Replace migrations separator to sql standard comment.
- Add `inspectdb` command.
### 0.4.2
- Use `pathlib` for path resolving. (#89)
- Fix upgrade in new db. (#96)
- Fix packaging error. (#92)
### 0.4.1
- Bug fix. (#91 #93)
### 0.4.0
- Use `.sql` instead of `.json` to store version file.
- Add `rename` column support MySQL5.
- Remove callable detection for defaults. (#87)
- Fix `sqlite` stuck. (#90)
## 0.3
### 0.3.3
- Fix encoding error. (#75)
- Support multiple databases. (#68)
- Compatible with models file in directory. (#70)
### 0.3.2
- Fix migrate to new database error. (#62)
### 0.3.1
- Fix first version error.

View File

@@ -3,8 +3,10 @@ black_opts = -l 100 -t py38
py_warn = PYTHONDEVMODE=1
MYSQL_HOST ?= "127.0.0.1"
MYSQL_PORT ?= 3306
MYSQL_PASS ?= "123456"
POSTGRES_HOST ?= "127.0.0.1"
POSTGRES_PORT ?= 5432
POSTGRES_PASS ?= "123456"
help:
@echo "Aerich development makefile"
@@ -22,7 +24,7 @@ up:
@poetry update
deps:
@poetry install -E dbdrivers --no-root
@poetry install -E dbdrivers
style: deps
isort -src $(checkfiles)
@@ -43,7 +45,7 @@ test_mysql:
$(py_warn) TEST_DB="mysql://root:$(MYSQL_PASS)@$(MYSQL_HOST):$(MYSQL_PORT)/test_\{\}" pytest -vv -s
test_postgres:
$(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest
$(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest -vv -s
testall: deps test_sqlite test_postgres test_mysql

View File

@@ -7,10 +7,12 @@
## Introduction
Aerich is a database migrations tool for Tortoise-ORM, which like alembic for SQLAlchemy, or Django ORM with it\'s
own migrations solution.
Aerich is a database migrations tool for Tortoise-ORM, which like alembic for SQLAlchemy, or Django ORM with it\'s own
migrations solution.
**If you upgrade aerich from <= 0.2.5 to >= 0.3.0, see [changelog](https://github.com/tortoise/aerich/blob/dev/CHANGELOG.md) for upgrade steps.**
~~**Important: You can only use absolutely import in your `models.py` to make `aerich` work.**~~
From version `v0.5.0`, there is no such limitation now.
## Install
@@ -40,14 +42,14 @@ Commands:
history List all migrate items.
init Init config file and generate root migrate location.
init-db Generate schema and generate app migrate location.
inspectdb Introspects the database tables to standard output as...
migrate Generate migrate changes file.
upgrade Upgrade to latest version.
```
## Usage
You need add `aerich.models` to your `Tortoise-ORM` config first,
example:
You need add `aerich.models` to your `Tortoise-ORM` config first, example:
```python
TORTOISE_ORM = {
@@ -103,22 +105,22 @@ If your Tortoise-ORM app is not default `models`, you must specify
```shell
> aerich migrate --name drop_column
Success migrate 1_202029051520102929_drop_column.json
Success migrate 1_202029051520102929_drop_column.sql
```
Format of migrate filename is
`{version_num}_{datetime}_{name|update}.json`.
`{version_num}_{datetime}_{name|update}.sql`.
And if `aerich` guess you are renaming a column, it will ask `Rename {old_column} to {new_column} [True]`, you can choice `True` to rename column without column drop, or choice `False` to drop column then create.
If you use `MySQL`, only MySQL8.0+ support `rename..to` syntax.
And if `aerich` guess you are renaming a column, it will ask `Rename {old_column} to {new_column} [True]`, you can
choice `True` to rename column without column drop, or choice `False` to drop column then create, note that the after
maybe lose data.
### Upgrade to latest version
```shell
> aerich upgrade
Success upgrade 1_202029051520102929_drop_column.json
Success upgrade 1_202029051520102929_drop_column.sql
```
Now your db is migrated to latest.
@@ -134,13 +136,17 @@ Usage: aerich downgrade [OPTIONS]
Options:
-v, --version INTEGER Specified version, default to last. [default: -1]
-d, --delete Delete version files at the same time. [default:
False]
--yes Confirm the action without prompting.
-h, --help Show this message and exit.
```
```shell
> aerich downgrade
Success downgrade 1_202029051520102929_drop_column.json
Success downgrade 1_202029051520102929_drop_column.sql
```
Now your db rollback to specified version.
@@ -150,7 +156,7 @@ Now your db rollback to specified version.
```shell
> aerich history
1_202029051520102929_drop_column.json
1_202029051520102929_drop_column.sql
```
### Show heads to be migrated
@@ -158,13 +164,56 @@ Now your db rollback to specified version.
```shell
> aerich heads
1_202029051520102929_drop_column.json
1_202029051520102929_drop_column.sql
```
## Support this project
### Inspect db tables to TortoiseORM model
- Just give a star!
- Donation.
Currently, only support MySQL.
```shell
Usage: aerich inspectdb [OPTIONS]
Introspects the database tables to standard output as TortoiseORM model.
Options:
-t, --table TEXT Which tables to inspect.
-h, --help Show this message and exit.
```
Inspect all tables and print to console:
```shell
aerich --app models inspectdb
```
Inspect a specified table in default app and redirect to `models.py`:
```shell
aerich inspectdb -t user > models.py
```
Note that this command is restricted, which is not supported in some solutions, such as `IntEnumField`
and `ForeignKeyField` and so on.
### Multiple databases
```python
tortoise_orm = {
"connections": {
"default": expand_db_url(db_url, True),
"second": expand_db_url(db_url_second, True),
},
"apps": {
"models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"},
"models_second": {"models": ["tests.models_second"], "default_connection": "second", },
},
}
```
You need only specify `aerich.models` in one app, and must specify `--app` when run `aerich migrate` and so on.
## Support this project
| AliPay | WeChatPay | PayPal |
| -------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | ---------------------------------------------------------------- |

View File

@@ -1 +1 @@
__version__ = "0.3.1"
__version__ = "0.5.0"

View File

@@ -1,9 +1,10 @@
import asyncio
import json
import os
import sys
from configparser import ConfigParser
from functools import wraps
from pathlib import Path
from typing import List
import click
from click import Context, UsageError
@@ -12,8 +13,16 @@ from tortoise.exceptions import OperationalError
from tortoise.transactions import in_transaction
from tortoise.utils import get_schema_sql
from aerich.inspectdb import InspectDb
from aerich.migrate import Migrate
from aerich.utils import get_app_connection, get_app_connection_name, get_tortoise_config
from aerich.utils import (
get_app_connection,
get_app_connection_name,
get_models_describe,
get_tortoise_config,
get_version_content_from_file,
write_version_file,
)
from . import __version__
from .enums import Color
@@ -26,11 +35,7 @@ def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
ctx = args[0]
loop.run_until_complete(f(*args, **kwargs))
app = ctx.obj.get("app")
if app:
Migrate.remove_old_model_file(app, ctx.obj["location"])
return wrapper
@@ -38,7 +43,11 @@ def coro(f):
@click.group(context_settings={"help_option_names": ["-h", "--help"]})
@click.version_option(__version__, "-V", "--version")
@click.option(
"-c", "--config", default="aerich.ini", show_default=True, help="Config file.",
"-c",
"--config",
default="aerich.ini",
show_default=True,
help="Config file.",
)
@click.option("--app", required=False, help="Tortoise-ORM app name.")
@click.option(
@@ -57,7 +66,7 @@ async def cli(ctx: Context, config, app, name):
invoked_subcommand = ctx.invoked_subcommand
if invoked_subcommand != "init":
if not os.path.exists(config):
if not Path(config).exists():
raise UsageError("You must exec init first", ctx=ctx)
parser.read(config)
@@ -66,14 +75,14 @@ async def cli(ctx: Context, config, app, name):
tortoise_config = get_tortoise_config(ctx, tortoise_orm)
app = app or list(tortoise_config.get("apps").keys())[0]
if "aerich.models" not in tortoise_config.get("apps").get(app).get("models"):
raise UsageError("Check your tortoise config and add aerich.models to it.", ctx=ctx)
ctx.obj["config"] = tortoise_config
ctx.obj["location"] = location
ctx.obj["app"] = app
Migrate.app = app
if invoked_subcommand != "init-db":
await Migrate.init_with_old_models(tortoise_config, app, location)
if not Path(location, app).exists():
raise UsageError("You must exec init-db first", ctx=ctx)
await Migrate.init(tortoise_config, app, location)
@cli.command(help="Generate migrate changes file.")
@@ -93,7 +102,6 @@ async def migrate(ctx: Context, name):
async def upgrade(ctx: Context):
config = ctx.obj["config"]
app = ctx.obj["app"]
location = ctx.obj["location"]
migrated = False
for version_file in Migrate.get_all_version_files():
try:
@@ -102,21 +110,20 @@ async def upgrade(ctx: Context):
exists = False
if not exists:
async with in_transaction(get_app_connection_name(config, app)) as conn:
file_path = os.path.join(Migrate.migrate_location, version_file)
with open(file_path, "r", encoding="utf-8") as f:
content = json.load(f)
file_path = Path(Migrate.migrate_location, version_file)
content = get_version_content_from_file(file_path)
upgrade_query_list = content.get("upgrade")
for upgrade_query in upgrade_query_list:
await conn.execute_script(upgrade_query)
await Aerich.create(
version=version_file,
app=app,
content=Migrate.get_models_content(config, app, location),
content=get_models_describe(app),
)
click.secho(f"Success upgrade {version_file}", fg=Color.green)
migrated = True
if not migrated:
click.secho("No migrate items", fg=Color.yellow)
click.secho("No upgrade items found", fg=Color.yellow)
@cli.command(help="Downgrade to specified version.")
@@ -128,9 +135,20 @@ async def upgrade(ctx: Context):
show_default=True,
help="Specified version, default to last.",
)
@click.option(
"-d",
"--delete",
is_flag=True,
default=False,
show_default=True,
help="Delete version files at the same time.",
)
@click.pass_context
@click.confirmation_option(
prompt="Downgrade is dangerous, which maybe lose your data, are you sure?",
)
@coro
async def downgrade(ctx: Context, version: int):
async def downgrade(ctx: Context, version: int, delete: bool):
app = ctx.obj["app"]
config = ctx.obj["config"]
if version == -1:
@@ -146,15 +164,16 @@ async def downgrade(ctx: Context, version: int):
for version in versions:
file = version.version
async with in_transaction(get_app_connection_name(config, app)) as conn:
file_path = os.path.join(Migrate.migrate_location, file)
with open(file_path, "r", encoding="utf-8") as f:
content = json.load(f)
file_path = Path(Migrate.migrate_location, file)
content = get_version_content_from_file(file_path)
downgrade_query_list = content.get("downgrade")
if not downgrade_query_list:
return click.secho("No downgrade item found", fg=Color.yellow)
click.secho("No downgrade items found", fg=Color.yellow)
return
for downgrade_query in downgrade_query_list:
await conn.execute_query(downgrade_query)
await version.delete()
if delete:
os.unlink(file_path)
click.secho(f"Success downgrade {file}", fg=Color.green)
@@ -193,16 +212,21 @@ async def history(ctx: Context):
help="Tortoise-ORM config module dict variable, like settings.TORTOISE_ORM.",
)
@click.option(
"--location", default="./migrations", show_default=True, help="Migrate store location."
"--location",
default="./migrations",
show_default=True,
help="Migrate store location.",
)
@click.pass_context
@coro
async def init(
ctx: Context, tortoise_orm, location,
ctx: Context,
tortoise_orm,
location,
):
config_file = ctx.obj["config_file"]
name = ctx.obj["name"]
if os.path.exists(config_file):
if Path(config_file).exists():
return click.secho("You have inited", fg=Color.yellow)
parser.add_section(name)
@@ -212,8 +236,7 @@ async def init(
with open(config_file, "w", encoding="utf-8") as f:
parser.write(f)
if not os.path.isdir(location):
os.mkdir(location)
Path(location).mkdir(parents=True, exist_ok=True)
click.secho(f"Success create migrate location {location}", fg=Color.green)
click.secho(f"Success generate config file {config_file}", fg=Color.green)
@@ -234,12 +257,14 @@ async def init_db(ctx: Context, safe):
location = ctx.obj["location"]
app = ctx.obj["app"]
dirname = os.path.join(location, app)
if not os.path.isdir(dirname):
os.mkdir(dirname)
dirname = Path(location, app)
try:
dirname.mkdir(parents=True)
click.secho(f"Success create app migrate location {dirname}", fg=Color.green)
else:
return click.secho(f"Inited {app} already", fg=Color.yellow)
except FileExistsError:
return click.secho(
f"Inited {app} already, or delete {dirname} and try again.", fg=Color.yellow
)
await Tortoise.init(config=config)
connection = get_app_connection(config, app)
@@ -249,16 +274,40 @@ async def init_db(ctx: Context, safe):
version = await Migrate.generate_version()
await Aerich.create(
version=version, app=app, content=Migrate.get_models_content(config, app, location)
version=version,
app=app,
content=get_models_describe(app),
)
with open(os.path.join(dirname, version), "w", encoding="utf-8") as f:
content = {
"upgrade": [schema],
}
json.dump(content, f, ensure_ascii=False, indent=2)
return click.secho(f'Success generate schema for app "{app}"', fg=Color.green)
write_version_file(Path(dirname, version), content)
click.secho(f'Success generate schema for app "{app}"', fg=Color.green)
@cli.command(help="Introspects the database tables to standard output as TortoiseORM model.")
@click.option(
"-t",
"--table",
help="Which tables to inspect.",
multiple=True,
required=False,
)
@click.pass_context
@coro
async def inspectdb(ctx: Context, table: List[str]):
config = ctx.obj["config"]
app = ctx.obj["app"]
connection = get_app_connection(config, app)
inspect = InspectDb(connection, table)
await inspect.inspect()
def main():
sys.path.insert(0, ".")
cli()
if __name__ == "__main__":
main()

View File

@@ -1,8 +1,8 @@
from enum import Enum
from typing import List, Type
from tortoise import BaseDBAsyncClient, ForeignKeyFieldInstance, ManyToManyFieldInstance, Model
from tortoise import BaseDBAsyncClient, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.fields import CASCADE, Field, JSONField, TextField, UUIDField
class BaseDDL:
@@ -11,6 +11,7 @@ class BaseDDL:
_DROP_TABLE_TEMPLATE = 'DROP TABLE IF EXISTS "{table_name}"'
_ADD_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {column}'
_DROP_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" DROP COLUMN "{column_name}"'
_ALTER_DEFAULT_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {default}'
_RENAME_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" RENAME COLUMN "{old_column_name}" TO "{new_column_name}"'
)
@@ -20,8 +21,11 @@ class BaseDDL:
_DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"'
_ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"'
_M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment};'
_M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment}'
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}'
_CHANGE_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" CHANGE {old_column_name} {new_column_name} {new_column_type}'
)
def __init__(self, client: "BaseDBAsyncClient"):
self.client = client
@@ -30,43 +34,50 @@ class BaseDDL:
def create_table(self, model: "Type[Model]"):
return self.schema_generator._get_table_sql(model, True)["table_creation_string"]
def drop_table(self, model: "Type[Model]"):
return self._DROP_TABLE_TEMPLATE.format(table_name=model._meta.db_table)
def drop_table(self, table_name: str):
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def create_m2m_table(self, model: "Type[Model]", field: ManyToManyFieldInstance):
def create_m2m(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
):
through = field_describe.get("through")
description = field_describe.get("description")
reference_id = reference_table_describe.get("pk_field").get("db_column")
db_field_types = reference_table_describe.get("pk_field").get("db_field_types")
return self._M2M_TABLE_TEMPLATE.format(
table_name=field.through,
table_name=through,
backward_table=model._meta.db_table,
forward_table=field.related_model._meta.db_table,
forward_table=reference_table_describe.get("table"),
backward_field=model._meta.db_pk_column,
forward_field=field.related_model._meta.db_pk_column,
backward_key=field.backward_key,
forward_field=reference_id,
backward_key=field_describe.get("backward_key"),
backward_type=model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"),
forward_key=field.forward_key,
forward_type=field.related_model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"),
on_delete=CASCADE,
extra=self.schema_generator._table_generate_extra(table=field.through),
forward_key=field_describe.get("forward_key"),
forward_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
on_delete=field_describe.get("on_delete"),
extra=self.schema_generator._table_generate_extra(table=through),
comment=self.schema_generator._table_comment_generator(
table=field.through, comment=field.description
table=through, comment=description
)
if field.description
if description
else "",
)
def drop_m2m(self, field: ManyToManyFieldInstance):
return self._DROP_TABLE_TEMPLATE.format(table_name=field.through)
def drop_m2m(self, table_name: str):
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def _get_default(self, model: "Type[Model]", field_object: Field):
def _get_default(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table
default = field_object.default
db_column = field_object.model_field_name
auto_now_add = getattr(field_object, "auto_now_add", False)
auto_now = getattr(field_object, "auto_now", False)
default = field_describe.get("default")
if isinstance(default, Enum):
default = default.value
db_column = field_describe.get("db_column")
auto_now_add = field_describe.get("auto_now_add", False)
auto_now = field_describe.get("auto_now", False)
if default is not None or auto_now_add:
if callable(default) or isinstance(field_object, (UUIDField, TextField, JSONField)):
if field_describe.get("field_type") in ["UUIDField", "TextField", "JSONField"]:
default = ""
else:
default = field_object.to_db_value(default, model)
try:
default = self.schema_generator._column_default_generator(
db_table,
@@ -78,28 +89,33 @@ class BaseDDL:
except NotImplementedError:
default = ""
else:
default = ""
default = None
return default
def add_column(self, model: "Type[Model]", field_object: Field):
def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
db_table = model._meta.db_table
description = field_describe.get("description")
db_column = field_describe.get("db_column")
db_field_types = field_describe.get("db_field_types")
default = self._get_default(model, field_describe)
if default is None:
default = ""
return self._ADD_COLUMN_TEMPLATE.format(
table_name=db_table,
column=self.schema_generator._create_string(
db_column=field_object.model_field_name,
field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"),
nullable="NOT NULL" if not field_object.null else "",
unique="UNIQUE" if field_object.unique else "",
db_column=db_column,
field_type=db_field_types.get(self.DIALECT, db_field_types.get("")),
nullable="NOT NULL" if not field_describe.get("nullable") else "",
unique="UNIQUE" if field_describe.get("unique") else "",
comment=self.schema_generator._column_comment_generator(
table=db_table,
column=field_object.model_field_name,
comment=field_object.description,
column=db_column,
comment=field_describe.get("description"),
)
if field_object.description
if description
else "",
is_primary_key=field_object.pk,
default=self._get_default(model, field_object),
is_primary_key=is_pk,
default=default,
),
)
@@ -108,24 +124,28 @@ class BaseDDL:
table_name=model._meta.db_table, column_name=column_name
)
def modify_column(self, model: "Type[Model]", field_object: Field):
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
db_table = model._meta.db_table
db_field_types = field_describe.get("db_field_types")
default = self._get_default(model, field_describe)
if default is None:
default = ""
return self._MODIFY_COLUMN_TEMPLATE.format(
table_name=db_table,
column=self.schema_generator._create_string(
db_column=field_object.model_field_name,
field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"),
nullable="NOT NULL" if not field_object.null else "",
db_column=field_describe.get("db_column"),
field_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
nullable="NOT NULL" if not field_describe.get("nullable") else "",
unique="",
comment=self.schema_generator._column_comment_generator(
table=db_table,
column=field_object.model_field_name,
comment=field_object.description,
column=field_describe.get("db_column"),
comment=field_describe.get("description"),
)
if field_object.description
if field_describe.get("description")
else "",
is_primary_key=field_object.pk,
default=self._get_default(model, field_object),
is_primary_key=is_pk,
default=default,
),
)
@@ -136,6 +156,16 @@ class BaseDDL:
new_column_name=new_column_name,
)
def change_column(
self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str
):
return self._CHANGE_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table,
old_column_name=old_column_name,
new_column_name=new_column_name,
new_column_type=new_column_type,
)
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE " if unique else "",
@@ -154,48 +184,49 @@ class BaseDDL:
table_name=model._meta.db_table,
)
def add_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance):
def add_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict):
db_table = model._meta.db_table
to_field_name = field.to_field_instance.source_field
if not to_field_name:
to_field_name = field.to_field_instance.model_field_name
db_column = field.source_field or field.model_field_name + "_id"
db_column = field_describe.get("raw_field")
reference_id = reference_table_describe.get("pk_field").get("db_column")
fk_name = self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=db_column,
to_table=field.related_model._meta.db_table,
to_field=to_field_name,
to_table=reference_table_describe.get("table"),
to_field=reference_table_describe.get("pk_field").get("db_column"),
)
return self._ADD_FK_TEMPLATE.format(
table_name=db_table,
fk_name=fk_name,
db_column=db_column,
table=field.related_model._meta.db_table,
field=to_field_name,
on_delete=field.on_delete,
table=reference_table_describe.get("table"),
field=reference_id,
on_delete=field_describe.get("on_delete"),
)
def drop_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance):
to_field_name = field.to_field_instance.source_field
if not to_field_name:
to_field_name = field.to_field_instance.model_field_name
def drop_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict):
db_table = model._meta.db_table
return self._DROP_FK_TEMPLATE.format(
table_name=db_table,
fk_name=self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=field.source_field or field.model_field_name + "_id",
to_table=field.related_model._meta.db_table,
to_field=to_field_name,
from_field=field_describe.get("raw_field"),
to_table=reference_table_describe.get("table"),
to_field=reference_table_describe.get("pk_field").get("db_column"),
),
)
def alter_column_default(self, model: "Type[Model]", field_object: Field):
pass
def alter_column_default(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table
default = self._get_default(model, field_describe)
return self._ALTER_DEFAULT_TEMPLATE.format(
table_name=db_table,
column=field_describe.get("db_column"),
default="SET" + default if default is not None else "DROP DEFAULT",
)
def alter_column_null(self, model: "Type[Model]", field_object: Field):
pass
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
raise NotImplementedError
def set_comment(self, model: "Type[Model]", field_object: Field):
pass
def set_comment(self, model: "Type[Model]", field_describe: dict):
raise NotImplementedError

View File

@@ -1,6 +1,10 @@
from typing import Type
from tortoise import Model
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
from aerich.ddl import BaseDDL
from aerich.exceptions import NotSupportError
class MysqlDDL(BaseDDL):
@@ -8,6 +12,10 @@ class MysqlDDL(BaseDDL):
DIALECT = MySQLSchemaGenerator.DIALECT
_DROP_TABLE_TEMPLATE = "DROP TABLE IF EXISTS `{table_name}`"
_ADD_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` ADD {column}"
_ALTER_DEFAULT_TEMPLATE = "ALTER TABLE `{table_name}` ALTER COLUMN `{column}` {default}"
_CHANGE_COLUMN_TEMPLATE = (
"ALTER TABLE `{table_name}` CHANGE {old_column_name} {new_column_name} {new_column_type}"
)
_DROP_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` DROP COLUMN `{column_name}`"
_RENAME_COLUMN_TEMPLATE = (
"ALTER TABLE `{table_name}` RENAME COLUMN `{old_column_name}` TO `{new_column_name}`"
@@ -18,5 +26,11 @@ class MysqlDDL(BaseDDL):
_DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`"
_ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}"
_DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`"
_M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment};"
_M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment}"
_MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}"
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
raise NotSupportError("Alter column null is unsupported in MySQL.")
def set_comment(self, model: "Type[Model]", field_describe: dict):
raise NotSupportError("Alter column comment is unsupported in MySQL.")

View File

@@ -1,8 +1,7 @@
from typing import List, Type
from typing import Type
from tortoise import Model
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
from tortoise.fields import Field
from aerich.ddl import BaseDDL
@@ -10,66 +9,36 @@ from aerich.ddl import BaseDDL
class PostgresDDL(BaseDDL):
schema_generator_cls = AsyncpgSchemaGenerator
DIALECT = AsyncpgSchemaGenerator.DIALECT
_ADD_INDEX_TEMPLATE = 'CREATE INDEX "{index_name}" ON "{table_name}" ({column_names})'
_ADD_UNIQUE_TEMPLATE = (
'ALTER TABLE "{table_name}" ADD CONSTRAINT "{index_name}" UNIQUE ({column_names})'
)
_ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX "{index_name}" ON "{table_name}" ({column_names})'
_DROP_INDEX_TEMPLATE = 'DROP INDEX "{index_name}"'
_DROP_UNIQUE_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{index_name}"'
_ALTER_DEFAULT_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {default}'
_ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL'
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}'
_SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"'
def alter_column_default(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table
default = self._get_default(model, field_object)
return self._ALTER_DEFAULT_TEMPLATE.format(
table_name=db_table,
column=field_object.model_field_name,
default="SET" + default if default else "DROP DEFAULT",
)
def alter_column_null(self, model: "Type[Model]", field_object: Field):
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table
return self._ALTER_NULL_TEMPLATE.format(
table_name=db_table,
column=field_object.model_field_name,
set_drop="DROP" if field_object.null else "SET",
column=field_describe.get("db_column"),
set_drop="DROP" if field_describe.get("nullable") else "SET",
)
def modify_column(self, model: "Type[Model]", field_object: Field):
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
db_table = model._meta.db_table
db_field_types = field_describe.get("db_field_types")
return self._MODIFY_COLUMN_TEMPLATE.format(
table_name=db_table,
column=field_object.model_field_name,
datatype=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"),
column=field_describe.get("db_column"),
datatype=db_field_types.get(self.DIALECT) or db_field_types.get(""),
)
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
template = self._ADD_UNIQUE_TEMPLATE if unique else self._ADD_INDEX_TEMPLATE
return template.format(
index_name=self.schema_generator._generate_index_name(
"uid" if unique else "idx", model, field_names
),
table_name=model._meta.db_table,
column_names=", ".join([self.schema_generator.quote(f) for f in field_names]),
)
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False):
template = self._DROP_UNIQUE_TEMPLATE if unique else self._DROP_INDEX_TEMPLATE
return template.format(
index_name=self.schema_generator._generate_index_name(
"uid" if unique else "idx", model, field_names
),
table_name=model._meta.db_table,
)
def set_comment(self, model: "Type[Model]", field_object: Field):
def set_comment(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table
return self._SET_COMMENT_TEMPLATE.format(
table_name=db_table,
column=field_object.model_field_name,
comment="'{}'".format(field_object.description) if field_object.description else "NULL",
column=field_describe.get("db_column") or field_describe.get("raw_field"),
comment="'{}'".format(field_describe.get("description"))
if field_describe.get("description")
else "NULL",
)

View File

@@ -2,7 +2,6 @@ from typing import Type
from tortoise import Model
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
from tortoise.fields import Field
from aerich.ddl import BaseDDL
from aerich.exceptions import NotSupportError
@@ -15,5 +14,14 @@ class SqliteDDL(BaseDDL):
def drop_column(self, model: "Type[Model]", column_name: str):
raise NotSupportError("Drop column is unsupported in SQLite.")
def modify_column(self, model: "Type[Model]", field_object: Field):
def modify_column(self, model: "Type[Model]", field_object: dict, is_pk: bool = True):
raise NotSupportError("Modify column is unsupported in SQLite.")
def alter_column_default(self, model: "Type[Model]", field_describe: dict):
raise NotSupportError("Alter column default is unsupported in SQLite.")
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
raise NotSupportError("Alter column null is unsupported in SQLite.")
def set_comment(self, model: "Type[Model]", field_describe: dict):
raise NotSupportError("Alter column comment is unsupported in SQLite.")

85
aerich/inspectdb.py Normal file
View File

@@ -0,0 +1,85 @@
import sys
from typing import List, Optional
from ddlparse import DdlParse
from tortoise import BaseDBAsyncClient
class InspectDb:
_table_template = "class {table}(Model):\n"
_field_template_mapping = {
"INT": " {field} = fields.IntField({pk}{unique}{comment})",
"SMALLINT": " {field} = fields.IntField({pk}{unique}{comment})",
"TINYINT": " {field} = fields.BooleanField({null}{default}{comment})",
"VARCHAR": " {field} = fields.CharField({pk}{unique}{length}{null}{default}{comment})",
"LONGTEXT": " {field} = fields.TextField({null}{default}{comment})",
"TEXT": " {field} = fields.TextField({null}{default}{comment})",
"DATETIME": " {field} = fields.DatetimeField({null}{default}{comment})",
}
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
self.conn = conn
self.tables = tables
self.DIALECT = conn.schema_generator.DIALECT
async def show_create_tables(self):
if self.DIALECT == "mysql":
if not self.tables:
sql_tables = f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{self.conn.database}';" # nosec: B608
ret = await self.conn.execute_query(sql_tables)
self.tables = map(lambda x: x["TABLE_NAME"], ret[1])
for table in self.tables:
sql_show_create_table = f"SHOW CREATE TABLE {table}"
ret = await self.conn.execute_query(sql_show_create_table)
yield ret[1][0]["Create Table"]
else:
raise NotImplementedError("Currently only support MySQL")
async def inspect(self):
ddl_list = self.show_create_tables()
result = "from tortoise import Model, fields\n\n\n"
tables = []
async for ddl in ddl_list:
parser = DdlParse(ddl, DdlParse.DATABASE.mysql)
table = parser.parse()
name = table.name.title()
columns = table.columns
fields = []
model = self._table_template.format(table=name)
for column_name, column in columns.items():
comment = default = length = unique = null = pk = ""
if column.primary_key:
pk = "pk=True, "
if column.unique:
unique = "unique=True, "
if column.data_type == "VARCHAR":
length = f"max_length={column.length}, "
if not column.not_null:
null = "null=True, "
if column.default is not None:
if column.data_type == "TINYINT":
default = f"default={'True' if column.default == '1' else 'False'}, "
elif column.data_type == "DATETIME":
if "CURRENT_TIMESTAMP" in column.default:
if "ON UPDATE CURRENT_TIMESTAMP" in ddl:
default = "auto_now_add=True, "
else:
default = "auto_now=True, "
else:
default = f"default={column.default}, "
if column.comment:
comment = f"description='{column.comment}', "
field = self._field_template_mapping[column.data_type].format(
field=column_name,
pk=pk,
unique=unique,
length=length,
null=null,
default=default,
comment=comment,
)
fields.append(field)
tables.append(model + "\n".join(fields))
sys.stdout.write(result + "\n\n\n".join(tables))

View File

@@ -1,25 +1,16 @@
import json
import os
import re
from datetime import datetime
from importlib import import_module
from io import StringIO
from typing import Dict, List, Tuple, Type
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type
import click
from tortoise import (
BackwardFKRelation,
BackwardOneToOneRelation,
ForeignKeyFieldInstance,
ManyToManyFieldInstance,
Model,
Tortoise,
)
from tortoise.fields import Field
from dictdiffer import diff
from tortoise import BaseDBAsyncClient, Model, Tortoise
from tortoise.exceptions import OperationalError
from aerich.ddl import BaseDDL
from aerich.models import MAX_VERSION_LENGTH, Aerich
from aerich.utils import get_app_connection
from aerich.utils import get_app_connection, get_models_describe, write_version_file
class Migrate:
@@ -34,50 +25,45 @@ class Migrate:
_rename_new = []
ddl: BaseDDL
migrate_config: dict
old_models = "old_models"
diff_app = "diff_models"
_last_version_content: Optional[dict] = None
app: str
migrate_location: str
dialect: str
@classmethod
def get_old_model_file(cls, app: str, location: str):
return os.path.join(location, app, cls.old_models + ".py")
_db_version: Optional[str] = None
@classmethod
def get_all_version_files(cls) -> List[str]:
return sorted(
filter(lambda x: x.endswith("json"), os.listdir(cls.migrate_location)),
filter(lambda x: x.endswith("sql"), os.listdir(cls.migrate_location)),
key=lambda x: int(x.split("_")[0]),
)
@classmethod
async def get_last_version(cls) -> Aerich:
return await Aerich.filter(app=cls.app).first()
def _get_model(cls, model: str) -> Type[Model]:
return Tortoise.apps.get(cls.app).get(model)
@classmethod
def remove_old_model_file(cls, app: str, location: str):
async def get_last_version(cls) -> Optional[Aerich]:
try:
os.unlink(cls.get_old_model_file(app, location))
except FileNotFoundError:
return await Aerich.filter(app=cls.app).first()
except OperationalError:
pass
@classmethod
async def init_with_old_models(cls, config: dict, app: str, location: str):
async def _get_db_version(cls, connection: BaseDBAsyncClient):
if cls.dialect == "mysql":
sql = "select version() as version"
ret = await connection.execute_query(sql)
cls._db_version = ret[1][0].get("version")
@classmethod
async def init(cls, config: dict, app: str, location: str):
await Tortoise.init(config=config)
last_version = await cls.get_last_version()
if last_version:
content = last_version.content
with open(cls.get_old_model_file(app, location), "w") as f:
f.write(content)
migrate_config = cls._get_migrate_config(config, app, location)
cls.app = app
cls.migrate_config = migrate_config
cls.migrate_location = os.path.join(location, app)
await Tortoise.init(config=migrate_config)
cls.migrate_location = Path(location, app)
if last_version:
cls._last_version_content = last_version.content
connection = get_app_connection(config, app)
cls.dialect = connection.schema_generator.DIALECT
@@ -93,6 +79,7 @@ class Migrate:
from aerich.ddl.postgres import PostgresDDL
cls.ddl = PostgresDDL(connection)
await cls._get_db_version(connection)
@classmethod
async def _get_last_version_num(cls):
@@ -107,8 +94,8 @@ class Migrate:
now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "")
last_version_num = await cls._get_last_version_num()
if last_version_num is None:
return f"0_{now}_init.json"
version = f"{last_version_num + 1}_{now}_{name}.json"
return f"0_{now}_init.sql"
version = f"{last_version_num + 1}_{now}_{name}.sql"
if len(version) > MAX_VERSION_LENGTH:
raise ValueError(f"Version name exceeds maximum length ({MAX_VERSION_LENGTH})")
return version
@@ -119,13 +106,12 @@ class Migrate:
# delete if same version exists
for version_file in cls.get_all_version_files():
if version_file.startswith(version.split("_")[0]):
os.unlink(os.path.join(cls.migrate_location, version_file))
os.unlink(Path(cls.migrate_location, version_file))
content = {
"upgrade": cls.upgrade_operators,
"downgrade": cls.downgrade_operators,
}
with open(os.path.join(cls.migrate_location, version), "w", encoding="utf-8") as f:
json.dump(content, f, indent=2, ensure_ascii=False)
write_version_file(Path(cls.migrate_location, version), content)
return version
@classmethod
@@ -135,12 +121,9 @@ class Migrate:
:param name:
:return:
"""
apps = Tortoise.apps
diff_models = apps.get(cls.diff_app)
app_models = apps.get(cls.app)
cls.diff_models(diff_models, app_models)
cls.diff_models(app_models, diff_models, False)
new_version_content = get_models_describe(cls.app)
cls.diff_models(cls._last_version_content, new_version_content)
cls.diff_models(new_version_content, cls._last_version_content, False)
cls._merge_operators()
@@ -170,48 +153,7 @@ class Migrate:
cls.downgrade_operators.append(operator)
@classmethod
def _get_migrate_config(cls, config: dict, app: str, location: str):
"""
generate tmp config with old models
:param config:
:param app:
:param location:
:return:
"""
path = os.path.join(location, app, cls.old_models)
path = path.replace(os.sep, ".").lstrip(".")
config["apps"][cls.diff_app] = {
"models": [path],
"default_connection": config.get("apps").get(app).get("default_connection", "default"),
}
return config
@classmethod
def get_models_content(cls, config: dict, app: str, location: str):
"""
write new models to old models
:param config:
:param app:
:param location:
:return:
"""
old_model_files = []
models = config.get("apps").get(app).get("models")
for model in models:
old_model_files.append(import_module(model).__file__)
pattern = rf"(\n)?('|\")({app})(.\w+)('|\")"
str_io = StringIO()
for i, model_file in enumerate(old_model_files):
with open(model_file, "r", encoding="utf-8") as f:
content = f.read()
ret = re.sub(pattern, rf"\2{cls.diff_app}\4\5", content)
str_io.write(f"{ret}\n")
return str_io.getvalue()
@classmethod
def diff_models(
cls, old_models: Dict[str, Type[Model]], new_models: Dict[str, Type[Model]], upgrade=True
):
def diff_models(cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True):
"""
diff models and add operators
:param old_models:
@@ -219,165 +161,248 @@ class Migrate:
:param upgrade:
:return:
"""
old_models.pop(cls._aerich, None)
new_models.pop(cls._aerich, None)
_aerich = f"{cls.app}.{cls._aerich}"
old_models.pop(_aerich, None)
new_models.pop(_aerich, None)
for new_model_str, new_model_describe in new_models.items():
model = cls._get_model(new_model_describe.get("name").split(".")[1])
for new_model_str, new_model in new_models.items():
if new_model_str not in old_models.keys():
cls._add_operator(cls.add_model(new_model), upgrade)
cls._add_operator(cls.add_model(model), upgrade)
else:
cls.diff_model(old_models.get(new_model_str), new_model, upgrade)
old_model_describe = old_models.get(new_model_str)
old_unique_together = set(
map(lambda x: tuple(x), old_model_describe.get("unique_together"))
)
new_unique_together = set(
map(lambda x: tuple(x), new_model_describe.get("unique_together"))
)
old_pk_field = old_model_describe.get("pk_field")
new_pk_field = new_model_describe.get("pk_field")
# pk field
changes = diff(old_pk_field, new_pk_field)
for action, option, change in changes:
# current only support rename pk
if action == "change" and option == "name":
cls._add_operator(cls._rename_field(model, *change), upgrade)
# m2m fields
old_m2m_fields = old_model_describe.get("m2m_fields")
new_m2m_fields = new_model_describe.get("m2m_fields")
for action, option, change in diff(old_m2m_fields, new_m2m_fields):
table = change[0][1].get("through")
if action == "add":
add = False
if upgrade and table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table)
add = True
elif not upgrade and table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table)
add = True
if add:
cls._add_operator(
cls.create_m2m(
model,
change[0][1],
new_models.get(change[0][1].get("model_name")),
),
upgrade,
fk_m2m=True,
)
elif action == "remove":
add = False
if upgrade and table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table)
add = True
elif not upgrade and table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table)
add = True
if add:
cls._add_operator(cls.drop_m2m(table), upgrade, fk_m2m=True)
# add unique_together
for index in new_unique_together.difference(old_unique_together):
cls._add_operator(
cls._add_index(model, index, True),
upgrade,
)
# remove unique_together
for index in old_unique_together.difference(new_unique_together):
cls._add_operator(
cls._drop_index(model, index, True),
upgrade,
)
old_data_fields = old_model_describe.get("data_fields")
new_data_fields = new_model_describe.get("data_fields")
old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields))
new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields))
# add fields or rename fields
for new_data_field_name in set(new_data_fields_name).difference(
set(old_data_fields_name)
):
new_data_field = next(
filter(lambda x: x.get("name") == new_data_field_name, new_data_fields)
)
is_rename = False
for old_data_field in old_data_fields:
changes = list(diff(old_data_field, new_data_field))
old_data_field_name = old_data_field.get("name")
if len(changes) == 2:
# rename field
if changes[0] == (
"change",
"name",
(old_data_field_name, new_data_field_name),
) and changes[1] == (
"change",
"db_column",
(old_data_field.get("db_column"), new_data_field.get("db_column")),
):
if upgrade:
is_rename = click.prompt(
f"Rename {old_data_field_name} to {new_data_field_name}?",
default=True,
type=bool,
show_choices=True,
)
else:
is_rename = old_data_field_name in cls._rename_new
if is_rename:
cls._rename_new.append(new_data_field_name)
cls._rename_old.append(old_data_field_name)
# only MySQL8+ has rename syntax
if (
cls.dialect == "mysql"
and cls._db_version
and cls._db_version.startswith("5.")
):
cls._add_operator(
cls._change_field(
model, new_data_field, old_data_field
),
upgrade,
)
else:
cls._add_operator(
cls._rename_field(model, *changes[1][2]),
upgrade,
)
if not is_rename:
cls._add_operator(
cls._add_field(
model,
new_data_field,
),
upgrade,
)
# remove fields
for old_data_field_name in set(old_data_fields_name).difference(
set(new_data_fields_name)
):
# don't remove field if is rename
if (upgrade and old_data_field_name in cls._rename_old) or (
not upgrade and old_data_field_name in cls._rename_new
):
continue
cls._add_operator(
cls._remove_field(
model,
next(
filter(
lambda x: x.get("name") == old_data_field_name, old_data_fields
)
).get("db_column"),
),
upgrade,
)
old_fk_fields = old_model_describe.get("fk_fields")
new_fk_fields = new_model_describe.get("fk_fields")
old_fk_fields_name = list(map(lambda x: x.get("name"), old_fk_fields))
new_fk_fields_name = list(map(lambda x: x.get("name"), new_fk_fields))
# add fk
for new_fk_field_name in set(new_fk_fields_name).difference(
set(old_fk_fields_name)
):
fk_field = next(
filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields)
)
cls._add_operator(
cls._add_fk(model, fk_field, new_models.get(fk_field.get("python_type"))),
upgrade,
fk_m2m=True,
)
# drop fk
for old_fk_field_name in set(old_fk_fields_name).difference(
set(new_fk_fields_name)
):
old_fk_field = next(
filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields)
)
cls._add_operator(
cls._drop_fk(
model, old_fk_field, old_models.get(old_fk_field.get("python_type"))
),
upgrade,
fk_m2m=True,
)
# change fields
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
old_data_field = next(
filter(lambda x: x.get("name") == field_name, old_data_fields)
)
new_data_field = next(
filter(lambda x: x.get("name") == field_name, new_data_fields)
)
changes = diff(old_data_field, new_data_field)
for change in changes:
_, option, old_new = change
if option == "indexed":
# change index
unique = new_data_field.get("unique")
if old_new[0] is False and old_new[1] is True:
cls._add_operator(
cls._add_index(model, (field_name,), unique),
upgrade,
)
else:
cls._add_operator(
cls._drop_index(model, (field_name,), unique),
upgrade,
)
elif option == "db_field_types.":
# change column
cls._add_operator(
cls._change_field(model, old_data_field, new_data_field),
upgrade,
)
elif option == "default":
cls._add_operator(cls._alter_default(model, new_data_field), upgrade)
for old_model in old_models:
if old_model not in new_models.keys():
cls._add_operator(cls.remove_model(old_models.get(old_model)), upgrade)
@classmethod
def _is_fk_m2m(cls, field: Field):
return isinstance(field, (ForeignKeyFieldInstance, ManyToManyFieldInstance))
cls._add_operator(cls.drop_model(old_models.get(old_model).get("table")), upgrade)
@classmethod
def add_model(cls, model: Type[Model]):
return cls.ddl.create_table(model)
@classmethod
def remove_model(cls, model: Type[Model]):
return cls.ddl.drop_table(model)
def drop_model(cls, table_name: str):
return cls.ddl.drop_table(table_name)
@classmethod
def diff_model(cls, old_model: Type[Model], new_model: Type[Model], upgrade=True):
"""
diff single model
:param old_model:
:param new_model:
:param upgrade:
:return:
"""
old_indexes = old_model._meta.indexes
new_indexes = new_model._meta.indexes
def create_m2m(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
return cls.ddl.create_m2m(model, field_describe, reference_table_describe)
old_unique_together = old_model._meta.unique_together
new_unique_together = new_model._meta.unique_together
old_fields_map = old_model._meta.fields_map
new_fields_map = new_model._meta.fields_map
old_keys = old_fields_map.keys()
new_keys = new_fields_map.keys()
for new_key in new_keys:
new_field = new_fields_map.get(new_key)
if cls._exclude_field(new_field, upgrade):
continue
if new_key not in old_keys:
new_field_dict = new_field.describe(serializable=True)
new_field_dict.pop("name", None)
new_field_dict.pop("db_column", None)
for diff_key in old_keys - new_keys:
old_field = old_fields_map.get(diff_key)
old_field_dict = old_field.describe(serializable=True)
old_field_dict.pop("name", None)
old_field_dict.pop("db_column", None)
if old_field_dict == new_field_dict:
if upgrade:
is_rename = click.prompt(
f"Rename {diff_key} to {new_key}?",
default=True,
type=bool,
show_choices=True,
)
cls._rename_new.append(new_key)
cls._rename_old.append(diff_key)
else:
is_rename = diff_key in cls._rename_new
if is_rename:
cls._add_operator(
cls._rename_field(new_model, old_field, new_field), upgrade,
)
break
else:
cls._add_operator(
cls._add_field(new_model, new_field), upgrade, cls._is_fk_m2m(new_field),
)
else:
old_field = old_fields_map.get(new_key)
new_field_dict = new_field.describe(serializable=True)
new_field_dict.pop("unique")
new_field_dict.pop("indexed")
old_field_dict = old_field.describe(serializable=True)
old_field_dict.pop("unique")
old_field_dict.pop("indexed")
if not cls._is_fk_m2m(new_field) and new_field_dict != old_field_dict:
if cls.dialect == "postgres":
if new_field.null != old_field.null:
cls._add_operator(
cls._alter_null(new_model, new_field), upgrade=upgrade
)
if new_field.default != old_field.default:
cls._add_operator(
cls._alter_default(new_model, new_field), upgrade=upgrade
)
if new_field.description != old_field.description:
cls._add_operator(
cls._set_comment(new_model, new_field), upgrade=upgrade
)
if new_field.field_type != old_field.field_type:
cls._add_operator(
cls._modify_field(new_model, new_field), upgrade=upgrade
)
else:
cls._add_operator(cls._modify_field(new_model, new_field), upgrade=upgrade)
if (old_field.index and not new_field.index) or (
old_field.unique and not new_field.unique
):
cls._add_operator(
cls._remove_index(
old_model, (old_field.model_field_name,), old_field.unique
),
upgrade,
cls._is_fk_m2m(old_field),
)
elif (new_field.index and not old_field.index) or (
new_field.unique and not old_field.unique
):
cls._add_operator(
cls._add_index(new_model, (new_field.model_field_name,), new_field.unique),
upgrade,
cls._is_fk_m2m(new_field),
)
if isinstance(new_field, ForeignKeyFieldInstance):
if old_field.db_constraint and not new_field.db_constraint:
cls._add_operator(
cls._drop_fk(new_model, new_field), upgrade, True,
)
if new_field.db_constraint and not old_field.db_constraint:
cls._add_operator(
cls._add_fk(new_model, new_field), upgrade, True,
)
for old_key in old_keys:
field = old_fields_map.get(old_key)
if old_key not in new_keys and not cls._exclude_field(field, upgrade):
if (upgrade and old_key not in cls._rename_old) or (
not upgrade and old_key not in cls._rename_new
):
cls._add_operator(
cls._remove_field(old_model, field), upgrade, cls._is_fk_m2m(field),
)
for new_index in new_indexes:
if new_index not in old_indexes:
cls._add_operator(cls._add_index(new_model, new_index,), upgrade)
for old_index in old_indexes:
if old_index not in new_indexes:
cls._add_operator(cls._remove_index(old_model, old_index), upgrade)
for new_unique in new_unique_together:
if new_unique not in old_unique_together:
cls._add_operator(cls._add_index(new_model, new_unique, unique=True), upgrade)
for old_unique in old_unique_together:
if old_unique not in new_unique_together:
cls._add_operator(cls._remove_index(old_model, old_unique, unique=True), upgrade)
@classmethod
def drop_m2m(cls, table_name: str):
return cls.ddl.drop_m2m(table_name)
@classmethod
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]):
@@ -390,7 +415,7 @@ class Migrate:
return ret
@classmethod
def _remove_index(cls, model: Type[Model], fields_name: Tuple[str], unique=False):
def _drop_index(cls, model: Type[Model], fields_name: Tuple[str], unique=False):
fields_name = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.drop_index(model, fields_name, unique)
@@ -400,87 +425,57 @@ class Migrate:
return cls.ddl.add_index(model, fields_name, unique)
@classmethod
def _exclude_field(cls, field: Field, upgrade=False):
"""
exclude BackwardFKRelation and repeat m2m field
:param field:
:return:
"""
if isinstance(field, ManyToManyFieldInstance):
through = field.through
if upgrade:
if through in cls._upgrade_m2m:
return True
else:
cls._upgrade_m2m.append(through)
return False
else:
if through in cls._downgrade_m2m:
return True
else:
cls._downgrade_m2m.append(through)
return False
return isinstance(field, (BackwardFKRelation, BackwardOneToOneRelation))
def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False):
return cls.ddl.add_column(model, field_describe, is_pk)
@classmethod
def _add_field(cls, model: Type[Model], field: Field):
if isinstance(field, ForeignKeyFieldInstance):
return cls.ddl.add_fk(model, field)
if isinstance(field, ManyToManyFieldInstance):
return cls.ddl.create_m2m_table(model, field)
return cls.ddl.add_column(model, field)
def _alter_default(cls, model: Type[Model], field_describe: dict):
return cls.ddl.alter_column_default(model, field_describe)
@classmethod
def _alter_default(cls, model: Type[Model], field: Field):
return cls.ddl.alter_column_default(model, field)
def _alter_null(cls, model: Type[Model], field_describe: dict):
return cls.ddl.alter_column_null(model, field_describe)
@classmethod
def _alter_null(cls, model: Type[Model], field: Field):
return cls.ddl.alter_column_null(model, field)
def _set_comment(cls, model: Type[Model], field_describe: dict):
return cls.ddl.set_comment(model, field_describe)
@classmethod
def _set_comment(cls, model: Type[Model], field: Field):
return cls.ddl.set_comment(model, field)
def _modify_field(cls, model: Type[Model], field_describe: dict):
return cls.ddl.modify_column(model, field_describe)
@classmethod
def _modify_field(cls, model: Type[Model], field: Field):
return cls.ddl.modify_column(model, field)
def _drop_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
return cls.ddl.drop_fk(model, field_describe, reference_table_describe)
@classmethod
def _drop_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance):
return cls.ddl.drop_fk(model, field)
def _remove_field(cls, model: Type[Model], column_name: str):
return cls.ddl.drop_column(model, column_name)
@classmethod
def _remove_field(cls, model: Type[Model], field: Field):
if isinstance(field, ForeignKeyFieldInstance):
return cls.ddl.drop_fk(model, field)
if isinstance(field, ManyToManyFieldInstance):
return cls.ddl.drop_m2m(field)
return cls.ddl.drop_column(model, field.model_field_name)
def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str):
return cls.ddl.rename_column(model, old_field_name, new_field_name)
@classmethod
def _rename_field(cls, model: Type[Model], old_field: Field, new_field: Field):
return cls.ddl.rename_column(model, old_field.model_field_name, new_field.model_field_name)
def _change_field(cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict):
db_field_types = new_field_describe.get("db_field_types")
return cls.ddl.change_column(
model,
old_field_describe.get("db_column"),
new_field_describe.get("db_column"),
db_field_types.get(cls.dialect) or db_field_types.get(""),
)
@classmethod
def _add_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance):
def _add_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
"""
add fk
:param model:
:param field:
:param field_describe:
:param reference_table_describe:
:return:
"""
return cls.ddl.add_fk(model, field)
@classmethod
def _remove_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance):
"""
drop fk
:param model:
:param field:
:return:
"""
return cls.ddl.drop_fk(model, field)
return cls.ddl.add_fk(model, field_describe, reference_table_describe)
@classmethod
def _merge_operators(cls):

View File

@@ -6,7 +6,7 @@ MAX_VERSION_LENGTH = 255
class Aerich(Model):
version = fields.CharField(max_length=MAX_VERSION_LENGTH)
app = fields.CharField(max_length=20)
content = fields.TextField()
content = fields.JSONField()
class Meta:
ordering = ["-id"]

View File

@@ -1,17 +1,24 @@
import importlib
from typing import Dict
from click import BadOptionUsage, Context
from tortoise import BaseDBAsyncClient, Tortoise
def get_app_connection_name(config, app) -> str:
def get_app_connection_name(config, app_name: str) -> str:
"""
get connection name
:param config:
:param app:
:param app_name:
:return:
"""
return config.get("apps").get(app).get("default_connection", "default")
app = config.get("apps").get(app_name)
if app:
return app.get("default_connection", "default")
raise BadOptionUsage(
option_name="--app",
message=f'Can\'t get app named "{app_name}"',
)
def get_app_connection(config, app) -> BaseDBAsyncClient:
@@ -49,3 +56,68 @@ def get_tortoise_config(ctx: Context, tortoise_orm: str) -> dict:
ctx=ctx,
)
return config
_UPGRADE = "-- upgrade --\n"
_DOWNGRADE = "-- downgrade --\n"
def get_version_content_from_file(version_file: str) -> Dict:
"""
get version content
:param version_file:
:return:
"""
with open(version_file, "r", encoding="utf-8") as f:
content = f.read()
first = content.index(_UPGRADE)
try:
second = content.index(_DOWNGRADE)
except ValueError:
second = len(content) - 1
upgrade_content = content[first + len(_UPGRADE) : second].strip() # noqa:E203
downgrade_content = content[second + len(_DOWNGRADE) :].strip() # noqa:E203
ret = {
"upgrade": list(filter(lambda x: x or False, upgrade_content.split(";\n"))),
"downgrade": list(filter(lambda x: x or False, downgrade_content.split(";\n"))),
}
return ret
def write_version_file(version_file: str, content: Dict):
"""
write version file
:param version_file:
:param content:
:return:
"""
with open(version_file, "w", encoding="utf-8") as f:
f.write(_UPGRADE)
upgrade = content.get("upgrade")
if len(upgrade) > 1:
f.write(";\n".join(upgrade) + ";\n")
else:
f.write(f"{upgrade[0]}")
if not upgrade[0].endswith(";"):
f.write(";")
f.write("\n")
downgrade = content.get("downgrade")
if downgrade:
f.write(_DOWNGRADE)
if len(downgrade) > 1:
f.write(";\n".join(downgrade) + ";\n")
else:
f.write(f"{downgrade[0]};\n")
def get_models_describe(app: str) -> Dict:
"""
get app models describe
:param app:
:return:
"""
ret = {}
for model in Tortoise.apps.get(app).values():
describe = model.describe()
ret[describe.get("name")] = describe
return ret

View File

@@ -13,10 +13,18 @@ from aerich.ddl.sqlite import SqliteDDL
from aerich.migrate import Migrate
db_url = os.getenv("TEST_DB", "sqlite://:memory:")
db_url_second = os.getenv("TEST_DB_SECOND", "sqlite://:memory:")
tortoise_orm = {
"connections": {"default": expand_db_url(db_url, True)},
"connections": {
"default": expand_db_url(db_url, True),
"second": expand_db_url(db_url_second, True),
},
"apps": {
"models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"},
"models": {
"models": ["tests.models", "aerich.models"],
"default_connection": "default",
},
"models_second": {"models": ["tests.models_second"], "default_connection": "second"},
},
}
@@ -31,7 +39,7 @@ def reset_migrate():
Migrate._downgrade_m2m = []
@pytest.yield_fixture(scope="session")
@pytest.fixture(scope="session")
def event_loop():
policy = asyncio.get_event_loop_policy()
res = policy.new_event_loop()
@@ -46,12 +54,6 @@ def event_loop():
@pytest.fixture(scope="session", autouse=True)
async def initialize_tests(event_loop, request):
tortoise_orm["connections"]["diff_models"] = "sqlite://:memory:"
tortoise_orm["apps"]["diff_models"] = {
"models": ["tests.diff_models"],
"default_connection": "diff_models",
}
await Tortoise.init(config=tortoise_orm, _create_db=True)
await generate_schema_for_client(Tortoise.get_connection("default"), safe=True)
@@ -62,5 +64,5 @@ async def initialize_tests(event_loop, request):
Migrate.ddl = SqliteDDL(client)
elif client.schema_generator is AsyncpgSchemaGenerator:
Migrate.ddl = PostgresDDL(client)
Migrate.dialect = Migrate.ddl.DIALECT
request.addfinalizer(lambda: event_loop.run_until_complete(Tortoise._drop_databases()))

863
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "aerich"
version = "0.3.1"
version = "0.5.0"
description = "A database migrations tool for Tortoise ORM."
authors = ["long2ice <long2ice@gmail.com>"]
license = "Apache-2.0"
@@ -16,16 +16,18 @@ include = ["CHANGELOG.md", "LICENSE", "README.md"]
[tool.poetry.dependencies]
python = "^3.7"
tortoise-orm = "*"
tortoise-orm = "^0.16.21"
click = "*"
pydantic = "*"
aiomysql = { version = "*", optional = true }
asyncpg = { version = "*", optional = true }
ddlparse = "*"
dictdiffer = "*"
[tool.poetry.dev-dependencies]
flake8 = "*"
isort = "*"
black = "^19.10b0"
black = "^20.8b1"
pytest = "*"
pytest-xdist = "*"
pytest-asyncio = "*"

View File

@@ -23,18 +23,19 @@ class Status(IntEnum):
class User(Model):
username = fields.CharField(max_length=20, unique=True)
password = fields.CharField(max_length=200)
password = fields.CharField(max_length=100)
last_login = fields.DatetimeField(description="Last Login", default=datetime.datetime.now)
is_active = fields.BooleanField(default=True, description="Is Active")
is_superuser = fields.BooleanField(default=False, description="Is SuperUser")
avatar = fields.CharField(max_length=200, default="")
intro = fields.TextField(default="")
class Email(Model):
email = fields.CharField(max_length=200)
email_id = fields.IntField(pk=True)
email = fields.CharField(max_length=200, index=True)
is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("models.User", db_constraint=False)
address = fields.CharField(max_length=200)
users = fields.ManyToManyField("models.User")
class Category(Model):
@@ -47,17 +48,21 @@ class Category(Model):
class Product(Model):
categories = fields.ManyToManyField("models.Category")
name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num")
view_num = fields.IntField(description="View Num", default=0)
sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField(ProductType, description="Product Type")
image = fields.CharField(max_length=200)
pic = fields.CharField(max_length=200)
body = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True)
class Meta:
unique_together = (("name", "type"),)
class Config(Model):
label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20)
value = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on)
status: Status = fields.IntEnumField(Status)
user = fields.ForeignKeyField("models.User", description="User")

63
tests/models_second.py Normal file
View File

@@ -0,0 +1,63 @@
import datetime
from enum import IntEnum
from tortoise import Model, fields
class ProductType(IntEnum):
article = 1
page = 2
class PermissionAction(IntEnum):
create = 1
delete = 2
update = 3
read = 4
class Status(IntEnum):
on = 1
off = 0
class User(Model):
username = fields.CharField(max_length=20, unique=True)
password = fields.CharField(max_length=200)
last_login = fields.DatetimeField(description="Last Login", default=datetime.datetime.now)
is_active = fields.BooleanField(default=True, description="Is Active")
is_superuser = fields.BooleanField(default=False, description="Is SuperUser")
avatar = fields.CharField(max_length=200, default="")
intro = fields.TextField(default="")
class Email(Model):
email = fields.CharField(max_length=200)
is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("models_second.User", db_constraint=False)
class Category(Model):
slug = fields.CharField(max_length=200)
name = fields.CharField(max_length=200)
user = fields.ForeignKeyField("models_second.User", description="User")
created_at = fields.DatetimeField(auto_now_add=True)
class Product(Model):
categories = fields.ManyToManyField("models_second.Category")
name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num")
sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField(ProductType, description="Product Type")
image = fields.CharField(max_length=200)
body = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True)
class Config(Model):
label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20)
value = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on)

View File

@@ -24,7 +24,7 @@ class Status(IntEnum):
class User(Model):
username = fields.CharField(max_length=20)
password = fields.CharField(max_length=200)
last_login_at = fields.DatetimeField(description="Last Login", default=datetime.datetime.now)
last_login = fields.DatetimeField(description="Last Login", default=datetime.datetime.now)
is_active = fields.BooleanField(default=True, description="Is Active")
is_superuser = fields.BooleanField(default=False, description="Is SuperUser")
avatar = fields.CharField(max_length=200, default="")
@@ -34,17 +34,18 @@ class User(Model):
class Email(Model):
email = fields.CharField(max_length=200)
is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("diff_models.User", db_constraint=True)
user = fields.ForeignKeyField("models.User", db_constraint=False)
class Category(Model):
slug = fields.CharField(max_length=200)
user = fields.ForeignKeyField("diff_models.User", description="User")
name = fields.CharField(max_length=200)
user = fields.ForeignKeyField("models.User", description="User")
created_at = fields.DatetimeField(auto_now_add=True)
class Product(Model):
categories = fields.ManyToManyField("diff_models.Category")
categories = fields.ManyToManyField("models.Category")
name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num")
sort = fields.IntField()

View File

@@ -5,7 +5,7 @@ from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL
from aerich.exceptions import NotSupportError
from aerich.migrate import Migrate
from tests.models import Category, User
from tests.models import Category, Product, User
def test_create_table():
@@ -42,7 +42,7 @@ def test_create_table():
"id" SERIAL NOT NULL PRIMARY KEY,
"slug" VARCHAR(200) NOT NULL,
"name" VARCHAR(200) NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE
);
COMMENT ON COLUMN "category"."user_id" IS 'User';"""
@@ -50,7 +50,7 @@ COMMENT ON COLUMN "category"."user_id" IS 'User';"""
def test_drop_table():
ret = Migrate.ddl.drop_table(Category)
ret = Migrate.ddl.drop_table(Category._meta.db_table)
if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "DROP TABLE IF EXISTS `category`"
else:
@@ -58,7 +58,7 @@ def test_drop_table():
def test_add_column():
ret = Migrate.ddl.add_column(Category, Category._meta.fields_map.get("name"))
ret = Migrate.ddl.add_column(Category, Category._meta.fields_map.get("name").describe(False))
if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` ADD `name` VARCHAR(200) NOT NULL"
else:
@@ -67,13 +67,12 @@ def test_add_column():
def test_modify_column():
if isinstance(Migrate.ddl, SqliteDDL):
with pytest.raises(NotSupportError):
ret0 = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name"))
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active"))
return
else:
ret0 = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name"))
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active"))
ret0 = Migrate.ddl.modify_column(
Category, Category._meta.fields_map.get("name").describe(False)
)
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active").describe(False))
if isinstance(Migrate.ddl, MysqlDDL):
assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL"
elif isinstance(Migrate.ddl, PostgresDDL):
@@ -89,47 +88,56 @@ def test_modify_column():
def test_alter_column_default():
ret = Migrate.ddl.alter_column_default(Category, Category._meta.fields_map.get("name"))
if isinstance(Migrate.ddl, SqliteDDL):
return
ret = Migrate.ddl.alter_column_default(
Category, Category._meta.fields_map.get("name").describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" DROP DEFAULT'
else:
assert ret is None
elif isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` ALTER COLUMN `name` DROP DEFAULT"
ret = Migrate.ddl.alter_column_default(Category, Category._meta.fields_map.get("created_at"))
ret = Migrate.ddl.alter_column_default(
Category, Category._meta.fields_map.get("created_at").describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL):
assert (
ret == 'ALTER TABLE "category" ALTER COLUMN "created_at" SET DEFAULT CURRENT_TIMESTAMP'
)
else:
assert ret is None
elif isinstance(Migrate.ddl, MysqlDDL):
assert (
ret
== "ALTER TABLE `category` ALTER COLUMN `created_at` SET DEFAULT CURRENT_TIMESTAMP(6)"
)
ret = Migrate.ddl.alter_column_default(User, User._meta.fields_map.get("avatar"))
ret = Migrate.ddl.alter_column_default(
Product, Product._meta.fields_map.get("view_num").describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "user" ALTER COLUMN "avatar" SET DEFAULT \'\''
else:
assert ret is None
assert ret == 'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0'
elif isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0"
def test_alter_column_null():
ret = Migrate.ddl.alter_column_null(Category, Category._meta.fields_map.get("name"))
if isinstance(Migrate.ddl, (SqliteDDL, MysqlDDL)):
return
ret = Migrate.ddl.alter_column_null(
Category, Category._meta.fields_map.get("name").describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL'
else:
assert ret is None
def test_set_comment():
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("name"))
if isinstance(Migrate.ddl, PostgresDDL):
if isinstance(Migrate.ddl, (SqliteDDL, MysqlDDL)):
return
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("name").describe(False))
assert ret == 'COMMENT ON COLUMN "category"."name" IS NULL'
else:
assert ret is None
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("user"))
if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'COMMENT ON COLUMN "category"."user" IS \'User\''
else:
assert ret is None
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("user").describe(False))
assert ret == 'COMMENT ON COLUMN "category"."user_id" IS \'User\''
def test_drop_column():
@@ -154,10 +162,7 @@ def test_add_index():
)
elif isinstance(Migrate.ddl, PostgresDDL):
assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")'
assert (
index_u
== 'ALTER TABLE "category" ADD CONSTRAINT "uid_category_name_8b0cb9" UNIQUE ("name")'
)
assert index_u == 'CREATE UNIQUE INDEX "uid_category_name_8b0cb9" ON "category" ("name")'
else:
assert index == 'ALTER TABLE "category" ADD INDEX "idx_category_name_8b0cb9" ("name")'
assert (
@@ -173,14 +178,16 @@ def test_drop_index():
assert ret_u == "ALTER TABLE `category` DROP INDEX `uid_category_name_8b0cb9`"
elif isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'DROP INDEX "idx_category_name_8b0cb9"'
assert ret_u == 'ALTER TABLE "category" DROP CONSTRAINT "uid_category_name_8b0cb9"'
assert ret_u == 'DROP INDEX "uid_category_name_8b0cb9"'
else:
assert ret == 'ALTER TABLE "category" DROP INDEX "idx_category_name_8b0cb9"'
assert ret_u == 'ALTER TABLE "category" DROP INDEX "uid_category_name_8b0cb9"'
def test_add_fk():
ret = Migrate.ddl.add_fk(Category, Category._meta.fields_map.get("user"))
ret = Migrate.ddl.add_fk(
Category, Category._meta.fields_map.get("user").describe(False), User.describe(False)
)
if isinstance(Migrate.ddl, MysqlDDL):
assert (
ret
@@ -194,7 +201,9 @@ def test_add_fk():
def test_drop_fk():
ret = Migrate.ddl.drop_fk(Category, Category._meta.fields_map.get("user"))
ret = Migrate.ddl.drop_fk(
Category, Category._meta.fields_map.get("user").describe(False), User.describe(False)
)
if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` DROP FOREIGN KEY `fk_category_user_e2e3874c`"
elif isinstance(Migrate.ddl, PostgresDDL):

View File

@@ -1,60 +1,871 @@
import pytest
from pytest_mock import MockerFixture
from tortoise import Tortoise
from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL
from aerich.exceptions import NotSupportError
from aerich.migrate import Migrate
from aerich.utils import get_models_describe
old_models_describe = {
"models.Category": {
"name": "models.Category",
"app": "models",
"table": "category",
"abstract": False,
"description": None,
"docstring": None,
"unique_together": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
"db_column": "id",
"python_type": "int",
"generated": True,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"},
},
"data_fields": [
{
"name": "slug",
"field_type": "CharField",
"db_column": "slug",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 200},
"db_field_types": {"": "VARCHAR(200)"},
},
{
"name": "name",
"field_type": "CharField",
"db_column": "name",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 200},
"db_field_types": {"": "VARCHAR(200)"},
},
{
"name": "created_at",
"field_type": "DatetimeField",
"db_column": "created_at",
"python_type": "datetime.datetime",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"readOnly": True},
"db_field_types": {
"": "TIMESTAMP",
"mysql": "DATETIME(6)",
"postgres": "TIMESTAMPTZ",
},
"auto_now_add": True,
"auto_now": False,
},
{
"name": "user_id",
"field_type": "IntField",
"db_column": "user_id",
"python_type": "int",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": "User",
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"},
},
],
"fk_fields": [
{
"name": "user",
"field_type": "ForeignKeyFieldInstance",
"python_type": "models.User",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": "User",
"docstring": None,
"constraints": {},
"raw_field": "user_id",
"on_delete": "CASCADE",
}
],
"backward_fk_fields": [],
"o2o_fields": [],
"backward_o2o_fields": [],
"m2m_fields": [
{
"name": "products",
"field_type": "ManyToManyFieldInstance",
"python_type": "models.Product",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"model_name": "models.Product",
"related_name": "categories",
"forward_key": "product_id",
"backward_key": "category_id",
"through": "product_category",
"on_delete": "CASCADE",
"_generated": True,
}
],
},
"models.Config": {
"name": "models.Config",
"app": "models",
"table": "config",
"abstract": False,
"description": None,
"docstring": None,
"unique_together": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
"db_column": "id",
"python_type": "int",
"generated": True,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"},
},
"data_fields": [
{
"name": "label",
"field_type": "CharField",
"db_column": "label",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 200},
"db_field_types": {"": "VARCHAR(200)"},
},
{
"name": "key",
"field_type": "CharField",
"db_column": "key",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 20},
"db_field_types": {"": "VARCHAR(20)"},
},
{
"name": "value",
"field_type": "JSONField",
"db_column": "value",
"python_type": "Union[dict, list]",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"db_field_types": {"": "TEXT", "postgres": "JSONB"},
},
{
"name": "status",
"field_type": "IntEnumFieldInstance",
"db_column": "status",
"python_type": "int",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": 1,
"description": "on: 1\noff: 0",
"docstring": None,
"constraints": {"ge": -32768, "le": 32767},
"db_field_types": {"": "SMALLINT"},
},
],
"fk_fields": [],
"backward_fk_fields": [],
"o2o_fields": [],
"backward_o2o_fields": [],
"m2m_fields": [],
},
"models.Email": {
"name": "models.Email",
"app": "models",
"table": "email",
"abstract": False,
"description": None,
"docstring": None,
"unique_together": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
"db_column": "id",
"python_type": "int",
"generated": True,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"},
},
"data_fields": [
{
"name": "email",
"field_type": "CharField",
"db_column": "email",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 200},
"db_field_types": {"": "VARCHAR(200)"},
},
{
"name": "is_primary",
"field_type": "BooleanField",
"db_column": "is_primary",
"python_type": "bool",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": False,
"description": None,
"docstring": None,
"constraints": {},
"db_field_types": {"": "BOOL", "sqlite": "INT"},
},
{
"name": "user_id",
"field_type": "IntField",
"db_column": "user_id",
"python_type": "int",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"},
},
],
"fk_fields": [
{
"name": "user",
"field_type": "ForeignKeyFieldInstance",
"python_type": "models.User",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"raw_field": "user_id",
"on_delete": "CASCADE",
}
],
"backward_fk_fields": [],
"o2o_fields": [],
"backward_o2o_fields": [],
"m2m_fields": [],
},
"models.Product": {
"name": "models.Product",
"app": "models",
"table": "product",
"abstract": False,
"description": None,
"docstring": None,
"unique_together": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
"db_column": "id",
"python_type": "int",
"generated": True,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"},
},
"data_fields": [
{
"name": "name",
"field_type": "CharField",
"db_column": "name",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 50},
"db_field_types": {"": "VARCHAR(50)"},
},
{
"name": "view_num",
"field_type": "IntField",
"db_column": "view_num",
"python_type": "int",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": "View Num",
"docstring": None,
"constraints": {"ge": -2147483648, "le": 2147483647},
"db_field_types": {"": "INT"},
},
{
"name": "sort",
"field_type": "IntField",
"db_column": "sort",
"python_type": "int",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": -2147483648, "le": 2147483647},
"db_field_types": {"": "INT"},
},
{
"name": "is_reviewed",
"field_type": "BooleanField",
"db_column": "is_reviewed",
"python_type": "bool",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": "Is Reviewed",
"docstring": None,
"constraints": {},
"db_field_types": {"": "BOOL", "sqlite": "INT"},
},
{
"name": "type",
"field_type": "IntEnumFieldInstance",
"db_column": "type",
"python_type": "int",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": "Product Type",
"docstring": None,
"constraints": {"ge": -32768, "le": 32767},
"db_field_types": {"": "SMALLINT"},
},
{
"name": "image",
"field_type": "CharField",
"db_column": "image",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 200},
"db_field_types": {"": "VARCHAR(200)"},
},
{
"name": "body",
"field_type": "TextField",
"db_column": "body",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"db_field_types": {"": "TEXT", "mysql": "LONGTEXT"},
},
{
"name": "created_at",
"field_type": "DatetimeField",
"db_column": "created_at",
"python_type": "datetime.datetime",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"readOnly": True},
"db_field_types": {
"": "TIMESTAMP",
"mysql": "DATETIME(6)",
"postgres": "TIMESTAMPTZ",
},
"auto_now_add": True,
"auto_now": False,
},
],
"fk_fields": [],
"backward_fk_fields": [],
"o2o_fields": [],
"backward_o2o_fields": [],
"m2m_fields": [
{
"name": "categories",
"field_type": "ManyToManyFieldInstance",
"python_type": "models.Category",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"model_name": "models.Category",
"related_name": "products",
"forward_key": "category_id",
"backward_key": "product_id",
"through": "product_category",
"on_delete": "CASCADE",
"_generated": False,
}
],
},
"models.User": {
"name": "models.User",
"app": "models",
"table": "user",
"abstract": False,
"description": None,
"docstring": None,
"unique_together": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
"db_column": "id",
"python_type": "int",
"generated": True,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"},
},
"data_fields": [
{
"name": "username",
"field_type": "CharField",
"db_column": "username",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 20},
"db_field_types": {"": "VARCHAR(20)"},
},
{
"name": "password",
"field_type": "CharField",
"db_column": "password",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 200},
"db_field_types": {"": "VARCHAR(200)"},
},
{
"name": "last_login",
"field_type": "DatetimeField",
"db_column": "last_login",
"python_type": "datetime.datetime",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": "<function None.now>",
"description": "Last Login",
"docstring": None,
"constraints": {},
"db_field_types": {
"": "TIMESTAMP",
"mysql": "DATETIME(6)",
"postgres": "TIMESTAMPTZ",
},
"auto_now_add": False,
"auto_now": False,
},
{
"name": "is_active",
"field_type": "BooleanField",
"db_column": "is_active",
"python_type": "bool",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": True,
"description": "Is Active",
"docstring": None,
"constraints": {},
"db_field_types": {"": "BOOL", "sqlite": "INT"},
},
{
"name": "is_superuser",
"field_type": "BooleanField",
"db_column": "is_superuser",
"python_type": "bool",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": False,
"description": "Is SuperUser",
"docstring": None,
"constraints": {},
"db_field_types": {"": "BOOL", "sqlite": "INT"},
},
{
"name": "avatar",
"field_type": "CharField",
"db_column": "avatar",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": "",
"description": None,
"docstring": None,
"constraints": {"max_length": 200},
"db_field_types": {"": "VARCHAR(200)"},
},
{
"name": "intro",
"field_type": "TextField",
"db_column": "intro",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": "",
"description": None,
"docstring": None,
"constraints": {},
"db_field_types": {"": "TEXT", "mysql": "LONGTEXT"},
},
],
"fk_fields": [],
"backward_fk_fields": [
{
"name": "categorys",
"field_type": "BackwardFKRelation",
"python_type": "models.Category",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": "User",
"docstring": None,
"constraints": {},
},
{
"name": "emails",
"field_type": "BackwardFKRelation",
"python_type": "models.Email",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
},
],
"o2o_fields": [],
"backward_o2o_fields": [],
"m2m_fields": [],
},
"models.Aerich": {
"name": "models.Aerich",
"app": "models",
"table": "aerich",
"abstract": False,
"description": None,
"docstring": None,
"unique_together": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
"db_column": "id",
"python_type": "int",
"generated": True,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"},
},
"data_fields": [
{
"name": "version",
"field_type": "CharField",
"db_column": "version",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 255},
"db_field_types": {"": "VARCHAR(255)"},
},
{
"name": "app",
"field_type": "CharField",
"db_column": "app",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 20},
"db_field_types": {"": "VARCHAR(20)"},
},
{
"name": "content",
"field_type": "JSONField",
"db_column": "content",
"python_type": "Union[dict, list]",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"db_field_types": {"": "TEXT", "postgres": "JSONB"},
},
],
"fk_fields": [],
"backward_fk_fields": [],
"o2o_fields": [],
"backward_o2o_fields": [],
"m2m_fields": [],
},
}
def test_migrate(mocker: MockerFixture):
mocker.patch("click.prompt", return_value=True)
apps = Tortoise.apps
models = apps.get("models")
diff_models = apps.get("diff_models")
Migrate.diff_models(diff_models, models)
"""
models.py diff with old_models.py
- change email pk: id -> email_id
- add field: Email.address
- add fk: Config.user
- drop fk: Email.user
- drop field: User.avatar
- add index: Email.email
- add many to many: Email.users
- remove unique: User.username
- change column: length User.password
- add unique_together: (name,type) of Product
- alter default: Config.status
- rename column: Product.image -> Product.pic
"""
mocker.patch("click.prompt", side_effect=(False, True))
models_describe = get_models_describe("models")
Migrate.app = "models"
if isinstance(Migrate.ddl, SqliteDDL):
with pytest.raises(NotSupportError):
Migrate.diff_models(models, diff_models, False)
Migrate.diff_models(old_models_describe, models_describe)
Migrate.diff_models(models_describe, old_models_describe, False)
else:
Migrate.diff_models(models, diff_models, False)
Migrate.diff_models(old_models_describe, models_describe)
Migrate.diff_models(models_describe, old_models_describe, False)
Migrate._merge_operators()
if isinstance(Migrate.ddl, MysqlDDL):
assert Migrate.upgrade_operators == [
assert sorted(Migrate.upgrade_operators) == sorted(
[
"ALTER TABLE `config` ADD `user_id` INT NOT NULL COMMENT 'User'",
"ALTER TABLE `config` ADD CONSTRAINT `fk_config_user_17daa970` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
"ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT",
"ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL",
"ALTER TABLE `email` DROP COLUMN `user_id`",
"ALTER TABLE `product` RENAME COLUMN `image` TO `pic`",
"ALTER TABLE `email` RENAME COLUMN `id` TO `email_id`",
"ALTER TABLE `email` DROP FOREIGN KEY `fk_email_user_5b58673d`",
"ALTER TABLE `category` ADD `name` VARCHAR(200) NOT NULL",
"ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)",
"ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_f14935` (`name`, `type`)",
"ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0",
"ALTER TABLE `user` DROP COLUMN `avatar`",
"ALTER TABLE `user` CHANGE password password VARCHAR(100)",
"ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)",
"ALTER TABLE `user` RENAME COLUMN `last_login_at` TO `last_login`",
"CREATE TABLE `email_user` (`email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,`user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE) CHARACTER SET utf8mb4",
]
assert Migrate.downgrade_operators == [
"ALTER TABLE `category` DROP COLUMN `name`",
"ALTER TABLE `user` DROP INDEX `uid_user_usernam_9987ab`",
"ALTER TABLE `user` RENAME COLUMN `last_login` TO `last_login_at`",
"ALTER TABLE `email` ADD CONSTRAINT `fk_email_user_5b58673d` FOREIGN KEY "
"(`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
)
assert sorted(Migrate.downgrade_operators) == sorted(
[
"ALTER TABLE `config` DROP COLUMN `user_id`",
"ALTER TABLE `config` DROP FOREIGN KEY `fk_config_user_17daa970`",
"ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1",
"ALTER TABLE `email` ADD `user_id` INT NOT NULL",
"ALTER TABLE `email` DROP COLUMN `address`",
"ALTER TABLE `product` RENAME COLUMN `pic` TO `image`",
"ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`",
"ALTER TABLE `email` ADD CONSTRAINT `fk_email_user_5b58673d` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
"ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`",
"ALTER TABLE `product` DROP INDEX `uid_product_name_f14935`",
"ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT",
"ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''",
"ALTER TABLE `user` DROP INDEX `idx_user_usernam_9987ab`",
"ALTER TABLE `user` CHANGE password password VARCHAR(200)",
"DROP TABLE IF EXISTS `email_user`",
]
)
elif isinstance(Migrate.ddl, PostgresDDL):
assert Migrate.upgrade_operators == [
assert sorted(Migrate.upgrade_operators) == sorted(
[
'ALTER TABLE "config" ADD "user_id" INT NOT NULL',
'ALTER TABLE "config" ADD CONSTRAINT "fk_config_user_17daa970" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE',
'ALTER TABLE "config" ALTER COLUMN "status" DROP DEFAULT',
'ALTER TABLE "email" ADD "address" VARCHAR(200) NOT NULL',
'ALTER TABLE "email" DROP COLUMN "user_id"',
'ALTER TABLE "product" RENAME COLUMN "image" TO "pic"',
'ALTER TABLE "email" RENAME COLUMN "id" TO "email_id"',
'ALTER TABLE "email" DROP CONSTRAINT "fk_email_user_5b58673d"',
'ALTER TABLE "category" ADD "name" VARCHAR(200) NOT NULL',
'ALTER TABLE "user" ADD CONSTRAINT "uid_user_usernam_9987ab" UNIQUE ("username")',
'ALTER TABLE "user" RENAME COLUMN "last_login_at" TO "last_login"',
'CREATE INDEX "idx_email_email_4a1a33" ON "email" ("email")',
'CREATE UNIQUE INDEX "uid_product_name_f14935" ON "product" ("name", "type")',
'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0',
'ALTER TABLE "user" DROP COLUMN "avatar"',
'ALTER TABLE "user" CHANGE password password VARCHAR(100)',
'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")',
'CREATE TABLE "email_user" ("email_id" INT NOT NULL REFERENCES "email" ("email_id") ON DELETE CASCADE,"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE)',
]
assert Migrate.downgrade_operators == [
'ALTER TABLE "category" DROP COLUMN "name"',
'ALTER TABLE "user" DROP CONSTRAINT "uid_user_usernam_9987ab"',
'ALTER TABLE "user" RENAME COLUMN "last_login" TO "last_login_at"',
)
assert sorted(Migrate.downgrade_operators) == sorted(
[
'ALTER TABLE "config" DROP COLUMN "user_id"',
'ALTER TABLE "config" DROP CONSTRAINT "fk_config_user_17daa970"',
'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1',
'ALTER TABLE "email" ADD "user_id" INT NOT NULL',
'ALTER TABLE "email" DROP COLUMN "address"',
'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"',
'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"',
'ALTER TABLE "email" ADD CONSTRAINT "fk_email_user_5b58673d" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE',
'DROP INDEX "idx_email_email_4a1a33"',
'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT',
'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'',
'DROP INDEX "idx_user_usernam_9987ab"',
'DROP INDEX "uid_product_name_f14935"',
'ALTER TABLE "user" CHANGE password password VARCHAR(200)',
'DROP TABLE IF EXISTS "email_user"',
]
)
elif isinstance(Migrate.ddl, SqliteDDL):
assert Migrate.upgrade_operators == [
'ALTER TABLE "email" DROP FOREIGN KEY "fk_email_user_5b58673d"',
'ALTER TABLE "category" ADD "name" VARCHAR(200) NOT NULL',
'ALTER TABLE "user" ADD UNIQUE INDEX "uid_user_usernam_9987ab" ("username")',
'ALTER TABLE "user" RENAME COLUMN "last_login_at" TO "last_login"',
'ALTER TABLE "config" ADD "user_id" INT NOT NULL /* User */',
'ALTER TABLE "config" ADD CONSTRAINT "fk_config_user_17daa970" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE',
]
assert Migrate.downgrade_operators == []
@@ -62,18 +873,18 @@ def test_sort_all_version_files(mocker):
mocker.patch(
"os.listdir",
return_value=[
"1_datetime_update.json",
"11_datetime_update.json",
"10_datetime_update.json",
"2_datetime_update.json",
"1_datetime_update.sql",
"11_datetime_update.sql",
"10_datetime_update.sql",
"2_datetime_update.sql",
],
)
Migrate.migrate_location = "."
assert Migrate.get_all_version_files() == [
"1_datetime_update.json",
"2_datetime_update.json",
"10_datetime_update.json",
"11_datetime_update.json",
"1_datetime_update.sql",
"2_datetime_update.sql",
"10_datetime_update.sql",
"11_datetime_update.sql",
]