155 Commits

Author SHA1 Message Date
long2ice
9da99824fe fix: postgres sql error (#263) 2022-09-23 23:33:49 +08:00
long2ice
75db7cea60 fix: test error 2022-09-23 10:35:33 +08:00
long2ice
d777c9c278 Merge remote-tracking branch 'origin/dev' into dev
# Conflicts:
#	aerich/utils.py
#	tests/test_utils.py
2022-09-23 10:30:46 +08:00
long2ice
e9b76bdd35 feat: use .py version files 2022-09-23 10:29:48 +08:00
long2ice
8b7864d886 Merge pull request #199 from ehdgua01/fix-ddl-format-and-writing-version-file
[Enhancement] Fix version file formats
2022-09-16 09:30:27 +08:00
KDH
bef45941f2 Fix testcase 2022-09-16 10:26:21 +09:00
KDH
7b472d7a84 Fix testcase 2022-09-16 10:08:34 +09:00
KDH
1f0a6dfb50 Fix typo 2022-09-16 09:58:04 +09:00
KDH
36282f123f Merge branch 'dev' into fix-ddl-format-and-writing-version-file
# Conflicts:
#	aerich/utils.py
#	tests/test_migrate.py
2022-09-16 09:54:51 +09:00
KDH
3cd4e24050 Merge branch 'dev' into fix-ddl-format-and-writing-version-file 2022-09-16 09:43:57 +09:00
long2ice
f8c2f1b551 Merge pull request #205 from GDGSNF/dev
Merge repeated `if` statements into single `if`
2022-09-16 08:40:28 +08:00
Yasser Tahiri
131d97a3d6 Merge branch 'dev' into dev 2022-09-15 23:19:59 +01:00
Yasser Tahiri
1a0371e977 Update aerich/utils.py
Co-authored-by: KDH <ehdgua01@naver.com>
2022-09-15 23:19:18 +01:00
long2ice
e5b092fd08 Merge pull request #260 from waketzheng/dev
refactor: use pathlib to read and write text
2022-09-12 11:44:02 +08:00
Waket Zheng
7a109f3c79 refactor: use pathlib to read and write text 2022-09-12 00:57:46 +08:00
Jinlong Peng
8c2ecbaef1 feat: support add/remove field with index 2022-08-26 18:04:20 +08:00
long2ice
b141363c51 Merge pull request #242 from ssilaev/dev
Hotfix for cli group function in v0.6.3
2022-07-22 08:39:02 +08:00
long2ice
9dd474d79f Merge remote-tracking branch 'origin/dev' into dev 2022-06-27 11:42:37 +08:00
long2ice
e4bb9d838e docs: update changelog 2022-06-27 11:41:48 +08:00
long2ice
029d522c79 Merge pull request #249 from tortoise/fix-decimal
Fix decimal field change
2022-06-27 11:38:03 +08:00
long2ice
d6627906c7 test: fix test_migrate 2022-06-27 11:36:09 +08:00
long2ice
3c88833154 fix: decimal field change (#246) 2022-06-27 11:29:47 +08:00
long2ice
8f68f08eba Merge pull request #248 from isaquealves/feature/load_ddl_class_per_dialect
feat: Add support for dynamically load DDL classes
2022-06-22 20:25:42 +08:00
Isaque Alves
60ba6963fd Update changelog 2022-06-22 09:22:26 -03:00
Isaque Alves
4c35c44bd2 feat: Add support for dynamically load DDL classes
Adopt a strategy of loading classes based on their names, allowing to
easily add new database support without changing Migrate class logic
2022-06-22 09:16:11 -03:00
long2ice
bdeaf5495e Merge pull request #247 from isaquealves/feature/postgresql-numeric-type-translate
refactor: Improve db inspection
2022-06-22 08:41:15 +08:00
Isaque Alves
db33059ec9 Resolve style issue 2022-06-20 15:42:21 -03:00
Isaque Alves
44b96058f8 fix(tests/test_migrate.py): Resolve issue with broken tests 2022-06-17 12:36:04 -03:00
Isaque Alves
abff753b6a refactor: Improve postgresql migrate operators tests 2022-06-17 09:45:02 -03:00
Isaque Alves
dcd8441a05 fix: add space following python style guide" 2022-06-17 02:03:41 -03:00
Isaque Alves
b4a735b814 fix: Adjust changelog formatting 2022-06-17 02:02:37 -03:00
Isaque Alves
83ba13e99a Update changelog 2022-06-17 02:00:35 -03:00
Isaque Alves
d7b1c07d13 fix: Add comma to separate value in join 2022-06-17 01:51:47 -03:00
Isaque Alves
1ac16188fc refactor: Improve db inspection
- Add support to postgresql numeric type.
- Improve field configuration handling for numeric and decimal types
2022-06-17 01:38:39 -03:00
long2ice
4abc464ce0 feat: add is_flag to init-db 2022-05-24 11:20:12 +08:00
Sergey Silaev
d4430cec0d Hotfix for cli group function in v0.6.3 2022-05-10 01:20:11 +04:00
long2ice
0b01fa38d8 feat: add index inspect 2022-04-05 19:38:08 +08:00
long2ice
801dde15be feat: inspectdb support sqlite 2022-04-01 20:30:36 +08:00
long2ice
75480e2041 Merge remote-tracking branch 'origin/dev' into dev 2022-04-01 19:57:03 +08:00
long2ice
45129cef9f feat: improve inspectdb and support postgres 2022-04-01 19:56:48 +08:00
long2ice
3a0dd2355d Merge pull request #230 from ssilaev/dev
Increase max length of app column
2022-02-09 15:01:39 +08:00
Sergey Silaev
0e71bc16ae Increase max length of app column 2022-02-08 22:14:55 +03:00
long2ice
c39462820c upgrade deps 2022-01-17 22:26:13 +08:00
long2ice
f15cbaf9e0 Support migration for specified index. (#203) 2021-12-29 21:36:23 +08:00
long2ice
15131469df upgrade deps 2021-12-22 16:26:13 +08:00
long2ice
c60c1610f0 Fix pyproject.toml not existing error. (#217) 2021-12-12 22:11:51 +08:00
long2ice
63e8d06157 remove aiomysql 2021-12-08 14:43:33 +08:00
long2ice
68ef8ac676 Fix ci 2021-12-08 14:38:16 +08:00
long2ice
8b5cf6faa0 inspectdb support DATE. (#215) 2021-12-08 14:33:27 +08:00
Yasser Tahiri
40c7ef7fd6 Merge repeated if statements into single if 2021-10-21 15:43:18 +01:00
KDH
7a826df43f Fix duplicated semicolon in table creation DDL 2021-10-12 11:24:37 +09:00
KDH
b1b9cc1454 Fix M2M table template 2021-10-12 11:23:29 +09:00
long2ice
fac00d45cc Remove pydantic dependency. (#198) 2021-10-04 23:05:20 +08:00
long2ice
6f7893d376 Fix section name 2021-09-28 15:07:10 +08:00
long2ice
b1521c4cc7 update version 2021-09-27 19:55:38 +08:00
long2ice
24c1f4cb7d Change default config file from aerich.ini to pyproject.toml. (#197) 2021-09-27 11:05:20 +08:00
long2ice
661f241dac Compatible with old version in indexes 2021-08-31 17:53:17 +08:00
long2ice
01787558d6 Fix test 2021-08-31 17:41:13 +08:00
long2ice
699b0321a4 Support indexes change. (#193) 2021-08-31 17:36:25 +08:00
long2ice
4a83021892 Update FUNDING.yml 2021-08-26 20:39:31 +08:00
long2ice
af63221875 Fix no module found error. (#188) (#189) 2021-08-16 11:14:43 +08:00
long2ice
359525716c update README.md 2021-08-12 15:42:54 +08:00
long2ice
7d3eb2e151 Merge pull request #181 from Vovetta/dev
Fix: migrate doesn't use source_field in unique_together
2021-08-04 09:42:18 +08:00
Vovetta
d8abf79449 Updated changelog and version 2021-08-03 10:38:31 -07:00
Vovetta
aa9f40ae27 Fix: migrate doesn't use source_field in unique_together 2021-08-03 10:36:06 -07:00
long2ice
79b7ae343a update README.md 2021-08-03 16:25:06 +08:00
long2ice
6f5a9ab78c Add Command class. (#148) (#141) (#123) (#106) 2021-08-03 16:18:07 +08:00
long2ice
1e5a83c281 update deps 2021-07-26 17:44:18 +08:00
long2ice
180420843d update README.md 2021-07-26 15:27:49 +08:00
long2ice
58f66b91cf Fix redundant semicolons 2021-07-23 17:07:10 +08:00
long2ice
064d7ff675 Fix ci 2021-07-22 15:32:07 +08:00
long2ice
2da794d823 Fix db_constraint when fk changed. (#179) 2021-07-22 14:37:49 +08:00
long2ice
77005f3793 Fix MySQL 5.X rename column. 2021-07-09 10:53:13 +08:00
long2ice
5a873b8b69 Merge pull request #177 from yusukefs/add-default-src-folder-config
Add default value for src_folder config
2021-07-08 17:27:29 +08:00
Yusuke Sakai
3989b7c674 Update version and changelog 2021-07-08 18:01:59 +09:00
Yusuke Sakai
694b05356f Add default src_folder cofig value 2021-07-08 17:35:44 +09:00
long2ice
919d56c936 add ci branches-ignore master 2021-07-07 10:29:38 +08:00
long2ice
7bcf9b2fed Support drop column for sqlite. (#40) 2021-07-03 13:51:01 +08:00
long2ice
9f663299cf Merge pull request #174 from sasha00123/dev
Fixed typo in README.md concerning dowgrade usage
2021-06-25 13:52:34 +08:00
Alexander Batyrgariev
28dbdf2663 Fixed typo in README.md concerning dowgrade usage 2021-06-25 08:00:00 +03:00
long2ice
e71a4b60a5 Merge pull request #166 from spacemanspiff2007/dev
Added config option to specify source folder
2021-06-13 14:26:26 +08:00
-
62840136be used old black version 2021-06-11 15:36:54 +02:00
-
185514f711 reformatted with black 2021-06-11 15:18:06 +02:00
-
8e783e031e updated readme 2021-06-10 16:56:30 +02:00
-
10b7272ca8 Added an configuration option to specify the path of the source folder.
This will make aerich work with various folder structures (e.g. ./src/MyPythonModule)
Additionally this will try to import in init and show the user the error message on failure.
2021-06-10 16:52:03 +02:00
long2ice
0c763c6024 Fix repeat 2021-06-09 13:56:25 +08:00
long2ice
c6371a5c16 Fix repeat 2021-06-09 11:43:32 +08:00
long2ice
1dbf9185b6 Not catch exception when import config. (#164) 2021-06-04 17:47:39 +08:00
long2ice
9bf2de0b9a Fix incorrect index creation order. (#151) 2021-06-01 17:09:45 +08:00
long2ice
bf1cf21324 Merge pull request #158 from manzato/pyproject-update
Update URLs
2021-05-22 22:43:52 +08:00
Guillermo Manzato
8b08329493 Update URLs 2021-05-22 11:39:49 -03:00
long2ice
5bc7d23d95 Merge pull request #157 from tortoise/dependabot/pip/pydantic-1.8.2
Bump pydantic from 1.8.1 to 1.8.2
2021-05-14 09:30:47 +08:00
dependabot[bot]
a253aa96cb Bump pydantic from 1.8.1 to 1.8.2
Bumps [pydantic](https://github.com/samuelcolvin/pydantic) from 1.8.1 to 1.8.2.
- [Release notes](https://github.com/samuelcolvin/pydantic/releases)
- [Changelog](https://github.com/samuelcolvin/pydantic/blob/master/HISTORY.md)
- [Commits](https://github.com/samuelcolvin/pydantic/compare/v1.8.1...v1.8.2)

Signed-off-by: dependabot[bot] <support@github.com>
2021-05-13 20:51:51 +00:00
long2ice
15a6e874dd update deps 2021-05-03 14:23:27 +08:00
long2ice
19a5dcbf3f update deps 2021-04-26 21:01:40 +08:00
long2ice
922e3eef16 Fix CI 2021-04-05 17:11:28 +08:00
long2ice
44fd2fe6ae Fix default function when migrate. (#147) 2021-04-05 14:10:42 +08:00
long2ice
b147859960 Fix default function when migrate 2021-04-04 05:46:34 +00:00
long2ice
793cf2532c Create FUNDING.yml 2021-04-03 21:34:24 +08:00
long2ice
fa85e05d1d Fix postgre alter null. (#142) 2021-03-28 16:22:49 +08:00
long2ice
3f52ac348b Support rename table. (#139) 2021-03-25 21:21:49 +08:00
long2ice
f8aa7a8f34 Fix inspectdb for FloatField. (#138) 2021-03-22 14:16:59 +08:00
long2ice
44d520cc82 Fix postgres field type change error. (#135) 2021-03-21 21:18:08 +08:00
long2ice
364735f804 Fix rename field on the field add. (#134) 2021-03-21 20:43:05 +08:00
long2ice
505d361597 Fix drop model in the downgrade. (#132) 2021-03-18 23:40:13 +08:00
long2ice
a19edd3a35 update ci name 2021-03-13 16:45:35 +08:00
long2ice
84d1f78019 update workflow name and add cryptography 2021-03-13 16:43:22 +08:00
long2ice
8fb07a6c9e update deps 2021-03-13 16:40:27 +08:00
long2ice
54da8b22af update aiomysql to asyncmy 2021-03-13 16:37:45 +08:00
long2ice
4c0308ff22 update test.yml 2021-03-03 22:03:38 +08:00
long2ice
38c4a15661 update test.yml 2021-03-03 20:42:18 +08:00
long2ice
52151270e0 Fix bug for field change. (#119) 2021-03-03 20:36:54 +08:00
long2ice
49897dc4fd Merge pull request #121 from AulonSal/close-tortoise-connections
Close Tortoise connections properly
2021-02-28 14:47:58 +08:00
AulonSal
d4ad0e270f Update version and changelog 2021-02-28 12:13:59 +05:30
AulonSal
e74fc304a5 Don't close db connections when group function \(cli\) is run 2021-02-27 00:43:55 +05:30
AulonSal
14d20455e6 Replace coro logic with tortoise.run_async 2021-02-23 13:06:40 +05:30
long2ice
bd9ecfd6e1 Merge pull request #122 from personalcomputer/personalcomputer/improve_readme_english
Improve English grammar / clarity in README.md
2021-02-22 12:31:15 +08:00
John Miller
de8500b9a1 Improve English grammar / clarity in README.md 2021-02-21 19:46:04 -08:00
AulonSal
90b47c5af7 Close connections even if command raises exception 2021-02-22 07:40:18 +05:30
AulonSal
02fe5a9d31 Close Tortoise connections properly 2021-02-20 13:11:29 +05:30
long2ice
be41a1332a update tortoise-orm version 2021-02-04 20:53:04 +08:00
long2ice
09661c1d46 Fix unique_together 2021-02-04 14:39:07 +08:00
long2ice
abfa60133f Fix drop table 2021-02-04 14:23:46 +08:00
long2ice
048e428eac update tortoise-orm 2021-02-03 22:52:01 +08:00
long2ice
38a3df9b5a add support m2m 2021-02-03 22:22:22 +08:00
long2ice
0d94b22b3f Remove unused functions 2021-02-03 18:06:43 +08:00
long2ice
f1f0074255 Support rename field 2021-02-03 17:56:30 +08:00
long2ice
e3a14a2f60 Fix postgres index 2021-02-03 16:34:07 +08:00
long2ice
608ff8f071 update conftest.py 2021-02-03 15:49:40 +08:00
long2ice
d3a1342293 update README.md 2021-02-03 15:48:06 +08:00
long2ice
01e3de9522 basically completed 2021-02-03 15:43:04 +08:00
long2ice
c6c398fdf0 update 2021-02-02 22:52:50 +08:00
long2ice
c60bdd290e add fk and drop fk 2021-02-02 20:35:05 +08:00
long2ice
f443dc68db WIP 2021-02-01 16:54:35 +08:00
long2ice
36f84702b7 update 2021-02-01 14:00:12 +08:00
long2ice
b4cc2de0e3 v0.5 refactoring 2021-01-31 23:10:30 +08:00
long2ice
4780b90c1c add close_connections to fix stuck 2021-01-29 22:58:12 +08:00
long2ice
cd176c1fd6 Merge pull request #111 from lqmanh/bugfixes/fix-tortoise-orm-0.16.19
Fix Aerich b/c of a new feature in Tortoise ORM v0.16.19
2021-01-04 14:59:11 +08:00
long2ice
c2819fc8dc update CHANGELOG.md 2020-12-29 19:13:37 +08:00
long2ice
530e7cfce5 Fixed unnecessary import. (#113) 2020-12-29 19:12:36 +08:00
Lương Quang Mạnh
47824a100b Fix Aerich b/c of Tortoise ORM v0.16.19 2020-12-26 10:31:10 +07:00
long2ice
78a15f9f19 Merge pull request #108 from lqmanh/features/make-parent-dirs-as-needed
Make parent directories as needed
2020-12-25 22:10:56 +08:00
long2ice
5ae8b9e85f complete InspectDb 2020-12-25 21:44:26 +08:00
long2ice
55a6d4bbc7 add InspectDb and show_create_tables 2020-12-24 23:32:58 +08:00
long2ice
c5535f16e1 TODO: Add inspectdb command 2020-12-23 23:38:45 +08:00
long2ice
840cd71e44 Replace migrations separator to sql standard comment 2020-12-23 23:30:35 +08:00
Lương Quang Mạnh
e0d52b1210 Fix make style 2020-12-21 15:36:29 +07:00
Lương Quang Mạnh
4dc45f723a Make parent directories as needed 2020-12-21 15:13:26 +07:00
long2ice
d2e0a68351 Fix packaging error. (#92) 2020-12-02 23:03:15 +08:00
long2ice
ee6cc20c7d Fix empty items 2020-11-30 11:14:09 +08:00
long2ice
4e917495a0 Fix upgrade in new db. (#96) 2020-11-30 11:02:48 +08:00
long2ice
bfa66f6dd4 update changelog 2020-11-29 11:15:43 +08:00
long2ice
f00715d4c4 Merge pull request #97 from TrDex/pathlib-for-path-resolving
Use `pathlib` for path resolving
2020-11-29 11:02:44 +08:00
Mykola Solodukha
6e3105690a Use pathlib for path resolving 2020-11-28 19:23:34 +02:00
long2ice
c707f7ecb2 bug fix 2020-11-28 14:31:41 +08:00
35 changed files with 3089 additions and 1470 deletions

1
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1 @@
custom: ["https://sponsor.long2ice.io"]

View File

@@ -1,7 +1,13 @@
name: test
on: [ push, pull_request ]
name: ci
on:
push:
branches-ignore:
- main
pull_request:
branches-ignore:
- main
jobs:
testall:
ci:
runs-on: ubuntu-latest
services:
postgres:
@@ -20,9 +26,9 @@ jobs:
with:
python-version: '3.x'
- name: Install and configure Poetry
uses: snok/install-poetry@v1.1.1
with:
virtualenvs-create: false
run: |
pip install -U pip poetry
poetry config virtualenvs.create false
- name: CI
env:
MYSQL_PASS: root

View File

@@ -12,9 +12,9 @@ jobs:
with:
python-version: '3.x'
- name: Install and configure Poetry
uses: snok/install-poetry@v1.1.1
with:
virtualenvs-create: false
run: |
pip install -U pip poetry
poetry config virtualenvs.create false
- name: Build dists
run: make build
- name: Pypi Publish

1
.gitignore vendored
View File

@@ -146,3 +146,4 @@ aerich.ini
src
.vscode
.DS_Store
.python-version

View File

@@ -1,7 +1,121 @@
# ChangeLog
## 0.7
### 0.7.1rc1
- Fix postgres sql error (#263)
### 0.7.0
**Now aerich use `.py` file to record versions.**
Upgrade Note:
1. Drop `aerich` table
2. Delete `migrations/models` folder
3. Run `aerich init-db`
- Improve `inspectdb` adding support to `postgresql::numeric` data type
- Add support for dynamically load DDL classes easing to add support to
new databases without changing `Migrate` class logic
- Fix decimal field change. (#246)
- Support add/remove field with index.
## 0.6
### 0.6.3
- Improve `inspectdb` and support `postgres` & `sqlite`.
### 0.6.2
- Support migration for specified index. (#203)
### 0.6.1
- Fix `pyproject.toml` not existing error. (#217)
### 0.6.0
- Change default config file from `aerich.ini` to `pyproject.toml`. (#197)
**Upgrade note:**
1. Run `aerich init -t config.TORTOISE_ORM`.
2. Remove `aerich.ini`.
- Remove `pydantic` dependency. (#198)
- `inspectdb` support `DATE`. (#215)
## 0.5
### 0.5.8
- Support `indexes` change. (#193)
### 0.5.7
- Fix no module found error. (#188) (#189)
### 0.5.6
- Add `Command` class. (#148) (#141) (#123) (#106)
- Fix: migrate doesn't use source_field in unique_together. (#181)
### 0.5.5
- Fix KeyError: 'src_folder' after upgrading aerich to 0.5.4. (#176)
- Fix MySQL 5.X rename column.
- Fix `db_constraint` when fk changed. (#179)
### 0.5.4
- Fix incorrect index creation order. (#151)
- Not catch exception when import config. (#164)
- Support `drop column` for sqlite. (#40)
### 0.5.3
- Fix postgre alter null. (#142)
- Fix default function when migrate. (#147)
### 0.5.2
- Fix rename field on the field add. (#134)
- Fix postgres field type change error. (#135)
- Fix inspectdb for `FloatField`. (#138)
- Support `rename table`. (#139)
### 0.5.1
- Fix tortoise connections not being closed properly. (#120)
- Fix bug for field change. (#119)
- Fix drop model in the downgrade. (#132)
### 0.5.0
- Refactor core code, now has no limitation for everything.
## 0.4
### 0.4.4
- Fix unnecessary import. (#113)
### 0.4.3
- Replace migrations separator to sql standard comment.
- Add `inspectdb` command.
### 0.4.2
- Use `pathlib` for path resolving. (#89)
- Fix upgrade in new db. (#96)
- Fix packaging error. (#92)
### 0.4.1
- Bug fix. (#91 #93)
### 0.4.0
- Use `.sql` instead of `.json` to store version file.

View File

@@ -8,32 +8,19 @@ POSTGRES_HOST ?= "127.0.0.1"
POSTGRES_PORT ?= 5432
POSTGRES_PASS ?= "123456"
help:
@echo "Aerich development makefile"
@echo
@echo "usage: make <target>"
@echo "Targets:"
@echo " up Updates dev/test dependencies"
@echo " deps Ensure dev/test dependencies are installed"
@echo " check Checks that build is sane"
@echo " lint Reports all linter violations"
@echo " test Runs all tests"
@echo " style Auto-formats the code"
up:
@poetry update
deps:
@poetry install -E dbdrivers
@poetry install -E asyncpg -E asyncmy
style: deps
isort -src $(checkfiles)
black $(black_opts) $(checkfiles)
@isort -src $(checkfiles)
@black $(black_opts) $(checkfiles)
check: deps
black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
flake8 $(checkfiles)
bandit -x tests -r $(checkfiles)
@black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
@pflake8 $(checkfiles)
test: deps
$(py_warn) TEST_DB=sqlite://:memory: py.test
@@ -45,7 +32,7 @@ test_mysql:
$(py_warn) TEST_DB="mysql://root:$(MYSQL_PASS)@$(MYSQL_HOST):$(MYSQL_PORT)/test_\{\}" pytest -vv -s
test_postgres:
$(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest
$(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest -vv -s
testall: deps test_sqlite test_postgres test_mysql

152
README.md
View File

@@ -1,23 +1,21 @@
# Aerich
[![image](https://img.shields.io/pypi/v/aerich.svg?style=flat)](https://pypi.python.org/pypi/aerich)
[![image](https://img.shields.io/github/license/long2ice/aerich)](https://github.com/long2ice/aerich)
[![image](https://github.com/long2ice/aerich/workflows/pypi/badge.svg)](https://github.com/long2ice/aerich/actions?query=workflow:pypi)
[![image](https://github.com/long2ice/aerich/workflows/test/badge.svg)](https://github.com/long2ice/aerich/actions?query=workflow:test)
[![image](https://img.shields.io/github/license/tortoise/aerich)](https://github.com/tortoise/aerich)
[![image](https://github.com/tortoise/aerich/workflows/pypi/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:pypi)
[![image](https://github.com/tortoise/aerich/workflows/ci/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:ci)
## Introduction
Aerich is a database migrations tool for Tortoise-ORM, which like alembic for SQLAlchemy, or Django ORM with it\'s
own migrations solution.
**Important: You can only use absolutely import in your `models.py` to make `aerich` work.**
Aerich is a database migrations tool for TortoiseORM, which is like alembic for SQLAlchemy, or like Django ORM with
it\'s own migration solution.
## Install
Just install from pypi:
```shell
> pip install aerich
pip install aerich
```
## Quick Start
@@ -28,10 +26,9 @@ Just install from pypi:
Usage: aerich [OPTIONS] COMMAND [ARGS]...
Options:
-c, --config TEXT Config file. [default: aerich.ini]
--app TEXT Tortoise-ORM app name. [default: models]
-n, --name TEXT Name of section in .ini file to use for aerich config.
[default: aerich]
-V, --version Show the version and exit.
-c, --config TEXT Config file. [default: pyproject.toml]
--app TEXT Tortoise-ORM app name.
-h, --help Show this message and exit.
Commands:
@@ -40,14 +37,14 @@ Commands:
history List all migrate items.
init Init config file and generate root migrate location.
init-db Generate schema and generate app migrate location.
inspectdb Introspects the database tables to standard output as...
migrate Generate migrate changes file.
upgrade Upgrade to latest version.
upgrade Upgrade to specified version.
```
## Usage
You need add `aerich.models` to your `Tortoise-ORM` config first,
example:
You need add `aerich.models` to your `Tortoise-ORM` config first. Example:
```python
TORTOISE_ORM = {
@@ -71,19 +68,20 @@ Usage: aerich init [OPTIONS]
Init config file and generate root migrate location.
Options:
-t, --tortoise-orm TEXT Tortoise-ORM config module dict variable, like settings.TORTOISE_ORM.
[required]
-t, --tortoise-orm TEXT Tortoise-ORM config module dict variable, like
settings.TORTOISE_ORM. [required]
--location TEXT Migrate store location. [default: ./migrations]
-s, --src_folder TEXT Folder of the source, relative to the project root.
-h, --help Show this message and exit.
```
Init config file and location:
Initialize the config file and migrations location:
```shell
> aerich init -t tests.backends.mysql.TORTOISE_ORM
Success create migrate location ./migrations
Success generate config file aerich.ini
Success write config to pyproject.toml
```
### Init db
@@ -95,28 +93,30 @@ Success create app migrate location ./migrations/models
Success generate schema for app "models"
```
If your Tortoise-ORM app is not default `models`, you must specify
`--app` like `aerich --app other_models init-db`.
If your Tortoise-ORM app is not the default `models`, you must specify the correct app via `--app`,
e.g. `aerich --app other_models init-db`.
### Update models and make migrate
```shell
> aerich migrate --name drop_column
Success migrate 1_202029051520102929_drop_column.sql
Success migrate 1_202029051520102929_drop_column.py
```
Format of migrate filename is
`{version_num}_{datetime}_{name|update}.sql`.
`{version_num}_{datetime}_{name|update}.py`.
And if `aerich` guess you are renaming a column, it will ask `Rename {old_column} to {new_column} [True]`, you can choice `True` to rename column without column drop, or choice `False` to drop column then create.
If `aerich` guesses you are renaming a column, it will ask `Rename {old_column} to {new_column} [True]`. You can choose
`True` to rename column without column drop, or choose `False` to drop the column then create. Note that the latter may
lose data.
### Upgrade to latest version
```shell
> aerich upgrade
Success upgrade 1_202029051520102929_drop_column.sql
Success upgrade 1_202029051520102929_drop_column.py
```
Now your db is migrated to latest.
@@ -124,7 +124,7 @@ Now your db is migrated to latest.
### Downgrade to specified version
```shell
> aerich init -h
> aerich downgrade -h
Usage: aerich downgrade [OPTIONS]
@@ -142,17 +142,17 @@ Options:
```shell
> aerich downgrade
Success downgrade 1_202029051520102929_drop_column.sql
Success downgrade 1_202029051520102929_drop_column.py
```
Now your db rollback to specified version.
Now your db is rolled back to the specified version.
### Show history
```shell
> aerich history
1_202029051520102929_drop_column.sql
1_202029051520102929_drop_column.py
```
### Show heads to be migrated
@@ -160,9 +160,74 @@ Now your db rollback to specified version.
```shell
> aerich heads
1_202029051520102929_drop_column.sql
1_202029051520102929_drop_column.py
```
### Inspect db tables to TortoiseORM model
Currently `inspectdb` support MySQL & Postgres & SQLite.
```shell
Usage: aerich inspectdb [OPTIONS]
Introspects the database tables to standard output as TortoiseORM model.
Options:
-t, --table TEXT Which tables to inspect.
-h, --help Show this message and exit.
```
Inspect all tables and print to console:
```shell
aerich --app models inspectdb
```
Inspect a specified table in the default app and redirect to `models.py`:
```shell
aerich inspectdb -t user > models.py
```
For example, you table is:
```sql
CREATE TABLE `test`
(
`id` int NOT NULL AUTO_INCREMENT,
`decimal` decimal(10, 2) NOT NULL,
`date` date DEFAULT NULL,
`datetime` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`time` time DEFAULT NULL,
`float` float DEFAULT NULL,
`string` varchar(200) COLLATE utf8mb4_general_ci DEFAULT NULL,
`tinyint` tinyint DEFAULT NULL,
PRIMARY KEY (`id`),
KEY `asyncmy_string_index` (`string`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COLLATE = utf8mb4_general_ci
```
Now run `aerich inspectdb -t test` to see the generated model:
```python
from tortoise import Model, fields
class Test(Model):
date = fields.DateField(null=True, )
datetime = fields.DatetimeField(auto_now=True, )
decimal = fields.DecimalField(max_digits=10, decimal_places=2, )
float = fields.FloatField(null=True, )
id = fields.IntField(pk=True, )
string = fields.CharField(max_length=200, null=True, )
time = fields.TimeField(null=True, )
tinyint = fields.BooleanField(null=True, )
```
Note that this command is limited and can't infer some fields, such as `IntEnumField`, `ForeignKeyField`, and others.
### Multiple databases
```python
@@ -178,13 +243,30 @@ tortoise_orm = {
}
```
You need only specify `aerich.models` in one app, and must specify `--app` when run `aerich migrate` and so on.
You only need to specify `aerich.models` in one app, and must specify `--app` when running `aerich migrate` and so on.
## Support this project
## Restore `aerich` workflow
| AliPay | WeChatPay | PayPal |
| -------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | ---------------------------------------------------------------- |
| <img width="200" src="https://github.com/long2ice/aerich/raw/dev/images/alipay.jpeg"/> | <img width="200" src="https://github.com/long2ice/aerich/raw/dev/images/wechatpay.jpeg"/> | [PayPal](https://www.paypal.me/long2ice) to my account long2ice. |
In some cases, such as broken changes from upgrade of `aerich`, you can't run `aerich migrate` or `aerich upgrade`, you
can make the following steps:
1. drop `aerich` table.
2. delete `migrations/{app}` directory.
3. rerun `aerich init-db`.
Note that these actions is safe, also you can do that to reset your migrations if your migration files is too many.
## Use `aerich` in application
You can use `aerich` out of cli by use `Command` class.
```python
from aerich import Command
command = Command(tortoise_config=config, app='models')
await command.init()
await command.migrate('test')
```
## License

View File

@@ -1 +1,143 @@
__version__ = "0.4.0"
import os
from pathlib import Path
from typing import List
from tortoise import Tortoise, generate_schema_for_client
from tortoise.exceptions import OperationalError
from tortoise.transactions import in_transaction
from tortoise.utils import get_schema_sql
from aerich.exceptions import DowngradeError
from aerich.inspectdb.mysql import InspectMySQL
from aerich.inspectdb.postgres import InspectPostgres
from aerich.inspectdb.sqlite import InspectSQLite
from aerich.migrate import MIGRATE_TEMPLATE, Migrate
from aerich.models import Aerich
from aerich.utils import (
get_app_connection,
get_app_connection_name,
get_models_describe,
import_py_file,
)
class Command:
def __init__(
self,
tortoise_config: dict,
app: str = "models",
location: str = "./migrations",
):
self.tortoise_config = tortoise_config
self.app = app
self.location = location
Migrate.app = app
async def init(self):
await Migrate.init(self.tortoise_config, self.app, self.location)
async def upgrade(self):
migrated = []
for version_file in Migrate.get_all_version_files():
try:
exists = await Aerich.exists(version=version_file, app=self.app)
except OperationalError:
exists = False
if not exists:
async with in_transaction(
get_app_connection_name(self.tortoise_config, self.app)
) as conn:
file_path = Path(Migrate.migrate_location, version_file)
m = import_py_file(file_path)
upgrade = getattr(m, "upgrade")
await upgrade(conn)
await Aerich.create(
version=version_file,
app=self.app,
content=get_models_describe(self.app),
)
migrated.append(version_file)
return migrated
async def downgrade(self, version: int, delete: bool):
ret = []
if version == -1:
specified_version = await Migrate.get_last_version()
else:
specified_version = await Aerich.filter(
app=self.app, version__startswith=f"{version}_"
).first()
if not specified_version:
raise DowngradeError("No specified version found")
if version == -1:
versions = [specified_version]
else:
versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk)
for version in versions:
file = version.version
async with in_transaction(
get_app_connection_name(self.tortoise_config, self.app)
) as conn:
file_path = Path(Migrate.migrate_location, file)
m = import_py_file(file_path)
downgrade = getattr(m, "downgrade", None)
if not downgrade:
raise DowngradeError("No downgrade items found")
await downgrade(conn)
await version.delete()
if delete:
os.unlink(file_path)
ret.append(file)
return ret
async def heads(self):
ret = []
versions = Migrate.get_all_version_files()
for version in versions:
if not await Aerich.exists(version=version, app=self.app):
ret.append(version)
return ret
async def history(self):
versions = Migrate.get_all_version_files()
return [version for version in versions]
async def inspectdb(self, tables: List[str] = None) -> str:
connection = get_app_connection(self.tortoise_config, self.app)
dialect = connection.schema_generator.DIALECT
if dialect == "mysql":
cls = InspectMySQL
elif dialect == "postgres":
cls = InspectPostgres
elif dialect == "sqlite":
cls = InspectSQLite
else:
raise NotImplementedError(f"{dialect} is not supported")
inspect = cls(connection, tables)
return await inspect.inspect()
async def migrate(self, name: str = "update"):
return await Migrate.migrate(name)
async def init_db(self, safe: bool):
location = self.location
app = self.app
dirname = Path(location, app)
dirname.mkdir(parents=True)
await Tortoise.init(config=self.tortoise_config)
connection = get_app_connection(self.tortoise_config, app)
await generate_schema_for_client(connection, safe)
schema = get_schema_sql(connection, safe)
version = await Migrate.generate_version()
await Aerich.create(
version=version,
app=app,
content=get_models_describe(app),
)
version_file = Path(dirname, version)
content = MIGRATE_TEMPLATE.format(upgrade_sql=f'"""{schema}"""', downgrade_sql="")
with open(version_file, "w", encoding="utf-8") as f:
f.write(content)

View File

@@ -1,42 +1,37 @@
import asyncio
import os
import sys
from configparser import ConfigParser
from functools import wraps
from pathlib import Path
from typing import List
import click
import tomlkit
from click import Context, UsageError
from tortoise import Tortoise, generate_schema_for_client
from tortoise.exceptions import OperationalError
from tortoise.transactions import in_transaction
from tortoise.utils import get_schema_sql
from tomlkit.exceptions import NonExistentKey
from tortoise import Tortoise
from aerich.migrate import Migrate
from aerich.utils import (
get_app_connection,
get_app_connection_name,
get_tortoise_config,
get_version_content_from_file,
write_version_file,
)
from aerich import Command
from aerich.enums import Color
from aerich.exceptions import DowngradeError
from aerich.utils import add_src_path, get_tortoise_config
from aerich.version import __version__
from . import __version__
from .enums import Color
from .models import Aerich
parser = ConfigParser()
CONFIG_DEFAULT_VALUES = {
"src_folder": ".",
}
def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
ctx = args[0]
# Close db connections at the end of all but the cli group function
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
if f.__name__ not in ["cli", "init_db", "init"]:
loop.run_until_complete(Tortoise.close_connections())
app = ctx.obj.get("app")
if app:
Migrate.remove_old_model_file(app, ctx.obj["location"])
return wrapper
@@ -46,42 +41,40 @@ def coro(f):
@click.option(
"-c",
"--config",
default="aerich.ini",
default="pyproject.toml",
show_default=True,
help="Config file.",
)
@click.option("--app", required=False, help="Tortoise-ORM app name.")
@click.option(
"-n",
"--name",
default="aerich",
show_default=True,
help="Name of section in .ini file to use for aerich config.",
)
@click.pass_context
@coro
async def cli(ctx: Context, config, app, name):
async def cli(ctx: Context, config, app):
ctx.ensure_object(dict)
ctx.obj["config_file"] = config
ctx.obj["name"] = name
invoked_subcommand = ctx.invoked_subcommand
if invoked_subcommand != "init":
if not os.path.exists(config):
config_path = Path(config)
if not config_path.exists():
raise UsageError("You must exec init first", ctx=ctx)
parser.read(config)
location = parser[name]["location"]
tortoise_orm = parser[name]["tortoise_orm"]
content = config_path.read_text()
doc = tomlkit.parse(content)
try:
tool = doc["tool"]["aerich"]
location = tool["location"]
tortoise_orm = tool["tortoise_orm"]
src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"])
except NonExistentKey:
raise UsageError("You need run aerich init again when upgrade to 0.6.0+")
add_src_path(src_folder)
tortoise_config = get_tortoise_config(ctx, tortoise_orm)
app = app or list(tortoise_config.get("apps").keys())[0]
ctx.obj["config"] = tortoise_config
ctx.obj["location"] = location
ctx.obj["app"] = app
Migrate.app = app
command = Command(tortoise_config=tortoise_config, app=app, location=location)
ctx.obj["command"] = command
if invoked_subcommand != "init-db":
await Migrate.init_with_old_models(tortoise_config, app, location)
if not Path(location, app).exists():
raise UsageError("You must exec init-db first", ctx=ctx)
await command.init()
@cli.command(help="Generate migrate changes file.")
@@ -89,7 +82,8 @@ async def cli(ctx: Context, config, app, name):
@click.pass_context
@coro
async def migrate(ctx: Context, name):
ret = await Migrate.migrate(name)
command = ctx.obj["command"]
ret = await command.migrate(name)
if not ret:
return click.secho("No changes detected", fg=Color.yellow)
click.secho(f"Success migrate {ret}", fg=Color.green)
@@ -99,32 +93,13 @@ async def migrate(ctx: Context, name):
@click.pass_context
@coro
async def upgrade(ctx: Context):
config = ctx.obj["config"]
app = ctx.obj["app"]
location = ctx.obj["location"]
migrated = False
for version_file in Migrate.get_all_version_files():
try:
exists = await Aerich.exists(version=version_file, app=app)
except OperationalError:
exists = False
if not exists:
async with in_transaction(get_app_connection_name(config, app)) as conn:
file_path = os.path.join(Migrate.migrate_location, version_file)
content = get_version_content_from_file(file_path)
upgrade_query_list = content.get("upgrade")
print(upgrade_query_list)
for upgrade_query in upgrade_query_list:
await conn.execute_script(upgrade_query)
await Aerich.create(
version=version_file,
app=app,
content=Migrate.get_models_content(config, app, location),
)
click.secho(f"Success upgrade {version_file}", fg=Color.green)
migrated = True
command = ctx.obj["command"]
migrated = await command.upgrade()
if not migrated:
click.secho("No migrate items", fg=Color.yellow)
click.secho("No upgrade items found", fg=Color.yellow)
else:
for version_file in migrated:
click.secho(f"Success upgrade {version_file}", fg=Color.green)
@cli.command(help="Downgrade to specified version.")
@@ -150,31 +125,12 @@ async def upgrade(ctx: Context):
)
@coro
async def downgrade(ctx: Context, version: int, delete: bool):
app = ctx.obj["app"]
config = ctx.obj["config"]
if version == -1:
specified_version = await Migrate.get_last_version()
else:
specified_version = await Aerich.filter(app=app, version__startswith=f"{version}_").first()
if not specified_version:
return click.secho("No specified version found", fg=Color.yellow)
if version == -1:
versions = [specified_version]
else:
versions = await Aerich.filter(app=app, pk__gte=specified_version.pk)
for version in versions:
file = version.version
async with in_transaction(get_app_connection_name(config, app)) as conn:
file_path = os.path.join(Migrate.migrate_location, file)
content = get_version_content_from_file(file_path)
downgrade_query_list = content.get("downgrade")
if not downgrade_query_list:
return click.secho("No downgrade items found", fg=Color.yellow)
for downgrade_query in downgrade_query_list:
await conn.execute_query(downgrade_query)
await version.delete()
if delete:
os.unlink(file_path)
command = ctx.obj["command"]
try:
files = await command.downgrade(version, delete)
except DowngradeError as e:
return click.secho(str(e), fg=Color.yellow)
for file in files:
click.secho(f"Success downgrade {file}", fg=Color.green)
@@ -182,26 +138,24 @@ async def downgrade(ctx: Context, version: int, delete: bool):
@click.pass_context
@coro
async def heads(ctx: Context):
app = ctx.obj["app"]
versions = Migrate.get_all_version_files()
is_heads = False
for version in versions:
if not await Aerich.exists(version=version, app=app):
command = ctx.obj["command"]
head_list = await command.heads()
if not head_list:
return click.secho("No available heads, try migrate first", fg=Color.green)
for version in head_list:
click.secho(version, fg=Color.green)
is_heads = True
if not is_heads:
click.secho("No available heads,try migrate first", fg=Color.green)
@cli.command(help="List all migrate items.")
@click.pass_context
@coro
async def history(ctx: Context):
versions = Migrate.get_all_version_files()
command = ctx.obj["command"]
versions = await command.history()
if not versions:
return click.secho("No history, try migrate", fg=Color.green)
for version in versions:
click.secho(version, fg=Color.green)
if not versions:
click.secho("No history,try migrate", fg=Color.green)
@cli.command(help="Init config file and generate root migrate location.")
@@ -217,77 +171,90 @@ async def history(ctx: Context):
show_default=True,
help="Migrate store location.",
)
@click.option(
"-s",
"--src_folder",
default=CONFIG_DEFAULT_VALUES["src_folder"],
show_default=False,
help="Folder of the source, relative to the project root.",
)
@click.pass_context
@coro
async def init(
ctx: Context,
tortoise_orm,
location,
):
async def init(ctx: Context, tortoise_orm, location, src_folder):
config_file = ctx.obj["config_file"]
name = ctx.obj["name"]
if os.path.exists(config_file):
return click.secho("You have inited", fg=Color.yellow)
parser.add_section(name)
parser.set(name, "tortoise_orm", tortoise_orm)
parser.set(name, "location", location)
if os.path.isabs(src_folder):
src_folder = os.path.relpath(os.getcwd(), src_folder)
# Add ./ so it's clear that this is relative path
if not src_folder.startswith("./"):
src_folder = "./" + src_folder
with open(config_file, "w", encoding="utf-8") as f:
parser.write(f)
# check that we can find the configuration, if not we can fail before the config file gets created
add_src_path(src_folder)
get_tortoise_config(ctx, tortoise_orm)
config_path = Path(config_file)
if config_path.exists():
content = config_path.read_text()
doc = tomlkit.parse(content)
else:
doc = tomlkit.parse("[tool.aerich]")
table = tomlkit.table()
table["tortoise_orm"] = tortoise_orm
table["location"] = location
table["src_folder"] = src_folder
doc["tool"]["aerich"] = table
if not os.path.isdir(location):
os.mkdir(location)
config_path.write_text(tomlkit.dumps(doc))
Path(location).mkdir(parents=True, exist_ok=True)
click.secho(f"Success create migrate location {location}", fg=Color.green)
click.secho(f"Success generate config file {config_file}", fg=Color.green)
click.secho(f"Success write config to {config_file}", fg=Color.green)
@cli.command(help="Generate schema and generate app migrate location.")
@click.option(
"-s",
"--safe",
type=bool,
is_flag=True,
default=True,
help="When set to true, creates the table only when it does not already exist.",
show_default=True,
)
@click.pass_context
@coro
async def init_db(ctx: Context, safe):
config = ctx.obj["config"]
location = ctx.obj["location"]
app = ctx.obj["app"]
dirname = os.path.join(location, app)
if not os.path.isdir(dirname):
os.mkdir(dirname)
async def init_db(ctx: Context, safe: bool):
command = ctx.obj["command"]
app = command.app
dirname = Path(command.location, app)
try:
await command.init_db(safe)
click.secho(f"Success create app migrate location {dirname}", fg=Color.green)
else:
click.secho(f'Success generate schema for app "{app}"', fg=Color.green)
except FileExistsError:
return click.secho(
f"Inited {app} already, or delete {dirname} and try again.", fg=Color.yellow
)
await Tortoise.init(config=config)
connection = get_app_connection(config, app)
await generate_schema_for_client(connection, safe)
schema = get_schema_sql(connection, safe)
version = await Migrate.generate_version()
await Aerich.create(
version=version,
app=app,
content=Migrate.get_models_content(config, app, location),
@cli.command(help="Introspects the database tables to standard output as TortoiseORM model.")
@click.option(
"-t",
"--table",
help="Which tables to inspect.",
multiple=True,
required=False,
)
content = {
"upgrade": [schema],
}
write_version_file(os.path.join(dirname, version), content)
click.secho(f'Success generate schema for app "{app}"', fg=Color.green)
@click.pass_context
@coro
async def inspectdb(ctx: Context, table: List[str]):
command = ctx.obj["command"]
ret = await command.inspectdb(table)
click.secho(ret)
def main():
sys.path.insert(0, ".")
cli()

31
aerich/coder.py Normal file
View File

@@ -0,0 +1,31 @@
import base64
import json
import pickle # nosec: B301,B403
from tortoise.indexes import Index
class JsonEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Index):
return {
"type": "index",
"val": base64.b64encode(pickle.dumps(obj)).decode(), # nosec: B301
}
else:
return super().default(obj)
def object_hook(obj):
_type = obj.get("type")
if not _type:
return obj
return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301
def encoder(obj: dict):
return json.dumps(obj, cls=JsonEncoder)
def decoder(obj: str):
return json.loads(obj, object_hook=object_hook)

View File

@@ -1,8 +1,10 @@
from enum import Enum
from typing import List, Type
from tortoise import BaseDBAsyncClient, ForeignKeyFieldInstance, ManyToManyFieldInstance, Model
from tortoise import BaseDBAsyncClient, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.fields import CASCADE, Field, JSONField, TextField, UUIDField
from aerich.utils import is_default_function
class BaseDDL:
@@ -11,6 +13,7 @@ class BaseDDL:
_DROP_TABLE_TEMPLATE = 'DROP TABLE IF EXISTS "{table_name}"'
_ADD_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {column}'
_DROP_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" DROP COLUMN "{column_name}"'
_ALTER_DEFAULT_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {default}'
_RENAME_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" RENAME COLUMN "{old_column_name}" TO "{new_column_name}"'
)
@@ -20,11 +23,17 @@ class BaseDDL:
_DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"'
_ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"'
_M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment};'
_M2M_TABLE_TEMPLATE = (
'CREATE TABLE "{table_name}" (\n'
' "{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,\n'
' "{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}\n'
"){extra}{comment}"
)
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}'
_CHANGE_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" CHANGE {old_column_name} {new_column_name} {new_column_type}'
)
_RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"'
def __init__(self, client: "BaseDBAsyncClient"):
self.client = client
@@ -33,43 +42,54 @@ class BaseDDL:
def create_table(self, model: "Type[Model]"):
return self.schema_generator._get_table_sql(model, True)["table_creation_string"]
def drop_table(self, model: "Type[Model]"):
return self._DROP_TABLE_TEMPLATE.format(table_name=model._meta.db_table)
def drop_table(self, table_name: str):
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def create_m2m_table(self, model: "Type[Model]", field: ManyToManyFieldInstance):
def create_m2m(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
):
through = field_describe.get("through")
description = field_describe.get("description")
reference_id = reference_table_describe.get("pk_field").get("db_column")
db_field_types = reference_table_describe.get("pk_field").get("db_field_types")
return self._M2M_TABLE_TEMPLATE.format(
table_name=field.through,
table_name=through,
backward_table=model._meta.db_table,
forward_table=field.related_model._meta.db_table,
forward_table=reference_table_describe.get("table"),
backward_field=model._meta.db_pk_column,
forward_field=field.related_model._meta.db_pk_column,
backward_key=field.backward_key,
forward_field=reference_id,
backward_key=field_describe.get("backward_key"),
backward_type=model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"),
forward_key=field.forward_key,
forward_type=field.related_model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"),
on_delete=CASCADE,
extra=self.schema_generator._table_generate_extra(table=field.through),
forward_key=field_describe.get("forward_key"),
forward_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
on_delete=field_describe.get("on_delete"),
extra=self.schema_generator._table_generate_extra(table=through),
comment=self.schema_generator._table_comment_generator(
table=field.through, comment=field.description
table=through, comment=description
)
if field.description
if description
else "",
)
def drop_m2m(self, field: ManyToManyFieldInstance):
return self._DROP_TABLE_TEMPLATE.format(table_name=field.through)
def drop_m2m(self, table_name: str):
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def _get_default(self, model: "Type[Model]", field_object: Field):
def _get_default(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table
default = field_object.default
db_column = field_object.model_field_name
auto_now_add = getattr(field_object, "auto_now_add", False)
auto_now = getattr(field_object, "auto_now", False)
default = field_describe.get("default")
if isinstance(default, Enum):
default = default.value
db_column = field_describe.get("db_column")
auto_now_add = field_describe.get("auto_now_add", False)
auto_now = field_describe.get("auto_now", False)
if default is not None or auto_now_add:
if callable(default) or isinstance(field_object, (UUIDField, TextField, JSONField)):
if field_describe.get("field_type") in [
"UUIDField",
"TextField",
"JSONField",
] or is_default_function(default):
default = ""
else:
default = field_object.to_db_value(default, model)
try:
default = self.schema_generator._column_default_generator(
db_table,
@@ -81,28 +101,33 @@ class BaseDDL:
except NotImplementedError:
default = ""
else:
default = ""
default = None
return default
def add_column(self, model: "Type[Model]", field_object: Field):
def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
db_table = model._meta.db_table
description = field_describe.get("description")
db_column = field_describe.get("db_column")
db_field_types = field_describe.get("db_field_types")
default = self._get_default(model, field_describe)
if default is None:
default = ""
return self._ADD_COLUMN_TEMPLATE.format(
table_name=db_table,
column=self.schema_generator._create_string(
db_column=field_object.model_field_name,
field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"),
nullable="NOT NULL" if not field_object.null else "",
unique="UNIQUE" if field_object.unique else "",
db_column=db_column,
field_type=db_field_types.get(self.DIALECT, db_field_types.get("")),
nullable="NOT NULL" if not field_describe.get("nullable") else "",
unique="UNIQUE" if field_describe.get("unique") else "",
comment=self.schema_generator._column_comment_generator(
table=db_table,
column=field_object.model_field_name,
comment=field_object.description,
column=db_column,
comment=field_describe.get("description"),
)
if field_object.description
if description
else "",
is_primary_key=field_object.pk,
default=self._get_default(model, field_object),
is_primary_key=is_pk,
default=default,
),
)
@@ -111,24 +136,28 @@ class BaseDDL:
table_name=model._meta.db_table, column_name=column_name
)
def modify_column(self, model: "Type[Model]", field_object: Field):
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
db_table = model._meta.db_table
db_field_types = field_describe.get("db_field_types")
default = self._get_default(model, field_describe)
if default is None:
default = ""
return self._MODIFY_COLUMN_TEMPLATE.format(
table_name=db_table,
column=self.schema_generator._create_string(
db_column=field_object.model_field_name,
field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"),
nullable="NOT NULL" if not field_object.null else "",
db_column=field_describe.get("db_column"),
field_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
nullable="NOT NULL" if not field_describe.get("nullable") else "",
unique="",
comment=self.schema_generator._column_comment_generator(
table=db_table,
column=field_object.model_field_name,
comment=field_object.description,
column=field_describe.get("db_column"),
comment=field_describe.get("description"),
)
if field_object.description
if field_describe.get("description")
else "",
is_primary_key=field_object.pk,
default=self._get_default(model, field_object),
is_primary_key=is_pk,
default=default,
),
)
@@ -156,7 +185,7 @@ class BaseDDL:
"idx" if not unique else "uid", model, field_names
),
table_name=model._meta.db_table,
column_names=", ".join([self.schema_generator.quote(f) for f in field_names]),
column_names=", ".join(self.schema_generator.quote(f) for f in field_names),
)
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False):
@@ -167,48 +196,61 @@ class BaseDDL:
table_name=model._meta.db_table,
)
def add_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance):
db_table = model._meta.db_table
to_field_name = field.to_field_instance.source_field
if not to_field_name:
to_field_name = field.to_field_instance.model_field_name
def drop_index_by_name(self, model: "Type[Model]", index_name: str):
return self._DROP_INDEX_TEMPLATE.format(
index_name=index_name,
table_name=model._meta.db_table,
)
db_column = field.source_field or field.model_field_name + "_id"
def add_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict):
db_table = model._meta.db_table
db_column = field_describe.get("raw_field")
reference_id = reference_table_describe.get("pk_field").get("db_column")
fk_name = self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=db_column,
to_table=field.related_model._meta.db_table,
to_field=to_field_name,
to_table=reference_table_describe.get("table"),
to_field=reference_table_describe.get("pk_field").get("db_column"),
)
return self._ADD_FK_TEMPLATE.format(
table_name=db_table,
fk_name=fk_name,
db_column=db_column,
table=field.related_model._meta.db_table,
field=to_field_name,
on_delete=field.on_delete,
table=reference_table_describe.get("table"),
field=reference_id,
on_delete=field_describe.get("on_delete"),
)
def drop_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance):
to_field_name = field.to_field_instance.source_field
if not to_field_name:
to_field_name = field.to_field_instance.model_field_name
def drop_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict):
db_table = model._meta.db_table
return self._DROP_FK_TEMPLATE.format(
table_name=db_table,
fk_name=self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=field.source_field or field.model_field_name + "_id",
to_table=field.related_model._meta.db_table,
to_field=to_field_name,
from_field=field_describe.get("raw_field"),
to_table=reference_table_describe.get("table"),
to_field=reference_table_describe.get("pk_field").get("db_column"),
),
)
def alter_column_default(self, model: "Type[Model]", field_object: Field):
pass
def alter_column_default(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table
default = self._get_default(model, field_describe)
return self._ALTER_DEFAULT_TEMPLATE.format(
table_name=db_table,
column=field_describe.get("db_column"),
default="SET" + default if default is not None else "DROP DEFAULT",
)
def alter_column_null(self, model: "Type[Model]", field_object: Field):
pass
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
return self.modify_column(model, field_describe)
def set_comment(self, model: "Type[Model]", field_object: Field):
pass
def set_comment(self, model: "Type[Model]", field_describe: dict):
return self.modify_column(model, field_describe)
def rename_table(self, model: "Type[Model]", old_table_name: str, new_table_name: str):
db_table = model._meta.db_table
return self._RENAME_TABLE_TEMPLATE.format(
table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name
)

View File

@@ -8,6 +8,10 @@ class MysqlDDL(BaseDDL):
DIALECT = MySQLSchemaGenerator.DIALECT
_DROP_TABLE_TEMPLATE = "DROP TABLE IF EXISTS `{table_name}`"
_ADD_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` ADD {column}"
_ALTER_DEFAULT_TEMPLATE = "ALTER TABLE `{table_name}` ALTER COLUMN `{column}` {default}"
_CHANGE_COLUMN_TEMPLATE = (
"ALTER TABLE `{table_name}` CHANGE {old_column_name} {new_column_name} {new_column_type}"
)
_DROP_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` DROP COLUMN `{column_name}`"
_RENAME_COLUMN_TEMPLATE = (
"ALTER TABLE `{table_name}` RENAME COLUMN `{old_column_name}` TO `{new_column_name}`"
@@ -18,5 +22,11 @@ class MysqlDDL(BaseDDL):
_DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`"
_ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}"
_DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`"
_M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment};"
_M2M_TABLE_TEMPLATE = (
"CREATE TABLE `{table_name}` (\n"
" `{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,\n"
" `{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE\n"
"){extra}{comment}"
)
_MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}"
_RENAME_TABLE_TEMPLATE = "ALTER TABLE `{old_table_name}` RENAME TO `{new_table_name}`"

View File

@@ -1,8 +1,7 @@
from typing import List, Type
from typing import Type
from tortoise import Model
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
from tortoise.fields import Field
from aerich.ddl import BaseDDL
@@ -10,66 +9,41 @@ from aerich.ddl import BaseDDL
class PostgresDDL(BaseDDL):
schema_generator_cls = AsyncpgSchemaGenerator
DIALECT = AsyncpgSchemaGenerator.DIALECT
_ADD_INDEX_TEMPLATE = 'CREATE INDEX "{index_name}" ON "{table_name}" ({column_names})'
_ADD_UNIQUE_TEMPLATE = (
'ALTER TABLE "{table_name}" ADD CONSTRAINT "{index_name}" UNIQUE ({column_names})'
)
_ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX "{index_name}" ON "{table_name}" ({column_names})'
_DROP_INDEX_TEMPLATE = 'DROP INDEX "{index_name}"'
_DROP_UNIQUE_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{index_name}"'
_ALTER_DEFAULT_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {default}'
_ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL'
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}'
_MODIFY_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}{using}'
)
_SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"'
def alter_column_default(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table
default = self._get_default(model, field_object)
return self._ALTER_DEFAULT_TEMPLATE.format(
table_name=db_table,
column=field_object.model_field_name,
default="SET" + default if default else "DROP DEFAULT",
)
def alter_column_null(self, model: "Type[Model]", field_object: Field):
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table
return self._ALTER_NULL_TEMPLATE.format(
table_name=db_table,
column=field_object.model_field_name,
set_drop="DROP" if field_object.null else "SET",
column=field_describe.get("db_column"),
set_drop="DROP" if field_describe.get("nullable") else "SET",
)
def modify_column(self, model: "Type[Model]", field_object: Field):
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
db_table = model._meta.db_table
db_field_types = field_describe.get("db_field_types")
db_column = field_describe.get("db_column")
datatype = db_field_types.get(self.DIALECT) or db_field_types.get("")
return self._MODIFY_COLUMN_TEMPLATE.format(
table_name=db_table,
column=field_object.model_field_name,
datatype=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"),
column=db_column,
datatype=datatype,
using=f' USING "{db_column}"::{datatype}',
)
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
template = self._ADD_UNIQUE_TEMPLATE if unique else self._ADD_INDEX_TEMPLATE
return template.format(
index_name=self.schema_generator._generate_index_name(
"uid" if unique else "idx", model, field_names
),
table_name=model._meta.db_table,
column_names=", ".join([self.schema_generator.quote(f) for f in field_names]),
)
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False):
template = self._DROP_UNIQUE_TEMPLATE if unique else self._DROP_INDEX_TEMPLATE
return template.format(
index_name=self.schema_generator._generate_index_name(
"uid" if unique else "idx", model, field_names
),
table_name=model._meta.db_table,
)
def set_comment(self, model: "Type[Model]", field_object: Field):
def set_comment(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table
return self._SET_COMMENT_TEMPLATE.format(
table_name=db_table,
column=field_object.model_field_name,
comment="'{}'".format(field_object.description) if field_object.description else "NULL",
column=field_describe.get("db_column") or field_describe.get("raw_field"),
comment="'{}'".format(field_describe.get("description"))
if field_describe.get("description")
else "NULL",
)

View File

@@ -2,7 +2,6 @@ from typing import Type
from tortoise import Model
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
from tortoise.fields import Field
from aerich.ddl import BaseDDL
from aerich.exceptions import NotSupportError
@@ -12,8 +11,14 @@ class SqliteDDL(BaseDDL):
schema_generator_cls = SqliteSchemaGenerator
DIALECT = SqliteSchemaGenerator.DIALECT
def drop_column(self, model: "Type[Model]", column_name: str):
raise NotSupportError("Drop column is unsupported in SQLite.")
def modify_column(self, model: "Type[Model]", field_object: Field):
def modify_column(self, model: "Type[Model]", field_object: dict, is_pk: bool = True):
raise NotSupportError("Modify column is unsupported in SQLite.")
def alter_column_default(self, model: "Type[Model]", field_describe: dict):
raise NotSupportError("Alter column default is unsupported in SQLite.")
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
raise NotSupportError("Alter column null is unsupported in SQLite.")
def set_comment(self, model: "Type[Model]", field_describe: dict):
raise NotSupportError("Alter column comment is unsupported in SQLite.")

View File

@@ -2,3 +2,9 @@ class NotSupportError(Exception):
"""
raise when features not support
"""
class DowngradeError(Exception):
"""
raise when downgrade error
"""

View File

@@ -0,0 +1,168 @@
from typing import Any, List, Optional
from pydantic import BaseModel
from tortoise import BaseDBAsyncClient
class Column(BaseModel):
name: str
data_type: str
null: bool
default: Any
comment: Optional[str]
pk: bool
unique: bool
index: bool
length: Optional[int]
extra: Optional[str]
decimal_places: Optional[int]
max_digits: Optional[int]
def translate(self) -> dict:
comment = default = length = index = null = pk = ""
if self.pk:
pk = "pk=True, "
else:
if self.unique:
index = "unique=True, "
else:
if self.index:
index = "index=True, "
if self.data_type in ["varchar", "VARCHAR"]:
length = f"max_length={self.length}, "
if self.data_type in ["decimal", "numeric"]:
length_parts = []
if self.max_digits:
length_parts.append(f"max_digits={self.max_digits}")
if self.decimal_places:
length_parts.append(f"decimal_places={self.decimal_places}")
length = ", ".join(length_parts)
if self.null:
null = "null=True, "
if self.default is not None:
if self.data_type in ["tinyint", "INT"]:
default = f"default={'True' if self.default == '1' else 'False'}, "
elif self.data_type == "bool":
default = f"default={'True' if self.default == 'true' else 'False'}, "
elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]:
if "CURRENT_TIMESTAMP" == self.default:
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
default = "auto_now=True, "
else:
default = "auto_now_add=True, "
else:
if "::" in self.default:
default = f"default={self.default.split('::')[0]}, "
elif self.default.endswith("()"):
default = ""
else:
default = f"default={self.default}, "
if self.comment:
comment = f"description='{self.comment}', "
return {
"name": self.name,
"pk": pk,
"index": index,
"null": null,
"default": default,
"length": length,
"comment": comment,
}
class Inspect:
_table_template = "class {table}(Model):\n"
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
self.conn = conn
try:
self.database = conn.database
except AttributeError:
pass
self.tables = tables
@property
def field_map(self) -> dict:
raise NotImplementedError
async def inspect(self) -> str:
if not self.tables:
self.tables = await self.get_all_tables()
result = "from tortoise import Model, fields\n\n\n"
tables = []
for table in self.tables:
columns = await self.get_columns(table)
fields = []
model = self._table_template.format(table=table.title().replace("_", ""))
for column in columns:
field = self.field_map[column.data_type](**column.translate())
fields.append(" " + field)
tables.append(model + "\n".join(fields))
return result + "\n\n\n".join(tables)
async def get_columns(self, table: str) -> List[Column]:
raise NotImplementedError
async def get_all_tables(self) -> List[str]:
raise NotImplementedError
@classmethod
def decimal_field(cls, **kwargs) -> str:
return "{name} = fields.DecimalField({pk}{index}{length}{null}{default}{comment})".format(
**kwargs
)
@classmethod
def time_field(cls, **kwargs) -> str:
return "{name} = fields.TimeField({null}{default}{comment})".format(**kwargs)
@classmethod
def date_field(cls, **kwargs) -> str:
return "{name} = fields.DateField({null}{default}{comment})".format(**kwargs)
@classmethod
def float_field(cls, **kwargs) -> str:
return "{name} = fields.FloatField({null}{default}{comment})".format(**kwargs)
@classmethod
def datetime_field(cls, **kwargs) -> str:
return "{name} = fields.DatetimeField({null}{default}{comment})".format(**kwargs)
@classmethod
def text_field(cls, **kwargs) -> str:
return "{name} = fields.TextField({null}{default}{comment})".format(**kwargs)
@classmethod
def char_field(cls, **kwargs) -> str:
return "{name} = fields.CharField({pk}{index}{length}{null}{default}{comment})".format(
**kwargs
)
@classmethod
def int_field(cls, **kwargs) -> str:
return "{name} = fields.IntField({pk}{index}{comment})".format(**kwargs)
@classmethod
def smallint_field(cls, **kwargs) -> str:
return "{name} = fields.SmallIntField({pk}{index}{comment})".format(**kwargs)
@classmethod
def bigint_field(cls, **kwargs) -> str:
return "{name} = fields.BigIntField({pk}{index}{default}{comment})".format(**kwargs)
@classmethod
def bool_field(cls, **kwargs) -> str:
return "{name} = fields.BooleanField({null}{default}{comment})".format(**kwargs)
@classmethod
def uuid_field(cls, **kwargs) -> str:
return "{name} = fields.UUIDField({pk}{index}{default}{comment})".format(**kwargs)
@classmethod
def json_field(cls, **kwargs) -> str:
return "{name} = fields.JSONField({null}{default}{comment})".format(**kwargs)
@classmethod
def binary_field(cls, **kwargs) -> str:
return "{name} = fields.BinaryField({null}{default}{comment})".format(**kwargs)

69
aerich/inspectdb/mysql.py Normal file
View File

@@ -0,0 +1,69 @@
from typing import List
from aerich.inspectdb import Column, Inspect
class InspectMySQL(Inspect):
@property
def field_map(self) -> dict:
return {
"int": self.int_field,
"smallint": self.smallint_field,
"tinyint": self.bool_field,
"bigint": self.bigint_field,
"varchar": self.char_field,
"longtext": self.text_field,
"text": self.text_field,
"datetime": self.datetime_field,
"float": self.float_field,
"date": self.date_field,
"time": self.time_field,
"decimal": self.decimal_field,
"json": self.json_field,
"longblob": self.binary_field,
}
async def get_all_tables(self) -> List[str]:
sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s"
ret = await self.conn.execute_query_dict(sql, [self.database])
return list(map(lambda x: x["TABLE_NAME"], ret))
async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME
from information_schema.COLUMNS c
left join information_schema.STATISTICS s on c.TABLE_NAME = s.TABLE_NAME
and c.TABLE_SCHEMA = s.TABLE_SCHEMA
and c.COLUMN_NAME = s.COLUMN_NAME
where c.TABLE_SCHEMA = %s
and c.TABLE_NAME = %s"""
ret = await self.conn.execute_query_dict(sql, [self.database, table])
for row in ret:
non_unique = row["NON_UNIQUE"]
if non_unique is None:
unique = False
else:
unique = not non_unique
index_name = row["INDEX_NAME"]
if index_name is None:
index = False
else:
index = row["INDEX_NAME"] != "PRIMARY"
columns.append(
Column(
name=row["COLUMN_NAME"],
data_type=row["DATA_TYPE"],
null=row["IS_NULLABLE"] == "YES",
default=row["COLUMN_DEFAULT"],
pk=row["COLUMN_KEY"] == "PRI",
comment=row["COLUMN_COMMENT"],
unique=row["COLUMN_KEY"] == "UNI",
extra=row["EXTRA"],
unque=unique,
index=index,
length=row["CHARACTER_MAXIMUM_LENGTH"],
max_digits=row["NUMERIC_PRECISION"],
decimal_places=row["NUMERIC_SCALE"],
)
)
return columns

View File

@@ -0,0 +1,76 @@
from typing import List, Optional
from tortoise import BaseDBAsyncClient
from aerich.inspectdb import Column, Inspect
class InspectPostgres(Inspect):
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
super().__init__(conn, tables)
self.schema = self.conn.server_settings.get("schema") or "public"
@property
def field_map(self) -> dict:
return {
"int4": self.int_field,
"int8": self.int_field,
"smallint": self.smallint_field,
"varchar": self.char_field,
"text": self.text_field,
"bigint": self.bigint_field,
"timestamptz": self.datetime_field,
"float4": self.float_field,
"float8": self.float_field,
"date": self.date_field,
"time": self.time_field,
"decimal": self.decimal_field,
"numeric": self.decimal_field,
"uuid": self.uuid_field,
"jsonb": self.json_field,
"bytea": self.binary_field,
"bool": self.bool_field,
"timestamp": self.datetime_field,
}
async def get_all_tables(self) -> List[str]:
sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2"
ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
return list(map(lambda x: x["table_name"], ret))
async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = f"""select c.column_name,
col_description('public.{table}'::regclass, ordinal_position) as column_comment,
t.constraint_type as column_key,
udt_name as data_type,
is_nullable,
column_default,
character_maximum_length,
numeric_precision,
numeric_scale
from information_schema.constraint_column_usage const
join information_schema.table_constraints t
using (table_catalog, table_schema, table_name, constraint_catalog, constraint_schema, constraint_name)
right join information_schema.columns c using (column_name, table_catalog, table_schema, table_name)
where c.table_catalog = $1
and c.table_name = $2
and c.table_schema = $3"""
ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema])
for row in ret:
columns.append(
Column(
name=row["column_name"],
data_type=row["data_type"],
null=row["is_nullable"] == "YES",
default=row["column_default"],
length=row["character_maximum_length"],
max_digits=row["numeric_precision"],
decimal_places=row["numeric_scale"],
comment=row["column_comment"],
pk=row["column_key"] == "PRIMARY KEY",
unique=False, # can't get this simply
index=False, # can't get this simply
)
)
return columns

View File

@@ -0,0 +1,61 @@
from typing import List
from aerich.inspectdb import Column, Inspect
class InspectSQLite(Inspect):
@property
def field_map(self) -> dict:
return {
"INTEGER": self.int_field,
"INT": self.bool_field,
"SMALLINT": self.smallint_field,
"VARCHAR": self.char_field,
"TEXT": self.text_field,
"TIMESTAMP": self.datetime_field,
"REAL": self.float_field,
"BIGINT": self.bigint_field,
"DATE": self.date_field,
"TIME": self.time_field,
"JSON": self.json_field,
"BLOB": self.binary_field,
}
async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = f"PRAGMA table_info({table})"
ret = await self.conn.execute_query_dict(sql)
columns_index = await self._get_columns_index(table)
for row in ret:
try:
length = row["type"].split("(")[1].split(")")[0]
except IndexError:
length = None
columns.append(
Column(
name=row["name"],
data_type=row["type"].split("(")[0],
null=row["notnull"] == 0,
default=row["dflt_value"],
length=length,
pk=row["pk"] == 1,
unique=columns_index.get(row["name"]) == "unique",
index=columns_index.get(row["name"]) == "index",
)
)
return columns
async def _get_columns_index(self, table: str):
sql = f"PRAGMA index_list ({table})"
indexes = await self.conn.execute_query_dict(sql)
ret = {}
for index in indexes:
sql = f"PRAGMA index_info({index['name']})"
index_info = (await self.conn.execute_query_dict(sql))[0]
ret[index_info["name"]] = "unique" if index["unique"] else "index"
return ret
async def get_all_tables(self) -> List[str]:
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
ret = await self.conn.execute_query_dict(sql)
return list(map(lambda x: x["tbl_name"], ret))

View File

@@ -1,29 +1,36 @@
import inspect
import importlib
import os
import re
from datetime import datetime
from importlib import import_module
from io import StringIO
from hashlib import md5
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type, Union
import click
from packaging import version
from packaging.version import LegacyVersion, Version
from tortoise import (
BackwardFKRelation,
BackwardOneToOneRelation,
BaseDBAsyncClient,
ForeignKeyFieldInstance,
ManyToManyFieldInstance,
Model,
Tortoise,
)
from dictdiffer import diff
from tortoise import BaseDBAsyncClient, Model, Tortoise
from tortoise.exceptions import OperationalError
from tortoise.fields import Field
from tortoise.indexes import Index
from aerich.ddl import BaseDDL
from aerich.models import MAX_VERSION_LENGTH, Aerich
from aerich.utils import get_app_connection, write_version_file
from aerich.utils import get_app_connection, get_models_describe, is_default_function
MIGRATE_TEMPLATE = """from typing import List
from tortoise import BaseDBAsyncClient
async def upgrade(db: BaseDBAsyncClient) -> List[str]:
return [
{upgrade_sql}
]
async def downgrade(db: BaseDBAsyncClient) -> List[str]:
return [
{downgrade_sql}
]
"""
class Migrate:
@@ -38,25 +45,23 @@ class Migrate:
_rename_new = []
ddl: BaseDDL
migrate_config: dict
old_models = "old_models"
diff_app = "diff_models"
_last_version_content: Optional[dict] = None
app: str
migrate_location: str
migrate_location: Path
dialect: str
_db_version: Union[LegacyVersion, Version] = None
_db_version: Optional[str] = None
@classmethod
def get_old_model_file(cls, app: str, location: str):
return os.path.join(location, app, cls.old_models + ".py")
@classmethod
def get_all_version_files(cls) -> List[str]:
def get_all_version_files(cls) -> list[str]:
return sorted(
filter(lambda x: x.endswith("sql"), os.listdir(cls.migrate_location)),
filter(lambda x: x.endswith("py"), os.listdir(cls.migrate_location)),
key=lambda x: int(x.split("_")[0]),
)
@classmethod
def _get_model(cls, model: str) -> Type[Model]:
return Tortoise.apps.get(cls.app).get(model)
@classmethod
async def get_last_version(cls) -> Optional[Aerich]:
try:
@@ -64,49 +69,31 @@ class Migrate:
except OperationalError:
pass
@classmethod
def remove_old_model_file(cls, app: str, location: str):
try:
os.unlink(cls.get_old_model_file(app, location))
except (OSError, FileNotFoundError):
pass
@classmethod
async def _get_db_version(cls, connection: BaseDBAsyncClient):
if cls.dialect == "mysql":
sql = "select version() as version"
ret = await connection.execute_query(sql)
cls._db_version = version.parse(ret[1][0].get("version"))
cls._db_version = ret[1][0].get("version")
@classmethod
async def init_with_old_models(cls, config: dict, app: str, location: str):
async def load_ddl_class(cls):
ddl_dialect_module = importlib.import_module(f"aerich.ddl.{cls.dialect}")
return getattr(ddl_dialect_module, f"{cls.dialect.capitalize()}DDL")
@classmethod
async def init(cls, config: dict, app: str, location: str):
await Tortoise.init(config=config)
last_version = await cls.get_last_version()
cls.app = app
cls.migrate_location = os.path.join(location, app)
cls.migrate_location = Path(location, app)
if last_version:
content = last_version.content
with open(cls.get_old_model_file(app, location), "w", encoding="utf-8") as f:
f.write(content)
migrate_config = cls._get_migrate_config(config, app, location)
cls.migrate_config = migrate_config
await Tortoise.init(config=migrate_config)
cls._last_version_content = last_version.content
connection = get_app_connection(config, app)
if cls.dialect == "mysql":
from aerich.ddl.mysql import MysqlDDL
cls.ddl = MysqlDDL(connection)
elif cls.dialect == "sqlite":
from aerich.ddl.sqlite import SqliteDDL
cls.ddl = SqliteDDL(connection)
elif cls.dialect == "postgres":
from aerich.ddl.postgres import PostgresDDL
cls.ddl = PostgresDDL(connection)
cls.dialect = cls.ddl.DIALECT
cls.dialect = connection.schema_generator.DIALECT
cls.ddl_class = await cls.load_ddl_class()
cls.ddl = cls.ddl_class(connection)
await cls._get_db_version(connection)
@classmethod
@@ -122,24 +109,27 @@ class Migrate:
now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "")
last_version_num = await cls._get_last_version_num()
if last_version_num is None:
return f"0_{now}_init.sql"
version = f"{last_version_num + 1}_{now}_{name}.sql"
return f"0_{now}_init.py"
version = f"{last_version_num + 1}_{now}_{name}.py"
if len(version) > MAX_VERSION_LENGTH:
raise ValueError(f"Version name exceeds maximum length ({MAX_VERSION_LENGTH})")
return version
@classmethod
async def _generate_diff_sql(cls, name):
async def _generate_diff_py(cls, name):
version = await cls.generate_version(name)
# delete if same version exists
for version_file in cls.get_all_version_files():
if version_file.startswith(version.split("_")[0]):
os.unlink(os.path.join(cls.migrate_location, version_file))
content = {
"upgrade": cls.upgrade_operators,
"downgrade": cls.downgrade_operators,
}
write_version_file(os.path.join(cls.migrate_location, version), content)
os.unlink(Path(cls.migrate_location, version_file))
version_file = Path(cls.migrate_location, version)
content = MIGRATE_TEMPLATE.format(
upgrade_sql=",\n ".join(map(lambda x: f"'{x}'", cls.upgrade_operators)),
downgrade_sql=",\n ".join(map(lambda x: f"'{x}'", cls.downgrade_operators)),
)
with open(version_file, "w", encoding="utf-8") as f:
f.write(content)
return version
@classmethod
@@ -149,91 +139,51 @@ class Migrate:
:param name:
:return:
"""
apps = Tortoise.apps
diff_models = apps.get(cls.diff_app)
app_models = apps.get(cls.app)
cls.diff_models(diff_models, app_models)
cls.diff_models(app_models, diff_models, False)
new_version_content = get_models_describe(cls.app)
cls.diff_models(cls._last_version_content, new_version_content)
cls.diff_models(new_version_content, cls._last_version_content, False)
cls._merge_operators()
if not cls.upgrade_operators:
return ""
return await cls._generate_diff_sql(name)
return await cls._generate_diff_py(name)
@classmethod
def _add_operator(cls, operator: str, upgrade=True, fk_m2m=False):
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False):
"""
add operator,differentiate fk because fk is order limit
:param operator:
:param upgrade:
:param fk_m2m:
:param fk_m2m_index:
:return:
"""
if upgrade:
if fk_m2m:
if fk_m2m_index:
cls._upgrade_fk_m2m_index_operators.append(operator)
else:
cls.upgrade_operators.append(operator)
else:
if fk_m2m:
if fk_m2m_index:
cls._downgrade_fk_m2m_index_operators.append(operator)
else:
cls.downgrade_operators.append(operator)
@classmethod
def _get_migrate_config(cls, config: dict, app: str, location: str):
"""
generate tmp config with old models
:param config:
:param app:
:param location:
:return:
"""
path = os.path.join(location, app, cls.old_models)
path = path.replace(os.sep, ".").lstrip(".")
config["apps"][cls.diff_app] = {
"models": [path],
"default_connection": config.get("apps").get(app).get("default_connection", "default"),
}
return config
def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]):
ret = []
for index in indexes:
if isinstance(index, Index):
index.__hash__ = lambda self: md5( # nosec: B303
self.index_name(cls.ddl.schema_generator, model).encode()
+ self.__class__.__name__.encode()
).hexdigest()
ret.append(index)
return ret
@classmethod
def get_models_content(cls, config: dict, app: str, location: str):
"""
write new models to old models
:param config:
:param app:
:param location:
:return:
"""
old_model_files = []
models = config.get("apps").get(app).get("models")
for model in models:
module = import_module(model)
possible_models = [getattr(module, attr_name) for attr_name in dir(module)]
for attr in filter(
lambda x: inspect.isclass(x) and issubclass(x, Model) and x is not Model,
possible_models,
):
file = inspect.getfile(attr)
if file not in old_model_files:
old_model_files.append(file)
pattern = rf"(\n)?('|\")({app})(.\w+)('|\")"
str_io = StringIO()
for i, model_file in enumerate(old_model_files):
with open(model_file, "r", encoding="utf-8") as f:
content = f.read()
ret = re.sub(pattern, rf"\2{cls.diff_app}\4\5", content)
str_io.write(f"{ret}\n")
return str_io.getvalue()
@classmethod
def diff_models(
cls, old_models: Dict[str, Type[Model]], new_models: Dict[str, Type[Model]], upgrade=True
):
def diff_models(cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True):
"""
diff models and add operators
:param old_models:
@@ -241,304 +191,404 @@ class Migrate:
:param upgrade:
:return:
"""
old_models.pop(cls._aerich, None)
new_models.pop(cls._aerich, None)
_aerich = f"{cls.app}.{cls._aerich}"
old_models.pop(_aerich, None)
new_models.pop(_aerich, None)
for new_model_str, new_model_describe in new_models.items():
model = cls._get_model(new_model_describe.get("name").split(".")[1])
for new_model_str, new_model in new_models.items():
if new_model_str not in old_models.keys():
cls._add_operator(cls.add_model(new_model), upgrade)
if upgrade:
cls._add_operator(cls.add_model(model), upgrade)
else:
cls.diff_model(old_models.get(new_model_str), new_model, upgrade)
# we can't find origin model when downgrade, so skip
pass
else:
old_model_describe = old_models.get(new_model_str)
# rename table
new_table = new_model_describe.get("table")
old_table = old_model_describe.get("table")
if new_table != old_table:
cls._add_operator(cls.rename_table(model, old_table, new_table), upgrade)
old_unique_together = set(
map(lambda x: tuple(x), old_model_describe.get("unique_together"))
)
new_unique_together = set(
map(lambda x: tuple(x), new_model_describe.get("unique_together"))
)
old_indexes = set(
map(
lambda x: x if isinstance(x, Index) else tuple(x),
cls._handle_indexes(model, old_model_describe.get("indexes", [])),
)
)
new_indexes = set(
map(
lambda x: x if isinstance(x, Index) else tuple(x),
cls._handle_indexes(model, new_model_describe.get("indexes", [])),
)
)
old_pk_field = old_model_describe.get("pk_field")
new_pk_field = new_model_describe.get("pk_field")
# pk field
changes = diff(old_pk_field, new_pk_field)
for action, option, change in changes:
# current only support rename pk
if action == "change" and option == "name":
cls._add_operator(cls._rename_field(model, *change), upgrade)
# m2m fields
old_m2m_fields = old_model_describe.get("m2m_fields")
new_m2m_fields = new_model_describe.get("m2m_fields")
for action, option, change in diff(old_m2m_fields, new_m2m_fields):
if change[0][0] == "db_constraint":
continue
table = change[0][1].get("through")
if action == "add":
add = False
if upgrade and table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table)
add = True
elif not upgrade and table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table)
add = True
if add:
cls._add_operator(
cls.create_m2m(
model,
change[0][1],
new_models.get(change[0][1].get("model_name")),
),
upgrade,
fk_m2m_index=True,
)
elif action == "remove":
add = False
if upgrade and table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table)
add = True
elif not upgrade and table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table)
add = True
if add:
cls._add_operator(cls.drop_m2m(table), upgrade, True)
# add unique_together
for index in new_unique_together.difference(old_unique_together):
cls._add_operator(cls._add_index(model, index, True), upgrade, True)
# remove unique_together
for index in old_unique_together.difference(new_unique_together):
cls._add_operator(cls._drop_index(model, index, True), upgrade, True)
# add indexes
for index in new_indexes.difference(old_indexes):
cls._add_operator(cls._add_index(model, index, False), upgrade, True)
# remove indexes
for index in old_indexes.difference(new_indexes):
cls._add_operator(cls._drop_index(model, index, False), upgrade, True)
old_data_fields = old_model_describe.get("data_fields")
new_data_fields = new_model_describe.get("data_fields")
old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields))
new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields))
# add fields or rename fields
for new_data_field_name in set(new_data_fields_name).difference(
set(old_data_fields_name)
):
new_data_field = next(
filter(lambda x: x.get("name") == new_data_field_name, new_data_fields)
)
is_rename = False
for old_data_field in old_data_fields:
changes = list(diff(old_data_field, new_data_field))
old_data_field_name = old_data_field.get("name")
if len(changes) == 2:
# rename field
if (
changes[0]
== (
"change",
"name",
(old_data_field_name, new_data_field_name),
)
and changes[1]
== (
"change",
"db_column",
(
old_data_field.get("db_column"),
new_data_field.get("db_column"),
),
)
and old_data_field_name not in new_data_fields_name
):
if upgrade:
is_rename = click.prompt(
f"Rename {old_data_field_name} to {new_data_field_name}?",
default=True,
type=bool,
show_choices=True,
)
else:
is_rename = old_data_field_name in cls._rename_new
if is_rename:
cls._rename_new.append(new_data_field_name)
cls._rename_old.append(old_data_field_name)
# only MySQL8+ has rename syntax
if (
cls.dialect == "mysql"
and cls._db_version
and cls._db_version.startswith("5.")
):
cls._add_operator(
cls._change_field(
model, old_data_field, new_data_field
),
upgrade,
)
else:
cls._add_operator(
cls._rename_field(model, *changes[1][2]),
upgrade,
)
if not is_rename:
cls._add_operator(
cls._add_field(
model,
new_data_field,
),
upgrade,
)
if new_data_field["indexed"]:
cls._add_operator(
cls._add_index(
model, {new_data_field["db_column"]}, new_data_field["unique"]
),
upgrade,
True,
)
# remove fields
for old_data_field_name in set(old_data_fields_name).difference(
set(new_data_fields_name)
):
# don't remove field if is renamed
if (upgrade and old_data_field_name in cls._rename_old) or (
not upgrade and old_data_field_name in cls._rename_new
):
continue
old_data_field = next(
filter(lambda x: x.get("name") == old_data_field_name, old_data_fields)
)
db_column = old_data_field["db_column"]
cls._add_operator(
cls._remove_field(
model,
db_column,
),
upgrade,
)
if old_data_field["indexed"]:
cls._add_operator(
cls._drop_index(
model,
{db_column},
),
upgrade,
True,
)
old_fk_fields = old_model_describe.get("fk_fields")
new_fk_fields = new_model_describe.get("fk_fields")
old_fk_fields_name = list(map(lambda x: x.get("name"), old_fk_fields))
new_fk_fields_name = list(map(lambda x: x.get("name"), new_fk_fields))
# add fk
for new_fk_field_name in set(new_fk_fields_name).difference(
set(old_fk_fields_name)
):
fk_field = next(
filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields)
)
if fk_field.get("db_constraint"):
cls._add_operator(
cls._add_fk(
model, fk_field, new_models.get(fk_field.get("python_type"))
),
upgrade,
fk_m2m_index=True,
)
# drop fk
for old_fk_field_name in set(old_fk_fields_name).difference(
set(new_fk_fields_name)
):
old_fk_field = next(
filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields)
)
if old_fk_field.get("db_constraint"):
cls._add_operator(
cls._drop_fk(
model, old_fk_field, old_models.get(old_fk_field.get("python_type"))
),
upgrade,
fk_m2m_index=True,
)
# change fields
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
old_data_field = next(
filter(lambda x: x.get("name") == field_name, old_data_fields)
)
new_data_field = next(
filter(lambda x: x.get("name") == field_name, new_data_fields)
)
changes = diff(old_data_field, new_data_field)
for change in changes:
_, option, old_new = change
if option == "indexed":
# change index
unique = new_data_field.get("unique")
if old_new[0] is False and old_new[1] is True:
cls._add_operator(
cls._add_index(model, (field_name,), unique), upgrade, True
)
else:
cls._add_operator(
cls._drop_index(model, (field_name,), unique), upgrade, True
)
elif option == "db_field_types.":
if new_data_field.get("field_type") == "DecimalField":
# modify column
cls._add_operator(
cls._modify_field(model, new_data_field),
upgrade,
)
else:
continue
elif option == "default":
if not (
is_default_function(old_new[0]) or is_default_function(old_new[1])
):
# change column default
cls._add_operator(
cls._alter_default(model, new_data_field), upgrade
)
elif option == "unique":
# because indexed include it
continue
elif option == "nullable":
# change nullable
cls._add_operator(cls._alter_null(model, new_data_field), upgrade)
else:
# modify column
cls._add_operator(
cls._modify_field(model, new_data_field),
upgrade,
)
for old_model in old_models:
if old_model not in new_models.keys():
cls._add_operator(cls.remove_model(old_models.get(old_model)), upgrade)
cls._add_operator(cls.drop_model(old_models.get(old_model).get("table")), upgrade)
@classmethod
def _is_fk_m2m(cls, field: Field):
return isinstance(field, (ForeignKeyFieldInstance, ManyToManyFieldInstance))
def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str):
return cls.ddl.rename_table(model, old_table_name, new_table_name)
@classmethod
def add_model(cls, model: Type[Model]):
return cls.ddl.create_table(model)
@classmethod
def remove_model(cls, model: Type[Model]):
return cls.ddl.drop_table(model)
def drop_model(cls, table_name: str):
return cls.ddl.drop_table(table_name)
@classmethod
def diff_model(cls, old_model: Type[Model], new_model: Type[Model], upgrade=True):
"""
diff single model
:param old_model:
:param new_model:
:param upgrade:
:return:
"""
old_indexes = old_model._meta.indexes
new_indexes = new_model._meta.indexes
def create_m2m(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
return cls.ddl.create_m2m(model, field_describe, reference_table_describe)
old_unique_together = old_model._meta.unique_together
new_unique_together = new_model._meta.unique_together
old_fields_map = old_model._meta.fields_map
new_fields_map = new_model._meta.fields_map
old_keys = old_fields_map.keys()
new_keys = new_fields_map.keys()
for new_key in new_keys:
new_field = new_fields_map.get(new_key)
if cls._exclude_field(new_field, upgrade):
continue
if new_key not in old_keys:
new_field_dict = new_field.describe(serializable=True)
new_field_dict.pop("name", None)
new_field_dict.pop("db_column", None)
for diff_key in old_keys - new_keys:
old_field = old_fields_map.get(diff_key)
old_field_dict = old_field.describe(serializable=True)
old_field_dict.pop("name", None)
old_field_dict.pop("db_column", None)
if old_field_dict == new_field_dict:
if upgrade:
is_rename = click.prompt(
f"Rename {diff_key} to {new_key}?",
default=True,
type=bool,
show_choices=True,
)
cls._rename_new.append(new_key)
cls._rename_old.append(diff_key)
else:
is_rename = diff_key in cls._rename_new
if is_rename:
if (
cls.dialect == "mysql"
and cls._db_version
and cls._db_version.major == 5
):
cls._add_operator(
cls._change_field(new_model, old_field, new_field),
upgrade,
)
else:
cls._add_operator(
cls._rename_field(new_model, old_field, new_field),
upgrade,
)
break
else:
cls._add_operator(
cls._add_field(new_model, new_field),
upgrade,
cls._is_fk_m2m(new_field),
)
else:
old_field = old_fields_map.get(new_key)
new_field_dict = new_field.describe(serializable=True)
new_field_dict.pop("unique")
new_field_dict.pop("indexed")
old_field_dict = old_field.describe(serializable=True)
old_field_dict.pop("unique")
old_field_dict.pop("indexed")
if not cls._is_fk_m2m(new_field) and new_field_dict != old_field_dict:
if cls.dialect == "postgres":
if new_field.null != old_field.null:
cls._add_operator(
cls._alter_null(new_model, new_field), upgrade=upgrade
)
if new_field.default != old_field.default and not callable(
new_field.default
):
cls._add_operator(
cls._alter_default(new_model, new_field), upgrade=upgrade
)
if new_field.description != old_field.description:
cls._add_operator(
cls._set_comment(new_model, new_field), upgrade=upgrade
)
if new_field.field_type != old_field.field_type:
cls._add_operator(
cls._modify_field(new_model, new_field), upgrade=upgrade
)
else:
cls._add_operator(cls._modify_field(new_model, new_field), upgrade=upgrade)
if (old_field.index and not new_field.index) or (
old_field.unique and not new_field.unique
):
cls._add_operator(
cls._remove_index(
old_model, (old_field.model_field_name,), old_field.unique
),
upgrade,
cls._is_fk_m2m(old_field),
)
elif (new_field.index and not old_field.index) or (
new_field.unique and not old_field.unique
):
cls._add_operator(
cls._add_index(new_model, (new_field.model_field_name,), new_field.unique),
upgrade,
cls._is_fk_m2m(new_field),
)
if isinstance(new_field, ForeignKeyFieldInstance):
if old_field.db_constraint and not new_field.db_constraint:
cls._add_operator(
cls._drop_fk(new_model, new_field),
upgrade,
True,
)
if new_field.db_constraint and not old_field.db_constraint:
cls._add_operator(
cls._add_fk(new_model, new_field),
upgrade,
True,
)
for old_key in old_keys:
field = old_fields_map.get(old_key)
if old_key not in new_keys and not cls._exclude_field(field, upgrade):
if (upgrade and old_key not in cls._rename_old) or (
not upgrade and old_key not in cls._rename_new
):
cls._add_operator(
cls._remove_field(old_model, field),
upgrade,
cls._is_fk_m2m(field),
)
for new_index in new_indexes:
if new_index not in old_indexes:
cls._add_operator(
cls._add_index(
new_model,
new_index,
),
upgrade,
)
for old_index in old_indexes:
if old_index not in new_indexes:
cls._add_operator(cls._remove_index(old_model, old_index), upgrade)
for new_unique in new_unique_together:
if new_unique not in old_unique_together:
cls._add_operator(cls._add_index(new_model, new_unique, unique=True), upgrade)
for old_unique in old_unique_together:
if old_unique not in new_unique_together:
cls._add_operator(cls._remove_index(old_model, old_unique, unique=True), upgrade)
@classmethod
def drop_m2m(cls, table_name: str):
return cls.ddl.drop_m2m(table_name)
@classmethod
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]):
ret = []
for field_name in fields_name:
if field_name in model._meta.fk_fields:
field = model._meta.fields_map[field_name]
if field.source_field:
ret.append(field.source_field)
elif field_name in model._meta.fk_fields:
ret.append(field_name + "_id")
else:
ret.append(field_name)
return ret
@classmethod
def _remove_index(cls, model: Type[Model], fields_name: Tuple[str], unique=False):
def _drop_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False):
if isinstance(fields_name, Index):
return cls.ddl.drop_index_by_name(
model, fields_name.index_name(cls.ddl.schema_generator, model)
)
fields_name = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.drop_index(model, fields_name, unique)
@classmethod
def _add_index(cls, model: Type[Model], fields_name: Tuple[str], unique=False):
def _add_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False):
if isinstance(fields_name, Index):
return fields_name.get_sql(cls.ddl.schema_generator, model, False)
fields_name = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.add_index(model, fields_name, unique)
@classmethod
def _exclude_field(cls, field: Field, upgrade=False):
"""
exclude BackwardFKRelation and repeat m2m field
:param field:
:return:
"""
if isinstance(field, ManyToManyFieldInstance):
through = field.through
if upgrade:
if through in cls._upgrade_m2m:
return True
else:
cls._upgrade_m2m.append(through)
return False
else:
if through in cls._downgrade_m2m:
return True
else:
cls._downgrade_m2m.append(through)
return False
return isinstance(field, (BackwardFKRelation, BackwardOneToOneRelation))
def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False):
return cls.ddl.add_column(model, field_describe, is_pk)
@classmethod
def _add_field(cls, model: Type[Model], field: Field):
if isinstance(field, ForeignKeyFieldInstance):
return cls.ddl.add_fk(model, field)
if isinstance(field, ManyToManyFieldInstance):
return cls.ddl.create_m2m_table(model, field)
return cls.ddl.add_column(model, field)
def _alter_default(cls, model: Type[Model], field_describe: dict):
return cls.ddl.alter_column_default(model, field_describe)
@classmethod
def _alter_default(cls, model: Type[Model], field: Field):
return cls.ddl.alter_column_default(model, field)
def _alter_null(cls, model: Type[Model], field_describe: dict):
return cls.ddl.alter_column_null(model, field_describe)
@classmethod
def _alter_null(cls, model: Type[Model], field: Field):
return cls.ddl.alter_column_null(model, field)
def _set_comment(cls, model: Type[Model], field_describe: dict):
return cls.ddl.set_comment(model, field_describe)
@classmethod
def _set_comment(cls, model: Type[Model], field: Field):
return cls.ddl.set_comment(model, field)
def _modify_field(cls, model: Type[Model], field_describe: dict):
return cls.ddl.modify_column(model, field_describe)
@classmethod
def _modify_field(cls, model: Type[Model], field: Field):
return cls.ddl.modify_column(model, field)
def _drop_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
return cls.ddl.drop_fk(model, field_describe, reference_table_describe)
@classmethod
def _drop_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance):
return cls.ddl.drop_fk(model, field)
def _remove_field(cls, model: Type[Model], column_name: str):
return cls.ddl.drop_column(model, column_name)
@classmethod
def _remove_field(cls, model: Type[Model], field: Field):
if isinstance(field, ForeignKeyFieldInstance):
return cls.ddl.drop_fk(model, field)
if isinstance(field, ManyToManyFieldInstance):
return cls.ddl.drop_m2m(field)
return cls.ddl.drop_column(model, field.model_field_name)
def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str):
return cls.ddl.rename_column(model, old_field_name, new_field_name)
@classmethod
def _rename_field(cls, model: Type[Model], old_field: Field, new_field: Field):
return cls.ddl.rename_column(model, old_field.model_field_name, new_field.model_field_name)
@classmethod
def _change_field(cls, model: Type[Model], old_field: Field, new_field: Field):
def _change_field(cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict):
db_field_types = new_field_describe.get("db_field_types")
return cls.ddl.change_column(
model,
old_field.model_field_name,
new_field.model_field_name,
new_field.get_for_dialect(cls.dialect, "SQL_TYPE"),
old_field_describe.get("db_column"),
new_field_describe.get("db_column"),
db_field_types.get(cls.dialect) or db_field_types.get(""),
)
@classmethod
def _add_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance):
def _add_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
"""
add fk
:param model:
:param field:
:param field_describe:
:param reference_table_describe:
:return:
"""
return cls.ddl.add_fk(model, field)
@classmethod
def _remove_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance):
"""
drop fk
:param model:
:param field:
:return:
"""
return cls.ddl.drop_fk(model, field)
return cls.ddl.add_fk(model, field_describe, reference_table_describe)
@classmethod
def _merge_operators(cls):

View File

@@ -1,12 +1,15 @@
from tortoise import Model, fields
from aerich.coder import decoder, encoder
MAX_VERSION_LENGTH = 255
MAX_APP_LENGTH = 100
class Aerich(Model):
version = fields.CharField(max_length=MAX_VERSION_LENGTH)
app = fields.CharField(max_length=20)
content = fields.TextField()
app = fields.CharField(max_length=MAX_APP_LENGTH)
content = fields.JSONField(encoder=encoder, decoder=decoder)
class Meta:
ordering = ["-id"]

View File

@@ -1,18 +1,44 @@
import importlib
import importlib.util
import os
import re
import sys
from pathlib import Path
from typing import Dict
from click import BadOptionUsage, Context
from click import BadOptionUsage, ClickException, Context
from tortoise import BaseDBAsyncClient, Tortoise
def get_app_connection_name(config, app) -> str:
def add_src_path(path: str) -> str:
"""
add a folder to the paths, so we can import from there
:param path: path to add
:return: absolute path
"""
if not os.path.isabs(path):
# use the absolute path, otherwise some other things (e.g. __file__) won't work properly
path = os.path.abspath(path)
if not os.path.isdir(path):
raise ClickException(f"Specified source folder does not exist: {path}")
if path not in sys.path:
sys.path.insert(0, path)
return path
def get_app_connection_name(config, app_name: str) -> str:
"""
get connection name
:param config:
:param app:
:param app_name:
:return:
"""
return config.get("apps").get(app).get("default_connection", "default")
app = config.get("apps").get(app_name)
if app:
return app.get("default_connection", "default")
raise BadOptionUsage(
option_name="--app",
message=f'Can\'t get app named "{app_name}"',
)
def get_app_connection(config, app) -> BaseDBAsyncClient:
@@ -35,12 +61,11 @@ def get_tortoise_config(ctx: Context, tortoise_orm: str) -> dict:
splits = tortoise_orm.split(".")
config_path = ".".join(splits[:-1])
tortoise_config = splits[-1]
try:
config_module = importlib.import_module(config_path)
except (ModuleNotFoundError, AttributeError):
raise BadOptionUsage(
ctx=ctx, message=f'No config named "{config_path}"', option_name="--config"
)
except ModuleNotFoundError as e:
raise ClickException(f"Error while importing configuration module: {e}") from None
config = getattr(config_module, tortoise_config, None)
if not config:
@@ -52,44 +77,26 @@ def get_tortoise_config(ctx: Context, tortoise_orm: str) -> dict:
return config
_UPGRADE = "##### upgrade #####\n"
_DOWNGRADE = "##### downgrade #####\n"
def get_version_content_from_file(version_file: str) -> Dict:
def get_models_describe(app: str) -> Dict:
"""
get version content
:param version_file:
get app models describe
:param app:
:return:
"""
with open(version_file, "r", encoding="utf-8") as f:
content = f.read()
first = content.index(_UPGRADE)
second = content.index(_DOWNGRADE)
upgrade_content = content[first + len(_UPGRADE) : second].strip() # noqa:E203
downgrade_content = content[second + len(_DOWNGRADE) :].strip() # noqa:E203
ret = {"upgrade": upgrade_content.split("\n"), "downgrade": downgrade_content.split("\n")}
ret = {}
for model in Tortoise.apps.get(app).values():
describe = model.describe()
ret[describe.get("name")] = describe
return ret
def write_version_file(version_file: str, content: Dict):
"""
write version file
:param version_file:
:param content:
:return:
"""
with open(version_file, "w", encoding="utf-8") as f:
f.write(_UPGRADE)
upgrade = content.get("upgrade")
if len(upgrade) > 1:
f.write(";\n".join(upgrade) + ";\n")
else:
f.write(f"{upgrade[0]};\n")
downgrade = content.get("downgrade")
if downgrade:
f.write(_DOWNGRADE)
if len(downgrade) > 1:
f.write(";\n".join(downgrade) + ";\n")
else:
f.write(f"{downgrade[0]};\n")
def is_default_function(string: str):
return re.match(r"^<function.+>$", str(string or ""))
def import_py_file(file: Path):
module_name, file_ext = os.path.splitext(os.path.split(file)[-1])
spec = importlib.util.spec_from_file_location(module_name, file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module

1
aerich/version.py Normal file
View File

@@ -0,0 +1 @@
__version__ = "0.7.1rc1"

View File

@@ -36,7 +36,7 @@ def reset_migrate():
Migrate._downgrade_m2m = []
@pytest.yield_fixture(scope="session")
@pytest.fixture(scope="session")
def event_loop():
policy = asyncio.get_event_loop_policy()
res = policy.new_event_loop()
@@ -51,12 +51,6 @@ def event_loop():
@pytest.fixture(scope="session", autouse=True)
async def initialize_tests(event_loop, request):
tortoise_orm["connections"]["diff_models"] = "sqlite://:memory:"
tortoise_orm["apps"]["diff_models"] = {
"models": ["tests.diff_models"],
"default_connection": "diff_models",
}
await Tortoise.init(config=tortoise_orm, _create_db=True)
await generate_schema_for_client(Tortoise.get_connection("default"), safe=True)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 76 KiB

1015
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,13 @@
[tool.poetry]
name = "aerich"
version = "0.4.0"
version = "0.7.1rc1"
description = "A database migrations tool for Tortoise ORM."
authors = ["long2ice <long2ice@gmail.com>"]
license = "Apache-2.0"
readme = "README.md"
homepage = "https://github.com/long2ice/aerich"
repository = "https://github.com/long2ice/aerich.git"
documentation = "https://github.com/long2ice/aerich"
homepage = "https://github.com/tortoise/aerich"
repository = "https://github.com/tortoise/aerich.git"
documentation = "https://github.com/tortoise/aerich"
keywords = ["migrate", "Tortoise-ORM", "mysql"]
packages = [
{ include = "aerich" }
@@ -18,22 +18,32 @@ include = ["CHANGELOG.md", "LICENSE", "README.md"]
python = "^3.7"
tortoise-orm = "*"
click = "*"
pydantic = "*"
aiomysql = {version = "*", optional = true}
asyncpg = { version = "*", optional = true }
asyncmy = { version = "*", optional = true }
pydantic = "*"
dictdiffer = "*"
tomlkit = "*"
[tool.poetry.dev-dependencies]
flake8 = "*"
isort = "*"
black = "^20.8b1"
black = "*"
pytest = "*"
pytest-xdist = "*"
pytest-asyncio = "*"
bandit = "*"
pytest-mock = "*"
cryptography = "*"
pyproject-flake8 = "*"
[tool.poetry.extras]
dbdrivers = ["aiomysql", "asyncpg"]
asyncmy = ["asyncmy"]
asyncpg = ["asyncpg"]
[tool.aerich]
tortoise_orm = "conftest.tortoise_orm"
location = "./migrations"
src_folder = "./."
[build-system]
requires = ["poetry>=0.12"]
@@ -41,3 +51,17 @@ build-backend = "poetry.masonry.api"
[tool.poetry.scripts]
aerich = "aerich.cli:main"
[tool.black]
line-length = 100
target-version = ['py36', 'py37', 'py38', 'py39']
[tool.pytest.ini_options]
asyncio_mode = 'auto'
[tool.mypy]
pretty = true
ignore_missing_imports = true
[tool.flake8]
ignore = 'E501,W503,E203'

View File

@@ -1,2 +0,0 @@
[flake8]
ignore = E501,W503

View File

@@ -1,4 +1,5 @@
import datetime
import uuid
from enum import IntEnum
from tortoise import Model, fields
@@ -23,23 +24,29 @@ class Status(IntEnum):
class User(Model):
username = fields.CharField(max_length=20, unique=True)
password = fields.CharField(max_length=200)
password = fields.CharField(max_length=100)
last_login = fields.DatetimeField(description="Last Login", default=datetime.datetime.now)
is_active = fields.BooleanField(default=True, description="Is Active")
is_superuser = fields.BooleanField(default=False, description="Is SuperUser")
avatar = fields.CharField(max_length=200, default="")
intro = fields.TextField(default="")
longitude = fields.DecimalField(max_digits=10, decimal_places=8)
class Email(Model):
email = fields.CharField(max_length=200)
email_id = fields.IntField(pk=True)
email = fields.CharField(max_length=200, index=True)
is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("models.User", db_constraint=False)
address = fields.CharField(max_length=200)
users = fields.ManyToManyField("models.User")
def default_name():
return uuid.uuid4()
class Category(Model):
slug = fields.CharField(max_length=200)
name = fields.CharField(max_length=200)
slug = fields.CharField(max_length=100)
name = fields.CharField(max_length=200, null=True, default=default_name)
user = fields.ForeignKeyField("models.User", description="User")
created_at = fields.DatetimeField(auto_now_add=True)
@@ -47,17 +54,28 @@ class Category(Model):
class Product(Model):
categories = fields.ManyToManyField("models.Category")
name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num")
view_num = fields.IntField(description="View Num", default=0)
sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField(ProductType, description="Product Type")
image = fields.CharField(max_length=200)
type = fields.IntEnumField(
ProductType, description="Product Type", source_field="type_db_alias"
)
pic = fields.CharField(max_length=200)
body = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True)
class Meta:
unique_together = (("name", "type"),)
indexes = (("name", "type"),)
class Config(Model):
label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20)
value = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on)
status: Status = fields.IntEnumField(Status)
user = fields.ForeignKeyField("models.User", description="User")
class NewModel(Model):
name = fields.CharField(max_length=50)

View File

@@ -50,7 +50,9 @@ class Product(Model):
view_num = fields.IntField(description="View Num")
sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField(ProductType, description="Product Type")
type = fields.IntEnumField(
ProductType, description="Product Type", source_field="type_db_alias"
)
image = fields.CharField(max_length=200)
body = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True)

View File

@@ -24,32 +24,36 @@ class Status(IntEnum):
class User(Model):
username = fields.CharField(max_length=20)
password = fields.CharField(max_length=200)
last_login_at = fields.DatetimeField(description="Last Login", default=datetime.datetime.now)
last_login = fields.DatetimeField(description="Last Login", default=datetime.datetime.now)
is_active = fields.BooleanField(default=True, description="Is Active")
is_superuser = fields.BooleanField(default=False, description="Is SuperUser")
avatar = fields.CharField(max_length=200, default="")
intro = fields.TextField(default="")
longitude = fields.DecimalField(max_digits=12, decimal_places=9)
class Email(Model):
email = fields.CharField(max_length=200)
is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("diff_models.User", db_constraint=True)
user = fields.ForeignKeyField("models.User", db_constraint=False)
class Category(Model):
slug = fields.CharField(max_length=200)
user = fields.ForeignKeyField("diff_models.User", description="User")
name = fields.CharField(max_length=200)
user = fields.ForeignKeyField("models.User", description="User")
created_at = fields.DatetimeField(auto_now_add=True)
class Product(Model):
categories = fields.ManyToManyField("diff_models.Category")
categories = fields.ManyToManyField("models.Category")
name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num")
sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField(ProductType, description="Product Type")
type = fields.IntEnumField(
ProductType, description="Product Type", source_field="type_db_alias"
)
image = fields.CharField(max_length=200)
body = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True)
@@ -60,3 +64,6 @@ class Config(Model):
key = fields.CharField(max_length=20)
value = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on)
class Meta:
table = "configs"

View File

@@ -1,11 +1,8 @@
import pytest
from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL
from aerich.exceptions import NotSupportError
from aerich.migrate import Migrate
from tests.models import Category, User
from tests.models import Category, Product, User
def test_create_table():
@@ -15,8 +12,8 @@ def test_create_table():
ret
== """CREATE TABLE IF NOT EXISTS `category` (
`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
`slug` VARCHAR(200) NOT NULL,
`name` VARCHAR(200) NOT NULL,
`slug` VARCHAR(100) NOT NULL,
`name` VARCHAR(200),
`created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
`user_id` INT NOT NULL COMMENT 'User',
CONSTRAINT `fk_category_user_e2e3874c` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE
@@ -28,8 +25,8 @@ def test_create_table():
ret
== """CREATE TABLE IF NOT EXISTS "category" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
"slug" VARCHAR(200) NOT NULL,
"name" VARCHAR(200) NOT NULL,
"slug" VARCHAR(100) NOT NULL,
"name" VARCHAR(200),
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */
);"""
@@ -40,8 +37,8 @@ def test_create_table():
ret
== """CREATE TABLE IF NOT EXISTS "category" (
"id" SERIAL NOT NULL PRIMARY KEY,
"slug" VARCHAR(200) NOT NULL,
"name" VARCHAR(200) NOT NULL,
"slug" VARCHAR(100) NOT NULL,
"name" VARCHAR(200),
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE
);
@@ -50,7 +47,7 @@ COMMENT ON COLUMN "category"."user_id" IS 'User';"""
def test_drop_table():
ret = Migrate.ddl.drop_table(Category)
ret = Migrate.ddl.drop_table(Category._meta.db_table)
if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "DROP TABLE IF EXISTS `category`"
else:
@@ -58,85 +55,90 @@ def test_drop_table():
def test_add_column():
ret = Migrate.ddl.add_column(Category, Category._meta.fields_map.get("name"))
ret = Migrate.ddl.add_column(Category, Category._meta.fields_map.get("name").describe(False))
if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` ADD `name` VARCHAR(200) NOT NULL"
assert ret == "ALTER TABLE `category` ADD `name` VARCHAR(200)"
else:
assert ret == 'ALTER TABLE "category" ADD "name" VARCHAR(200) NOT NULL'
assert ret == 'ALTER TABLE "category" ADD "name" VARCHAR(200)'
def test_modify_column():
if isinstance(Migrate.ddl, SqliteDDL):
with pytest.raises(NotSupportError):
ret0 = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name"))
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active"))
else:
ret0 = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name"))
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active"))
if isinstance(Migrate.ddl, MysqlDDL):
assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL"
elif isinstance(Migrate.ddl, PostgresDDL):
assert ret0 == 'ALTER TABLE "category" ALTER COLUMN "name" TYPE VARCHAR(200)'
return
ret0 = Migrate.ddl.modify_column(
Category, Category._meta.fields_map.get("name").describe(False)
)
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active").describe(False))
if isinstance(Migrate.ddl, MysqlDDL):
assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)"
assert (
ret1
== "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1"
)
elif isinstance(Migrate.ddl, PostgresDDL):
assert ret1 == 'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL'
assert (
ret0
== 'ALTER TABLE "category" ALTER COLUMN "name" TYPE VARCHAR(200) USING "name"::VARCHAR(200)'
)
assert (
ret1 == 'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL USING "is_active"::BOOL'
)
def test_alter_column_default():
ret = Migrate.ddl.alter_column_default(Category, Category._meta.fields_map.get("name"))
if isinstance(Migrate.ddl, SqliteDDL):
return
ret = Migrate.ddl.alter_column_default(User, User._meta.fields_map.get("intro").describe(False))
if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" DROP DEFAULT'
else:
assert ret is None
assert ret == 'ALTER TABLE "user" ALTER COLUMN "intro" SET DEFAULT \'\''
elif isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `user` ALTER COLUMN `intro` SET DEFAULT ''"
ret = Migrate.ddl.alter_column_default(Category, Category._meta.fields_map.get("created_at"))
ret = Migrate.ddl.alter_column_default(
Category, Category._meta.fields_map.get("created_at").describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL):
assert (
ret == 'ALTER TABLE "category" ALTER COLUMN "created_at" SET DEFAULT CURRENT_TIMESTAMP'
)
else:
assert ret is None
elif isinstance(Migrate.ddl, MysqlDDL):
assert (
ret
== "ALTER TABLE `category` ALTER COLUMN `created_at` SET DEFAULT CURRENT_TIMESTAMP(6)"
)
ret = Migrate.ddl.alter_column_default(User, User._meta.fields_map.get("avatar"))
ret = Migrate.ddl.alter_column_default(
Product, Product._meta.fields_map.get("view_num").describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "user" ALTER COLUMN "avatar" SET DEFAULT \'\''
else:
assert ret is None
assert ret == 'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0'
elif isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0"
def test_alter_column_null():
ret = Migrate.ddl.alter_column_null(Category, Category._meta.fields_map.get("name"))
if isinstance(Migrate.ddl, (SqliteDDL, MysqlDDL)):
return
ret = Migrate.ddl.alter_column_null(
Category, Category._meta.fields_map.get("name").describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL'
else:
assert ret is None
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" DROP NOT NULL'
def test_set_comment():
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("name"))
if isinstance(Migrate.ddl, PostgresDDL):
if isinstance(Migrate.ddl, (SqliteDDL, MysqlDDL)):
return
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("name").describe(False))
assert ret == 'COMMENT ON COLUMN "category"."name" IS NULL'
else:
assert ret is None
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("user"))
if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'COMMENT ON COLUMN "category"."user" IS \'User\''
else:
assert ret is None
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("user").describe(False))
assert ret == 'COMMENT ON COLUMN "category"."user_id" IS \'User\''
def test_drop_column():
if isinstance(Migrate.ddl, SqliteDDL):
with pytest.raises(NotSupportError):
ret = Migrate.ddl.drop_column(Category, "name")
else:
ret = Migrate.ddl.drop_column(Category, "name")
if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` DROP COLUMN `name`"
@@ -154,10 +156,7 @@ def test_add_index():
)
elif isinstance(Migrate.ddl, PostgresDDL):
assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")'
assert (
index_u
== 'ALTER TABLE "category" ADD CONSTRAINT "uid_category_name_8b0cb9" UNIQUE ("name")'
)
assert index_u == 'CREATE UNIQUE INDEX "uid_category_name_8b0cb9" ON "category" ("name")'
else:
assert index == 'ALTER TABLE "category" ADD INDEX "idx_category_name_8b0cb9" ("name")'
assert (
@@ -173,14 +172,16 @@ def test_drop_index():
assert ret_u == "ALTER TABLE `category` DROP INDEX `uid_category_name_8b0cb9`"
elif isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'DROP INDEX "idx_category_name_8b0cb9"'
assert ret_u == 'ALTER TABLE "category" DROP CONSTRAINT "uid_category_name_8b0cb9"'
assert ret_u == 'DROP INDEX "uid_category_name_8b0cb9"'
else:
assert ret == 'ALTER TABLE "category" DROP INDEX "idx_category_name_8b0cb9"'
assert ret_u == 'ALTER TABLE "category" DROP INDEX "uid_category_name_8b0cb9"'
def test_add_fk():
ret = Migrate.ddl.add_fk(Category, Category._meta.fields_map.get("user"))
ret = Migrate.ddl.add_fk(
Category, Category._meta.fields_map.get("user").describe(False), User.describe(False)
)
if isinstance(Migrate.ddl, MysqlDDL):
assert (
ret
@@ -194,7 +195,9 @@ def test_add_fk():
def test_drop_fk():
ret = Migrate.ddl.drop_fk(Category, Category._meta.fields_map.get("user"))
ret = Migrate.ddl.drop_fk(
Category, Category._meta.fields_map.get("user").describe(False), User.describe(False)
)
if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` DROP FOREIGN KEY `fk_category_user_e2e3874c`"
elif isinstance(Migrate.ddl, PostgresDDL):

File diff suppressed because it is too large Load Diff

6
tests/test_utils.py Normal file
View File

@@ -0,0 +1,6 @@
from aerich.utils import import_py_file
def test_import_py_file():
m = import_py_file("aerich/utils.py")
assert getattr(m, "import_py_file")