149 Commits

Author SHA1 Message Date
Waket Zheng
252cb97767 Bump version to 0.8.0 and update changelog 2024-12-04 00:14:35 +08:00
Waket Zheng
ac3ef9e2eb tests: no need to merge operators when NotSupportError raised (#373) 2024-12-03 23:59:54 +08:00
Waket Zheng
5b04b4422d chore: upgrade deps, update changelog, drop test db before create (#372)
* chore: upgrade deps, update changelog, drop test db before create

* tests: clear operators for sqlite ddl after NotSupportError raised
2024-12-03 23:30:37 +08:00
Waket Zheng
b2f4029a4a Improve type hints of inspectdb (#371) 2024-12-03 12:40:28 +08:00
xsillen
4e46d9d969 Fix the issue of parameter concatenation when generating ORM with inspectdb (#331)
Co-authored-by: floodpillar <165008032+floodpillar@users.noreply.github.com>
2024-12-03 11:44:36 +08:00
gck123
c0fd3ec63c Fix KeyError when deleting a field with unqiue=True (#365)
* Fix KeyError when deleting a field with unqiue=True

* refactor: rename `old_data_unique` to `is_unique_field`

* Add testcases for remove unique field

---------

Co-authored-by: gongchangku <gongchangku@anban.tech>
Co-authored-by: Waket Zheng <waketzheng@gmail.com>
2024-12-03 10:16:44 +08:00
Waket Zheng
103470f4c1 Merge pull request #367 from waketzheng/fix-style-issue
chore: make style, upgrade deps, fix ci error and update changelog
2024-11-29 15:09:25 +08:00
Waket Zheng
dc020358b6 chore: make style, upgrade deps, fix ci error and update changelog 2024-11-25 23:46:48 +08:00
Waket Zheng
095eb48196 Merge pull request #360 from ahmetveburak/fix-package
fix(package): correct the click import
2024-11-25 18:10:27 +08:00
He
fac1de6299 Merge pull request #355 from merlinz01/improve-cli-descriptions
Improve CLI help text and output
2024-11-18 21:25:13 +01:00
ahmetveburak
e049fcee26 fix(package): correct the click import 2024-10-05 20:43:32 +03:00
Merlin
ee144d4a9b update migrate output for consistency 2024-08-07 13:32:49 -04:00
Merlin
dbf96a17d3 Improve command-line descriptions
Changed CLI help texts for some of the options to make them clearer.
2024-08-07 13:31:07 -04:00
long2ice
15d56121ef fix: migrate 2024-08-06 22:46:50 +08:00
long2ice
4851ecfe82 refactor: use asyncclick 2024-08-06 22:41:39 +08:00
long2ice
ad0e7006bc Merge pull request #348 from waketzheng/update-changelog
Update changelog
2024-06-07 11:25:10 +08:00
Waket Zheng
27b29e401b Merge remote-tracking branch 'upstream/dev' into update-changelog 2024-06-07 11:11:47 +08:00
Waket Zheng
7bb35c10e4 Add tag link to change log 2024-06-07 11:08:49 +08:00
long2ice
ed113d491e Merge pull request #347 from waketzheng/add-mypy-check-to-ci
Use mypy for type hint check in ci
2024-06-06 21:38:23 +08:00
Waket Zheng
9e46fbf55d rollback ci command 2024-06-06 18:20:57 +08:00
Waket Zheng
fc68f99c86 fixing cache action 2024-06-06 18:02:58 +08:00
Waket Zheng
480087df07 Add id for cache step 2024-06-06 17:38:51 +08:00
Waket Zheng
24a2087b78 Skip make deps if cache hint in ci 2024-06-06 17:35:00 +08:00
Waket Zheng
bceeb236c2 Checking cache action 2024-06-06 17:21:46 +08:00
Waket Zheng
c42fdab74d Add cache action to ci 2024-06-06 17:11:27 +08:00
Waket Zheng
ceb1e0ffef Activate type hint check in ci 2024-06-06 17:07:09 +08:00
long2ice
13dd44bef7 Merge pull request #340 from waketzheng/type-hint-tests
Improve type hints for tests/
2024-06-06 16:53:01 +08:00
long2ice
219633a926 Merge pull request #341 from waketzheng/type-hint-simple
Simple type hints for aerich/
2024-06-06 16:52:30 +08:00
long2ice
e6302a920a Merge pull request #342 from waketzheng/type-hint-ddl
Improve type hints for ddl and inspectdb
2024-06-06 16:51:37 +08:00
long2ice
79a77d368f Merge pull request #343 from waketzheng/type-hints-aerich.migrate
Add type hints for aerich.migrate
2024-06-06 16:50:53 +08:00
long2ice
4a1fc4cfa0 Merge pull request #339 from waketzheng/update-ci
Update ci action versions and avoid `make ci` install deps twice
2024-06-06 16:50:39 +08:00
Waket Zheng
aee706e29b Update changelog 2024-06-06 16:43:08 +08:00
Waket Zheng
572c13f9dd fix conflict 2024-06-06 16:36:29 +08:00
Waket Zheng
7b733495fb fix conflict 2024-06-06 16:35:09 +08:00
Waket Zheng
c0c217392c Merge remote-tracking branch 'upstream/dev' into type-hint-simple 2024-06-06 16:32:00 +08:00
Waket Zheng
c7a3d164cb fix conflict 2024-06-06 16:31:24 +08:00
Waket Zheng
50add58981 merge dev 2024-06-06 16:27:06 +08:00
long2ice
f3b6f8f264 Merge pull request #345 from waketzheng/dev
Drop python3.7 support
2024-06-06 15:49:31 +08:00
long2ice
d33638471b Merge pull request #346 from waketzheng/fix-336
fix: mysql drop unique index with error name
2024-06-06 15:48:50 +08:00
Waket Zheng
e764bb56f7 Add new column for unique index remove test 2024-06-06 09:07:13 +08:00
Waket Zheng
84d31d63f6 Update readme 2024-06-06 08:42:55 +08:00
Waket Zheng
234495d291 docs: bump up version and update changelog 2024-06-06 00:07:29 +08:00
Waket Zheng
e971653851 fix: mysql drop unique index migrate error 2024-06-05 23:37:30 +08:00
Waket Zheng
58d31b3a05 fix: make ci error with latest tortoise-orm (#344) 2024-06-04 01:44:28 +08:00
Waket Zheng
affffbdae3 Drop python3.7 support 2024-06-03 01:26:48 +08:00
Waket Zheng
6466a852c8 Add type hints for aerich.migrate 2024-06-02 18:09:34 +08:00
Waket Zheng
a917f253c9 Add type hints for inspectdb/sqlite 2024-06-02 17:56:54 +08:00
Waket Zheng
dd11bed5a0 Add type hints for ddl and inspectdb 2024-06-01 22:14:30 +08:00
Waket Zheng
8756f64e3f Simple type hints for aerich/ 2024-06-01 21:16:53 +08:00
Waket Zheng
1addda8178 Improve type hints for tests/ 2024-06-01 20:58:25 +08:00
Waket Zheng
716638752b Update ci action versions and avoid make ci install deps twice 2024-06-01 20:46:08 +08:00
long2ice
51117867a6 Merge pull request #328 from Fl0kse/fix_issue-150
try to fix "Add ManyToManyField will break migrate #150" issues
2024-01-23 22:08:08 +08:00
artem.ulchenko
b25c7671bb try to fix "Add ManyToManyField will break migrate #150" issues 2024-01-18 11:40:59 +03:00
long2ice
b63fbcc7a4 style: fix 2024-01-18 09:50:41 +08:00
long2ice
4370b5ed08 Merge pull request #327 from ar0ne/dev
Added an option to generate empty migration file
2024-01-18 09:46:02 +08:00
ar0ne
2b2e465362 missed changes that add new flag to cli 2023-12-27 12:56:46 +05:30
ar0ne
ede53ade86 added note in readme 2023-12-26 23:20:29 +05:30
ar0ne
ad54b5e9dd remove redundant comma for empty migration 2023-12-26 23:12:11 +05:30
ar0ne
b1ff2418f5 added option to generate empty migration file 2023-12-26 22:55:51 +05:30
long2ice
01264f3f27 fix: pydantic v2. (#322) 2023-08-30 10:04:53 +08:00
long2ice
ea234a5799 fix: default empty str 2023-08-23 10:53:19 +08:00
long2ice
b724f24f1a Merge pull request #319 from strayge/pr/editable-install
Update poetry build backend to support editable install
2023-08-22 10:00:03 +08:00
long2ice
2bc23103dc Merge pull request #320 from strayge/pr/column-description-diff
Fix changed column descriptions in diffs
2023-08-22 09:59:46 +08:00
strayge
6d83c370ad fix changed column descriptions in diffs 2023-08-21 10:58:43 +04:00
strayge
e729bb9b60 update poetry build-backend to support editable install 2023-08-21 10:58:26 +04:00
long2ice
8edd834da6 refactor: make in_transaction default True 2023-08-04 10:36:27 +08:00
long2ice
4adc89d441 Merge pull request #306 from plusiv/add-char-support-mysql
add char support for mysql
2023-07-27 12:52:01 +08:00
Jorge Massih
fd4b9fe7d3 add char support for mysql 2023-06-15 00:40:29 -04:00
long2ice
467406ed20 Merge pull request #303 from karichevi/dev
Added Documentation in Russian Language
2023-05-23 11:35:09 +08:00
karichevi
484b5900ce Added Documentation in Russian Language 2023-05-22 14:30:36 +03:00
long2ice
b8b6df0b65 chore: update deps 2023-05-12 15:27:50 +08:00
long2ice
f0bc3126e9 fix: generates two semicolons in a row. (#301) 2023-05-12 14:52:17 +08:00
long2ice
dbc0d9e7ef Merge pull request #296 from evstratbg/migrate-transaction
add in-transaction for upgrade
2023-05-05 23:12:08 +08:00
Bogdan
818dd29991 fix styles 2023-05-05 17:29:34 +04:00
Bogdan
e199e03b53 added -i param 2023-05-03 19:48:54 +04:00
Bogdan
d79dc25ee8 enriched changelog 2023-05-03 19:48:25 +04:00
Bogdan
c6d51a4dcf bump version 2023-05-03 19:48:15 +04:00
Bogdan
241b30a710 add in-transaction for upgrade 2023-04-03 20:55:12 +04:00
long2ice
8cf50c58d7 test: fix ci 2023-01-27 15:17:41 +08:00
long2ice
1c9b65cc37 fix: modify multiple times. (#279) 2023-01-27 13:49:07 +08:00
long2ice
3fbf9febfb Merge pull request #281 from CortexPE/patch-1
Fix #280 by removing trailing semicolon
2022-12-20 14:28:51 +08:00
marshall
7b6545d4e1 Fix #280 by removing trailing semicolon 2022-12-20 04:28:40 +08:00
long2ice
52b50a2161 feat: support virtual fields 2022-11-18 23:23:04 +08:00
long2ice
90943a473c Merge pull request #269 from jtraub/dev
Close connections in the init_db wrapper
2022-09-29 08:39:33 +08:00
Konstantin Mikhailov
d7ecd97e88 Close connections in the init_db wrapper 2022-09-29 08:09:59 +10:00
long2ice
20aebc4413 chore: update version 2022-09-27 22:44:12 +08:00
long2ice
f8e1f9ff44 fix: initialize an empty database. (#267) 2022-09-27 22:42:54 +08:00
long2ice
ab31445fb2 fix: test 2022-09-26 18:39:48 +08:00
long2ice
28d19a4b7b - Fix syntax error with python3.8.10. (#265)
- Fix sql generate error. (#263)
2022-09-26 18:36:57 +08:00
long2ice
9da99824fe fix: postgres sql error (#263) 2022-09-23 23:33:49 +08:00
long2ice
75db7cea60 fix: test error 2022-09-23 10:35:33 +08:00
long2ice
d777c9c278 Merge remote-tracking branch 'origin/dev' into dev
# Conflicts:
#	aerich/utils.py
#	tests/test_utils.py
2022-09-23 10:30:46 +08:00
long2ice
e9b76bdd35 feat: use .py version files 2022-09-23 10:29:48 +08:00
long2ice
8b7864d886 Merge pull request #199 from ehdgua01/fix-ddl-format-and-writing-version-file
[Enhancement] Fix version file formats
2022-09-16 09:30:27 +08:00
KDH
bef45941f2 Fix testcase 2022-09-16 10:26:21 +09:00
KDH
7b472d7a84 Fix testcase 2022-09-16 10:08:34 +09:00
KDH
1f0a6dfb50 Fix typo 2022-09-16 09:58:04 +09:00
KDH
36282f123f Merge branch 'dev' into fix-ddl-format-and-writing-version-file
# Conflicts:
#	aerich/utils.py
#	tests/test_migrate.py
2022-09-16 09:54:51 +09:00
KDH
3cd4e24050 Merge branch 'dev' into fix-ddl-format-and-writing-version-file 2022-09-16 09:43:57 +09:00
long2ice
f8c2f1b551 Merge pull request #205 from GDGSNF/dev
Merge repeated `if` statements into single `if`
2022-09-16 08:40:28 +08:00
Yasser Tahiri
131d97a3d6 Merge branch 'dev' into dev 2022-09-15 23:19:59 +01:00
Yasser Tahiri
1a0371e977 Update aerich/utils.py
Co-authored-by: KDH <ehdgua01@naver.com>
2022-09-15 23:19:18 +01:00
long2ice
e5b092fd08 Merge pull request #260 from waketzheng/dev
refactor: use pathlib to read and write text
2022-09-12 11:44:02 +08:00
Waket Zheng
7a109f3c79 refactor: use pathlib to read and write text 2022-09-12 00:57:46 +08:00
Jinlong Peng
8c2ecbaef1 feat: support add/remove field with index 2022-08-26 18:04:20 +08:00
long2ice
b141363c51 Merge pull request #242 from ssilaev/dev
Hotfix for cli group function in v0.6.3
2022-07-22 08:39:02 +08:00
long2ice
9dd474d79f Merge remote-tracking branch 'origin/dev' into dev 2022-06-27 11:42:37 +08:00
long2ice
e4bb9d838e docs: update changelog 2022-06-27 11:41:48 +08:00
long2ice
029d522c79 Merge pull request #249 from tortoise/fix-decimal
Fix decimal field change
2022-06-27 11:38:03 +08:00
long2ice
d6627906c7 test: fix test_migrate 2022-06-27 11:36:09 +08:00
long2ice
3c88833154 fix: decimal field change (#246) 2022-06-27 11:29:47 +08:00
long2ice
8f68f08eba Merge pull request #248 from isaquealves/feature/load_ddl_class_per_dialect
feat: Add support for dynamically load DDL classes
2022-06-22 20:25:42 +08:00
Isaque Alves
60ba6963fd Update changelog 2022-06-22 09:22:26 -03:00
Isaque Alves
4c35c44bd2 feat: Add support for dynamically load DDL classes
Adopt a strategy of loading classes based on their names, allowing to
easily add new database support without changing Migrate class logic
2022-06-22 09:16:11 -03:00
long2ice
bdeaf5495e Merge pull request #247 from isaquealves/feature/postgresql-numeric-type-translate
refactor: Improve db inspection
2022-06-22 08:41:15 +08:00
Isaque Alves
db33059ec9 Resolve style issue 2022-06-20 15:42:21 -03:00
Isaque Alves
44b96058f8 fix(tests/test_migrate.py): Resolve issue with broken tests 2022-06-17 12:36:04 -03:00
Isaque Alves
abff753b6a refactor: Improve postgresql migrate operators tests 2022-06-17 09:45:02 -03:00
Isaque Alves
dcd8441a05 fix: add space following python style guide" 2022-06-17 02:03:41 -03:00
Isaque Alves
b4a735b814 fix: Adjust changelog formatting 2022-06-17 02:02:37 -03:00
Isaque Alves
83ba13e99a Update changelog 2022-06-17 02:00:35 -03:00
Isaque Alves
d7b1c07d13 fix: Add comma to separate value in join 2022-06-17 01:51:47 -03:00
Isaque Alves
1ac16188fc refactor: Improve db inspection
- Add support to postgresql numeric type.
- Improve field configuration handling for numeric and decimal types
2022-06-17 01:38:39 -03:00
long2ice
4abc464ce0 feat: add is_flag to init-db 2022-05-24 11:20:12 +08:00
Sergey Silaev
d4430cec0d Hotfix for cli group function in v0.6.3 2022-05-10 01:20:11 +04:00
long2ice
0b01fa38d8 feat: add index inspect 2022-04-05 19:38:08 +08:00
long2ice
801dde15be feat: inspectdb support sqlite 2022-04-01 20:30:36 +08:00
long2ice
75480e2041 Merge remote-tracking branch 'origin/dev' into dev 2022-04-01 19:57:03 +08:00
long2ice
45129cef9f feat: improve inspectdb and support postgres 2022-04-01 19:56:48 +08:00
long2ice
3a0dd2355d Merge pull request #230 from ssilaev/dev
Increase max length of app column
2022-02-09 15:01:39 +08:00
Sergey Silaev
0e71bc16ae Increase max length of app column 2022-02-08 22:14:55 +03:00
long2ice
c39462820c upgrade deps 2022-01-17 22:26:13 +08:00
long2ice
f15cbaf9e0 Support migration for specified index. (#203) 2021-12-29 21:36:23 +08:00
long2ice
15131469df upgrade deps 2021-12-22 16:26:13 +08:00
long2ice
c60c1610f0 Fix pyproject.toml not existing error. (#217) 2021-12-12 22:11:51 +08:00
long2ice
63e8d06157 remove aiomysql 2021-12-08 14:43:33 +08:00
long2ice
68ef8ac676 Fix ci 2021-12-08 14:38:16 +08:00
long2ice
8b5cf6faa0 inspectdb support DATE. (#215) 2021-12-08 14:33:27 +08:00
Yasser Tahiri
40c7ef7fd6 Merge repeated if statements into single if 2021-10-21 15:43:18 +01:00
KDH
7a826df43f Fix duplicated semicolon in table creation DDL 2021-10-12 11:24:37 +09:00
KDH
b1b9cc1454 Fix M2M table template 2021-10-12 11:23:29 +09:00
long2ice
fac00d45cc Remove pydantic dependency. (#198) 2021-10-04 23:05:20 +08:00
long2ice
6f7893d376 Fix section name 2021-09-28 15:07:10 +08:00
long2ice
b1521c4cc7 update version 2021-09-27 19:55:38 +08:00
long2ice
24c1f4cb7d Change default config file from aerich.ini to pyproject.toml. (#197) 2021-09-27 11:05:20 +08:00
long2ice
661f241dac Compatible with old version in indexes 2021-08-31 17:53:17 +08:00
long2ice
01787558d6 Fix test 2021-08-31 17:41:13 +08:00
long2ice
699b0321a4 Support indexes change. (#193) 2021-08-31 17:36:25 +08:00
long2ice
4a83021892 Update FUNDING.yml 2021-08-26 20:39:31 +08:00
32 changed files with 2712 additions and 1475 deletions

2
.github/FUNDING.yml vendored
View File

@@ -1 +1 @@
custom: ["https://sponsor.long2ice.cn"] custom: ["https://sponsor.long2ice.io"]

View File

@@ -2,10 +2,10 @@ name: ci
on: on:
push: push:
branches-ignore: branches-ignore:
- master - main
pull_request: pull_request:
branches-ignore: branches-ignore:
- master - main
jobs: jobs:
ci: ci:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -18,17 +18,26 @@ jobs:
POSTGRES_PASSWORD: 123456 POSTGRES_PASSWORD: 123456
POSTGRES_USER: postgres POSTGRES_USER: postgres
options: --health-cmd=pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 options: --health-cmd=pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
steps: steps:
- name: Start MySQL - name: Start MySQL
run: sudo systemctl start mysql.service run: sudo systemctl start mysql.service
- uses: actions/checkout@v2 - uses: actions/cache@v4
- uses: actions/setup-python@v2
with: with:
python-version: '3.x' path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/poetry.lock') }}
restore-keys: |
${{ runner.os }}-pip-
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install and configure Poetry - name: Install and configure Poetry
uses: snok/install-poetry@v1.1.1 run: |
with: pip install -U pip poetry
virtualenvs-create: false poetry config virtualenvs.create false
- name: CI - name: CI
env: env:
MYSQL_PASS: root MYSQL_PASS: root

View File

@@ -7,14 +7,14 @@ jobs:
publish: publish:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- uses: actions/setup-python@v2 - uses: actions/setup-python@v5
with: with:
python-version: '3.x' python-version: '3.x'
- name: Install and configure Poetry - name: Install and configure Poetry
uses: snok/install-poetry@v1.1.1 run: |
with: pip install -U pip poetry
virtualenvs-create: false poetry config virtualenvs.create false
- name: Build dists - name: Build dists
run: make build run: make build
- name: Pypi Publish - name: Pypi Publish

View File

@@ -1,7 +1,83 @@
# ChangeLog # ChangeLog
## 0.8
### [0.8.0](../../releases/tag/v0.8.0) - 2024-12-04
- Fix the issue of parameter concatenation when generating ORM with inspectdb (#331)
- Fix KeyError when deleting a field with unqiue=True. (#364)
- Correct the click import. (#360)
- Improve CLI help text and output. (#355)
- Fix mysql drop unique index raises OperationalError. (#346)
**Upgrade note:**
1. Use column name as unique key name for mysql
2. Drop support for Python3.7
## 0.7
### [0.7.2](../../releases/tag/v0.7.2) - 2023-07-20
- Support virtual fields.
- Fix modify multiple times. (#279)
- Added `-i` and `--in-transaction` options to `aerich migrate` command. (#296)
- Fix generates two semicolons in a row. (#301)
### 0.7.1
- Fix syntax error with python3.8.10. (#265)
- Fix sql generate error. (#263)
- Fix initialize an empty database. (#267)
### 0.7.1rc1
- Fix postgres sql error (#263)
### 0.7.0
**Now aerich use `.py` file to record versions.**
Upgrade Note:
1. Drop `aerich` table
2. Delete `migrations/models` folder
3. Run `aerich init-db`
- Improve `inspectdb` adding support to `postgresql::numeric` data type
- Add support for dynamically load DDL classes easing to add support to
new databases without changing `Migrate` class logic
- Fix decimal field change. (#246)
- Support add/remove field with index.
## 0.6
### 0.6.3
- Improve `inspectdb` and support `postgres` & `sqlite`.
### 0.6.2
- Support migration for specified index. (#203)
### 0.6.1
- Fix `pyproject.toml` not existing error. (#217)
### 0.6.0
- Change default config file from `aerich.ini` to `pyproject.toml`. (#197)
**Upgrade note:**
1. Run `aerich init -t config.TORTOISE_ORM`.
2. Remove `aerich.ini`.
- Remove `pydantic` dependency. (#198)
- `inspectdb` support `DATE`. (#215)
## 0.5 ## 0.5
### 0.5.8
- Support `indexes` change. (#193)
### 0.5.7 ### 0.5.7
- Fix no module found error. (#188) (#189) - Fix no module found error. (#188) (#189)

View File

@@ -12,16 +12,22 @@ up:
@poetry update @poetry update
deps: deps:
@poetry install -E asyncpg -E asyncmy -E aiomysql @poetry install -E asyncpg -E asyncmy
style: deps _style:
isort -src $(checkfiles) @isort -src $(checkfiles)
black $(black_opts) $(checkfiles) @black $(black_opts) $(checkfiles)
style: deps _style
check: deps _check:
black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false) @black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
flake8 $(checkfiles) @ruff check $(checkfiles)
bandit -x tests -r $(checkfiles) @mypy $(checkfiles)
ifneq ($(shell python -c 'import sys;is_py38=sys.version_info<(3,9);rc=int(is_py38);sys.exit(rc)'),)
# Run bandit with Python3.9+, as the `usedforsecurity=...` parameter of `hashlib.new` is only added from Python 3.9 onwards.
@bandit -r aerich
endif
check: deps _check
test: deps test: deps
$(py_warn) TEST_DB=sqlite://:memory: py.test $(py_warn) TEST_DB=sqlite://:memory: py.test
@@ -35,9 +41,10 @@ test_mysql:
test_postgres: test_postgres:
$(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest -vv -s $(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest -vv -s
testall: deps test_sqlite test_postgres test_mysql _testall: test_sqlite test_postgres test_mysql
testall: deps _testall
build: deps build: deps
@poetry build @poetry build
ci: check testall ci: check _testall

View File

@@ -5,9 +5,11 @@
[![image](https://github.com/tortoise/aerich/workflows/pypi/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:pypi) [![image](https://github.com/tortoise/aerich/workflows/pypi/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:pypi)
[![image](https://github.com/tortoise/aerich/workflows/ci/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:ci) [![image](https://github.com/tortoise/aerich/workflows/ci/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:ci)
English | [Русский](./README_RU.md)
## Introduction ## Introduction
Aerich is a database migrations tool for Tortoise-ORM, which is like alembic for SQLAlchemy, or like Django ORM with Aerich is a database migrations tool for TortoiseORM, which is like alembic for SQLAlchemy, or like Django ORM with
it\'s own migration solution. it\'s own migration solution.
## Install ## Install
@@ -15,7 +17,7 @@ it\'s own migration solution.
Just install from pypi: Just install from pypi:
```shell ```shell
> pip install aerich pip install aerich
``` ```
## Quick Start ## Quick Start
@@ -27,11 +29,8 @@ Usage: aerich [OPTIONS] COMMAND [ARGS]...
Options: Options:
-V, --version Show the version and exit. -V, --version Show the version and exit.
-c, --config TEXT Config file. [default: aerich.ini] -c, --config TEXT Config file. [default: pyproject.toml]
--app TEXT Tortoise-ORM app name. --app TEXT Tortoise-ORM app name.
-n, --name TEXT Name of section in .ini file to use for aerich config.
[default: aerich]
-h, --help Show this message and exit. -h, --help Show this message and exit.
Commands: Commands:
@@ -47,7 +46,7 @@ Commands:
## Usage ## Usage
You need add `aerich.models` to your `Tortoise-ORM` config first. Example: You need to add `aerich.models` to your `Tortoise-ORM` config first. Example:
```python ```python
TORTOISE_ORM = { TORTOISE_ORM = {
@@ -70,10 +69,9 @@ Usage: aerich init [OPTIONS]
Init config file and generate root migrate location. Init config file and generate root migrate location.
OOptions: Options:
-t, --tortoise-orm TEXT Tortoise-ORM config module dict variable, like -t, --tortoise-orm TEXT Tortoise-ORM config module dict variable, like
settings.TORTOISE_ORM. [required] settings.TORTOISE_ORM. [required]
--location TEXT Migrate store location. [default: ./migrations] --location TEXT Migrate store location. [default: ./migrations]
-s, --src_folder TEXT Folder of the source, relative to the project root. -s, --src_folder TEXT Folder of the source, relative to the project root.
-h, --help Show this message and exit. -h, --help Show this message and exit.
@@ -85,7 +83,7 @@ Initialize the config file and migrations location:
> aerich init -t tests.backends.mysql.TORTOISE_ORM > aerich init -t tests.backends.mysql.TORTOISE_ORM
Success create migrate location ./migrations Success create migrate location ./migrations
Success generate config file aerich.ini Success write config to pyproject.toml
``` ```
### Init db ### Init db
@@ -105,22 +103,30 @@ e.g. `aerich --app other_models init-db`.
```shell ```shell
> aerich migrate --name drop_column > aerich migrate --name drop_column
Success migrate 1_202029051520102929_drop_column.sql Success migrate 1_202029051520102929_drop_column.py
``` ```
Format of migrate filename is Format of migrate filename is
`{version_num}_{datetime}_{name|update}.sql`. `{version_num}_{datetime}_{name|update}.py`.
If `aerich` guesses you are renaming a column, it will ask `Rename {old_column} to {new_column} [True]`. You can choose If `aerich` guesses you are renaming a column, it will ask `Rename {old_column} to {new_column} [True]`. You can choose
`True` to rename column without column drop, or choose `False` to drop the column then create. Note that the latter may `True` to rename column without column drop, or choose `False` to drop the column then create. Note that the latter may
lose data. lose data.
If you need to manually write migration, you could generate empty file:
```shell
> aerich migrate --name add_index --empty
Success migrate 1_202326122220101229_add_index.py
```
### Upgrade to latest version ### Upgrade to latest version
```shell ```shell
> aerich upgrade > aerich upgrade
Success upgrade 1_202029051520102929_drop_column.sql Success upgrade 1_202029051520102929_drop_column.py
``` ```
Now your db is migrated to latest. Now your db is migrated to latest.
@@ -146,7 +152,7 @@ Options:
```shell ```shell
> aerich downgrade > aerich downgrade
Success downgrade 1_202029051520102929_drop_column.sql Success downgrade 1_202029051520102929_drop_column.py
``` ```
Now your db is rolled back to the specified version. Now your db is rolled back to the specified version.
@@ -156,7 +162,7 @@ Now your db is rolled back to the specified version.
```shell ```shell
> aerich history > aerich history
1_202029051520102929_drop_column.sql 1_202029051520102929_drop_column.py
``` ```
### Show heads to be migrated ### Show heads to be migrated
@@ -164,12 +170,12 @@ Now your db is rolled back to the specified version.
```shell ```shell
> aerich heads > aerich heads
1_202029051520102929_drop_column.sql 1_202029051520102929_drop_column.py
``` ```
### Inspect db tables to TortoiseORM model ### Inspect db tables to TortoiseORM model
Currently `inspectdb` only supports MySQL. Currently `inspectdb` support MySQL & Postgres & SQLite.
```shell ```shell
Usage: aerich inspectdb [OPTIONS] Usage: aerich inspectdb [OPTIONS]
@@ -193,7 +199,44 @@ Inspect a specified table in the default app and redirect to `models.py`:
aerich inspectdb -t user > models.py aerich inspectdb -t user > models.py
``` ```
Note that this command is limited and cannot infer some fields, such as `IntEnumField`, `ForeignKeyField`, and others. For example, you table is:
```sql
CREATE TABLE `test`
(
`id` int NOT NULL AUTO_INCREMENT,
`decimal` decimal(10, 2) NOT NULL,
`date` date DEFAULT NULL,
`datetime` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`time` time DEFAULT NULL,
`float` float DEFAULT NULL,
`string` varchar(200) COLLATE utf8mb4_general_ci DEFAULT NULL,
`tinyint` tinyint DEFAULT NULL,
PRIMARY KEY (`id`),
KEY `asyncmy_string_index` (`string`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COLLATE = utf8mb4_general_ci
```
Now run `aerich inspectdb -t test` to see the generated model:
```python
from tortoise import Model, fields
class Test(Model):
date = fields.DateField(null=True, )
datetime = fields.DatetimeField(auto_now=True, )
decimal = fields.DecimalField(max_digits=10, decimal_places=2, )
float = fields.FloatField(null=True, )
id = fields.IntField(pk=True, )
string = fields.CharField(max_length=200, null=True, )
time = fields.TimeField(null=True, )
tinyint = fields.BooleanField(null=True, )
```
Note that this command is limited and can't infer some fields, such as `IntEnumField`, `ForeignKeyField`, and others.
### Multiple databases ### Multiple databases

274
README_RU.md Normal file
View File

@@ -0,0 +1,274 @@
# Aerich
[![image](https://img.shields.io/pypi/v/aerich.svg?style=flat)](https://pypi.python.org/pypi/aerich)
[![image](https://img.shields.io/github/license/tortoise/aerich)](https://github.com/tortoise/aerich)
[![image](https://github.com/tortoise/aerich/workflows/pypi/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:pypi)
[![image](https://github.com/tortoise/aerich/workflows/ci/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:ci)
[English](./README.md) | Русский
## Введение
Aerich - это инструмент для миграции базы данных для TortoiseORM, который аналогичен Alembic для SQLAlchemy или встроенному решению миграций в Django ORM.
## Установка
Просто установите из pypi:
```shell
pip install aerich
```
## Быстрый старт
```shell
> aerich -h
Usage: aerich [OPTIONS] COMMAND [ARGS]...
Options:
-V, --version Show the version and exit.
-c, --config TEXT Config file. [default: pyproject.toml]
--app TEXT Tortoise-ORM app name.
-h, --help Show this message and exit.
Commands:
downgrade Downgrade to specified version.
heads Show current available heads in migrate location.
history List all migrate items.
init Init config file and generate root migrate location.
init-db Generate schema and generate app migrate location.
inspectdb Introspects the database tables to standard output as...
migrate Generate migrate changes file.
upgrade Upgrade to specified version.
```
## Использование
Сначала вам нужно добавить aerich.models в конфигурацию вашего Tortoise-ORM. Пример:
```python
TORTOISE_ORM = {
"connections": {"default": "mysql://root:123456@127.0.0.1:3306/test"},
"apps": {
"models": {
"models": ["tests.models", "aerich.models"],
"default_connection": "default",
},
},
}
```
### Инициализация
```shell
> aerich init -h
Usage: aerich init [OPTIONS]
Init config file and generate root migrate location.
Options:
-t, --tortoise-orm TEXT Tortoise-ORM config module dict variable, like
settings.TORTOISE_ORM. [required]
--location TEXT Migrate store location. [default: ./migrations]
-s, --src_folder TEXT Folder of the source, relative to the project root.
-h, --help Show this message and exit.
```
Инициализируйте файл конфигурации и задайте местоположение миграций:
```shell
> aerich init -t tests.backends.mysql.TORTOISE_ORM
Success create migrate location ./migrations
Success write config to pyproject.toml
```
### Инициализация базы данных
```shell
> aerich init-db
Success create app migrate location ./migrations/models
Success generate schema for app "models"
```
Если ваше приложение Tortoise-ORM не является приложением по умолчанию с именем models, вы должны указать правильное имя приложения с помощью параметра --app, например: aerich --app other_models init-db.
### Обновление моделей и создание миграции
```shell
> aerich migrate --name drop_column
Success migrate 1_202029051520102929_drop_column.py
```
Формат имени файла миграции следующий: `{версия}_{дата_и_время}_{имя|обновление}.py`.
Если aerich предполагает, что вы переименовываете столбец, он спросит:
Переименовать `{старый_столбец} в {новый_столбец} [True]`. Вы можете выбрать `True`,
чтобы переименовать столбец без удаления столбца, или выбрать `False`, чтобы удалить столбец,
а затем создать новый. Обратите внимание, что последний вариант может привести к потере данных.
### Обновление до последней версии
```shell
> aerich upgrade
Success upgrade 1_202029051520102929_drop_column.py
```
Теперь ваша база данных обновлена до последней версии.
### Откат до указанной версии
```shell
> aerich downgrade -h
Usage: aerich downgrade [OPTIONS]
Downgrade to specified version.
Options:
-v, --version INTEGER Specified version, default to last. [default: -1]
-d, --delete Delete version files at the same time. [default:
False]
--yes Confirm the action without prompting.
-h, --help Show this message and exit.
```
```shell
> aerich downgrade
Success downgrade 1_202029051520102929_drop_column.py
```
Теперь ваша база данных откатилась до указанной версии.
### Показать историю
```shell
> aerich history
1_202029051520102929_drop_column.py
```
### Чтобы узнать, какие миграции должны быть применены, можно использовать команду:
```shell
> aerich heads
1_202029051520102929_drop_column.py
```
### Осмотр таблиц базы данных для модели TortoiseORM
В настоящее время inspectdb поддерживает MySQL, Postgres и SQLite.
```shell
Usage: aerich inspectdb [OPTIONS]
Introspects the database tables to standard output as TortoiseORM model.
Options:
-t, --table TEXT Which tables to inspect.
-h, --help Show this message and exit.
```
Посмотреть все таблицы и вывести их на консоль:
```shell
aerich --app models inspectdb
```
Осмотреть указанную таблицу в приложении по умолчанию и перенаправить в models.py:
```shell
aerich inspectdb -t user > models.py
```
Например, ваша таблица выглядит следующим образом:
```sql
CREATE TABLE `test`
(
`id` int NOT NULL AUTO_INCREMENT,
`decimal` decimal(10, 2) NOT NULL,
`date` date DEFAULT NULL,
`datetime` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`time` time DEFAULT NULL,
`float` float DEFAULT NULL,
`string` varchar(200) COLLATE utf8mb4_general_ci DEFAULT NULL,
`tinyint` tinyint DEFAULT NULL,
PRIMARY KEY (`id`),
KEY `asyncmy_string_index` (`string`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COLLATE = utf8mb4_general_ci
```
Теперь выполните команду aerich inspectdb -t test, чтобы увидеть сгенерированную модель:
```python
from tortoise import Model, fields
class Test(Model):
date = fields.DateField(null=True, )
datetime = fields.DatetimeField(auto_now=True, )
decimal = fields.DecimalField(max_digits=10, decimal_places=2, )
float = fields.FloatField(null=True, )
id = fields.IntField(pk=True, )
string = fields.CharField(max_length=200, null=True, )
time = fields.TimeField(null=True, )
tinyint = fields.BooleanField(null=True, )
```
Обратите внимание, что эта команда имеет ограничения и не может автоматически определить некоторые поля, такие как `IntEnumField`, `ForeignKeyField` и другие.
### Несколько баз данных
```python
tortoise_orm = {
"connections": {
"default": expand_db_url(db_url, True),
"second": expand_db_url(db_url_second, True),
},
"apps": {
"models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"},
"models_second": {"models": ["tests.models_second"], "default_connection": "second", },
},
}
```
Вам нужно указать `aerich.models` только в одном приложении и должны указывать `--app` при запуске команды `aerich migrate` и т.д.
## Восстановление рабочего процесса aerich
В некоторых случаях, например, при возникновении проблем после обновления `aerich`, вы не можете запустить `aerich migrate` или `aerich upgrade`. В таком случае вы можете выполнить следующие шаги:
1. удалите таблицы `aerich`.
2. удалите директорию `migrations/{app}`.
3. rerun `aerich init-db`.
Обратите внимание, что эти действия безопасны, и вы можете использовать их для сброса миграций, если у вас слишком много файлов миграции.
## Использование aerich в приложении
Вы можете использовать `aerich` вне командной строки, используя класс `Command`.
```python
from aerich import Command
command = Command(tortoise_config=config, app='models')
await command.init()
await command.migrate('test')
```
## Лицензия
Этот проект лицензирован в соответствии с лицензией
[Apache-2.0](https://github.com/long2ice/aerich/blob/master/LICENSE) Лицензия.

View File

@@ -1,6 +1,6 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import List from typing import TYPE_CHECKING, List, Optional, Type
from tortoise import Tortoise, generate_schema_for_client from tortoise import Tortoise, generate_schema_for_client
from tortoise.exceptions import OperationalError from tortoise.exceptions import OperationalError
@@ -8,17 +8,21 @@ from tortoise.transactions import in_transaction
from tortoise.utils import get_schema_sql from tortoise.utils import get_schema_sql
from aerich.exceptions import DowngradeError from aerich.exceptions import DowngradeError
from aerich.inspectdb import InspectDb from aerich.inspectdb.mysql import InspectMySQL
from aerich.migrate import Migrate from aerich.inspectdb.postgres import InspectPostgres
from aerich.inspectdb.sqlite import InspectSQLite
from aerich.migrate import MIGRATE_TEMPLATE, Migrate
from aerich.models import Aerich from aerich.models import Aerich
from aerich.utils import ( from aerich.utils import (
get_app_connection, get_app_connection,
get_app_connection_name, get_app_connection_name,
get_models_describe, get_models_describe,
get_version_content_from_file, import_py_file,
write_version_file,
) )
if TYPE_CHECKING:
from aerich.inspectdb import Inspect # noqa:F401
class Command: class Command:
def __init__( def __init__(
@@ -26,16 +30,27 @@ class Command:
tortoise_config: dict, tortoise_config: dict,
app: str = "models", app: str = "models",
location: str = "./migrations", location: str = "./migrations",
): ) -> None:
self.tortoise_config = tortoise_config self.tortoise_config = tortoise_config
self.app = app self.app = app
self.location = location self.location = location
Migrate.app = app Migrate.app = app
async def init(self): async def init(self) -> None:
await Migrate.init(self.tortoise_config, self.app, self.location) await Migrate.init(self.tortoise_config, self.app, self.location)
async def upgrade(self): async def _upgrade(self, conn, version_file) -> None:
file_path = Path(Migrate.migrate_location, version_file)
m = import_py_file(file_path)
upgrade = getattr(m, "upgrade")
await conn.execute_script(await upgrade(conn))
await Aerich.create(
version=version_file,
app=self.app,
content=get_models_describe(self.app),
)
async def upgrade(self, run_in_transaction: bool = True) -> List[str]:
migrated = [] migrated = []
for version_file in Migrate.get_all_version_files(): for version_file in Migrate.get_all_version_files():
try: try:
@@ -43,24 +58,18 @@ class Command:
except OperationalError: except OperationalError:
exists = False exists = False
if not exists: if not exists:
async with in_transaction( app_conn_name = get_app_connection_name(self.tortoise_config, self.app)
get_app_connection_name(self.tortoise_config, self.app) if run_in_transaction:
) as conn: async with in_transaction(app_conn_name) as conn:
file_path = Path(Migrate.migrate_location, version_file) await self._upgrade(conn, version_file)
content = get_version_content_from_file(file_path) else:
upgrade_query_list = content.get("upgrade") app_conn = get_app_connection(self.tortoise_config, self.app)
for upgrade_query in upgrade_query_list: await self._upgrade(app_conn, version_file)
await conn.execute_script(upgrade_query)
await Aerich.create(
version=version_file,
app=self.app,
content=get_models_describe(self.app),
)
migrated.append(version_file) migrated.append(version_file)
return migrated return migrated
async def downgrade(self, version: int, delete: bool): async def downgrade(self, version: int, delete: bool) -> List[str]:
ret = [] ret: List[str] = []
if version == -1: if version == -1:
specified_version = await Migrate.get_last_version() specified_version = await Migrate.get_last_version()
else: else:
@@ -73,25 +82,25 @@ class Command:
versions = [specified_version] versions = [specified_version]
else: else:
versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk) versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk)
for version in versions: for version_obj in versions:
file = version.version file = version_obj.version
async with in_transaction( async with in_transaction(
get_app_connection_name(self.tortoise_config, self.app) get_app_connection_name(self.tortoise_config, self.app)
) as conn: ) as conn:
file_path = Path(Migrate.migrate_location, file) file_path = Path(Migrate.migrate_location, file)
content = get_version_content_from_file(file_path) m = import_py_file(file_path)
downgrade_query_list = content.get("downgrade") downgrade = getattr(m, "downgrade")
if not downgrade_query_list: downgrade_sql = await downgrade(conn)
if not downgrade_sql.strip():
raise DowngradeError("No downgrade items found") raise DowngradeError("No downgrade items found")
for downgrade_query in downgrade_query_list: await conn.execute_script(downgrade_sql)
await conn.execute_query(downgrade_query) await version_obj.delete()
await version.delete()
if delete: if delete:
os.unlink(file_path) os.unlink(file_path)
ret.append(file) ret.append(file)
return ret return ret
async def heads(self): async def heads(self) -> List[str]:
ret = [] ret = []
versions = Migrate.get_all_version_files() versions = Migrate.get_all_version_files()
for version in versions: for version in versions:
@@ -99,22 +108,28 @@ class Command:
ret.append(version) ret.append(version)
return ret return ret
async def history(self): async def history(self) -> List[str]:
ret = []
versions = Migrate.get_all_version_files() versions = Migrate.get_all_version_files()
for version in versions: return [version for version in versions]
ret.append(version)
return ret
async def inspectdb(self, tables: List[str]): async def inspectdb(self, tables: Optional[List[str]] = None) -> str:
connection = get_app_connection(self.tortoise_config, self.app) connection = get_app_connection(self.tortoise_config, self.app)
inspect = InspectDb(connection, tables) dialect = connection.schema_generator.DIALECT
await inspect.inspect() if dialect == "mysql":
cls: Type["Inspect"] = InspectMySQL
elif dialect == "postgres":
cls = InspectPostgres
elif dialect == "sqlite":
cls = InspectSQLite
else:
raise NotImplementedError(f"{dialect} is not supported")
inspect = cls(connection, tables)
return await inspect.inspect()
async def migrate(self, name: str = "update"): async def migrate(self, name: str = "update", empty: bool = False) -> str:
return await Migrate.migrate(name) return await Migrate.migrate(name, empty)
async def init_db(self, safe: bool): async def init_db(self, safe: bool) -> None:
location = self.location location = self.location
app = self.app app = self.app
dirname = Path(location, app) dirname = Path(location, app)
@@ -132,7 +147,7 @@ class Command:
app=app, app=app,
content=get_models_describe(app), content=get_models_describe(app),
) )
content = { version_file = Path(dirname, version)
"upgrade": [schema], content = MIGRATE_TEMPLATE.format(upgrade_sql=schema, downgrade_sql="")
} with open(version_file, "w", encoding="utf-8") as f:
write_version_file(Path(dirname, version), content) f.write(content)

View File

@@ -1,110 +1,98 @@
import asyncio
import os import os
from configparser import ConfigParser
from functools import wraps
from pathlib import Path from pathlib import Path
from typing import List from typing import Dict, List, cast
import click import asyncclick as click
from click import Context, UsageError import tomlkit
from tortoise import Tortoise from asyncclick import Context, UsageError
from tomlkit.exceptions import NonExistentKey
from aerich import Command
from aerich.enums import Color
from aerich.exceptions import DowngradeError from aerich.exceptions import DowngradeError
from aerich.utils import add_src_path, get_tortoise_config from aerich.utils import add_src_path, get_tortoise_config
from aerich.version import __version__
from . import Command
from .enums import Color
from .version import __version__
parser = ConfigParser()
CONFIG_DEFAULT_VALUES = { CONFIG_DEFAULT_VALUES = {
"src_folder": ".", "src_folder": ".",
} }
def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
# Close db connections at the end of all all but the cli group function
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
if f.__name__ != "cli":
loop.run_until_complete(Tortoise.close_connections())
return wrapper
@click.group(context_settings={"help_option_names": ["-h", "--help"]}) @click.group(context_settings={"help_option_names": ["-h", "--help"]})
@click.version_option(__version__, "-V", "--version") @click.version_option(__version__, "-V", "--version")
@click.option( @click.option(
"-c", "-c",
"--config", "--config",
default="aerich.ini", default="pyproject.toml",
show_default=True, show_default=True,
help="Config file.", help="Config file.",
) )
@click.option("--app", required=False, help="Tortoise-ORM app name.") @click.option("--app", required=False, help="Tortoise-ORM app name.")
@click.option(
"-n",
"--name",
default="aerich",
show_default=True,
help="Name of section in .ini file to use for aerich config.",
)
@click.pass_context @click.pass_context
@coro async def cli(ctx: Context, config, app) -> None:
async def cli(ctx: Context, config, app, name):
ctx.ensure_object(dict) ctx.ensure_object(dict)
ctx.obj["config_file"] = config ctx.obj["config_file"] = config
ctx.obj["name"] = name
invoked_subcommand = ctx.invoked_subcommand invoked_subcommand = ctx.invoked_subcommand
if invoked_subcommand != "init": if invoked_subcommand != "init":
if not Path(config).exists(): config_path = Path(config)
raise UsageError("You must exec init first", ctx=ctx) if not config_path.exists():
parser.read(config) raise UsageError(
"You need to run `aerich init` first to create the config file.", ctx=ctx
location = parser[name]["location"] )
tortoise_orm = parser[name]["tortoise_orm"] content = config_path.read_text()
src_folder = parser[name].get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"]) doc: dict = tomlkit.parse(content)
try:
tool = cast(Dict[str, str], doc["tool"]["aerich"])
location = tool["location"]
tortoise_orm = tool["tortoise_orm"]
src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"])
except NonExistentKey:
raise UsageError("You need run `aerich init` again when upgrading to aerich 0.6.0+.")
add_src_path(src_folder) add_src_path(src_folder)
tortoise_config = get_tortoise_config(ctx, tortoise_orm) tortoise_config = get_tortoise_config(ctx, tortoise_orm)
app = app or list(tortoise_config.get("apps").keys())[0] if not app:
apps_config = cast(dict, tortoise_config.get("apps"))
app = list(apps_config.keys())[0]
command = Command(tortoise_config=tortoise_config, app=app, location=location) command = Command(tortoise_config=tortoise_config, app=app, location=location)
ctx.obj["command"] = command ctx.obj["command"] = command
if invoked_subcommand != "init-db": if invoked_subcommand != "init-db":
if not Path(location, app).exists(): if not Path(location, app).exists():
raise UsageError("You must exec init-db first", ctx=ctx) raise UsageError(
"You need to run `aerich init-db` first to initialize the database.", ctx=ctx
)
await command.init() await command.init()
@cli.command(help="Generate migrate changes file.") @cli.command(help="Generate a migration file for the current state of the models.")
@click.option("--name", default="update", show_default=True, help="Migrate name.") @click.option("--name", default="update", show_default=True, help="Migration name.")
@click.option("--empty", default=False, is_flag=True, help="Generate an empty migration file.")
@click.pass_context @click.pass_context
@coro async def migrate(ctx: Context, name, empty) -> None:
async def migrate(ctx: Context, name):
command = ctx.obj["command"] command = ctx.obj["command"]
ret = await command.migrate(name) ret = await command.migrate(name, empty)
if not ret: if not ret:
return click.secho("No changes detected", fg=Color.yellow) return click.secho("No changes detected", fg=Color.yellow)
click.secho(f"Success migrate {ret}", fg=Color.green) click.secho(f"Success creating migration file {ret}", fg=Color.green)
@cli.command(help="Upgrade to specified version.") @cli.command(help="Upgrade to specified migration version.")
@click.option(
"--in-transaction",
"-i",
default=True,
type=bool,
help="Make migrations in a single transaction or not. Can be helpful for large migrations or creating concurrent indexes.",
)
@click.pass_context @click.pass_context
@coro async def upgrade(ctx: Context, in_transaction: bool) -> None:
async def upgrade(ctx: Context):
command = ctx.obj["command"] command = ctx.obj["command"]
migrated = await command.upgrade() migrated = await command.upgrade(run_in_transaction=in_transaction)
if not migrated: if not migrated:
click.secho("No upgrade items found", fg=Color.yellow) click.secho("No upgrade items found", fg=Color.yellow)
else: else:
for version_file in migrated: for version_file in migrated:
click.secho(f"Success upgrade {version_file}", fg=Color.green) click.secho(f"Success upgrading to {version_file}", fg=Color.green)
@cli.command(help="Downgrade to specified version.") @cli.command(help="Downgrade to specified version.")
@@ -113,8 +101,8 @@ async def upgrade(ctx: Context):
"--version", "--version",
default=-1, default=-1,
type=int, type=int,
show_default=True, show_default=False,
help="Specified version, default to last.", help="Specified version, default to last migration.",
) )
@click.option( @click.option(
"-d", "-d",
@@ -122,59 +110,56 @@ async def upgrade(ctx: Context):
is_flag=True, is_flag=True,
default=False, default=False,
show_default=True, show_default=True,
help="Delete version files at the same time.", help="Also delete the migration files.",
) )
@click.pass_context @click.pass_context
@click.confirmation_option( @click.confirmation_option(
prompt="Downgrade is dangerous, which maybe lose your data, are you sure?", prompt="Downgrade is dangerous: you might lose your data! Are you sure?",
) )
@coro async def downgrade(ctx: Context, version: int, delete: bool) -> None:
async def downgrade(ctx: Context, version: int, delete: bool):
command = ctx.obj["command"] command = ctx.obj["command"]
try: try:
files = await command.downgrade(version, delete) files = await command.downgrade(version, delete)
except DowngradeError as e: except DowngradeError as e:
return click.secho(str(e), fg=Color.yellow) return click.secho(str(e), fg=Color.yellow)
for file in files: for file in files:
click.secho(f"Success downgrade {file}", fg=Color.green) click.secho(f"Success downgrading to {file}", fg=Color.green)
@cli.command(help="Show current available heads in migrate location.") @cli.command(help="Show currently available heads (unapplied migrations).")
@click.pass_context @click.pass_context
@coro async def heads(ctx: Context) -> None:
async def heads(ctx: Context):
command = ctx.obj["command"] command = ctx.obj["command"]
head_list = await command.heads() head_list = await command.heads()
if not head_list: if not head_list:
return click.secho("No available heads, try migrate first", fg=Color.green) return click.secho("No available heads.", fg=Color.green)
for version in head_list: for version in head_list:
click.secho(version, fg=Color.green) click.secho(version, fg=Color.green)
@cli.command(help="List all migrate items.") @cli.command(help="List all migrations.")
@click.pass_context @click.pass_context
@coro async def history(ctx: Context) -> None:
async def history(ctx: Context):
command = ctx.obj["command"] command = ctx.obj["command"]
versions = await command.history() versions = await command.history()
if not versions: if not versions:
return click.secho("No history, try migrate", fg=Color.green) return click.secho("No migrations created yet.", fg=Color.green)
for version in versions: for version in versions:
click.secho(version, fg=Color.green) click.secho(version, fg=Color.green)
@cli.command(help="Init config file and generate root migrate location.") @cli.command(help="Initialize aerich config and create migrations folder.")
@click.option( @click.option(
"-t", "-t",
"--tortoise-orm", "--tortoise-orm",
required=True, required=True,
help="Tortoise-ORM config module dict variable, like settings.TORTOISE_ORM.", help="Tortoise-ORM config dict location, like `settings.TORTOISE_ORM`.",
) )
@click.option( @click.option(
"--location", "--location",
default="./migrations", default="./migrations",
show_default=True, show_default=True,
help="Migrate store location.", help="Migrations folder.",
) )
@click.option( @click.option(
"-s", "-s",
@@ -184,12 +169,8 @@ async def history(ctx: Context):
help="Folder of the source, relative to the project root.", help="Folder of the source, relative to the project root.",
) )
@click.pass_context @click.pass_context
@coro async def init(ctx: Context, tortoise_orm, location, src_folder) -> None:
async def init(ctx: Context, tortoise_orm, location, src_folder):
config_file = ctx.obj["config_file"] config_file = ctx.obj["config_file"]
name = ctx.obj["name"]
if Path(config_file).exists():
return click.secho("Configuration file already created", fg=Color.yellow)
if os.path.isabs(src_folder): if os.path.isabs(src_folder):
src_folder = os.path.relpath(os.getcwd(), src_folder) src_folder = os.path.relpath(os.getcwd(), src_folder)
@@ -200,46 +181,52 @@ async def init(ctx: Context, tortoise_orm, location, src_folder):
# check that we can find the configuration, if not we can fail before the config file gets created # check that we can find the configuration, if not we can fail before the config file gets created
add_src_path(src_folder) add_src_path(src_folder)
get_tortoise_config(ctx, tortoise_orm) get_tortoise_config(ctx, tortoise_orm)
config_path = Path(config_file)
if config_path.exists():
content = config_path.read_text()
else:
content = "[tool.aerich]"
doc: dict = tomlkit.parse(content)
table = tomlkit.table()
table["tortoise_orm"] = tortoise_orm
table["location"] = location
table["src_folder"] = src_folder
doc["tool"]["aerich"] = table
parser.add_section(name) config_path.write_text(tomlkit.dumps(doc))
parser.set(name, "tortoise_orm", tortoise_orm)
parser.set(name, "location", location)
parser.set(name, "src_folder", src_folder)
with open(config_file, "w", encoding="utf-8") as f:
parser.write(f)
Path(location).mkdir(parents=True, exist_ok=True) Path(location).mkdir(parents=True, exist_ok=True)
click.secho(f"Success create migrate location {location}", fg=Color.green) click.secho(f"Success creating migrations folder {location}", fg=Color.green)
click.secho(f"Success generate config file {config_file}", fg=Color.green) click.secho(f"Success writing aerich config to {config_file}", fg=Color.green)
@cli.command(help="Generate schema and generate app migrate location.") @cli.command(help="Generate schema and generate app migration folder.")
@click.option( @click.option(
"-s",
"--safe", "--safe",
type=bool, type=bool,
is_flag=True,
default=True, default=True,
help="When set to true, creates the table only when it does not already exist.", help="Create tables only when they do not already exist.",
show_default=True, show_default=True,
) )
@click.pass_context @click.pass_context
@coro async def init_db(ctx: Context, safe: bool) -> None:
async def init_db(ctx: Context, safe):
command = ctx.obj["command"] command = ctx.obj["command"]
app = command.app app = command.app
dirname = Path(command.location, app) dirname = Path(command.location, app)
try: try:
await command.init_db(safe) await command.init_db(safe)
click.secho(f"Success create app migrate location {dirname}", fg=Color.green) click.secho(f"Success creating app migration folder {dirname}", fg=Color.green)
click.secho(f'Success generate schema for app "{app}"', fg=Color.green) click.secho(f'Success generating initial migration file for app "{app}"', fg=Color.green)
except FileExistsError: except FileExistsError:
return click.secho( return click.secho(
f"Inited {app} already, or delete {dirname} and try again.", fg=Color.yellow f"App {app} is already initialized. Delete {dirname} and try again.", fg=Color.yellow
) )
@cli.command(help="Introspects the database tables to standard output as TortoiseORM model.") @cli.command(help="Prints the current database tables to stdout as Tortoise-ORM models.")
@click.option( @click.option(
"-t", "-t",
"--table", "--table",
@@ -248,13 +235,13 @@ async def init_db(ctx: Context, safe):
required=False, required=False,
) )
@click.pass_context @click.pass_context
@coro async def inspectdb(ctx: Context, table: List[str]) -> None:
async def inspectdb(ctx: Context, table: List[str]):
command = ctx.obj["command"] command = ctx.obj["command"]
await command.inspectdb(table) ret = await command.inspectdb(table)
click.secho(ret)
def main(): def main() -> None:
cli() cli()

32
aerich/coder.py Normal file
View File

@@ -0,0 +1,32 @@
import base64
import json
import pickle # nosec: B301,B403
from typing import Any, Union
from tortoise.indexes import Index
class JsonEncoder(json.JSONEncoder):
def default(self, obj) -> Any:
if isinstance(obj, Index):
return {
"type": "index",
"val": base64.b64encode(pickle.dumps(obj)).decode(), # nosec: B301
}
else:
return super().default(obj)
def object_hook(obj) -> Any:
_type = obj.get("type")
if not _type:
return obj
return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301
def encoder(obj: dict) -> str:
return json.dumps(obj, cls=JsonEncoder)
def decoder(obj: Union[str, bytes]) -> Any:
return json.loads(obj, object_hook=object_hook)

View File

@@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import List, Type from typing import Any, List, Type, cast
from tortoise import BaseDBAsyncClient, Model from tortoise import BaseDBAsyncClient, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.backends.base.schema_generator import BaseSchemaGenerator
@@ -23,30 +23,38 @@ class BaseDDL:
_DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"' _DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"'
_ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}' _ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"'
_M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment}' _M2M_TABLE_TEMPLATE = (
'CREATE TABLE "{table_name}" (\n'
' "{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,\n'
' "{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}\n'
"){extra}{comment}"
)
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}' _MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}'
_CHANGE_COLUMN_TEMPLATE = ( _CHANGE_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" CHANGE {old_column_name} {new_column_name} {new_column_type}' 'ALTER TABLE "{table_name}" CHANGE {old_column_name} {new_column_name} {new_column_type}'
) )
_RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"' _RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"'
def __init__(self, client: "BaseDBAsyncClient"): def __init__(self, client: "BaseDBAsyncClient") -> None:
self.client = client self.client = client
self.schema_generator = self.schema_generator_cls(client) self.schema_generator = self.schema_generator_cls(client)
def create_table(self, model: "Type[Model]"): def create_table(self, model: "Type[Model]") -> str:
return self.schema_generator._get_table_sql(model, True)["table_creation_string"] return self.schema_generator._get_table_sql(model, True)["table_creation_string"].rstrip(
";"
)
def drop_table(self, table_name: str): def drop_table(self, table_name: str) -> str:
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name) return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def create_m2m( def create_m2m(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
): ) -> str:
through = field_describe.get("through") through = cast(str, field_describe.get("through"))
description = field_describe.get("description") description = field_describe.get("description")
reference_id = reference_table_describe.get("pk_field").get("db_column") pk_field = cast(dict, reference_table_describe.get("pk_field"))
db_field_types = reference_table_describe.get("pk_field").get("db_field_types") reference_id = pk_field.get("db_column")
db_field_types = cast(dict, pk_field.get("db_field_types"))
return self._M2M_TABLE_TEMPLATE.format( return self._M2M_TABLE_TEMPLATE.format(
table_name=through, table_name=through,
backward_table=model._meta.db_table, backward_table=model._meta.db_table,
@@ -59,34 +67,30 @@ class BaseDDL:
forward_type=db_field_types.get(self.DIALECT) or db_field_types.get(""), forward_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
on_delete=field_describe.get("on_delete"), on_delete=field_describe.get("on_delete"),
extra=self.schema_generator._table_generate_extra(table=through), extra=self.schema_generator._table_generate_extra(table=through),
comment=self.schema_generator._table_comment_generator( comment=(
table=through, comment=description self.schema_generator._table_comment_generator(table=through, comment=description)
) if description
if description else ""
else "", ),
) )
def drop_m2m(self, table_name: str): def drop_m2m(self, table_name: str) -> str:
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name) return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def _get_default(self, model: "Type[Model]", field_describe: dict): def _get_default(self, model: "Type[Model]", field_describe: dict) -> Any:
db_table = model._meta.db_table db_table = model._meta.db_table
default = field_describe.get("default") default = field_describe.get("default")
if isinstance(default, Enum): if isinstance(default, Enum):
default = default.value default = default.value
db_column = field_describe.get("db_column") db_column = cast(str, field_describe.get("db_column"))
auto_now_add = field_describe.get("auto_now_add", False) auto_now_add = field_describe.get("auto_now_add", False)
auto_now = field_describe.get("auto_now", False) auto_now = field_describe.get("auto_now", False)
if default is not None or auto_now_add: if default is not None or auto_now_add:
if ( if field_describe.get("field_type") in [
field_describe.get("field_type") "UUIDField",
in [ "TextField",
"UUIDField", "JSONField",
"TextField", ] or is_default_function(default):
"JSONField",
]
or is_default_function(default)
):
default = "" default = ""
else: else:
try: try:
@@ -103,64 +107,55 @@ class BaseDDL:
default = None default = None
return default return default
def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False): def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str:
return self._add_or_modify_column(model, field_describe, is_pk)
def _add_or_modify_column(self, model, field_describe: dict, is_pk: bool, modify=False) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
description = field_describe.get("description") description = field_describe.get("description")
db_column = field_describe.get("db_column") db_column = cast(str, field_describe.get("db_column"))
db_field_types = field_describe.get("db_field_types") db_field_types = cast(dict, field_describe.get("db_field_types"))
default = self._get_default(model, field_describe) default = self._get_default(model, field_describe)
if default is None: if default is None:
default = "" default = ""
return self._ADD_COLUMN_TEMPLATE.format( if modify:
unique = ""
template = self._MODIFY_COLUMN_TEMPLATE
else:
unique = "UNIQUE" if field_describe.get("unique") else ""
template = self._ADD_COLUMN_TEMPLATE
return template.format(
table_name=db_table, table_name=db_table,
column=self.schema_generator._create_string( column=self.schema_generator._create_string(
db_column=db_column, db_column=db_column,
field_type=db_field_types.get(self.DIALECT, db_field_types.get("")), field_type=db_field_types.get(self.DIALECT, db_field_types.get("")),
nullable="NOT NULL" if not field_describe.get("nullable") else "", nullable="NOT NULL" if not field_describe.get("nullable") else "",
unique="UNIQUE" if field_describe.get("unique") else "", unique=unique,
comment=self.schema_generator._column_comment_generator( comment=(
table=db_table, self.schema_generator._column_comment_generator(
column=db_column, table=db_table,
comment=field_describe.get("description"), column=db_column,
) comment=description,
if description )
else "", if description
else ""
),
is_primary_key=is_pk, is_primary_key=is_pk,
default=default, default=default,
), ),
) )
def drop_column(self, model: "Type[Model]", column_name: str): def drop_column(self, model: "Type[Model]", column_name: str) -> str:
return self._DROP_COLUMN_TEMPLATE.format( return self._DROP_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, column_name=column_name table_name=model._meta.db_table, column_name=column_name
) )
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False): def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str:
db_table = model._meta.db_table return self._add_or_modify_column(model, field_describe, is_pk, modify=True)
db_field_types = field_describe.get("db_field_types")
default = self._get_default(model, field_describe)
if default is None:
default = ""
return self._MODIFY_COLUMN_TEMPLATE.format(
table_name=db_table,
column=self.schema_generator._create_string(
db_column=field_describe.get("db_column"),
field_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
nullable="NOT NULL" if not field_describe.get("nullable") else "",
unique="",
comment=self.schema_generator._column_comment_generator(
table=db_table,
column=field_describe.get("db_column"),
comment=field_describe.get("description"),
)
if field_describe.get("description")
else "",
is_primary_key=is_pk,
default=default,
),
)
def rename_column(self, model: "Type[Model]", old_column_name: str, new_column_name: str): def rename_column(
self, model: "Type[Model]", old_column_name: str, new_column_name: str
) -> str:
return self._RENAME_COLUMN_TEMPLATE.format( return self._RENAME_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, table_name=model._meta.db_table,
old_column_name=old_column_name, old_column_name=old_column_name,
@@ -169,7 +164,7 @@ class BaseDDL:
def change_column( def change_column(
self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str
): ) -> str:
return self._CHANGE_COLUMN_TEMPLATE.format( return self._CHANGE_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, table_name=model._meta.db_table,
old_column_name=old_column_name, old_column_name=old_column_name,
@@ -177,17 +172,17 @@ class BaseDDL:
new_column_type=new_column_type, new_column_type=new_column_type,
) )
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False): def add_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str:
return self._ADD_INDEX_TEMPLATE.format( return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE " if unique else "", unique="UNIQUE " if unique else "",
index_name=self.schema_generator._generate_index_name( index_name=self.schema_generator._generate_index_name(
"idx" if not unique else "uid", model, field_names "idx" if not unique else "uid", model, field_names
), ),
table_name=model._meta.db_table, table_name=model._meta.db_table,
column_names=", ".join([self.schema_generator.quote(f) for f in field_names]), column_names=", ".join(self.schema_generator.quote(f) for f in field_names),
) )
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False): def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str:
return self._DROP_INDEX_TEMPLATE.format( return self._DROP_INDEX_TEMPLATE.format(
index_name=self.schema_generator._generate_index_name( index_name=self.schema_generator._generate_index_name(
"idx" if not unique else "uid", model, field_names "idx" if not unique else "uid", model, field_names
@@ -195,39 +190,52 @@ class BaseDDL:
table_name=model._meta.db_table, table_name=model._meta.db_table,
) )
def add_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict): def drop_index_by_name(self, model: "Type[Model]", index_name: str) -> str:
return self._DROP_INDEX_TEMPLATE.format(
index_name=index_name,
table_name=model._meta.db_table,
)
def _generate_fk_name(
self, db_table, field_describe: dict, reference_table_describe: dict
) -> str:
"""Generate fk name"""
db_column = cast(str, field_describe.get("raw_field"))
pk_field = cast(dict, reference_table_describe.get("pk_field"))
to_field = cast(str, pk_field.get("db_column"))
to_table = cast(str, reference_table_describe.get("table"))
return self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=db_column,
to_table=to_table,
to_field=to_field,
)
def add_fk(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
db_column = field_describe.get("raw_field") db_column = field_describe.get("raw_field")
reference_id = reference_table_describe.get("pk_field").get("db_column") pk_field = cast(dict, reference_table_describe.get("pk_field"))
fk_name = self.schema_generator._generate_fk_name( reference_id = pk_field.get("db_column")
from_table=db_table,
from_field=db_column,
to_table=reference_table_describe.get("table"),
to_field=reference_table_describe.get("pk_field").get("db_column"),
)
return self._ADD_FK_TEMPLATE.format( return self._ADD_FK_TEMPLATE.format(
table_name=db_table, table_name=db_table,
fk_name=fk_name, fk_name=self._generate_fk_name(db_table, field_describe, reference_table_describe),
db_column=db_column, db_column=db_column,
table=reference_table_describe.get("table"), table=reference_table_describe.get("table"),
field=reference_id, field=reference_id,
on_delete=field_describe.get("on_delete"), on_delete=field_describe.get("on_delete"),
) )
def drop_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict): def drop_fk(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
return self._DROP_FK_TEMPLATE.format( fk_name = self._generate_fk_name(db_table, field_describe, reference_table_describe)
table_name=db_table, return self._DROP_FK_TEMPLATE.format(table_name=db_table, fk_name=fk_name)
fk_name=self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=field_describe.get("raw_field"),
to_table=reference_table_describe.get("table"),
to_field=reference_table_describe.get("pk_field").get("db_column"),
),
)
def alter_column_default(self, model: "Type[Model]", field_describe: dict): def alter_column_default(self, model: "Type[Model]", field_describe: dict) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
default = self._get_default(model, field_describe) default = self._get_default(model, field_describe)
return self._ALTER_DEFAULT_TEMPLATE.format( return self._ALTER_DEFAULT_TEMPLATE.format(
@@ -236,13 +244,13 @@ class BaseDDL:
default="SET" + default if default is not None else "DROP DEFAULT", default="SET" + default if default is not None else "DROP DEFAULT",
) )
def alter_column_null(self, model: "Type[Model]", field_describe: dict): def alter_column_null(self, model: "Type[Model]", field_describe: dict) -> str:
return self.modify_column(model, field_describe) return self.modify_column(model, field_describe)
def set_comment(self, model: "Type[Model]", field_describe: dict): def set_comment(self, model: "Type[Model]", field_describe: dict) -> str:
return self.modify_column(model, field_describe) return self.modify_column(model, field_describe)
def rename_table(self, model: "Type[Model]", old_table_name: str, new_table_name: str): def rename_table(self, model: "Type[Model]", old_table_name: str, new_table_name: str) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
return self._RENAME_TABLE_TEMPLATE.format( return self._RENAME_TABLE_TEMPLATE.format(
table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name

View File

@@ -1,7 +1,12 @@
from typing import TYPE_CHECKING, List, Type
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
if TYPE_CHECKING:
from tortoise import Model # noqa:F401
class MysqlDDL(BaseDDL): class MysqlDDL(BaseDDL):
schema_generator_cls = MySQLSchemaGenerator schema_generator_cls = MySQLSchemaGenerator
@@ -22,6 +27,37 @@ class MysqlDDL(BaseDDL):
_DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`" _DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`"
_ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}" _ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}"
_DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`" _DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`"
_M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment}" _M2M_TABLE_TEMPLATE = (
"CREATE TABLE `{table_name}` (\n"
" `{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,\n"
" `{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE\n"
"){extra}{comment}"
)
_MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}" _MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}"
_RENAME_TABLE_TEMPLATE = "ALTER TABLE `{old_table_name}` RENAME TO `{new_table_name}`" _RENAME_TABLE_TEMPLATE = "ALTER TABLE `{old_table_name}` RENAME TO `{new_table_name}`"
def _index_name(self, unique: bool, model: "Type[Model]", field_names: List[str]) -> str:
if unique:
if len(field_names) == 1:
# Example: `email = CharField(max_length=50, unique=True)`
# Generate schema: `"email" VARCHAR(10) NOT NULL UNIQUE`
# Unique index key is the same as field name: `email`
return field_names[0]
index_prefix = "uid"
else:
index_prefix = "idx"
return self.schema_generator._generate_index_name(index_prefix, model, field_names)
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str:
return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE " if unique else "",
index_name=self._index_name(unique, model, field_names),
table_name=model._meta.db_table,
column_names=", ".join(self.schema_generator.quote(f) for f in field_names),
)
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str:
return self._DROP_INDEX_TEMPLATE.format(
index_name=self._index_name(unique, model, field_names),
table_name=model._meta.db_table,
)

View File

@@ -1,4 +1,4 @@
from typing import Type from typing import Type, cast
from tortoise import Model from tortoise import Model
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
@@ -18,7 +18,7 @@ class PostgresDDL(BaseDDL):
_SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}' _SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"'
def alter_column_null(self, model: "Type[Model]", field_describe: dict): def alter_column_null(self, model: "Type[Model]", field_describe: dict) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
return self._ALTER_NULL_TEMPLATE.format( return self._ALTER_NULL_TEMPLATE.format(
table_name=db_table, table_name=db_table,
@@ -26,9 +26,9 @@ class PostgresDDL(BaseDDL):
set_drop="DROP" if field_describe.get("nullable") else "SET", set_drop="DROP" if field_describe.get("nullable") else "SET",
) )
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False): def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
db_field_types = field_describe.get("db_field_types") db_field_types = cast(dict, field_describe.get("db_field_types"))
db_column = field_describe.get("db_column") db_column = field_describe.get("db_column")
datatype = db_field_types.get(self.DIALECT) or db_field_types.get("") datatype = db_field_types.get(self.DIALECT) or db_field_types.get("")
return self._MODIFY_COLUMN_TEMPLATE.format( return self._MODIFY_COLUMN_TEMPLATE.format(
@@ -38,12 +38,14 @@ class PostgresDDL(BaseDDL):
using=f' USING "{db_column}"::{datatype}', using=f' USING "{db_column}"::{datatype}',
) )
def set_comment(self, model: "Type[Model]", field_describe: dict): def set_comment(self, model: "Type[Model]", field_describe: dict) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
return self._SET_COMMENT_TEMPLATE.format( return self._SET_COMMENT_TEMPLATE.format(
table_name=db_table, table_name=db_table,
column=field_describe.get("db_column") or field_describe.get("raw_field"), column=field_describe.get("db_column") or field_describe.get("raw_field"),
comment="'{}'".format(field_describe.get("description")) comment=(
if field_describe.get("description") "'{}'".format(field_describe.get("description"))
else "NULL", if field_describe.get("description")
else "NULL"
),
) )

View File

@@ -1,86 +0,0 @@
import sys
from typing import List, Optional
from ddlparse import DdlParse
from tortoise import BaseDBAsyncClient
class InspectDb:
_table_template = "class {table}(Model):\n"
_field_template_mapping = {
"INT": " {field} = fields.IntField({pk}{unique}{comment})",
"SMALLINT": " {field} = fields.IntField({pk}{unique}{comment})",
"TINYINT": " {field} = fields.BooleanField({null}{default}{comment})",
"VARCHAR": " {field} = fields.CharField({pk}{unique}{length}{null}{default}{comment})",
"LONGTEXT": " {field} = fields.TextField({null}{default}{comment})",
"TEXT": " {field} = fields.TextField({null}{default}{comment})",
"DATETIME": " {field} = fields.DatetimeField({null}{default}{comment})",
"FLOAT": " {field} = fields.FloatField({null}{default}{comment})",
}
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
self.conn = conn
self.tables = tables
self.DIALECT = conn.schema_generator.DIALECT
async def show_create_tables(self):
if self.DIALECT == "mysql":
if not self.tables:
sql_tables = f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{self.conn.database}';" # nosec: B608
ret = await self.conn.execute_query(sql_tables)
self.tables = map(lambda x: x["TABLE_NAME"], ret[1])
for table in self.tables:
sql_show_create_table = f"SHOW CREATE TABLE {table}"
ret = await self.conn.execute_query(sql_show_create_table)
yield ret[1][0]["Create Table"]
else:
raise NotImplementedError("Currently only support MySQL")
async def inspect(self):
ddl_list = self.show_create_tables()
result = "from tortoise import Model, fields\n\n\n"
tables = []
async for ddl in ddl_list:
parser = DdlParse(ddl, DdlParse.DATABASE.mysql)
table = parser.parse()
name = table.name.title()
columns = table.columns
fields = []
model = self._table_template.format(table=name)
for column_name, column in columns.items():
comment = default = length = unique = null = pk = ""
if column.primary_key:
pk = "pk=True, "
if column.unique:
unique = "unique=True, "
if column.data_type == "VARCHAR":
length = f"max_length={column.length}, "
if not column.not_null:
null = "null=True, "
if column.default is not None:
if column.data_type == "TINYINT":
default = f"default={'True' if column.default == '1' else 'False'}, "
elif column.data_type == "DATETIME":
if "CURRENT_TIMESTAMP" in column.default:
if "ON UPDATE CURRENT_TIMESTAMP" in ddl:
default = "auto_now_add=True, "
else:
default = "auto_now=True, "
else:
default = f"default={column.default}, "
if column.comment:
comment = f"description='{column.comment}', "
field = self._field_template_mapping[column.data_type].format(
field=column_name,
pk=pk,
unique=unique,
length=length,
null=null,
default=default,
comment=comment,
)
fields.append(field)
tables.append(model + "\n".join(fields))
sys.stdout.write(result + "\n\n\n".join(tables))

View File

@@ -0,0 +1,186 @@
from __future__ import annotations
from typing import Any, Callable, Dict, Optional, TypedDict
from pydantic import BaseModel
from tortoise import BaseDBAsyncClient
class ColumnInfoDict(TypedDict):
name: str
pk: str
index: str
null: str
default: str
length: str
comment: str
FieldMapDict = Dict[str, Callable[..., str]]
class Column(BaseModel):
name: str
data_type: str
null: bool
default: Any
comment: Optional[str] = None
pk: bool
unique: bool
index: bool
length: Optional[int] = None
extra: Optional[str] = None
decimal_places: Optional[int] = None
max_digits: Optional[int] = None
def translate(self) -> ColumnInfoDict:
comment = default = length = index = null = pk = ""
if self.pk:
pk = "pk=True, "
else:
if self.unique:
index = "unique=True, "
else:
if self.index:
index = "index=True, "
if self.data_type in ("varchar", "VARCHAR"):
length = f"max_length={self.length}, "
elif self.data_type in ("decimal", "numeric"):
length_parts = []
if self.max_digits:
length_parts.append(f"max_digits={self.max_digits}")
if self.decimal_places:
length_parts.append(f"decimal_places={self.decimal_places}")
if length_parts:
length = ", ".join(length_parts) + ", "
if self.null:
null = "null=True, "
if self.default is not None:
if self.data_type in ("tinyint", "INT"):
default = f"default={'True' if self.default == '1' else 'False'}, "
elif self.data_type == "bool":
default = f"default={'True' if self.default == 'true' else 'False'}, "
elif self.data_type in ("datetime", "timestamptz", "TIMESTAMP"):
if "CURRENT_TIMESTAMP" == self.default:
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
default = "auto_now=True, "
else:
default = "auto_now_add=True, "
else:
if "::" in self.default:
default = f"default={self.default.split('::')[0]}, "
elif self.default.endswith("()"):
default = ""
elif self.default == "":
default = 'default=""'
else:
default = f"default={self.default}, "
if self.comment:
comment = f"description='{self.comment}', "
return {
"name": self.name,
"pk": pk,
"index": index,
"null": null,
"default": default,
"length": length,
"comment": comment,
}
class Inspect:
_table_template = "class {table}(Model):\n"
def __init__(self, conn: BaseDBAsyncClient, tables: list[str] | None = None) -> None:
self.conn = conn
try:
self.database = conn.database # type:ignore[attr-defined]
except AttributeError:
pass
self.tables = tables
@property
def field_map(self) -> FieldMapDict:
raise NotImplementedError
async def inspect(self) -> str:
if not self.tables:
self.tables = await self.get_all_tables()
result = "from tortoise import Model, fields\n\n\n"
tables = []
for table in self.tables:
columns = await self.get_columns(table)
fields = []
model = self._table_template.format(table=table.title().replace("_", ""))
for column in columns:
field = self.field_map[column.data_type](**column.translate())
fields.append(" " + field)
tables.append(model + "\n".join(fields))
return result + "\n\n\n".join(tables)
async def get_columns(self, table: str) -> list[Column]:
raise NotImplementedError
async def get_all_tables(self) -> list[str]:
raise NotImplementedError
@classmethod
def decimal_field(cls, **kwargs) -> str:
return "{name} = fields.DecimalField({pk}{index}{length}{null}{default}{comment})".format(
**kwargs
)
@classmethod
def time_field(cls, **kwargs) -> str:
return "{name} = fields.TimeField({null}{default}{comment})".format(**kwargs)
@classmethod
def date_field(cls, **kwargs) -> str:
return "{name} = fields.DateField({null}{default}{comment})".format(**kwargs)
@classmethod
def float_field(cls, **kwargs) -> str:
return "{name} = fields.FloatField({null}{default}{comment})".format(**kwargs)
@classmethod
def datetime_field(cls, **kwargs) -> str:
return "{name} = fields.DatetimeField({null}{default}{comment})".format(**kwargs)
@classmethod
def text_field(cls, **kwargs) -> str:
return "{name} = fields.TextField({null}{default}{comment})".format(**kwargs)
@classmethod
def char_field(cls, **kwargs) -> str:
return "{name} = fields.CharField({pk}{index}{length}{null}{default}{comment})".format(
**kwargs
)
@classmethod
def int_field(cls, **kwargs) -> str:
return "{name} = fields.IntField({pk}{index}{comment})".format(**kwargs)
@classmethod
def smallint_field(cls, **kwargs) -> str:
return "{name} = fields.SmallIntField({pk}{index}{comment})".format(**kwargs)
@classmethod
def bigint_field(cls, **kwargs) -> str:
return "{name} = fields.BigIntField({pk}{index}{default}{comment})".format(**kwargs)
@classmethod
def bool_field(cls, **kwargs) -> str:
return "{name} = fields.BooleanField({null}{default}{comment})".format(**kwargs)
@classmethod
def uuid_field(cls, **kwargs) -> str:
return "{name} = fields.UUIDField({pk}{index}{default}{comment})".format(**kwargs)
@classmethod
def json_field(cls, **kwargs) -> str:
return "{name} = fields.JSONField({null}{default}{comment})".format(**kwargs)
@classmethod
def binary_field(cls, **kwargs) -> str:
return "{name} = fields.BinaryField({null}{default}{comment})".format(**kwargs)

71
aerich/inspectdb/mysql.py Normal file
View File

@@ -0,0 +1,71 @@
from __future__ import annotations
from aerich.inspectdb import Column, FieldMapDict, Inspect
class InspectMySQL(Inspect):
@property
def field_map(self) -> FieldMapDict:
return {
"int": self.int_field,
"smallint": self.smallint_field,
"tinyint": self.bool_field,
"bigint": self.bigint_field,
"varchar": self.char_field,
"char": self.char_field,
"longtext": self.text_field,
"text": self.text_field,
"datetime": self.datetime_field,
"float": self.float_field,
"date": self.date_field,
"time": self.time_field,
"decimal": self.decimal_field,
"json": self.json_field,
"longblob": self.binary_field,
}
async def get_all_tables(self) -> list[str]:
sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s"
ret = await self.conn.execute_query_dict(sql, [self.database])
return list(map(lambda x: x["TABLE_NAME"], ret))
async def get_columns(self, table: str) -> list[Column]:
columns = []
sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME
from information_schema.COLUMNS c
left join information_schema.STATISTICS s on c.TABLE_NAME = s.TABLE_NAME
and c.TABLE_SCHEMA = s.TABLE_SCHEMA
and c.COLUMN_NAME = s.COLUMN_NAME
where c.TABLE_SCHEMA = %s
and c.TABLE_NAME = %s"""
ret = await self.conn.execute_query_dict(sql, [self.database, table])
for row in ret:
non_unique = row["NON_UNIQUE"]
if non_unique is None:
unique = False
else:
unique = not non_unique
index_name = row["INDEX_NAME"]
if index_name is None:
index = False
else:
index = row["INDEX_NAME"] != "PRIMARY"
columns.append(
Column(
name=row["COLUMN_NAME"],
data_type=row["DATA_TYPE"],
null=row["IS_NULLABLE"] == "YES",
default=row["COLUMN_DEFAULT"],
pk=row["COLUMN_KEY"] == "PRI",
comment=row["COLUMN_COMMENT"],
unique=row["COLUMN_KEY"] == "UNI",
extra=row["EXTRA"],
# TODO: why `unque`?
unque=unique, # type:ignore
index=index,
length=row["CHARACTER_MAXIMUM_LENGTH"],
max_digits=row["NUMERIC_PRECISION"],
decimal_places=row["NUMERIC_SCALE"],
)
)
return columns

View File

@@ -0,0 +1,79 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from aerich.inspectdb import Column, FieldMapDict, Inspect
if TYPE_CHECKING:
from tortoise.backends.base_postgres.client import BasePostgresClient
class InspectPostgres(Inspect):
def __init__(self, conn: "BasePostgresClient", tables: list[str] | None = None) -> None:
super().__init__(conn, tables)
self.schema = conn.server_settings.get("schema") or "public"
@property
def field_map(self) -> FieldMapDict:
return {
"int4": self.int_field,
"int8": self.int_field,
"smallint": self.smallint_field,
"varchar": self.char_field,
"text": self.text_field,
"bigint": self.bigint_field,
"timestamptz": self.datetime_field,
"float4": self.float_field,
"float8": self.float_field,
"date": self.date_field,
"time": self.time_field,
"decimal": self.decimal_field,
"numeric": self.decimal_field,
"uuid": self.uuid_field,
"jsonb": self.json_field,
"bytea": self.binary_field,
"bool": self.bool_field,
"timestamp": self.datetime_field,
}
async def get_all_tables(self) -> list[str]:
sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2"
ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
return list(map(lambda x: x["table_name"], ret))
async def get_columns(self, table: str) -> list[Column]:
columns = []
sql = f"""select c.column_name,
col_description('public.{table}'::regclass, ordinal_position) as column_comment,
t.constraint_type as column_key,
udt_name as data_type,
is_nullable,
column_default,
character_maximum_length,
numeric_precision,
numeric_scale
from information_schema.constraint_column_usage const
join information_schema.table_constraints t
using (table_catalog, table_schema, table_name, constraint_catalog, constraint_schema, constraint_name)
right join information_schema.columns c using (column_name, table_catalog, table_schema, table_name)
where c.table_catalog = $1
and c.table_name = $2
and c.table_schema = $3""" # nosec:B608
ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema])
for row in ret:
columns.append(
Column(
name=row["column_name"],
data_type=row["data_type"],
null=row["is_nullable"] == "YES",
default=row["column_default"],
length=row["character_maximum_length"],
max_digits=row["numeric_precision"],
decimal_places=row["numeric_scale"],
comment=row["column_comment"],
pk=row["column_key"] == "PRIMARY KEY",
unique=False, # can't get this simply
index=False, # can't get this simply
)
)
return columns

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
from aerich.inspectdb import Column, FieldMapDict, Inspect
class InspectSQLite(Inspect):
@property
def field_map(self) -> FieldMapDict:
return {
"INTEGER": self.int_field,
"INT": self.bool_field,
"SMALLINT": self.smallint_field,
"VARCHAR": self.char_field,
"TEXT": self.text_field,
"TIMESTAMP": self.datetime_field,
"REAL": self.float_field,
"BIGINT": self.bigint_field,
"DATE": self.date_field,
"TIME": self.time_field,
"JSON": self.json_field,
"BLOB": self.binary_field,
}
async def get_columns(self, table: str) -> list[Column]:
columns = []
sql = f"PRAGMA table_info({table})"
ret = await self.conn.execute_query_dict(sql)
columns_index = await self._get_columns_index(table)
for row in ret:
try:
length = row["type"].split("(")[1].split(")")[0]
except IndexError:
length = None
columns.append(
Column(
name=row["name"],
data_type=row["type"].split("(")[0],
null=row["notnull"] == 0,
default=row["dflt_value"],
length=length,
pk=row["pk"] == 1,
unique=columns_index.get(row["name"]) == "unique",
index=columns_index.get(row["name"]) == "index",
)
)
return columns
async def _get_columns_index(self, table: str) -> dict[str, str]:
sql = f"PRAGMA index_list ({table})"
indexes = await self.conn.execute_query_dict(sql)
ret = {}
for index in indexes:
sql = f"PRAGMA index_info({index['name']})"
index_info = (await self.conn.execute_query_dict(sql))[0]
ret[index_info["name"]] = "unique" if index["unique"] else "index"
return ret
async def get_all_tables(self) -> list[str]:
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
ret = await self.conn.execute_query_dict(sql)
return list(map(lambda x: x["tbl_name"], ret))

View File

@@ -1,21 +1,32 @@
import hashlib
import importlib
import os import os
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast
import click import asyncclick as click
from dictdiffer import diff from dictdiffer import diff
from tortoise import BaseDBAsyncClient, Model, Tortoise from tortoise import BaseDBAsyncClient, Model, Tortoise
from tortoise.exceptions import OperationalError from tortoise.exceptions import OperationalError
from tortoise.indexes import Index
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
from aerich.models import MAX_VERSION_LENGTH, Aerich from aerich.models import MAX_VERSION_LENGTH, Aerich
from aerich.utils import ( from aerich.utils import get_app_connection, get_models_describe, is_default_function
get_app_connection,
get_models_describe, MIGRATE_TEMPLATE = """from tortoise import BaseDBAsyncClient
is_default_function,
write_version_file,
) async def upgrade(db: BaseDBAsyncClient) -> str:
return \"\"\"
{upgrade_sql}\"\"\"
async def downgrade(db: BaseDBAsyncClient) -> str:
return \"\"\"
{downgrade_sql}\"\"\"
"""
class Migrate: class Migrate:
@@ -26,68 +37,68 @@ class Migrate:
_upgrade_m2m: List[str] = [] _upgrade_m2m: List[str] = []
_downgrade_m2m: List[str] = [] _downgrade_m2m: List[str] = []
_aerich = Aerich.__name__ _aerich = Aerich.__name__
_rename_old = [] _rename_old: List[str] = []
_rename_new = [] _rename_new: List[str] = []
ddl: BaseDDL ddl: BaseDDL
ddl_class: Type[BaseDDL]
_last_version_content: Optional[dict] = None _last_version_content: Optional[dict] = None
app: str app: str
migrate_location: str migrate_location: Path
dialect: str dialect: str
_db_version: Optional[str] = None _db_version: Optional[str] = None
@staticmethod
def get_field_by_name(name: str, fields: List[dict]) -> dict:
return next(filter(lambda x: x.get("name") == name, fields))
@classmethod @classmethod
def get_all_version_files(cls) -> List[str]: def get_all_version_files(cls) -> List[str]:
return sorted( return sorted(
filter(lambda x: x.endswith("sql"), os.listdir(cls.migrate_location)), filter(lambda x: x.endswith("py"), os.listdir(cls.migrate_location)),
key=lambda x: int(x.split("_")[0]), key=lambda x: int(x.split("_")[0]),
) )
@classmethod @classmethod
def _get_model(cls, model: str) -> Type[Model]: def _get_model(cls, model: str) -> Type[Model]:
return Tortoise.apps.get(cls.app).get(model) return Tortoise.apps[cls.app][model]
@classmethod @classmethod
async def get_last_version(cls) -> Optional[Aerich]: async def get_last_version(cls) -> Optional[Aerich]:
try: try:
return await Aerich.filter(app=cls.app).first() return await Aerich.filter(app=cls.app).first()
except OperationalError: except OperationalError:
pass return None
@classmethod @classmethod
async def _get_db_version(cls, connection: BaseDBAsyncClient): async def _get_db_version(cls, connection: BaseDBAsyncClient) -> None:
if cls.dialect == "mysql": if cls.dialect == "mysql":
sql = "select version() as version" sql = "select version() as version"
ret = await connection.execute_query(sql) ret = await connection.execute_query(sql)
cls._db_version = ret[1][0].get("version") cls._db_version = ret[1][0].get("version")
@classmethod @classmethod
async def init(cls, config: dict, app: str, location: str): async def load_ddl_class(cls) -> Type[BaseDDL]:
ddl_dialect_module = importlib.import_module(f"aerich.ddl.{cls.dialect}")
return getattr(ddl_dialect_module, f"{cls.dialect.capitalize()}DDL")
@classmethod
async def init(cls, config: dict, app: str, location: str) -> None:
await Tortoise.init(config=config) await Tortoise.init(config=config)
last_version = await cls.get_last_version() last_version = await cls.get_last_version()
cls.app = app cls.app = app
cls.migrate_location = Path(location, app) cls.migrate_location = Path(location, app)
if last_version: if last_version:
cls._last_version_content = last_version.content cls._last_version_content = cast(dict, last_version.content)
connection = get_app_connection(config, app) connection = get_app_connection(config, app)
cls.dialect = connection.schema_generator.DIALECT cls.dialect = connection.schema_generator.DIALECT
if cls.dialect == "mysql": cls.ddl_class = await cls.load_ddl_class()
from aerich.ddl.mysql import MysqlDDL cls.ddl = cls.ddl_class(connection)
cls.ddl = MysqlDDL(connection)
elif cls.dialect == "sqlite":
from aerich.ddl.sqlite import SqliteDDL
cls.ddl = SqliteDDL(connection)
elif cls.dialect == "postgres":
from aerich.ddl.postgres import PostgresDDL
cls.ddl = PostgresDDL(connection)
await cls._get_db_version(connection) await cls._get_db_version(connection)
@classmethod @classmethod
async def _get_last_version_num(cls): async def _get_last_version_num(cls) -> Optional[int]:
last_version = await cls.get_last_version() last_version = await cls.get_last_version()
if not last_version: if not last_version:
return None return None
@@ -95,70 +106,119 @@ class Migrate:
return int(version.split("_", 1)[0]) return int(version.split("_", 1)[0])
@classmethod @classmethod
async def generate_version(cls, name=None): async def generate_version(cls, name=None) -> str:
now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "") now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "")
last_version_num = await cls._get_last_version_num() last_version_num = await cls._get_last_version_num()
if last_version_num is None: if last_version_num is None:
return f"0_{now}_init.sql" return f"0_{now}_init.py"
version = f"{last_version_num + 1}_{now}_{name}.sql" version = f"{last_version_num + 1}_{now}_{name}.py"
if len(version) > MAX_VERSION_LENGTH: if len(version) > MAX_VERSION_LENGTH:
raise ValueError(f"Version name exceeds maximum length ({MAX_VERSION_LENGTH})") raise ValueError(f"Version name exceeds maximum length ({MAX_VERSION_LENGTH})")
return version return version
@classmethod @classmethod
async def _generate_diff_sql(cls, name): async def _generate_diff_py(cls, name) -> str:
version = await cls.generate_version(name) version = await cls.generate_version(name)
# delete if same version exists # delete if same version exists
for version_file in cls.get_all_version_files(): for version_file in cls.get_all_version_files():
if version_file.startswith(version.split("_")[0]): if version_file.startswith(version.split("_")[0]):
os.unlink(Path(cls.migrate_location, version_file)) os.unlink(Path(cls.migrate_location, version_file))
content = {
"upgrade": list(dict.fromkeys(cls.upgrade_operators)), content = cls._get_diff_file_content()
"downgrade": list(dict.fromkeys(cls.downgrade_operators)), Path(cls.migrate_location, version).write_text(content, encoding="utf-8")
}
write_version_file(Path(cls.migrate_location, version), content)
return version return version
@classmethod @classmethod
async def migrate(cls, name) -> str: async def migrate(cls, name: str, empty: bool) -> str:
""" """
diff old models and new models to generate diff content diff old models and new models to generate diff content
:param name: :param name: str name for migration
:param empty: bool if True generates empty migration
:return: :return:
""" """
if empty:
return await cls._generate_diff_py(name)
new_version_content = get_models_describe(cls.app) new_version_content = get_models_describe(cls.app)
cls.diff_models(cls._last_version_content, new_version_content) last_version = cast(dict, cls._last_version_content)
cls.diff_models(new_version_content, cls._last_version_content, False) cls.diff_models(last_version, new_version_content)
cls.diff_models(new_version_content, last_version, False)
cls._merge_operators() cls._merge_operators()
if not cls.upgrade_operators: if not cls.upgrade_operators:
return "" return ""
return await cls._generate_diff_sql(name) return await cls._generate_diff_py(name)
@classmethod @classmethod
def _add_operator(cls, operator: str, upgrade=True, fk_m2m=False): def _get_diff_file_content(cls) -> str:
"""
builds content for diff file from template
"""
def join_lines(lines: List[str]) -> str:
if not lines:
return ""
return ";\n ".join(lines) + ";"
return MIGRATE_TEMPLATE.format(
upgrade_sql=join_lines(cls.upgrade_operators),
downgrade_sql=join_lines(cls.downgrade_operators),
)
@classmethod
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False) -> None:
""" """
add operator,differentiate fk because fk is order limit add operator,differentiate fk because fk is order limit
:param operator: :param operator:
:param upgrade: :param upgrade:
:param fk_m2m: :param fk_m2m_index:
:return: :return:
""" """
operator = operator.rstrip(";")
if upgrade: if upgrade:
if fk_m2m: if fk_m2m_index:
cls._upgrade_fk_m2m_index_operators.append(operator) cls._upgrade_fk_m2m_index_operators.append(operator)
else: else:
cls.upgrade_operators.append(operator) cls.upgrade_operators.append(operator)
else: else:
if fk_m2m: if fk_m2m_index:
cls._downgrade_fk_m2m_index_operators.append(operator) cls._downgrade_fk_m2m_index_operators.append(operator)
else: else:
cls.downgrade_operators.append(operator) cls.downgrade_operators.append(operator)
@classmethod @classmethod
def diff_models(cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True): def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]) -> list:
ret: list = []
def index_hash(self) -> str:
h = hashlib.new("MD5", usedforsecurity=False) # type:ignore[call-arg]
h.update(
self.index_name(cls.ddl.schema_generator, model).encode()
+ self.__class__.__name__.encode()
)
return h.hexdigest()
for index in indexes:
if isinstance(index, Index):
index.__hash__ = index_hash # type:ignore[method-assign,assignment]
ret.append(index)
return ret
@classmethod
def _get_indexes(cls, model, model_describe: dict) -> Set[Union[Index, Tuple[str, ...]]]:
indexes: Set[Union[Index, Tuple[str, ...]]] = set()
for x in cls._handle_indexes(model, model_describe.get("indexes", [])):
if isinstance(x, Index):
indexes.add(x)
else:
indexes.add(cast(Tuple[str, ...], tuple(x)))
return indexes
@classmethod
def diff_models(
cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True
) -> None:
""" """
diff models and add operators diff models and add operators
:param old_models: :param old_models:
@@ -171,28 +231,35 @@ class Migrate:
new_models.pop(_aerich, None) new_models.pop(_aerich, None)
for new_model_str, new_model_describe in new_models.items(): for new_model_str, new_model_describe in new_models.items():
model = cls._get_model(new_model_describe.get("name").split(".")[1]) model = cls._get_model(new_model_describe["name"].split(".")[1])
if new_model_str not in old_models.keys(): if new_model_str not in old_models:
if upgrade: if upgrade:
cls._add_operator(cls.add_model(model), upgrade) cls._add_operator(cls.add_model(model), upgrade)
else: else:
# we can't find origin model when downgrade, so skip # we can't find origin model when downgrade, so skip
pass pass
else: else:
old_model_describe = old_models.get(new_model_str) old_model_describe = cast(dict, old_models.get(new_model_str))
# rename table # rename table
new_table = new_model_describe.get("table") new_table = cast(str, new_model_describe.get("table"))
old_table = old_model_describe.get("table") old_table = cast(str, old_model_describe.get("table"))
if new_table != old_table: if new_table != old_table:
cls._add_operator(cls.rename_table(model, old_table, new_table), upgrade) cls._add_operator(cls.rename_table(model, old_table, new_table), upgrade)
old_unique_together = set( old_unique_together = set(
map(lambda x: tuple(x), old_model_describe.get("unique_together")) map(
lambda x: tuple(x),
cast(List[Iterable[str]], old_model_describe.get("unique_together")),
)
) )
new_unique_together = set( new_unique_together = set(
map(lambda x: tuple(x), new_model_describe.get("unique_together")) map(
lambda x: tuple(x),
cast(List[Iterable[str]], new_model_describe.get("unique_together")),
)
) )
old_indexes = cls._get_indexes(model, old_model_describe)
new_indexes = cls._get_indexes(model, new_model_describe)
old_pk_field = old_model_describe.get("pk_field") old_pk_field = old_model_describe.get("pk_field")
new_pk_field = new_model_describe.get("pk_field") new_pk_field = new_model_describe.get("pk_field")
# pk field # pk field
@@ -202,12 +269,19 @@ class Migrate:
if action == "change" and option == "name": if action == "change" and option == "name":
cls._add_operator(cls._rename_field(model, *change), upgrade) cls._add_operator(cls._rename_field(model, *change), upgrade)
# m2m fields # m2m fields
old_m2m_fields = old_model_describe.get("m2m_fields") old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields"))
new_m2m_fields = new_model_describe.get("m2m_fields") new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields"))
for action, option, change in diff(old_m2m_fields, new_m2m_fields): for action, option, change in diff(old_m2m_fields, new_m2m_fields):
if change[0][0] == "db_constraint": if change[0][0] == "db_constraint":
continue continue
table = change[0][1].get("through") new_value = change[0][1]
if isinstance(new_value, str):
for new_m2m_field in new_m2m_fields:
if new_m2m_field["name"] == new_value:
table = cast(str, new_m2m_field.get("through"))
break
else:
table = new_value.get("through")
if action == "add": if action == "add":
add = False add = False
if upgrade and table not in cls._upgrade_m2m: if upgrade and table not in cls._upgrade_m2m:
@@ -217,14 +291,11 @@ class Migrate:
cls._downgrade_m2m.append(table) cls._downgrade_m2m.append(table)
add = True add = True
if add: if add:
ref_desc = cast(dict, new_models.get(new_value.get("model_name")))
cls._add_operator( cls._add_operator(
cls.create_m2m( cls.create_m2m(model, new_value, ref_desc),
model,
change[0][1],
new_models.get(change[0][1].get("model_name")),
),
upgrade, upgrade,
fk_m2m=True, fk_m2m_index=True,
) )
elif action == "remove": elif action == "remove":
add = False add = False
@@ -235,31 +306,44 @@ class Migrate:
cls._downgrade_m2m.append(table) cls._downgrade_m2m.append(table)
add = True add = True
if add: if add:
cls._add_operator(cls.drop_m2m(table), upgrade, fk_m2m=True) cls._add_operator(cls.drop_m2m(table), upgrade, True)
# add unique_together # add unique_together
for index in new_unique_together.difference(old_unique_together): for index in new_unique_together.difference(old_unique_together):
cls._add_operator(cls._add_index(model, index, True), upgrade, True) cls._add_operator(cls._add_index(model, index, True), upgrade, True)
# remove unique_together # remove unique_together
for index in old_unique_together.difference(new_unique_together): for index in old_unique_together.difference(new_unique_together):
cls._add_operator(cls._drop_index(model, index, True), upgrade, True) cls._add_operator(cls._drop_index(model, index, True), upgrade, True)
# add indexes
for idx in new_indexes.difference(old_indexes):
cls._add_operator(cls._add_index(model, idx, False), upgrade, True)
# remove indexes
for idx in old_indexes.difference(new_indexes):
cls._add_operator(cls._drop_index(model, idx, False), upgrade, True)
old_data_fields = list(
filter(
lambda x: x.get("db_field_types") is not None,
cast(List[dict], old_model_describe.get("data_fields")),
)
)
new_data_fields = list(
filter(
lambda x: x.get("db_field_types") is not None,
cast(List[dict], new_model_describe.get("data_fields")),
)
)
old_data_fields = old_model_describe.get("data_fields") old_data_fields_name = cast(List[str], [i.get("name") for i in old_data_fields])
new_data_fields = new_model_describe.get("data_fields") new_data_fields_name = cast(List[str], [i.get("name") for i in new_data_fields])
old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields))
new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields))
# add fields or rename fields # add fields or rename fields
for new_data_field_name in set(new_data_fields_name).difference( for new_data_field_name in set(new_data_fields_name).difference(
set(old_data_fields_name) set(old_data_fields_name)
): ):
new_data_field = next( new_data_field = cls.get_field_by_name(new_data_field_name, new_data_fields)
filter(lambda x: x.get("name") == new_data_field_name, new_data_fields)
)
is_rename = False is_rename = False
for old_data_field in old_data_fields: for old_data_field in old_data_fields:
changes = list(diff(old_data_field, new_data_field)) changes = list(diff(old_data_field, new_data_field))
old_data_field_name = old_data_field.get("name") old_data_field_name = cast(str, old_data_field.get("name"))
if len(changes) == 2: if len(changes) == 2:
# rename field # rename field
if ( if (
@@ -317,87 +401,98 @@ class Migrate:
), ),
upgrade, upgrade,
) )
if new_data_field["indexed"]:
cls._add_operator(
cls._add_index(
model, (new_data_field["db_column"],), new_data_field["unique"]
),
upgrade,
True,
)
# remove fields # remove fields
for old_data_field_name in set(old_data_fields_name).difference( for old_data_field_name in set(old_data_fields_name).difference(
set(new_data_fields_name) set(new_data_fields_name)
): ):
# don't remove field if is rename # don't remove field if is renamed
if (upgrade and old_data_field_name in cls._rename_old) or ( if (upgrade and old_data_field_name in cls._rename_old) or (
not upgrade and old_data_field_name in cls._rename_new not upgrade and old_data_field_name in cls._rename_new
): ):
continue continue
old_data_field = cls.get_field_by_name(old_data_field_name, old_data_fields)
db_column = cast(str, old_data_field["db_column"])
cls._add_operator( cls._add_operator(
cls._remove_field( cls._remove_field(model, db_column),
model,
next(
filter(
lambda x: x.get("name") == old_data_field_name, old_data_fields
)
).get("db_column"),
),
upgrade, upgrade,
) )
old_fk_fields = old_model_describe.get("fk_fields") if old_data_field["indexed"]:
new_fk_fields = new_model_describe.get("fk_fields") is_unique_field = old_data_field.get("unique")
cls._add_operator(
cls._drop_index(model, {db_column}, is_unique_field),
upgrade,
True,
)
old_fk_fields_name = list(map(lambda x: x.get("name"), old_fk_fields)) old_fk_fields = cast(List[dict], old_model_describe.get("fk_fields"))
new_fk_fields_name = list(map(lambda x: x.get("name"), new_fk_fields)) new_fk_fields = cast(List[dict], new_model_describe.get("fk_fields"))
old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields]
new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields]
# add fk # add fk
for new_fk_field_name in set(new_fk_fields_name).difference( for new_fk_field_name in set(new_fk_fields_name).difference(
set(old_fk_fields_name) set(old_fk_fields_name)
): ):
fk_field = next( fk_field = cls.get_field_by_name(new_fk_field_name, new_fk_fields)
filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields)
)
if fk_field.get("db_constraint"): if fk_field.get("db_constraint"):
ref_describe = cast(dict, new_models[fk_field["python_type"]])
cls._add_operator( cls._add_operator(
cls._add_fk( cls._add_fk(model, fk_field, ref_describe),
model, fk_field, new_models.get(fk_field.get("python_type"))
),
upgrade, upgrade,
fk_m2m=True, fk_m2m_index=True,
) )
# drop fk # drop fk
for old_fk_field_name in set(old_fk_fields_name).difference( for old_fk_field_name in set(old_fk_fields_name).difference(
set(new_fk_fields_name) set(new_fk_fields_name)
): ):
old_fk_field = next( old_fk_field = cls.get_field_by_name(
filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields) old_fk_field_name, cast(List[dict], old_fk_fields)
) )
if old_fk_field.get("db_constraint"): if old_fk_field.get("db_constraint"):
ref_describe = cast(dict, old_models[old_fk_field["python_type"]])
cls._add_operator( cls._add_operator(
cls._drop_fk( cls._drop_fk(model, old_fk_field, ref_describe),
model, old_fk_field, old_models.get(old_fk_field.get("python_type"))
),
upgrade, upgrade,
fk_m2m=True, fk_m2m_index=True,
) )
# change fields # change fields
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)): for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
old_data_field = next( old_data_field = cls.get_field_by_name(field_name, old_data_fields)
filter(lambda x: x.get("name") == field_name, old_data_fields) new_data_field = cls.get_field_by_name(field_name, new_data_fields)
)
new_data_field = next(
filter(lambda x: x.get("name") == field_name, new_data_fields)
)
changes = diff(old_data_field, new_data_field) changes = diff(old_data_field, new_data_field)
modified = False
for change in changes: for change in changes:
_, option, old_new = change _, option, old_new = change
if option == "indexed": if option == "indexed":
# change index # change index
unique = new_data_field.get("unique")
if old_new[0] is False and old_new[1] is True: if old_new[0] is False and old_new[1] is True:
unique = new_data_field.get("unique")
cls._add_operator( cls._add_operator(
cls._add_index(model, (field_name,), unique), upgrade, True cls._add_index(model, (field_name,), unique), upgrade, True
) )
else: else:
unique = old_data_field.get("unique")
cls._add_operator( cls._add_operator(
cls._drop_index(model, (field_name,), unique), upgrade, True cls._drop_index(model, (field_name,), unique), upgrade, True
) )
elif option == "db_field_types.": elif option == "db_field_types.":
# continue since repeated with others if new_data_field.get("field_type") == "DecimalField":
continue # modify column
cls._add_operator(
cls._modify_field(model, new_data_field),
upgrade,
)
else:
continue
elif option == "default": elif option == "default":
if not ( if not (
is_default_function(old_new[0]) or is_default_function(old_new[1]) is_default_function(old_new[0]) or is_default_function(old_new[1])
@@ -412,104 +507,131 @@ class Migrate:
elif option == "nullable": elif option == "nullable":
# change nullable # change nullable
cls._add_operator(cls._alter_null(model, new_data_field), upgrade) cls._add_operator(cls._alter_null(model, new_data_field), upgrade)
elif option == "description":
# change comment
cls._add_operator(cls._set_comment(model, new_data_field), upgrade)
else: else:
if modified:
continue
# modify column # modify column
cls._add_operator( cls._add_operator(
cls._modify_field(model, new_data_field), cls._modify_field(model, new_data_field),
upgrade, upgrade,
) )
modified = True
for old_model in old_models: for old_model in old_models.keys() - new_models.keys():
if old_model not in new_models.keys(): cls._add_operator(cls.drop_model(old_models[old_model]["table"]), upgrade)
cls._add_operator(cls.drop_model(old_models.get(old_model).get("table")), upgrade)
@classmethod @classmethod
def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str): def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str) -> str:
return cls.ddl.rename_table(model, old_table_name, new_table_name) return cls.ddl.rename_table(model, old_table_name, new_table_name)
@classmethod @classmethod
def add_model(cls, model: Type[Model]): def add_model(cls, model: Type[Model]) -> str:
return cls.ddl.create_table(model) return cls.ddl.create_table(model)
@classmethod @classmethod
def drop_model(cls, table_name: str): def drop_model(cls, table_name: str) -> str:
return cls.ddl.drop_table(table_name) return cls.ddl.drop_table(table_name)
@classmethod @classmethod
def create_m2m(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict): def create_m2m(
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
return cls.ddl.create_m2m(model, field_describe, reference_table_describe) return cls.ddl.create_m2m(model, field_describe, reference_table_describe)
@classmethod @classmethod
def drop_m2m(cls, table_name: str): def drop_m2m(cls, table_name: str) -> str:
return cls.ddl.drop_m2m(table_name) return cls.ddl.drop_m2m(table_name)
@classmethod @classmethod
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]): def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Iterable[str]) -> List[str]:
ret = [] ret = []
for field_name in fields_name: for field_name in fields_name:
field = model._meta.fields_map[field_name] try:
if field.source_field: field = model._meta.fields_map[field_name]
ret.append(field.source_field) except KeyError:
elif field_name in model._meta.fk_fields: # field dropped or to be add
ret.append(field_name + "_id") pass
else: else:
ret.append(field_name) if field.source_field:
field_name = field.source_field
elif field_name in model._meta.fk_fields:
field_name += "_id"
ret.append(field_name)
return ret return ret
@classmethod @classmethod
def _drop_index(cls, model: Type[Model], fields_name: Tuple[str], unique=False): def _drop_index(
fields_name = cls._resolve_fk_fields_name(model, fields_name) cls, model: Type[Model], fields_name: Union[Iterable[str], Index], unique=False
return cls.ddl.drop_index(model, fields_name, unique) ) -> str:
if isinstance(fields_name, Index):
return cls.ddl.drop_index_by_name(
model, fields_name.index_name(cls.ddl.schema_generator, model)
)
field_names = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.drop_index(model, field_names, unique)
@classmethod @classmethod
def _add_index(cls, model: Type[Model], fields_name: Tuple[str], unique=False): def _add_index(
fields_name = cls._resolve_fk_fields_name(model, fields_name) cls, model: Type[Model], fields_name: Union[Iterable[str], Index], unique=False
return cls.ddl.add_index(model, fields_name, unique) ) -> str:
if isinstance(fields_name, Index):
return fields_name.get_sql(cls.ddl.schema_generator, model, False)
field_names = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.add_index(model, field_names, unique)
@classmethod @classmethod
def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False): def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False) -> str:
return cls.ddl.add_column(model, field_describe, is_pk) return cls.ddl.add_column(model, field_describe, is_pk)
@classmethod @classmethod
def _alter_default(cls, model: Type[Model], field_describe: dict): def _alter_default(cls, model: Type[Model], field_describe: dict) -> str:
return cls.ddl.alter_column_default(model, field_describe) return cls.ddl.alter_column_default(model, field_describe)
@classmethod @classmethod
def _alter_null(cls, model: Type[Model], field_describe: dict): def _alter_null(cls, model: Type[Model], field_describe: dict) -> str:
return cls.ddl.alter_column_null(model, field_describe) return cls.ddl.alter_column_null(model, field_describe)
@classmethod @classmethod
def _set_comment(cls, model: Type[Model], field_describe: dict): def _set_comment(cls, model: Type[Model], field_describe: dict) -> str:
return cls.ddl.set_comment(model, field_describe) return cls.ddl.set_comment(model, field_describe)
@classmethod @classmethod
def _modify_field(cls, model: Type[Model], field_describe: dict): def _modify_field(cls, model: Type[Model], field_describe: dict) -> str:
return cls.ddl.modify_column(model, field_describe) return cls.ddl.modify_column(model, field_describe)
@classmethod @classmethod
def _drop_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict): def _drop_fk(
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
return cls.ddl.drop_fk(model, field_describe, reference_table_describe) return cls.ddl.drop_fk(model, field_describe, reference_table_describe)
@classmethod @classmethod
def _remove_field(cls, model: Type[Model], column_name: str): def _remove_field(cls, model: Type[Model], column_name: str) -> str:
return cls.ddl.drop_column(model, column_name) return cls.ddl.drop_column(model, column_name)
@classmethod @classmethod
def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str): def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str) -> str:
return cls.ddl.rename_column(model, old_field_name, new_field_name) return cls.ddl.rename_column(model, old_field_name, new_field_name)
@classmethod @classmethod
def _change_field(cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict): def _change_field(
db_field_types = new_field_describe.get("db_field_types") cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict
) -> str:
db_field_types = cast(dict, new_field_describe.get("db_field_types"))
return cls.ddl.change_column( return cls.ddl.change_column(
model, model,
old_field_describe.get("db_column"), cast(str, old_field_describe.get("db_column")),
new_field_describe.get("db_column"), cast(str, new_field_describe.get("db_column")),
db_field_types.get(cls.dialect) or db_field_types.get(""), cast(str, db_field_types.get(cls.dialect) or db_field_types.get("")),
) )
@classmethod @classmethod
def _add_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict): def _add_fk(
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
""" """
add fk add fk
:param model: :param model:
@@ -520,7 +642,7 @@ class Migrate:
return cls.ddl.add_fk(model, field_describe, reference_table_describe) return cls.ddl.add_fk(model, field_describe, reference_table_describe)
@classmethod @classmethod
def _merge_operators(cls): def _merge_operators(cls) -> None:
""" """
fk/m2m/index must be last when add,first when drop fk/m2m/index must be last when add,first when drop
:return: :return:

View File

@@ -1,12 +1,15 @@
from tortoise import Model, fields from tortoise import Model, fields
from aerich.coder import decoder, encoder
MAX_VERSION_LENGTH = 255 MAX_VERSION_LENGTH = 255
MAX_APP_LENGTH = 100
class Aerich(Model): class Aerich(Model):
version = fields.CharField(max_length=MAX_VERSION_LENGTH) version = fields.CharField(max_length=MAX_VERSION_LENGTH)
app = fields.CharField(max_length=20) app = fields.CharField(max_length=MAX_APP_LENGTH)
content = fields.JSONField() content: dict = fields.JSONField(encoder=encoder, decoder=decoder)
class Meta: class Meta:
ordering = ["-id"] ordering = ["-id"]

View File

@@ -1,17 +1,18 @@
import importlib import importlib.util
import os import os
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict, Union from types import ModuleType
from typing import Dict, Optional, Union
from click import BadOptionUsage, ClickException, Context from asyncclick import BadOptionUsage, ClickException, Context
from tortoise import BaseDBAsyncClient, Tortoise from tortoise import BaseDBAsyncClient, Tortoise
def add_src_path(path: str) -> str: def add_src_path(path: str) -> str:
""" """
add a folder to the paths so we can import from there add a folder to the paths, so we can import from there
:param path: path to add :param path: path to add
:return: absolute path :return: absolute path
""" """
@@ -77,60 +78,6 @@ def get_tortoise_config(ctx: Context, tortoise_orm: str) -> dict:
return config return config
_UPGRADE = "-- upgrade --\n"
_DOWNGRADE = "-- downgrade --\n"
def get_version_content_from_file(version_file: Union[str, Path]) -> Dict:
"""
get version content
:param version_file:
:return:
"""
with open(version_file, "r", encoding="utf-8") as f:
content = f.read()
first = content.index(_UPGRADE)
try:
second = content.index(_DOWNGRADE)
except ValueError:
second = len(content) - 1
upgrade_content = content[first + len(_UPGRADE) : second].strip() # noqa:E203
downgrade_content = content[second + len(_DOWNGRADE) :].strip() # noqa:E203
ret = {
"upgrade": list(filter(lambda x: x or False, upgrade_content.split(";\n"))),
"downgrade": list(filter(lambda x: x or False, downgrade_content.split(";\n"))),
}
return ret
def write_version_file(version_file: Path, content: Dict):
"""
write version file
:param version_file:
:param content:
:return:
"""
with open(version_file, "w", encoding="utf-8") as f:
f.write(_UPGRADE)
upgrade = content.get("upgrade")
if len(upgrade) > 1:
f.write(";\n".join(upgrade))
if not upgrade[-1].endswith(";"):
f.write(";\n")
else:
f.write(f"{upgrade[0]}")
if not upgrade[0].endswith(";"):
f.write(";")
f.write("\n")
downgrade = content.get("downgrade")
if downgrade:
f.write(_DOWNGRADE)
if len(downgrade) > 1:
f.write(";\n".join(downgrade) + ";\n")
else:
f.write(f"{downgrade[0]};\n")
def get_models_describe(app: str) -> Dict: def get_models_describe(app: str) -> Dict:
""" """
get app models describe get app models describe
@@ -138,11 +85,19 @@ def get_models_describe(app: str) -> Dict:
:return: :return:
""" """
ret = {} ret = {}
for model in Tortoise.apps.get(app).values(): for model in Tortoise.apps[app].values():
describe = model.describe() describe = model.describe()
ret[describe.get("name")] = describe ret[describe.get("name")] = describe
return ret return ret
def is_default_function(string: str): def is_default_function(string: str) -> Optional[re.Match]:
return re.match(r"^<function.+>$", str(string or "")) return re.match(r"^<function.+>$", str(string or ""))
def import_py_file(file: Union[str, Path]) -> ModuleType:
module_name, file_ext = os.path.splitext(os.path.split(file)[-1])
spec = importlib.util.spec_from_file_location(module_name, file)
module = importlib.util.module_from_spec(spec) # type:ignore[arg-type]
spec.loader.exec_module(module) # type:ignore[union-attr]
return module

View File

@@ -1 +1 @@
__version__ = "0.5.7" __version__ = "0.8.0"

View File

@@ -1,19 +1,22 @@
import asyncio import asyncio
import os import os
from typing import Generator
import pytest import pytest
from tortoise import Tortoise, expand_db_url, generate_schema_for_client from tortoise import Tortoise, expand_db_url, generate_schema_for_client
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
from tortoise.exceptions import DBConnectionError, OperationalError
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
from aerich.migrate import Migrate from aerich.migrate import Migrate
db_url = os.getenv("TEST_DB", "sqlite://:memory:") MEMORY_SQLITE = "sqlite://:memory:"
db_url_second = os.getenv("TEST_DB_SECOND", "sqlite://:memory:") db_url = os.getenv("TEST_DB", MEMORY_SQLITE)
db_url_second = os.getenv("TEST_DB_SECOND", MEMORY_SQLITE)
tortoise_orm = { tortoise_orm = {
"connections": { "connections": {
"default": expand_db_url(db_url, True), "default": expand_db_url(db_url, True),
@@ -27,7 +30,7 @@ tortoise_orm = {
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def reset_migrate(): def reset_migrate() -> None:
Migrate.upgrade_operators = [] Migrate.upgrade_operators = []
Migrate.downgrade_operators = [] Migrate.downgrade_operators = []
Migrate._upgrade_fk_m2m_index_operators = [] Migrate._upgrade_fk_m2m_index_operators = []
@@ -37,20 +40,27 @@ def reset_migrate():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def event_loop(): def event_loop() -> Generator:
policy = asyncio.get_event_loop_policy() policy = asyncio.get_event_loop_policy()
res = policy.new_event_loop() res = policy.new_event_loop()
asyncio.set_event_loop(res) asyncio.set_event_loop(res)
res._close = res.close res._close = res.close # type:ignore[attr-defined]
res.close = lambda: None res.close = lambda: None # type:ignore[method-assign]
yield res yield res
res._close() res._close() # type:ignore[attr-defined]
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
async def initialize_tests(event_loop, request): async def initialize_tests(event_loop, request) -> None:
# Placing init outside the try block since it doesn't
# establish connections to the DB eagerly.
await Tortoise.init(config=tortoise_orm)
try:
await Tortoise._drop_databases()
except (DBConnectionError, OperationalError):
pass
await Tortoise.init(config=tortoise_orm, _create_db=True) await Tortoise.init(config=tortoise_orm, _create_db=True)
await generate_schema_for_client(Tortoise.get_connection("default"), safe=True) await generate_schema_for_client(Tortoise.get_connection("default"), safe=True)

1602
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "aerich" name = "aerich"
version = "0.5.7" version = "0.8.0"
description = "A database migrations tool for Tortoise ORM." description = "A database migrations tool for Tortoise ORM."
authors = ["long2ice <long2ice@gmail.com>"] authors = ["long2ice <long2ice@gmail.com>"]
license = "Apache-2.0" license = "Apache-2.0"
@@ -15,34 +15,57 @@ packages = [
include = ["CHANGELOG.md", "LICENSE", "README.md"] include = ["CHANGELOG.md", "LICENSE", "README.md"]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.7" python = "^3.8"
tortoise-orm = "*" tortoise-orm = "*"
click = "*"
pydantic = "*"
aiomysql = { version = "*", optional = true }
asyncpg = { version = "*", optional = true } asyncpg = { version = "*", optional = true }
ddlparse = "*" asyncmy = { version = "^0.2.9", optional = true, allow-prereleases = true }
pydantic = "^2.0"
dictdiffer = "*" dictdiffer = "*"
tomlkit = "*"
asyncclick = "^8.1.7.2"
[tool.poetry.dev-dependencies] [tool.poetry.group.dev.dependencies]
flake8 = "*" ruff = "*"
isort = "*" isort = "*"
black = "*" black = "*"
pytest = "*" pytest = "*"
pytest-xdist = "*" pytest-xdist = "*"
pytest-asyncio = "*" # Breaking change in 0.23.*
# https://github.com/pytest-dev/pytest-asyncio/issues/706
pytest-asyncio = "^0.21.2"
bandit = "*" bandit = "*"
pytest-mock = "*" pytest-mock = "*"
cryptography = "*" cryptography = "*"
mypy = "^1.10.0"
[tool.poetry.extras] [tool.poetry.extras]
asyncmy = ["asyncmy"] asyncmy = ["asyncmy"]
asyncpg = ["asyncpg"] asyncpg = ["asyncpg"]
aiomysql = ["aiomysql"]
[tool.aerich]
tortoise_orm = "conftest.tortoise_orm"
location = "./migrations"
src_folder = "./."
[build-system] [build-system]
requires = ["poetry>=0.12"] requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.poetry.scripts] [tool.poetry.scripts]
aerich = "aerich.cli:main" aerich = "aerich.cli:main"
[tool.black]
line-length = 100
target-version = ['py38', 'py39', 'py310', 'py311', 'py312']
[tool.pytest.ini_options]
asyncio_mode = 'auto'
[tool.mypy]
pretty = true
python_version = "3.8"
ignore_missing_imports = true
[tool.ruff.lint]
ignore = ['E501']

View File

@@ -1,2 +0,0 @@
[flake8]
ignore = E501,W503

View File

@@ -29,14 +29,15 @@ class User(Model):
is_active = fields.BooleanField(default=True, description="Is Active") is_active = fields.BooleanField(default=True, description="Is Active")
is_superuser = fields.BooleanField(default=False, description="Is SuperUser") is_superuser = fields.BooleanField(default=False, description="Is SuperUser")
intro = fields.TextField(default="") intro = fields.TextField(default="")
longitude = fields.DecimalField(max_digits=10, decimal_places=8)
class Email(Model): class Email(Model):
email_id = fields.IntField(pk=True) email_id = fields.IntField(primary_key=True)
email = fields.CharField(max_length=200, index=True) email = fields.CharField(max_length=200, db_index=True)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
address = fields.CharField(max_length=200) address = fields.CharField(max_length=200)
users = fields.ManyToManyField("models.User") users: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User")
def default_name(): def default_name():
@@ -46,12 +47,15 @@ def default_name():
class Category(Model): class Category(Model):
slug = fields.CharField(max_length=100) slug = fields.CharField(max_length=100)
name = fields.CharField(max_length=200, null=True, default=default_name) name = fields.CharField(max_length=200, null=True, default=default_name)
user = fields.ForeignKeyField("models.User", description="User") user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", description="User"
)
title = fields.CharField(max_length=20, unique=False)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Product(Model): class Product(Model):
categories = fields.ManyToManyField("models.Category") categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category")
name = fields.CharField(max_length=50) name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num", default=0) view_num = fields.IntField(description="View Num", default=0)
sort = fields.IntField() sort = fields.IntField()
@@ -65,14 +69,17 @@ class Product(Model):
class Meta: class Meta:
unique_together = (("name", "type"),) unique_together = (("name", "type"),)
indexes = (("name", "type"),)
class Config(Model): class Config(Model):
label = fields.CharField(max_length=200) label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20) key = fields.CharField(max_length=20)
value = fields.JSONField() value: dict = fields.JSONField()
status: Status = fields.IntEnumField(Status) status: Status = fields.IntEnumField(Status)
user = fields.ForeignKeyField("models.User", description="User") user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", description="User"
)
class NewModel(Model): class NewModel(Model):

View File

@@ -34,18 +34,24 @@ class User(Model):
class Email(Model): class Email(Model):
email = fields.CharField(max_length=200) email = fields.CharField(max_length=200)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("models_second.User", db_constraint=False) user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models_second.User", db_constraint=False
)
class Category(Model): class Category(Model):
slug = fields.CharField(max_length=200) slug = fields.CharField(max_length=200)
name = fields.CharField(max_length=200) name = fields.CharField(max_length=200)
user = fields.ForeignKeyField("models_second.User", description="User") user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models_second.User", description="User"
)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Product(Model): class Product(Model):
categories = fields.ManyToManyField("models_second.Category") categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models_second.Category"
)
name = fields.CharField(max_length=50) name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num") view_num = fields.IntField(description="View Num")
sort = fields.IntField() sort = fields.IntField()
@@ -61,5 +67,5 @@ class Product(Model):
class Config(Model): class Config(Model):
label = fields.CharField(max_length=200) label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20) key = fields.CharField(max_length=20)
value = fields.JSONField() value: dict = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on) status: Status = fields.IntEnumField(Status, default=Status.on)

View File

@@ -29,23 +29,29 @@ class User(Model):
is_superuser = fields.BooleanField(default=False, description="Is SuperUser") is_superuser = fields.BooleanField(default=False, description="Is SuperUser")
avatar = fields.CharField(max_length=200, default="") avatar = fields.CharField(max_length=200, default="")
intro = fields.TextField(default="") intro = fields.TextField(default="")
longitude = fields.DecimalField(max_digits=12, decimal_places=9)
class Email(Model): class Email(Model):
email = fields.CharField(max_length=200) email = fields.CharField(max_length=200)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("models.User", db_constraint=False) user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", db_constraint=False
)
class Category(Model): class Category(Model):
slug = fields.CharField(max_length=200) slug = fields.CharField(max_length=200)
name = fields.CharField(max_length=200) name = fields.CharField(max_length=200)
user = fields.ForeignKeyField("models.User", description="User") user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", description="User"
)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Product(Model): class Product(Model):
categories = fields.ManyToManyField("models.Category") categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category")
uid = fields.IntField(source_field="uuid", unique=True)
name = fields.CharField(max_length=50) name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num") view_num = fields.IntField(description="View Num")
sort = fields.IntField() sort = fields.IntField()
@@ -59,9 +65,10 @@ class Product(Model):
class Config(Model): class Config(Model):
name = fields.CharField(max_length=100, unique=True)
label = fields.CharField(max_length=200) label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20) key = fields.CharField(max_length=20)
value = fields.JSONField() value: dict = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on) status: Status = fields.IntEnumField(Status, default=Status.on)
class Meta: class Meta:

View File

@@ -14,10 +14,11 @@ def test_create_table():
`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT, `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
`slug` VARCHAR(100) NOT NULL, `slug` VARCHAR(100) NOT NULL,
`name` VARCHAR(200), `name` VARCHAR(200),
`title` VARCHAR(20) NOT NULL,
`created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
`user_id` INT NOT NULL COMMENT 'User', `user_id` INT NOT NULL COMMENT 'User',
CONSTRAINT `fk_category_user_e2e3874c` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE CONSTRAINT `fk_category_user_e2e3874c` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE
) CHARACTER SET utf8mb4;""" ) CHARACTER SET utf8mb4"""
) )
elif isinstance(Migrate.ddl, SqliteDDL): elif isinstance(Migrate.ddl, SqliteDDL):
@@ -27,9 +28,10 @@ def test_create_table():
"id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
"slug" VARCHAR(100) NOT NULL, "slug" VARCHAR(100) NOT NULL,
"name" VARCHAR(200), "name" VARCHAR(200),
"title" VARCHAR(20) NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */ "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */
);""" )"""
) )
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
@@ -39,10 +41,11 @@ def test_create_table():
"id" SERIAL NOT NULL PRIMARY KEY, "id" SERIAL NOT NULL PRIMARY KEY,
"slug" VARCHAR(100) NOT NULL, "slug" VARCHAR(100) NOT NULL,
"name" VARCHAR(200), "name" VARCHAR(200),
"title" VARCHAR(20) NOT NULL,
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, "created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE
); );
COMMENT ON COLUMN "category"."user_id" IS 'User';""" COMMENT ON COLUMN "category"."user_id" IS 'User'"""
) )
@@ -72,18 +75,16 @@ def test_modify_column():
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active").describe(False)) ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active").describe(False))
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)" assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)"
assert (
ret1
== "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1"
)
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert ( assert (
ret0 ret0
== 'ALTER TABLE "category" ALTER COLUMN "name" TYPE VARCHAR(200) USING "name"::VARCHAR(200)' == 'ALTER TABLE "category" ALTER COLUMN "name" TYPE VARCHAR(200) USING "name"::VARCHAR(200)'
) )
if isinstance(Migrate.ddl, MysqlDDL):
assert (
ret1
== "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1"
)
elif isinstance(Migrate.ddl, PostgresDDL):
assert ( assert (
ret1 == 'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL USING "is_active"::BOOL' ret1 == 'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL USING "is_active"::BOOL'
) )
@@ -153,9 +154,7 @@ def test_add_index():
index_u = Migrate.ddl.add_index(Category, ["name"], True) index_u = Migrate.ddl.add_index(Category, ["name"], True)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert index == "ALTER TABLE `category` ADD INDEX `idx_category_name_8b0cb9` (`name`)" assert index == "ALTER TABLE `category` ADD INDEX `idx_category_name_8b0cb9` (`name`)"
assert ( assert index_u == "ALTER TABLE `category` ADD UNIQUE INDEX `name` (`name`)"
index_u == "ALTER TABLE `category` ADD UNIQUE INDEX `uid_category_name_8b0cb9` (`name`)"
)
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")' assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")'
assert index_u == 'CREATE UNIQUE INDEX "uid_category_name_8b0cb9" ON "category" ("name")' assert index_u == 'CREATE UNIQUE INDEX "uid_category_name_8b0cb9" ON "category" ("name")'
@@ -171,7 +170,7 @@ def test_drop_index():
ret_u = Migrate.ddl.drop_index(Category, ["name"], True) ret_u = Migrate.ddl.drop_index(Category, ["name"], True)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` DROP INDEX `idx_category_name_8b0cb9`" assert ret == "ALTER TABLE `category` DROP INDEX `idx_category_name_8b0cb9`"
assert ret_u == "ALTER TABLE `category` DROP INDEX `uid_category_name_8b0cb9`" assert ret_u == "ALTER TABLE `category` DROP INDEX `name`"
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'DROP INDEX "idx_category_name_8b0cb9"' assert ret == 'DROP INDEX "idx_category_name_8b0cb9"'
assert ret_u == 'DROP INDEX "uid_category_name_8b0cb9"' assert ret_u == 'DROP INDEX "uid_category_name_8b0cb9"'

View File

@@ -1,11 +1,15 @@
from pathlib import Path
from typing import List, cast
import pytest import pytest
import tortoise
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
from aerich.exceptions import NotSupportError from aerich.exceptions import NotSupportError
from aerich.migrate import Migrate from aerich.migrate import MIGRATE_TEMPLATE, Migrate
from aerich.utils import get_models_describe from aerich.utils import get_models_describe
old_models_describe = { old_models_describe = {
@@ -17,6 +21,7 @@ old_models_describe = {
"description": None, "description": None,
"docstring": None, "docstring": None,
"unique_together": [], "unique_together": [],
"indexes": [],
"pk_field": { "pk_field": {
"name": "id", "name": "id",
"field_type": "IntField", "field_type": "IntField",
@@ -99,6 +104,21 @@ old_models_describe = {
"constraints": {"ge": 1, "le": 2147483647}, "constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"}, "db_field_types": {"": "INT"},
}, },
{
"name": "title",
"field_type": "CharField",
"db_column": "title",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 20},
"db_field_types": {"": "VARCHAR(20)"},
},
], ],
"fk_fields": [ "fk_fields": [
{ {
@@ -151,6 +171,7 @@ old_models_describe = {
"description": None, "description": None,
"docstring": None, "docstring": None,
"unique_together": [], "unique_together": [],
"indexes": [],
"pk_field": { "pk_field": {
"name": "id", "name": "id",
"field_type": "IntField", "field_type": "IntField",
@@ -167,6 +188,21 @@ old_models_describe = {
"db_field_types": {"": "INT"}, "db_field_types": {"": "INT"},
}, },
"data_fields": [ "data_fields": [
{
"name": "name",
"field_type": "CharField",
"db_column": "name",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 100},
"db_field_types": {"": "VARCHAR(100)"},
},
{ {
"name": "label", "name": "label",
"field_type": "CharField", "field_type": "CharField",
@@ -242,6 +278,7 @@ old_models_describe = {
"description": None, "description": None,
"docstring": None, "docstring": None,
"unique_together": [], "unique_together": [],
"indexes": [],
"pk_field": { "pk_field": {
"name": "id", "name": "id",
"field_type": "IntField", "field_type": "IntField",
@@ -334,6 +371,7 @@ old_models_describe = {
"description": None, "description": None,
"docstring": None, "docstring": None,
"unique_together": [], "unique_together": [],
"indexes": [],
"pk_field": { "pk_field": {
"name": "id", "name": "id",
"field_type": "IntField", "field_type": "IntField",
@@ -365,6 +403,21 @@ old_models_describe = {
"constraints": {"max_length": 50}, "constraints": {"max_length": 50},
"db_field_types": {"": "VARCHAR(50)"}, "db_field_types": {"": "VARCHAR(50)"},
}, },
{
"name": "uid",
"field_type": "IntField",
"db_column": "uuid",
"python_type": "int",
"generated": False,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": -2147483648, "le": 2147483647},
"db_field_types": {"": "INT"},
},
{ {
"name": "view_num", "name": "view_num",
"field_type": "IntField", "field_type": "IntField",
@@ -512,6 +565,7 @@ old_models_describe = {
"description": None, "description": None,
"docstring": None, "docstring": None,
"unique_together": [], "unique_together": [],
"indexes": [],
"pk_field": { "pk_field": {
"name": "id", "name": "id",
"field_type": "IntField", "field_type": "IntField",
@@ -639,6 +693,21 @@ old_models_describe = {
"constraints": {}, "constraints": {},
"db_field_types": {"": "TEXT", "mysql": "LONGTEXT"}, "db_field_types": {"": "TEXT", "mysql": "LONGTEXT"},
}, },
{
"name": "longitude",
"unique": False,
"default": None,
"indexed": False,
"nullable": False,
"db_column": "longitude",
"docstring": None,
"generated": False,
"field_type": "DecimalField",
"constraints": {},
"description": None,
"python_type": "decimal.Decimal",
"db_field_types": {"": "DECIMAL(12,9)", "sqlite": "VARCHAR(40)"},
},
], ],
"fk_fields": [], "fk_fields": [],
"backward_fk_fields": [ "backward_fk_fields": [
@@ -681,6 +750,7 @@ old_models_describe = {
"description": None, "description": None,
"docstring": None, "docstring": None,
"unique_together": [], "unique_together": [],
"indexes": [],
"pk_field": { "pk_field": {
"name": "id", "name": "id",
"field_type": "IntField", "field_type": "IntField",
@@ -752,6 +822,16 @@ old_models_describe = {
} }
def should_add_user_id_column_type_alter_sql() -> bool:
if tortoise.__version__ < "0.21":
return False
# tortoise-orm>=0.21 changes IntField constraints
# from {"ge": 1, "le": 2147483647} to {"ge": -2147483648,"le": 2147483647}
data_fields = cast(List[dict], old_models_describe["models.Category"]["data_fields"])
user_id_constraints = data_fields[-1]["constraints"]
return tortoise.fields.data.IntField.constraints != user_id_constraints
def test_migrate(mocker: MockerFixture): def test_migrate(mocker: MockerFixture):
""" """
models.py diff with old_models.py models.py diff with old_models.py
@@ -762,119 +842,200 @@ def test_migrate(mocker: MockerFixture):
- drop field: User.avatar - drop field: User.avatar
- add index: Email.email - add index: Email.email
- add many to many: Email.users - add many to many: Email.users
- remove unique: User.username - remove unique: Category.title
- add unique: User.username
- change column: length User.password - change column: length User.password
- add unique_together: (name,type) of Product - add unique_together: (name,type) of Product
- drop unique field: Config.name
- alter default: Config.status - alter default: Config.status
- rename column: Product.image -> Product.pic - rename column: Product.image -> Product.pic
""" """
mocker.patch("click.prompt", side_effect=(True,)) mocker.patch("asyncclick.prompt", side_effect=(True,))
models_describe = get_models_describe("models") models_describe = get_models_describe("models")
Migrate.app = "models" Migrate.app = "models"
if isinstance(Migrate.ddl, SqliteDDL): if isinstance(Migrate.ddl, SqliteDDL):
with pytest.raises(NotSupportError): with pytest.raises(NotSupportError):
Migrate.diff_models(old_models_describe, models_describe) Migrate.diff_models(old_models_describe, models_describe)
Migrate.upgrade_operators.clear()
with pytest.raises(NotSupportError):
Migrate.diff_models(models_describe, old_models_describe, False) Migrate.diff_models(models_describe, old_models_describe, False)
Migrate.downgrade_operators.clear()
else: else:
Migrate.diff_models(old_models_describe, models_describe) Migrate.diff_models(old_models_describe, models_describe)
Migrate.diff_models(models_describe, old_models_describe, False) Migrate.diff_models(models_describe, old_models_describe, False)
Migrate._merge_operators() Migrate._merge_operators()
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert sorted(Migrate.upgrade_operators) == sorted( expected_upgrade_operators = {
[ "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)",
"ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)", "ALTER TABLE `category` MODIFY COLUMN `slug` VARCHAR(100) NOT NULL",
"ALTER TABLE `category` MODIFY COLUMN `slug` VARCHAR(100) NOT NULL", "ALTER TABLE `category` DROP INDEX `title`",
"ALTER TABLE `config` ADD `user_id` INT NOT NULL COMMENT 'User'", "ALTER TABLE `config` DROP COLUMN `name`",
"ALTER TABLE `config` ADD CONSTRAINT `fk_config_user_17daa970` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE", "ALTER TABLE `config` DROP INDEX `name`",
"ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT", "ALTER TABLE `config` ADD `user_id` INT NOT NULL COMMENT 'User'",
"ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL", "ALTER TABLE `config` ADD CONSTRAINT `fk_config_user_17daa970` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
"ALTER TABLE `email` DROP COLUMN `user_id`", "ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT",
"ALTER TABLE `configs` RENAME TO `config`", "ALTER TABLE `config` MODIFY COLUMN `value` JSON NOT NULL",
"ALTER TABLE `product` RENAME COLUMN `image` TO `pic`", "ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL",
"ALTER TABLE `email` RENAME COLUMN `id` TO `email_id`", "ALTER TABLE `email` DROP COLUMN `user_id`",
"ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)", "ALTER TABLE `configs` RENAME TO `config`",
"ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_869427` (`name`, `type_db_alias`)", "ALTER TABLE `product` DROP COLUMN `uuid`",
"ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0", "ALTER TABLE `product` DROP INDEX `uuid`",
"ALTER TABLE `user` DROP COLUMN `avatar`", "ALTER TABLE `product` RENAME COLUMN `image` TO `pic`",
"ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(100) NOT NULL", "ALTER TABLE `email` RENAME COLUMN `id` TO `email_id`",
"CREATE TABLE IF NOT EXISTS `newmodel` (\n `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,\n `name` VARCHAR(50) NOT NULL\n) CHARACTER SET utf8mb4;", "ALTER TABLE `product` ADD INDEX `idx_product_name_869427` (`name`, `type_db_alias`)",
"ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)", "ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)",
"CREATE TABLE `email_user` (`email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,`user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE) CHARACTER SET utf8mb4", "ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_869427` (`name`, `type_db_alias`)",
] "ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0",
) "ALTER TABLE `product` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `product` MODIFY COLUMN `is_reviewed` BOOL NOT NULL COMMENT 'Is Reviewed'",
"ALTER TABLE `user` DROP COLUMN `avatar`",
"ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(100) NOT NULL",
"ALTER TABLE `user` MODIFY COLUMN `intro` LONGTEXT NOT NULL",
"ALTER TABLE `user` MODIFY COLUMN `last_login` DATETIME(6) NOT NULL COMMENT 'Last Login'",
"ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1",
"ALTER TABLE `user` MODIFY COLUMN `is_superuser` BOOL NOT NULL COMMENT 'Is SuperUser' DEFAULT 0",
"ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(10,8) NOT NULL",
"ALTER TABLE `user` ADD UNIQUE INDEX `username` (`username`)",
"CREATE TABLE `email_user` (\n `email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"CREATE TABLE IF NOT EXISTS `newmodel` (\n `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,\n `name` VARCHAR(50) NOT NULL\n) CHARACTER SET utf8mb4",
"ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL",
"ALTER TABLE `email` MODIFY COLUMN `is_primary` BOOL NOT NULL DEFAULT 0",
}
expected_downgrade_operators = {
"ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL",
"ALTER TABLE `category` MODIFY COLUMN `slug` VARCHAR(200) NOT NULL",
"ALTER TABLE `category` ADD UNIQUE INDEX `title` (`title`)",
"ALTER TABLE `config` ADD `name` VARCHAR(100) NOT NULL UNIQUE",
"ALTER TABLE `config` ADD UNIQUE INDEX `name` (`name`)",
"ALTER TABLE `config` DROP COLUMN `user_id`",
"ALTER TABLE `config` DROP FOREIGN KEY `fk_config_user_17daa970`",
"ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1",
"ALTER TABLE `email` ADD `user_id` INT NOT NULL",
"ALTER TABLE `email` DROP COLUMN `address`",
"ALTER TABLE `config` RENAME TO `configs`",
"ALTER TABLE `product` RENAME COLUMN `pic` TO `image`",
"ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`",
"ALTER TABLE `product` ADD `uuid` INT NOT NULL UNIQUE",
"ALTER TABLE `product` ADD UNIQUE INDEX `uuid` (`uuid`)",
"ALTER TABLE `product` DROP INDEX `idx_product_name_869427`",
"ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`",
"ALTER TABLE `product` DROP INDEX `uid_product_name_869427`",
"ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT",
"ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''",
"ALTER TABLE `user` DROP INDEX `username`",
"ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(200) NOT NULL",
"DROP TABLE IF EXISTS `email_user`",
"DROP TABLE IF EXISTS `newmodel`",
"ALTER TABLE `user` MODIFY COLUMN `intro` LONGTEXT NOT NULL",
"ALTER TABLE `config` MODIFY COLUMN `value` TEXT NOT NULL",
"ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `product` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `product` MODIFY COLUMN `is_reviewed` BOOL NOT NULL COMMENT 'Is Reviewed'",
"ALTER TABLE `user` MODIFY COLUMN `last_login` DATETIME(6) NOT NULL COMMENT 'Last Login'",
"ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1",
"ALTER TABLE `user` MODIFY COLUMN `is_superuser` BOOL NOT NULL COMMENT 'Is SuperUser' DEFAULT 0",
"ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(12,9) NOT NULL",
"ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL",
"ALTER TABLE `email` MODIFY COLUMN `is_primary` BOOL NOT NULL DEFAULT 0",
}
if should_add_user_id_column_type_alter_sql():
sql = "ALTER TABLE `category` MODIFY COLUMN `user_id` INT NOT NULL COMMENT 'User'"
expected_upgrade_operators.add(sql)
expected_downgrade_operators.add(sql)
assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators)
assert sorted(Migrate.downgrade_operators) == sorted( assert not set(Migrate.downgrade_operators).symmetric_difference(
[ expected_downgrade_operators
"ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL",
"ALTER TABLE `category` MODIFY COLUMN `slug` VARCHAR(200) NOT NULL",
"ALTER TABLE `config` DROP COLUMN `user_id`",
"ALTER TABLE `config` DROP FOREIGN KEY `fk_config_user_17daa970`",
"ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1",
"ALTER TABLE `email` ADD `user_id` INT NOT NULL",
"ALTER TABLE `email` DROP COLUMN `address`",
"ALTER TABLE `config` RENAME TO `configs`",
"ALTER TABLE `product` RENAME COLUMN `pic` TO `image`",
"ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`",
"ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`",
"ALTER TABLE `product` DROP INDEX `uid_product_name_869427`",
"ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT",
"ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''",
"ALTER TABLE `user` DROP INDEX `idx_user_usernam_9987ab`",
"ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(200) NOT NULL",
"DROP TABLE IF EXISTS `email_user`",
"DROP TABLE IF EXISTS `newmodel`",
]
) )
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert sorted(Migrate.upgrade_operators) == sorted( expected_upgrade_operators = {
[ 'DROP INDEX "uid_category_title_f7fc03"',
'ALTER TABLE "category" ALTER COLUMN "name" DROP NOT NULL', 'ALTER TABLE "category" ALTER COLUMN "name" DROP NOT NULL',
'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(100) USING "slug"::VARCHAR(100)', 'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(100) USING "slug"::VARCHAR(100)',
'ALTER TABLE "config" ADD "user_id" INT NOT NULL', 'ALTER TABLE "category" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ',
'ALTER TABLE "config" ADD CONSTRAINT "fk_config_user_17daa970" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE', 'ALTER TABLE "config" DROP COLUMN "name"',
'ALTER TABLE "config" ALTER COLUMN "status" DROP DEFAULT', 'DROP INDEX "uid_config_name_2c83c8"',
'ALTER TABLE "configs" RENAME TO "config"', 'ALTER TABLE "config" ADD "user_id" INT NOT NULL',
'ALTER TABLE "email" ADD "address" VARCHAR(200) NOT NULL', 'ALTER TABLE "config" ADD CONSTRAINT "fk_config_user_17daa970" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE',
'ALTER TABLE "email" DROP COLUMN "user_id"', 'ALTER TABLE "config" ALTER COLUMN "status" DROP DEFAULT',
'ALTER TABLE "email" RENAME COLUMN "id" TO "email_id"', 'ALTER TABLE "config" ALTER COLUMN "value" TYPE JSONB USING "value"::JSONB',
'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0', 'ALTER TABLE "configs" RENAME TO "config"',
'ALTER TABLE "product" RENAME COLUMN "image" TO "pic"', 'ALTER TABLE "email" ADD "address" VARCHAR(200) NOT NULL',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(100) USING "password"::VARCHAR(100)', 'ALTER TABLE "email" DROP COLUMN "user_id"',
'ALTER TABLE "user" DROP COLUMN "avatar"', 'ALTER TABLE "email" RENAME COLUMN "id" TO "email_id"',
'CREATE INDEX "idx_email_email_4a1a33" ON "email" ("email")', 'ALTER TABLE "email" ALTER COLUMN "is_primary" TYPE BOOL USING "is_primary"::BOOL',
'CREATE TABLE "email_user" ("email_id" INT NOT NULL REFERENCES "email" ("email_id") ON DELETE CASCADE,"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE)', 'DROP INDEX "uid_product_uuid_d33c18"',
'CREATE TABLE IF NOT EXISTS "newmodel" (\n "id" SERIAL NOT NULL PRIMARY KEY,\n "name" VARCHAR(50) NOT NULL\n);\nCOMMENT ON COLUMN "config"."user_id" IS \'User\';', 'ALTER TABLE "product" DROP COLUMN "uuid"',
'CREATE UNIQUE INDEX "uid_product_name_869427" ON "product" ("name", "type_db_alias")', 'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0',
'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")', 'ALTER TABLE "product" RENAME COLUMN "image" TO "pic"',
] 'ALTER TABLE "product" ALTER COLUMN "is_reviewed" TYPE BOOL USING "is_reviewed"::BOOL',
) 'ALTER TABLE "product" ALTER COLUMN "body" TYPE TEXT USING "body"::TEXT',
assert sorted(Migrate.downgrade_operators) == sorted( 'ALTER TABLE "product" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ',
[ 'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(100) USING "password"::VARCHAR(100)',
'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL', 'ALTER TABLE "user" DROP COLUMN "avatar"',
'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(200) USING "slug"::VARCHAR(200)', 'ALTER TABLE "user" ALTER COLUMN "is_superuser" TYPE BOOL USING "is_superuser"::BOOL',
'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1', 'ALTER TABLE "user" ALTER COLUMN "last_login" TYPE TIMESTAMPTZ USING "last_login"::TIMESTAMPTZ',
'ALTER TABLE "config" DROP COLUMN "user_id"', 'ALTER TABLE "user" ALTER COLUMN "intro" TYPE TEXT USING "intro"::TEXT',
'ALTER TABLE "config" DROP CONSTRAINT "fk_config_user_17daa970"', 'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL USING "is_active"::BOOL',
'ALTER TABLE "config" RENAME TO "configs"', 'ALTER TABLE "user" ALTER COLUMN "longitude" TYPE DECIMAL(10,8) USING "longitude"::DECIMAL(10,8)',
'ALTER TABLE "email" ADD "user_id" INT NOT NULL', 'CREATE INDEX "idx_product_name_869427" ON "product" ("name", "type_db_alias")',
'ALTER TABLE "email" DROP COLUMN "address"', 'CREATE INDEX "idx_email_email_4a1a33" ON "email" ("email")',
'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"', 'CREATE TABLE "email_user" (\n "email_id" INT NOT NULL REFERENCES "email" ("email_id") ON DELETE CASCADE,\n "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE\n)',
'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT', 'CREATE TABLE IF NOT EXISTS "newmodel" (\n "id" SERIAL NOT NULL PRIMARY KEY,\n "name" VARCHAR(50) NOT NULL\n);\nCOMMENT ON COLUMN "config"."user_id" IS \'User\'',
'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"', 'CREATE UNIQUE INDEX "uid_product_name_869427" ON "product" ("name", "type_db_alias")',
'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'', 'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(200) USING "password"::VARCHAR(200)', }
'DROP INDEX "idx_email_email_4a1a33"', expected_downgrade_operators = {
'DROP INDEX "idx_user_usernam_9987ab"', 'CREATE UNIQUE INDEX "uid_category_title_f7fc03" ON "category" ("title")',
'DROP INDEX "uid_product_name_869427"', 'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL',
'DROP TABLE IF EXISTS "email_user"', 'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(200) USING "slug"::VARCHAR(200)',
'DROP TABLE IF EXISTS "newmodel"', 'ALTER TABLE "category" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ',
] 'ALTER TABLE "config" ADD "name" VARCHAR(100) NOT NULL UNIQUE',
'CREATE UNIQUE INDEX "uid_config_name_2c83c8" ON "config" ("name")',
'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1',
'ALTER TABLE "config" DROP COLUMN "user_id"',
'ALTER TABLE "config" DROP CONSTRAINT "fk_config_user_17daa970"',
'ALTER TABLE "config" RENAME TO "configs"',
'ALTER TABLE "config" ALTER COLUMN "value" TYPE JSONB USING "value"::JSONB',
'ALTER TABLE "email" ADD "user_id" INT NOT NULL',
'ALTER TABLE "email" DROP COLUMN "address"',
'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"',
'ALTER TABLE "email" ALTER COLUMN "is_primary" TYPE BOOL USING "is_primary"::BOOL',
'ALTER TABLE "product" ADD "uuid" INT NOT NULL UNIQUE',
'CREATE UNIQUE INDEX "uid_product_uuid_d33c18" ON "product" ("uuid")',
'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT',
'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"',
'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(200) USING "password"::VARCHAR(200)',
'ALTER TABLE "user" ALTER COLUMN "last_login" TYPE TIMESTAMPTZ USING "last_login"::TIMESTAMPTZ',
'ALTER TABLE "user" ALTER COLUMN "is_superuser" TYPE BOOL USING "is_superuser"::BOOL',
'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL USING "is_active"::BOOL',
'ALTER TABLE "user" ALTER COLUMN "intro" TYPE TEXT USING "intro"::TEXT',
'ALTER TABLE "user" ALTER COLUMN "longitude" TYPE DECIMAL(12,9) USING "longitude"::DECIMAL(12,9)',
'ALTER TABLE "product" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ',
'ALTER TABLE "product" ALTER COLUMN "is_reviewed" TYPE BOOL USING "is_reviewed"::BOOL',
'ALTER TABLE "product" ALTER COLUMN "body" TYPE TEXT USING "body"::TEXT',
'DROP INDEX "idx_product_name_869427"',
'DROP INDEX "idx_email_email_4a1a33"',
'DROP INDEX "uid_user_usernam_9987ab"',
'DROP INDEX "uid_product_name_869427"',
'DROP TABLE IF EXISTS "email_user"',
'DROP TABLE IF EXISTS "newmodel"',
}
if should_add_user_id_column_type_alter_sql():
sql = 'ALTER TABLE "category" ALTER COLUMN "user_id" TYPE INT USING "user_id"::INT'
expected_upgrade_operators.add(sql)
expected_downgrade_operators.add(sql)
assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators)
assert not set(Migrate.downgrade_operators).symmetric_difference(
expected_downgrade_operators
) )
elif isinstance(Migrate.ddl, SqliteDDL): elif isinstance(Migrate.ddl, SqliteDDL):
assert Migrate.upgrade_operators == [] assert Migrate.upgrade_operators == []
assert Migrate.downgrade_operators == [] assert Migrate.downgrade_operators == []
@@ -882,18 +1043,30 @@ def test_sort_all_version_files(mocker):
mocker.patch( mocker.patch(
"os.listdir", "os.listdir",
return_value=[ return_value=[
"1_datetime_update.sql", "1_datetime_update.py",
"11_datetime_update.sql", "11_datetime_update.py",
"10_datetime_update.sql", "10_datetime_update.py",
"2_datetime_update.sql", "2_datetime_update.py",
], ],
) )
Migrate.migrate_location = "." Migrate.migrate_location = "."
assert Migrate.get_all_version_files() == [ assert Migrate.get_all_version_files() == [
"1_datetime_update.sql", "1_datetime_update.py",
"2_datetime_update.sql", "2_datetime_update.py",
"10_datetime_update.sql", "10_datetime_update.py",
"11_datetime_update.sql", "11_datetime_update.py",
] ]
async def test_empty_migration(mocker, tmp_path: Path) -> None:
mocker.patch("os.listdir", return_value=[])
Migrate.app = "foo"
expected_content = MIGRATE_TEMPLATE.format(upgrade_sql="", downgrade_sql="")
Migrate.migrate_location = tmp_path
migration_file = await Migrate.migrate("update", True)
f = tmp_path / migration_file
assert f.read_text() == expected_content

6
tests/test_utils.py Normal file
View File

@@ -0,0 +1,6 @@
from aerich.utils import import_py_file
def test_import_py_file() -> None:
m = import_py_file("aerich/utils.py")
assert getattr(m, "import_py_file")