Compare commits

..

129 Commits
v0.7.1 ... dev

Author SHA1 Message Date
Waket Zheng
9c3ba7e273
fix: aerich init-db process is suspended (#435) 2025-03-06 13:39:56 +08:00
Waket Zheng
074ba9b743
fix: ci failed with m2m field migrate test (#434)
* fix style issue

* fixing m2m test error
2025-03-05 10:28:41 +08:00
Waket Zheng
5d9adbdb54
chore: improve type hints (#432)
* chore: improve type hints

* chore: set `warn_unused_ignores` true for mypy

* refactor: use function to compare tortoise version

* refactor: change function name
2025-03-04 14:52:12 +08:00
Waket Zheng
8609435815
Release 0.8.2 (#429) 2025-02-28 20:24:06 +08:00
Waket Zheng
a624d1b43b
fix: migrate does not recognise attribute changes for string primary key (#428)
* refactor: show warning for unsupported pk field changes

* fix: migrate does not recognise attribute changes for string primary key

* docs: update changelog

* refactor: reduce indents

* chore: update docs
2025-02-27 22:23:26 +08:00
Waket Zheng
e299f8e1d6
feat: aerich.Command support async with syntax (#427)
* feat: `aerich.Command` support `async with` syntax

* docs: update readme
2025-02-27 10:55:48 +08:00
Waket Zheng
db0cf656fc chore: show friendly message when config missing 'apps' section 2025-02-26 18:08:12 +08:00
Waket Zheng
49bfbf4e6b
feat: support psycopg (#425) 2025-02-26 17:11:31 +08:00
Waket Zheng
0364ae3f83
feat: add project section (#424)
* refactor: apply future style type hints

* chore: use project section

* ci: upgrade to poetry v2

* ci: explicit declare python version for poetry

* fix error for generate index name

* fix _generate_fk_name

* ci: verify aiomysql support

* tests: poetry add

* Add patch to fix tortoise 0.24.1

* docs: update changelog
2025-02-26 14:24:02 +08:00
Waket Zheng
91adf9334e
feat: support skip table migration by set managed=False (#397) 2025-02-21 17:08:03 +08:00
Waket Zheng
41df464e8b
fix: no migration occurs when adding unique true to indexed field (#414)
* feat: alter unique for indexed column

* chore: update docs and change some var names
2025-02-20 16:58:32 +08:00
程序猿过家家
c35282c2a3
fix: inspectdb not match data type 'DOUBLE' and 'CHAR' for MySQL
* increase:
1. Inspectdb adds DECIMAL, DOUBLE, CHAR, TIME data type matching;
2. Add exception handling, avoid the need to manually create the entire table because a certain data type is not supported.

* fix: aerich inspectdb raise KeyError for double in MySQL

* feat: support command `python -m aerich`

* docs: update changelog

* tests: verify mysql inspectdb for float field

* fix mysql uuid field inspect to be charfield

* refactor: use `db_index=True` instead of `index=True` for inspectdb

* docs: update changelog

---------

Co-authored-by: xiechen <xiechen@jinse.com>
Co-authored-by: Waket Zheng <waketzheng@gmail.com>
2025-02-19 16:04:15 +08:00
Waket Zheng
557271c8e1
feat: support command python -m aerich (#417)
* feat: support command `python -m aerich`

* docs: update changelog
2025-02-18 15:44:02 +08:00
radluz
7f8c5dcddc
fix: update asyncio event loop policy on Windows (#251)
* fix: update asyncio event loop policy on Windows

* Use `platform.system` instead of `sys.platform`

---------

Co-authored-by: Waket Zheng <waketzheng@gmail.com>
2025-02-17 18:06:10 +08:00
Waket Zheng
1793dab43d
refactor: apply future type hints style (#416)
* refactor: apply future style type hints

* chore: put cryptography out of dev dependencies
2025-02-17 11:42:56 +08:00
Waket Zheng
6bdfdfc6db
fix: aerich migrate raises tortoise.exceptions.FieldError when index.INDEX_TYPE is not empty (#415)
* fix: aerich migrate raises `tortoise.exceptions.FieldError` when `index.INDEX_TYPE` is not empty

* feat: add `IF NOT EXISTS` to postgres create index template

* chore: explicit declare type hints of function parameters
2025-02-13 18:48:45 +08:00
alistairmaclean
0be5c1b545
Remove system dependency on libsqlite3.so on command.upgrade (#413)
* Remove system dependency on libsqlite3.so on command.upgrade

* Fix styling using `make style` command
2025-02-07 20:09:04 +08:00
Abdeldjalil Hezouat
d6b35ab0ac
change hardcoded version (#412)
Co-authored-by: Waket Zheng <waketzheng@gmail.com>
2025-02-07 19:50:41 +08:00
Waket Zheng
b46ceafb2e
feat: support --fake for aerich upgrade (#398)
* feat: support `--fake` for aerich upgrade

* Add `--fake` to downgrade

* tests: check --fake result for aerich upgrade and downgrade

* Update readme

* Fix unittest failed because of `db_field_types` changed

* refactor: improve type hints and document
2025-02-07 19:44:15 +08:00
Waket Zheng
ac847ba616
refactor: avoid updating inited config file (#402)
* refactor: avoid updating config file if init config items not changed

* fix unittest error with tortoise develop branch

* Remove extra space

* fix mysql test error

* fix mysql create index error
2025-01-04 09:08:14 +08:00
Waket Zheng
f5d7d56fa5
fix: inspectdb raise KeyError 'int2' for smallint (#401)
* fix: inspectdb raise KeyError 'int2' for smallint

* fix ci error

* no ask confirm for ci

* docs: update changelog
2024-12-27 23:49:53 +08:00
Waket Zheng
d6a51bd20e docs: release v0.8.1 2024-12-27 12:34:55 +08:00
Waket Zheng
c1dea4e846
chore: upgrade deps, add tortoise0.23 to ci (#399) 2024-12-27 12:30:34 +08:00
Waket Zheng
5e8a7c7e91
fix: migration with duplicate renaming of columns in some cases (#395)
* fix: migration with duplicate renaming of columns in some cases

* Update var name

* fix downgrade sql error

* fix test error

* docs: update changelog

* Add unittest

* Move not change line to origin position

* Update sort key to make it more frieldly interactive from multi fields rename

* refactor: remove puzzle vars

* docs: fix PR links in changelog

* fix sort key lambda error
2024-12-27 12:09:23 +08:00
Waket Zheng
7d22518c74
Fix create/drop indexes in every migration (#377)
* Add `__eq__` method for `Index`instances

* tests: add Index test case

* refactor: compare index instances before set hash and eq func to class

* fix: sort fields when generating index hash

* docs: update changlog

* fix style issue

* refactor: use CustomIndex instead of postgres special HashIndex

* Check tortoise version before patch Index

* Add comment

* Add comment for why > work

---------

Co-authored-by: dbf <somnium@riseup.net>
2024-12-22 00:24:18 +08:00
Waket Zheng
f93faa8afb
fix: add o2o field does not create constraint when migrating (#396)
* fix: add o2o field does not create constraint when migrating

* Add testcase and update changelog

* docs: update migrating list

* refactor: use `_handle_o2o_fields` instead of `is_o2o=True`

* Remove unused line
2024-12-22 00:23:47 +08:00
Waket Zheng
1acb9ed1e7
chore(deps): limit tortoise-orm version to >=0.21 instead of wildcard (*) (#388)
* Limit tortoise-orm version to `>=0.21` instead of wildcard (*)

* ci: fix typos

* docs: update changelog

* refactor: import MEMORY_SQLITE from tortoise

* Update changelog
2024-12-21 11:04:02 +08:00
Waket Zheng
69ce0cafa1
fix: intermediate table for m2m relation not created (#394)
* fix: intermediate table for m2m relation not created

* Add unittest

* docs: update changelog
2024-12-18 00:13:19 +08:00
Waket Zheng
4fc7f324d4
fix: add m2m field with custom m2m through generate duplicated table when migrating (#393)
* fix: m2m table duplicated when using custom model for through

* Add testcase

* docs: update changelog

* tests: add m2m custom through example test
2024-12-17 22:28:06 +08:00
Waket Zheng
d8addadb37
chore(deps): prefer to use tomllib/tomli and mark tomlkit/tomli_w as optional (#392)
* chore(deps): prefer to use tomllib and mark tomlkit as optional

* docs: add toml extra to install command

* docs: update changelog
2024-12-17 01:43:36 +08:00
Waket Zheng
0780919ef3
fix: migrate drop the wrong m2m field when model have multi m2m fields (#390)
* fix: migrate drop the wrong m2m field when model have multi m2m fields

* Make style and update changelog

* refactor: return new lists instead of change argument values in function

* refactor: use custom diff function instead of reorder lists

* docs: fix typo

* Fix hardcoded and rename custom diff function

* Update function doc
2024-12-17 01:20:02 +08:00
Waket Zheng
5af8c9cd56 docs: update changelog 2024-12-16 11:49:52 +08:00
Waket Zheng
56da0e7e3c fix: KeyError raised when removing or renaming an existing model 2024-12-12 10:43:40 +08:00
Lance.Moe
6270c4781e
fix: error when there is __init__.py in the migration folder (#272)
* fix: error when there is __init__.py in the migration folder

* fix: check __init__.py in the migration folder

* refactor

* refactor & add test

* refactor

* Update changelog

---------

Co-authored-by: Waket Zheng <waketzheng@gmail.com>
2024-12-11 21:34:14 +08:00
Waket Zheng
12d0a5dad1
fix: setting null=false on m2m field causes migration to fail (#385)
* fix: setting null=false on m2m field causes migration to fail

* Update changelog
2024-12-11 21:12:04 +08:00
Waket Zheng
56eff1b22f chore: use pytest instead of py.test in Makefile 2024-12-11 16:44:18 +08:00
Waket Zheng
e4a3863f80
fix: aerich upgrade raises OperationalError when unique constraint dropped at migration 1_xxx.py with postgres (#383) 2024-12-11 15:15:29 +08:00
Waket Zheng
5572876714
fix: NonExistentKey when running aerich init without [tool] section in config file (#381)
* fix: NonExistentKey when running `aerich init` without `[tool]` section in config file

* docs: update changelog
2024-12-11 13:26:18 +08:00
Waket Zheng
3d840395f1 docs: update changlog 2024-12-10 23:11:13 +08:00
Tuffy_
accceef24f
Fixed two problems when using under windows (#286)
* fix: Fixed an issue where an error would occur when using aerich in windows if the profile contained Chinese characters

* fix: Automatically delete the empty migration directory of the app if the init-db operation fails

* feat: generate migration file in empty directory instead of abort with warning

* tests: fix test fail in ci

---------

Co-authored-by: Waket Zheng <waketzheng@gmail.com>
2024-12-10 23:02:49 +08:00
Waket Zheng
9c81bc6036
Fix sqlite create/drop index (#379)
* Update add/drop index template for sqlite

* tests: add sqlite migrate/upgrade command test

* tests: add timeout for sqlite migrate command test

* tests: add test cases for add/drop unique field for sqlite

* fix: sqlite failed to add unique field
2024-12-10 16:37:30 +08:00
Waket Zheng
c2ebe9b5e4
Fix postgres fk rename (#378)
* Update postgres drop fk template

* fix test error

* docs: update changlog
2024-12-09 11:41:40 +08:00
Mykola Solodukha
8cefe68c9b
[BUG] Sort m2m fields before comparing them with diff(...) (#271)
* 🐛 Sort m2m fields before comparing them with `diff(...)`

* Add test case and upgrade changelog

---------

Co-authored-by: Waket Zheng <waketzheng@gmail.com>
2024-12-05 17:41:58 +08:00
Waket Zheng
44025823ee
chore: upgrade deps and fix ruff lint issues (#374)
* chore: upgrade deps and apply ruff lint for tests/

* style: fix ruff lint issues
2024-12-05 15:56:00 +08:00
Waket Zheng
252cb97767 Bump version to 0.8.0 and update changelog 2024-12-04 00:14:35 +08:00
Waket Zheng
ac3ef9e2eb
tests: no need to merge operators when NotSupportError raised (#373) 2024-12-03 23:59:54 +08:00
Waket Zheng
5b04b4422d
chore: upgrade deps, update changelog, drop test db before create (#372)
* chore: upgrade deps, update changelog, drop test db before create

* tests: clear operators for sqlite ddl after NotSupportError raised
2024-12-03 23:30:37 +08:00
Waket Zheng
b2f4029a4a
Improve type hints of inspectdb (#371) 2024-12-03 12:40:28 +08:00
xsillen
4e46d9d969
Fix the issue of parameter concatenation when generating ORM with inspectdb (#331)
Co-authored-by: floodpillar <165008032+floodpillar@users.noreply.github.com>
2024-12-03 11:44:36 +08:00
gck123
c0fd3ec63c
Fix KeyError when deleting a field with unqiue=True (#365)
* Fix KeyError when deleting a field with unqiue=True

* refactor: rename `old_data_unique` to `is_unique_field`

* Add testcases for remove unique field

---------

Co-authored-by: gongchangku <gongchangku@anban.tech>
Co-authored-by: Waket Zheng <waketzheng@gmail.com>
2024-12-03 10:16:44 +08:00
Waket Zheng
103470f4c1
Merge pull request #367 from waketzheng/fix-style-issue
chore: make style, upgrade deps, fix ci error and update changelog
2024-11-29 15:09:25 +08:00
Waket Zheng
dc020358b6 chore: make style, upgrade deps, fix ci error and update changelog 2024-11-25 23:46:48 +08:00
Waket Zheng
095eb48196
Merge pull request #360 from ahmetveburak/fix-package
fix(package): correct the click import
2024-11-25 18:10:27 +08:00
He
fac1de6299
Merge pull request #355 from merlinz01/improve-cli-descriptions
Improve CLI help text and output
2024-11-18 21:25:13 +01:00
ahmetveburak
e049fcee26
fix(package): correct the click import 2024-10-05 20:43:32 +03:00
Merlin
ee144d4a9b
update migrate output for consistency 2024-08-07 13:32:49 -04:00
Merlin
dbf96a17d3
Improve command-line descriptions
Changed CLI help texts for some of the options to make them clearer.
2024-08-07 13:31:07 -04:00
long2ice
15d56121ef fix: migrate 2024-08-06 22:46:50 +08:00
long2ice
4851ecfe82 refactor: use asyncclick 2024-08-06 22:41:39 +08:00
long2ice
ad0e7006bc
Merge pull request #348 from waketzheng/update-changelog
Update changelog
2024-06-07 11:25:10 +08:00
Waket Zheng
27b29e401b Merge remote-tracking branch 'upstream/dev' into update-changelog 2024-06-07 11:11:47 +08:00
Waket Zheng
7bb35c10e4 Add tag link to change log 2024-06-07 11:08:49 +08:00
long2ice
ed113d491e
Merge pull request #347 from waketzheng/add-mypy-check-to-ci
Use mypy for type hint check in ci
2024-06-06 21:38:23 +08:00
Waket Zheng
9e46fbf55d rollback ci command 2024-06-06 18:20:57 +08:00
Waket Zheng
fc68f99c86 fixing cache action 2024-06-06 18:02:58 +08:00
Waket Zheng
480087df07 Add id for cache step 2024-06-06 17:38:51 +08:00
Waket Zheng
24a2087b78 Skip make deps if cache hint in ci 2024-06-06 17:35:00 +08:00
Waket Zheng
bceeb236c2 Checking cache action 2024-06-06 17:21:46 +08:00
Waket Zheng
c42fdab74d Add cache action to ci 2024-06-06 17:11:27 +08:00
Waket Zheng
ceb1e0ffef Activate type hint check in ci 2024-06-06 17:07:09 +08:00
long2ice
13dd44bef7
Merge pull request #340 from waketzheng/type-hint-tests
Improve type hints for tests/
2024-06-06 16:53:01 +08:00
long2ice
219633a926
Merge pull request #341 from waketzheng/type-hint-simple
Simple type hints for aerich/
2024-06-06 16:52:30 +08:00
long2ice
e6302a920a
Merge pull request #342 from waketzheng/type-hint-ddl
Improve type hints for ddl and inspectdb
2024-06-06 16:51:37 +08:00
long2ice
79a77d368f
Merge pull request #343 from waketzheng/type-hints-aerich.migrate
Add type hints for aerich.migrate
2024-06-06 16:50:53 +08:00
long2ice
4a1fc4cfa0
Merge pull request #339 from waketzheng/update-ci
Update ci action versions and avoid `make ci` install deps twice
2024-06-06 16:50:39 +08:00
Waket Zheng
aee706e29b Update changelog 2024-06-06 16:43:08 +08:00
Waket Zheng
572c13f9dd fix conflict 2024-06-06 16:36:29 +08:00
Waket Zheng
7b733495fb fix conflict 2024-06-06 16:35:09 +08:00
Waket Zheng
c0c217392c Merge remote-tracking branch 'upstream/dev' into type-hint-simple 2024-06-06 16:32:00 +08:00
Waket Zheng
c7a3d164cb fix conflict 2024-06-06 16:31:24 +08:00
Waket Zheng
50add58981 merge dev 2024-06-06 16:27:06 +08:00
long2ice
f3b6f8f264
Merge pull request #345 from waketzheng/dev
Drop python3.7 support
2024-06-06 15:49:31 +08:00
long2ice
d33638471b
Merge pull request #346 from waketzheng/fix-336
fix: mysql drop unique index with error name
2024-06-06 15:48:50 +08:00
Waket Zheng
e764bb56f7 Add new column for unique index remove test 2024-06-06 09:07:13 +08:00
Waket Zheng
84d31d63f6 Update readme 2024-06-06 08:42:55 +08:00
Waket Zheng
234495d291 docs: bump up version and update changelog 2024-06-06 00:07:29 +08:00
Waket Zheng
e971653851 fix: mysql drop unique index migrate error 2024-06-05 23:37:30 +08:00
Waket Zheng
58d31b3a05 fix: make ci error with latest tortoise-orm (#344) 2024-06-04 01:44:28 +08:00
Waket Zheng
affffbdae3 Drop python3.7 support 2024-06-03 01:26:48 +08:00
Waket Zheng
6466a852c8 Add type hints for aerich.migrate 2024-06-02 18:09:34 +08:00
Waket Zheng
a917f253c9 Add type hints for inspectdb/sqlite 2024-06-02 17:56:54 +08:00
Waket Zheng
dd11bed5a0 Add type hints for ddl and inspectdb 2024-06-01 22:14:30 +08:00
Waket Zheng
8756f64e3f Simple type hints for aerich/ 2024-06-01 21:16:53 +08:00
Waket Zheng
1addda8178 Improve type hints for tests/ 2024-06-01 20:58:25 +08:00
Waket Zheng
716638752b Update ci action versions and avoid make ci install deps twice 2024-06-01 20:46:08 +08:00
long2ice
51117867a6
Merge pull request #328 from Fl0kse/fix_issue-150
try to fix "Add ManyToManyField will break migrate #150" issues
2024-01-23 22:08:08 +08:00
artem.ulchenko
b25c7671bb try to fix "Add ManyToManyField will break migrate #150" issues 2024-01-18 11:40:59 +03:00
long2ice
b63fbcc7a4 style: fix 2024-01-18 09:50:41 +08:00
long2ice
4370b5ed08
Merge pull request #327 from ar0ne/dev
Added an option to generate empty migration file
2024-01-18 09:46:02 +08:00
ar0ne
2b2e465362 missed changes that add new flag to cli 2023-12-27 12:56:46 +05:30
ar0ne
ede53ade86 added note in readme 2023-12-26 23:20:29 +05:30
ar0ne
ad54b5e9dd remove redundant comma for empty migration 2023-12-26 23:12:11 +05:30
ar0ne
b1ff2418f5 added option to generate empty migration file 2023-12-26 22:55:51 +05:30
long2ice
01264f3f27 fix: pydantic v2. (#322) 2023-08-30 10:04:53 +08:00
long2ice
ea234a5799 fix: default empty str 2023-08-23 10:53:19 +08:00
long2ice
b724f24f1a
Merge pull request #319 from strayge/pr/editable-install
Update poetry build backend to support editable install
2023-08-22 10:00:03 +08:00
long2ice
2bc23103dc
Merge pull request #320 from strayge/pr/column-description-diff
Fix changed column descriptions in diffs
2023-08-22 09:59:46 +08:00
strayge
6d83c370ad fix changed column descriptions in diffs 2023-08-21 10:58:43 +04:00
strayge
e729bb9b60 update poetry build-backend to support editable install 2023-08-21 10:58:26 +04:00
long2ice
8edd834da6 refactor: make in_transaction default True 2023-08-04 10:36:27 +08:00
long2ice
4adc89d441
Merge pull request #306 from plusiv/add-char-support-mysql
add char support for mysql
2023-07-27 12:52:01 +08:00
Jorge Massih
fd4b9fe7d3
add char support for mysql 2023-06-15 00:40:29 -04:00
long2ice
467406ed20
Merge pull request #303 from karichevi/dev
Added Documentation in Russian Language
2023-05-23 11:35:09 +08:00
karichevi
484b5900ce Added Documentation in Russian Language 2023-05-22 14:30:36 +03:00
long2ice
b8b6df0b65 chore: update deps 2023-05-12 15:27:50 +08:00
long2ice
f0bc3126e9 fix: generates two semicolons in a row. (#301) 2023-05-12 14:52:17 +08:00
long2ice
dbc0d9e7ef
Merge pull request #296 from evstratbg/migrate-transaction
add in-transaction for upgrade
2023-05-05 23:12:08 +08:00
Bogdan
818dd29991 fix styles 2023-05-05 17:29:34 +04:00
Bogdan
e199e03b53 added -i param 2023-05-03 19:48:54 +04:00
Bogdan
d79dc25ee8 enriched changelog 2023-05-03 19:48:25 +04:00
Bogdan
c6d51a4dcf bump version 2023-05-03 19:48:15 +04:00
Bogdan
241b30a710 add in-transaction for upgrade 2023-04-03 20:55:12 +04:00
long2ice
8cf50c58d7 test: fix ci 2023-01-27 15:17:41 +08:00
long2ice
1c9b65cc37 fix: modify multiple times. (#279) 2023-01-27 13:49:07 +08:00
long2ice
3fbf9febfb
Merge pull request #281 from CortexPE/patch-1
Fix #280 by removing trailing semicolon
2022-12-20 14:28:51 +08:00
marshall
7b6545d4e1
Fix #280 by removing trailing semicolon 2022-12-20 04:28:40 +08:00
long2ice
52b50a2161 feat: support virtual fields 2022-11-18 23:23:04 +08:00
long2ice
90943a473c
Merge pull request #269 from jtraub/dev
Close connections in the init_db wrapper
2022-09-29 08:39:33 +08:00
Konstantin Mikhailov
d7ecd97e88
Close connections in the init_db wrapper 2022-09-29 08:09:59 +10:00
46 changed files with 5103 additions and 1470 deletions

View File

@ -18,17 +18,67 @@ jobs:
POSTGRES_PASSWORD: 123456
POSTGRES_USER: postgres
options: --health-cmd=pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
tortoise-orm:
- tortoise021
- tortoise022
- tortoise023
- tortoise024
# TODO: add dev back when drop python3.8 support
# - tortoisedev
steps:
- name: Start MySQL
run: sudo systemctl start mysql.service
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.x'
python-version: ${{ matrix.python-version }}
- uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/poetry.lock') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install and configure Poetry
run: |
pip install -U pip poetry
poetry config virtualenvs.create false
pip install -U pip
if [[ "${{ matrix.python-version }}" == "3.8" ]]; then
# poetry2.0+ does not support installed by python3.8, but can manage project using py38
python3.12 -m pip install "poetry>=2.0"
else
pip install "poetry>=2.0"
fi
poetry env use python${{ matrix.python-version }}
- name: Install dependencies and check style
run: poetry run make check
- name: Install TortoiseORM v0.21
if: matrix.tortoise-orm == 'tortoise021'
run: poetry run pip install --upgrade "tortoise-orm>=0.21,<0.22"
- name: Install TortoiseORM v0.22
if: matrix.tortoise-orm == 'tortoise022'
run: poetry run pip install --upgrade "tortoise-orm>=0.22,<0.23"
- name: Install TortoiseORM v0.23
if: matrix.tortoise-orm == 'tortoise023'
run: poetry run pip install --upgrade "tortoise-orm>=0.23,<0.24"
- name: Install TortoiseORM v0.24
if: matrix.tortoise-orm == 'tortoise024'
run: |
if [[ "${{ matrix.python-version }}" == "3.8" ]]; then
echo "Skip test for tortoise v0.24 as it does not support Python3.8"
else
poetry run pip install --upgrade "tortoise-orm>=0.24,<0.25"
fi
- name: Install TortoiseORM develop branch
if: matrix.tortoise-orm == 'tortoisedev'
run: |
if [[ "${{ matrix.python-version }}" == "3.8" ]]; then
echo "Skip test for tortoise develop branch as it does not support Python3.8"
else
poetry run pip uninstall -y tortoise-orm
poetry run pip install --upgrade "git+https://github.com/tortoise/tortoise-orm"
fi
- name: CI
env:
MYSQL_PASS: root
@ -37,4 +87,23 @@ jobs:
POSTGRES_PASS: 123456
POSTGRES_HOST: 127.0.0.1
POSTGRES_PORT: 5432
run: make ci
run: poetry run make _testall
- name: Verify aiomysql support
# Only check the latest version of tortoise
if: matrix.tortoise-orm == 'tortoise024'
run: |
poetry run pip uninstall -y asyncmy
poetry run make test_mysql
poetry run pip install asyncmy
env:
MYSQL_PASS: root
MYSQL_HOST: 127.0.0.1
MYSQL_PORT: 3306
- name: Verify psycopg support
# Only check the latest version of tortoise
if: matrix.tortoise-orm == 'tortoise024'
run: poetry run make test_psycopg
env:
POSTGRES_PASS: 123456
POSTGRES_HOST: 127.0.0.1
POSTGRES_PORT: 5432

View File

@ -7,8 +7,8 @@ jobs:
publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install and configure Poetry

View File

@ -1,7 +1,104 @@
# ChangeLog
## 0.8
### [0.8.3]**(Unreleased)**
#### Fixed
- fix: `aerich init-db` process is suspended. ([#435])
[#435]: https://github.com/tortoise/aerich/pull/435
### [0.8.2](../../releases/tag/v0.8.2) - 2025-02-28
#### Added
- Support changes `max_length` or int type for primary key field. ([#428])
- feat: support psycopg. ([#425])
- Support run `poetry add aerich` in project that inited by poetry v2. ([#424])
- feat: support command `python -m aerich`. ([#417])
- feat: add --fake to upgrade/downgrade. ([#398])
- Support ignore table by settings `managed=False` in `Meta` class. ([#397])
#### Fixed
- fix: aerich migrate raises tortoise.exceptions.FieldError when `index.INDEX_TYPE` is not empty. ([#415])
- No migration occurs as expected when adding `unique=True` to indexed field. ([#404])
- fix: inspectdb raise KeyError 'int2' for smallint. ([#401])
- fix: inspectdb not match data type 'DOUBLE' and 'CHAR' for MySQL. ([#187])
### Changed
- Refactored version management to use `importlib.metadata.version(__package__)` instead of hardcoded version string ([#412])
[#397]: https://github.com/tortoise/aerich/pull/397
[#398]: https://github.com/tortoise/aerich/pull/398
[#401]: https://github.com/tortoise/aerich/pull/401
[#404]: https://github.com/tortoise/aerich/pull/404
[#412]: https://github.com/tortoise/aerich/pull/412
[#415]: https://github.com/tortoise/aerich/pull/415
[#417]: https://github.com/tortoise/aerich/pull/417
[#424]: https://github.com/tortoise/aerich/pull/424
[#425]: https://github.com/tortoise/aerich/pull/425
### [0.8.1](../../releases/tag/v0.8.1) - 2024-12-27
#### Fixed
- fix: add o2o field does not create constraint when migrating. ([#396])
- Migration with duplicate renaming of columns in some cases. ([#395])
- fix: intermediate table for m2m relation not created. ([#394])
- Migrate add m2m field with custom through generate duplicated table. ([#393])
- Migrate drop the wrong m2m field when model have multi m2m fields. ([#376])
- KeyError raised when removing or renaming an existing model. ([#386])
- fix: error when there is `__init__.py` in the migration folder. ([#272])
- Setting null=false on m2m field causes migration to fail. ([#334])
- Fix NonExistentKey when running `aerich init` without `[tool]` section in config file. ([#284])
- Fix configuration file reading error when containing Chinese characters. ([#286])
- sqlite: failed to create/drop index. ([#302])
- PostgreSQL: Cannot drop constraint after deleting or rename FK on a model. ([#378])
- Fix create/drop indexes in every migration. ([#377])
- Sort m2m fields before comparing them with diff. ([#271])
#### Changed
- Allow run `aerich init-db` with empty migration directories instead of abort with warnings. ([#286])
- Add version constraint(>=0.21) for tortoise-orm. ([#388])
- Move `tomlkit` to optional and support `pip install aerich[toml]`. ([#392])
[#396]: https://github.com/tortoise/aerich/pull/396
[#395]: https://github.com/tortoise/aerich/pull/395
[#394]: https://github.com/tortoise/aerich/pull/394
[#393]: https://github.com/tortoise/aerich/pull/393
[#392]: https://github.com/tortoise/aerich/pull/392
[#388]: https://github.com/tortoise/aerich/pull/388
[#386]: https://github.com/tortoise/aerich/pull/386
[#378]: https://github.com/tortoise/aerich/pull/378
[#377]: https://github.com/tortoise/aerich/pull/377
[#376]: https://github.com/tortoise/aerich/pull/376
[#334]: https://github.com/tortoise/aerich/pull/334
[#302]: https://github.com/tortoise/aerich/pull/302
[#286]: https://github.com/tortoise/aerich/pull/286
[#284]: https://github.com/tortoise/aerich/pull/284
[#272]: https://github.com/tortoise/aerich/pull/272
[#271]: https://github.com/tortoise/aerich/pull/271
### [0.8.0](../../releases/tag/v0.8.0) - 2024-12-04
- Fix the issue of parameter concatenation when generating ORM with inspectdb (#331)
- Fix KeyError when deleting a field with unqiue=True. (#364)
- Correct the click import. (#360)
- Improve CLI help text and output. (#355)
- Fix mysql drop unique index raises OperationalError. (#346)
**Upgrade note:**
1. Use column name as unique key name for mysql
2. Drop support for Python3.7
## 0.7
### [0.7.2](../../releases/tag/v0.7.2) - 2023-07-20
- Support virtual fields.
- Fix modify multiple times. (#279)
- Added `-i` and `--in-transaction` options to `aerich migrate` command. (#296)
- Fix generates two semicolons in a row. (#301)
### 0.7.1
- Fix syntax error with python3.8.10. (#265)

View File

@ -1,32 +1,43 @@
checkfiles = aerich/ tests/ conftest.py
black_opts = -l 100 -t py38
py_warn = PYTHONDEVMODE=1
MYSQL_HOST ?= "127.0.0.1"
MYSQL_PORT ?= 3306
MYSQL_PASS ?= "123456"
POSTGRES_HOST ?= "127.0.0.1"
POSTGRES_PORT ?= 5432
POSTGRES_PASS ?= "123456"
POSTGRES_PASS ?= 123456
up:
@poetry update
deps:
@poetry install -E asyncpg -E asyncmy
@poetry install --all-extras --all-groups
style: deps
@isort -src $(checkfiles)
@black $(black_opts) $(checkfiles)
_style:
@ruff check --fix $(checkfiles)
@ruff format $(checkfiles)
style: deps _style
check: deps
@black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
@pflake8 $(checkfiles)
_check:
@ruff format --check $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
@ruff check $(checkfiles)
@mypy $(checkfiles)
@bandit -r aerich
check: deps _check
_lint: _build
@ruff format $(checkfiles)
ruff check --fix $(checkfiles)
mypy $(checkfiles)
bandit -c pyproject.toml -r $(checkfiles)
twine check dist/*
lint: deps _lint
test: deps
$(py_warn) TEST_DB=sqlite://:memory: py.test
$(py_warn) TEST_DB=sqlite://:memory: pytest
test_sqlite:
$(py_warn) TEST_DB=sqlite://:memory: py.test
$(py_warn) TEST_DB=sqlite://:memory: pytest
test_mysql:
$(py_warn) TEST_DB="mysql://root:$(MYSQL_PASS)@$(MYSQL_HOST):$(MYSQL_PORT)/test_\{\}" pytest -vv -s
@ -34,9 +45,14 @@ test_mysql:
test_postgres:
$(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest -vv -s
testall: deps test_sqlite test_postgres test_mysql
test_psycopg:
$(py_warn) TEST_DB="psycopg://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest -vv -s
build: deps
_testall: test_sqlite test_postgres test_mysql
testall: deps _testall
_build:
@poetry build
build: deps _build
ci: check testall
ci: build _check _testall

View File

@ -5,6 +5,8 @@
[![image](https://github.com/tortoise/aerich/workflows/pypi/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:pypi)
[![image](https://github.com/tortoise/aerich/workflows/ci/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:ci)
English | [Русский](./README_RU.md)
## Introduction
Aerich is a database migrations tool for TortoiseORM, which is like alembic for SQLAlchemy, or like Django ORM with
@ -15,7 +17,7 @@ it\'s own migration solution.
Just install from pypi:
```shell
pip install aerich
pip install "aerich[toml]"
```
## Quick Start
@ -44,7 +46,7 @@ Commands:
## Usage
You need add `aerich.models` to your `Tortoise-ORM` config first. Example:
You need to add `aerich.models` to your `Tortoise-ORM` config first. Example:
```python
TORTOISE_ORM = {
@ -111,6 +113,14 @@ If `aerich` guesses you are renaming a column, it will ask `Rename {old_column}
`True` to rename column without column drop, or choose `False` to drop the column then create. Note that the latter may
lose data.
If you need to manually write migration, you could generate empty file:
```shell
> aerich migrate --name add_index --empty
Success migrate 1_202326122220101229_add_index.py
```
### Upgrade to latest version
```shell
@ -216,14 +226,14 @@ from tortoise import Model, fields
class Test(Model):
date = fields.DateField(null=True, )
datetime = fields.DatetimeField(auto_now=True, )
decimal = fields.DecimalField(max_digits=10, decimal_places=2, )
float = fields.FloatField(null=True, )
id = fields.IntField(pk=True, )
string = fields.CharField(max_length=200, null=True, )
time = fields.TimeField(null=True, )
tinyint = fields.BooleanField(null=True, )
date = fields.DateField(null=True)
datetime = fields.DatetimeField(auto_now=True)
decimal = fields.DecimalField(max_digits=10, decimal_places=2)
float = fields.FloatField(null=True)
id = fields.IntField(primary_key=True)
string = fields.CharField(max_length=200, null=True)
time = fields.TimeField(null=True)
tinyint = fields.BooleanField(null=True)
```
Note that this command is limited and can't infer some fields, such as `IntEnumField`, `ForeignKeyField`, and others.
@ -233,8 +243,8 @@ Note that this command is limited and can't infer some fields, such as `IntEnumF
```python
tortoise_orm = {
"connections": {
"default": expand_db_url(db_url, True),
"second": expand_db_url(db_url_second, True),
"default": "postgres://postgres_user:postgres_pass@127.0.0.1:5432/db1",
"second": "postgres://postgres_user:postgres_pass@127.0.0.1:5432/db2",
},
"apps": {
"models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"},
@ -243,7 +253,7 @@ tortoise_orm = {
}
```
You only need to specify `aerich.models` in one app, and must specify `--app` when running `aerich migrate` and so on.
You only need to specify `aerich.models` in one app, and must specify `--app` when running `aerich migrate` and so on, e.g. `aerich --app models_second migrate`.
## Restore `aerich` workflow
@ -263,11 +273,38 @@ You can use `aerich` out of cli by use `Command` class.
```python
from aerich import Command
command = Command(tortoise_config=config, app='models')
await command.init()
await command.migrate('test')
async with Command(tortoise_config=config, app='models') as command:
await command.migrate('test')
await command.upgrade()
```
## Upgrade/Downgrade with `--fake` option
Marks the migrations up to the latest one(or back to the target one) as applied, but without actually running the SQL to change your database schema.
- Upgrade
```bash
aerich upgrade --fake
aerich --app models upgrade --fake
```
- Downgrade
```bash
aerich downgrade --fake -v 2
aerich --app models downgrade --fake -v 2
```
### Ignore tables
You can tell aerich to ignore table by setting `managed=False` in the `Meta` class, e.g.:
```py
class MyModel(Model):
class Meta:
managed = False
```
**Note** `managed=False` does not recognized by `tortoise-orm` and `aerich init-db`, it is only for `aerich migrate`.
## License
This project is licensed under the

274
README_RU.md Normal file
View File

@ -0,0 +1,274 @@
# Aerich
[![image](https://img.shields.io/pypi/v/aerich.svg?style=flat)](https://pypi.python.org/pypi/aerich)
[![image](https://img.shields.io/github/license/tortoise/aerich)](https://github.com/tortoise/aerich)
[![image](https://github.com/tortoise/aerich/workflows/pypi/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:pypi)
[![image](https://github.com/tortoise/aerich/workflows/ci/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:ci)
[English](./README.md) | Русский
## Введение
Aerich - это инструмент для миграции базы данных для TortoiseORM, который аналогичен Alembic для SQLAlchemy или встроенному решению миграций в Django ORM.
## Установка
Просто установите из pypi:
```shell
pip install aerich
```
## Быстрый старт
```shell
> aerich -h
Usage: aerich [OPTIONS] COMMAND [ARGS]...
Options:
-V, --version Show the version and exit.
-c, --config TEXT Config file. [default: pyproject.toml]
--app TEXT Tortoise-ORM app name.
-h, --help Show this message and exit.
Commands:
downgrade Downgrade to specified version.
heads Show current available heads in migrate location.
history List all migrate items.
init Init config file and generate root migrate location.
init-db Generate schema and generate app migrate location.
inspectdb Introspects the database tables to standard output as...
migrate Generate migrate changes file.
upgrade Upgrade to specified version.
```
## Использование
Сначала вам нужно добавить aerich.models в конфигурацию вашего Tortoise-ORM. Пример:
```python
TORTOISE_ORM = {
"connections": {"default": "mysql://root:123456@127.0.0.1:3306/test"},
"apps": {
"models": {
"models": ["tests.models", "aerich.models"],
"default_connection": "default",
},
},
}
```
### Инициализация
```shell
> aerich init -h
Usage: aerich init [OPTIONS]
Init config file and generate root migrate location.
Options:
-t, --tortoise-orm TEXT Tortoise-ORM config module dict variable, like
settings.TORTOISE_ORM. [required]
--location TEXT Migrate store location. [default: ./migrations]
-s, --src_folder TEXT Folder of the source, relative to the project root.
-h, --help Show this message and exit.
```
Инициализируйте файл конфигурации и задайте местоположение миграций:
```shell
> aerich init -t tests.backends.mysql.TORTOISE_ORM
Success create migrate location ./migrations
Success write config to pyproject.toml
```
### Инициализация базы данных
```shell
> aerich init-db
Success create app migrate location ./migrations/models
Success generate schema for app "models"
```
Если ваше приложение Tortoise-ORM не является приложением по умолчанию с именем models, вы должны указать правильное имя приложения с помощью параметра --app, например: aerich --app other_models init-db.
### Обновление моделей и создание миграции
```shell
> aerich migrate --name drop_column
Success migrate 1_202029051520102929_drop_column.py
```
Формат имени файла миграции следующий: `{версия}_{дата_и_время}_{имя|обновление}.py`.
Если aerich предполагает, что вы переименовываете столбец, он спросит:
Переименовать `{старый_столбец} в {новый_столбец} [True]`. Вы можете выбрать `True`,
чтобы переименовать столбец без удаления столбца, или выбрать `False`, чтобы удалить столбец,
а затем создать новый. Обратите внимание, что последний вариант может привести к потере данных.
### Обновление до последней версии
```shell
> aerich upgrade
Success upgrade 1_202029051520102929_drop_column.py
```
Теперь ваша база данных обновлена до последней версии.
### Откат до указанной версии
```shell
> aerich downgrade -h
Usage: aerich downgrade [OPTIONS]
Downgrade to specified version.
Options:
-v, --version INTEGER Specified version, default to last. [default: -1]
-d, --delete Delete version files at the same time. [default:
False]
--yes Confirm the action without prompting.
-h, --help Show this message and exit.
```
```shell
> aerich downgrade
Success downgrade 1_202029051520102929_drop_column.py
```
Теперь ваша база данных откатилась до указанной версии.
### Показать историю
```shell
> aerich history
1_202029051520102929_drop_column.py
```
### Чтобы узнать, какие миграции должны быть применены, можно использовать команду:
```shell
> aerich heads
1_202029051520102929_drop_column.py
```
### Осмотр таблиц базы данных для модели TortoiseORM
В настоящее время inspectdb поддерживает MySQL, Postgres и SQLite.
```shell
Usage: aerich inspectdb [OPTIONS]
Introspects the database tables to standard output as TortoiseORM model.
Options:
-t, --table TEXT Which tables to inspect.
-h, --help Show this message and exit.
```
Посмотреть все таблицы и вывести их на консоль:
```shell
aerich --app models inspectdb
```
Осмотреть указанную таблицу в приложении по умолчанию и перенаправить в models.py:
```shell
aerich inspectdb -t user > models.py
```
Например, ваша таблица выглядит следующим образом:
```sql
CREATE TABLE `test`
(
`id` int NOT NULL AUTO_INCREMENT,
`decimal` decimal(10, 2) NOT NULL,
`date` date DEFAULT NULL,
`datetime` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`time` time DEFAULT NULL,
`float` float DEFAULT NULL,
`string` varchar(200) COLLATE utf8mb4_general_ci DEFAULT NULL,
`tinyint` tinyint DEFAULT NULL,
PRIMARY KEY (`id`),
KEY `asyncmy_string_index` (`string`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COLLATE = utf8mb4_general_ci
```
Теперь выполните команду aerich inspectdb -t test, чтобы увидеть сгенерированную модель:
```python
from tortoise import Model, fields
class Test(Model):
date = fields.DateField(null=True, )
datetime = fields.DatetimeField(auto_now=True, )
decimal = fields.DecimalField(max_digits=10, decimal_places=2, )
float = fields.FloatField(null=True, )
id = fields.IntField(pk=True, )
string = fields.CharField(max_length=200, null=True, )
time = fields.TimeField(null=True, )
tinyint = fields.BooleanField(null=True, )
```
Обратите внимание, что эта команда имеет ограничения и не может автоматически определить некоторые поля, такие как `IntEnumField`, `ForeignKeyField` и другие.
### Несколько баз данных
```python
tortoise_orm = {
"connections": {
"default": expand_db_url(db_url, True),
"second": expand_db_url(db_url_second, True),
},
"apps": {
"models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"},
"models_second": {"models": ["tests.models_second"], "default_connection": "second", },
},
}
```
Вам нужно указать `aerich.models` только в одном приложении и должны указывать `--app` при запуске команды `aerich migrate` и т.д.
## Восстановление рабочего процесса aerich
В некоторых случаях, например, при возникновении проблем после обновления `aerich`, вы не можете запустить `aerich migrate` или `aerich upgrade`. В таком случае вы можете выполнить следующие шаги:
1. удалите таблицы `aerich`.
2. удалите директорию `migrations/{app}`.
3. rerun `aerich init-db`.
Обратите внимание, что эти действия безопасны, и вы можете использовать их для сброса миграций, если у вас слишком много файлов миграции.
## Использование aerich в приложении
Вы можете использовать `aerich` вне командной строки, используя класс `Command`.
```python
from aerich import Command
command = Command(tortoise_config=config, app='models')
await command.init()
await command.migrate('test')
```
## Лицензия
Этот проект лицензирован в соответствии с лицензией
[Apache-2.0](https://github.com/long2ice/aerich/blob/master/LICENSE) Лицензия.

View File

@ -1,8 +1,13 @@
import os
from pathlib import Path
from typing import List
from __future__ import annotations
from tortoise import Tortoise, generate_schema_for_client
import os
import platform
from contextlib import AbstractAsyncContextManager
from pathlib import Path
from typing import TYPE_CHECKING
import tortoise
from tortoise import Tortoise, connections, generate_schema_for_client
from tortoise.exceptions import OperationalError
from tortoise.transactions import in_transaction
from tortoise.utils import get_schema_sql
@ -20,23 +25,155 @@ from aerich.utils import (
import_py_file,
)
if TYPE_CHECKING:
from tortoise import Model
from tortoise.fields.relational import ManyToManyFieldInstance # NOQA:F401
class Command:
from aerich.inspectdb import Inspect
def _init_asyncio_patch():
"""
Select compatible event loop for psycopg3.
As of Python 3.8+, the default event loop on Windows is `proactor`,
however psycopg3 requires the old default "selector" event loop.
See https://www.psycopg.org/psycopg3/docs/advanced/async.html
"""
if platform.system() == "Windows":
try:
from asyncio import WindowsSelectorEventLoopPolicy # type:ignore
except ImportError:
pass # Can't assign a policy which doesn't exist.
else:
from asyncio import get_event_loop_policy, set_event_loop_policy
if not isinstance(get_event_loop_policy(), WindowsSelectorEventLoopPolicy):
set_event_loop_policy(WindowsSelectorEventLoopPolicy())
def _init_tortoise_0_24_1_patch():
# this patch is for "tortoise-orm==0.24.1" to fix:
# https://github.com/tortoise/tortoise-orm/issues/1893
if tortoise.__version__ != "0.24.1":
return
from tortoise.backends.base.schema_generator import BaseSchemaGenerator, cast, re
def _get_m2m_tables(
self, model: type[Model], db_table: str, safe: bool, models_tables: list[str]
) -> list[str]: # Copied from tortoise-orm
m2m_tables_for_create = []
for m2m_field in model._meta.m2m_fields:
field_object = cast("ManyToManyFieldInstance", model._meta.fields_map[m2m_field])
if field_object._generated or field_object.through in models_tables:
continue
backward_key, forward_key = field_object.backward_key, field_object.forward_key
if field_object.db_constraint:
backward_fk = self._create_fk_string(
"",
backward_key,
db_table,
model._meta.db_pk_column,
field_object.on_delete,
"",
)
forward_fk = self._create_fk_string(
"",
forward_key,
field_object.related_model._meta.db_table,
field_object.related_model._meta.db_pk_column,
field_object.on_delete,
"",
)
else:
backward_fk = forward_fk = ""
exists = "IF NOT EXISTS " if safe else ""
through_table_name = field_object.through
backward_type = self._get_pk_field_sql_type(model._meta.pk)
forward_type = self._get_pk_field_sql_type(field_object.related_model._meta.pk)
comment = ""
if desc := field_object.description:
comment = self._table_comment_generator(table=through_table_name, comment=desc)
m2m_create_string = self.M2M_TABLE_TEMPLATE.format(
exists=exists,
table_name=through_table_name,
backward_fk=backward_fk,
forward_fk=forward_fk,
backward_key=backward_key,
backward_type=backward_type,
forward_key=forward_key,
forward_type=forward_type,
extra=self._table_generate_extra(table=field_object.through),
comment=comment,
)
if not field_object.db_constraint:
m2m_create_string = m2m_create_string.replace(
""",
,
""",
"",
) # may have better way
m2m_create_string += self._post_table_hook()
if getattr(field_object, "create_unique_index", field_object.unique):
unique_index_create_sql = self._get_unique_index_sql(
exists, through_table_name, [backward_key, forward_key]
)
if unique_index_create_sql.endswith(";"):
m2m_create_string += "\n" + unique_index_create_sql
else:
lines = m2m_create_string.splitlines()
lines[-2] += ","
indent = m.group() if (m := re.match(r"\s+", lines[-2])) else ""
lines.insert(-1, indent + unique_index_create_sql)
m2m_create_string = "\n".join(lines)
m2m_tables_for_create.append(m2m_create_string)
return m2m_tables_for_create
setattr(BaseSchemaGenerator, "_get_m2m_tables", _get_m2m_tables)
_init_asyncio_patch()
_init_tortoise_0_24_1_patch()
class Command(AbstractAsyncContextManager):
def __init__(
self,
tortoise_config: dict,
app: str = "models",
location: str = "./migrations",
):
) -> None:
self.tortoise_config = tortoise_config
self.app = app
self.location = location
Migrate.app = app
async def init(self):
async def init(self) -> None:
await Migrate.init(self.tortoise_config, self.app, self.location)
async def upgrade(self):
async def __aenter__(self) -> Command:
await self.init()
return self
async def close(self) -> None:
await connections.close_all()
async def __aexit__(self, *args, **kw) -> None:
await self.close()
async def _upgrade(self, conn, version_file, fake: bool = False) -> None:
file_path = Path(Migrate.migrate_location, version_file)
m = import_py_file(file_path)
upgrade = m.upgrade
if not fake:
await conn.execute_script(await upgrade(conn))
await Aerich.create(
version=version_file,
app=self.app,
content=get_models_describe(self.app),
)
async def upgrade(self, run_in_transaction: bool = True, fake: bool = False) -> list[str]:
migrated = []
for version_file in Migrate.get_all_version_files():
try:
@ -44,23 +181,18 @@ class Command:
except OperationalError:
exists = False
if not exists:
async with in_transaction(
get_app_connection_name(self.tortoise_config, self.app)
) as conn:
file_path = Path(Migrate.migrate_location, version_file)
m = import_py_file(file_path)
upgrade = getattr(m, "upgrade")
await conn.execute_script(await upgrade(conn))
await Aerich.create(
version=version_file,
app=self.app,
content=get_models_describe(self.app),
)
app_conn_name = get_app_connection_name(self.tortoise_config, self.app)
if run_in_transaction:
async with in_transaction(app_conn_name) as conn:
await self._upgrade(conn, version_file, fake=fake)
else:
app_conn = get_app_connection(self.tortoise_config, self.app)
await self._upgrade(app_conn, version_file, fake=fake)
migrated.append(version_file)
return migrated
async def downgrade(self, version: int, delete: bool):
ret = []
async def downgrade(self, version: int, delete: bool, fake: bool = False) -> list[str]:
ret: list[str] = []
if version == -1:
specified_version = await Migrate.get_last_version()
else:
@ -73,25 +205,26 @@ class Command:
versions = [specified_version]
else:
versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk)
for version in versions:
file = version.version
for version_obj in versions:
file = version_obj.version
async with in_transaction(
get_app_connection_name(self.tortoise_config, self.app)
) as conn:
file_path = Path(Migrate.migrate_location, file)
m = import_py_file(file_path)
downgrade = getattr(m, "downgrade")
downgrade = m.downgrade
downgrade_sql = await downgrade(conn)
if not downgrade_sql.strip():
raise DowngradeError("No downgrade items found")
await conn.execute_script(downgrade_sql)
await version.delete()
if not fake:
await conn.execute_script(downgrade_sql)
await version_obj.delete()
if delete:
os.unlink(file_path)
ret.append(file)
return ret
async def heads(self):
async def heads(self) -> list[str]:
ret = []
versions = Migrate.get_all_version_files()
for version in versions:
@ -99,15 +232,15 @@ class Command:
ret.append(version)
return ret
async def history(self):
async def history(self) -> list[str]:
versions = Migrate.get_all_version_files()
return [version for version in versions]
async def inspectdb(self, tables: List[str] = None) -> str:
async def inspectdb(self, tables: list[str] | None = None) -> str:
connection = get_app_connection(self.tortoise_config, self.app)
dialect = connection.schema_generator.DIALECT
if dialect == "mysql":
cls = InspectMySQL
cls: type[Inspect] = InspectMySQL
elif dialect == "postgres":
cls = InspectPostgres
elif dialect == "sqlite":
@ -117,14 +250,19 @@ class Command:
inspect = cls(connection, tables)
return await inspect.inspect()
async def migrate(self, name: str = "update"):
return await Migrate.migrate(name)
async def migrate(self, name: str = "update", empty: bool = False) -> str:
return await Migrate.migrate(name, empty)
async def init_db(self, safe: bool):
async def init_db(self, safe: bool) -> None:
location = self.location
app = self.app
dirname = Path(location, app)
dirname.mkdir(parents=True)
if not dirname.exists():
dirname.mkdir(parents=True)
else:
# If directory is empty, go ahead, otherwise raise FileExistsError
for unexpected_file in dirname.glob("*"):
raise FileExistsError(str(unexpected_file))
await Tortoise.init(config=self.tortoise_config)
connection = get_app_connection(self.tortoise_config, app)

3
aerich/__main__.py Normal file
View File

@ -0,0 +1,3 @@
from .cli import main
main()

28
aerich/_compat.py Normal file
View File

@ -0,0 +1,28 @@
# mypy: disable-error-code="no-redef"
from __future__ import annotations
import sys
from types import ModuleType
import tortoise
if sys.version_info >= (3, 11):
import tomllib
else:
try:
import tomli as tomllib
except ImportError:
import tomlkit as tomllib
def imports_tomlkit() -> ModuleType:
try:
import tomli_w as tomlkit
except ImportError:
import tomlkit
return tomlkit
def tortoise_version_less_than(version: str) -> bool:
# The min version of tortoise is '0.11.0', so we can compare it by a `<`,
return tortoise.__version__ < version

View File

@ -1,16 +1,14 @@
import asyncio
import os
from functools import wraps
from pathlib import Path
from typing import List
from __future__ import annotations
import click
import tomlkit
from click import Context, UsageError
from tomlkit.exceptions import NonExistentKey
from tortoise import Tortoise
import os
from pathlib import Path
from typing import cast
import asyncclick as click
from asyncclick import Context, UsageError
from aerich import Command
from aerich._compat import imports_tomlkit, tomllib
from aerich.enums import Color
from aerich.exceptions import DowngradeError
from aerich.utils import add_src_path, get_tortoise_config
@ -21,19 +19,20 @@ CONFIG_DEFAULT_VALUES = {
}
def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
def _patch_context_to_close_tortoise_connections_when_exit() -> None:
from tortoise import Tortoise, connections
# Close db connections at the end of all but the cli group function
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
if f.__name__ not in ["cli", "init_db", "init"]:
loop.run_until_complete(Tortoise.close_connections())
origin_aexit = Context.__aexit__
return wrapper
async def aexit(*args, **kw) -> None:
await origin_aexit(*args, **kw)
if Tortoise._inited:
await connections.close_all()
Context.__aexit__ = aexit # type:ignore[method-assign]
_patch_context_to_close_tortoise_connections_when_exit()
@click.group(context_settings={"help_option_names": ["-h", "--help"]})
@ -47,8 +46,7 @@ def coro(f):
)
@click.option("--app", required=False, help="Tortoise-ORM app name.")
@click.pass_context
@coro
async def cli(ctx: Context, config, app):
async def cli(ctx: Context, config, app) -> None:
ctx.ensure_object(dict)
ctx.obj["config_file"] = config
@ -56,50 +54,78 @@ async def cli(ctx: Context, config, app):
if invoked_subcommand != "init":
config_path = Path(config)
if not config_path.exists():
raise UsageError("You must exec init first", ctx=ctx)
content = config_path.read_text()
doc = tomlkit.parse(content)
raise UsageError(
"You need to run `aerich init` first to create the config file.", ctx=ctx
)
content = config_path.read_text("utf-8")
doc: dict = tomllib.loads(content)
try:
tool = doc["tool"]["aerich"]
tool = cast("dict[str, str]", doc["tool"]["aerich"])
location = tool["location"]
tortoise_orm = tool["tortoise_orm"]
src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"])
except NonExistentKey:
raise UsageError("You need run aerich init again when upgrade to 0.6.0+")
except KeyError as e:
raise UsageError(
"You need run `aerich init` again when upgrading to aerich 0.6.0+."
) from e
add_src_path(src_folder)
tortoise_config = get_tortoise_config(ctx, tortoise_orm)
app = app or list(tortoise_config.get("apps").keys())[0]
if not app:
try:
apps_config = cast(dict, tortoise_config["apps"])
except KeyError:
raise UsageError('Config must define "apps" section')
app = list(apps_config.keys())[0]
command = Command(tortoise_config=tortoise_config, app=app, location=location)
ctx.obj["command"] = command
if invoked_subcommand != "init-db":
if not Path(location, app).exists():
raise UsageError("You must exec init-db first", ctx=ctx)
raise UsageError(
"You need to run `aerich init-db` first to initialize the database.", ctx=ctx
)
await command.init()
@cli.command(help="Generate migrate changes file.")
@click.option("--name", default="update", show_default=True, help="Migrate name.")
@cli.command(help="Generate a migration file for the current state of the models.")
@click.option("--name", default="update", show_default=True, help="Migration name.")
@click.option("--empty", default=False, is_flag=True, help="Generate an empty migration file.")
@click.pass_context
@coro
async def migrate(ctx: Context, name):
async def migrate(ctx: Context, name, empty) -> None:
command = ctx.obj["command"]
ret = await command.migrate(name)
ret = await command.migrate(name, empty)
if not ret:
return click.secho("No changes detected", fg=Color.yellow)
click.secho(f"Success migrate {ret}", fg=Color.green)
click.secho(f"Success creating migration file {ret}", fg=Color.green)
@cli.command(help="Upgrade to specified version.")
@cli.command(help="Upgrade to specified migration version.")
@click.option(
"--in-transaction",
"-i",
default=True,
type=bool,
help="Make migrations in a single transaction or not. Can be helpful for large migrations or creating concurrent indexes.",
)
@click.option(
"--fake",
default=False,
is_flag=True,
help="Mark migrations as run without actually running them.",
)
@click.pass_context
@coro
async def upgrade(ctx: Context):
async def upgrade(ctx: Context, in_transaction: bool, fake: bool) -> None:
command = ctx.obj["command"]
migrated = await command.upgrade()
migrated = await command.upgrade(run_in_transaction=in_transaction, fake=fake)
if not migrated:
click.secho("No upgrade items found", fg=Color.yellow)
else:
for version_file in migrated:
click.secho(f"Success upgrade {version_file}", fg=Color.green)
if fake:
click.echo(
f"Upgrading to {version_file}... " + click.style("FAKED", fg=Color.green)
)
else:
click.secho(f"Success upgrading to {version_file}", fg=Color.green)
@cli.command(help="Downgrade to specified version.")
@ -108,8 +134,8 @@ async def upgrade(ctx: Context):
"--version",
default=-1,
type=int,
show_default=True,
help="Specified version, default to last.",
show_default=False,
help="Specified version, default to last migration.",
)
@click.option(
"-d",
@ -117,59 +143,75 @@ async def upgrade(ctx: Context):
is_flag=True,
default=False,
show_default=True,
help="Delete version files at the same time.",
help="Also delete the migration files.",
)
@click.option(
"--fake",
default=False,
is_flag=True,
help="Mark migrations as run without actually running them.",
)
@click.pass_context
@click.confirmation_option(
prompt="Downgrade is dangerous, which maybe lose your data, are you sure?",
prompt="Downgrade is dangerous: you might lose your data! Are you sure?",
)
@coro
async def downgrade(ctx: Context, version: int, delete: bool):
async def downgrade(ctx: Context, version: int, delete: bool, fake: bool) -> None:
command = ctx.obj["command"]
try:
files = await command.downgrade(version, delete)
files = await command.downgrade(version, delete, fake=fake)
except DowngradeError as e:
return click.secho(str(e), fg=Color.yellow)
for file in files:
click.secho(f"Success downgrade {file}", fg=Color.green)
if fake:
click.echo(f"Downgrading to {file}... " + click.style("FAKED", fg=Color.green))
else:
click.secho(f"Success downgrading to {file}", fg=Color.green)
@cli.command(help="Show current available heads in migrate location.")
@cli.command(help="Show currently available heads (unapplied migrations).")
@click.pass_context
@coro
async def heads(ctx: Context):
async def heads(ctx: Context) -> None:
command = ctx.obj["command"]
head_list = await command.heads()
if not head_list:
return click.secho("No available heads, try migrate first", fg=Color.green)
return click.secho("No available heads.", fg=Color.green)
for version in head_list:
click.secho(version, fg=Color.green)
@cli.command(help="List all migrate items.")
@cli.command(help="List all migrations.")
@click.pass_context
@coro
async def history(ctx: Context):
async def history(ctx: Context) -> None:
command = ctx.obj["command"]
versions = await command.history()
if not versions:
return click.secho("No history, try migrate", fg=Color.green)
return click.secho("No migrations created yet.", fg=Color.green)
for version in versions:
click.secho(version, fg=Color.green)
@cli.command(help="Init config file and generate root migrate location.")
def _write_config(config_path, doc, table) -> None:
tomlkit = imports_tomlkit()
try:
doc["tool"]["aerich"] = table
except KeyError:
doc["tool"] = {"aerich": table}
config_path.write_text(tomlkit.dumps(doc))
@cli.command(help="Initialize aerich config and create migrations folder.")
@click.option(
"-t",
"--tortoise-orm",
required=True,
help="Tortoise-ORM config module dict variable, like settings.TORTOISE_ORM.",
help="Tortoise-ORM config dict location, like `settings.TORTOISE_ORM`.",
)
@click.option(
"--location",
default="./migrations",
show_default=True,
help="Migrate store location.",
help="Migrations folder.",
)
@click.option(
"-s",
@ -179,8 +221,7 @@ async def history(ctx: Context):
help="Folder of the source, relative to the project root.",
)
@click.pass_context
@coro
async def init(ctx: Context, tortoise_orm, location, src_folder):
async def init(ctx: Context, tortoise_orm, location, src_folder) -> None:
config_file = ctx.obj["config_file"]
if os.path.isabs(src_folder):
@ -193,52 +234,48 @@ async def init(ctx: Context, tortoise_orm, location, src_folder):
add_src_path(src_folder)
get_tortoise_config(ctx, tortoise_orm)
config_path = Path(config_file)
if config_path.exists():
content = config_path.read_text()
doc = tomlkit.parse(content)
else:
doc = tomlkit.parse("[tool.aerich]")
table = tomlkit.table()
table["tortoise_orm"] = tortoise_orm
table["location"] = location
table["src_folder"] = src_folder
doc["tool"]["aerich"] = table
content = config_path.read_text("utf-8") if config_path.exists() else "[tool.aerich]"
doc: dict = tomllib.loads(content)
config_path.write_text(tomlkit.dumps(doc))
table = {"tortoise_orm": tortoise_orm, "location": location, "src_folder": src_folder}
if (aerich_config := doc.get("tool", {}).get("aerich")) and all(
aerich_config.get(k) == v for k, v in table.items()
):
click.echo(f"Aerich config {config_file} already inited.")
else:
_write_config(config_path, doc, table)
click.secho(f"Success writing aerich config to {config_file}", fg=Color.green)
Path(location).mkdir(parents=True, exist_ok=True)
click.secho(f"Success create migrate location {location}", fg=Color.green)
click.secho(f"Success write config to {config_file}", fg=Color.green)
click.secho(f"Success creating migrations folder {location}", fg=Color.green)
@cli.command(help="Generate schema and generate app migrate location.")
@cli.command(help="Generate schema and generate app migration folder.")
@click.option(
"-s",
"--safe",
type=bool,
is_flag=True,
default=True,
help="When set to true, creates the table only when it does not already exist.",
help="Create tables only when they do not already exist.",
show_default=True,
)
@click.pass_context
@coro
async def init_db(ctx: Context, safe: bool):
async def init_db(ctx: Context, safe: bool) -> None:
command = ctx.obj["command"]
app = command.app
dirname = Path(command.location, app)
try:
await command.init_db(safe)
click.secho(f"Success create app migrate location {dirname}", fg=Color.green)
click.secho(f'Success generate schema for app "{app}"', fg=Color.green)
click.secho(f"Success creating app migration folder {dirname}", fg=Color.green)
click.secho(f'Success generating initial migration file for app "{app}"', fg=Color.green)
except FileExistsError:
return click.secho(
f"Inited {app} already, or delete {dirname} and try again.", fg=Color.yellow
f"App {app} is already initialized. Delete {dirname} and try again.", fg=Color.yellow
)
@cli.command(help="Introspects the database tables to standard output as TortoiseORM model.")
@cli.command(help="Prints the current database tables to stdout as Tortoise-ORM models.")
@click.option(
"-t",
"--table",
@ -247,14 +284,13 @@ async def init_db(ctx: Context, safe: bool):
required=False,
)
@click.pass_context
@coro
async def inspectdb(ctx: Context, table: List[str]):
async def inspectdb(ctx: Context, table: list[str]) -> None:
command = ctx.obj["command"]
ret = await command.inspectdb(table)
click.secho(ret)
def main():
def main() -> None:
cli()

View File

@ -1,13 +1,19 @@
from __future__ import annotations
import base64
import json
import pickle # nosec: B301,B403
from typing import Any
from tortoise.indexes import Index
class JsonEncoder(json.JSONEncoder):
def default(self, obj):
def default(self, obj) -> Any:
if isinstance(obj, Index):
if hasattr(obj, "describe"):
# For tortoise>=0.24
return obj.describe()
return {
"type": "index",
"val": base64.b64encode(pickle.dumps(obj)).decode(), # nosec: B301
@ -16,16 +22,28 @@ class JsonEncoder(json.JSONEncoder):
return super().default(obj)
def object_hook(obj):
_type = obj.get("type")
if not _type:
return obj
return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301
def object_hook(obj) -> Any:
if (type_ := obj.get("type")) and type_ == "index" and (val := obj.get("val")):
return pickle.loads(base64.b64decode(val)) # nosec: B301
return obj
def encoder(obj: dict):
def load_index(obj: dict) -> Index:
"""Convert a dict that generated by `Index.decribe()` to a Index instance"""
try:
index = Index(fields=obj["fields"] or obj["expressions"], name=obj.get("name"))
except KeyError:
return object_hook(obj)
if extra := obj.get("extra"):
index.extra = extra
if idx_type := obj.get("type"):
index.INDEX_TYPE = idx_type
return index
def encoder(obj: dict) -> str:
return json.dumps(obj, cls=JsonEncoder)
def decoder(obj: str):
def decoder(obj: str | bytes) -> Any:
return json.loads(obj, object_hook=object_hook)

View File

@ -1,14 +1,20 @@
from enum import Enum
from typing import List, Type
from __future__ import annotations
import re
from enum import Enum
from typing import TYPE_CHECKING, Any, cast
from tortoise import BaseDBAsyncClient, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from aerich._compat import tortoise_version_less_than
from aerich.utils import is_default_function
if TYPE_CHECKING:
from tortoise import BaseDBAsyncClient, Model
class BaseDDL:
schema_generator_cls: Type[BaseSchemaGenerator] = BaseSchemaGenerator
schema_generator_cls: type[BaseSchemaGenerator] = BaseSchemaGenerator
DIALECT = "sql"
_DROP_TABLE_TEMPLATE = 'DROP TABLE IF EXISTS "{table_name}"'
_ADD_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {column}'
@ -17,10 +23,8 @@ class BaseDDL:
_RENAME_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" RENAME COLUMN "{old_column_name}" TO "{new_column_name}"'
)
_ADD_INDEX_TEMPLATE = (
'ALTER TABLE "{table_name}" ADD {unique}INDEX "{index_name}" ({column_names})'
)
_DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"'
_ADD_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {index_type}{unique}INDEX "{index_name}" ({column_names}){extra}'
_DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX IF EXISTS "{index_name}"'
_ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"'
_M2M_TABLE_TEMPLATE = (
@ -35,23 +39,32 @@ class BaseDDL:
)
_RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"'
def __init__(self, client: "BaseDBAsyncClient"):
def __init__(self, client: BaseDBAsyncClient) -> None:
self.client = client
self.schema_generator = self.schema_generator_cls(client)
def create_table(self, model: "Type[Model]"):
return self.schema_generator._get_table_sql(model, True)["table_creation_string"]
@staticmethod
def get_table_name(model: type[Model]) -> str:
return model._meta.db_table
def drop_table(self, table_name: str):
def create_table(self, model: type[Model]) -> str:
schema = self.schema_generator._get_table_sql(model, True)["table_creation_string"]
if tortoise_version_less_than("0.23.1"):
# Remove extra space
schema = re.sub(r'(["()A-Za-z]) (["()A-Za-z])', r"\1 \2", schema)
return schema.rstrip(";")
def drop_table(self, table_name: str) -> str:
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def create_m2m(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
):
through = field_describe.get("through")
self, model: type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
through = cast(str, field_describe.get("through"))
description = field_describe.get("description")
reference_id = reference_table_describe.get("pk_field").get("db_column")
db_field_types = reference_table_describe.get("pk_field").get("db_field_types")
pk_field = cast(dict, reference_table_describe.get("pk_field"))
reference_id = pk_field.get("db_column")
db_field_types = cast(dict, pk_field.get("db_field_types"))
return self._M2M_TABLE_TEMPLATE.format(
table_name=through,
backward_table=model._meta.db_table,
@ -64,22 +77,22 @@ class BaseDDL:
forward_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
on_delete=field_describe.get("on_delete"),
extra=self.schema_generator._table_generate_extra(table=through),
comment=self.schema_generator._table_comment_generator(
table=through, comment=description
)
if description
else "",
comment=(
self.schema_generator._table_comment_generator(table=through, comment=description)
if description
else ""
),
)
def drop_m2m(self, table_name: str):
def drop_m2m(self, table_name: str) -> str:
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def _get_default(self, model: "Type[Model]", field_describe: dict):
def _get_default(self, model: type[Model], field_describe: dict) -> Any:
db_table = model._meta.db_table
default = field_describe.get("default")
if isinstance(default, Enum):
default = default.value
db_column = field_describe.get("db_column")
db_column = cast(str, field_describe.get("db_column"))
auto_now_add = field_describe.get("auto_now_add", False)
auto_now = field_describe.get("auto_now", False)
if default is not None or auto_now_add:
@ -100,68 +113,58 @@ class BaseDDL:
)
except NotImplementedError:
default = ""
else:
default = None
return default
def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
def add_column(self, model: type[Model], field_describe: dict, is_pk: bool = False) -> str:
return self._add_or_modify_column(model, field_describe, is_pk)
def _add_or_modify_column(
self, model: type[Model], field_describe: dict, is_pk: bool, modify: bool = False
) -> str:
db_table = model._meta.db_table
description = field_describe.get("description")
db_column = field_describe.get("db_column")
db_field_types = field_describe.get("db_field_types")
db_column = cast(str, field_describe.get("db_column"))
db_field_types = cast(dict, field_describe.get("db_field_types"))
default = self._get_default(model, field_describe)
if default is None:
default = ""
return self._ADD_COLUMN_TEMPLATE.format(
table_name=db_table,
column=self.schema_generator._create_string(
db_column=db_column,
field_type=db_field_types.get(self.DIALECT, db_field_types.get("")),
nullable="NOT NULL" if not field_describe.get("nullable") else "",
unique="UNIQUE" if field_describe.get("unique") else "",
comment=self.schema_generator._column_comment_generator(
if modify:
unique = ""
template = self._MODIFY_COLUMN_TEMPLATE
else:
# sqlite does not support alter table to add unique column
unique = " UNIQUE" if field_describe.get("unique") and self.DIALECT != "sqlite" else ""
template = self._ADD_COLUMN_TEMPLATE
column = self.schema_generator._create_string(
db_column=db_column,
field_type=db_field_types.get(self.DIALECT, db_field_types.get("")),
nullable=" NOT NULL" if not field_describe.get("nullable") else "",
unique=unique,
comment=(
self.schema_generator._column_comment_generator(
table=db_table,
column=db_column,
comment=field_describe.get("description"),
comment=description,
)
if description
else "",
is_primary_key=is_pk,
default=default,
else ""
),
is_primary_key=is_pk,
default=default,
)
if tortoise_version_less_than("0.23.1"):
column = column.replace(" ", " ")
return template.format(table_name=db_table, column=column)
def drop_column(self, model: "Type[Model]", column_name: str):
def drop_column(self, model: type[Model], column_name: str) -> str:
return self._DROP_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, column_name=column_name
)
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
db_table = model._meta.db_table
db_field_types = field_describe.get("db_field_types")
default = self._get_default(model, field_describe)
if default is None:
default = ""
return self._MODIFY_COLUMN_TEMPLATE.format(
table_name=db_table,
column=self.schema_generator._create_string(
db_column=field_describe.get("db_column"),
field_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
nullable="NOT NULL" if not field_describe.get("nullable") else "",
unique="",
comment=self.schema_generator._column_comment_generator(
table=db_table,
column=field_describe.get("db_column"),
comment=field_describe.get("description"),
)
if field_describe.get("description")
else "",
is_primary_key=is_pk,
default=default,
),
)
def modify_column(self, model: type[Model], field_describe: dict, is_pk: bool = False) -> str:
return self._add_or_modify_column(model, field_describe, is_pk, modify=True)
def rename_column(self, model: "Type[Model]", old_column_name: str, new_column_name: str):
def rename_column(self, model: type[Model], old_column_name: str, new_column_name: str) -> str:
return self._RENAME_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table,
old_column_name=old_column_name,
@ -169,8 +172,8 @@ class BaseDDL:
)
def change_column(
self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str
):
self, model: type[Model], old_column_name: str, new_column_name: str, new_column_type: str
) -> str:
return self._CHANGE_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table,
old_column_name=old_column_name,
@ -178,63 +181,92 @@ class BaseDDL:
new_column_type=new_column_type,
)
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
def _index_name(self, unique: bool | None, model: type[Model], field_names: list[str]) -> str:
func_name = "_get_index_name"
if not hasattr(self.schema_generator, func_name):
# For tortoise-orm<0.24.1
func_name = "_generate_index_name"
return getattr(self.schema_generator, func_name)(
"idx" if not unique else "uid", model, field_names
)
def add_index(
self,
model: type[Model],
field_names: list[str],
unique: bool | None = False,
name: str | None = None,
index_type: str = "",
extra: str | None = "",
) -> str:
return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE " if unique else "",
index_name=self.schema_generator._generate_index_name(
"idx" if not unique else "uid", model, field_names
),
index_name=name or self._index_name(unique, model, field_names),
table_name=model._meta.db_table,
column_names=", ".join(self.schema_generator.quote(f) for f in field_names),
index_type=f"{index_type} " if index_type else "",
extra=f"{extra}" if extra else "",
)
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False):
def drop_index(
self,
model: type[Model],
field_names: list[str],
unique: bool | None = False,
name: str | None = None,
) -> str:
return self._DROP_INDEX_TEMPLATE.format(
index_name=self.schema_generator._generate_index_name(
"idx" if not unique else "uid", model, field_names
),
index_name=name or self._index_name(unique, model, field_names),
table_name=model._meta.db_table,
)
def drop_index_by_name(self, model: "Type[Model]", index_name: str):
return self._DROP_INDEX_TEMPLATE.format(
index_name=index_name,
table_name=model._meta.db_table,
def drop_index_by_name(self, model: type[Model], index_name: str) -> str:
return self.drop_index(model, [], name=index_name)
def _generate_fk_name(
self, db_table: str, field_describe: dict, reference_table_describe: dict
) -> str:
"""Generate fk name"""
db_column = cast(str, field_describe.get("raw_field"))
pk_field = cast(dict, reference_table_describe.get("pk_field"))
to_field = cast(str, pk_field.get("db_column"))
to_table = cast(str, reference_table_describe.get("table"))
func_name = "_get_fk_name"
if not hasattr(self.schema_generator, func_name):
# For tortoise-orm<0.24.1
func_name = "_generate_fk_name"
return getattr(self.schema_generator, func_name)(
from_table=db_table,
from_field=db_column,
to_table=to_table,
to_field=to_field,
)
def add_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict):
def add_fk(
self, model: type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
db_table = model._meta.db_table
db_column = field_describe.get("raw_field")
reference_id = reference_table_describe.get("pk_field").get("db_column")
fk_name = self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=db_column,
to_table=reference_table_describe.get("table"),
to_field=reference_table_describe.get("pk_field").get("db_column"),
)
pk_field = cast(dict, reference_table_describe.get("pk_field"))
reference_id = pk_field.get("db_column")
return self._ADD_FK_TEMPLATE.format(
table_name=db_table,
fk_name=fk_name,
fk_name=self._generate_fk_name(db_table, field_describe, reference_table_describe),
db_column=db_column,
table=reference_table_describe.get("table"),
field=reference_id,
on_delete=field_describe.get("on_delete"),
)
def drop_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict):
def drop_fk(
self, model: type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
db_table = model._meta.db_table
return self._DROP_FK_TEMPLATE.format(
table_name=db_table,
fk_name=self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=field_describe.get("raw_field"),
to_table=reference_table_describe.get("table"),
to_field=reference_table_describe.get("pk_field").get("db_column"),
),
)
fk_name = self._generate_fk_name(db_table, field_describe, reference_table_describe)
return self._DROP_FK_TEMPLATE.format(table_name=db_table, fk_name=fk_name)
def alter_column_default(self, model: "Type[Model]", field_describe: dict):
def alter_column_default(self, model: type[Model], field_describe: dict) -> str:
db_table = model._meta.db_table
default = self._get_default(model, field_describe)
return self._ALTER_DEFAULT_TEMPLATE.format(
@ -243,14 +275,28 @@ class BaseDDL:
default="SET" + default if default is not None else "DROP DEFAULT",
)
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
def alter_column_null(self, model: type[Model], field_describe: dict) -> str:
return self.modify_column(model, field_describe)
def set_comment(self, model: "Type[Model]", field_describe: dict):
def set_comment(self, model: type[Model], field_describe: dict) -> str:
return self.modify_column(model, field_describe)
def rename_table(self, model: "Type[Model]", old_table_name: str, new_table_name: str):
def rename_table(self, model: type[Model], old_table_name: str, new_table_name: str) -> str:
db_table = model._meta.db_table
return self._RENAME_TABLE_TEMPLATE.format(
table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name
)
def alter_indexed_column_unique(
self, model: type[Model], field_name: str, drop: bool = False
) -> list[str]:
"""Change unique constraint for indexed field, e.g.: Field(db_index=True) --> Field(unique=True)"""
fields = [field_name]
if drop:
drop_unique = self.drop_index(model, fields, unique=True)
add_normal_index = self.add_index(model, fields, unique=False)
return [drop_unique, add_normal_index]
else:
drop_index = self.drop_index(model, fields, unique=False)
add_unique_index = self.add_index(model, fields, unique=True)
return [drop_index, add_unique_index]

View File

@ -1,7 +1,14 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
from aerich.ddl import BaseDDL
if TYPE_CHECKING:
from tortoise import Model
class MysqlDDL(BaseDDL):
schema_generator_cls = MySQLSchemaGenerator
@ -16,10 +23,14 @@ class MysqlDDL(BaseDDL):
_RENAME_COLUMN_TEMPLATE = (
"ALTER TABLE `{table_name}` RENAME COLUMN `{old_column_name}` TO `{new_column_name}`"
)
_ADD_INDEX_TEMPLATE = (
"ALTER TABLE `{table_name}` ADD {unique}INDEX `{index_name}` ({column_names})"
)
_ADD_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` ADD {index_type}{unique}INDEX `{index_name}` ({column_names}){extra}"
_DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`"
_ADD_INDEXED_UNIQUE_TEMPLATE = (
"ALTER TABLE `{table_name}` DROP INDEX `{index_name}`, ADD UNIQUE (`{column_name}`)"
)
_DROP_INDEXED_UNIQUE_TEMPLATE = (
"ALTER TABLE `{table_name}` DROP INDEX `{column_name}`, ADD INDEX (`{index_name}`)"
)
_ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}"
_DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`"
_M2M_TABLE_TEMPLATE = (
@ -30,3 +41,21 @@ class MysqlDDL(BaseDDL):
)
_MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}"
_RENAME_TABLE_TEMPLATE = "ALTER TABLE `{old_table_name}` RENAME TO `{new_table_name}`"
def _index_name(self, unique: bool | None, model: type[Model], field_names: list[str]) -> str:
if unique and len(field_names) == 1:
# Example: `email = CharField(max_length=50, unique=True)`
# Generate schema: `"email" VARCHAR(10) NOT NULL UNIQUE`
# Unique index key is the same as field name: `email`
return field_names[0]
return super()._index_name(unique, model, field_names)
def alter_indexed_column_unique(
self, model: type[Model], field_name: str, drop: bool = False
) -> list[str]:
# if drop is false: Drop index and add unique
# else: Drop unique index and add normal index
template = self._DROP_INDEXED_UNIQUE_TEMPLATE if drop else self._ADD_INDEXED_UNIQUE_TEMPLATE
table = self.get_table_name(model)
index = self._index_name(unique=False, model=model, field_names=[field_name])
return [template.format(table_name=table, index_name=index, column_name=field_name)]

View File

@ -1,24 +1,26 @@
from typing import Type
from __future__ import annotations
from typing import cast
from tortoise import Model
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
from tortoise.backends.base_postgres.schema_generator import BasePostgresSchemaGenerator
from aerich.ddl import BaseDDL
class PostgresDDL(BaseDDL):
schema_generator_cls = AsyncpgSchemaGenerator
DIALECT = AsyncpgSchemaGenerator.DIALECT
_ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX "{index_name}" ON "{table_name}" ({column_names})'
_DROP_INDEX_TEMPLATE = 'DROP INDEX "{index_name}"'
schema_generator_cls = BasePostgresSchemaGenerator
DIALECT = BasePostgresSchemaGenerator.DIALECT
_ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX IF NOT EXISTS "{index_name}" ON "{table_name}" {index_type}({column_names}){extra}'
_DROP_INDEX_TEMPLATE = 'DROP INDEX IF EXISTS "{index_name}"'
_ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL'
_MODIFY_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}{using}'
)
_SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT IF EXISTS "{fk_name}"'
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
def alter_column_null(self, model: type[Model], field_describe: dict) -> str:
db_table = model._meta.db_table
return self._ALTER_NULL_TEMPLATE.format(
table_name=db_table,
@ -26,9 +28,9 @@ class PostgresDDL(BaseDDL):
set_drop="DROP" if field_describe.get("nullable") else "SET",
)
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
def modify_column(self, model: type[Model], field_describe: dict, is_pk: bool = False) -> str:
db_table = model._meta.db_table
db_field_types = field_describe.get("db_field_types")
db_field_types = cast(dict, field_describe.get("db_field_types"))
db_column = field_describe.get("db_column")
datatype = db_field_types.get(self.DIALECT) or db_field_types.get("")
return self._MODIFY_COLUMN_TEMPLATE.format(
@ -38,12 +40,14 @@ class PostgresDDL(BaseDDL):
using=f' USING "{db_column}"::{datatype}',
)
def set_comment(self, model: "Type[Model]", field_describe: dict):
def set_comment(self, model: type[Model], field_describe: dict) -> str:
db_table = model._meta.db_table
return self._SET_COMMENT_TEMPLATE.format(
table_name=db_table,
column=field_describe.get("db_column") or field_describe.get("raw_field"),
comment="'{}'".format(field_describe.get("description"))
if field_describe.get("description")
else "NULL",
comment=(
"'{}'".format(field_describe.get("description"))
if field_describe.get("description")
else "NULL"
),
)

View File

@ -1,4 +1,4 @@
from typing import Type
from __future__ import annotations
from tortoise import Model
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
@ -10,15 +10,17 @@ from aerich.exceptions import NotSupportError
class SqliteDDL(BaseDDL):
schema_generator_cls = SqliteSchemaGenerator
DIALECT = SqliteSchemaGenerator.DIALECT
_ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX "{index_name}" ON "{table_name}" ({column_names})'
_DROP_INDEX_TEMPLATE = 'DROP INDEX IF EXISTS "{index_name}"'
def modify_column(self, model: "Type[Model]", field_object: dict, is_pk: bool = True):
def modify_column(self, model: type[Model], field_object: dict, is_pk: bool = True):
raise NotSupportError("Modify column is unsupported in SQLite.")
def alter_column_default(self, model: "Type[Model]", field_describe: dict):
def alter_column_default(self, model: type[Model], field_describe: dict):
raise NotSupportError("Alter column default is unsupported in SQLite.")
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
def alter_column_null(self, model: type[Model], field_describe: dict):
raise NotSupportError("Alter column null is unsupported in SQLite.")
def set_comment(self, model: "Type[Model]", field_describe: dict):
def set_comment(self, model: type[Model], field_describe: dict):
raise NotSupportError("Alter column comment is unsupported in SQLite.")

View File

@ -1,52 +1,69 @@
from typing import Any, List, Optional
from __future__ import annotations
import contextlib
from typing import Any, Callable, Dict, TypedDict
from pydantic import BaseModel
from tortoise import BaseDBAsyncClient
class ColumnInfoDict(TypedDict):
name: str
pk: str
index: str
null: str
default: str
length: str
comment: str
# TODO: use dict to replace typing.Dict when dropping support for Python3.8
FieldMapDict = Dict[str, Callable[..., str]]
class Column(BaseModel):
name: str
data_type: str
null: bool
default: Any
comment: Optional[str]
comment: str | None = None
pk: bool
unique: bool
index: bool
length: Optional[int]
extra: Optional[str]
decimal_places: Optional[int]
max_digits: Optional[int]
length: int | None = None
extra: str | None = None
decimal_places: int | None = None
max_digits: int | None = None
def translate(self) -> dict:
def translate(self) -> ColumnInfoDict:
comment = default = length = index = null = pk = ""
if self.pk:
pk = "pk=True, "
pk = "primary_key=True, "
else:
if self.unique:
index = "unique=True, "
else:
if self.index:
index = "index=True, "
if self.data_type in ["varchar", "VARCHAR"]:
elif self.index:
index = "db_index=True, "
if self.data_type in ("varchar", "VARCHAR"):
length = f"max_length={self.length}, "
if self.data_type in ["decimal", "numeric"]:
elif self.data_type in ("decimal", "numeric"):
length_parts = []
if self.max_digits:
length_parts.append(f"max_digits={self.max_digits}")
if self.decimal_places:
length_parts.append(f"decimal_places={self.decimal_places}")
length = ", ".join(length_parts)
if length_parts:
length = ", ".join(length_parts) + ", "
if self.null:
null = "null=True, "
if self.default is not None:
if self.data_type in ["tinyint", "INT"]:
if self.default is not None and not self.pk:
if self.data_type in ("tinyint", "INT"):
default = f"default={'True' if self.default == '1' else 'False'}, "
elif self.data_type == "bool":
default = f"default={'True' if self.default == 'true' else 'False'}, "
elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]:
if "CURRENT_TIMESTAMP" == self.default:
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
elif self.data_type in ("datetime", "timestamptz", "TIMESTAMP"):
if self.default == "CURRENT_TIMESTAMP":
if self.extra == "DEFAULT_GENERATED on update CURRENT_TIMESTAMP":
default = "auto_now=True, "
else:
default = "auto_now_add=True, "
@ -55,6 +72,8 @@ class Column(BaseModel):
default = f"default={self.default.split('::')[0]}, "
elif self.default.endswith("()"):
default = ""
elif self.default == "":
default = 'default=""'
else:
default = f"default={self.default}, "
@ -74,16 +93,14 @@ class Column(BaseModel):
class Inspect:
_table_template = "class {table}(Model):\n"
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
def __init__(self, conn: BaseDBAsyncClient, tables: list[str] | None = None) -> None:
self.conn = conn
try:
self.database = conn.database
except AttributeError:
pass
with contextlib.suppress(AttributeError):
self.database = conn.database # type:ignore[attr-defined]
self.tables = tables
@property
def field_map(self) -> dict:
def field_map(self) -> FieldMapDict:
raise NotImplementedError
async def inspect(self) -> str:
@ -101,68 +118,75 @@ class Inspect:
tables.append(model + "\n".join(fields))
return result + "\n\n\n".join(tables)
async def get_columns(self, table: str) -> List[Column]:
async def get_columns(self, table: str) -> list[Column]:
raise NotImplementedError
async def get_all_tables(self) -> List[str]:
async def get_all_tables(self) -> list[str]:
raise NotImplementedError
@staticmethod
def get_field_string(
field_class: str, arguments: str = "{null}{default}{comment}", **kwargs
) -> str:
name = kwargs["name"]
field_params = arguments.format(**kwargs).strip().rstrip(",")
return f"{name} = fields.{field_class}({field_params})"
@classmethod
def decimal_field(cls, **kwargs) -> str:
return "{name} = fields.DecimalField({pk}{index}{length}{null}{default}{comment})".format(
**kwargs
)
return cls.get_field_string("DecimalField", **kwargs)
@classmethod
def time_field(cls, **kwargs) -> str:
return "{name} = fields.TimeField({null}{default}{comment})".format(**kwargs)
return cls.get_field_string("TimeField", **kwargs)
@classmethod
def date_field(cls, **kwargs) -> str:
return "{name} = fields.DateField({null}{default}{comment})".format(**kwargs)
return cls.get_field_string("DateField", **kwargs)
@classmethod
def float_field(cls, **kwargs) -> str:
return "{name} = fields.FloatField({null}{default}{comment})".format(**kwargs)
return cls.get_field_string("FloatField", **kwargs)
@classmethod
def datetime_field(cls, **kwargs) -> str:
return "{name} = fields.DatetimeField({null}{default}{comment})".format(**kwargs)
return cls.get_field_string("DatetimeField", **kwargs)
@classmethod
def text_field(cls, **kwargs) -> str:
return "{name} = fields.TextField({null}{default}{comment})".format(**kwargs)
return cls.get_field_string("TextField", **kwargs)
@classmethod
def char_field(cls, **kwargs) -> str:
return "{name} = fields.CharField({pk}{index}{length}{null}{default}{comment})".format(
**kwargs
)
arguments = "{pk}{index}{length}{null}{default}{comment}"
return cls.get_field_string("CharField", arguments, **kwargs)
@classmethod
def int_field(cls, **kwargs) -> str:
return "{name} = fields.IntField({pk}{index}{comment})".format(**kwargs)
def int_field(cls, field_class="IntField", **kwargs) -> str:
arguments = "{pk}{index}{default}{comment}"
return cls.get_field_string(field_class, arguments, **kwargs)
@classmethod
def smallint_field(cls, **kwargs) -> str:
return "{name} = fields.SmallIntField({pk}{index}{comment})".format(**kwargs)
return cls.int_field("SmallIntField", **kwargs)
@classmethod
def bigint_field(cls, **kwargs) -> str:
return "{name} = fields.BigIntField({pk}{index}{default}{comment})".format(**kwargs)
return cls.int_field("BigIntField", **kwargs)
@classmethod
def bool_field(cls, **kwargs) -> str:
return "{name} = fields.BooleanField({null}{default}{comment})".format(**kwargs)
return cls.get_field_string("BooleanField", **kwargs)
@classmethod
def uuid_field(cls, **kwargs) -> str:
return "{name} = fields.UUIDField({pk}{index}{default}{comment})".format(**kwargs)
arguments = "{pk}{index}{default}{comment}"
return cls.get_field_string("UUIDField", arguments, **kwargs)
@classmethod
def json_field(cls, **kwargs) -> str:
return "{name} = fields.JSONField({null}{default}{comment})".format(**kwargs)
return cls.get_field_string("JSONField", **kwargs)
@classmethod
def binary_field(cls, **kwargs) -> str:
return "{name} = fields.BinaryField({null}{default}{comment})".format(**kwargs)
return cls.get_field_string("BinaryField", **kwargs)

View File

@ -1,21 +1,23 @@
from typing import List
from __future__ import annotations
from aerich.inspectdb import Column, Inspect
from aerich.inspectdb import Column, FieldMapDict, Inspect
class InspectMySQL(Inspect):
@property
def field_map(self) -> dict:
def field_map(self) -> FieldMapDict:
return {
"int": self.int_field,
"smallint": self.smallint_field,
"tinyint": self.bool_field,
"bigint": self.bigint_field,
"varchar": self.char_field,
"char": self.uuid_field,
"longtext": self.text_field,
"text": self.text_field,
"datetime": self.datetime_field,
"float": self.float_field,
"double": self.float_field,
"date": self.date_field,
"time": self.time_field,
"decimal": self.decimal_field,
@ -23,12 +25,12 @@ class InspectMySQL(Inspect):
"longblob": self.binary_field,
}
async def get_all_tables(self) -> List[str]:
async def get_all_tables(self) -> list[str]:
sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s"
ret = await self.conn.execute_query_dict(sql, [self.database])
return list(map(lambda x: x["TABLE_NAME"], ret))
async def get_columns(self, table: str) -> List[Column]:
async def get_columns(self, table: str) -> list[Column]:
columns = []
sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME
from information_schema.COLUMNS c
@ -39,16 +41,13 @@ where c.TABLE_SCHEMA = %s
and c.TABLE_NAME = %s"""
ret = await self.conn.execute_query_dict(sql, [self.database, table])
for row in ret:
non_unique = row["NON_UNIQUE"]
if non_unique is None:
unique = False
else:
unique = index = False
if (non_unique := row["NON_UNIQUE"]) is not None:
unique = not non_unique
index_name = row["INDEX_NAME"]
if index_name is None:
index = False
else:
index = row["INDEX_NAME"] != "PRIMARY"
elif row["COLUMN_KEY"] == "UNI":
unique = True
if (index_name := row["INDEX_NAME"]) is not None:
index = index_name != "PRIMARY"
columns.append(
Column(
name=row["COLUMN_NAME"],
@ -57,9 +56,8 @@ where c.TABLE_SCHEMA = %s
default=row["COLUMN_DEFAULT"],
pk=row["COLUMN_KEY"] == "PRI",
comment=row["COLUMN_COMMENT"],
unique=row["COLUMN_KEY"] == "UNI",
unique=unique,
extra=row["EXTRA"],
unque=unique,
index=index,
length=row["CHARACTER_MAXIMUM_LENGTH"],
max_digits=row["NUMERIC_PRECISION"],

View File

@ -1,24 +1,29 @@
from typing import List, Optional
from __future__ import annotations
from tortoise import BaseDBAsyncClient
import re
from typing import TYPE_CHECKING
from aerich.inspectdb import Column, Inspect
from aerich.inspectdb import Column, FieldMapDict, Inspect
if TYPE_CHECKING:
from tortoise.backends.base_postgres.client import BasePostgresClient
class InspectPostgres(Inspect):
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
def __init__(self, conn: BasePostgresClient, tables: list[str] | None = None) -> None:
super().__init__(conn, tables)
self.schema = self.conn.server_settings.get("schema") or "public"
self.schema = conn.server_settings.get("schema") or "public"
@property
def field_map(self) -> dict:
def field_map(self) -> FieldMapDict:
return {
"int2": self.smallint_field,
"int4": self.int_field,
"int8": self.int_field,
"int8": self.bigint_field,
"smallint": self.smallint_field,
"bigint": self.bigint_field,
"varchar": self.char_field,
"text": self.text_field,
"bigint": self.bigint_field,
"timestamptz": self.datetime_field,
"float4": self.float_field,
"float8": self.float_field,
@ -33,12 +38,12 @@ class InspectPostgres(Inspect):
"timestamp": self.datetime_field,
}
async def get_all_tables(self) -> List[str]:
async def get_all_tables(self) -> list[str]:
sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2"
ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
return list(map(lambda x: x["table_name"], ret))
async def get_columns(self, table: str) -> List[Column]:
async def get_columns(self, table: str) -> list[Column]:
columns = []
sql = f"""select c.column_name,
col_description('public.{table}'::regclass, ordinal_position) as column_comment,
@ -55,7 +60,9 @@ from information_schema.constraint_column_usage const
right join information_schema.columns c using (column_name, table_catalog, table_schema, table_name)
where c.table_catalog = $1
and c.table_name = $2
and c.table_schema = $3"""
and c.table_schema = $3""" # nosec:B608
if "psycopg" in str(type(self.conn)).lower():
sql = re.sub(r"\$[123]", "%s", sql)
ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema])
for row in ret:
columns.append(

View File

@ -1,11 +1,11 @@
from typing import List
from __future__ import annotations
from aerich.inspectdb import Column, Inspect
from aerich.inspectdb import Column, FieldMapDict, Inspect
class InspectSQLite(Inspect):
@property
def field_map(self) -> dict:
def field_map(self) -> FieldMapDict:
return {
"INTEGER": self.int_field,
"INT": self.bool_field,
@ -21,7 +21,7 @@ class InspectSQLite(Inspect):
"BLOB": self.binary_field,
}
async def get_columns(self, table: str) -> List[Column]:
async def get_columns(self, table: str) -> list[Column]:
columns = []
sql = f"PRAGMA table_info({table})"
ret = await self.conn.execute_query_dict(sql)
@ -45,7 +45,7 @@ class InspectSQLite(Inspect):
)
return columns
async def _get_columns_index(self, table: str):
async def _get_columns_index(self, table: str) -> dict[str, str]:
sql = f"PRAGMA index_list ({table})"
indexes = await self.conn.execute_query_dict(sql)
ret = {}
@ -55,7 +55,7 @@ class InspectSQLite(Inspect):
ret[index_info["name"]] = "unique" if index["unique"] else "index"
return ret
async def get_all_tables(self) -> List[str]:
async def get_all_tables(self) -> list[str]:
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
ret = await self.conn.execute_query_dict(sql)
return list(map(lambda x: x["tbl_name"], ret))

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,7 @@ MAX_APP_LENGTH = 100
class Aerich(Model):
version = fields.CharField(max_length=MAX_VERSION_LENGTH)
app = fields.CharField(max_length=MAX_APP_LENGTH)
content = fields.JSONField(encoder=encoder, decoder=decoder)
content: dict = fields.JSONField(encoder=encoder, decoder=decoder)
class Meta:
ordering = ["-id"]

View File

@ -1,11 +1,15 @@
from __future__ import annotations
import importlib.util
import os
import re
import sys
from collections.abc import Generator
from pathlib import Path
from typing import Dict
from types import ModuleType
from click import BadOptionUsage, ClickException, Context
from asyncclick import BadOptionUsage, ClickException, Context
from dictdiffer import diff
from tortoise import BaseDBAsyncClient, Tortoise
@ -30,23 +34,19 @@ def get_app_connection_name(config, app_name: str) -> str:
get connection name
:param config:
:param app_name:
:return:
:return: the default connection name (Usally it is 'default')
"""
app = config.get("apps").get(app_name)
if app:
if app := config.get("apps").get(app_name):
return app.get("default_connection", "default")
raise BadOptionUsage(
option_name="--app",
message=f'Can\'t get app named "{app_name}"',
)
raise BadOptionUsage(option_name="--app", message=f"Can't get app named {app_name!r}")
def get_app_connection(config, app) -> BaseDBAsyncClient:
"""
get connection name
get connection client
:param config:
:param app:
:return:
:return: client instance
"""
return Tortoise.get_connection(get_app_connection_name(config, app))
@ -77,26 +77,67 @@ def get_tortoise_config(ctx: Context, tortoise_orm: str) -> dict:
return config
def get_models_describe(app: str) -> Dict:
def get_models_describe(app: str) -> dict:
"""
get app models describe
:param app:
:return:
"""
ret = {}
for model in Tortoise.apps.get(app).values():
for model in Tortoise.apps[app].values():
managed = getattr(model.Meta, "managed", None)
describe = model.describe()
ret[describe.get("name")] = describe
ret[describe.get("name")] = dict(describe, managed=managed)
return ret
def is_default_function(string: str):
def is_default_function(string: str) -> re.Match | None:
return re.match(r"^<function.+>$", str(string or ""))
def import_py_file(file: Path):
def import_py_file(file: str | Path) -> ModuleType:
module_name, file_ext = os.path.splitext(os.path.split(file)[-1])
spec = importlib.util.spec_from_file_location(module_name, file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
module = importlib.util.module_from_spec(spec) # type:ignore[arg-type]
spec.loader.exec_module(module) # type:ignore[union-attr]
return module
def get_dict_diff_by_key(
old_fields: list[dict], new_fields: list[dict], key="through"
) -> Generator[tuple]:
"""
Compare two list by key instead of by index
:param old_fields: previous field info list
:param new_fields: current field info list
:param key: if two dicts have the same value of this key, action is change; otherwise, is remove/add
:return: similar to dictdiffer.diff
Example::
>>> old = [{'through': 'a'}, {'through': 'b'}, {'through': 'c'}]
>>> new = [{'through': 'a'}, {'through': 'c'}] # remove the second element
>>> list(diff(old, new))
[('change', [1, 'through'], ('b', 'c')),
('remove', '', [(2, {'through': 'c'})])]
>>> list(get_dict_diff_by_key(old, new))
[('remove', '', [(0, {'through': 'b'})])]
"""
length_old, length_new = len(old_fields), len(new_fields)
if length_old == 0 or length_new == 0 or length_old == length_new == 1:
yield from diff(old_fields, new_fields)
else:
value_index: dict[str, int] = {f[key]: i for i, f in enumerate(new_fields)}
additions = set(range(length_new))
for field in old_fields:
value = field[key]
if (index := value_index.get(value)) is not None:
additions.remove(index)
yield from diff([field], [new_fields[index]]) # change
else:
yield from diff([field], []) # remove
if additions:
for index in sorted(additions):
yield from diff([], [new_fields[index]]) # add

View File

@ -1 +1,3 @@
__version__ = "0.7.1"
from importlib.metadata import version
__version__ = version(__package__)

View File

@ -1,23 +1,30 @@
from __future__ import annotations
import asyncio
import os
import sys
from collections.abc import Generator
from pathlib import Path
import pytest
from tortoise import Tortoise, expand_db_url, generate_schema_for_client
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
from tortoise import Tortoise, expand_db_url
from tortoise.backends.base_postgres.schema_generator import BasePostgresSchemaGenerator
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
from tortoise.contrib.test import MEMORY_SQLITE
from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL
from aerich.migrate import Migrate
from tests._utils import chdir, copy_files, init_db, run_shell
db_url = os.getenv("TEST_DB", "sqlite://:memory:")
db_url_second = os.getenv("TEST_DB_SECOND", "sqlite://:memory:")
db_url = os.getenv("TEST_DB", MEMORY_SQLITE)
db_url_second = os.getenv("TEST_DB_SECOND", MEMORY_SQLITE)
tortoise_orm = {
"connections": {
"default": expand_db_url(db_url, True),
"second": expand_db_url(db_url_second, True),
"default": expand_db_url(db_url, testing=True),
"second": expand_db_url(db_url_second, testing=True),
},
"apps": {
"models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"},
@ -27,7 +34,7 @@ tortoise_orm = {
@pytest.fixture(scope="function", autouse=True)
def reset_migrate():
def reset_migrate() -> None:
Migrate.upgrade_operators = []
Migrate.downgrade_operators = []
Migrate._upgrade_fk_m2m_index_operators = []
@ -37,29 +44,54 @@ def reset_migrate():
@pytest.fixture(scope="session")
def event_loop():
def event_loop() -> Generator:
policy = asyncio.get_event_loop_policy()
res = policy.new_event_loop()
asyncio.set_event_loop(res)
res._close = res.close
res.close = lambda: None
res._close = res.close # type:ignore[attr-defined]
res.close = lambda: None # type:ignore[method-assign]
yield res
res._close()
res._close() # type:ignore[attr-defined]
@pytest.fixture(scope="session", autouse=True)
async def initialize_tests(event_loop, request):
await Tortoise.init(config=tortoise_orm, _create_db=True)
await generate_schema_for_client(Tortoise.get_connection("default"), safe=True)
async def initialize_tests(event_loop, request) -> None:
await init_db(tortoise_orm)
client = Tortoise.get_connection("default")
if client.schema_generator is MySQLSchemaGenerator:
Migrate.ddl = MysqlDDL(client)
elif client.schema_generator is SqliteSchemaGenerator:
Migrate.ddl = SqliteDDL(client)
elif client.schema_generator is AsyncpgSchemaGenerator:
elif issubclass(client.schema_generator, BasePostgresSchemaGenerator):
Migrate.ddl = PostgresDDL(client)
Migrate.dialect = Migrate.ddl.DIALECT
request.addfinalizer(lambda: event_loop.run_until_complete(Tortoise._drop_databases()))
@pytest.fixture
def new_aerich_project(tmp_path: Path):
test_dir = Path(__file__).parent / "tests"
asset_dir = test_dir / "assets" / "fake"
settings_py = asset_dir / "settings.py"
_tests_py = asset_dir / "_tests.py"
db_py = asset_dir / "db.py"
models_py = test_dir / "models.py"
models_second_py = test_dir / "models_second.py"
copy_files(settings_py, _tests_py, models_py, models_second_py, db_py, target_dir=tmp_path)
dst_dir = tmp_path / "tests"
dst_dir.mkdir()
dst_dir.joinpath("__init__.py").touch()
copy_files(test_dir / "_utils.py", test_dir / "indexes.py", target_dir=dst_dir)
if should_remove := str(tmp_path) not in sys.path:
sys.path.append(str(tmp_path))
with chdir(tmp_path):
run_shell("python db.py create", capture_output=False)
try:
yield
finally:
if not os.getenv("AERICH_DONT_DROP_FAKE_DB"):
run_shell("python db.py drop", capture_output=False)
if should_remove:
sys.path.remove(str(tmp_path))

2227
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,44 +1,60 @@
[tool.poetry]
[project]
name = "aerich"
version = "0.7.1"
version = "0.8.2"
description = "A database migrations tool for Tortoise ORM."
authors = ["long2ice <long2ice@gmail.com>"]
license = "Apache-2.0"
authors = [{name="long2ice", email="long2ice@gmail.com>"}]
license = { text = "Apache-2.0" }
readme = "README.md"
keywords = ["migrate", "Tortoise-ORM", "mysql"]
packages = [{ include = "aerich" }]
include = ["CHANGELOG.md", "LICENSE", "README.md"]
requires-python = ">=3.8"
dependencies = [
"tortoise-orm (>=0.21.0,<1.0.0); python_version < '4.0'",
"pydantic (>=2.0.2,!=2.1.0,!=2.7.0,<3.0.0)",
"dictdiffer (>=0.9.0,<1.0.0)",
"asyncclick (>=8.1.7,<9.0.0)",
"eval-type-backport (>=0.2.2,<1.0.0); python_version < '3.10'",
]
[project.optional-dependencies]
toml = [
"tomli-w (>=1.1.0,<2.0.0); python_version >= '3.11'",
"tomlkit (>=0.11.4,<1.0.0); python_version < '3.11'",
]
# Need asyncpg or psyncopg for PostgreSQL
asyncpg = ["asyncpg"]
psycopg = ["psycopg[pool,binary] (>=3.0.12,<4.0.0)"]
# Need asyncmy or aiomysql for MySQL
asyncmy = ["asyncmy>=0.2.9; python_version < '4.0'"]
mysql = ["aiomysql>=0.2.0"]
[project.urls]
homepage = "https://github.com/tortoise/aerich"
repository = "https://github.com/tortoise/aerich.git"
documentation = "https://github.com/tortoise/aerich"
keywords = ["migrate", "Tortoise-ORM", "mysql"]
packages = [
{ include = "aerich" }
]
include = ["CHANGELOG.md", "LICENSE", "README.md"]
[tool.poetry.dependencies]
python = "^3.7"
tortoise-orm = "*"
click = "*"
asyncpg = { version = "*", optional = true }
asyncmy = { version = "*", optional = true }
pydantic = "*"
dictdiffer = "*"
tomlkit = "*"
[project.scripts]
aerich = "aerich.cli:main"
[tool.poetry.dev-dependencies]
flake8 = "*"
isort = "*"
black = "*"
pytest = "*"
pytest-xdist = "*"
pytest-asyncio = "*"
bandit = "*"
pytest-mock = "*"
cryptography = "*"
pyproject-flake8 = "*"
[tool.poetry]
requires-poetry = ">=2.0"
[tool.poetry.extras]
asyncmy = ["asyncmy"]
asyncpg = ["asyncpg"]
[tool.poetry.group.dev.dependencies]
ruff = "^0.9.0"
bandit = "^1.7.0"
mypy = "^1.10.0"
twine = "^6.1.0"
[tool.poetry.group.test.dependencies]
pytest = "^8.3.0"
pytest-mock = "^3.14.0"
pytest-xdist = "^3.6.0"
# Breaking change in 0.23.*
# https://github.com/pytest-dev/pytest-asyncio/issues/706
pytest-asyncio = "^0.21.2"
# required for sha256_password by asyncmy
cryptography = {version="^44.0.1", python="!=3.9.0,!=3.9.1"}
[tool.aerich]
tortoise_orm = "conftest.tortoise_orm"
@ -46,22 +62,55 @@ location = "./migrations"
src_folder = "./."
[build-system]
requires = ["poetry>=0.12"]
build-backend = "poetry.masonry.api"
[tool.poetry.scripts]
aerich = "aerich.cli:main"
[tool.black]
line-length = 100
target-version = ['py36', 'py37', 'py38', 'py39']
requires = ["poetry-core>=2.0.0"]
build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
asyncio_mode = 'auto'
[tool.coverage.run]
branch = true
source = ["aerich"]
[tool.coverage.report]
exclude_also = [
"if TYPE_CHECKING:"
]
[tool.mypy]
pretty = true
python_version = "3.8"
check_untyped_defs = true
warn_unused_ignores = true
disallow_incomplete_defs = false
exclude = ["tests/assets", "migrations"]
[[tool.mypy.overrides]]
module = [
'dictdiffer.*',
'tomlkit',
'tomli_w',
'tomli',
]
ignore_missing_imports = true
[tool.flake8]
ignore = 'E501,W503,E203'
[tool.ruff]
line-length = 100
[tool.ruff.lint]
extend-select = [
"I", # https://docs.astral.sh/ruff/rules/#isort-i
"SIM", # https://docs.astral.sh/ruff/rules/#flake8-simplify-sim
"FA", # https://docs.astral.sh/ruff/rules/#flake8-future-annotations-fa
"UP", # https://docs.astral.sh/ruff/rules/#pyupgrade-up
"RUF100", # https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
]
ignore = ["UP031"] # https://docs.astral.sh/ruff/rules/printf-string-formatting/
[tool.ruff.lint.per-file-ignores]
# TODO: Remove this line when dropping support for Python3.8
"aerich/inspectdb/__init__.py" = ["UP006", "UP035"]
"aerich/_compat.py" = ["F401"]
[tool.bandit]
exclude_dirs = ["tests", "conftest.py"]

87
tests/_utils.py Normal file
View File

@ -0,0 +1,87 @@
import contextlib
import os
import platform
import shlex
import shutil
import subprocess
import sys
from pathlib import Path
from tortoise import Tortoise, generate_schema_for_client
from tortoise.exceptions import DBConnectionError, OperationalError
if sys.version_info >= (3, 11):
from contextlib import chdir
else:
class chdir(contextlib.AbstractContextManager): # Copied from source code of Python3.13
"""Non thread-safe context manager to change the current working directory."""
def __init__(self, path):
self.path = path
self._old_cwd = []
def __enter__(self):
self._old_cwd.append(os.getcwd())
os.chdir(self.path)
def __exit__(self, *excinfo):
os.chdir(self._old_cwd.pop())
async def drop_db(tortoise_orm) -> None:
# Placing init outside the try-block(suppress) since it doesn't
# establish connections to the DB eagerly.
await Tortoise.init(config=tortoise_orm)
with contextlib.suppress(DBConnectionError, OperationalError):
await Tortoise._drop_databases()
async def init_db(tortoise_orm, generate_schemas=True) -> None:
await drop_db(tortoise_orm)
await Tortoise.init(config=tortoise_orm, _create_db=True)
if generate_schemas:
await generate_schema_for_client(Tortoise.get_connection("default"), safe=True)
def copy_files(*src_files: Path, target_dir: Path) -> None:
for src in src_files:
shutil.copy(src, target_dir)
class Dialect:
test_db_url: str
@classmethod
def load_env(cls) -> None:
if getattr(cls, "test_db_url", None) is None:
cls.test_db_url = os.getenv("TEST_DB", "")
@classmethod
def is_postgres(cls) -> bool:
cls.load_env()
return "postgres" in cls.test_db_url
@classmethod
def is_mysql(cls) -> bool:
cls.load_env()
return "mysql" in cls.test_db_url
@classmethod
def is_sqlite(cls) -> bool:
cls.load_env()
return not cls.test_db_url or "sqlite" in cls.test_db_url
WINDOWS = platform.system() == "Windows"
def run_shell(command: str, capture_output=True, **kw) -> str:
if WINDOWS and command.startswith("aerich "):
command = "python -m " + command
r = subprocess.run(shlex.split(command), capture_output=capture_output)
if r.returncode != 0 and r.stderr:
return r.stderr.decode()
if not r.stdout:
return ""
return r.stdout.decode()

View File

@ -0,0 +1,80 @@
import pytest
from models import NewModel
from models_second import Config
from settings import TORTOISE_ORM
from tortoise import Tortoise
from tortoise.exceptions import OperationalError
try:
# This error does not translate to tortoise's OperationalError
from psycopg.errors import UndefinedColumn
except ImportError:
errors = (OperationalError,)
else:
errors = (OperationalError, UndefinedColumn)
@pytest.fixture(scope="session")
def anyio_backend() -> str:
return "asyncio"
@pytest.fixture(autouse=True)
async def init_connections():
await Tortoise.init(TORTOISE_ORM)
try:
yield
finally:
await Tortoise.close_connections()
@pytest.mark.anyio
async def test_init_db():
m1 = await NewModel.filter(name="")
assert isinstance(m1, list)
m2 = await Config.filter(key="")
assert isinstance(m2, list)
await NewModel.create(name="")
await Config.create(key="", label="", value={})
@pytest.mark.anyio
async def test_fake_field_1():
assert "field_1" in NewModel._meta.fields_map
assert "field_1" in Config._meta.fields_map
with pytest.raises(errors):
await NewModel.create(name="", field_1=1)
with pytest.raises(errors):
await Config.create(key="", label="", value={}, field_1=1)
obj1 = NewModel(name="", field_1=1)
with pytest.raises(errors):
await obj1.save()
obj1 = NewModel(name="")
with pytest.raises(errors):
await obj1.save()
with pytest.raises(errors):
obj1 = await NewModel.first()
obj1 = await NewModel.all().first().values("id", "name")
assert obj1 and obj1["id"]
obj2 = Config(key="", label="", value={}, field_1=1)
with pytest.raises(errors):
await obj2.save()
obj2 = Config(key="", label="", value={})
with pytest.raises(errors):
await obj2.save()
with pytest.raises(errors):
obj2 = await Config.first()
obj2 = await Config.all().first().values("id", "key")
assert obj2 and obj2["id"]
@pytest.mark.anyio
async def test_fake_field_2():
assert "field_2" in NewModel._meta.fields_map
assert "field_2" in Config._meta.fields_map
with pytest.raises(errors):
await NewModel.create(name="")
with pytest.raises(errors):
await Config.create(key="", label="", value={})

28
tests/assets/fake/db.py Normal file
View File

@ -0,0 +1,28 @@
import asyncclick as click
from settings import TORTOISE_ORM
from tests._utils import drop_db, init_db
@click.group()
def cli(): ...
@cli.command()
async def create():
await init_db(TORTOISE_ORM, False)
click.echo(f"Success to create databases for {TORTOISE_ORM['connections']}")
@cli.command()
async def drop():
await drop_db(TORTOISE_ORM)
click.echo(f"Dropped databases for {TORTOISE_ORM['connections']}")
def main():
cli()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,22 @@
import os
from datetime import date
from tortoise.contrib.test import MEMORY_SQLITE
DB_URL = (
_u.replace("\\{\\}", f"aerich_fake_{date.today():%Y%m%d}")
if (_u := os.getenv("TEST_DB"))
else MEMORY_SQLITE
)
DB_URL_SECOND = (DB_URL + "_second") if DB_URL != MEMORY_SQLITE else MEMORY_SQLITE
TORTOISE_ORM = {
"connections": {
"default": DB_URL.replace(MEMORY_SQLITE, "sqlite://db.sqlite3"),
"second": DB_URL_SECOND.replace(MEMORY_SQLITE, "sqlite://db_second.sqlite3"),
},
"apps": {
"models": {"models": ["models", "aerich.models"], "default_connection": "default"},
"models_second": {"models": ["models_second"], "default_connection": "second"},
},
}

View File

@ -0,0 +1,76 @@
import uuid
import pytest
from models import Foo
from tortoise.exceptions import IntegrityError
@pytest.mark.asyncio
async def test_allow_duplicate() -> None:
await Foo.all().delete()
await Foo.create(name="foo")
obj = await Foo.create(name="foo")
assert (await Foo.all().count()) == 2
await obj.delete()
@pytest.mark.asyncio
async def test_unique_is_true() -> None:
with pytest.raises(IntegrityError):
await Foo.create(name="foo")
await Foo.create(name="foo")
@pytest.mark.asyncio
async def test_add_unique_field() -> None:
if not await Foo.filter(age=0).exists():
await Foo.create(name="0_" + uuid.uuid4().hex, age=0)
with pytest.raises(IntegrityError):
await Foo.create(name=uuid.uuid4().hex, age=0)
@pytest.mark.asyncio
async def test_drop_unique_field() -> None:
name = "1_" + uuid.uuid4().hex
await Foo.create(name=name, age=0)
assert await Foo.filter(name=name).exists()
@pytest.mark.asyncio
async def test_with_age_field() -> None:
name = "2_" + uuid.uuid4().hex
await Foo.create(name=name, age=0)
obj = await Foo.get(name=name)
assert obj.age == 0
@pytest.mark.asyncio
async def test_without_age_field() -> None:
name = "3_" + uuid.uuid4().hex
await Foo.create(name=name, age=0)
obj = await Foo.get(name=name)
assert getattr(obj, "age", None) is None
@pytest.mark.asyncio
async def test_m2m_with_custom_through() -> None:
from models import FooGroup, Group
name = "4_" + uuid.uuid4().hex
foo = await Foo.create(name=name)
group = await Group.create(name=name + "1")
await FooGroup.all().delete()
await foo.groups.add(group)
foo_group = await FooGroup.get(foo=foo, group=group)
assert not foo_group.is_active
@pytest.mark.asyncio
async def test_add_m2m_field_after_init_db() -> None:
from models import Group
name = "5_" + uuid.uuid4().hex
foo = await Foo.create(name=name)
group = await Group.create(name=name + "1")
await foo.groups.add(group)
assert (await group.users.all().first()) == foo

View File

@ -0,0 +1,28 @@
from __future__ import annotations
import asyncio
from collections.abc import Generator
import pytest
import pytest_asyncio
import settings
from tortoise import Tortoise, connections
@pytest.fixture(scope="session")
def event_loop() -> Generator:
policy = asyncio.get_event_loop_policy()
res = policy.new_event_loop()
asyncio.set_event_loop(res)
res._close = res.close # type:ignore[attr-defined]
res.close = lambda: None # type:ignore[method-assign]
yield res
res._close() # type:ignore[attr-defined]
@pytest_asyncio.fixture(scope="session", autouse=True)
async def api(event_loop, request):
await Tortoise.init(config=settings.TORTOISE_ORM)
request.addfinalizer(lambda: event_loop.run_until_complete(connections.close_all(discard=True)))

View File

@ -0,0 +1,5 @@
from tortoise import Model, fields
class Foo(Model):
name = fields.CharField(max_length=60, db_index=False)

View File

@ -0,0 +1,4 @@
TORTOISE_ORM = {
"connections": {"default": "sqlite://db.sqlite3"},
"apps": {"models": {"models": ["models", "aerich.models"]}},
}

7
tests/indexes.py Normal file
View File

@ -0,0 +1,7 @@
from tortoise.indexes import Index
class CustomIndex(Index):
def __init__(self, *args, **kw) -> None:
super().__init__(*args, **kw)
self._foo = ""

View File

@ -1,8 +1,16 @@
from __future__ import annotations
import datetime
import uuid
from enum import IntEnum
from tortoise import Model, fields
from tortoise.contrib.mysql.indexes import FullTextIndex
from tortoise.contrib.postgres.indexes import HashIndex
from tortoise.indexes import Index
from tests._utils import Dialect
from tests.indexes import CustomIndex
class ProductType(IntEnum):
@ -31,13 +39,21 @@ class User(Model):
intro = fields.TextField(default="")
longitude = fields.DecimalField(max_digits=10, decimal_places=8)
products: fields.ManyToManyRelation[Product]
class Meta:
# reverse indexes elements
indexes = [CustomIndex(fields=("is_superuser",)), Index(fields=("username", "is_active"))]
class Email(Model):
email_id = fields.IntField(pk=True)
email = fields.CharField(max_length=200, index=True)
email_id = fields.IntField(primary_key=True)
email = fields.CharField(max_length=200, db_index=True)
company = fields.CharField(max_length=100, db_index=True, unique=True)
is_primary = fields.BooleanField(default=False)
address = fields.CharField(max_length=200)
users = fields.ManyToManyField("models.User")
users: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User")
config: fields.OneToOneRelation[Config] = fields.OneToOneField("models.Config")
def default_name():
@ -47,34 +63,78 @@ def default_name():
class Category(Model):
slug = fields.CharField(max_length=100)
name = fields.CharField(max_length=200, null=True, default=default_name)
user = fields.ForeignKeyField("models.User", description="User")
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", description="User"
)
title = fields.CharField(max_length=20, unique=False)
created_at = fields.DatetimeField(auto_now_add=True)
class Meta:
if Dialect.is_postgres():
indexes = [HashIndex(fields=("slug",))]
elif Dialect.is_mysql():
indexes = [FullTextIndex(fields=("slug",))] # type:ignore
else:
indexes = [Index(fields=("slug",))] # type:ignore
class Product(Model):
categories = fields.ManyToManyField("models.Category")
id = fields.BigIntField(primary_key=True)
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models.Category", null=False
)
users: fields.ManyToManyRelation[User] = fields.ManyToManyField(
"models.User", related_name="products"
)
name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num", default=0)
sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField(
type: int = fields.IntEnumField(
ProductType, description="Product Type", source_field="type_db_alias"
)
pic = fields.CharField(max_length=200)
body = fields.TextField()
price = fields.FloatField(null=True)
no = fields.UUIDField(db_index=True)
created_at = fields.DatetimeField(auto_now_add=True)
is_deleted = fields.BooleanField(default=False)
class Meta:
unique_together = (("name", "type"),)
indexes = (("name", "type"),)
managed = True
class Config(Model):
slug = fields.CharField(primary_key=True, max_length=20)
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models.Category", through="config_category_map", related_name="category_set"
)
label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20)
value = fields.JSONField()
value: dict = fields.JSONField()
status: Status = fields.IntEnumField(Status)
user = fields.ForeignKeyField("models.User", description="User")
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", description="User"
)
email: fields.OneToOneRelation[Email]
class Meta:
managed = True
class DontManageMe(Model):
name = fields.CharField(max_length=50)
class Meta:
managed = False
class Ignore(Model):
class Meta:
managed = False
class NewModel(Model):

View File

@ -34,23 +34,29 @@ class User(Model):
class Email(Model):
email = fields.CharField(max_length=200)
is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("models_second.User", db_constraint=False)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models_second.User", db_constraint=False
)
class Category(Model):
slug = fields.CharField(max_length=200)
name = fields.CharField(max_length=200)
user = fields.ForeignKeyField("models_second.User", description="User")
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models_second.User", description="User"
)
created_at = fields.DatetimeField(auto_now_add=True)
class Product(Model):
categories = fields.ManyToManyField("models_second.Category")
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models_second.Category"
)
name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num")
sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField(
type: int = fields.IntEnumField(
ProductType, description="Product Type", source_field="type_db_alias"
)
image = fields.CharField(max_length=200)
@ -61,5 +67,5 @@ class Product(Model):
class Config(Model):
label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20)
value = fields.JSONField()
value: dict = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on)

View File

@ -2,6 +2,9 @@ import datetime
from enum import IntEnum
from tortoise import Model, fields
from tortoise.indexes import Index
from tests.indexes import CustomIndex
class ProductType(IntEnum):
@ -31,39 +34,96 @@ class User(Model):
intro = fields.TextField(default="")
longitude = fields.DecimalField(max_digits=12, decimal_places=9)
class Meta:
indexes = [Index(fields=("username", "is_active")), CustomIndex(fields=("is_superuser",))]
class Email(Model):
email = fields.CharField(max_length=200)
company = fields.CharField(max_length=100, db_index=True)
is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("models.User", db_constraint=False)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", db_constraint=False
)
class Category(Model):
slug = fields.CharField(max_length=200)
name = fields.CharField(max_length=200)
user = fields.ForeignKeyField("models.User", description="User")
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", description="User"
)
title = fields.CharField(max_length=20, unique=True)
created_at = fields.DatetimeField(auto_now_add=True)
class Meta:
indexes = [Index(fields=("slug",))]
class Product(Model):
categories = fields.ManyToManyField("models.Category")
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category")
uid = fields.IntField(source_field="uuid", unique=True)
name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num")
sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField(
is_review = fields.BooleanField(description="Is Reviewed")
type: int = fields.IntEnumField(
ProductType, description="Product Type", source_field="type_db_alias"
)
image = fields.CharField(max_length=200)
body = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True)
is_delete = fields.BooleanField(default=False)
class Config(Model):
slug = fields.CharField(primary_key=True, max_length=10)
category: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category")
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models.Category", through="config_category_map", related_name="config_set"
)
name = fields.CharField(max_length=100, unique=True)
label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20)
value = fields.JSONField()
value: dict = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on)
class Meta:
table = "configs"
class DontManageMe(Model):
name = fields.CharField(max_length=50)
class Meta:
table = "dont_manage"
class Ignore(Model):
name = fields.CharField(max_length=50)
class Meta:
managed = True
def main() -> None:
"""Generate a python file for the old_models_describe"""
from pathlib import Path
from tortoise import run_async
from tortoise.contrib.test import init_memory_sqlite
from aerich.utils import get_models_describe
@init_memory_sqlite
async def run() -> None:
old_models_describe = get_models_describe("models")
p = Path("old_models_describe.py")
p.write_text(f"{old_models_describe = }", encoding="utf-8")
print(f"Write value to {p}\nYou can reformat it by `ruff format {p}`")
run_async(run())
if __name__ == "__main__":
main()

11
tests/test_command.py Normal file
View File

@ -0,0 +1,11 @@
from aerich import Command
from conftest import tortoise_orm
async def test_command(mocker):
mocker.patch("os.listdir", return_value=[])
async with Command(tortoise_orm) as command:
history = await command.history()
heads = await command.heads()
assert history == []
assert heads == []

View File

@ -1,3 +1,5 @@
import tortoise
from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL
@ -8,28 +10,48 @@ from tests.models import Category, Product, User
def test_create_table():
ret = Migrate.ddl.create_table(Category)
if isinstance(Migrate.ddl, MysqlDDL):
if tortoise.__version__ >= "0.24":
assert (
ret
== """CREATE TABLE IF NOT EXISTS `category` (
`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
`slug` VARCHAR(100) NOT NULL,
`name` VARCHAR(200),
`title` VARCHAR(20) NOT NULL,
`created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
`owner_id` INT NOT NULL COMMENT 'User',
CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE,
FULLTEXT KEY `idx_category_slug_e9bcff` (`slug`)
) CHARACTER SET utf8mb4"""
)
return
assert (
ret
== """CREATE TABLE IF NOT EXISTS `category` (
`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
`slug` VARCHAR(100) NOT NULL,
`name` VARCHAR(200),
`created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
`user_id` INT NOT NULL COMMENT 'User',
CONSTRAINT `fk_category_user_e2e3874c` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE
) CHARACTER SET utf8mb4;"""
`title` VARCHAR(20) NOT NULL,
`created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
`owner_id` INT NOT NULL COMMENT 'User',
CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE
) CHARACTER SET utf8mb4;
CREATE FULLTEXT INDEX `idx_category_slug_e9bcff` ON `category` (`slug`)"""
)
elif isinstance(Migrate.ddl, SqliteDDL):
exists = "IF NOT EXISTS " if tortoise.__version__ >= "0.24" else ""
assert (
ret
== """CREATE TABLE IF NOT EXISTS "category" (
== f"""CREATE TABLE IF NOT EXISTS "category" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
"slug" VARCHAR(100) NOT NULL,
"name" VARCHAR(200),
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */
);"""
"title" VARCHAR(20) NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */
);
CREATE INDEX {exists}"idx_category_slug_e9bcff" ON "category" ("slug")"""
)
elif isinstance(Migrate.ddl, PostgresDDL):
@ -39,10 +61,12 @@ def test_create_table():
"id" SERIAL NOT NULL PRIMARY KEY,
"slug" VARCHAR(100) NOT NULL,
"name" VARCHAR(200),
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE
"title" VARCHAR(20) NOT NULL,
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE
);
COMMENT ON COLUMN "category"."user_id" IS 'User';"""
CREATE INDEX IF NOT EXISTS "idx_category_slug_e9bcff" ON "category" USING HASH ("slug");
COMMENT ON COLUMN "category"."owner_id" IS 'User'"""
)
@ -55,26 +79,32 @@ def test_drop_table():
def test_add_column():
ret = Migrate.ddl.add_column(Category, Category._meta.fields_map.get("name").describe(False))
ret = Migrate.ddl.add_column(Category, Category._meta.fields_map["name"].describe(False))
if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` ADD `name` VARCHAR(200)"
else:
assert ret == 'ALTER TABLE "category" ADD "name" VARCHAR(200)'
# add unique column
ret = Migrate.ddl.add_column(User, User._meta.fields_map["username"].describe(False))
if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `user` ADD `username` VARCHAR(20) NOT NULL UNIQUE"
elif isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "user" ADD "username" VARCHAR(20) NOT NULL UNIQUE'
else:
assert ret == 'ALTER TABLE "user" ADD "username" VARCHAR(20) NOT NULL'
def test_modify_column():
if isinstance(Migrate.ddl, SqliteDDL):
return
ret0 = Migrate.ddl.modify_column(
Category, Category._meta.fields_map.get("name").describe(False)
)
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active").describe(False))
ret0 = Migrate.ddl.modify_column(Category, Category._meta.fields_map["name"].describe(False))
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map["is_active"].describe(False))
if isinstance(Migrate.ddl, MysqlDDL):
assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)"
assert (
ret1
== "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1"
== "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1"
)
elif isinstance(Migrate.ddl, PostgresDDL):
assert (
@ -90,14 +120,14 @@ def test_modify_column():
def test_alter_column_default():
if isinstance(Migrate.ddl, SqliteDDL):
return
ret = Migrate.ddl.alter_column_default(User, User._meta.fields_map.get("intro").describe(False))
ret = Migrate.ddl.alter_column_default(User, User._meta.fields_map["intro"].describe(False))
if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "user" ALTER COLUMN "intro" SET DEFAULT \'\''
elif isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `user` ALTER COLUMN `intro` SET DEFAULT ''"
ret = Migrate.ddl.alter_column_default(
Category, Category._meta.fields_map.get("created_at").describe(False)
Category, Category._meta.fields_map["created_at"].describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL):
assert (
@ -110,7 +140,7 @@ def test_alter_column_default():
)
ret = Migrate.ddl.alter_column_default(
Product, Product._meta.fields_map.get("view_num").describe(False)
Product, Product._meta.fields_map["view_num"].describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0'
@ -121,9 +151,7 @@ def test_alter_column_default():
def test_alter_column_null():
if isinstance(Migrate.ddl, (SqliteDDL, MysqlDDL)):
return
ret = Migrate.ddl.alter_column_null(
Category, Category._meta.fields_map.get("name").describe(False)
)
ret = Migrate.ddl.alter_column_null(Category, Category._meta.fields_map["name"].describe(False))
if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" DROP NOT NULL'
@ -131,11 +159,11 @@ def test_alter_column_null():
def test_set_comment():
if isinstance(Migrate.ddl, (SqliteDDL, MysqlDDL)):
return
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("name").describe(False))
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map["name"].describe(False))
assert ret == 'COMMENT ON COLUMN "category"."name" IS NULL'
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("user").describe(False))
assert ret == 'COMMENT ON COLUMN "category"."user_id" IS \'User\''
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map["owner"].describe(False))
assert ret == 'COMMENT ON COLUMN "category"."owner_id" IS \'User\''
def test_drop_column():
@ -151,17 +179,18 @@ def test_add_index():
index_u = Migrate.ddl.add_index(Category, ["name"], True)
if isinstance(Migrate.ddl, MysqlDDL):
assert index == "ALTER TABLE `category` ADD INDEX `idx_category_name_8b0cb9` (`name`)"
assert (
index_u == "ALTER TABLE `category` ADD UNIQUE INDEX `uid_category_name_8b0cb9` (`name`)"
)
assert index_u == "ALTER TABLE `category` ADD UNIQUE INDEX `name` (`name`)"
elif isinstance(Migrate.ddl, PostgresDDL):
assert (
index == 'CREATE INDEX IF NOT EXISTS "idx_category_name_8b0cb9" ON "category" ("name")'
)
assert (
index_u
== 'CREATE UNIQUE INDEX IF NOT EXISTS "uid_category_name_8b0cb9" ON "category" ("name")'
)
else:
assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")'
assert index_u == 'CREATE UNIQUE INDEX "uid_category_name_8b0cb9" ON "category" ("name")'
else:
assert index == 'ALTER TABLE "category" ADD INDEX "idx_category_name_8b0cb9" ("name")'
assert (
index_u == 'ALTER TABLE "category" ADD UNIQUE INDEX "uid_category_name_8b0cb9" ("name")'
)
def test_drop_index():
@ -169,38 +198,35 @@ def test_drop_index():
ret_u = Migrate.ddl.drop_index(Category, ["name"], True)
if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` DROP INDEX `idx_category_name_8b0cb9`"
assert ret_u == "ALTER TABLE `category` DROP INDEX `uid_category_name_8b0cb9`"
elif isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'DROP INDEX "idx_category_name_8b0cb9"'
assert ret_u == 'DROP INDEX "uid_category_name_8b0cb9"'
assert ret_u == "ALTER TABLE `category` DROP INDEX `name`"
else:
assert ret == 'ALTER TABLE "category" DROP INDEX "idx_category_name_8b0cb9"'
assert ret_u == 'ALTER TABLE "category" DROP INDEX "uid_category_name_8b0cb9"'
assert ret == 'DROP INDEX IF EXISTS "idx_category_name_8b0cb9"'
assert ret_u == 'DROP INDEX IF EXISTS "uid_category_name_8b0cb9"'
def test_add_fk():
ret = Migrate.ddl.add_fk(
Category, Category._meta.fields_map.get("user").describe(False), User.describe(False)
Category, Category._meta.fields_map["owner"].describe(False), User.describe(False)
)
if isinstance(Migrate.ddl, MysqlDDL):
assert (
ret
== "ALTER TABLE `category` ADD CONSTRAINT `fk_category_user_e2e3874c` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE"
== "ALTER TABLE `category` ADD CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE"
)
else:
assert (
ret
== 'ALTER TABLE "category" ADD CONSTRAINT "fk_category_user_e2e3874c" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE'
== 'ALTER TABLE "category" ADD CONSTRAINT "fk_category_user_110d4c63" FOREIGN KEY ("owner_id") REFERENCES "user" ("id") ON DELETE CASCADE'
)
def test_drop_fk():
ret = Migrate.ddl.drop_fk(
Category, Category._meta.fields_map.get("user").describe(False), User.describe(False)
Category, Category._meta.fields_map["owner"].describe(False), User.describe(False)
)
if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` DROP FOREIGN KEY `fk_category_user_e2e3874c`"
assert ret == "ALTER TABLE `category` DROP FOREIGN KEY `fk_category_user_110d4c63`"
elif isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" DROP CONSTRAINT "fk_category_user_e2e3874c"'
assert ret == 'ALTER TABLE "category" DROP CONSTRAINT IF EXISTS "fk_category_user_110d4c63"'
else:
assert ret == 'ALTER TABLE "category" DROP FOREIGN KEY "fk_category_user_e2e3874c"'
assert ret == 'ALTER TABLE "category" DROP FOREIGN KEY "fk_category_user_110d4c63"'

106
tests/test_fake.py Normal file
View File

@ -0,0 +1,106 @@
from __future__ import annotations
import os
import re
from pathlib import Path
from tests._utils import Dialect, run_shell
def _append_field(*files: str, name="field_1") -> None:
for file in files:
p = Path(file)
field = f" {name} = fields.IntField(default=0)"
with p.open("a") as f:
f.write(os.linesep + field)
def test_fake(new_aerich_project):
if Dialect.is_sqlite():
# TODO: go ahead if sqlite alter-column supported
return
output = run_shell("aerich init -t settings.TORTOISE_ORM")
assert "Success" in output
output = run_shell("aerich init-db")
assert "Success" in output
output = run_shell("aerich --app models_second init-db")
assert "Success" in output
output = run_shell("pytest _tests.py::test_init_db")
assert "error" not in output.lower()
_append_field("models.py", "models_second.py")
output = run_shell("aerich migrate")
assert "Success" in output
output = run_shell("aerich --app models_second migrate")
assert "Success" in output
output = run_shell("aerich upgrade --fake")
assert "FAKED" in output
output = run_shell("aerich --app models_second upgrade --fake")
assert "FAKED" in output
output = run_shell("pytest _tests.py::test_fake_field_1")
assert "error" not in output.lower()
_append_field("models.py", "models_second.py", name="field_2")
output = run_shell("aerich migrate")
assert "Success" in output
output = run_shell("aerich --app models_second migrate")
assert "Success" in output
output = run_shell("aerich heads")
assert "_update.py" in output
output = run_shell("aerich upgrade --fake")
assert "FAKED" in output
output = run_shell("aerich --app models_second upgrade --fake")
assert "FAKED" in output
output = run_shell("pytest _tests.py::test_fake_field_2")
assert "error" not in output.lower()
output = run_shell("aerich heads")
assert "No available heads." in output
output = run_shell("aerich --app models_second heads")
assert "No available heads." in output
_append_field("models.py", "models_second.py", name="field_3")
run_shell("aerich migrate", capture_output=False)
run_shell("aerich --app models_second migrate", capture_output=False)
run_shell("aerich upgrade --fake", capture_output=False)
run_shell("aerich --app models_second upgrade --fake", capture_output=False)
output = run_shell("aerich downgrade --fake -v 2 --yes", input="y\n")
assert "FAKED" in output
output = run_shell("aerich --app models_second downgrade --fake -v 2 --yes", input="y\n")
assert "FAKED" in output
output = run_shell("aerich heads")
assert "No available heads." not in output
assert not re.search(r"1_\d+_update\.py", output)
assert re.search(r"2_\d+_update\.py", output)
output = run_shell("aerich --app models_second heads")
assert "No available heads." not in output
assert not re.search(r"1_\d+_update\.py", output)
assert re.search(r"2_\d+_update\.py", output)
output = run_shell("aerich downgrade --fake -v 1 --yes", input="y\n")
assert "FAKED" in output
output = run_shell("aerich --app models_second downgrade --fake -v 1 --yes", input="y\n")
assert "FAKED" in output
output = run_shell("aerich heads")
assert "No available heads." not in output
assert re.search(r"1_\d+_update\.py", output)
assert re.search(r"2_\d+_update\.py", output)
output = run_shell("aerich --app models_second heads")
assert "No available heads." not in output
assert re.search(r"1_\d+_update\.py", output)
assert re.search(r"2_\d+_update\.py", output)
output = run_shell("aerich upgrade --fake")
assert "FAKED" in output
output = run_shell("aerich --app models_second upgrade --fake")
assert "FAKED" in output
output = run_shell("aerich heads")
assert "No available heads." in output
output = run_shell("aerich --app models_second heads")
assert "No available heads." in output
output = run_shell("aerich downgrade --fake -v 1 --yes", input="y\n")
assert "FAKED" in output
output = run_shell("aerich --app models_second downgrade --fake -v 1 --yes", input="y\n")
assert "FAKED" in output
output = run_shell("aerich heads")
assert "No available heads." not in output
assert re.search(r"1_\d+_update\.py", output)
assert re.search(r"2_\d+_update\.py", output)
output = run_shell("aerich --app models_second heads")
assert "No available heads." not in output
assert re.search(r"1_\d+_update\.py", output)
assert re.search(r"2_\d+_update\.py", output)

17
tests/test_inspectdb.py Normal file
View File

@ -0,0 +1,17 @@
from tests._utils import Dialect, run_shell
def test_inspect(new_aerich_project):
if Dialect.is_sqlite():
# TODO: test sqlite after #384 fixed
return
run_shell("aerich init -t settings.TORTOISE_ORM")
run_shell("aerich init-db")
ret = run_shell("aerich inspectdb -t product")
assert ret.startswith("from tortoise import Model, fields")
assert "primary_key=True" in ret
assert "fields.DatetimeField" in ret
assert "fields.FloatField" in ret
assert "fields.UUIDField" in ret
if Dialect.is_mysql():
assert "db_index=True" in ret

View File

@ -1,13 +1,34 @@
import pytest
from pytest_mock import MockerFixture
from __future__ import annotations
from pathlib import Path
import pytest
import tortoise
from pytest_mock import MockerFixture
from tortoise.indexes import Index
from aerich._compat import tortoise_version_less_than
from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL
from aerich.exceptions import NotSupportError
from aerich.migrate import Migrate
from aerich.migrate import MIGRATE_TEMPLATE, Migrate
from aerich.utils import get_models_describe
from tests.indexes import CustomIndex
def describe_index(idx: Index) -> Index | dict:
# tortoise-orm>=0.24 changes Index desribe to be dict
if tortoise_version_less_than("0.24"):
return idx
if hasattr(idx, "describe"):
return idx.describe()
return idx
# tortoise-orm>=0.21 changes IntField constraints
# from {"ge": 1, "le": 2147483647} to {"ge": -2147483648, "le": 2147483647}
MIN_INT = 1 if tortoise.__version__ < "0.21" else -2147483648
old_models_describe = {
"models.Category": {
"name": "models.Category",
@ -17,7 +38,7 @@ old_models_describe = {
"description": None,
"docstring": None,
"unique_together": [],
"indexes": [],
"indexes": [describe_index(Index(fields=("slug",)))],
"pk_field": {
"name": "id",
"field_type": "IntField",
@ -30,7 +51,7 @@ old_models_describe = {
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"constraints": {"ge": MIN_INT, "le": 2147483647},
"db_field_types": {"": "INT"},
},
"data_fields": [
@ -97,9 +118,24 @@ old_models_describe = {
"default": None,
"description": "User",
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"constraints": {"ge": MIN_INT, "le": 2147483647},
"db_field_types": {"": "INT"},
},
{
"name": "title",
"field_type": "CharField",
"db_column": "title",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 20},
"db_field_types": {"": "VARCHAR(20)"},
},
],
"fk_fields": [
{
@ -154,21 +190,36 @@ old_models_describe = {
"unique_together": [],
"indexes": [],
"pk_field": {
"name": "id",
"field_type": "IntField",
"db_column": "id",
"python_type": "int",
"generated": True,
"name": "slug",
"field_type": "CharField",
"db_column": "slug",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"db_field_types": {"": "INT"},
"constraints": {"max_length": 10},
"db_field_types": {"": "VARCHAR(10)"},
},
"data_fields": [
{
"name": "name",
"field_type": "CharField",
"db_column": "name",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 100},
"db_field_types": {"": "VARCHAR(100)"},
},
{
"name": "label",
"field_type": "CharField",
@ -234,7 +285,48 @@ old_models_describe = {
"backward_fk_fields": [],
"o2o_fields": [],
"backward_o2o_fields": [],
"m2m_fields": [],
"m2m_fields": [
{
"name": "category",
"field_type": "ManyToManyFieldInstance",
"python_type": "models.Category",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"model_name": "models.Category",
"related_name": "configs",
"forward_key": "category_id",
"backward_key": "config_id",
"through": "config_category",
"on_delete": "CASCADE",
"_generated": False,
},
{
"name": "categories",
"field_type": "ManyToManyFieldInstance",
"python_type": "models.Category",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"model_name": "models.Category",
"related_name": "config_set",
"forward_key": "category_id",
"backward_key": "config_id",
"through": "config_category_map",
"on_delete": "CASCADE",
"_generated": False,
},
],
},
"models.Email": {
"name": "models.Email",
@ -257,7 +349,7 @@ old_models_describe = {
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"constraints": {"ge": MIN_INT, "le": 2147483647},
"db_field_types": {"": "INT"},
},
"data_fields": [
@ -276,6 +368,21 @@ old_models_describe = {
"constraints": {"max_length": 200},
"db_field_types": {"": "VARCHAR(200)"},
},
{
"name": "company",
"field_type": "CharField",
"db_column": "company",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 100},
"db_field_types": {"": "VARCHAR(100)"},
},
{
"name": "is_primary",
"field_type": "BooleanField",
@ -289,7 +396,12 @@ old_models_describe = {
"description": None,
"docstring": None,
"constraints": {},
"db_field_types": {"": "BOOL", "sqlite": "INT"},
"db_field_types": {
"": "BOOL",
"mssql": "BIT",
"oracle": "NUMBER(1)",
"sqlite": "INT",
},
},
{
"name": "user_id",
@ -303,7 +415,7 @@ old_models_describe = {
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"constraints": {"ge": MIN_INT, "le": 2147483647},
"db_field_types": {"": "INT"},
},
],
@ -350,7 +462,7 @@ old_models_describe = {
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"constraints": {"ge": MIN_INT, "le": 2147483647},
"db_field_types": {"": "INT"},
},
"data_fields": [
@ -369,6 +481,21 @@ old_models_describe = {
"constraints": {"max_length": 50},
"db_field_types": {"": "VARCHAR(50)"},
},
{
"name": "uid",
"field_type": "IntField",
"db_column": "uuid",
"python_type": "int",
"generated": False,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": -2147483648, "le": 2147483647},
"db_field_types": {"": "INT"},
},
{
"name": "view_num",
"field_type": "IntField",
@ -400,9 +527,9 @@ old_models_describe = {
"db_field_types": {"": "INT"},
},
{
"name": "is_reviewed",
"name": "is_review",
"field_type": "BooleanField",
"db_column": "is_reviewed",
"db_column": "is_review",
"python_type": "bool",
"generated": False,
"nullable": False,
@ -412,7 +539,12 @@ old_models_describe = {
"description": "Is Reviewed",
"docstring": None,
"constraints": {},
"db_field_types": {"": "BOOL", "sqlite": "INT"},
"db_field_types": {
"": "BOOL",
"mssql": "BIT",
"oracle": "NUMBER(1)",
"sqlite": "INT",
},
},
{
"name": "type",
@ -480,6 +612,26 @@ old_models_describe = {
"auto_now_add": True,
"auto_now": False,
},
{
"name": "is_delete",
"field_type": "BooleanField",
"db_column": "is_delete",
"python_type": "bool",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": False,
"description": None,
"docstring": None,
"constraints": {},
"db_field_types": {
"": "BOOL",
"mssql": "BIT",
"oracle": "NUMBER(1)",
"sqlite": "INT",
},
},
],
"fk_fields": [],
"backward_fk_fields": [],
@ -516,7 +668,10 @@ old_models_describe = {
"description": None,
"docstring": None,
"unique_together": [],
"indexes": [],
"indexes": [
describe_index(Index(fields=("username", "is_active"))),
describe_index(CustomIndex(fields=("is_superuser",))),
],
"pk_field": {
"name": "id",
"field_type": "IntField",
@ -529,7 +684,7 @@ old_models_describe = {
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"constraints": {"ge": MIN_INT, "le": 2147483647},
"db_field_types": {"": "INT"},
},
"data_fields": [
@ -597,7 +752,12 @@ old_models_describe = {
"description": "Is Active",
"docstring": None,
"constraints": {},
"db_field_types": {"": "BOOL", "sqlite": "INT"},
"db_field_types": {
"": "BOOL",
"mssql": "BIT",
"oracle": "NUMBER(1)",
"sqlite": "INT",
},
},
{
"name": "is_superuser",
@ -612,7 +772,12 @@ old_models_describe = {
"description": "Is SuperUser",
"docstring": None,
"constraints": {},
"db_field_types": {"": "BOOL", "sqlite": "INT"},
"db_field_types": {
"": "BOOL",
"mssql": "BIT",
"oracle": "NUMBER(1)",
"sqlite": "INT",
},
},
{
"name": "avatar",
@ -714,7 +879,7 @@ old_models_describe = {
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": 1, "le": 2147483647},
"constraints": {"ge": MIN_INT, "le": 2147483647},
"db_field_types": {"": "INT"},
},
"data_fields": [
@ -777,170 +942,247 @@ def test_migrate(mocker: MockerFixture):
"""
models.py diff with old_models.py
- change email pk: id -> email_id
- change product pk field type: IntField -> BigIntField
- change config pk field attribute: max_length=10 -> max_length=20
- add field: Email.address
- add fk: Config.user
- drop fk: Email.user
- add fk field: Config.user
- drop fk field: Email.user
- drop field: User.avatar
- add index: Email.email
- add unique to indexed field: Email.company
- change index type for indexed field: Email.slug
- add many to many: Email.users
- remove unique: User.username
- add one to one: Email.config
- remove unique: Category.title
- add unique: User.username
- change column: length User.password
- add unique_together: (name,type) of Product
- add one more many to many field: Product.users
- drop unique field: Config.name
- alter default: Config.status
- rename column: Product.image -> Product.pic
- rename column: Product.is_review -> Product.is_reviewed
- rename column: Product.is_delete -> Product.is_deleted
- rename fk column: Category.user -> Category.owner
"""
mocker.patch("click.prompt", side_effect=(True,))
mocker.patch("asyncclick.prompt", side_effect=(True, True, True, True))
models_describe = get_models_describe("models")
Migrate.app = "models"
if isinstance(Migrate.ddl, SqliteDDL):
with pytest.raises(NotSupportError):
Migrate.diff_models(old_models_describe, models_describe)
Migrate.upgrade_operators.clear()
with pytest.raises(NotSupportError):
Migrate.diff_models(models_describe, old_models_describe, False)
Migrate.downgrade_operators.clear()
else:
Migrate.diff_models(old_models_describe, models_describe)
Migrate.diff_models(models_describe, old_models_describe, False)
Migrate._merge_operators()
Migrate._merge_operators()
if isinstance(Migrate.ddl, MysqlDDL):
expected_upgrade_operators = {
"ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)",
"ALTER TABLE `category` MODIFY COLUMN `slug` VARCHAR(100) NOT NULL",
"ALTER TABLE `config` ADD `user_id` INT NOT NULL COMMENT 'User'",
"ALTER TABLE `category` DROP INDEX `title`",
"ALTER TABLE `category` RENAME COLUMN `user_id` TO `owner_id`",
"ALTER TABLE `category` ADD CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
"ALTER TABLE `category` ADD FULLTEXT INDEX `idx_category_slug_e9bcff` (`slug`)",
"ALTER TABLE `category` DROP INDEX `idx_category_slug_e9bcff`",
"ALTER TABLE `email` DROP COLUMN `user_id`",
"ALTER TABLE `config` DROP COLUMN `name`",
"ALTER TABLE `config` DROP INDEX `name`",
"ALTER TABLE `config` ADD `user_id` INT NOT NULL COMMENT 'User'",
"ALTER TABLE `config` ADD CONSTRAINT `fk_config_user_17daa970` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
"ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT",
"ALTER TABLE `config` MODIFY COLUMN `value` JSON NOT NULL",
"ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL",
"ALTER TABLE `email` DROP COLUMN `user_id`",
"ALTER TABLE `email` ADD CONSTRAINT `fk_email_config_88e28c1b` FOREIGN KEY (`config_id`) REFERENCES `config` (`slug`) ON DELETE CASCADE",
"ALTER TABLE `email` ADD `config_id` VARCHAR(20) NOT NULL UNIQUE",
"ALTER TABLE `email` DROP INDEX `idx_email_company_1c9234`, ADD UNIQUE (`company`)",
"ALTER TABLE `configs` RENAME TO `config`",
"ALTER TABLE `product` DROP COLUMN `uuid`",
"ALTER TABLE `product` DROP INDEX `uuid`",
"ALTER TABLE `product` RENAME COLUMN `image` TO `pic`",
"ALTER TABLE `product` ADD `price` DOUBLE",
"ALTER TABLE `product` ADD `no` CHAR(36) NOT NULL",
"ALTER TABLE `email` RENAME COLUMN `id` TO `email_id`",
"ALTER TABLE `product` ADD INDEX `idx_product_name_869427` (`name`, `type_db_alias`)",
"ALTER TABLE `product` ADD INDEX `idx_product_no_e4d701` (`no`)",
"ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)",
"ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_869427` (`name`, `type_db_alias`)",
"ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0",
"ALTER TABLE `product` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `product` MODIFY COLUMN `is_reviewed` BOOL NOT NULL COMMENT 'Is Reviewed'",
"ALTER TABLE `product` RENAME COLUMN `is_delete` TO `is_deleted`",
"ALTER TABLE `product` RENAME COLUMN `is_review` TO `is_reviewed`",
"ALTER TABLE `product` MODIFY COLUMN `id` BIGINT NOT NULL",
"ALTER TABLE `user` DROP COLUMN `avatar`",
"ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(100) NOT NULL",
"ALTER TABLE `user` MODIFY COLUMN `intro` LONGTEXT NOT NULL",
"ALTER TABLE `user` MODIFY COLUMN `last_login` DATETIME(6) NOT NULL COMMENT 'Last Login'",
"ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1",
"ALTER TABLE `user` MODIFY COLUMN `is_superuser` BOOL NOT NULL COMMENT 'Is SuperUser' DEFAULT 0",
"ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(10,8) NOT NULL",
"ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)",
"ALTER TABLE `user` ADD UNIQUE INDEX `username` (`username`)",
"CREATE TABLE `email_user` (\n `email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"CREATE TABLE IF NOT EXISTS `newmodel` (\n `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,\n `name` VARCHAR(50) NOT NULL\n) CHARACTER SET utf8mb4;",
"ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL",
"ALTER TABLE `email` MODIFY COLUMN `is_primary` BOOL NOT NULL DEFAULT 0",
"CREATE TABLE IF NOT EXISTS `newmodel` (\n `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,\n `name` VARCHAR(50) NOT NULL\n) CHARACTER SET utf8mb4",
"CREATE TABLE `product_user` (\n `product_id` BIGINT NOT NULL REFERENCES `product` (`id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"CREATE TABLE `config_category_map` (\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE,\n `config_id` VARCHAR(20) NOT NULL REFERENCES `config` (`slug`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"DROP TABLE IF EXISTS `config_category`",
"ALTER TABLE `config` MODIFY COLUMN `slug` VARCHAR(20) NOT NULL",
}
upgrade_operators = set(Migrate.upgrade_operators)
upgrade_more_than_expected = upgrade_operators - expected_upgrade_operators
assert not upgrade_more_than_expected
upgrade_less_than_expected = expected_upgrade_operators - upgrade_operators
assert not upgrade_less_than_expected
expected_downgrade_operators = {
"ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL",
"ALTER TABLE `category` MODIFY COLUMN `slug` VARCHAR(200) NOT NULL",
"ALTER TABLE `config` DROP COLUMN `user_id`",
"ALTER TABLE `category` ADD UNIQUE INDEX `title` (`title`)",
"ALTER TABLE `category` RENAME COLUMN `owner_id` TO `user_id`",
"ALTER TABLE `category` DROP FOREIGN KEY `fk_category_user_110d4c63`",
"ALTER TABLE `category` ADD INDEX `idx_category_slug_e9bcff` (`slug`)",
"ALTER TABLE `category` DROP INDEX `idx_category_slug_e9bcff`",
"ALTER TABLE `config` ADD `name` VARCHAR(100) NOT NULL UNIQUE",
"ALTER TABLE `config` ADD UNIQUE INDEX `name` (`name`)",
"ALTER TABLE `config` DROP FOREIGN KEY `fk_config_user_17daa970`",
"ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1",
"ALTER TABLE `config` DROP COLUMN `user_id`",
"ALTER TABLE `config` MODIFY COLUMN `slug` VARCHAR(10) NOT NULL",
"ALTER TABLE `config` RENAME TO `configs`",
"ALTER TABLE `email` ADD `user_id` INT NOT NULL",
"ALTER TABLE `email` DROP COLUMN `address`",
"ALTER TABLE `config` RENAME TO `configs`",
"ALTER TABLE `product` RENAME COLUMN `pic` TO `image`",
"ALTER TABLE `email` DROP COLUMN `config_id`",
"ALTER TABLE `email` DROP FOREIGN KEY `fk_email_config_88e28c1b`",
"ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`",
"ALTER TABLE `product` DROP INDEX `idx_product_name_869427`",
"ALTER TABLE `email` DROP INDEX `company`, ADD INDEX (`idx_email_company_1c9234`)",
"ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`",
"ALTER TABLE `product` RENAME COLUMN `pic` TO `image`",
"ALTER TABLE `product` ADD `uuid` INT NOT NULL UNIQUE",
"ALTER TABLE `product` ADD UNIQUE INDEX `uuid` (`uuid`)",
"ALTER TABLE `product` DROP INDEX `idx_product_name_869427`",
"ALTER TABLE `product` DROP COLUMN `price`",
"ALTER TABLE `product` DROP COLUMN `no`",
"ALTER TABLE `product` DROP INDEX `uid_product_name_869427`",
"ALTER TABLE `product` DROP INDEX `idx_product_no_e4d701`",
"ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT",
"ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''",
"ALTER TABLE `user` DROP INDEX `idx_user_usernam_9987ab`",
"ALTER TABLE `product` RENAME COLUMN `is_deleted` TO `is_delete`",
"ALTER TABLE `product` RENAME COLUMN `is_reviewed` TO `is_review`",
"ALTER TABLE `product` MODIFY COLUMN `id` INT NOT NULL",
"ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''",
"ALTER TABLE `user` DROP INDEX `username`",
"ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(200) NOT NULL",
"DROP TABLE IF EXISTS `email_user`",
"DROP TABLE IF EXISTS `newmodel`",
"ALTER TABLE `user` MODIFY COLUMN `intro` LONGTEXT NOT NULL",
"ALTER TABLE `config` MODIFY COLUMN `value` TEXT NOT NULL",
"ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `product` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `product` MODIFY COLUMN `is_reviewed` BOOL NOT NULL COMMENT 'Is Reviewed'",
"ALTER TABLE `user` MODIFY COLUMN `last_login` DATETIME(6) NOT NULL COMMENT 'Last Login'",
"ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1",
"ALTER TABLE `user` MODIFY COLUMN `is_superuser` BOOL NOT NULL COMMENT 'Is SuperUser' DEFAULT 0",
"DROP TABLE IF EXISTS `product_user`",
"ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(12,9) NOT NULL",
"ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL",
"ALTER TABLE `email` MODIFY COLUMN `is_primary` BOOL NOT NULL DEFAULT 0",
"CREATE TABLE `config_category` (\n `config_id` VARCHAR(20) NOT NULL REFERENCES `config` (`slug`) ON DELETE CASCADE,\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"DROP TABLE IF EXISTS `config_category_map`",
}
assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators)
assert not set(Migrate.downgrade_operators).symmetric_difference(
expected_downgrade_operators
)
downgrade_operators = set(Migrate.downgrade_operators)
downgrade_more_than_expected = downgrade_operators - expected_downgrade_operators
assert not downgrade_more_than_expected
downgrade_less_than_expected = expected_downgrade_operators - downgrade_operators
assert not downgrade_less_than_expected
elif isinstance(Migrate.ddl, PostgresDDL):
expected_upgrade_operators = {
'DROP INDEX IF EXISTS "uid_category_title_f7fc03"',
'ALTER TABLE "category" ALTER COLUMN "name" DROP NOT NULL',
'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(100) USING "slug"::VARCHAR(100)',
'ALTER TABLE "category" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ',
'ALTER TABLE "category" RENAME COLUMN "user_id" TO "owner_id"',
'ALTER TABLE "category" ADD CONSTRAINT "fk_category_user_110d4c63" FOREIGN KEY ("owner_id") REFERENCES "user" ("id") ON DELETE CASCADE',
'CREATE INDEX IF NOT EXISTS "idx_category_slug_e9bcff" ON "category" USING HASH ("slug")',
'DROP INDEX IF EXISTS "idx_category_slug_e9bcff"',
'ALTER TABLE "configs" RENAME TO "config"',
'ALTER TABLE "config" DROP COLUMN "name"',
'DROP INDEX IF EXISTS "uid_config_name_2c83c8"',
'ALTER TABLE "config" ADD "user_id" INT NOT NULL',
'ALTER TABLE "config" ADD CONSTRAINT "fk_config_user_17daa970" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE',
'ALTER TABLE "config" ALTER COLUMN "status" DROP DEFAULT',
'ALTER TABLE "config" ALTER COLUMN "value" TYPE JSONB USING "value"::JSONB',
'ALTER TABLE "configs" RENAME TO "config"',
'ALTER TABLE "config" ALTER COLUMN "slug" TYPE VARCHAR(20) USING "slug"::VARCHAR(20)',
'ALTER TABLE "email" ADD "config_id" VARCHAR(20) NOT NULL UNIQUE',
'ALTER TABLE "email" ADD "address" VARCHAR(200) NOT NULL',
'ALTER TABLE "email" DROP COLUMN "user_id"',
'ALTER TABLE "email" RENAME COLUMN "id" TO "email_id"',
'ALTER TABLE "email" ALTER COLUMN "is_primary" TYPE BOOL USING "is_primary"::BOOL',
'ALTER TABLE "email" DROP COLUMN "user_id"',
'ALTER TABLE "email" ADD CONSTRAINT "fk_email_config_88e28c1b" FOREIGN KEY ("config_id") REFERENCES "config" ("slug") ON DELETE CASCADE',
'DROP INDEX IF EXISTS "idx_email_company_1c9234"',
'CREATE UNIQUE INDEX IF NOT EXISTS "uid_email_company_1c9234" ON "email" ("company")',
'DROP INDEX IF EXISTS "uid_product_uuid_d33c18"',
'ALTER TABLE "product" DROP COLUMN "uuid"',
'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0',
'ALTER TABLE "product" RENAME COLUMN "image" TO "pic"',
'ALTER TABLE "product" ALTER COLUMN "is_reviewed" TYPE BOOL USING "is_reviewed"::BOOL',
'ALTER TABLE "product" ALTER COLUMN "body" TYPE TEXT USING "body"::TEXT',
'ALTER TABLE "product" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ',
'ALTER TABLE "product" RENAME COLUMN "is_review" TO "is_reviewed"',
'ALTER TABLE "product" RENAME COLUMN "is_delete" TO "is_deleted"',
'ALTER TABLE "product" ADD "price" DOUBLE PRECISION',
'ALTER TABLE "product" ADD "no" UUID NOT NULL',
'ALTER TABLE "product" ALTER COLUMN "id" TYPE BIGINT USING "id"::BIGINT',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(100) USING "password"::VARCHAR(100)',
'ALTER TABLE "user" DROP COLUMN "avatar"',
'ALTER TABLE "user" ALTER COLUMN "is_superuser" TYPE BOOL USING "is_superuser"::BOOL',
'ALTER TABLE "user" ALTER COLUMN "last_login" TYPE TIMESTAMPTZ USING "last_login"::TIMESTAMPTZ',
'ALTER TABLE "user" ALTER COLUMN "intro" TYPE TEXT USING "intro"::TEXT',
'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL USING "is_active"::BOOL',
'ALTER TABLE "user" ALTER COLUMN "longitude" TYPE DECIMAL(10,8) USING "longitude"::DECIMAL(10,8)',
'CREATE INDEX "idx_product_name_869427" ON "product" ("name", "type_db_alias")',
'CREATE INDEX "idx_email_email_4a1a33" ON "email" ("email")',
'CREATE INDEX IF NOT EXISTS "idx_product_name_869427" ON "product" ("name", "type_db_alias")',
'CREATE INDEX IF NOT EXISTS "idx_email_email_4a1a33" ON "email" ("email")',
'CREATE INDEX IF NOT EXISTS "idx_product_no_e4d701" ON "product" ("no")',
'CREATE TABLE "email_user" (\n "email_id" INT NOT NULL REFERENCES "email" ("email_id") ON DELETE CASCADE,\n "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE\n)',
'CREATE TABLE IF NOT EXISTS "newmodel" (\n "id" SERIAL NOT NULL PRIMARY KEY,\n "name" VARCHAR(50) NOT NULL\n);\nCOMMENT ON COLUMN "config"."user_id" IS \'User\';',
'CREATE UNIQUE INDEX "uid_product_name_869427" ON "product" ("name", "type_db_alias")',
'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")',
'CREATE TABLE IF NOT EXISTS "newmodel" (\n "id" SERIAL NOT NULL PRIMARY KEY,\n "name" VARCHAR(50) NOT NULL\n);\nCOMMENT ON COLUMN "config"."user_id" IS \'User\'',
'CREATE UNIQUE INDEX IF NOT EXISTS "uid_product_name_869427" ON "product" ("name", "type_db_alias")',
'CREATE UNIQUE INDEX IF NOT EXISTS "uid_user_usernam_9987ab" ON "user" ("username")',
'CREATE TABLE "product_user" (\n "product_id" BIGINT NOT NULL REFERENCES "product" ("id") ON DELETE CASCADE,\n "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE\n)',
'CREATE TABLE "config_category_map" (\n "category_id" INT NOT NULL REFERENCES "category" ("id") ON DELETE CASCADE,\n "config_id" VARCHAR(20) NOT NULL REFERENCES "config" ("slug") ON DELETE CASCADE\n)',
'DROP TABLE IF EXISTS "config_category"',
}
upgrade_operators = set(Migrate.upgrade_operators)
upgrade_more_than_expected = upgrade_operators - expected_upgrade_operators
assert not upgrade_more_than_expected
upgrade_less_than_expected = expected_upgrade_operators - upgrade_operators
assert not upgrade_less_than_expected
expected_downgrade_operators = {
'CREATE UNIQUE INDEX IF NOT EXISTS "uid_category_title_f7fc03" ON "category" ("title")',
'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL',
'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(200) USING "slug"::VARCHAR(200)',
'ALTER TABLE "category" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ',
'ALTER TABLE "category" RENAME COLUMN "owner_id" TO "user_id"',
'ALTER TABLE "category" DROP CONSTRAINT IF EXISTS "fk_category_user_110d4c63"',
'DROP INDEX IF EXISTS "idx_category_slug_e9bcff"',
'CREATE INDEX IF NOT EXISTS "idx_category_slug_e9bcff" ON "category" ("slug")',
'ALTER TABLE "config" ADD "name" VARCHAR(100) NOT NULL UNIQUE',
'CREATE UNIQUE INDEX IF NOT EXISTS "uid_config_name_2c83c8" ON "config" ("name")',
'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1',
'ALTER TABLE "config" DROP COLUMN "user_id"',
'ALTER TABLE "config" DROP CONSTRAINT "fk_config_user_17daa970"',
'ALTER TABLE "config" DROP CONSTRAINT IF EXISTS "fk_config_user_17daa970"',
'ALTER TABLE "config" RENAME TO "configs"',
'ALTER TABLE "config" ALTER COLUMN "value" TYPE JSONB USING "value"::JSONB',
'ALTER TABLE "config" DROP COLUMN "user_id"',
'ALTER TABLE "config" ALTER COLUMN "slug" TYPE VARCHAR(10) USING "slug"::VARCHAR(10)',
'ALTER TABLE "email" ADD "user_id" INT NOT NULL',
'ALTER TABLE "email" DROP COLUMN "address"',
'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"',
'ALTER TABLE "email" ALTER COLUMN "is_primary" TYPE BOOL USING "is_primary"::BOOL',
'ALTER TABLE "email" DROP COLUMN "config_id"',
'ALTER TABLE "email" DROP CONSTRAINT IF EXISTS "fk_email_config_88e28c1b"',
'CREATE INDEX IF NOT EXISTS "idx_email_company_1c9234" ON "email" ("company")',
'DROP INDEX IF EXISTS "uid_email_company_1c9234"',
'ALTER TABLE "product" ADD "uuid" INT NOT NULL UNIQUE',
'CREATE UNIQUE INDEX IF NOT EXISTS "uid_product_uuid_d33c18" ON "product" ("uuid")',
'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT',
'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"',
'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'',
'ALTER TABLE "product" RENAME COLUMN "is_deleted" TO "is_delete"',
'ALTER TABLE "product" RENAME COLUMN "is_reviewed" TO "is_review"',
'ALTER TABLE "product" DROP COLUMN "price"',
'ALTER TABLE "product" DROP COLUMN "no"',
'ALTER TABLE "product" ALTER COLUMN "id" TYPE INT USING "id"::INT',
'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(200) USING "password"::VARCHAR(200)',
'ALTER TABLE "user" ALTER COLUMN "last_login" TYPE TIMESTAMPTZ USING "last_login"::TIMESTAMPTZ',
'ALTER TABLE "user" ALTER COLUMN "is_superuser" TYPE BOOL USING "is_superuser"::BOOL',
'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL USING "is_active"::BOOL',
'ALTER TABLE "user" ALTER COLUMN "intro" TYPE TEXT USING "intro"::TEXT',
'ALTER TABLE "user" ALTER COLUMN "longitude" TYPE DECIMAL(12,9) USING "longitude"::DECIMAL(12,9)',
'ALTER TABLE "product" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ',
'ALTER TABLE "product" ALTER COLUMN "is_reviewed" TYPE BOOL USING "is_reviewed"::BOOL',
'ALTER TABLE "product" ALTER COLUMN "body" TYPE TEXT USING "body"::TEXT',
'DROP INDEX "idx_product_name_869427"',
'DROP INDEX "idx_email_email_4a1a33"',
'DROP INDEX "idx_user_usernam_9987ab"',
'DROP INDEX "uid_product_name_869427"',
'DROP TABLE IF EXISTS "product_user"',
'DROP INDEX IF EXISTS "idx_product_name_869427"',
'DROP INDEX IF EXISTS "idx_email_email_4a1a33"',
'DROP INDEX IF EXISTS "uid_user_usernam_9987ab"',
'DROP INDEX IF EXISTS "uid_product_name_869427"',
'DROP INDEX IF EXISTS "idx_product_no_e4d701"',
'DROP TABLE IF EXISTS "email_user"',
'DROP TABLE IF EXISTS "newmodel"',
'CREATE TABLE "config_category" (\n "config_id" VARCHAR(20) NOT NULL REFERENCES "config" ("slug") ON DELETE CASCADE,\n "category_id" INT NOT NULL REFERENCES "category" ("id") ON DELETE CASCADE\n)',
'DROP TABLE IF EXISTS "config_category_map"',
}
assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators)
assert not set(Migrate.downgrade_operators).symmetric_difference(
expected_downgrade_operators
)
downgrade_operators = set(Migrate.downgrade_operators)
downgrade_more_than_expected = downgrade_operators - expected_downgrade_operators
assert not downgrade_more_than_expected
downgrade_less_than_expected = expected_downgrade_operators - downgrade_operators
assert not downgrade_less_than_expected
elif isinstance(Migrate.ddl, SqliteDDL):
assert Migrate.upgrade_operators == []
@ -958,7 +1200,7 @@ def test_sort_all_version_files(mocker):
],
)
Migrate.migrate_location = "."
Migrate.migrate_location = Path(".")
assert Migrate.get_all_version_files() == [
"1_datetime_update.py",
@ -966,3 +1208,39 @@ def test_sort_all_version_files(mocker):
"10_datetime_update.py",
"11_datetime_update.py",
]
def test_sort_files_containing_non_migrations(mocker):
mocker.patch(
"os.listdir",
return_value=[
"1_datetime_update.py",
"11_datetime_update.py",
"10_datetime_update.py",
"2_datetime_update.py",
"not_a_migration.py",
"999.py",
"123foo_not_a_migration.py",
],
)
Migrate.migrate_location = Path(".")
assert Migrate.get_all_version_files() == [
"1_datetime_update.py",
"2_datetime_update.py",
"10_datetime_update.py",
"11_datetime_update.py",
]
async def test_empty_migration(mocker, tmp_path: Path) -> None:
mocker.patch("os.listdir", return_value=[])
Migrate.app = "foo"
expected_content = MIGRATE_TEMPLATE.format(upgrade_sql="", downgrade_sql="")
Migrate.migrate_location = tmp_path
migration_file = await Migrate.migrate("update", True)
f = tmp_path / migration_file
assert f.read_text() == expected_content

18
tests/test_python_m.py Normal file
View File

@ -0,0 +1,18 @@
import subprocess # nosec
from pathlib import Path
from aerich.version import __version__
from tests._utils import chdir, run_shell
def test_python_m_aerich():
assert __version__ in run_shell("python -m aerich --version")
def test_poetry_add(tmp_path: Path):
package = Path(__file__).parent.resolve().parent
with chdir(tmp_path):
subprocess.run(["poetry", "new", "foo"]) # nosec
with chdir("foo"):
r = subprocess.run(["poetry", "add", package]) # nosec
assert r.returncode == 0

View File

@ -0,0 +1,213 @@
from __future__ import annotations
import contextlib
import os
import platform
import shlex
import shutil
import subprocess
from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path
from tests._utils import Dialect, chdir, copy_files
def run_aerich(cmd: str) -> subprocess.CompletedProcess | None:
if not cmd.startswith("poetry") and not cmd.startswith("python"):
if not cmd.startswith("aerich"):
cmd = "aerich " + cmd
if platform.system() == "Windows":
cmd = "python -m " + cmd
r = None
with contextlib.suppress(subprocess.TimeoutExpired):
r = subprocess.run(shlex.split(cmd), timeout=2)
return r
def run_shell(cmd: str) -> subprocess.CompletedProcess:
envs = dict(os.environ, PYTHONPATH=".")
return subprocess.run(shlex.split(cmd), env=envs)
def _get_empty_db() -> Path:
if (db_file := Path("db.sqlite3")).exists():
db_file.unlink()
return db_file
@contextmanager
def prepare_sqlite_project(tmp_path: Path) -> Generator[tuple[Path, str]]:
test_dir = Path(__file__).parent
asset_dir = test_dir / "assets" / "sqlite_migrate"
with chdir(tmp_path):
files = ("models.py", "settings.py", "_tests.py")
copy_files(*(asset_dir / f for f in files), target_dir=Path())
models_py, settings_py, test_py = (Path(f) for f in files)
copy_files(asset_dir / "conftest_.py", target_dir=Path("conftest.py"))
_get_empty_db()
yield models_py, models_py.read_text("utf-8")
def test_close_tortoise_connections_patch(tmp_path: Path) -> None:
if not Dialect.is_sqlite():
return
with prepare_sqlite_project(tmp_path) as (models_py, models_text):
run_aerich("aerich init -t settings.TORTOISE_ORM")
r = run_aerich("aerich init-db")
assert r is not None
def test_sqlite_migrate_alter_indexed_unique(tmp_path: Path) -> None:
if not Dialect.is_sqlite():
return
with prepare_sqlite_project(tmp_path) as (models_py, models_text):
models_py.write_text(models_text.replace("db_index=False", "db_index=True"))
run_aerich("aerich init -t settings.TORTOISE_ORM")
run_aerich("aerich init-db")
r = run_shell("pytest -s _tests.py::test_allow_duplicate")
assert r.returncode == 0
models_py.write_text(models_text.replace("db_index=False", "unique=True"))
run_aerich("aerich migrate") # migrations/models/1_
run_aerich("aerich upgrade")
r = run_shell("pytest _tests.py::test_unique_is_true")
assert r.returncode == 0
models_py.write_text(models_text.replace("db_index=False", "db_index=True"))
run_aerich("aerich migrate") # migrations/models/2_
run_aerich("aerich upgrade")
r = run_shell("pytest -s _tests.py::test_allow_duplicate")
assert r.returncode == 0
M2M_WITH_CUSTOM_THROUGH = """
groups = fields.ManyToManyField("models.Group", through="foo_group")
class Group(Model):
name = fields.CharField(max_length=60)
class FooGroup(Model):
foo = fields.ForeignKeyField("models.Foo")
group = fields.ForeignKeyField("models.Group")
is_active = fields.BooleanField(default=False)
class Meta:
table = "foo_group"
"""
def test_sqlite_migrate(tmp_path: Path) -> None:
if not Dialect.is_sqlite():
return
with prepare_sqlite_project(tmp_path) as (models_py, models_text):
MODELS = models_text
run_aerich("aerich init -t settings.TORTOISE_ORM")
config_file = Path("pyproject.toml")
modify_time = config_file.stat().st_mtime
run_aerich("aerich init-db")
run_aerich("aerich init -t settings.TORTOISE_ORM")
assert modify_time == config_file.stat().st_mtime
r = run_shell("pytest _tests.py::test_allow_duplicate")
assert r.returncode == 0
# Add index
models_py.write_text(MODELS.replace("index=False", "index=True"))
run_aerich("aerich migrate") # migrations/models/1_
run_aerich("aerich upgrade")
r = run_shell("pytest -s _tests.py::test_allow_duplicate")
assert r.returncode == 0
# Drop index
models_py.write_text(MODELS)
run_aerich("aerich migrate") # migrations/models/2_
run_aerich("aerich upgrade")
r = run_shell("pytest -s _tests.py::test_allow_duplicate")
assert r.returncode == 0
# Add unique index
models_py.write_text(MODELS.replace("index=False", "index=True, unique=True"))
run_aerich("aerich migrate") # migrations/models/3_
run_aerich("aerich upgrade")
r = run_shell("pytest _tests.py::test_unique_is_true")
assert r.returncode == 0
# Drop unique index
models_py.write_text(MODELS)
run_aerich("aerich migrate") # migrations/models/4_
run_aerich("aerich upgrade")
r = run_shell("pytest _tests.py::test_allow_duplicate")
assert r.returncode == 0
# Add field with unique=True
with models_py.open("a") as f:
f.write(" age = fields.IntField(unique=True, default=0)")
run_aerich("aerich migrate") # migrations/models/5_
run_aerich("aerich upgrade")
r = run_shell("pytest _tests.py::test_add_unique_field")
assert r.returncode == 0
# Drop unique field
models_py.write_text(MODELS)
run_aerich("aerich migrate") # migrations/models/6_
run_aerich("aerich upgrade")
r = run_shell("pytest -s _tests.py::test_drop_unique_field")
assert r.returncode == 0
# Initial with indexed field and then drop it
migrations_dir = Path("migrations/models")
shutil.rmtree(migrations_dir)
db_file = _get_empty_db()
models_py.write_text(MODELS + " age = fields.IntField(db_index=True)")
run_aerich("aerich init -t settings.TORTOISE_ORM")
run_aerich("aerich init-db")
migration_file = list(migrations_dir.glob("0_*.py"))[0]
assert "CREATE INDEX" in migration_file.read_text()
r = run_shell("pytest _tests.py::test_with_age_field")
assert r.returncode == 0
models_py.write_text(MODELS)
run_aerich("aerich migrate")
run_aerich("aerich upgrade")
migration_file_1 = list(migrations_dir.glob("1_*.py"))[0]
assert "DROP INDEX" in migration_file_1.read_text()
r = run_shell("pytest _tests.py::test_without_age_field")
assert r.returncode == 0
# Generate migration file in emptry directory
db_file.unlink()
run_aerich("aerich init-db")
assert not db_file.exists()
for p in migrations_dir.glob("*"):
if p.is_dir():
shutil.rmtree(p)
else:
p.unlink()
run_aerich("aerich init-db")
assert db_file.exists()
# init without '[tool]' section in pyproject.toml
config_file = Path("pyproject.toml")
config_file.write_text('[project]\nname = "project"')
run_aerich("init -t settings.TORTOISE_ORM")
assert "[tool.aerich]" in config_file.read_text()
# add m2m with custom model for through
models_py.write_text(MODELS + M2M_WITH_CUSTOM_THROUGH)
run_aerich("aerich migrate")
run_aerich("aerich upgrade")
migration_file_1 = list(migrations_dir.glob("1_*.py"))[0]
assert "foo_group" in migration_file_1.read_text()
r = run_shell("pytest _tests.py::test_m2m_with_custom_through")
assert r.returncode == 0
# add m2m field after init-db
new = """
groups = fields.ManyToManyField("models.Group", through="foo_group", related_name="users")
class Group(Model):
name = fields.CharField(max_length=60)
"""
_get_empty_db()
if migrations_dir.exists():
shutil.rmtree(migrations_dir)
models_py.write_text(MODELS)
run_aerich("aerich init-db")
models_py.write_text(MODELS + new)
run_aerich("aerich migrate")
run_aerich("aerich upgrade")
migration_file_1 = list(migrations_dir.glob("1_*.py"))[0]
assert "foo_group" in migration_file_1.read_text()
r = run_shell("pytest _tests.py::test_add_m2m_field_after_init_db")
assert r.returncode == 0

View File

@ -1,6 +1,164 @@
from aerich.utils import import_py_file
from aerich.utils import get_dict_diff_by_key, import_py_file
def test_import_py_file():
def test_import_py_file() -> None:
m = import_py_file("aerich/utils.py")
assert getattr(m, "import_py_file")
assert getattr(m, "import_py_file", None)
class TestDiffFields:
def test_the_same_through_order(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "members", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert type(get_dict_diff_by_key(old, new)).__name__ == "generator"
assert len(diffs) == 1
assert diffs == [("change", [0, "name"], ("users", "members"))]
def test_same_through_with_different_orders(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "admins", "through": "admins_group"},
{"name": "members", "through": "users_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 1
assert diffs == [("change", [0, "name"], ("users", "members"))]
def test_the_same_field_name_order(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "users", "through": "user_groups"},
{"name": "admins", "through": "admin_groups"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 4
assert diffs == [
("remove", "", [(0, {"name": "users", "through": "users_group"})]),
("remove", "", [(0, {"name": "admins", "through": "admins_group"})]),
("add", "", [(0, {"name": "users", "through": "user_groups"})]),
("add", "", [(0, {"name": "admins", "through": "admin_groups"})]),
]
def test_same_field_name_with_different_orders(self) -> None:
old = [
{"name": "admins", "through": "admins_group"},
{"name": "users", "through": "users_group"},
]
new = [
{"name": "users", "through": "user_groups"},
{"name": "admins", "through": "admin_groups"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 4
assert diffs == [
("remove", "", [(0, {"name": "admins", "through": "admins_group"})]),
("remove", "", [(0, {"name": "users", "through": "users_group"})]),
("add", "", [(0, {"name": "users", "through": "user_groups"})]),
("add", "", [(0, {"name": "admins", "through": "admin_groups"})]),
]
def test_drop_one(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "admins", "through": "admins_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 1
assert diffs == [("remove", "", [(0, {"name": "users", "through": "users_group"})])]
def test_add_one(self) -> None:
old = [
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 1
assert diffs == [("add", "", [(0, {"name": "users", "through": "users_group"})])]
def test_drop_some(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
{"name": "staffs", "through": "staffs_group"},
]
new = [
{"name": "admins", "through": "admins_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 2
assert diffs == [
("remove", "", [(0, {"name": "users", "through": "users_group"})]),
("remove", "", [(0, {"name": "staffs", "through": "staffs_group"})]),
]
def test_add_some(self) -> None:
old = [
{"name": "staffs", "through": "staffs_group"},
]
new = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
{"name": "staffs", "through": "staffs_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 2
assert diffs == [
("add", "", [(0, {"name": "users", "through": "users_group"})]),
("add", "", [(0, {"name": "admins", "through": "admins_group"})]),
]
def test_some_through_unchanged(self) -> None:
old = [
{"name": "staffs", "through": "staffs_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "users", "through": "users_group"},
{"name": "admins_new", "through": "admins_group"},
{"name": "staffs_new", "through": "staffs_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 3
assert diffs == [
("change", [0, "name"], ("staffs", "staffs_new")),
("change", [0, "name"], ("admins", "admins_new")),
("add", "", [(0, {"name": "users", "through": "users_group"})]),
]
def test_some_unchanged_without_drop_or_add(self) -> None:
old = [
{"name": "staffs", "through": "staffs_group"},
{"name": "admins", "through": "admins_group"},
{"name": "users", "through": "users_group"},
]
new = [
{"name": "users_new", "through": "users_group"},
{"name": "admins_new", "through": "admins_group"},
{"name": "staffs_new", "through": "staffs_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 3
assert diffs == [
("change", [0, "name"], ("staffs", "staffs_new")),
("change", [0, "name"], ("admins", "admins_new")),
("change", [0, "name"], ("users", "users_new")),
]