65 Commits

Author SHA1 Message Date
long2ice
0b01fa38d8 feat: add index inspect 2022-04-05 19:38:08 +08:00
long2ice
801dde15be feat: inspectdb support sqlite 2022-04-01 20:30:36 +08:00
long2ice
75480e2041 Merge remote-tracking branch 'origin/dev' into dev 2022-04-01 19:57:03 +08:00
long2ice
45129cef9f feat: improve inspectdb and support postgres 2022-04-01 19:56:48 +08:00
long2ice
3a0dd2355d Merge pull request #230 from ssilaev/dev
Increase max length of app column
2022-02-09 15:01:39 +08:00
Sergey Silaev
0e71bc16ae Increase max length of app column 2022-02-08 22:14:55 +03:00
long2ice
c39462820c upgrade deps 2022-01-17 22:26:13 +08:00
long2ice
f15cbaf9e0 Support migration for specified index. (#203) 2021-12-29 21:36:23 +08:00
long2ice
15131469df upgrade deps 2021-12-22 16:26:13 +08:00
long2ice
c60c1610f0 Fix pyproject.toml not existing error. (#217) 2021-12-12 22:11:51 +08:00
long2ice
63e8d06157 remove aiomysql 2021-12-08 14:43:33 +08:00
long2ice
68ef8ac676 Fix ci 2021-12-08 14:38:16 +08:00
long2ice
8b5cf6faa0 inspectdb support DATE. (#215) 2021-12-08 14:33:27 +08:00
long2ice
fac00d45cc Remove pydantic dependency. (#198) 2021-10-04 23:05:20 +08:00
long2ice
6f7893d376 Fix section name 2021-09-28 15:07:10 +08:00
long2ice
b1521c4cc7 update version 2021-09-27 19:55:38 +08:00
long2ice
24c1f4cb7d Change default config file from aerich.ini to pyproject.toml. (#197) 2021-09-27 11:05:20 +08:00
long2ice
661f241dac Compatible with old version in indexes 2021-08-31 17:53:17 +08:00
long2ice
01787558d6 Fix test 2021-08-31 17:41:13 +08:00
long2ice
699b0321a4 Support indexes change. (#193) 2021-08-31 17:36:25 +08:00
long2ice
4a83021892 Update FUNDING.yml 2021-08-26 20:39:31 +08:00
long2ice
af63221875 Fix no module found error. (#188) (#189) 2021-08-16 11:14:43 +08:00
long2ice
359525716c update README.md 2021-08-12 15:42:54 +08:00
long2ice
7d3eb2e151 Merge pull request #181 from Vovetta/dev
Fix: migrate doesn't use source_field in unique_together
2021-08-04 09:42:18 +08:00
Vovetta
d8abf79449 Updated changelog and version 2021-08-03 10:38:31 -07:00
Vovetta
aa9f40ae27 Fix: migrate doesn't use source_field in unique_together 2021-08-03 10:36:06 -07:00
long2ice
79b7ae343a update README.md 2021-08-03 16:25:06 +08:00
long2ice
6f5a9ab78c Add Command class. (#148) (#141) (#123) (#106) 2021-08-03 16:18:07 +08:00
long2ice
1e5a83c281 update deps 2021-07-26 17:44:18 +08:00
long2ice
180420843d update README.md 2021-07-26 15:27:49 +08:00
long2ice
58f66b91cf Fix redundant semicolons 2021-07-23 17:07:10 +08:00
long2ice
064d7ff675 Fix ci 2021-07-22 15:32:07 +08:00
long2ice
2da794d823 Fix db_constraint when fk changed. (#179) 2021-07-22 14:37:49 +08:00
long2ice
77005f3793 Fix MySQL 5.X rename column. 2021-07-09 10:53:13 +08:00
long2ice
5a873b8b69 Merge pull request #177 from yusukefs/add-default-src-folder-config
Add default value for src_folder config
2021-07-08 17:27:29 +08:00
Yusuke Sakai
3989b7c674 Update version and changelog 2021-07-08 18:01:59 +09:00
Yusuke Sakai
694b05356f Add default src_folder cofig value 2021-07-08 17:35:44 +09:00
long2ice
919d56c936 add ci branches-ignore master 2021-07-07 10:29:38 +08:00
long2ice
7bcf9b2fed Support drop column for sqlite. (#40) 2021-07-03 13:51:01 +08:00
long2ice
9f663299cf Merge pull request #174 from sasha00123/dev
Fixed typo in README.md concerning dowgrade usage
2021-06-25 13:52:34 +08:00
Alexander Batyrgariev
28dbdf2663 Fixed typo in README.md concerning dowgrade usage 2021-06-25 08:00:00 +03:00
long2ice
e71a4b60a5 Merge pull request #166 from spacemanspiff2007/dev
Added config option to specify source folder
2021-06-13 14:26:26 +08:00
-
62840136be used old black version 2021-06-11 15:36:54 +02:00
-
185514f711 reformatted with black 2021-06-11 15:18:06 +02:00
-
8e783e031e updated readme 2021-06-10 16:56:30 +02:00
-
10b7272ca8 Added an configuration option to specify the path of the source folder.
This will make aerich work with various folder structures (e.g. ./src/MyPythonModule)
Additionally this will try to import in init and show the user the error message on failure.
2021-06-10 16:52:03 +02:00
long2ice
0c763c6024 Fix repeat 2021-06-09 13:56:25 +08:00
long2ice
c6371a5c16 Fix repeat 2021-06-09 11:43:32 +08:00
long2ice
1dbf9185b6 Not catch exception when import config. (#164) 2021-06-04 17:47:39 +08:00
long2ice
9bf2de0b9a Fix incorrect index creation order. (#151) 2021-06-01 17:09:45 +08:00
long2ice
bf1cf21324 Merge pull request #158 from manzato/pyproject-update
Update URLs
2021-05-22 22:43:52 +08:00
Guillermo Manzato
8b08329493 Update URLs 2021-05-22 11:39:49 -03:00
long2ice
5bc7d23d95 Merge pull request #157 from tortoise/dependabot/pip/pydantic-1.8.2
Bump pydantic from 1.8.1 to 1.8.2
2021-05-14 09:30:47 +08:00
dependabot[bot]
a253aa96cb Bump pydantic from 1.8.1 to 1.8.2
Bumps [pydantic](https://github.com/samuelcolvin/pydantic) from 1.8.1 to 1.8.2.
- [Release notes](https://github.com/samuelcolvin/pydantic/releases)
- [Changelog](https://github.com/samuelcolvin/pydantic/blob/master/HISTORY.md)
- [Commits](https://github.com/samuelcolvin/pydantic/compare/v1.8.1...v1.8.2)

Signed-off-by: dependabot[bot] <support@github.com>
2021-05-13 20:51:51 +00:00
long2ice
15a6e874dd update deps 2021-05-03 14:23:27 +08:00
long2ice
19a5dcbf3f update deps 2021-04-26 21:01:40 +08:00
long2ice
922e3eef16 Fix CI 2021-04-05 17:11:28 +08:00
long2ice
44fd2fe6ae Fix default function when migrate. (#147) 2021-04-05 14:10:42 +08:00
long2ice
b147859960 Fix default function when migrate 2021-04-04 05:46:34 +00:00
long2ice
793cf2532c Create FUNDING.yml 2021-04-03 21:34:24 +08:00
long2ice
fa85e05d1d Fix postgre alter null. (#142) 2021-03-28 16:22:49 +08:00
long2ice
3f52ac348b Support rename table. (#139) 2021-03-25 21:21:49 +08:00
long2ice
f8aa7a8f34 Fix inspectdb for FloatField. (#138) 2021-03-22 14:16:59 +08:00
long2ice
44d520cc82 Fix postgres field type change error. (#135) 2021-03-21 21:18:08 +08:00
long2ice
364735f804 Fix rename field on the field add. (#134) 2021-03-21 20:43:05 +08:00
35 changed files with 1597 additions and 906 deletions

1
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1 @@
custom: ["https://sponsor.long2ice.io"]

View File

@@ -1,5 +1,11 @@
name: ci name: ci
on: [ push, pull_request ] on:
push:
branches-ignore:
- master
pull_request:
branches-ignore:
- master
jobs: jobs:
ci: ci:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -20,9 +26,9 @@ jobs:
with: with:
python-version: '3.x' python-version: '3.x'
- name: Install and configure Poetry - name: Install and configure Poetry
uses: snok/install-poetry@v1.1.1 run: |
with: pip install -U pip poetry
virtualenvs-create: false poetry config virtualenvs.create false
- name: CI - name: CI
env: env:
MYSQL_PASS: root MYSQL_PASS: root

View File

@@ -12,9 +12,9 @@ jobs:
with: with:
python-version: '3.x' python-version: '3.x'
- name: Install and configure Poetry - name: Install and configure Poetry
uses: snok/install-poetry@v1.1.1 run: |
with: pip install -U pip poetry
virtualenvs-create: false poetry config virtualenvs.create false
- name: Build dists - name: Build dists
run: make build run: make build
- name: Pypi Publish - name: Pypi Publish

1
.gitignore vendored
View File

@@ -146,3 +146,4 @@ aerich.ini
src src
.vscode .vscode
.DS_Store .DS_Store
.python-version

View File

@@ -1,7 +1,68 @@
# ChangeLog # ChangeLog
## 0.6
### 0.6.3
- Improve `inspectdb` and support `postgres` & `sqlite`.
### 0.6.2
- Support migration for specified index. (#203)
### 0.6.1
- Fix `pyproject.toml` not existing error. (#217)
### 0.6.0
- Change default config file from `aerich.ini` to `pyproject.toml`. (#197)
**Upgrade note:**
1. Run `aerich init -t config.TORTOISE_ORM`.
2. Remove `aerich.ini`.
- Remove `pydantic` dependency. (#198)
- `inspectdb` support `DATE`. (#215)
## 0.5 ## 0.5
### 0.5.8
- Support `indexes` change. (#193)
### 0.5.7
- Fix no module found error. (#188) (#189)
### 0.5.6
- Add `Command` class. (#148) (#141) (#123) (#106)
- Fix: migrate doesn't use source_field in unique_together. (#181)
### 0.5.5
- Fix KeyError: 'src_folder' after upgrading aerich to 0.5.4. (#176)
- Fix MySQL 5.X rename column.
- Fix `db_constraint` when fk changed. (#179)
### 0.5.4
- Fix incorrect index creation order. (#151)
- Not catch exception when import config. (#164)
- Support `drop column` for sqlite. (#40)
### 0.5.3
- Fix postgre alter null. (#142)
- Fix default function when migrate. (#147)
### 0.5.2
- Fix rename field on the field add. (#134)
- Fix postgres field type change error. (#135)
- Fix inspectdb for `FloatField`. (#138)
- Support `rename table`. (#139)
### 0.5.1 ### 0.5.1
- Fix tortoise connections not being closed properly. (#120) - Fix tortoise connections not being closed properly. (#120)

View File

@@ -12,16 +12,15 @@ up:
@poetry update @poetry update
deps: deps:
@poetry install -E asyncpg -E asyncmy -E aiomysql @poetry install -E asyncpg -E asyncmy
style: deps style: deps
isort -src $(checkfiles) @isort -src $(checkfiles)
black $(black_opts) $(checkfiles) @black $(black_opts) $(checkfiles)
check: deps check: deps
black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false) @black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
flake8 $(checkfiles) @pflake8 $(checkfiles)
bandit -x tests -r $(checkfiles)
test: deps test: deps
$(py_warn) TEST_DB=sqlite://:memory: py.test $(py_warn) TEST_DB=sqlite://:memory: py.test

105
README.md
View File

@@ -1,25 +1,21 @@
# Aerich # Aerich
[![image](https://img.shields.io/pypi/v/aerich.svg?style=flat)](https://pypi.python.org/pypi/aerich) [![image](https://img.shields.io/pypi/v/aerich.svg?style=flat)](https://pypi.python.org/pypi/aerich)
[![image](https://img.shields.io/github/license/long2ice/aerich)](https://github.com/long2ice/aerich) [![image](https://img.shields.io/github/license/tortoise/aerich)](https://github.com/tortoise/aerich)
[![image](https://github.com/long2ice/aerich/workflows/pypi/badge.svg)](https://github.com/long2ice/aerich/actions?query=workflow:pypi) [![image](https://github.com/tortoise/aerich/workflows/pypi/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:pypi)
[![image](https://github.com/long2ice/aerich/workflows/ci/badge.svg)](https://github.com/long2ice/aerich/actions?query=workflow:ci) [![image](https://github.com/tortoise/aerich/workflows/ci/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:ci)
## Introduction ## Introduction
Aerich is a database migrations tool for Tortoise-ORM, which is like alembic for SQLAlchemy, or like Django ORM with Aerich is a database migrations tool for TortoiseORM, which is like alembic for SQLAlchemy, or like Django ORM with
it\'s own migrations solution. it\'s own migration solution.
~~**Important: You can only use absolutely import in your `models.py` to make `aerich` work.**~~
From version `v0.5.0`, there is no such limitation now.
## Install ## Install
Just install from pypi: Just install from pypi:
```shell ```shell
> pip install aerich pip install aerich
``` ```
## Quick Start ## Quick Start
@@ -30,10 +26,9 @@ Just install from pypi:
Usage: aerich [OPTIONS] COMMAND [ARGS]... Usage: aerich [OPTIONS] COMMAND [ARGS]...
Options: Options:
-c, --config TEXT Config file. [default: aerich.ini] -V, --version Show the version and exit.
--app TEXT Tortoise-ORM app name. [default: models] -c, --config TEXT Config file. [default: pyproject.toml]
-n, --name TEXT Name of section in .ini file to use for aerich config. --app TEXT Tortoise-ORM app name.
[default: aerich]
-h, --help Show this message and exit. -h, --help Show this message and exit.
Commands: Commands:
@@ -44,7 +39,7 @@ Commands:
init-db Generate schema and generate app migrate location. init-db Generate schema and generate app migrate location.
inspectdb Introspects the database tables to standard output as... inspectdb Introspects the database tables to standard output as...
migrate Generate migrate changes file. migrate Generate migrate changes file.
upgrade Upgrade to latest version. upgrade Upgrade to specified version.
``` ```
## Usage ## Usage
@@ -73,9 +68,10 @@ Usage: aerich init [OPTIONS]
Init config file and generate root migrate location. Init config file and generate root migrate location.
Options: Options:
-t, --tortoise-orm TEXT Tortoise-ORM config module dict variable, like settings.TORTOISE_ORM. -t, --tortoise-orm TEXT Tortoise-ORM config module dict variable, like
[required] settings.TORTOISE_ORM. [required]
--location TEXT Migrate store location. [default: ./migrations] --location TEXT Migrate store location. [default: ./migrations]
-s, --src_folder TEXT Folder of the source, relative to the project root.
-h, --help Show this message and exit. -h, --help Show this message and exit.
``` ```
@@ -85,7 +81,7 @@ Initialize the config file and migrations location:
> aerich init -t tests.backends.mysql.TORTOISE_ORM > aerich init -t tests.backends.mysql.TORTOISE_ORM
Success create migrate location ./migrations Success create migrate location ./migrations
Success generate config file aerich.ini Success write config to pyproject.toml
``` ```
### Init db ### Init db
@@ -97,8 +93,8 @@ Success create app migrate location ./migrations/models
Success generate schema for app "models" Success generate schema for app "models"
``` ```
If your Tortoise-ORM app is not the default `models`, you must specify the correct app via `--app`, e.g. `aerich If your Tortoise-ORM app is not the default `models`, you must specify the correct app via `--app`,
--app other_models init-db`. e.g. `aerich --app other_models init-db`.
### Update models and make migrate ### Update models and make migrate
@@ -128,7 +124,7 @@ Now your db is migrated to latest.
### Downgrade to specified version ### Downgrade to specified version
```shell ```shell
> aerich init -h > aerich downgrade -h
Usage: aerich downgrade [OPTIONS] Usage: aerich downgrade [OPTIONS]
@@ -169,7 +165,7 @@ Now your db is rolled back to the specified version.
### Inspect db tables to TortoiseORM model ### Inspect db tables to TortoiseORM model
Currently `inspectdb` only supports MySQL. Currently `inspectdb` support MySQL & Postgres & SQLite.
```shell ```shell
Usage: aerich inspectdb [OPTIONS] Usage: aerich inspectdb [OPTIONS]
@@ -193,8 +189,44 @@ Inspect a specified table in the default app and redirect to `models.py`:
aerich inspectdb -t user > models.py aerich inspectdb -t user > models.py
``` ```
Note that this command is limited and cannot infer some fields, such as `IntEnumField`, `ForeignKeyField`, and For example, you table is:
others.
```sql
CREATE TABLE `test`
(
`id` int NOT NULL AUTO_INCREMENT,
`decimal` decimal(10, 2) NOT NULL,
`date` date DEFAULT NULL,
`datetime` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`time` time DEFAULT NULL,
`float` float DEFAULT NULL,
`string` varchar(200) COLLATE utf8mb4_general_ci DEFAULT NULL,
`tinyint` tinyint DEFAULT NULL,
PRIMARY KEY (`id`),
KEY `asyncmy_string_index` (`string`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COLLATE = utf8mb4_general_ci
```
Now run `aerich inspectdb -t test` to see the generated model:
```python
from tortoise import Model, fields
class Test(Model):
date = fields.DateField(null=True, )
datetime = fields.DatetimeField(auto_now=True, )
decimal = fields.DecimalField(max_digits=10, decimal_places=2, )
float = fields.FloatField(null=True, )
id = fields.IntField(pk=True, )
string = fields.CharField(max_length=200, null=True, )
time = fields.TimeField(null=True, )
tinyint = fields.BooleanField(null=True, )
```
Note that this command is limited and can't infer some fields, such as `IntEnumField`, `ForeignKeyField`, and others.
### Multiple databases ### Multiple databases
@@ -213,11 +245,28 @@ tortoise_orm = {
You only need to specify `aerich.models` in one app, and must specify `--app` when running `aerich migrate` and so on. You only need to specify `aerich.models` in one app, and must specify `--app` when running `aerich migrate` and so on.
## Support this project ## Restore `aerich` workflow
| AliPay | WeChatPay | PayPal | In some cases, such as broken changes from upgrade of `aerich`, you can't run `aerich migrate` or `aerich upgrade`, you
| -------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | ---------------------------------------------------------------- | can make the following steps:
| <img width="200" src="https://github.com/long2ice/aerich/raw/dev/images/alipay.jpeg"/> | <img width="200" src="https://github.com/long2ice/aerich/raw/dev/images/wechatpay.jpeg"/> | [PayPal](https://www.paypal.me/long2ice) to my account long2ice. |
1. drop `aerich` table.
2. delete `migrations/{app}` directory.
3. rerun `aerich init-db`.
Note that these actions is safe, also you can do that to reset your migrations if your migration files is too many.
## Use `aerich` in application
You can use `aerich` out of cli by use `Command` class.
```python
from aerich import Command
command = Command(tortoise_config=config, app='models')
await command.init()
await command.migrate('test')
```
## License ## License

View File

@@ -1 +1,149 @@
__version__ = "0.5.1" import os
from pathlib import Path
from typing import List
from tortoise import Tortoise, generate_schema_for_client
from tortoise.exceptions import OperationalError
from tortoise.transactions import in_transaction
from tortoise.utils import get_schema_sql
from aerich.exceptions import DowngradeError
from aerich.inspect.mysql import InspectMySQL
from aerich.inspect.postgres import InspectPostgres
from aerich.inspect.sqlite import InspectSQLite
from aerich.migrate import Migrate
from aerich.models import Aerich
from aerich.utils import (
get_app_connection,
get_app_connection_name,
get_models_describe,
get_version_content_from_file,
write_version_file,
)
class Command:
def __init__(
self,
tortoise_config: dict,
app: str = "models",
location: str = "./migrations",
):
self.tortoise_config = tortoise_config
self.app = app
self.location = location
Migrate.app = app
async def init(self):
await Migrate.init(self.tortoise_config, self.app, self.location)
async def upgrade(self):
migrated = []
for version_file in Migrate.get_all_version_files():
try:
exists = await Aerich.exists(version=version_file, app=self.app)
except OperationalError:
exists = False
if not exists:
async with in_transaction(
get_app_connection_name(self.tortoise_config, self.app)
) as conn:
file_path = Path(Migrate.migrate_location, version_file)
content = get_version_content_from_file(file_path)
upgrade_query_list = content.get("upgrade")
for upgrade_query in upgrade_query_list:
await conn.execute_script(upgrade_query)
await Aerich.create(
version=version_file,
app=self.app,
content=get_models_describe(self.app),
)
migrated.append(version_file)
return migrated
async def downgrade(self, version: int, delete: bool):
ret = []
if version == -1:
specified_version = await Migrate.get_last_version()
else:
specified_version = await Aerich.filter(
app=self.app, version__startswith=f"{version}_"
).first()
if not specified_version:
raise DowngradeError("No specified version found")
if version == -1:
versions = [specified_version]
else:
versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk)
for version in versions:
file = version.version
async with in_transaction(
get_app_connection_name(self.tortoise_config, self.app)
) as conn:
file_path = Path(Migrate.migrate_location, file)
content = get_version_content_from_file(file_path)
downgrade_query_list = content.get("downgrade")
if not downgrade_query_list:
raise DowngradeError("No downgrade items found")
for downgrade_query in downgrade_query_list:
await conn.execute_query(downgrade_query)
await version.delete()
if delete:
os.unlink(file_path)
ret.append(file)
return ret
async def heads(self):
ret = []
versions = Migrate.get_all_version_files()
for version in versions:
if not await Aerich.exists(version=version, app=self.app):
ret.append(version)
return ret
async def history(self):
ret = []
versions = Migrate.get_all_version_files()
for version in versions:
ret.append(version)
return ret
async def inspectdb(self, tables: List[str] = None) -> str:
connection = get_app_connection(self.tortoise_config, self.app)
dialect = connection.schema_generator.DIALECT
if dialect == "mysql":
cls = InspectMySQL
elif dialect == "postgres":
cls = InspectPostgres
elif dialect == "sqlite":
cls = InspectSQLite
else:
raise NotImplementedError(f"{dialect} is not supported")
inspect = cls(connection, tables)
return await inspect.inspect()
async def migrate(self, name: str = "update"):
return await Migrate.migrate(name)
async def init_db(self, safe: bool):
location = self.location
app = self.app
dirname = Path(location, app)
dirname.mkdir(parents=True)
await Tortoise.init(config=self.tortoise_config)
connection = get_app_connection(self.tortoise_config, app)
await generate_schema_for_client(connection, safe)
schema = get_schema_sql(connection, safe)
version = await Migrate.generate_version()
await Aerich.create(
version=version,
app=app,
content=get_models_describe(app),
)
content = {
"upgrade": [schema],
}
write_version_file(Path(dirname, version), content)

View File

@@ -1,34 +1,24 @@
import asyncio import asyncio
import os import os
import sys
from configparser import ConfigParser
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from typing import List from typing import List
import click import click
import tomlkit
from click import Context, UsageError from click import Context, UsageError
from tortoise import Tortoise, generate_schema_for_client from tomlkit.exceptions import NonExistentKey
from tortoise.exceptions import OperationalError from tortoise import Tortoise
from tortoise.transactions import in_transaction
from tortoise.utils import get_schema_sql
from aerich.inspectdb import InspectDb from aerich import Command
from aerich.migrate import Migrate from aerich.enums import Color
from aerich.utils import ( from aerich.exceptions import DowngradeError
get_app_connection, from aerich.utils import add_src_path, get_tortoise_config
get_app_connection_name, from aerich.version import __version__
get_models_describe,
get_tortoise_config,
get_version_content_from_file,
write_version_file,
)
from . import __version__ CONFIG_DEFAULT_VALUES = {
from .enums import Color "src_folder": ".",
from .models import Aerich }
parser = ConfigParser()
def coro(f): def coro(f):
@@ -40,7 +30,7 @@ def coro(f):
try: try:
loop.run_until_complete(f(*args, **kwargs)) loop.run_until_complete(f(*args, **kwargs))
finally: finally:
if f.__name__ != "cli": if Tortoise._inited:
loop.run_until_complete(Tortoise.close_connections()) loop.run_until_complete(Tortoise.close_connections())
return wrapper return wrapper
@@ -51,44 +41,40 @@ def coro(f):
@click.option( @click.option(
"-c", "-c",
"--config", "--config",
default="aerich.ini", default="pyproject.toml",
show_default=True, show_default=True,
help="Config file.", help="Config file.",
) )
@click.option("--app", required=False, help="Tortoise-ORM app name.") @click.option("--app", required=False, help="Tortoise-ORM app name.")
@click.option(
"-n",
"--name",
default="aerich",
show_default=True,
help="Name of section in .ini file to use for aerich config.",
)
@click.pass_context @click.pass_context
@coro @coro
async def cli(ctx: Context, config, app, name): async def cli(ctx: Context, config, app):
ctx.ensure_object(dict) ctx.ensure_object(dict)
ctx.obj["config_file"] = config ctx.obj["config_file"] = config
ctx.obj["name"] = name
invoked_subcommand = ctx.invoked_subcommand invoked_subcommand = ctx.invoked_subcommand
if invoked_subcommand != "init": if invoked_subcommand != "init":
if not Path(config).exists(): if not Path(config).exists():
raise UsageError("You must exec init first", ctx=ctx) raise UsageError("You must exec init first", ctx=ctx)
parser.read(config) with open(config, "r") as f:
content = f.read()
location = parser[name]["location"] doc = tomlkit.parse(content)
tortoise_orm = parser[name]["tortoise_orm"] try:
tool = doc["tool"]["aerich"]
location = tool["location"]
tortoise_orm = tool["tortoise_orm"]
src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"])
except NonExistentKey:
raise UsageError("You need run aerich init again when upgrade to 0.6.0+")
add_src_path(src_folder)
tortoise_config = get_tortoise_config(ctx, tortoise_orm) tortoise_config = get_tortoise_config(ctx, tortoise_orm)
app = app or list(tortoise_config.get("apps").keys())[0] app = app or list(tortoise_config.get("apps").keys())[0]
ctx.obj["config"] = tortoise_config command = Command(tortoise_config=tortoise_config, app=app, location=location)
ctx.obj["location"] = location ctx.obj["command"] = command
ctx.obj["app"] = app
Migrate.app = app
if invoked_subcommand != "init-db": if invoked_subcommand != "init-db":
if not Path(location, app).exists(): if not Path(location, app).exists():
raise UsageError("You must exec init-db first", ctx=ctx) raise UsageError("You must exec init-db first", ctx=ctx)
await Migrate.init(tortoise_config, app, location) await command.init()
@cli.command(help="Generate migrate changes file.") @cli.command(help="Generate migrate changes file.")
@@ -96,7 +82,8 @@ async def cli(ctx: Context, config, app, name):
@click.pass_context @click.pass_context
@coro @coro
async def migrate(ctx: Context, name): async def migrate(ctx: Context, name):
ret = await Migrate.migrate(name) command = ctx.obj["command"]
ret = await command.migrate(name)
if not ret: if not ret:
return click.secho("No changes detected", fg=Color.yellow) return click.secho("No changes detected", fg=Color.yellow)
click.secho(f"Success migrate {ret}", fg=Color.green) click.secho(f"Success migrate {ret}", fg=Color.green)
@@ -106,30 +93,13 @@ async def migrate(ctx: Context, name):
@click.pass_context @click.pass_context
@coro @coro
async def upgrade(ctx: Context): async def upgrade(ctx: Context):
config = ctx.obj["config"] command = ctx.obj["command"]
app = ctx.obj["app"] migrated = await command.upgrade()
migrated = False
for version_file in Migrate.get_all_version_files():
try:
exists = await Aerich.exists(version=version_file, app=app)
except OperationalError:
exists = False
if not exists:
async with in_transaction(get_app_connection_name(config, app)) as conn:
file_path = Path(Migrate.migrate_location, version_file)
content = get_version_content_from_file(file_path)
upgrade_query_list = content.get("upgrade")
for upgrade_query in upgrade_query_list:
await conn.execute_script(upgrade_query)
await Aerich.create(
version=version_file,
app=app,
content=get_models_describe(app),
)
click.secho(f"Success upgrade {version_file}", fg=Color.green)
migrated = True
if not migrated: if not migrated:
click.secho("No upgrade items found", fg=Color.yellow) click.secho("No upgrade items found", fg=Color.yellow)
else:
for version_file in migrated:
click.secho(f"Success upgrade {version_file}", fg=Color.green)
@cli.command(help="Downgrade to specified version.") @cli.command(help="Downgrade to specified version.")
@@ -155,59 +125,37 @@ async def upgrade(ctx: Context):
) )
@coro @coro
async def downgrade(ctx: Context, version: int, delete: bool): async def downgrade(ctx: Context, version: int, delete: bool):
app = ctx.obj["app"] command = ctx.obj["command"]
config = ctx.obj["config"] try:
if version == -1: files = await command.downgrade(version, delete)
specified_version = await Migrate.get_last_version() except DowngradeError as e:
else: return click.secho(str(e), fg=Color.yellow)
specified_version = await Aerich.filter(app=app, version__startswith=f"{version}_").first() for file in files:
if not specified_version: click.secho(f"Success downgrade {file}", fg=Color.green)
return click.secho("No specified version found", fg=Color.yellow)
if version == -1:
versions = [specified_version]
else:
versions = await Aerich.filter(app=app, pk__gte=specified_version.pk)
for version in versions:
file = version.version
async with in_transaction(get_app_connection_name(config, app)) as conn:
file_path = Path(Migrate.migrate_location, file)
content = get_version_content_from_file(file_path)
downgrade_query_list = content.get("downgrade")
if not downgrade_query_list:
click.secho("No downgrade items found", fg=Color.yellow)
return
for downgrade_query in downgrade_query_list:
await conn.execute_query(downgrade_query)
await version.delete()
if delete:
os.unlink(file_path)
click.secho(f"Success downgrade {file}", fg=Color.green)
@cli.command(help="Show current available heads in migrate location.") @cli.command(help="Show current available heads in migrate location.")
@click.pass_context @click.pass_context
@coro @coro
async def heads(ctx: Context): async def heads(ctx: Context):
app = ctx.obj["app"] command = ctx.obj["command"]
versions = Migrate.get_all_version_files() head_list = await command.heads()
is_heads = False if not head_list:
for version in versions: return click.secho("No available heads, try migrate first", fg=Color.green)
if not await Aerich.exists(version=version, app=app): for version in head_list:
click.secho(version, fg=Color.green) click.secho(version, fg=Color.green)
is_heads = True
if not is_heads:
click.secho("No available heads,try migrate first", fg=Color.green)
@cli.command(help="List all migrate items.") @cli.command(help="List all migrate items.")
@click.pass_context @click.pass_context
@coro @coro
async def history(ctx: Context): async def history(ctx: Context):
versions = Migrate.get_all_version_files() command = ctx.obj["command"]
versions = await command.history()
if not versions:
return click.secho("No history, try migrate", fg=Color.green)
for version in versions: for version in versions:
click.secho(version, fg=Color.green) click.secho(version, fg=Color.green)
if not versions:
click.secho("No history,try migrate", fg=Color.green)
@cli.command(help="Init config file and generate root migrate location.") @cli.command(help="Init config file and generate root migrate location.")
@@ -223,29 +171,46 @@ async def history(ctx: Context):
show_default=True, show_default=True,
help="Migrate store location.", help="Migrate store location.",
) )
@click.option(
"-s",
"--src_folder",
default=CONFIG_DEFAULT_VALUES["src_folder"],
show_default=False,
help="Folder of the source, relative to the project root.",
)
@click.pass_context @click.pass_context
@coro @coro
async def init( async def init(ctx: Context, tortoise_orm, location, src_folder):
ctx: Context,
tortoise_orm,
location,
):
config_file = ctx.obj["config_file"] config_file = ctx.obj["config_file"]
name = ctx.obj["name"]
if os.path.isabs(src_folder):
src_folder = os.path.relpath(os.getcwd(), src_folder)
# Add ./ so it's clear that this is relative path
if not src_folder.startswith("./"):
src_folder = "./" + src_folder
# check that we can find the configuration, if not we can fail before the config file gets created
add_src_path(src_folder)
get_tortoise_config(ctx, tortoise_orm)
if Path(config_file).exists(): if Path(config_file).exists():
return click.secho("You have inited", fg=Color.yellow) with open(config_file, "r") as f:
content = f.read()
doc = tomlkit.parse(content)
else:
doc = tomlkit.parse("[tool.aerich]")
table = tomlkit.table()
table["tortoise_orm"] = tortoise_orm
table["location"] = location
table["src_folder"] = src_folder
doc["tool"]["aerich"] = table
parser.add_section(name) with open(config_file, "w") as f:
parser.set(name, "tortoise_orm", tortoise_orm) f.write(tomlkit.dumps(doc))
parser.set(name, "location", location)
with open(config_file, "w", encoding="utf-8") as f:
parser.write(f)
Path(location).mkdir(parents=True, exist_ok=True) Path(location).mkdir(parents=True, exist_ok=True)
click.secho(f"Success create migrate location {location}", fg=Color.green) click.secho(f"Success create migrate location {location}", fg=Color.green)
click.secho(f"Success generate config file {config_file}", fg=Color.green) click.secho(f"Success write config to {config_file}", fg=Color.green)
@cli.command(help="Generate schema and generate app migrate location.") @cli.command(help="Generate schema and generate app migrate location.")
@@ -259,37 +224,18 @@ async def init(
@click.pass_context @click.pass_context
@coro @coro
async def init_db(ctx: Context, safe): async def init_db(ctx: Context, safe):
config = ctx.obj["config"] command = ctx.obj["command"]
location = ctx.obj["location"] app = command.app
app = ctx.obj["app"] dirname = Path(command.location, app)
dirname = Path(location, app)
try: try:
dirname.mkdir(parents=True) await command.init_db(safe)
click.secho(f"Success create app migrate location {dirname}", fg=Color.green) click.secho(f"Success create app migrate location {dirname}", fg=Color.green)
click.secho(f'Success generate schema for app "{app}"', fg=Color.green)
except FileExistsError: except FileExistsError:
return click.secho( return click.secho(
f"Inited {app} already, or delete {dirname} and try again.", fg=Color.yellow f"Inited {app} already, or delete {dirname} and try again.", fg=Color.yellow
) )
await Tortoise.init(config=config)
connection = get_app_connection(config, app)
await generate_schema_for_client(connection, safe)
schema = get_schema_sql(connection, safe)
version = await Migrate.generate_version()
await Aerich.create(
version=version,
app=app,
content=get_models_describe(app),
)
content = {
"upgrade": [schema],
}
write_version_file(Path(dirname, version), content)
click.secho(f'Success generate schema for app "{app}"', fg=Color.green)
@cli.command(help="Introspects the database tables to standard output as TortoiseORM model.") @cli.command(help="Introspects the database tables to standard output as TortoiseORM model.")
@click.option( @click.option(
@@ -302,16 +248,12 @@ async def init_db(ctx: Context, safe):
@click.pass_context @click.pass_context
@coro @coro
async def inspectdb(ctx: Context, table: List[str]): async def inspectdb(ctx: Context, table: List[str]):
config = ctx.obj["config"] command = ctx.obj["command"]
app = ctx.obj["app"] ret = await command.inspectdb(table)
connection = get_app_connection(config, app) click.secho(ret)
inspect = InspectDb(connection, table)
await inspect.inspect()
def main(): def main():
sys.path.insert(0, ".")
cli() cli()

31
aerich/coder.py Normal file
View File

@@ -0,0 +1,31 @@
import base64
import json
import pickle # nosec: B301,B403
from tortoise.indexes import Index
class JsonEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Index):
return {
"type": "index",
"val": base64.b64encode(pickle.dumps(obj)).decode(), # nosec: B301
}
else:
return super().default(obj)
def object_hook(obj):
_type = obj.get("type")
if not _type:
return obj
return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301
def encoder(obj: dict):
return json.dumps(obj, cls=JsonEncoder)
def decoder(obj: str):
return json.loads(obj, object_hook=object_hook)

View File

@@ -4,6 +4,8 @@ from typing import List, Type
from tortoise import BaseDBAsyncClient, Model from tortoise import BaseDBAsyncClient, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from aerich.utils import is_default_function
class BaseDDL: class BaseDDL:
schema_generator_cls: Type[BaseSchemaGenerator] = BaseSchemaGenerator schema_generator_cls: Type[BaseSchemaGenerator] = BaseSchemaGenerator
@@ -26,6 +28,7 @@ class BaseDDL:
_CHANGE_COLUMN_TEMPLATE = ( _CHANGE_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" CHANGE {old_column_name} {new_column_name} {new_column_type}' 'ALTER TABLE "{table_name}" CHANGE {old_column_name} {new_column_name} {new_column_type}'
) )
_RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"'
def __init__(self, client: "BaseDBAsyncClient"): def __init__(self, client: "BaseDBAsyncClient"):
self.client = client self.client = client
@@ -75,7 +78,11 @@ class BaseDDL:
auto_now_add = field_describe.get("auto_now_add", False) auto_now_add = field_describe.get("auto_now_add", False)
auto_now = field_describe.get("auto_now", False) auto_now = field_describe.get("auto_now", False)
if default is not None or auto_now_add: if default is not None or auto_now_add:
if field_describe.get("field_type") in ["UUIDField", "TextField", "JSONField"]: if field_describe.get("field_type") in [
"UUIDField",
"TextField",
"JSONField",
] or is_default_function(default):
default = "" default = ""
else: else:
try: try:
@@ -184,6 +191,12 @@ class BaseDDL:
table_name=model._meta.db_table, table_name=model._meta.db_table,
) )
def drop_index_by_name(self, model: "Type[Model]", index_name: str):
return self._DROP_INDEX_TEMPLATE.format(
index_name=index_name,
table_name=model._meta.db_table,
)
def add_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict): def add_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict):
db_table = model._meta.db_table db_table = model._meta.db_table
@@ -226,7 +239,13 @@ class BaseDDL:
) )
def alter_column_null(self, model: "Type[Model]", field_describe: dict): def alter_column_null(self, model: "Type[Model]", field_describe: dict):
raise NotImplementedError return self.modify_column(model, field_describe)
def set_comment(self, model: "Type[Model]", field_describe: dict): def set_comment(self, model: "Type[Model]", field_describe: dict):
raise NotImplementedError return self.modify_column(model, field_describe)
def rename_table(self, model: "Type[Model]", old_table_name: str, new_table_name: str):
db_table = model._meta.db_table
return self._RENAME_TABLE_TEMPLATE.format(
table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name
)

View File

@@ -1,10 +1,6 @@
from typing import Type
from tortoise import Model
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
from aerich.exceptions import NotSupportError
class MysqlDDL(BaseDDL): class MysqlDDL(BaseDDL):
@@ -28,9 +24,4 @@ class MysqlDDL(BaseDDL):
_DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`" _DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`"
_M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment}" _M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment}"
_MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}" _MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}"
_RENAME_TABLE_TEMPLATE = "ALTER TABLE `{old_table_name}` RENAME TO `{new_table_name}`"
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
raise NotSupportError("Alter column null is unsupported in MySQL.")
def set_comment(self, model: "Type[Model]", field_describe: dict):
raise NotSupportError("Alter column comment is unsupported in MySQL.")

View File

@@ -12,7 +12,9 @@ class PostgresDDL(BaseDDL):
_ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX "{index_name}" ON "{table_name}" ({column_names})' _ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX "{index_name}" ON "{table_name}" ({column_names})'
_DROP_INDEX_TEMPLATE = 'DROP INDEX "{index_name}"' _DROP_INDEX_TEMPLATE = 'DROP INDEX "{index_name}"'
_ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL' _ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL'
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}' _MODIFY_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}{using}'
)
_SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}' _SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"'
@@ -27,10 +29,13 @@ class PostgresDDL(BaseDDL):
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False): def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
db_table = model._meta.db_table db_table = model._meta.db_table
db_field_types = field_describe.get("db_field_types") db_field_types = field_describe.get("db_field_types")
db_column = field_describe.get("db_column")
datatype = db_field_types.get(self.DIALECT) or db_field_types.get("")
return self._MODIFY_COLUMN_TEMPLATE.format( return self._MODIFY_COLUMN_TEMPLATE.format(
table_name=db_table, table_name=db_table,
column=field_describe.get("db_column"), column=db_column,
datatype=db_field_types.get(self.DIALECT) or db_field_types.get(""), datatype=datatype,
using=f' USING "{db_column}"::{datatype}',
) )
def set_comment(self, model: "Type[Model]", field_describe: dict): def set_comment(self, model: "Type[Model]", field_describe: dict):

View File

@@ -11,9 +11,6 @@ class SqliteDDL(BaseDDL):
schema_generator_cls = SqliteSchemaGenerator schema_generator_cls = SqliteSchemaGenerator
DIALECT = SqliteSchemaGenerator.DIALECT DIALECT = SqliteSchemaGenerator.DIALECT
def drop_column(self, model: "Type[Model]", column_name: str):
raise NotSupportError("Drop column is unsupported in SQLite.")
def modify_column(self, model: "Type[Model]", field_object: dict, is_pk: bool = True): def modify_column(self, model: "Type[Model]", field_object: dict, is_pk: bool = True):
raise NotSupportError("Modify column is unsupported in SQLite.") raise NotSupportError("Modify column is unsupported in SQLite.")

View File

@@ -2,3 +2,9 @@ class NotSupportError(Exception):
""" """
raise when features not support raise when features not support
""" """
class DowngradeError(Exception):
"""
raise when downgrade error
"""

163
aerich/inspect/__init__.py Normal file
View File

@@ -0,0 +1,163 @@
from typing import Any, List, Optional
from pydantic import BaseModel
from tortoise import BaseDBAsyncClient
class Column(BaseModel):
name: str
data_type: str
null: bool
default: Any
comment: Optional[str]
pk: bool
unique: bool
index: bool
length: Optional[int]
extra: Optional[str]
decimal_places: Optional[int]
max_digits: Optional[int]
def translate(self) -> dict:
comment = default = length = index = null = pk = ""
if self.pk:
pk = "pk=True, "
else:
if self.unique:
index = "unique=True, "
else:
if self.index:
index = "index=True, "
if self.data_type in ["varchar", "VARCHAR"]:
length = f"max_length={self.length}, "
if self.data_type == "decimal":
length = f"max_digits={self.max_digits}, decimal_places={self.decimal_places}, "
if self.null:
null = "null=True, "
if self.default is not None:
if self.data_type in ["tinyint", "INT"]:
default = f"default={'True' if self.default == '1' else 'False'}, "
elif self.data_type == "bool":
default = f"default={'True' if self.default == 'true' else 'False'}, "
elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]:
if "CURRENT_TIMESTAMP" == self.default:
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
default = "auto_now=True, "
else:
default = "auto_now_add=True, "
else:
if "::" in self.default:
default = f"default={self.default.split('::')[0]}, "
elif self.default.endswith("()"):
default = ""
else:
default = f"default={self.default}, "
if self.comment:
comment = f"description='{self.comment}', "
return {
"name": self.name,
"pk": pk,
"index": index,
"null": null,
"default": default,
"length": length,
"comment": comment,
}
class Inspect:
_table_template = "class {table}(Model):\n"
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
self.conn = conn
try:
self.database = conn.database
except AttributeError:
pass
self.tables = tables
@property
def field_map(self) -> dict:
raise NotImplementedError
async def inspect(self) -> str:
if not self.tables:
self.tables = await self.get_all_tables()
result = "from tortoise import Model, fields\n\n\n"
tables = []
for table in self.tables:
columns = await self.get_columns(table)
fields = []
model = self._table_template.format(table=table.title().replace("_", ""))
for column in columns:
field = self.field_map[column.data_type](**column.translate())
fields.append(" " + field)
tables.append(model + "\n".join(fields))
return result + "\n\n\n".join(tables)
async def get_columns(self, table: str) -> List[Column]:
raise NotImplementedError
async def get_all_tables(self) -> List[str]:
raise NotImplementedError
@classmethod
def decimal_field(cls, **kwargs) -> str:
return "{name} = fields.DecimalField({pk}{index}{length}{null}{default}{comment})".format(
**kwargs
)
@classmethod
def time_field(cls, **kwargs) -> str:
return "{name} = fields.TimeField({null}{default}{comment})".format(**kwargs)
@classmethod
def date_field(cls, **kwargs) -> str:
return "{name} = fields.DateField({null}{default}{comment})".format(**kwargs)
@classmethod
def float_field(cls, **kwargs) -> str:
return "{name} = fields.FloatField({null}{default}{comment})".format(**kwargs)
@classmethod
def datetime_field(cls, **kwargs) -> str:
return "{name} = fields.DatetimeField({null}{default}{comment})".format(**kwargs)
@classmethod
def text_field(cls, **kwargs) -> str:
return "{name} = fields.TextField({null}{default}{comment})".format(**kwargs)
@classmethod
def char_field(cls, **kwargs) -> str:
return "{name} = fields.CharField({pk}{index}{length}{null}{default}{comment})".format(
**kwargs
)
@classmethod
def int_field(cls, **kwargs) -> str:
return "{name} = fields.IntField({pk}{index}{comment})".format(**kwargs)
@classmethod
def smallint_field(cls, **kwargs) -> str:
return "{name} = fields.SmallIntField({pk}{index}{comment})".format(**kwargs)
@classmethod
def bigint_field(cls, **kwargs) -> str:
return "{name} = fields.BigIntField({pk}{index}{default}{comment})".format(**kwargs)
@classmethod
def bool_field(cls, **kwargs) -> str:
return "{name} = fields.BooleanField({null}{default}{comment})".format(**kwargs)
@classmethod
def uuid_field(cls, **kwargs) -> str:
return "{name} = fields.UUIDField({pk}{index}{default}{comment})".format(**kwargs)
@classmethod
def json_field(cls, **kwargs) -> str:
return "{name} = fields.JSONField({null}{default}{comment})".format(**kwargs)
@classmethod
def binary_field(cls, **kwargs) -> str:
return "{name} = fields.BinaryField({null}{default}{comment})".format(**kwargs)

69
aerich/inspect/mysql.py Normal file
View File

@@ -0,0 +1,69 @@
from typing import List
from aerich.inspect import Column, Inspect
class InspectMySQL(Inspect):
@property
def field_map(self) -> dict:
return {
"int": self.int_field,
"smallint": self.smallint_field,
"tinyint": self.bool_field,
"bigint": self.bigint_field,
"varchar": self.char_field,
"longtext": self.text_field,
"text": self.text_field,
"datetime": self.datetime_field,
"float": self.float_field,
"date": self.date_field,
"time": self.time_field,
"decimal": self.decimal_field,
"json": self.json_field,
"longblob": self.binary_field,
}
async def get_all_tables(self) -> List[str]:
sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s"
ret = await self.conn.execute_query_dict(sql, [self.database])
return list(map(lambda x: x["TABLE_NAME"], ret))
async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME
from information_schema.COLUMNS c
left join information_schema.STATISTICS s on c.TABLE_NAME = s.TABLE_NAME
and c.TABLE_SCHEMA = s.TABLE_SCHEMA
and c.COLUMN_NAME = s.COLUMN_NAME
where c.TABLE_SCHEMA = %s
and c.TABLE_NAME = %s"""
ret = await self.conn.execute_query_dict(sql, [self.database, table])
for row in ret:
non_unique = row["NON_UNIQUE"]
if non_unique is None:
unique = False
else:
unique = not non_unique
index_name = row["INDEX_NAME"]
if index_name is None:
index = False
else:
index = row["INDEX_NAME"] != "PRIMARY"
columns.append(
Column(
name=row["COLUMN_NAME"],
data_type=row["DATA_TYPE"],
null=row["IS_NULLABLE"] == "YES",
default=row["COLUMN_DEFAULT"],
pk=row["COLUMN_KEY"] == "PRI",
comment=row["COLUMN_COMMENT"],
unique=row["COLUMN_KEY"] == "UNI",
extra=row["EXTRA"],
unque=unique,
index=index,
length=row["CHARACTER_MAXIMUM_LENGTH"],
max_digits=row["NUMERIC_PRECISION"],
decimal_places=row["NUMERIC_SCALE"],
)
)
return columns

View File

@@ -0,0 +1,75 @@
from typing import List, Optional
from tortoise import BaseDBAsyncClient
from aerich.inspect import Column, Inspect
class InspectPostgres(Inspect):
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
super().__init__(conn, tables)
self.schema = self.conn.server_settings.get("schema") or "public"
@property
def field_map(self) -> dict:
return {
"int4": self.int_field,
"int8": self.int_field,
"smallint": self.smallint_field,
"varchar": self.char_field,
"text": self.text_field,
"bigint": self.bigint_field,
"timestamptz": self.datetime_field,
"float4": self.float_field,
"float8": self.float_field,
"date": self.date_field,
"time": self.time_field,
"decimal": self.decimal_field,
"uuid": self.uuid_field,
"jsonb": self.json_field,
"bytea": self.binary_field,
"bool": self.bool_field,
"timestamp": self.datetime_field,
}
async def get_all_tables(self) -> List[str]:
sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2"
ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
return list(map(lambda x: x["table_name"], ret))
async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = f"""select c.column_name,
col_description('public.{table}'::regclass, ordinal_position) as column_comment,
t.constraint_type as column_key,
udt_name as data_type,
is_nullable,
column_default,
character_maximum_length,
numeric_precision,
numeric_scale
from information_schema.constraint_column_usage const
join information_schema.table_constraints t
using (table_catalog, table_schema, table_name, constraint_catalog, constraint_schema, constraint_name)
right join information_schema.columns c using (column_name, table_catalog, table_schema, table_name)
where c.table_catalog = $1
and c.table_name = $2
and c.table_schema = $3"""
ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema])
for row in ret:
columns.append(
Column(
name=row["column_name"],
data_type=row["data_type"],
null=row["is_nullable"] == "YES",
default=row["column_default"],
length=row["character_maximum_length"],
max_digits=row["numeric_precision"],
decimal_places=row["numeric_scale"],
comment=row["column_comment"],
pk=row["column_key"] == "PRIMARY KEY",
unique=False, # can't get this simply
index=False, # can't get this simply
)
)
return columns

61
aerich/inspect/sqlite.py Normal file
View File

@@ -0,0 +1,61 @@
from typing import List
from aerich.inspect import Column, Inspect
class InspectSQLite(Inspect):
@property
def field_map(self) -> dict:
return {
"INTEGER": self.int_field,
"INT": self.bool_field,
"SMALLINT": self.smallint_field,
"VARCHAR": self.char_field,
"TEXT": self.text_field,
"TIMESTAMP": self.datetime_field,
"REAL": self.float_field,
"BIGINT": self.bigint_field,
"DATE": self.date_field,
"TIME": self.time_field,
"JSON": self.json_field,
"BLOB": self.binary_field,
}
async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = f"PRAGMA table_info({table})"
ret = await self.conn.execute_query_dict(sql)
columns_index = await self._get_columns_index(table)
for row in ret:
try:
length = row["type"].split("(")[1].split(")")[0]
except IndexError:
length = None
columns.append(
Column(
name=row["name"],
data_type=row["type"].split("(")[0],
null=row["notnull"] == 0,
default=row["dflt_value"],
length=length,
pk=row["pk"] == 1,
unique=columns_index.get(row["name"]) == "unique",
index=columns_index.get(row["name"]) == "index",
)
)
return columns
async def _get_columns_index(self, table: str):
sql = f"PRAGMA index_list ({table})"
indexes = await self.conn.execute_query_dict(sql)
ret = {}
for index in indexes:
sql = f"PRAGMA index_info({index['name']})"
index_info = (await self.conn.execute_query_dict(sql))[0]
ret[index_info["name"]] = "unique" if index["unique"] else "index"
return ret
async def get_all_tables(self) -> List[str]:
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
ret = await self.conn.execute_query_dict(sql)
return list(map(lambda x: x["tbl_name"], ret))

View File

@@ -1,85 +0,0 @@
import sys
from typing import List, Optional
from ddlparse import DdlParse
from tortoise import BaseDBAsyncClient
class InspectDb:
_table_template = "class {table}(Model):\n"
_field_template_mapping = {
"INT": " {field} = fields.IntField({pk}{unique}{comment})",
"SMALLINT": " {field} = fields.IntField({pk}{unique}{comment})",
"TINYINT": " {field} = fields.BooleanField({null}{default}{comment})",
"VARCHAR": " {field} = fields.CharField({pk}{unique}{length}{null}{default}{comment})",
"LONGTEXT": " {field} = fields.TextField({null}{default}{comment})",
"TEXT": " {field} = fields.TextField({null}{default}{comment})",
"DATETIME": " {field} = fields.DatetimeField({null}{default}{comment})",
}
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
self.conn = conn
self.tables = tables
self.DIALECT = conn.schema_generator.DIALECT
async def show_create_tables(self):
if self.DIALECT == "mysql":
if not self.tables:
sql_tables = f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{self.conn.database}';" # nosec: B608
ret = await self.conn.execute_query(sql_tables)
self.tables = map(lambda x: x["TABLE_NAME"], ret[1])
for table in self.tables:
sql_show_create_table = f"SHOW CREATE TABLE {table}"
ret = await self.conn.execute_query(sql_show_create_table)
yield ret[1][0]["Create Table"]
else:
raise NotImplementedError("Currently only support MySQL")
async def inspect(self):
ddl_list = self.show_create_tables()
result = "from tortoise import Model, fields\n\n\n"
tables = []
async for ddl in ddl_list:
parser = DdlParse(ddl, DdlParse.DATABASE.mysql)
table = parser.parse()
name = table.name.title()
columns = table.columns
fields = []
model = self._table_template.format(table=name)
for column_name, column in columns.items():
comment = default = length = unique = null = pk = ""
if column.primary_key:
pk = "pk=True, "
if column.unique:
unique = "unique=True, "
if column.data_type == "VARCHAR":
length = f"max_length={column.length}, "
if not column.not_null:
null = "null=True, "
if column.default is not None:
if column.data_type == "TINYINT":
default = f"default={'True' if column.default == '1' else 'False'}, "
elif column.data_type == "DATETIME":
if "CURRENT_TIMESTAMP" in column.default:
if "ON UPDATE CURRENT_TIMESTAMP" in ddl:
default = "auto_now_add=True, "
else:
default = "auto_now=True, "
else:
default = f"default={column.default}, "
if column.comment:
comment = f"description='{column.comment}', "
field = self._field_template_mapping[column.data_type].format(
field=column_name,
pk=pk,
unique=unique,
length=length,
null=null,
default=default,
comment=comment,
)
fields.append(field)
tables.append(model + "\n".join(fields))
sys.stdout.write(result + "\n\n\n".join(tables))

View File

@@ -1,16 +1,23 @@
import os import os
from datetime import datetime from datetime import datetime
from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Tuple, Type, Union
import click import click
from dictdiffer import diff from dictdiffer import diff
from tortoise import BaseDBAsyncClient, Model, Tortoise from tortoise import BaseDBAsyncClient, Model, Tortoise
from tortoise.exceptions import OperationalError from tortoise.exceptions import OperationalError
from tortoise.indexes import Index
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
from aerich.models import MAX_VERSION_LENGTH, Aerich from aerich.models import MAX_VERSION_LENGTH, Aerich
from aerich.utils import get_app_connection, get_models_describe, write_version_file from aerich.utils import (
get_app_connection,
get_models_describe,
is_default_function,
write_version_file,
)
class Migrate: class Migrate:
@@ -27,7 +34,7 @@ class Migrate:
ddl: BaseDDL ddl: BaseDDL
_last_version_content: Optional[dict] = None _last_version_content: Optional[dict] = None
app: str app: str
migrate_location: str migrate_location: Path
dialect: str dialect: str
_db_version: Optional[str] = None _db_version: Optional[str] = None
@@ -108,8 +115,8 @@ class Migrate:
if version_file.startswith(version.split("_")[0]): if version_file.startswith(version.split("_")[0]):
os.unlink(Path(cls.migrate_location, version_file)) os.unlink(Path(cls.migrate_location, version_file))
content = { content = {
"upgrade": cls.upgrade_operators, "upgrade": list(dict.fromkeys(cls.upgrade_operators)),
"downgrade": cls.downgrade_operators, "downgrade": list(dict.fromkeys(cls.downgrade_operators)),
} }
write_version_file(Path(cls.migrate_location, version), content) write_version_file(Path(cls.migrate_location, version), content)
return version return version
@@ -133,25 +140,37 @@ class Migrate:
return await cls._generate_diff_sql(name) return await cls._generate_diff_sql(name)
@classmethod @classmethod
def _add_operator(cls, operator: str, upgrade=True, fk_m2m=False): def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False):
""" """
add operator,differentiate fk because fk is order limit add operator,differentiate fk because fk is order limit
:param operator: :param operator:
:param upgrade: :param upgrade:
:param fk_m2m: :param fk_m2m_index:
:return: :return:
""" """
if upgrade: if upgrade:
if fk_m2m: if fk_m2m_index:
cls._upgrade_fk_m2m_index_operators.append(operator) cls._upgrade_fk_m2m_index_operators.append(operator)
else: else:
cls.upgrade_operators.append(operator) cls.upgrade_operators.append(operator)
else: else:
if fk_m2m: if fk_m2m_index:
cls._downgrade_fk_m2m_index_operators.append(operator) cls._downgrade_fk_m2m_index_operators.append(operator)
else: else:
cls.downgrade_operators.append(operator) cls.downgrade_operators.append(operator)
@classmethod
def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]):
ret = []
for index in indexes:
if isinstance(index, Index):
index.__hash__ = lambda self: md5( # nosec: B303
self.index_name(cls.ddl.schema_generator, model).encode()
+ self.__class__.__name__.encode()
).hexdigest()
ret.append(index)
return ret
@classmethod @classmethod
def diff_models(cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True): def diff_models(cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True):
""" """
@@ -176,14 +195,29 @@ class Migrate:
pass pass
else: else:
old_model_describe = old_models.get(new_model_str) old_model_describe = old_models.get(new_model_str)
# rename table
new_table = new_model_describe.get("table")
old_table = old_model_describe.get("table")
if new_table != old_table:
cls._add_operator(cls.rename_table(model, old_table, new_table), upgrade)
old_unique_together = set( old_unique_together = set(
map(lambda x: tuple(x), old_model_describe.get("unique_together")) map(lambda x: tuple(x), old_model_describe.get("unique_together"))
) )
new_unique_together = set( new_unique_together = set(
map(lambda x: tuple(x), new_model_describe.get("unique_together")) map(lambda x: tuple(x), new_model_describe.get("unique_together"))
) )
old_indexes = set(
map(
lambda x: x if isinstance(x, Index) else tuple(x),
cls._handle_indexes(model, old_model_describe.get("indexes", [])),
)
)
new_indexes = set(
map(
lambda x: x if isinstance(x, Index) else tuple(x),
cls._handle_indexes(model, new_model_describe.get("indexes", [])),
)
)
old_pk_field = old_model_describe.get("pk_field") old_pk_field = old_model_describe.get("pk_field")
new_pk_field = new_model_describe.get("pk_field") new_pk_field = new_model_describe.get("pk_field")
# pk field # pk field
@@ -196,6 +230,8 @@ class Migrate:
old_m2m_fields = old_model_describe.get("m2m_fields") old_m2m_fields = old_model_describe.get("m2m_fields")
new_m2m_fields = new_model_describe.get("m2m_fields") new_m2m_fields = new_model_describe.get("m2m_fields")
for action, option, change in diff(old_m2m_fields, new_m2m_fields): for action, option, change in diff(old_m2m_fields, new_m2m_fields):
if change[0][0] == "db_constraint":
continue
table = change[0][1].get("through") table = change[0][1].get("through")
if action == "add": if action == "add":
add = False add = False
@@ -213,7 +249,7 @@ class Migrate:
new_models.get(change[0][1].get("model_name")), new_models.get(change[0][1].get("model_name")),
), ),
upgrade, upgrade,
fk_m2m=True, fk_m2m_index=True,
) )
elif action == "remove": elif action == "remove":
add = False add = False
@@ -224,20 +260,19 @@ class Migrate:
cls._downgrade_m2m.append(table) cls._downgrade_m2m.append(table)
add = True add = True
if add: if add:
cls._add_operator(cls.drop_m2m(table), upgrade, fk_m2m=True) cls._add_operator(cls.drop_m2m(table), upgrade, True)
# add unique_together # add unique_together
for index in new_unique_together.difference(old_unique_together): for index in new_unique_together.difference(old_unique_together):
cls._add_operator( cls._add_operator(cls._add_index(model, index, True), upgrade, True)
cls._add_index(model, index, True),
upgrade,
)
# remove unique_together # remove unique_together
for index in old_unique_together.difference(new_unique_together): for index in old_unique_together.difference(new_unique_together):
cls._add_operator( cls._add_operator(cls._drop_index(model, index, True), upgrade, True)
cls._drop_index(model, index, True), # add indexes
upgrade, for index in new_indexes.difference(old_indexes):
) cls._add_operator(cls._add_index(model, index, False), upgrade, True)
# remove indexes
for index in old_indexes.difference(new_indexes):
cls._add_operator(cls._drop_index(model, index, False), upgrade, True)
old_data_fields = old_model_describe.get("data_fields") old_data_fields = old_model_describe.get("data_fields")
new_data_fields = new_model_describe.get("data_fields") new_data_fields = new_model_describe.get("data_fields")
@@ -257,14 +292,23 @@ class Migrate:
old_data_field_name = old_data_field.get("name") old_data_field_name = old_data_field.get("name")
if len(changes) == 2: if len(changes) == 2:
# rename field # rename field
if changes[0] == ( if (
"change", changes[0]
"name", == (
(old_data_field_name, new_data_field_name), "change",
) and changes[1] == ( "name",
"change", (old_data_field_name, new_data_field_name),
"db_column", )
(old_data_field.get("db_column"), new_data_field.get("db_column")), and changes[1]
== (
"change",
"db_column",
(
old_data_field.get("db_column"),
new_data_field.get("db_column"),
),
)
and old_data_field_name not in new_data_fields_name
): ):
if upgrade: if upgrade:
is_rename = click.prompt( is_rename = click.prompt(
@@ -285,7 +329,9 @@ class Migrate:
and cls._db_version.startswith("5.") and cls._db_version.startswith("5.")
): ):
cls._add_operator( cls._add_operator(
cls._modify_field(model, new_data_field), cls._change_field(
model, old_data_field, new_data_field
),
upgrade, upgrade,
) )
else: else:
@@ -334,11 +380,14 @@ class Migrate:
fk_field = next( fk_field = next(
filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields) filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields)
) )
cls._add_operator( if fk_field.get("db_constraint"):
cls._add_fk(model, fk_field, new_models.get(fk_field.get("python_type"))), cls._add_operator(
upgrade, cls._add_fk(
fk_m2m=True, model, fk_field, new_models.get(fk_field.get("python_type"))
) ),
upgrade,
fk_m2m_index=True,
)
# drop fk # drop fk
for old_fk_field_name in set(old_fk_fields_name).difference( for old_fk_field_name in set(old_fk_fields_name).difference(
set(new_fk_fields_name) set(new_fk_fields_name)
@@ -346,13 +395,14 @@ class Migrate:
old_fk_field = next( old_fk_field = next(
filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields) filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields)
) )
cls._add_operator( if old_fk_field.get("db_constraint"):
cls._drop_fk( cls._add_operator(
model, old_fk_field, old_models.get(old_fk_field.get("python_type")) cls._drop_fk(
), model, old_fk_field, old_models.get(old_fk_field.get("python_type"))
upgrade, ),
fk_m2m=True, upgrade,
) fk_m2m_index=True,
)
# change fields # change fields
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)): for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
old_data_field = next( old_data_field = next(
@@ -369,20 +419,29 @@ class Migrate:
unique = new_data_field.get("unique") unique = new_data_field.get("unique")
if old_new[0] is False and old_new[1] is True: if old_new[0] is False and old_new[1] is True:
cls._add_operator( cls._add_operator(
cls._add_index(model, (field_name,), unique), cls._add_index(model, (field_name,), unique), upgrade, True
upgrade,
) )
else: else:
cls._add_operator( cls._add_operator(
cls._drop_index(model, (field_name,), unique), cls._drop_index(model, (field_name,), unique), upgrade, True
upgrade,
) )
elif option == "db_field_types.": elif option == "db_field_types.":
# continue since repeated with others # continue since repeated with others
continue continue
elif option == "default": elif option == "default":
# change column default if not (
cls._add_operator(cls._alter_default(model, new_data_field), upgrade) is_default_function(old_new[0]) or is_default_function(old_new[1])
):
# change column default
cls._add_operator(
cls._alter_default(model, new_data_field), upgrade
)
elif option == "unique":
# because indexed include it
continue
elif option == "nullable":
# change nullable
cls._add_operator(cls._alter_null(model, new_data_field), upgrade)
else: else:
# modify column # modify column
cls._add_operator( cls._add_operator(
@@ -394,6 +453,10 @@ class Migrate:
if old_model not in new_models.keys(): if old_model not in new_models.keys():
cls._add_operator(cls.drop_model(old_models.get(old_model).get("table")), upgrade) cls._add_operator(cls.drop_model(old_models.get(old_model).get("table")), upgrade)
@classmethod
def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str):
return cls.ddl.rename_table(model, old_table_name, new_table_name)
@classmethod @classmethod
def add_model(cls, model: Type[Model]): def add_model(cls, model: Type[Model]):
return cls.ddl.create_table(model) return cls.ddl.create_table(model)
@@ -414,19 +477,28 @@ class Migrate:
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]): def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]):
ret = [] ret = []
for field_name in fields_name: for field_name in fields_name:
if field_name in model._meta.fk_fields: field = model._meta.fields_map[field_name]
if field.source_field:
ret.append(field.source_field)
elif field_name in model._meta.fk_fields:
ret.append(field_name + "_id") ret.append(field_name + "_id")
else: else:
ret.append(field_name) ret.append(field_name)
return ret return ret
@classmethod @classmethod
def _drop_index(cls, model: Type[Model], fields_name: Tuple[str], unique=False): def _drop_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False):
if isinstance(fields_name, Index):
return cls.ddl.drop_index_by_name(
model, fields_name.index_name(cls.ddl.schema_generator, model)
)
fields_name = cls._resolve_fk_fields_name(model, fields_name) fields_name = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.drop_index(model, fields_name, unique) return cls.ddl.drop_index(model, fields_name, unique)
@classmethod @classmethod
def _add_index(cls, model: Type[Model], fields_name: Tuple[str], unique=False): def _add_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False):
if isinstance(fields_name, Index):
return fields_name.get_sql(cls.ddl.schema_generator, model, False)
fields_name = cls._resolve_fk_fields_name(model, fields_name) fields_name = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.add_index(model, fields_name, unique) return cls.ddl.add_index(model, fields_name, unique)

View File

@@ -1,12 +1,15 @@
from tortoise import Model, fields from tortoise import Model, fields
from aerich.coder import decoder, encoder
MAX_VERSION_LENGTH = 255 MAX_VERSION_LENGTH = 255
MAX_APP_LENGTH = 100
class Aerich(Model): class Aerich(Model):
version = fields.CharField(max_length=MAX_VERSION_LENGTH) version = fields.CharField(max_length=MAX_VERSION_LENGTH)
app = fields.CharField(max_length=20) app = fields.CharField(max_length=MAX_APP_LENGTH)
content = fields.JSONField() content = fields.JSONField(encoder=encoder, decoder=decoder)
class Meta: class Meta:
ordering = ["-id"] ordering = ["-id"]

View File

@@ -1,10 +1,30 @@
import importlib import importlib
from typing import Dict import os
import re
import sys
from pathlib import Path
from typing import Dict, Union
from click import BadOptionUsage, Context from click import BadOptionUsage, ClickException, Context
from tortoise import BaseDBAsyncClient, Tortoise from tortoise import BaseDBAsyncClient, Tortoise
def add_src_path(path: str) -> str:
"""
add a folder to the paths so we can import from there
:param path: path to add
:return: absolute path
"""
if not os.path.isabs(path):
# use the absolute path, otherwise some other things (e.g. __file__) won't work properly
path = os.path.abspath(path)
if not os.path.isdir(path):
raise ClickException(f"Specified source folder does not exist: {path}")
if path not in sys.path:
sys.path.insert(0, path)
return path
def get_app_connection_name(config, app_name: str) -> str: def get_app_connection_name(config, app_name: str) -> str:
""" """
get connection name get connection name
@@ -41,12 +61,11 @@ def get_tortoise_config(ctx: Context, tortoise_orm: str) -> dict:
splits = tortoise_orm.split(".") splits = tortoise_orm.split(".")
config_path = ".".join(splits[:-1]) config_path = ".".join(splits[:-1])
tortoise_config = splits[-1] tortoise_config = splits[-1]
try: try:
config_module = importlib.import_module(config_path) config_module = importlib.import_module(config_path)
except (ModuleNotFoundError, AttributeError): except ModuleNotFoundError as e:
raise BadOptionUsage( raise ClickException(f"Error while importing configuration module: {e}") from None
ctx=ctx, message=f'No config named "{config_path}"', option_name="--config"
)
config = getattr(config_module, tortoise_config, None) config = getattr(config_module, tortoise_config, None)
if not config: if not config:
@@ -62,7 +81,7 @@ _UPGRADE = "-- upgrade --\n"
_DOWNGRADE = "-- downgrade --\n" _DOWNGRADE = "-- downgrade --\n"
def get_version_content_from_file(version_file: str) -> Dict: def get_version_content_from_file(version_file: Union[str, Path]) -> Dict:
""" """
get version content get version content
:param version_file: :param version_file:
@@ -84,7 +103,7 @@ def get_version_content_from_file(version_file: str) -> Dict:
return ret return ret
def write_version_file(version_file: str, content: Dict): def write_version_file(version_file: Path, content: Dict):
""" """
write version file write version file
:param version_file: :param version_file:
@@ -95,7 +114,9 @@ def write_version_file(version_file: str, content: Dict):
f.write(_UPGRADE) f.write(_UPGRADE)
upgrade = content.get("upgrade") upgrade = content.get("upgrade")
if len(upgrade) > 1: if len(upgrade) > 1:
f.write(";\n".join(upgrade) + ";\n") f.write(";\n".join(upgrade))
if not upgrade[-1].endswith(";"):
f.write(";\n")
else: else:
f.write(f"{upgrade[0]}") f.write(f"{upgrade[0]}")
if not upgrade[0].endswith(";"): if not upgrade[0].endswith(";"):
@@ -121,3 +142,7 @@ def get_models_describe(app: str) -> Dict:
describe = model.describe() describe = model.describe()
ret[describe.get("name")] = describe ret[describe.get("name")] = describe
return ret return ret
def is_default_function(string: str):
return re.match(r"^<function.+>$", str(string or ""))

1
aerich/version.py Normal file
View File

@@ -0,0 +1 @@
__version__ = "0.6.3"

View File

@@ -20,10 +20,7 @@ tortoise_orm = {
"second": expand_db_url(db_url_second, True), "second": expand_db_url(db_url_second, True),
}, },
"apps": { "apps": {
"models": { "models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"},
"models": ["tests.models", "aerich.models"],
"default_connection": "default",
},
"models_second": {"models": ["tests.models_second"], "default_connection": "second"}, "models_second": {"models": ["tests.models_second"], "default_connection": "second"},
}, },
} }

Binary file not shown.

Before

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 76 KiB

990
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,13 @@
[tool.poetry] [tool.poetry]
name = "aerich" name = "aerich"
version = "0.5.1" version = "0.6.3"
description = "A database migrations tool for Tortoise ORM." description = "A database migrations tool for Tortoise ORM."
authors = ["long2ice <long2ice@gmail.com>"] authors = ["long2ice <long2ice@gmail.com>"]
license = "Apache-2.0" license = "Apache-2.0"
readme = "README.md" readme = "README.md"
homepage = "https://github.com/long2ice/aerich" homepage = "https://github.com/tortoise/aerich"
repository = "https://github.com/long2ice/aerich.git" repository = "https://github.com/tortoise/aerich.git"
documentation = "https://github.com/long2ice/aerich" documentation = "https://github.com/tortoise/aerich"
keywords = ["migrate", "Tortoise-ORM", "mysql"] keywords = ["migrate", "Tortoise-ORM", "mysql"]
packages = [ packages = [
{ include = "aerich" } { include = "aerich" }
@@ -16,29 +16,34 @@ include = ["CHANGELOG.md", "LICENSE", "README.md"]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.7" python = "^3.7"
tortoise-orm = "^0.16.21" tortoise-orm = "*"
click = "*" click = "*"
pydantic = "*"
aiomysql = { version = "*", optional = true }
asyncpg = { version = "*", optional = true } asyncpg = { version = "*", optional = true }
ddlparse = "*" asyncmy = { version = "*", optional = true }
pydantic = "*"
dictdiffer = "*" dictdiffer = "*"
tomlkit = "*"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
flake8 = "*" flake8 = "*"
isort = "*" isort = "*"
black = "^20.8b1" black = "*"
pytest = "*" pytest = "*"
pytest-xdist = "*" pytest-xdist = "*"
pytest-asyncio = "*" pytest-asyncio = "*"
bandit = "*" bandit = "*"
pytest-mock = "*" pytest-mock = "*"
cryptography = "*" cryptography = "*"
pyproject-flake8 = "*"
[tool.poetry.extras] [tool.poetry.extras]
asyncmy = ["asyncmy"] asyncmy = ["asyncmy"]
asyncpg = ["asyncpg"] asyncpg = ["asyncpg"]
aiomysql = ["aiomysql"]
[tool.aerich]
tortoise_orm = "conftest.tortoise_orm"
location = "./migrations"
src_folder = "./."
[build-system] [build-system]
requires = ["poetry>=0.12"] requires = ["poetry>=0.12"]
@@ -46,3 +51,17 @@ build-backend = "poetry.masonry.api"
[tool.poetry.scripts] [tool.poetry.scripts]
aerich = "aerich.cli:main" aerich = "aerich.cli:main"
[tool.black]
line-length = 100
target-version = ['py36', 'py37', 'py38', 'py39']
[tool.pytest.ini_options]
asyncio_mode = 'auto'
[tool.mypy]
pretty = true
ignore_missing_imports = true
[tool.flake8]
ignore = 'E501,W503,E203'

View File

@@ -1,2 +0,0 @@
[flake8]
ignore = E501,W503

View File

@@ -1,4 +1,5 @@
import datetime import datetime
import uuid
from enum import IntEnum from enum import IntEnum
from tortoise import Model, fields from tortoise import Model, fields
@@ -38,9 +39,13 @@ class Email(Model):
users = fields.ManyToManyField("models.User") users = fields.ManyToManyField("models.User")
def default_name():
return uuid.uuid4()
class Category(Model): class Category(Model):
slug = fields.CharField(max_length=100) slug = fields.CharField(max_length=100)
name = fields.CharField(max_length=200, null=True) name = fields.CharField(max_length=200, null=True, default=default_name)
user = fields.ForeignKeyField("models.User", description="User") user = fields.ForeignKeyField("models.User", description="User")
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
@@ -51,13 +56,16 @@ class Product(Model):
view_num = fields.IntField(description="View Num", default=0) view_num = fields.IntField(description="View Num", default=0)
sort = fields.IntField() sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed") is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField(ProductType, description="Product Type") type = fields.IntEnumField(
ProductType, description="Product Type", source_field="type_db_alias"
)
pic = fields.CharField(max_length=200) pic = fields.CharField(max_length=200)
body = fields.TextField() body = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Meta: class Meta:
unique_together = (("name", "type"),) unique_together = (("name", "type"),)
indexes = (("name", "type"),)
class Config(Model): class Config(Model):

View File

@@ -50,7 +50,9 @@ class Product(Model):
view_num = fields.IntField(description="View Num") view_num = fields.IntField(description="View Num")
sort = fields.IntField() sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed") is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField(ProductType, description="Product Type") type = fields.IntEnumField(
ProductType, description="Product Type", source_field="type_db_alias"
)
image = fields.CharField(max_length=200) image = fields.CharField(max_length=200)
body = fields.TextField() body = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)

View File

@@ -50,7 +50,9 @@ class Product(Model):
view_num = fields.IntField(description="View Num") view_num = fields.IntField(description="View Num")
sort = fields.IntField() sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed") is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField(ProductType, description="Product Type") type = fields.IntEnumField(
ProductType, description="Product Type", source_field="type_db_alias"
)
image = fields.CharField(max_length=200) image = fields.CharField(max_length=200)
body = fields.TextField() body = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
@@ -61,3 +63,6 @@ class Config(Model):
key = fields.CharField(max_length=20) key = fields.CharField(max_length=20)
value = fields.JSONField() value = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on) status: Status = fields.IntEnumField(Status, default=Status.on)
class Meta:
table = "configs"

View File

@@ -1,9 +1,6 @@
import pytest
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
from aerich.exceptions import NotSupportError
from aerich.migrate import Migrate from aerich.migrate import Migrate
from tests.models import Category, Product, User from tests.models import Category, Product, User
@@ -76,7 +73,10 @@ def test_modify_column():
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)" assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)"
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert ret0 == 'ALTER TABLE "category" ALTER COLUMN "name" TYPE VARCHAR(200)' assert (
ret0
== 'ALTER TABLE "category" ALTER COLUMN "name" TYPE VARCHAR(200) USING "name"::VARCHAR(200)'
)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ( assert (
@@ -84,19 +84,19 @@ def test_modify_column():
== "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1" == "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1"
) )
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert ret1 == 'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL' assert (
ret1 == 'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL USING "is_active"::BOOL'
)
def test_alter_column_default(): def test_alter_column_default():
if isinstance(Migrate.ddl, SqliteDDL): if isinstance(Migrate.ddl, SqliteDDL):
return return
ret = Migrate.ddl.alter_column_default( ret = Migrate.ddl.alter_column_default(User, User._meta.fields_map.get("intro").describe(False))
Category, Category._meta.fields_map.get("name").describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" DROP DEFAULT' assert ret == 'ALTER TABLE "user" ALTER COLUMN "intro" SET DEFAULT \'\''
elif isinstance(Migrate.ddl, MysqlDDL): elif isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` ALTER COLUMN `name` DROP DEFAULT" assert ret == "ALTER TABLE `user` ALTER COLUMN `intro` SET DEFAULT ''"
ret = Migrate.ddl.alter_column_default( ret = Migrate.ddl.alter_column_default(
Category, Category._meta.fields_map.get("created_at").describe(False) Category, Category._meta.fields_map.get("created_at").describe(False)
@@ -141,11 +141,7 @@ def test_set_comment():
def test_drop_column(): def test_drop_column():
if isinstance(Migrate.ddl, SqliteDDL): ret = Migrate.ddl.drop_column(Category, "name")
with pytest.raises(NotSupportError):
ret = Migrate.ddl.drop_column(Category, "name")
else:
ret = Migrate.ddl.drop_column(Category, "name")
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` DROP COLUMN `name`" assert ret == "ALTER TABLE `category` DROP COLUMN `name`"
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):

View File

@@ -17,6 +17,7 @@ old_models_describe = {
"description": None, "description": None,
"docstring": None, "docstring": None,
"unique_together": [], "unique_together": [],
"indexes": [],
"pk_field": { "pk_field": {
"name": "id", "name": "id",
"field_type": "IntField", "field_type": "IntField",
@@ -146,11 +147,12 @@ old_models_describe = {
"models.Config": { "models.Config": {
"name": "models.Config", "name": "models.Config",
"app": "models", "app": "models",
"table": "config", "table": "configs",
"abstract": False, "abstract": False,
"description": None, "description": None,
"docstring": None, "docstring": None,
"unique_together": [], "unique_together": [],
"indexes": [],
"pk_field": { "pk_field": {
"name": "id", "name": "id",
"field_type": "IntField", "field_type": "IntField",
@@ -242,6 +244,7 @@ old_models_describe = {
"description": None, "description": None,
"docstring": None, "docstring": None,
"unique_together": [], "unique_together": [],
"indexes": [],
"pk_field": { "pk_field": {
"name": "id", "name": "id",
"field_type": "IntField", "field_type": "IntField",
@@ -334,6 +337,7 @@ old_models_describe = {
"description": None, "description": None,
"docstring": None, "docstring": None,
"unique_together": [], "unique_together": [],
"indexes": [],
"pk_field": { "pk_field": {
"name": "id", "name": "id",
"field_type": "IntField", "field_type": "IntField",
@@ -413,7 +417,7 @@ old_models_describe = {
{ {
"name": "type", "name": "type",
"field_type": "IntEnumFieldInstance", "field_type": "IntEnumFieldInstance",
"db_column": "type", "db_column": "type_db_alias",
"python_type": "int", "python_type": "int",
"generated": False, "generated": False,
"nullable": False, "nullable": False,
@@ -512,6 +516,7 @@ old_models_describe = {
"description": None, "description": None,
"docstring": None, "docstring": None,
"unique_together": [], "unique_together": [],
"indexes": [],
"pk_field": { "pk_field": {
"name": "id", "name": "id",
"field_type": "IntField", "field_type": "IntField",
@@ -681,6 +686,7 @@ old_models_describe = {
"description": None, "description": None,
"docstring": None, "docstring": None,
"unique_together": [], "unique_together": [],
"indexes": [],
"pk_field": { "pk_field": {
"name": "id", "name": "id",
"field_type": "IntField", "field_type": "IntField",
@@ -768,7 +774,7 @@ def test_migrate(mocker: MockerFixture):
- alter default: Config.status - alter default: Config.status
- rename column: Product.image -> Product.pic - rename column: Product.image -> Product.pic
""" """
mocker.patch("click.prompt", side_effect=(False, True)) mocker.patch("click.prompt", side_effect=(True,))
models_describe = get_models_describe("models") models_describe = get_models_describe("models")
Migrate.app = "models" Migrate.app = "models"
@@ -790,15 +796,15 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT", "ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT",
"ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL", "ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL",
"ALTER TABLE `email` DROP COLUMN `user_id`", "ALTER TABLE `email` DROP COLUMN `user_id`",
"ALTER TABLE `configs` RENAME TO `config`",
"ALTER TABLE `product` RENAME COLUMN `image` TO `pic`", "ALTER TABLE `product` RENAME COLUMN `image` TO `pic`",
"ALTER TABLE `email` RENAME COLUMN `id` TO `email_id`", "ALTER TABLE `email` RENAME COLUMN `id` TO `email_id`",
"ALTER TABLE `email` DROP FOREIGN KEY `fk_email_user_5b58673d`", "ALTER TABLE `product` ADD INDEX `idx_product_name_869427` (`name`, `type_db_alias`)",
"ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)", "ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)",
"ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_f14935` (`name`, `type`)", "ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_869427` (`name`, `type_db_alias`)",
"ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0", "ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0",
"ALTER TABLE `user` DROP COLUMN `avatar`", "ALTER TABLE `user` DROP COLUMN `avatar`",
"ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(100) NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(100) NOT NULL",
"ALTER TABLE `user` MODIFY COLUMN `username` VARCHAR(20) NOT NULL",
"CREATE TABLE IF NOT EXISTS `newmodel` (\n `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,\n `name` VARCHAR(50) NOT NULL\n) CHARACTER SET utf8mb4;", "CREATE TABLE IF NOT EXISTS `newmodel` (\n `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,\n `name` VARCHAR(50) NOT NULL\n) CHARACTER SET utf8mb4;",
"ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)", "ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)",
"CREATE TABLE `email_user` (`email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,`user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE) CHARACTER SET utf8mb4", "CREATE TABLE `email_user` (`email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,`user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE) CHARACTER SET utf8mb4",
@@ -814,16 +820,16 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1", "ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1",
"ALTER TABLE `email` ADD `user_id` INT NOT NULL", "ALTER TABLE `email` ADD `user_id` INT NOT NULL",
"ALTER TABLE `email` DROP COLUMN `address`", "ALTER TABLE `email` DROP COLUMN `address`",
"ALTER TABLE `config` RENAME TO `configs`",
"ALTER TABLE `product` RENAME COLUMN `pic` TO `image`", "ALTER TABLE `product` RENAME COLUMN `pic` TO `image`",
"ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`", "ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`",
"ALTER TABLE `email` ADD CONSTRAINT `fk_email_user_5b58673d` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE", "ALTER TABLE `product` DROP INDEX `idx_product_name_869427`",
"ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`", "ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`",
"ALTER TABLE `product` DROP INDEX `uid_product_name_f14935`", "ALTER TABLE `product` DROP INDEX `uid_product_name_869427`",
"ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT", "ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT",
"ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''", "ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''",
"ALTER TABLE `user` DROP INDEX `idx_user_usernam_9987ab`", "ALTER TABLE `user` DROP INDEX `idx_user_usernam_9987ab`",
"ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(200) NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(200) NOT NULL",
"ALTER TABLE `user` MODIFY COLUMN `username` VARCHAR(20) NOT NULL",
"DROP TABLE IF EXISTS `email_user`", "DROP TABLE IF EXISTS `email_user`",
"DROP TABLE IF EXISTS `newmodel`", "DROP TABLE IF EXISTS `newmodel`",
] ]
@@ -832,46 +838,46 @@ def test_migrate(mocker: MockerFixture):
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert sorted(Migrate.upgrade_operators) == sorted( assert sorted(Migrate.upgrade_operators) == sorted(
[ [
'ALTER TABLE "category" ALTER COLUMN "name" TYPE VARCHAR(200)', 'ALTER TABLE "category" ALTER COLUMN "name" DROP NOT NULL',
'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(100)', 'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(100) USING "slug"::VARCHAR(100)',
'ALTER TABLE "config" ADD "user_id" INT NOT NULL', 'ALTER TABLE "config" ADD "user_id" INT NOT NULL',
'ALTER TABLE "config" ADD CONSTRAINT "fk_config_user_17daa970" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE', 'ALTER TABLE "config" ADD CONSTRAINT "fk_config_user_17daa970" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE',
'ALTER TABLE "config" ALTER COLUMN "status" DROP DEFAULT', 'ALTER TABLE "config" ALTER COLUMN "status" DROP DEFAULT',
'ALTER TABLE "configs" RENAME TO "config"',
'ALTER TABLE "email" ADD "address" VARCHAR(200) NOT NULL', 'ALTER TABLE "email" ADD "address" VARCHAR(200) NOT NULL',
'ALTER TABLE "email" DROP COLUMN "user_id"', 'ALTER TABLE "email" DROP COLUMN "user_id"',
'ALTER TABLE "product" RENAME COLUMN "image" TO "pic"',
'ALTER TABLE "email" RENAME COLUMN "id" TO "email_id"', 'ALTER TABLE "email" RENAME COLUMN "id" TO "email_id"',
'ALTER TABLE "email" DROP CONSTRAINT "fk_email_user_5b58673d"',
'CREATE INDEX "idx_email_email_4a1a33" ON "email" ("email")',
'ALTER TABLE "user" ALTER COLUMN "username" TYPE VARCHAR(20)',
'CREATE UNIQUE INDEX "uid_product_name_f14935" ON "product" ("name", "type")',
'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0', 'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0',
'ALTER TABLE "product" RENAME COLUMN "image" TO "pic"',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(100) USING "password"::VARCHAR(100)',
'ALTER TABLE "user" DROP COLUMN "avatar"', 'ALTER TABLE "user" DROP COLUMN "avatar"',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(100)', 'CREATE INDEX "idx_product_name_869427" ON "product" ("name", "type_db_alias")',
'CREATE TABLE IF NOT EXISTS "newmodel" (\n "id" SERIAL NOT NULL PRIMARY KEY,\n "name" VARCHAR(50) NOT NULL\n);\nCOMMENT ON COLUMN "config"."user_id" IS \'User\';', 'CREATE INDEX "idx_email_email_4a1a33" ON "email" ("email")',
'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")',
'CREATE TABLE "email_user" ("email_id" INT NOT NULL REFERENCES "email" ("email_id") ON DELETE CASCADE,"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE)', 'CREATE TABLE "email_user" ("email_id" INT NOT NULL REFERENCES "email" ("email_id") ON DELETE CASCADE,"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE)',
'CREATE TABLE IF NOT EXISTS "newmodel" (\n "id" SERIAL NOT NULL PRIMARY KEY,\n "name" VARCHAR(50) NOT NULL\n);\nCOMMENT ON COLUMN "config"."user_id" IS \'User\';',
'CREATE UNIQUE INDEX "uid_product_name_869427" ON "product" ("name", "type_db_alias")',
'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")',
] ]
) )
assert sorted(Migrate.downgrade_operators) == sorted( assert sorted(Migrate.downgrade_operators) == sorted(
[ [
'ALTER TABLE "category" ALTER COLUMN "name" TYPE VARCHAR(200)', 'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL',
'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(200)', 'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(200) USING "slug"::VARCHAR(200)',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(200)', 'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1',
'ALTER TABLE "user" ALTER COLUMN "username" TYPE VARCHAR(20)',
'ALTER TABLE "config" DROP COLUMN "user_id"', 'ALTER TABLE "config" DROP COLUMN "user_id"',
'ALTER TABLE "config" DROP CONSTRAINT "fk_config_user_17daa970"', 'ALTER TABLE "config" DROP CONSTRAINT "fk_config_user_17daa970"',
'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1', 'ALTER TABLE "config" RENAME TO "configs"',
'ALTER TABLE "email" ADD "user_id" INT NOT NULL', 'ALTER TABLE "email" ADD "user_id" INT NOT NULL',
'ALTER TABLE "email" DROP COLUMN "address"', 'ALTER TABLE "email" DROP COLUMN "address"',
'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"',
'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"', 'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"',
'ALTER TABLE "email" ADD CONSTRAINT "fk_email_user_5b58673d" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE',
'DROP INDEX "idx_email_email_4a1a33"',
'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT', 'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT',
'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"',
'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'', 'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(200) USING "password"::VARCHAR(200)',
'DROP INDEX "idx_product_name_869427"',
'DROP INDEX "idx_email_email_4a1a33"',
'DROP INDEX "idx_user_usernam_9987ab"', 'DROP INDEX "idx_user_usernam_9987ab"',
'DROP INDEX "uid_product_name_f14935"', 'DROP INDEX "uid_product_name_869427"',
'DROP TABLE IF EXISTS "email_user"', 'DROP TABLE IF EXISTS "email_user"',
'DROP TABLE IF EXISTS "newmodel"', 'DROP TABLE IF EXISTS "newmodel"',
] ]