91 Commits

Author SHA1 Message Date
Waket Zheng
d6a51bd20e docs: release v0.8.1 2024-12-27 12:34:55 +08:00
Waket Zheng
c1dea4e846 chore: upgrade deps, add tortoise0.23 to ci (#399) 2024-12-27 12:30:34 +08:00
Waket Zheng
5e8a7c7e91 fix: migration with duplicate renaming of columns in some cases (#395)
* fix: migration with duplicate renaming of columns in some cases

* Update var name

* fix downgrade sql error

* fix test error

* docs: update changelog

* Add unittest

* Move not change line to origin position

* Update sort key to make it more frieldly interactive from multi fields rename

* refactor: remove puzzle vars

* docs: fix PR links in changelog

* fix sort key lambda error
2024-12-27 12:09:23 +08:00
Waket Zheng
7d22518c74 Fix create/drop indexes in every migration (#377)
* Add `__eq__` method for `Index`instances

* tests: add Index test case

* refactor: compare index instances before set hash and eq func to class

* fix: sort fields when generating index hash

* docs: update changlog

* fix style issue

* refactor: use CustomIndex instead of postgres special HashIndex

* Check tortoise version before patch Index

* Add comment

* Add comment for why > work

---------

Co-authored-by: dbf <somnium@riseup.net>
2024-12-22 00:24:18 +08:00
Waket Zheng
f93faa8afb fix: add o2o field does not create constraint when migrating (#396)
* fix: add o2o field does not create constraint when migrating

* Add testcase and update changelog

* docs: update migrating list

* refactor: use `_handle_o2o_fields` instead of `is_o2o=True`

* Remove unused line
2024-12-22 00:23:47 +08:00
Waket Zheng
1acb9ed1e7 chore(deps): limit tortoise-orm version to >=0.21 instead of wildcard (*) (#388)
* Limit tortoise-orm version to `>=0.21` instead of wildcard (*)

* ci: fix typos

* docs: update changelog

* refactor: import MEMORY_SQLITE from tortoise

* Update changelog
2024-12-21 11:04:02 +08:00
Waket Zheng
69ce0cafa1 fix: intermediate table for m2m relation not created (#394)
* fix: intermediate table for m2m relation not created

* Add unittest

* docs: update changelog
2024-12-18 00:13:19 +08:00
Waket Zheng
4fc7f324d4 fix: add m2m field with custom m2m through generate duplicated table when migrating (#393)
* fix: m2m table duplicated when using custom model for through

* Add testcase

* docs: update changelog

* tests: add m2m custom through example test
2024-12-17 22:28:06 +08:00
Waket Zheng
d8addadb37 chore(deps): prefer to use tomllib/tomli and mark tomlkit/tomli_w as optional (#392)
* chore(deps): prefer to use tomllib and mark tomlkit as optional

* docs: add toml extra to install command

* docs: update changelog
2024-12-17 01:43:36 +08:00
Waket Zheng
0780919ef3 fix: migrate drop the wrong m2m field when model have multi m2m fields (#390)
* fix: migrate drop the wrong m2m field when model have multi m2m fields

* Make style and update changelog

* refactor: return new lists instead of change argument values in function

* refactor: use custom diff function instead of reorder lists

* docs: fix typo

* Fix hardcoded and rename custom diff function

* Update function doc
2024-12-17 01:20:02 +08:00
Waket Zheng
5af8c9cd56 docs: update changelog 2024-12-16 11:49:52 +08:00
Waket Zheng
56da0e7e3c fix: KeyError raised when removing or renaming an existing model 2024-12-12 10:43:40 +08:00
Lance.Moe
6270c4781e fix: error when there is __init__.py in the migration folder (#272)
* fix: error when there is __init__.py in the migration folder

* fix: check __init__.py in the migration folder

* refactor

* refactor & add test

* refactor

* Update changelog

---------

Co-authored-by: Waket Zheng <waketzheng@gmail.com>
2024-12-11 21:34:14 +08:00
Waket Zheng
12d0a5dad1 fix: setting null=false on m2m field causes migration to fail (#385)
* fix: setting null=false on m2m field causes migration to fail

* Update changelog
2024-12-11 21:12:04 +08:00
Waket Zheng
56eff1b22f chore: use pytest instead of py.test in Makefile 2024-12-11 16:44:18 +08:00
Waket Zheng
e4a3863f80 fix: aerich upgrade raises OperationalError when unique constraint dropped at migration 1_xxx.py with postgres (#383) 2024-12-11 15:15:29 +08:00
Waket Zheng
5572876714 fix: NonExistentKey when running aerich init without [tool] section in config file (#381)
* fix: NonExistentKey when running `aerich init` without `[tool]` section in config file

* docs: update changelog
2024-12-11 13:26:18 +08:00
Waket Zheng
3d840395f1 docs: update changlog 2024-12-10 23:11:13 +08:00
Tuffy_
accceef24f Fixed two problems when using under windows (#286)
* fix: Fixed an issue where an error would occur when using aerich in windows if the profile contained Chinese characters

* fix: Automatically delete the empty migration directory of the app if the init-db operation fails

* feat: generate migration file in empty directory instead of abort with warning

* tests: fix test fail in ci

---------

Co-authored-by: Waket Zheng <waketzheng@gmail.com>
2024-12-10 23:02:49 +08:00
Waket Zheng
9c81bc6036 Fix sqlite create/drop index (#379)
* Update add/drop index template for sqlite

* tests: add sqlite migrate/upgrade command test

* tests: add timeout for sqlite migrate command test

* tests: add test cases for add/drop unique field for sqlite

* fix: sqlite failed to add unique field
2024-12-10 16:37:30 +08:00
Waket Zheng
c2ebe9b5e4 Fix postgres fk rename (#378)
* Update postgres drop fk template

* fix test error

* docs: update changlog
2024-12-09 11:41:40 +08:00
Mykola Solodukha
8cefe68c9b [BUG] Sort m2m fields before comparing them with diff(...) (#271)
* 🐛 Sort m2m fields before comparing them with `diff(...)`

* Add test case and upgrade changelog

---------

Co-authored-by: Waket Zheng <waketzheng@gmail.com>
2024-12-05 17:41:58 +08:00
Waket Zheng
44025823ee chore: upgrade deps and fix ruff lint issues (#374)
* chore: upgrade deps and apply ruff lint for tests/

* style: fix ruff lint issues
2024-12-05 15:56:00 +08:00
Waket Zheng
252cb97767 Bump version to 0.8.0 and update changelog 2024-12-04 00:14:35 +08:00
Waket Zheng
ac3ef9e2eb tests: no need to merge operators when NotSupportError raised (#373) 2024-12-03 23:59:54 +08:00
Waket Zheng
5b04b4422d chore: upgrade deps, update changelog, drop test db before create (#372)
* chore: upgrade deps, update changelog, drop test db before create

* tests: clear operators for sqlite ddl after NotSupportError raised
2024-12-03 23:30:37 +08:00
Waket Zheng
b2f4029a4a Improve type hints of inspectdb (#371) 2024-12-03 12:40:28 +08:00
xsillen
4e46d9d969 Fix the issue of parameter concatenation when generating ORM with inspectdb (#331)
Co-authored-by: floodpillar <165008032+floodpillar@users.noreply.github.com>
2024-12-03 11:44:36 +08:00
gck123
c0fd3ec63c Fix KeyError when deleting a field with unqiue=True (#365)
* Fix KeyError when deleting a field with unqiue=True

* refactor: rename `old_data_unique` to `is_unique_field`

* Add testcases for remove unique field

---------

Co-authored-by: gongchangku <gongchangku@anban.tech>
Co-authored-by: Waket Zheng <waketzheng@gmail.com>
2024-12-03 10:16:44 +08:00
Waket Zheng
103470f4c1 Merge pull request #367 from waketzheng/fix-style-issue
chore: make style, upgrade deps, fix ci error and update changelog
2024-11-29 15:09:25 +08:00
Waket Zheng
dc020358b6 chore: make style, upgrade deps, fix ci error and update changelog 2024-11-25 23:46:48 +08:00
Waket Zheng
095eb48196 Merge pull request #360 from ahmetveburak/fix-package
fix(package): correct the click import
2024-11-25 18:10:27 +08:00
He
fac1de6299 Merge pull request #355 from merlinz01/improve-cli-descriptions
Improve CLI help text and output
2024-11-18 21:25:13 +01:00
ahmetveburak
e049fcee26 fix(package): correct the click import 2024-10-05 20:43:32 +03:00
Merlin
ee144d4a9b update migrate output for consistency 2024-08-07 13:32:49 -04:00
Merlin
dbf96a17d3 Improve command-line descriptions
Changed CLI help texts for some of the options to make them clearer.
2024-08-07 13:31:07 -04:00
long2ice
15d56121ef fix: migrate 2024-08-06 22:46:50 +08:00
long2ice
4851ecfe82 refactor: use asyncclick 2024-08-06 22:41:39 +08:00
long2ice
ad0e7006bc Merge pull request #348 from waketzheng/update-changelog
Update changelog
2024-06-07 11:25:10 +08:00
Waket Zheng
27b29e401b Merge remote-tracking branch 'upstream/dev' into update-changelog 2024-06-07 11:11:47 +08:00
Waket Zheng
7bb35c10e4 Add tag link to change log 2024-06-07 11:08:49 +08:00
long2ice
ed113d491e Merge pull request #347 from waketzheng/add-mypy-check-to-ci
Use mypy for type hint check in ci
2024-06-06 21:38:23 +08:00
Waket Zheng
9e46fbf55d rollback ci command 2024-06-06 18:20:57 +08:00
Waket Zheng
fc68f99c86 fixing cache action 2024-06-06 18:02:58 +08:00
Waket Zheng
480087df07 Add id for cache step 2024-06-06 17:38:51 +08:00
Waket Zheng
24a2087b78 Skip make deps if cache hint in ci 2024-06-06 17:35:00 +08:00
Waket Zheng
bceeb236c2 Checking cache action 2024-06-06 17:21:46 +08:00
Waket Zheng
c42fdab74d Add cache action to ci 2024-06-06 17:11:27 +08:00
Waket Zheng
ceb1e0ffef Activate type hint check in ci 2024-06-06 17:07:09 +08:00
long2ice
13dd44bef7 Merge pull request #340 from waketzheng/type-hint-tests
Improve type hints for tests/
2024-06-06 16:53:01 +08:00
long2ice
219633a926 Merge pull request #341 from waketzheng/type-hint-simple
Simple type hints for aerich/
2024-06-06 16:52:30 +08:00
long2ice
e6302a920a Merge pull request #342 from waketzheng/type-hint-ddl
Improve type hints for ddl and inspectdb
2024-06-06 16:51:37 +08:00
long2ice
79a77d368f Merge pull request #343 from waketzheng/type-hints-aerich.migrate
Add type hints for aerich.migrate
2024-06-06 16:50:53 +08:00
long2ice
4a1fc4cfa0 Merge pull request #339 from waketzheng/update-ci
Update ci action versions and avoid `make ci` install deps twice
2024-06-06 16:50:39 +08:00
Waket Zheng
aee706e29b Update changelog 2024-06-06 16:43:08 +08:00
Waket Zheng
572c13f9dd fix conflict 2024-06-06 16:36:29 +08:00
Waket Zheng
7b733495fb fix conflict 2024-06-06 16:35:09 +08:00
Waket Zheng
c0c217392c Merge remote-tracking branch 'upstream/dev' into type-hint-simple 2024-06-06 16:32:00 +08:00
Waket Zheng
c7a3d164cb fix conflict 2024-06-06 16:31:24 +08:00
Waket Zheng
50add58981 merge dev 2024-06-06 16:27:06 +08:00
long2ice
f3b6f8f264 Merge pull request #345 from waketzheng/dev
Drop python3.7 support
2024-06-06 15:49:31 +08:00
long2ice
d33638471b Merge pull request #346 from waketzheng/fix-336
fix: mysql drop unique index with error name
2024-06-06 15:48:50 +08:00
Waket Zheng
e764bb56f7 Add new column for unique index remove test 2024-06-06 09:07:13 +08:00
Waket Zheng
84d31d63f6 Update readme 2024-06-06 08:42:55 +08:00
Waket Zheng
234495d291 docs: bump up version and update changelog 2024-06-06 00:07:29 +08:00
Waket Zheng
e971653851 fix: mysql drop unique index migrate error 2024-06-05 23:37:30 +08:00
Waket Zheng
58d31b3a05 fix: make ci error with latest tortoise-orm (#344) 2024-06-04 01:44:28 +08:00
Waket Zheng
affffbdae3 Drop python3.7 support 2024-06-03 01:26:48 +08:00
Waket Zheng
6466a852c8 Add type hints for aerich.migrate 2024-06-02 18:09:34 +08:00
Waket Zheng
a917f253c9 Add type hints for inspectdb/sqlite 2024-06-02 17:56:54 +08:00
Waket Zheng
dd11bed5a0 Add type hints for ddl and inspectdb 2024-06-01 22:14:30 +08:00
Waket Zheng
8756f64e3f Simple type hints for aerich/ 2024-06-01 21:16:53 +08:00
Waket Zheng
1addda8178 Improve type hints for tests/ 2024-06-01 20:58:25 +08:00
Waket Zheng
716638752b Update ci action versions and avoid make ci install deps twice 2024-06-01 20:46:08 +08:00
long2ice
51117867a6 Merge pull request #328 from Fl0kse/fix_issue-150
try to fix "Add ManyToManyField will break migrate #150" issues
2024-01-23 22:08:08 +08:00
artem.ulchenko
b25c7671bb try to fix "Add ManyToManyField will break migrate #150" issues 2024-01-18 11:40:59 +03:00
long2ice
b63fbcc7a4 style: fix 2024-01-18 09:50:41 +08:00
long2ice
4370b5ed08 Merge pull request #327 from ar0ne/dev
Added an option to generate empty migration file
2024-01-18 09:46:02 +08:00
ar0ne
2b2e465362 missed changes that add new flag to cli 2023-12-27 12:56:46 +05:30
ar0ne
ede53ade86 added note in readme 2023-12-26 23:20:29 +05:30
ar0ne
ad54b5e9dd remove redundant comma for empty migration 2023-12-26 23:12:11 +05:30
ar0ne
b1ff2418f5 added option to generate empty migration file 2023-12-26 22:55:51 +05:30
long2ice
01264f3f27 fix: pydantic v2. (#322) 2023-08-30 10:04:53 +08:00
long2ice
ea234a5799 fix: default empty str 2023-08-23 10:53:19 +08:00
long2ice
b724f24f1a Merge pull request #319 from strayge/pr/editable-install
Update poetry build backend to support editable install
2023-08-22 10:00:03 +08:00
long2ice
2bc23103dc Merge pull request #320 from strayge/pr/column-description-diff
Fix changed column descriptions in diffs
2023-08-22 09:59:46 +08:00
strayge
6d83c370ad fix changed column descriptions in diffs 2023-08-21 10:58:43 +04:00
strayge
e729bb9b60 update poetry build-backend to support editable install 2023-08-21 10:58:26 +04:00
long2ice
8edd834da6 refactor: make in_transaction default True 2023-08-04 10:36:27 +08:00
long2ice
4adc89d441 Merge pull request #306 from plusiv/add-char-support-mysql
add char support for mysql
2023-07-27 12:52:01 +08:00
Jorge Massih
fd4b9fe7d3 add char support for mysql 2023-06-15 00:40:29 -04:00
31 changed files with 2466 additions and 1241 deletions

View File

@@ -18,17 +18,45 @@ jobs:
POSTGRES_PASSWORD: 123456 POSTGRES_PASSWORD: 123456
POSTGRES_USER: postgres POSTGRES_USER: postgres
options: --health-cmd=pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 options: --health-cmd=pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
tortoise-orm:
- tortoise021
- tortoise022
- tortoise023
- tortoisedev
steps: steps:
- name: Start MySQL - name: Start MySQL
run: sudo systemctl start mysql.service run: sudo systemctl start mysql.service
- uses: actions/checkout@v2 - uses: actions/cache@v4
- uses: actions/setup-python@v2
with: with:
python-version: '3.x' path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/poetry.lock') }}
restore-keys: |
${{ runner.os }}-pip-
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install and configure Poetry - name: Install and configure Poetry
run: | run: |
pip install -U pip poetry pip install -U pip poetry
poetry config virtualenvs.create false poetry config virtualenvs.create false
- name: Install dependencies and check style
run: make check
- name: Install TortoiseORM v0.21
if: matrix.tortoise-orm == 'tortoise021'
run: poetry run pip install --upgrade "tortoise-orm>=0.21,<0.22"
- name: Install TortoiseORM v0.22
if: matrix.tortoise-orm == 'tortoise022'
run: poetry run pip install --upgrade "tortoise-orm>=0.22,<0.23"
- name: Install TortoiseORM v0.23
if: matrix.tortoise-orm == 'tortoise023'
run: poetry run pip install --upgrade "tortoise-orm>=0.23,<0.24"
- name: Install TortoiseORM develop branch
if: matrix.tortoise-orm == 'tortoisedev'
run: poetry run pip install --upgrade "git+https://github.com/tortoise/tortoise-orm"
- name: CI - name: CI
env: env:
MYSQL_PASS: root MYSQL_PASS: root
@@ -37,4 +65,4 @@ jobs:
POSTGRES_PASS: 123456 POSTGRES_PASS: 123456
POSTGRES_HOST: 127.0.0.1 POSTGRES_HOST: 127.0.0.1
POSTGRES_PORT: 5432 POSTGRES_PORT: 5432
run: make ci run: make _testall

View File

@@ -7,8 +7,8 @@ jobs:
publish: publish:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v4
- uses: actions/setup-python@v2 - uses: actions/setup-python@v5
with: with:
python-version: '3.x' python-version: '3.x'
- name: Install and configure Poetry - name: Install and configure Poetry

View File

@@ -1,8 +1,63 @@
# ChangeLog # ChangeLog
## 0.8
### [0.8.1](../../releases/tag/v0.8.1) - 2024-12-27
#### Fixed
- fix: add o2o field does not create constraint when migrating. ([#396])
- Migration with duplicate renaming of columns in some cases. ([#395])
- fix: intermediate table for m2m relation not created. ([#394])
- Migrate add m2m field with custom through generate duplicated table. ([#393])
- Migrate drop the wrong m2m field when model have multi m2m fields. ([#376])
- KeyError raised when removing or renaming an existing model. ([#386])
- fix: error when there is `__init__.py` in the migration folder. ([#272])
- Setting null=false on m2m field causes migration to fail. ([#334])
- Fix NonExistentKey when running `aerich init` without `[tool]` section in config file. ([#284])
- Fix configuration file reading error when containing Chinese characters. ([#286])
- sqlite: failed to create/drop index. ([#302])
- PostgreSQL: Cannot drop constraint after deleting or rename FK on a model. ([#378])
- Fix create/drop indexes in every migration. ([#377])
- Sort m2m fields before comparing them with diff. ([#271])
#### Changed
- Allow run `aerich init-db` with empty migration directories instead of abort with warnings. ([#286])
- Add version constraint(>=0.21) for tortoise-orm. ([#388])
- Move `tomlkit` to optional and support `pip install aerich[toml]`. ([#392])
[#396]: https://github.com/tortoise/aerich/pull/396
[#395]: https://github.com/tortoise/aerich/pull/395
[#394]: https://github.com/tortoise/aerich/pull/394
[#393]: https://github.com/tortoise/aerich/pull/393
[#376]: https://github.com/tortoise/aerich/pull/376
[#386]: https://github.com/tortoise/aerich/pull/386
[#272]: https://github.com/tortoise/aerich/pull/272
[#334]: https://github.com/tortoise/aerich/pull/334
[#284]: https://github.com/tortoise/aerich/pull/284
[#286]: https://github.com/tortoise/aerich/pull/286
[#302]: https://github.com/tortoise/aerich/pull/302
[#378]: https://github.com/tortoise/aerich/pull/378
[#377]: https://github.com/tortoise/aerich/pull/377
[#271]: https://github.com/tortoise/aerich/pull/271
[#286]: https://github.com/tortoise/aerich/pull/286
[#388]: https://github.com/tortoise/aerich/pull/388
[#392]: https://github.com/tortoise/aerich/pull/392
### [0.8.0](../../releases/tag/v0.8.0) - 2024-12-04
- Fix the issue of parameter concatenation when generating ORM with inspectdb (#331)
- Fix KeyError when deleting a field with unqiue=True. (#364)
- Correct the click import. (#360)
- Improve CLI help text and output. (#355)
- Fix mysql drop unique index raises OperationalError. (#346)
**Upgrade note:**
1. Use column name as unique key name for mysql
2. Drop support for Python3.7
## 0.7 ## 0.7
### 0.7.2 ### [0.7.2](../../releases/tag/v0.7.2) - 2023-07-20
- Support virtual fields. - Support virtual fields.
- Fix modify multiple times. (#279) - Fix modify multiple times. (#279)

View File

@@ -6,27 +6,31 @@ MYSQL_PORT ?= 3306
MYSQL_PASS ?= "123456" MYSQL_PASS ?= "123456"
POSTGRES_HOST ?= "127.0.0.1" POSTGRES_HOST ?= "127.0.0.1"
POSTGRES_PORT ?= 5432 POSTGRES_PORT ?= 5432
POSTGRES_PASS ?= "123456" POSTGRES_PASS ?= 123456
up: up:
@poetry update @poetry update
deps: deps:
@poetry install -E asyncpg -E asyncmy @poetry install -E asyncpg -E asyncmy -E toml
style: deps _style:
@isort -src $(checkfiles) @isort -src $(checkfiles)
@black $(black_opts) $(checkfiles) @black $(black_opts) $(checkfiles)
style: deps _style
check: deps _check:
@black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false) @black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
@ruff $(checkfiles) @ruff check $(checkfiles)
@mypy $(checkfiles)
@bandit -r aerich
check: deps _check
test: deps test: deps
$(py_warn) TEST_DB=sqlite://:memory: py.test $(py_warn) TEST_DB=sqlite://:memory: pytest
test_sqlite: test_sqlite:
$(py_warn) TEST_DB=sqlite://:memory: py.test $(py_warn) TEST_DB=sqlite://:memory: pytest
test_mysql: test_mysql:
$(py_warn) TEST_DB="mysql://root:$(MYSQL_PASS)@$(MYSQL_HOST):$(MYSQL_PORT)/test_\{\}" pytest -vv -s $(py_warn) TEST_DB="mysql://root:$(MYSQL_PASS)@$(MYSQL_HOST):$(MYSQL_PORT)/test_\{\}" pytest -vv -s
@@ -34,9 +38,10 @@ test_mysql:
test_postgres: test_postgres:
$(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest -vv -s $(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest -vv -s
testall: deps test_sqlite test_postgres test_mysql _testall: test_sqlite test_postgres test_mysql
testall: deps _testall
build: deps build: deps
@poetry build @poetry build
ci: check testall ci: build _check _testall

View File

@@ -17,7 +17,7 @@ it\'s own migration solution.
Just install from pypi: Just install from pypi:
```shell ```shell
pip install aerich pip install "aerich[toml]"
``` ```
## Quick Start ## Quick Start
@@ -46,7 +46,7 @@ Commands:
## Usage ## Usage
You need add `aerich.models` to your `Tortoise-ORM` config first. Example: You need to add `aerich.models` to your `Tortoise-ORM` config first. Example:
```python ```python
TORTOISE_ORM = { TORTOISE_ORM = {
@@ -113,6 +113,14 @@ If `aerich` guesses you are renaming a column, it will ask `Rename {old_column}
`True` to rename column without column drop, or choose `False` to drop the column then create. Note that the latter may `True` to rename column without column drop, or choose `False` to drop the column then create. Note that the latter may
lose data. lose data.
If you need to manually write migration, you could generate empty file:
```shell
> aerich migrate --name add_index --empty
Success migrate 1_202326122220101229_add_index.py
```
### Upgrade to latest version ### Upgrade to latest version
```shell ```shell

View File

@@ -1,6 +1,6 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import List from typing import TYPE_CHECKING, List, Optional, Type
from tortoise import Tortoise, generate_schema_for_client from tortoise import Tortoise, generate_schema_for_client
from tortoise.exceptions import OperationalError from tortoise.exceptions import OperationalError
@@ -20,6 +20,9 @@ from aerich.utils import (
import_py_file, import_py_file,
) )
if TYPE_CHECKING:
from aerich.inspectdb import Inspect # noqa:F401
class Command: class Command:
def __init__( def __init__(
@@ -27,19 +30,19 @@ class Command:
tortoise_config: dict, tortoise_config: dict,
app: str = "models", app: str = "models",
location: str = "./migrations", location: str = "./migrations",
): ) -> None:
self.tortoise_config = tortoise_config self.tortoise_config = tortoise_config
self.app = app self.app = app
self.location = location self.location = location
Migrate.app = app Migrate.app = app
async def init(self): async def init(self) -> None:
await Migrate.init(self.tortoise_config, self.app, self.location) await Migrate.init(self.tortoise_config, self.app, self.location)
async def _upgrade(self, conn, version_file): async def _upgrade(self, conn, version_file) -> None:
file_path = Path(Migrate.migrate_location, version_file) file_path = Path(Migrate.migrate_location, version_file)
m = import_py_file(file_path) m = import_py_file(file_path)
upgrade = getattr(m, "upgrade") upgrade = m.upgrade
await conn.execute_script(await upgrade(conn)) await conn.execute_script(await upgrade(conn))
await Aerich.create( await Aerich.create(
version=version_file, version=version_file,
@@ -47,7 +50,7 @@ class Command:
content=get_models_describe(self.app), content=get_models_describe(self.app),
) )
async def upgrade(self, run_in_transaction: bool): async def upgrade(self, run_in_transaction: bool = True) -> List[str]:
migrated = [] migrated = []
for version_file in Migrate.get_all_version_files(): for version_file in Migrate.get_all_version_files():
try: try:
@@ -65,8 +68,8 @@ class Command:
migrated.append(version_file) migrated.append(version_file)
return migrated return migrated
async def downgrade(self, version: int, delete: bool): async def downgrade(self, version: int, delete: bool) -> List[str]:
ret = [] ret: List[str] = []
if version == -1: if version == -1:
specified_version = await Migrate.get_last_version() specified_version = await Migrate.get_last_version()
else: else:
@@ -79,25 +82,25 @@ class Command:
versions = [specified_version] versions = [specified_version]
else: else:
versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk) versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk)
for version in versions: for version_obj in versions:
file = version.version file = version_obj.version
async with in_transaction( async with in_transaction(
get_app_connection_name(self.tortoise_config, self.app) get_app_connection_name(self.tortoise_config, self.app)
) as conn: ) as conn:
file_path = Path(Migrate.migrate_location, file) file_path = Path(Migrate.migrate_location, file)
m = import_py_file(file_path) m = import_py_file(file_path)
downgrade = getattr(m, "downgrade") downgrade = m.downgrade
downgrade_sql = await downgrade(conn) downgrade_sql = await downgrade(conn)
if not downgrade_sql.strip(): if not downgrade_sql.strip():
raise DowngradeError("No downgrade items found") raise DowngradeError("No downgrade items found")
await conn.execute_script(downgrade_sql) await conn.execute_script(downgrade_sql)
await version.delete() await version_obj.delete()
if delete: if delete:
os.unlink(file_path) os.unlink(file_path)
ret.append(file) ret.append(file)
return ret return ret
async def heads(self): async def heads(self) -> List[str]:
ret = [] ret = []
versions = Migrate.get_all_version_files() versions = Migrate.get_all_version_files()
for version in versions: for version in versions:
@@ -105,15 +108,15 @@ class Command:
ret.append(version) ret.append(version)
return ret return ret
async def history(self): async def history(self) -> List[str]:
versions = Migrate.get_all_version_files() versions = Migrate.get_all_version_files()
return [version for version in versions] return [version for version in versions]
async def inspectdb(self, tables: List[str] = None) -> str: async def inspectdb(self, tables: Optional[List[str]] = None) -> str:
connection = get_app_connection(self.tortoise_config, self.app) connection = get_app_connection(self.tortoise_config, self.app)
dialect = connection.schema_generator.DIALECT dialect = connection.schema_generator.DIALECT
if dialect == "mysql": if dialect == "mysql":
cls = InspectMySQL cls: Type["Inspect"] = InspectMySQL
elif dialect == "postgres": elif dialect == "postgres":
cls = InspectPostgres cls = InspectPostgres
elif dialect == "sqlite": elif dialect == "sqlite":
@@ -123,14 +126,19 @@ class Command:
inspect = cls(connection, tables) inspect = cls(connection, tables)
return await inspect.inspect() return await inspect.inspect()
async def migrate(self, name: str = "update"): async def migrate(self, name: str = "update", empty: bool = False) -> str:
return await Migrate.migrate(name) return await Migrate.migrate(name, empty)
async def init_db(self, safe: bool): async def init_db(self, safe: bool) -> None:
location = self.location location = self.location
app = self.app app = self.app
dirname = Path(location, app) dirname = Path(location, app)
if not dirname.exists():
dirname.mkdir(parents=True) dirname.mkdir(parents=True)
else:
# If directory is empty, go ahead, otherwise raise FileExistsError
for unexpected_file in dirname.glob("*"):
raise FileExistsError(str(unexpected_file))
await Tortoise.init(config=self.tortoise_config) await Tortoise.init(config=self.tortoise_config)
connection = get_app_connection(self.tortoise_config, app) connection = get_app_connection(self.tortoise_config, app)

View File

@@ -1,14 +1,10 @@
import asyncio
import os import os
from functools import wraps import sys
from pathlib import Path from pathlib import Path
from typing import List from typing import Dict, List, cast
import click import asyncclick as click
import tomlkit from asyncclick import Context, UsageError
from click import Context, UsageError
from tomlkit.exceptions import NonExistentKey
from tortoise import Tortoise
from aerich import Command from aerich import Command
from aerich.enums import Color from aerich.enums import Color
@@ -16,26 +12,19 @@ from aerich.exceptions import DowngradeError
from aerich.utils import add_src_path, get_tortoise_config from aerich.utils import add_src_path, get_tortoise_config
from aerich.version import __version__ from aerich.version import __version__
if sys.version_info >= (3, 11):
import tomllib
else:
try:
import tomli as tomllib
except ImportError:
import tomlkit as tomllib # type: ignore
CONFIG_DEFAULT_VALUES = { CONFIG_DEFAULT_VALUES = {
"src_folder": ".", "src_folder": ".",
} }
def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
# Close db connections at the end of all but the cli group function
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
if f.__name__ not in ["cli", "init"]:
loop.run_until_complete(Tortoise.close_connections())
return wrapper
@click.group(context_settings={"help_option_names": ["-h", "--help"]}) @click.group(context_settings={"help_option_names": ["-h", "--help"]})
@click.version_option(__version__, "-V", "--version") @click.version_option(__version__, "-V", "--version")
@click.option( @click.option(
@@ -47,8 +36,7 @@ def coro(f):
) )
@click.option("--app", required=False, help="Tortoise-ORM app name.") @click.option("--app", required=False, help="Tortoise-ORM app name.")
@click.pass_context @click.pass_context
@coro async def cli(ctx: Context, config, app) -> None:
async def cli(ctx: Context, config, app):
ctx.ensure_object(dict) ctx.ensure_object(dict)
ctx.obj["config_file"] = config ctx.obj["config_file"] = config
@@ -56,57 +44,64 @@ async def cli(ctx: Context, config, app):
if invoked_subcommand != "init": if invoked_subcommand != "init":
config_path = Path(config) config_path = Path(config)
if not config_path.exists(): if not config_path.exists():
raise UsageError("You must exec init first", ctx=ctx) raise UsageError(
content = config_path.read_text() "You need to run `aerich init` first to create the config file.", ctx=ctx
doc = tomlkit.parse(content) )
content = config_path.read_text("utf-8")
doc: dict = tomllib.loads(content)
try: try:
tool = doc["tool"]["aerich"] tool = cast(Dict[str, str], doc["tool"]["aerich"])
location = tool["location"] location = tool["location"]
tortoise_orm = tool["tortoise_orm"] tortoise_orm = tool["tortoise_orm"]
src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"]) src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"])
except NonExistentKey: except KeyError as e:
raise UsageError("You need run aerich init again when upgrade to 0.6.0+") raise UsageError(
"You need run `aerich init` again when upgrading to aerich 0.6.0+."
) from e
add_src_path(src_folder) add_src_path(src_folder)
tortoise_config = get_tortoise_config(ctx, tortoise_orm) tortoise_config = get_tortoise_config(ctx, tortoise_orm)
app = app or list(tortoise_config.get("apps").keys())[0] if not app:
apps_config = cast(dict, tortoise_config.get("apps"))
app = list(apps_config.keys())[0]
command = Command(tortoise_config=tortoise_config, app=app, location=location) command = Command(tortoise_config=tortoise_config, app=app, location=location)
ctx.obj["command"] = command ctx.obj["command"] = command
if invoked_subcommand != "init-db": if invoked_subcommand != "init-db":
if not Path(location, app).exists(): if not Path(location, app).exists():
raise UsageError("You must exec init-db first", ctx=ctx) raise UsageError(
"You need to run `aerich init-db` first to initialize the database.", ctx=ctx
)
await command.init() await command.init()
@cli.command(help="Generate migrate changes file.") @cli.command(help="Generate a migration file for the current state of the models.")
@click.option("--name", default="update", show_default=True, help="Migrate name.") @click.option("--name", default="update", show_default=True, help="Migration name.")
@click.option("--empty", default=False, is_flag=True, help="Generate an empty migration file.")
@click.pass_context @click.pass_context
@coro async def migrate(ctx: Context, name, empty) -> None:
async def migrate(ctx: Context, name):
command = ctx.obj["command"] command = ctx.obj["command"]
ret = await command.migrate(name) ret = await command.migrate(name, empty)
if not ret: if not ret:
return click.secho("No changes detected", fg=Color.yellow) return click.secho("No changes detected", fg=Color.yellow)
click.secho(f"Success migrate {ret}", fg=Color.green) click.secho(f"Success creating migration file {ret}", fg=Color.green)
@cli.command(help="Upgrade to specified version.") @cli.command(help="Upgrade to specified migration version.")
@click.option( @click.option(
"--in-transaction", "--in-transaction",
"-i", "-i",
default=True, default=True,
type=bool, type=bool,
help="Make migrations in transaction or not. Can be helpful for large migrations or creating concurrent indexes.", help="Make migrations in a single transaction or not. Can be helpful for large migrations or creating concurrent indexes.",
) )
@click.pass_context @click.pass_context
@coro async def upgrade(ctx: Context, in_transaction: bool) -> None:
async def upgrade(ctx: Context, in_transaction: bool):
command = ctx.obj["command"] command = ctx.obj["command"]
migrated = await command.upgrade(run_in_transaction=in_transaction) migrated = await command.upgrade(run_in_transaction=in_transaction)
if not migrated: if not migrated:
click.secho("No upgrade items found", fg=Color.yellow) click.secho("No upgrade items found", fg=Color.yellow)
else: else:
for version_file in migrated: for version_file in migrated:
click.secho(f"Success upgrade {version_file}", fg=Color.green) click.secho(f"Success upgrading to {version_file}", fg=Color.green)
@cli.command(help="Downgrade to specified version.") @cli.command(help="Downgrade to specified version.")
@@ -115,8 +110,8 @@ async def upgrade(ctx: Context, in_transaction: bool):
"--version", "--version",
default=-1, default=-1,
type=int, type=int,
show_default=True, show_default=False,
help="Specified version, default to last.", help="Specified version, default to last migration.",
) )
@click.option( @click.option(
"-d", "-d",
@@ -124,59 +119,56 @@ async def upgrade(ctx: Context, in_transaction: bool):
is_flag=True, is_flag=True,
default=False, default=False,
show_default=True, show_default=True,
help="Delete version files at the same time.", help="Also delete the migration files.",
) )
@click.pass_context @click.pass_context
@click.confirmation_option( @click.confirmation_option(
prompt="Downgrade is dangerous, which maybe lose your data, are you sure?", prompt="Downgrade is dangerous: you might lose your data! Are you sure?",
) )
@coro async def downgrade(ctx: Context, version: int, delete: bool) -> None:
async def downgrade(ctx: Context, version: int, delete: bool):
command = ctx.obj["command"] command = ctx.obj["command"]
try: try:
files = await command.downgrade(version, delete) files = await command.downgrade(version, delete)
except DowngradeError as e: except DowngradeError as e:
return click.secho(str(e), fg=Color.yellow) return click.secho(str(e), fg=Color.yellow)
for file in files: for file in files:
click.secho(f"Success downgrade {file}", fg=Color.green) click.secho(f"Success downgrading to {file}", fg=Color.green)
@cli.command(help="Show current available heads in migrate location.") @cli.command(help="Show currently available heads (unapplied migrations).")
@click.pass_context @click.pass_context
@coro async def heads(ctx: Context) -> None:
async def heads(ctx: Context):
command = ctx.obj["command"] command = ctx.obj["command"]
head_list = await command.heads() head_list = await command.heads()
if not head_list: if not head_list:
return click.secho("No available heads, try migrate first", fg=Color.green) return click.secho("No available heads.", fg=Color.green)
for version in head_list: for version in head_list:
click.secho(version, fg=Color.green) click.secho(version, fg=Color.green)
@cli.command(help="List all migrate items.") @cli.command(help="List all migrations.")
@click.pass_context @click.pass_context
@coro async def history(ctx: Context) -> None:
async def history(ctx: Context):
command = ctx.obj["command"] command = ctx.obj["command"]
versions = await command.history() versions = await command.history()
if not versions: if not versions:
return click.secho("No history, try migrate", fg=Color.green) return click.secho("No migrations created yet.", fg=Color.green)
for version in versions: for version in versions:
click.secho(version, fg=Color.green) click.secho(version, fg=Color.green)
@cli.command(help="Init config file and generate root migrate location.") @cli.command(help="Initialize aerich config and create migrations folder.")
@click.option( @click.option(
"-t", "-t",
"--tortoise-orm", "--tortoise-orm",
required=True, required=True,
help="Tortoise-ORM config module dict variable, like settings.TORTOISE_ORM.", help="Tortoise-ORM config dict location, like `settings.TORTOISE_ORM`.",
) )
@click.option( @click.option(
"--location", "--location",
default="./migrations", default="./migrations",
show_default=True, show_default=True,
help="Migrate store location.", help="Migrations folder.",
) )
@click.option( @click.option(
"-s", "-s",
@@ -186,8 +178,11 @@ async def history(ctx: Context):
help="Folder of the source, relative to the project root.", help="Folder of the source, relative to the project root.",
) )
@click.pass_context @click.pass_context
@coro async def init(ctx: Context, tortoise_orm, location, src_folder) -> None:
async def init(ctx: Context, tortoise_orm, location, src_folder): try:
import tomli_w as tomlkit
except ImportError:
import tomlkit # type: ignore
config_file = ctx.obj["config_file"] config_file = ctx.obj["config_file"]
if os.path.isabs(src_folder): if os.path.isabs(src_folder):
@@ -200,52 +195,50 @@ async def init(ctx: Context, tortoise_orm, location, src_folder):
add_src_path(src_folder) add_src_path(src_folder)
get_tortoise_config(ctx, tortoise_orm) get_tortoise_config(ctx, tortoise_orm)
config_path = Path(config_file) config_path = Path(config_file)
if config_path.exists(): content = config_path.read_text("utf-8") if config_path.exists() else "[tool.aerich]"
content = config_path.read_text() doc: dict = tomllib.loads(content)
doc = tomlkit.parse(content) table: dict = getattr(tomlkit, "table", dict)()
else:
doc = tomlkit.parse("[tool.aerich]")
table = tomlkit.table()
table["tortoise_orm"] = tortoise_orm table["tortoise_orm"] = tortoise_orm
table["location"] = location table["location"] = location
table["src_folder"] = src_folder table["src_folder"] = src_folder
try:
doc["tool"]["aerich"] = table doc["tool"]["aerich"] = table
except KeyError:
doc["tool"] = {"aerich": table}
config_path.write_text(tomlkit.dumps(doc)) config_path.write_text(tomlkit.dumps(doc))
Path(location).mkdir(parents=True, exist_ok=True) Path(location).mkdir(parents=True, exist_ok=True)
click.secho(f"Success create migrate location {location}", fg=Color.green) click.secho(f"Success creating migrations folder {location}", fg=Color.green)
click.secho(f"Success write config to {config_file}", fg=Color.green) click.secho(f"Success writing aerich config to {config_file}", fg=Color.green)
@cli.command(help="Generate schema and generate app migrate location.") @cli.command(help="Generate schema and generate app migration folder.")
@click.option( @click.option(
"-s", "-s",
"--safe", "--safe",
type=bool, type=bool,
is_flag=True, is_flag=True,
default=True, default=True,
help="When set to true, creates the table only when it does not already exist.", help="Create tables only when they do not already exist.",
show_default=True, show_default=True,
) )
@click.pass_context @click.pass_context
@coro async def init_db(ctx: Context, safe: bool) -> None:
async def init_db(ctx: Context, safe: bool):
command = ctx.obj["command"] command = ctx.obj["command"]
app = command.app app = command.app
dirname = Path(command.location, app) dirname = Path(command.location, app)
try: try:
await command.init_db(safe) await command.init_db(safe)
click.secho(f"Success create app migrate location {dirname}", fg=Color.green) click.secho(f"Success creating app migration folder {dirname}", fg=Color.green)
click.secho(f'Success generate schema for app "{app}"', fg=Color.green) click.secho(f'Success generating initial migration file for app "{app}"', fg=Color.green)
except FileExistsError: except FileExistsError:
return click.secho( return click.secho(
f"Inited {app} already, or delete {dirname} and try again.", fg=Color.yellow f"App {app} is already initialized. Delete {dirname} and try again.", fg=Color.yellow
) )
@cli.command(help="Introspects the database tables to standard output as TortoiseORM model.") @cli.command(help="Prints the current database tables to stdout as Tortoise-ORM models.")
@click.option( @click.option(
"-t", "-t",
"--table", "--table",
@@ -254,14 +247,13 @@ async def init_db(ctx: Context, safe: bool):
required=False, required=False,
) )
@click.pass_context @click.pass_context
@coro async def inspectdb(ctx: Context, table: List[str]) -> None:
async def inspectdb(ctx: Context, table: List[str]):
command = ctx.obj["command"] command = ctx.obj["command"]
ret = await command.inspectdb(table) ret = await command.inspectdb(table)
click.secho(ret) click.secho(ret)
def main(): def main() -> None:
cli() cli()

View File

@@ -1,12 +1,13 @@
import base64 import base64
import json import json
import pickle # nosec: B301,B403 import pickle # nosec: B301,B403
from typing import Any, Union
from tortoise.indexes import Index from tortoise.indexes import Index
class JsonEncoder(json.JSONEncoder): class JsonEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj) -> Any:
if isinstance(obj, Index): if isinstance(obj, Index):
return { return {
"type": "index", "type": "index",
@@ -16,16 +17,16 @@ class JsonEncoder(json.JSONEncoder):
return super().default(obj) return super().default(obj)
def object_hook(obj): def object_hook(obj) -> Any:
_type = obj.get("type") _type = obj.get("type")
if not _type: if not _type:
return obj return obj
return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301 return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301
def encoder(obj: dict): def encoder(obj: dict) -> str:
return json.dumps(obj, cls=JsonEncoder) return json.dumps(obj, cls=JsonEncoder)
def decoder(obj: str): def decoder(obj: Union[str, bytes]) -> Any:
return json.loads(obj, object_hook=object_hook) return json.loads(obj, object_hook=object_hook)

View File

@@ -1,8 +1,9 @@
from enum import Enum from enum import Enum
from typing import List, Type from typing import Any, List, Type, cast
from tortoise import BaseDBAsyncClient, Model from tortoise import BaseDBAsyncClient, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
from aerich.utils import is_default_function from aerich.utils import is_default_function
@@ -35,25 +36,26 @@ class BaseDDL:
) )
_RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"' _RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"'
def __init__(self, client: "BaseDBAsyncClient"): def __init__(self, client: "BaseDBAsyncClient") -> None:
self.client = client self.client = client
self.schema_generator = self.schema_generator_cls(client) self.schema_generator = self.schema_generator_cls(client)
def create_table(self, model: "Type[Model]"): def create_table(self, model: "Type[Model]") -> str:
return self.schema_generator._get_table_sql(model, True)["table_creation_string"].rstrip( return self.schema_generator._get_table_sql(model, True)["table_creation_string"].rstrip(
";" ";"
) )
def drop_table(self, table_name: str): def drop_table(self, table_name: str) -> str:
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name) return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def create_m2m( def create_m2m(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
): ) -> str:
through = field_describe.get("through") through = cast(str, field_describe.get("through"))
description = field_describe.get("description") description = field_describe.get("description")
reference_id = reference_table_describe.get("pk_field").get("db_column") pk_field = cast(dict, reference_table_describe.get("pk_field"))
db_field_types = reference_table_describe.get("pk_field").get("db_field_types") reference_id = pk_field.get("db_column")
db_field_types = cast(dict, pk_field.get("db_field_types"))
return self._M2M_TABLE_TEMPLATE.format( return self._M2M_TABLE_TEMPLATE.format(
table_name=through, table_name=through,
backward_table=model._meta.db_table, backward_table=model._meta.db_table,
@@ -66,22 +68,22 @@ class BaseDDL:
forward_type=db_field_types.get(self.DIALECT) or db_field_types.get(""), forward_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
on_delete=field_describe.get("on_delete"), on_delete=field_describe.get("on_delete"),
extra=self.schema_generator._table_generate_extra(table=through), extra=self.schema_generator._table_generate_extra(table=through),
comment=self.schema_generator._table_comment_generator( comment=(
table=through, comment=description self.schema_generator._table_comment_generator(table=through, comment=description)
)
if description if description
else "", else ""
),
) )
def drop_m2m(self, table_name: str): def drop_m2m(self, table_name: str) -> str:
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name) return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def _get_default(self, model: "Type[Model]", field_describe: dict): def _get_default(self, model: "Type[Model]", field_describe: dict) -> Any:
db_table = model._meta.db_table db_table = model._meta.db_table
default = field_describe.get("default") default = field_describe.get("default")
if isinstance(default, Enum): if isinstance(default, Enum):
default = default.value default = default.value
db_column = field_describe.get("db_column") db_column = cast(str, field_describe.get("db_column"))
auto_now_add = field_describe.get("auto_now_add", False) auto_now_add = field_describe.get("auto_now_add", False)
auto_now = field_describe.get("auto_now", False) auto_now = field_describe.get("auto_now", False)
if default is not None or auto_now_add: if default is not None or auto_now_add:
@@ -106,64 +108,60 @@ class BaseDDL:
default = None default = None
return default return default
def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False): def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str:
return self._add_or_modify_column(model, field_describe, is_pk)
def _add_or_modify_column(self, model, field_describe: dict, is_pk: bool, modify=False) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
description = field_describe.get("description") description = field_describe.get("description")
db_column = field_describe.get("db_column") db_column = cast(str, field_describe.get("db_column"))
db_field_types = field_describe.get("db_field_types") db_field_types = cast(dict, field_describe.get("db_field_types"))
default = self._get_default(model, field_describe) default = self._get_default(model, field_describe)
if default is None: if default is None:
default = "" default = ""
return self._ADD_COLUMN_TEMPLATE.format( if modify:
unique = ""
template = self._MODIFY_COLUMN_TEMPLATE
else:
# sqlite does not support alter table to add unique column
unique = (
"UNIQUE"
if field_describe.get("unique") and self.DIALECT != SqliteSchemaGenerator.DIALECT
else ""
)
template = self._ADD_COLUMN_TEMPLATE
return template.format(
table_name=db_table, table_name=db_table,
column=self.schema_generator._create_string( column=self.schema_generator._create_string(
db_column=db_column, db_column=db_column,
field_type=db_field_types.get(self.DIALECT, db_field_types.get("")), field_type=db_field_types.get(self.DIALECT, db_field_types.get("")),
nullable="NOT NULL" if not field_describe.get("nullable") else "", nullable="NOT NULL" if not field_describe.get("nullable") else "",
unique="UNIQUE" if field_describe.get("unique") else "", unique=unique,
comment=self.schema_generator._column_comment_generator( comment=(
self.schema_generator._column_comment_generator(
table=db_table, table=db_table,
column=db_column, column=db_column,
comment=field_describe.get("description"), comment=description,
) )
if description if description
else "", else ""
),
is_primary_key=is_pk, is_primary_key=is_pk,
default=default, default=default,
), ),
) )
def drop_column(self, model: "Type[Model]", column_name: str): def drop_column(self, model: "Type[Model]", column_name: str) -> str:
return self._DROP_COLUMN_TEMPLATE.format( return self._DROP_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, column_name=column_name table_name=model._meta.db_table, column_name=column_name
) )
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False): def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str:
db_table = model._meta.db_table return self._add_or_modify_column(model, field_describe, is_pk, modify=True)
db_field_types = field_describe.get("db_field_types")
default = self._get_default(model, field_describe)
if default is None:
default = ""
return self._MODIFY_COLUMN_TEMPLATE.format(
table_name=db_table,
column=self.schema_generator._create_string(
db_column=field_describe.get("db_column"),
field_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
nullable="NOT NULL" if not field_describe.get("nullable") else "",
unique="",
comment=self.schema_generator._column_comment_generator(
table=db_table,
column=field_describe.get("db_column"),
comment=field_describe.get("description"),
)
if field_describe.get("description")
else "",
is_primary_key=is_pk,
default=default,
),
)
def rename_column(self, model: "Type[Model]", old_column_name: str, new_column_name: str): def rename_column(
self, model: "Type[Model]", old_column_name: str, new_column_name: str
) -> str:
return self._RENAME_COLUMN_TEMPLATE.format( return self._RENAME_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, table_name=model._meta.db_table,
old_column_name=old_column_name, old_column_name=old_column_name,
@@ -172,7 +170,7 @@ class BaseDDL:
def change_column( def change_column(
self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str
): ) -> str:
return self._CHANGE_COLUMN_TEMPLATE.format( return self._CHANGE_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, table_name=model._meta.db_table,
old_column_name=old_column_name, old_column_name=old_column_name,
@@ -180,7 +178,7 @@ class BaseDDL:
new_column_type=new_column_type, new_column_type=new_column_type,
) )
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False): def add_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str:
return self._ADD_INDEX_TEMPLATE.format( return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE " if unique else "", unique="UNIQUE " if unique else "",
index_name=self.schema_generator._generate_index_name( index_name=self.schema_generator._generate_index_name(
@@ -190,7 +188,7 @@ class BaseDDL:
column_names=", ".join(self.schema_generator.quote(f) for f in field_names), column_names=", ".join(self.schema_generator.quote(f) for f in field_names),
) )
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False): def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str:
return self._DROP_INDEX_TEMPLATE.format( return self._DROP_INDEX_TEMPLATE.format(
index_name=self.schema_generator._generate_index_name( index_name=self.schema_generator._generate_index_name(
"idx" if not unique else "uid", model, field_names "idx" if not unique else "uid", model, field_names
@@ -198,45 +196,52 @@ class BaseDDL:
table_name=model._meta.db_table, table_name=model._meta.db_table,
) )
def drop_index_by_name(self, model: "Type[Model]", index_name: str): def drop_index_by_name(self, model: "Type[Model]", index_name: str) -> str:
return self._DROP_INDEX_TEMPLATE.format( return self._DROP_INDEX_TEMPLATE.format(
index_name=index_name, index_name=index_name,
table_name=model._meta.db_table, table_name=model._meta.db_table,
) )
def add_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict): def _generate_fk_name(
self, db_table, field_describe: dict, reference_table_describe: dict
) -> str:
"""Generate fk name"""
db_column = cast(str, field_describe.get("raw_field"))
pk_field = cast(dict, reference_table_describe.get("pk_field"))
to_field = cast(str, pk_field.get("db_column"))
to_table = cast(str, reference_table_describe.get("table"))
return self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=db_column,
to_table=to_table,
to_field=to_field,
)
def add_fk(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
db_column = field_describe.get("raw_field") db_column = field_describe.get("raw_field")
reference_id = reference_table_describe.get("pk_field").get("db_column") pk_field = cast(dict, reference_table_describe.get("pk_field"))
fk_name = self.schema_generator._generate_fk_name( reference_id = pk_field.get("db_column")
from_table=db_table,
from_field=db_column,
to_table=reference_table_describe.get("table"),
to_field=reference_table_describe.get("pk_field").get("db_column"),
)
return self._ADD_FK_TEMPLATE.format( return self._ADD_FK_TEMPLATE.format(
table_name=db_table, table_name=db_table,
fk_name=fk_name, fk_name=self._generate_fk_name(db_table, field_describe, reference_table_describe),
db_column=db_column, db_column=db_column,
table=reference_table_describe.get("table"), table=reference_table_describe.get("table"),
field=reference_id, field=reference_id,
on_delete=field_describe.get("on_delete"), on_delete=field_describe.get("on_delete"),
) )
def drop_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict): def drop_fk(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
return self._DROP_FK_TEMPLATE.format( fk_name = self._generate_fk_name(db_table, field_describe, reference_table_describe)
table_name=db_table, return self._DROP_FK_TEMPLATE.format(table_name=db_table, fk_name=fk_name)
fk_name=self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=field_describe.get("raw_field"),
to_table=reference_table_describe.get("table"),
to_field=reference_table_describe.get("pk_field").get("db_column"),
),
)
def alter_column_default(self, model: "Type[Model]", field_describe: dict): def alter_column_default(self, model: "Type[Model]", field_describe: dict) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
default = self._get_default(model, field_describe) default = self._get_default(model, field_describe)
return self._ALTER_DEFAULT_TEMPLATE.format( return self._ALTER_DEFAULT_TEMPLATE.format(
@@ -245,13 +250,13 @@ class BaseDDL:
default="SET" + default if default is not None else "DROP DEFAULT", default="SET" + default if default is not None else "DROP DEFAULT",
) )
def alter_column_null(self, model: "Type[Model]", field_describe: dict): def alter_column_null(self, model: "Type[Model]", field_describe: dict) -> str:
return self.modify_column(model, field_describe) return self.modify_column(model, field_describe)
def set_comment(self, model: "Type[Model]", field_describe: dict): def set_comment(self, model: "Type[Model]", field_describe: dict) -> str:
return self.modify_column(model, field_describe) return self.modify_column(model, field_describe)
def rename_table(self, model: "Type[Model]", old_table_name: str, new_table_name: str): def rename_table(self, model: "Type[Model]", old_table_name: str, new_table_name: str) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
return self._RENAME_TABLE_TEMPLATE.format( return self._RENAME_TABLE_TEMPLATE.format(
table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name

View File

@@ -1,7 +1,12 @@
from typing import TYPE_CHECKING, List, Type
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
if TYPE_CHECKING:
from tortoise import Model # noqa:F401
class MysqlDDL(BaseDDL): class MysqlDDL(BaseDDL):
schema_generator_cls = MySQLSchemaGenerator schema_generator_cls = MySQLSchemaGenerator
@@ -30,3 +35,29 @@ class MysqlDDL(BaseDDL):
) )
_MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}" _MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}"
_RENAME_TABLE_TEMPLATE = "ALTER TABLE `{old_table_name}` RENAME TO `{new_table_name}`" _RENAME_TABLE_TEMPLATE = "ALTER TABLE `{old_table_name}` RENAME TO `{new_table_name}`"
def _index_name(self, unique: bool, model: "Type[Model]", field_names: List[str]) -> str:
if unique:
if len(field_names) == 1:
# Example: `email = CharField(max_length=50, unique=True)`
# Generate schema: `"email" VARCHAR(10) NOT NULL UNIQUE`
# Unique index key is the same as field name: `email`
return field_names[0]
index_prefix = "uid"
else:
index_prefix = "idx"
return self.schema_generator._generate_index_name(index_prefix, model, field_names)
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str:
return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE " if unique else "",
index_name=self._index_name(unique, model, field_names),
table_name=model._meta.db_table,
column_names=", ".join(self.schema_generator.quote(f) for f in field_names),
)
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str:
return self._DROP_INDEX_TEMPLATE.format(
index_name=self._index_name(unique, model, field_names),
table_name=model._meta.db_table,
)

View File

@@ -1,4 +1,4 @@
from typing import Type from typing import Type, cast
from tortoise import Model from tortoise import Model
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
@@ -10,15 +10,15 @@ class PostgresDDL(BaseDDL):
schema_generator_cls = AsyncpgSchemaGenerator schema_generator_cls = AsyncpgSchemaGenerator
DIALECT = AsyncpgSchemaGenerator.DIALECT DIALECT = AsyncpgSchemaGenerator.DIALECT
_ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX "{index_name}" ON "{table_name}" ({column_names})' _ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX "{index_name}" ON "{table_name}" ({column_names})'
_DROP_INDEX_TEMPLATE = 'DROP INDEX "{index_name}"' _DROP_INDEX_TEMPLATE = 'DROP INDEX IF EXISTS "{index_name}"'
_ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL' _ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL'
_MODIFY_COLUMN_TEMPLATE = ( _MODIFY_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}{using}' 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}{using}'
) )
_SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}' _SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT IF EXISTS "{fk_name}"'
def alter_column_null(self, model: "Type[Model]", field_describe: dict): def alter_column_null(self, model: "Type[Model]", field_describe: dict) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
return self._ALTER_NULL_TEMPLATE.format( return self._ALTER_NULL_TEMPLATE.format(
table_name=db_table, table_name=db_table,
@@ -26,9 +26,9 @@ class PostgresDDL(BaseDDL):
set_drop="DROP" if field_describe.get("nullable") else "SET", set_drop="DROP" if field_describe.get("nullable") else "SET",
) )
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False): def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
db_field_types = field_describe.get("db_field_types") db_field_types = cast(dict, field_describe.get("db_field_types"))
db_column = field_describe.get("db_column") db_column = field_describe.get("db_column")
datatype = db_field_types.get(self.DIALECT) or db_field_types.get("") datatype = db_field_types.get(self.DIALECT) or db_field_types.get("")
return self._MODIFY_COLUMN_TEMPLATE.format( return self._MODIFY_COLUMN_TEMPLATE.format(
@@ -38,12 +38,14 @@ class PostgresDDL(BaseDDL):
using=f' USING "{db_column}"::{datatype}', using=f' USING "{db_column}"::{datatype}',
) )
def set_comment(self, model: "Type[Model]", field_describe: dict): def set_comment(self, model: "Type[Model]", field_describe: dict) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
return self._SET_COMMENT_TEMPLATE.format( return self._SET_COMMENT_TEMPLATE.format(
table_name=db_table, table_name=db_table,
column=field_describe.get("db_column") or field_describe.get("raw_field"), column=field_describe.get("db_column") or field_describe.get("raw_field"),
comment="'{}'".format(field_describe.get("description")) comment=(
"'{}'".format(field_describe.get("description"))
if field_describe.get("description") if field_describe.get("description")
else "NULL", else "NULL"
),
) )

View File

@@ -10,6 +10,8 @@ from aerich.exceptions import NotSupportError
class SqliteDDL(BaseDDL): class SqliteDDL(BaseDDL):
schema_generator_cls = SqliteSchemaGenerator schema_generator_cls = SqliteSchemaGenerator
DIALECT = SqliteSchemaGenerator.DIALECT DIALECT = SqliteSchemaGenerator.DIALECT
_ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX "{index_name}" ON "{table_name}" ({column_names})'
_DROP_INDEX_TEMPLATE = 'DROP INDEX IF EXISTS "{index_name}"'
def modify_column(self, model: "Type[Model]", field_object: dict, is_pk: bool = True): def modify_column(self, model: "Type[Model]", field_object: dict, is_pk: bool = True):
raise NotSupportError("Modify column is unsupported in SQLite.") raise NotSupportError("Modify column is unsupported in SQLite.")

View File

@@ -1,24 +1,40 @@
from typing import Any, List, Optional from __future__ import annotations
import contextlib
from typing import Any, Callable, Dict, Optional, TypedDict
from pydantic import BaseModel from pydantic import BaseModel
from tortoise import BaseDBAsyncClient from tortoise import BaseDBAsyncClient
class ColumnInfoDict(TypedDict):
name: str
pk: str
index: str
null: str
default: str
length: str
comment: str
FieldMapDict = Dict[str, Callable[..., str]]
class Column(BaseModel): class Column(BaseModel):
name: str name: str
data_type: str data_type: str
null: bool null: bool
default: Any default: Any
comment: Optional[str] comment: Optional[str] = None
pk: bool pk: bool
unique: bool unique: bool
index: bool index: bool
length: Optional[int] length: Optional[int] = None
extra: Optional[str] extra: Optional[str] = None
decimal_places: Optional[int] decimal_places: Optional[int] = None
max_digits: Optional[int] max_digits: Optional[int] = None
def translate(self) -> dict: def translate(self) -> ColumnInfoDict:
comment = default = length = index = null = pk = "" comment = default = length = index = null = pk = ""
if self.pk: if self.pk:
pk = "pk=True, " pk = "pk=True, "
@@ -28,25 +44,26 @@ class Column(BaseModel):
else: else:
if self.index: if self.index:
index = "index=True, " index = "index=True, "
if self.data_type in ["varchar", "VARCHAR"]: if self.data_type in ("varchar", "VARCHAR"):
length = f"max_length={self.length}, " length = f"max_length={self.length}, "
if self.data_type in ["decimal", "numeric"]: elif self.data_type in ("decimal", "numeric"):
length_parts = [] length_parts = []
if self.max_digits: if self.max_digits:
length_parts.append(f"max_digits={self.max_digits}") length_parts.append(f"max_digits={self.max_digits}")
if self.decimal_places: if self.decimal_places:
length_parts.append(f"decimal_places={self.decimal_places}") length_parts.append(f"decimal_places={self.decimal_places}")
length = ", ".join(length_parts) if length_parts:
length = ", ".join(length_parts) + ", "
if self.null: if self.null:
null = "null=True, " null = "null=True, "
if self.default is not None: if self.default is not None:
if self.data_type in ["tinyint", "INT"]: if self.data_type in ("tinyint", "INT"):
default = f"default={'True' if self.default == '1' else 'False'}, " default = f"default={'True' if self.default == '1' else 'False'}, "
elif self.data_type == "bool": elif self.data_type == "bool":
default = f"default={'True' if self.default == 'true' else 'False'}, " default = f"default={'True' if self.default == 'true' else 'False'}, "
elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]: elif self.data_type in ("datetime", "timestamptz", "TIMESTAMP"):
if "CURRENT_TIMESTAMP" == self.default: if self.default == "CURRENT_TIMESTAMP":
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra: if self.extra == "DEFAULT_GENERATED on update CURRENT_TIMESTAMP":
default = "auto_now=True, " default = "auto_now=True, "
else: else:
default = "auto_now_add=True, " default = "auto_now_add=True, "
@@ -55,6 +72,8 @@ class Column(BaseModel):
default = f"default={self.default.split('::')[0]}, " default = f"default={self.default.split('::')[0]}, "
elif self.default.endswith("()"): elif self.default.endswith("()"):
default = "" default = ""
elif self.default == "":
default = 'default=""'
else: else:
default = f"default={self.default}, " default = f"default={self.default}, "
@@ -74,16 +93,14 @@ class Column(BaseModel):
class Inspect: class Inspect:
_table_template = "class {table}(Model):\n" _table_template = "class {table}(Model):\n"
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None): def __init__(self, conn: BaseDBAsyncClient, tables: list[str] | None = None) -> None:
self.conn = conn self.conn = conn
try: with contextlib.suppress(AttributeError):
self.database = conn.database self.database = conn.database # type:ignore[attr-defined]
except AttributeError:
pass
self.tables = tables self.tables = tables
@property @property
def field_map(self) -> dict: def field_map(self) -> FieldMapDict:
raise NotImplementedError raise NotImplementedError
async def inspect(self) -> str: async def inspect(self) -> str:
@@ -101,10 +118,10 @@ class Inspect:
tables.append(model + "\n".join(fields)) tables.append(model + "\n".join(fields))
return result + "\n\n\n".join(tables) return result + "\n\n\n".join(tables)
async def get_columns(self, table: str) -> List[Column]: async def get_columns(self, table: str) -> list[Column]:
raise NotImplementedError raise NotImplementedError
async def get_all_tables(self) -> List[str]: async def get_all_tables(self) -> list[str]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod

View File

@@ -1,17 +1,18 @@
from typing import List from __future__ import annotations
from aerich.inspectdb import Column, Inspect from aerich.inspectdb import Column, FieldMapDict, Inspect
class InspectMySQL(Inspect): class InspectMySQL(Inspect):
@property @property
def field_map(self) -> dict: def field_map(self) -> FieldMapDict:
return { return {
"int": self.int_field, "int": self.int_field,
"smallint": self.smallint_field, "smallint": self.smallint_field,
"tinyint": self.bool_field, "tinyint": self.bool_field,
"bigint": self.bigint_field, "bigint": self.bigint_field,
"varchar": self.char_field, "varchar": self.char_field,
"char": self.char_field,
"longtext": self.text_field, "longtext": self.text_field,
"text": self.text_field, "text": self.text_field,
"datetime": self.datetime_field, "datetime": self.datetime_field,
@@ -23,12 +24,12 @@ class InspectMySQL(Inspect):
"longblob": self.binary_field, "longblob": self.binary_field,
} }
async def get_all_tables(self) -> List[str]: async def get_all_tables(self) -> list[str]:
sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s" sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s"
ret = await self.conn.execute_query_dict(sql, [self.database]) ret = await self.conn.execute_query_dict(sql, [self.database])
return list(map(lambda x: x["TABLE_NAME"], ret)) return list(map(lambda x: x["TABLE_NAME"], ret))
async def get_columns(self, table: str) -> List[Column]: async def get_columns(self, table: str) -> list[Column]:
columns = [] columns = []
sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME
from information_schema.COLUMNS c from information_schema.COLUMNS c
@@ -39,16 +40,11 @@ where c.TABLE_SCHEMA = %s
and c.TABLE_NAME = %s""" and c.TABLE_NAME = %s"""
ret = await self.conn.execute_query_dict(sql, [self.database, table]) ret = await self.conn.execute_query_dict(sql, [self.database, table])
for row in ret: for row in ret:
non_unique = row["NON_UNIQUE"] unique = index = False
if non_unique is None: if (non_unique := row["NON_UNIQUE"]) is not None:
unique = False
else:
unique = not non_unique unique = not non_unique
index_name = row["INDEX_NAME"] if (index_name := row["INDEX_NAME"]) is not None:
if index_name is None: index = index_name != "PRIMARY"
index = False
else:
index = row["INDEX_NAME"] != "PRIMARY"
columns.append( columns.append(
Column( Column(
name=row["COLUMN_NAME"], name=row["COLUMN_NAME"],
@@ -59,7 +55,8 @@ where c.TABLE_SCHEMA = %s
comment=row["COLUMN_COMMENT"], comment=row["COLUMN_COMMENT"],
unique=row["COLUMN_KEY"] == "UNI", unique=row["COLUMN_KEY"] == "UNI",
extra=row["EXTRA"], extra=row["EXTRA"],
unque=unique, # TODO: why `unque`?
unque=unique, # type:ignore
index=index, index=index,
length=row["CHARACTER_MAXIMUM_LENGTH"], length=row["CHARACTER_MAXIMUM_LENGTH"],
max_digits=row["NUMERIC_PRECISION"], max_digits=row["NUMERIC_PRECISION"],

View File

@@ -1,17 +1,20 @@
from typing import List, Optional from __future__ import annotations
from tortoise import BaseDBAsyncClient from typing import TYPE_CHECKING
from aerich.inspectdb import Column, Inspect from aerich.inspectdb import Column, FieldMapDict, Inspect
if TYPE_CHECKING:
from tortoise.backends.base_postgres.client import BasePostgresClient
class InspectPostgres(Inspect): class InspectPostgres(Inspect):
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None): def __init__(self, conn: "BasePostgresClient", tables: list[str] | None = None) -> None:
super().__init__(conn, tables) super().__init__(conn, tables)
self.schema = self.conn.server_settings.get("schema") or "public" self.schema = conn.server_settings.get("schema") or "public"
@property @property
def field_map(self) -> dict: def field_map(self) -> FieldMapDict:
return { return {
"int4": self.int_field, "int4": self.int_field,
"int8": self.int_field, "int8": self.int_field,
@@ -33,12 +36,12 @@ class InspectPostgres(Inspect):
"timestamp": self.datetime_field, "timestamp": self.datetime_field,
} }
async def get_all_tables(self) -> List[str]: async def get_all_tables(self) -> list[str]:
sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2" sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2"
ret = await self.conn.execute_query_dict(sql, [self.database, self.schema]) ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
return list(map(lambda x: x["table_name"], ret)) return list(map(lambda x: x["table_name"], ret))
async def get_columns(self, table: str) -> List[Column]: async def get_columns(self, table: str) -> list[Column]:
columns = [] columns = []
sql = f"""select c.column_name, sql = f"""select c.column_name,
col_description('public.{table}'::regclass, ordinal_position) as column_comment, col_description('public.{table}'::regclass, ordinal_position) as column_comment,
@@ -55,7 +58,7 @@ from information_schema.constraint_column_usage const
right join information_schema.columns c using (column_name, table_catalog, table_schema, table_name) right join information_schema.columns c using (column_name, table_catalog, table_schema, table_name)
where c.table_catalog = $1 where c.table_catalog = $1
and c.table_name = $2 and c.table_name = $2
and c.table_schema = $3""" and c.table_schema = $3""" # nosec:B608
ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema]) ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema])
for row in ret: for row in ret:
columns.append( columns.append(

View File

@@ -1,11 +1,11 @@
from typing import List from __future__ import annotations
from aerich.inspectdb import Column, Inspect from aerich.inspectdb import Column, FieldMapDict, Inspect
class InspectSQLite(Inspect): class InspectSQLite(Inspect):
@property @property
def field_map(self) -> dict: def field_map(self) -> FieldMapDict:
return { return {
"INTEGER": self.int_field, "INTEGER": self.int_field,
"INT": self.bool_field, "INT": self.bool_field,
@@ -21,7 +21,7 @@ class InspectSQLite(Inspect):
"BLOB": self.binary_field, "BLOB": self.binary_field,
} }
async def get_columns(self, table: str) -> List[Column]: async def get_columns(self, table: str) -> list[Column]:
columns = [] columns = []
sql = f"PRAGMA table_info({table})" sql = f"PRAGMA table_info({table})"
ret = await self.conn.execute_query_dict(sql) ret = await self.conn.execute_query_dict(sql)
@@ -45,7 +45,7 @@ class InspectSQLite(Inspect):
) )
return columns return columns
async def _get_columns_index(self, table: str): async def _get_columns_index(self, table: str) -> dict[str, str]:
sql = f"PRAGMA index_list ({table})" sql = f"PRAGMA index_list ({table})"
indexes = await self.conn.execute_query_dict(sql) indexes = await self.conn.execute_query_dict(sql)
ret = {} ret = {}
@@ -55,7 +55,7 @@ class InspectSQLite(Inspect):
ret[index_info["name"]] = "unique" if index["unique"] else "index" ret[index_info["name"]] = "unique" if index["unique"] else "index"
return ret return ret
async def get_all_tables(self) -> List[str]: async def get_all_tables(self) -> list[str]:
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'" sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
ret = await self.conn.execute_query_dict(sql) ret = await self.conn.execute_query_dict(sql)
return list(map(lambda x: x["tbl_name"], ret)) return list(map(lambda x: x["tbl_name"], ret))

View File

@@ -1,11 +1,13 @@
from __future__ import annotations
import importlib import importlib
import os import os
from datetime import datetime from datetime import datetime
from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type, Union from typing import Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast
import click import asyncclick as click
import tortoise
from dictdiffer import diff from dictdiffer import diff
from tortoise import BaseDBAsyncClient, Model, Tortoise from tortoise import BaseDBAsyncClient, Model, Tortoise
from tortoise.exceptions import OperationalError from tortoise.exceptions import OperationalError
@@ -13,7 +15,12 @@ from tortoise.indexes import Index
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
from aerich.models import MAX_VERSION_LENGTH, Aerich from aerich.models import MAX_VERSION_LENGTH, Aerich
from aerich.utils import get_app_connection, get_models_describe, is_default_function from aerich.utils import (
get_app_connection,
get_dict_diff_by_key,
get_models_describe,
is_default_function,
)
MIGRATE_TEMPLATE = """from tortoise import BaseDBAsyncClient MIGRATE_TEMPLATE = """from tortoise import BaseDBAsyncClient
@@ -37,54 +44,66 @@ class Migrate:
_upgrade_m2m: List[str] = [] _upgrade_m2m: List[str] = []
_downgrade_m2m: List[str] = [] _downgrade_m2m: List[str] = []
_aerich = Aerich.__name__ _aerich = Aerich.__name__
_rename_old = [] _rename_fields: Dict[str, Dict[str, str]] = {} # {'model': {'old_field': 'new_field'}}
_rename_new = []
ddl: BaseDDL ddl: BaseDDL
ddl_class: Type[BaseDDL]
_last_version_content: Optional[dict] = None _last_version_content: Optional[dict] = None
app: str app: str
migrate_location: Path migrate_location: Path
dialect: str dialect: str
_db_version: Optional[str] = None _db_version: Optional[str] = None
@staticmethod
def get_field_by_name(name: str, fields: List[dict]) -> dict:
return next(filter(lambda x: x.get("name") == name, fields))
@classmethod @classmethod
def get_all_version_files(cls) -> List[str]: def get_all_version_files(cls) -> List[str]:
return sorted( def get_file_version(file_name: str) -> str:
filter(lambda x: x.endswith("py"), os.listdir(cls.migrate_location)), return file_name.split("_")[0]
key=lambda x: int(x.split("_")[0]),
) def is_version_file(file_name: str) -> bool:
if not file_name.endswith("py"):
return False
if "_" not in file_name:
return False
return get_file_version(file_name).isdigit()
files = filter(is_version_file, os.listdir(cls.migrate_location))
return sorted(files, key=lambda x: int(get_file_version(x)))
@classmethod @classmethod
def _get_model(cls, model: str) -> Type[Model]: def _get_model(cls, model: str) -> Type[Model]:
return Tortoise.apps.get(cls.app).get(model) return Tortoise.apps[cls.app].get(model) # type: ignore
@classmethod @classmethod
async def get_last_version(cls) -> Optional[Aerich]: async def get_last_version(cls) -> Optional[Aerich]:
try: try:
return await Aerich.filter(app=cls.app).first() return await Aerich.filter(app=cls.app).first()
except OperationalError: except OperationalError:
pass return None
@classmethod @classmethod
async def _get_db_version(cls, connection: BaseDBAsyncClient): async def _get_db_version(cls, connection: BaseDBAsyncClient) -> None:
if cls.dialect == "mysql": if cls.dialect == "mysql":
sql = "select version() as version" sql = "select version() as version"
ret = await connection.execute_query(sql) ret = await connection.execute_query(sql)
cls._db_version = ret[1][0].get("version") cls._db_version = ret[1][0].get("version")
@classmethod @classmethod
async def load_ddl_class(cls): async def load_ddl_class(cls) -> Type[BaseDDL]:
ddl_dialect_module = importlib.import_module(f"aerich.ddl.{cls.dialect}") ddl_dialect_module = importlib.import_module(f"aerich.ddl.{cls.dialect}")
return getattr(ddl_dialect_module, f"{cls.dialect.capitalize()}DDL") return getattr(ddl_dialect_module, f"{cls.dialect.capitalize()}DDL")
@classmethod @classmethod
async def init(cls, config: dict, app: str, location: str): async def init(cls, config: dict, app: str, location: str) -> None:
await Tortoise.init(config=config) await Tortoise.init(config=config)
last_version = await cls.get_last_version() last_version = await cls.get_last_version()
cls.app = app cls.app = app
cls.migrate_location = Path(location, app) cls.migrate_location = Path(location, app)
if last_version: if last_version:
cls._last_version_content = last_version.content cls._last_version_content = cast(dict, last_version.content)
connection = get_app_connection(config, app) connection = get_app_connection(config, app)
cls.dialect = connection.schema_generator.DIALECT cls.dialect = connection.schema_generator.DIALECT
@@ -93,7 +112,7 @@ class Migrate:
await cls._get_db_version(connection) await cls._get_db_version(connection)
@classmethod @classmethod
async def _get_last_version_num(cls): async def _get_last_version_num(cls) -> Optional[int]:
last_version = await cls.get_last_version() last_version = await cls.get_last_version()
if not last_version: if not last_version:
return None return None
@@ -101,7 +120,7 @@ class Migrate:
return int(version.split("_", 1)[0]) return int(version.split("_", 1)[0])
@classmethod @classmethod
async def generate_version(cls, name=None): async def generate_version(cls, name=None) -> str:
now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "") now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "")
last_version_num = await cls._get_last_version_num() last_version_num = await cls._get_last_version_num()
if last_version_num is None: if last_version_num is None:
@@ -112,33 +131,31 @@ class Migrate:
return version return version
@classmethod @classmethod
async def _generate_diff_py(cls, name): async def _generate_diff_py(cls, name) -> str:
version = await cls.generate_version(name) version = await cls.generate_version(name)
# delete if same version exists # delete if same version exists
for version_file in cls.get_all_version_files(): for version_file in cls.get_all_version_files():
if version_file.startswith(version.split("_")[0]): if version_file.startswith(version.split("_")[0]):
os.unlink(Path(cls.migrate_location, version_file)) os.unlink(Path(cls.migrate_location, version_file))
version_file = Path(cls.migrate_location, version) content = cls._get_diff_file_content()
content = MIGRATE_TEMPLATE.format( Path(cls.migrate_location, version).write_text(content, encoding="utf-8")
upgrade_sql=";\n ".join(cls.upgrade_operators) + ";",
downgrade_sql=";\n ".join(cls.downgrade_operators) + ";",
)
with open(version_file, "w", encoding="utf-8") as f:
f.write(content)
return version return version
@classmethod @classmethod
async def migrate(cls, name) -> str: async def migrate(cls, name: str, empty: bool) -> str:
""" """
diff old models and new models to generate diff content diff old models and new models to generate diff content
:param name: :param name: str name for migration
:param empty: bool if True generates empty migration
:return: :return:
""" """
if empty:
return await cls._generate_diff_py(name)
new_version_content = get_models_describe(cls.app) new_version_content = get_models_describe(cls.app)
cls.diff_models(cls._last_version_content, new_version_content) last_version = cast(dict, cls._last_version_content)
cls.diff_models(new_version_content, cls._last_version_content, False) cls.diff_models(last_version, new_version_content)
cls.diff_models(new_version_content, last_version, False)
cls._merge_operators() cls._merge_operators()
@@ -148,7 +165,23 @@ class Migrate:
return await cls._generate_diff_py(name) return await cls._generate_diff_py(name)
@classmethod @classmethod
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False): def _get_diff_file_content(cls) -> str:
"""
builds content for diff file from template
"""
def join_lines(lines: List[str]) -> str:
if not lines:
return ""
return ";\n ".join(lines) + ";"
return MIGRATE_TEMPLATE.format(
upgrade_sql=join_lines(cls.upgrade_operators),
downgrade_sql=join_lines(cls.downgrade_operators),
)
@classmethod
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False) -> None:
""" """
add operator,differentiate fk because fk is order limit add operator,differentiate fk because fk is order limit
:param operator: :param operator:
@@ -169,94 +202,76 @@ class Migrate:
cls.downgrade_operators.append(operator) cls.downgrade_operators.append(operator)
@classmethod @classmethod
def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]): def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]) -> list:
ret = [] if tortoise.__version__ > "0.22.2":
for index in indexes: # The min version of tortoise is '0.11.0', so we can compare it by a `>`,
if isinstance(index, Index): # tortoise>0.22.2 have __eq__/__hash__ with Index class since 313ee76.
index.__hash__ = lambda self: md5( # nosec: B303 return indexes
self.index_name(cls.ddl.schema_generator, model).encode() if index_classes := set(index.__class__ for index in indexes if isinstance(index, Index)):
+ self.__class__.__name__.encode() # Leave magic patch here to compare with older version of tortoise-orm
).hexdigest() # TODO: limit tortoise>0.22.2 in pyproject.toml and remove this function when v0.9.0 released
ret.append(index) for index_cls in index_classes:
return ret if index_cls(fields=("id",)) != index_cls(fields=("id",)):
def _hash(self) -> int:
return hash((tuple(sorted(self.fields)), self.name, self.expressions))
def _eq(self, other) -> bool:
return type(self) is type(other) and self.__dict__ == other.__dict__
setattr(index_cls, "__hash__", _hash)
setattr(index_cls, "__eq__", _eq)
return indexes
@classmethod @classmethod
def diff_models(cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True): def _get_indexes(cls, model, model_describe: dict) -> Set[Union[Index, Tuple[str, ...]]]:
""" indexes: Set[Union[Index, Tuple[str, ...]]] = set()
diff models and add operators for x in cls._handle_indexes(model, model_describe.get("indexes", [])):
:param old_models: if isinstance(x, Index):
:param new_models: indexes.add(x)
:param upgrade:
:return:
"""
_aerich = f"{cls.app}.{cls._aerich}"
old_models.pop(_aerich, None)
new_models.pop(_aerich, None)
for new_model_str, new_model_describe in new_models.items():
model = cls._get_model(new_model_describe.get("name").split(".")[1])
if new_model_str not in old_models.keys():
if upgrade:
cls._add_operator(cls.add_model(model), upgrade)
else: else:
# we can't find origin model when downgrade, so skip indexes.add(cast(Tuple[str, ...], tuple(x)))
return indexes
@staticmethod
def _validate_custom_m2m_through(field: dict) -> None:
# TODO: Check whether field includes required fk columns
pass pass
else:
old_model_describe = old_models.get(new_model_str) @classmethod
# rename table def _handle_m2m_fields(
new_table = new_model_describe.get("table") cls, old_model_describe: Dict, new_model_describe: Dict, model, new_models, upgrade=True
old_table = old_model_describe.get("table") ) -> None:
if new_table != old_table: old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields", []))
cls._add_operator(cls.rename_table(model, old_table, new_table), upgrade) new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields", []))
old_unique_together = set( new_tables: Dict[str, dict] = {field["table"]: field for field in new_models.values()}
map(lambda x: tuple(x), old_model_describe.get("unique_together")) for action, option, change in get_dict_diff_by_key(old_m2m_fields, new_m2m_fields):
) if (option and option[-1] == "nullable") or change[0][0] == "db_constraint":
new_unique_together = set(
map(lambda x: tuple(x), new_model_describe.get("unique_together"))
)
old_indexes = set(
map(
lambda x: x if isinstance(x, Index) else tuple(x),
cls._handle_indexes(model, old_model_describe.get("indexes", [])),
)
)
new_indexes = set(
map(
lambda x: x if isinstance(x, Index) else tuple(x),
cls._handle_indexes(model, new_model_describe.get("indexes", [])),
)
)
old_pk_field = old_model_describe.get("pk_field")
new_pk_field = new_model_describe.get("pk_field")
# pk field
changes = diff(old_pk_field, new_pk_field)
for action, option, change in changes:
# current only support rename pk
if action == "change" and option == "name":
cls._add_operator(cls._rename_field(model, *change), upgrade)
# m2m fields
old_m2m_fields = old_model_describe.get("m2m_fields")
new_m2m_fields = new_model_describe.get("m2m_fields")
for action, option, change in diff(old_m2m_fields, new_m2m_fields):
if change[0][0] == "db_constraint":
continue continue
table = change[0][1].get("through") new_value = change[0][1]
if isinstance(new_value, str):
for new_m2m_field in new_m2m_fields:
if new_m2m_field["name"] == new_value:
table = cast(str, new_m2m_field.get("through"))
break
else:
table = new_value.get("through")
if action == "add": if action == "add":
add = False add = False
if upgrade and table not in cls._upgrade_m2m: if upgrade:
if field := new_tables.get(table):
cls._validate_custom_m2m_through(field)
elif table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table) cls._upgrade_m2m.append(table)
add = True add = True
elif not upgrade and table not in cls._downgrade_m2m: else:
if table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table) cls._downgrade_m2m.append(table)
add = True add = True
if add: if add:
ref_desc = cast(dict, new_models.get(new_value.get("model_name")))
cls._add_operator( cls._add_operator(
cls.create_m2m( cls.create_m2m(model, new_value, ref_desc),
model,
change[0][1],
new_models.get(change[0][1].get("model_name")),
),
upgrade, upgrade,
fk_m2m_index=True, fk_m2m_index=True,
) )
@@ -270,6 +285,134 @@ class Migrate:
add = True add = True
if add: if add:
cls._add_operator(cls.drop_m2m(table), upgrade, True) cls._add_operator(cls.drop_m2m(table), upgrade, True)
@classmethod
def _handle_relational(
cls,
key: str,
old_model_describe: Dict,
new_model_describe: Dict,
model: Type[Model],
old_models: Dict,
new_models: Dict,
upgrade=True,
) -> None:
old_fk_fields = cast(List[dict], old_model_describe.get(key))
new_fk_fields = cast(List[dict], new_model_describe.get(key))
old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields]
new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields]
# add
for new_fk_field_name in set(new_fk_fields_name).difference(set(old_fk_fields_name)):
fk_field = cls.get_field_by_name(new_fk_field_name, new_fk_fields)
if fk_field.get("db_constraint"):
ref_describe = cast(dict, new_models[fk_field["python_type"]])
sql = cls._add_fk(model, fk_field, ref_describe)
cls._add_operator(sql, upgrade, fk_m2m_index=True)
# drop
for old_fk_field_name in set(old_fk_fields_name).difference(set(new_fk_fields_name)):
old_fk_field = cls.get_field_by_name(old_fk_field_name, cast(List[dict], old_fk_fields))
if old_fk_field.get("db_constraint"):
ref_describe = cast(dict, old_models[old_fk_field["python_type"]])
sql = cls._drop_fk(model, old_fk_field, ref_describe)
cls._add_operator(sql, upgrade, fk_m2m_index=True)
@classmethod
def _handle_fk_fields(
cls,
old_model_describe: Dict,
new_model_describe: Dict,
model: Type[Model],
old_models: Dict,
new_models: Dict,
upgrade=True,
) -> None:
key = "fk_fields"
cls._handle_relational(
key, old_model_describe, new_model_describe, model, old_models, new_models, upgrade
)
@classmethod
def _handle_o2o_fields(
cls,
old_model_describe: Dict,
new_model_describe: Dict,
model: Type[Model],
old_models: Dict,
new_models: Dict,
upgrade=True,
) -> None:
key = "o2o_fields"
cls._handle_relational(
key, old_model_describe, new_model_describe, model, old_models, new_models, upgrade
)
@classmethod
def diff_models(
cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True
) -> None:
"""
diff models and add operators
:param old_models:
:param new_models:
:param upgrade:
:return:
"""
_aerich = f"{cls.app}.{cls._aerich}"
old_models.pop(_aerich, None)
new_models.pop(_aerich, None)
models_with_rename_field: Set[str] = set() # models that trigger the click.prompt
for new_model_str, new_model_describe in new_models.items():
model = cls._get_model(new_model_describe["name"].split(".")[1])
if new_model_str not in old_models:
if upgrade:
cls._add_operator(cls.add_model(model), upgrade)
cls._handle_m2m_fields({}, new_model_describe, model, new_models, upgrade)
else:
# we can't find origin model when downgrade, so skip
pass
else:
old_model_describe = cast(dict, old_models.get(new_model_str))
# rename table
new_table = cast(str, new_model_describe.get("table"))
old_table = cast(str, old_model_describe.get("table"))
if new_table != old_table:
cls._add_operator(cls.rename_table(model, old_table, new_table), upgrade)
old_unique_together = set(
map(
lambda x: tuple(x),
cast(List[Iterable[str]], old_model_describe.get("unique_together")),
)
)
new_unique_together = set(
map(
lambda x: tuple(x),
cast(List[Iterable[str]], new_model_describe.get("unique_together")),
)
)
old_indexes = cls._get_indexes(model, old_model_describe)
new_indexes = cls._get_indexes(model, new_model_describe)
old_pk_field = old_model_describe.get("pk_field")
new_pk_field = new_model_describe.get("pk_field")
# pk field
changes = diff(old_pk_field, new_pk_field)
for action, option, change in changes:
# current only support rename pk
if action == "change" and option == "name":
cls._add_operator(cls._rename_field(model, *change), upgrade)
# fk fields
args = (old_model_describe, new_model_describe, model, old_models, new_models)
cls._handle_fk_fields(*args, upgrade=upgrade)
# o2o fields
cls._handle_o2o_fields(*args, upgrade=upgrade)
old_o2o_columns = [i["raw_field"] for i in old_model_describe.get("o2o_fields", [])]
new_o2o_columns = [i["raw_field"] for i in new_model_describe.get("o2o_fields", [])]
# m2m fields
cls._handle_m2m_fields(
old_model_describe, new_model_describe, model, new_models, upgrade
)
# add unique_together # add unique_together
for index in new_unique_together.difference(old_unique_together): for index in new_unique_together.difference(old_unique_together):
cls._add_operator(cls._add_index(model, index, True), upgrade, True) cls._add_operator(cls._add_index(model, index, True), upgrade, True)
@@ -277,70 +420,90 @@ class Migrate:
for index in old_unique_together.difference(new_unique_together): for index in old_unique_together.difference(new_unique_together):
cls._add_operator(cls._drop_index(model, index, True), upgrade, True) cls._add_operator(cls._drop_index(model, index, True), upgrade, True)
# add indexes # add indexes
for index in new_indexes.difference(old_indexes): for idx in new_indexes.difference(old_indexes):
cls._add_operator(cls._add_index(model, index, False), upgrade, True) cls._add_operator(cls._add_index(model, idx, False), upgrade, True)
# remove indexes # remove indexes
for index in old_indexes.difference(new_indexes): for idx in old_indexes.difference(new_indexes):
cls._add_operator(cls._drop_index(model, index, False), upgrade, True) cls._add_operator(cls._drop_index(model, idx, False), upgrade, True)
old_data_fields = list( old_data_fields = list(
filter( filter(
lambda x: x.get("db_field_types") is not None, lambda x: x.get("db_field_types") is not None,
old_model_describe.get("data_fields"), cast(List[dict], old_model_describe.get("data_fields")),
) )
) )
new_data_fields = list( new_data_fields = list(
filter( filter(
lambda x: x.get("db_field_types") is not None, lambda x: x.get("db_field_types") is not None,
new_model_describe.get("data_fields"), cast(List[dict], new_model_describe.get("data_fields")),
) )
) )
old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields)) old_data_fields_name = cast(List[str], [i.get("name") for i in old_data_fields])
new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields)) new_data_fields_name = cast(List[str], [i.get("name") for i in new_data_fields])
# add fields or rename fields # add fields or rename fields
for new_data_field_name in set(new_data_fields_name).difference( for new_data_field_name in set(new_data_fields_name).difference(
set(old_data_fields_name) set(old_data_fields_name)
): ):
new_data_field = next( new_data_field = cls.get_field_by_name(new_data_field_name, new_data_fields)
filter(lambda x: x.get("name") == new_data_field_name, new_data_fields)
)
is_rename = False is_rename = False
for old_data_field in old_data_fields: field_type = new_data_field.get("field_type")
db_column = new_data_field.get("db_column")
new_name = set(new_data_field_name)
for old_data_field in sorted(
old_data_fields,
key=lambda f: (
f.get("field_type") != field_type,
# old field whose name have more same characters with new field's
# should be put in front of the other
len(new_name.symmetric_difference(set(f.get("name", "")))),
),
):
changes = list(diff(old_data_field, new_data_field)) changes = list(diff(old_data_field, new_data_field))
old_data_field_name = old_data_field.get("name") old_data_field_name = cast(str, old_data_field.get("name"))
if len(changes) == 2: if len(changes) == 2:
# rename field # rename field
name_diff = (old_data_field_name, new_data_field_name)
column_diff = (old_data_field.get("db_column"), db_column)
if ( if (
changes[0] changes[0] == ("change", "name", name_diff)
== ( and changes[1] == ("change", "db_column", column_diff)
"change",
"name",
(old_data_field_name, new_data_field_name),
)
and changes[1]
== (
"change",
"db_column",
(
old_data_field.get("db_column"),
new_data_field.get("db_column"),
),
)
and old_data_field_name not in new_data_fields_name and old_data_field_name not in new_data_fields_name
): ):
if upgrade: if upgrade:
if (
rename_fields := cls._rename_fields.get(new_model_str)
) and (
old_data_field_name in rename_fields
or new_data_field_name in rename_fields.values()
):
continue
prefix = f"({new_model_str}) "
if new_model_str not in models_with_rename_field:
if models_with_rename_field:
# When there are multi rename fields with different models,
# print a empty line to warn that is another model
prefix = "\n" + prefix
models_with_rename_field.add(new_model_str)
is_rename = click.prompt( is_rename = click.prompt(
f"Rename {old_data_field_name} to {new_data_field_name}?", f"{prefix}Rename {old_data_field_name} to {new_data_field_name}?",
default=True, default=True,
type=bool, type=bool,
show_choices=True, show_choices=True,
) )
else:
is_rename = old_data_field_name in cls._rename_new
if is_rename: if is_rename:
cls._rename_new.append(new_data_field_name) if rename_fields is None:
cls._rename_old.append(old_data_field_name) rename_fields = cls._rename_fields[new_model_str] = {}
rename_fields[old_data_field_name] = new_data_field_name
else:
is_rename = False
if rename_to := cls._rename_fields.get(new_model_str, {}).get(
new_data_field_name
):
is_rename = True
if rename_to != old_data_field_name:
continue
if is_rename:
# only MySQL8+ has rename syntax # only MySQL8+ has rename syntax
if ( if (
cls.dialect == "mysql" cls.dialect == "mysql"
@@ -359,107 +522,63 @@ class Migrate:
upgrade, upgrade,
) )
if not is_rename: if not is_rename:
cls._add_operator( cls._add_operator(cls._add_field(model, new_data_field), upgrade)
cls._add_field( if (
model, new_data_field["indexed"]
new_data_field, and new_data_field["db_column"] not in new_o2o_columns
), ):
upgrade,
)
if new_data_field["indexed"]:
cls._add_operator( cls._add_operator(
cls._add_index( cls._add_index(
model, {new_data_field["db_column"]}, new_data_field["unique"] model, (new_data_field["db_column"],), new_data_field["unique"]
), ),
upgrade, upgrade,
True, True,
) )
# remove fields # remove fields
rename_fields = cls._rename_fields.get(new_model_str)
for old_data_field_name in set(old_data_fields_name).difference( for old_data_field_name in set(old_data_fields_name).difference(
set(new_data_fields_name) set(new_data_fields_name)
): ):
# don't remove field if is renamed # don't remove field if is renamed
if (upgrade and old_data_field_name in cls._rename_old) or ( if rename_fields and (
not upgrade and old_data_field_name in cls._rename_new (upgrade and old_data_field_name in rename_fields)
or (not upgrade and old_data_field_name in rename_fields.values())
): ):
continue continue
old_data_field = next( old_data_field = cls.get_field_by_name(old_data_field_name, old_data_fields)
filter(lambda x: x.get("name") == old_data_field_name, old_data_fields) db_column = cast(str, old_data_field["db_column"])
)
db_column = old_data_field["db_column"]
cls._add_operator( cls._add_operator(
cls._remove_field( cls._remove_field(model, db_column),
model,
db_column,
),
upgrade, upgrade,
) )
if old_data_field["indexed"]: if (
old_data_field["indexed"]
and old_data_field["db_column"] not in old_o2o_columns
):
is_unique_field = old_data_field.get("unique")
cls._add_operator( cls._add_operator(
cls._drop_index( cls._drop_index(model, {db_column}, is_unique_field),
model,
{db_column},
),
upgrade, upgrade,
True, True,
) )
old_fk_fields = old_model_describe.get("fk_fields")
new_fk_fields = new_model_describe.get("fk_fields")
old_fk_fields_name = list(map(lambda x: x.get("name"), old_fk_fields))
new_fk_fields_name = list(map(lambda x: x.get("name"), new_fk_fields))
# add fk
for new_fk_field_name in set(new_fk_fields_name).difference(
set(old_fk_fields_name)
):
fk_field = next(
filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields)
)
if fk_field.get("db_constraint"):
cls._add_operator(
cls._add_fk(
model, fk_field, new_models.get(fk_field.get("python_type"))
),
upgrade,
fk_m2m_index=True,
)
# drop fk
for old_fk_field_name in set(old_fk_fields_name).difference(
set(new_fk_fields_name)
):
old_fk_field = next(
filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields)
)
if old_fk_field.get("db_constraint"):
cls._add_operator(
cls._drop_fk(
model, old_fk_field, old_models.get(old_fk_field.get("python_type"))
),
upgrade,
fk_m2m_index=True,
)
# change fields # change fields
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)): for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
old_data_field = next( old_data_field = cls.get_field_by_name(field_name, old_data_fields)
filter(lambda x: x.get("name") == field_name, old_data_fields) new_data_field = cls.get_field_by_name(field_name, new_data_fields)
)
new_data_field = next(
filter(lambda x: x.get("name") == field_name, new_data_fields)
)
changes = diff(old_data_field, new_data_field) changes = diff(old_data_field, new_data_field)
modified = False modified = False
for change in changes: for change in changes:
_, option, old_new = change _, option, old_new = change
if option == "indexed": if option == "indexed":
# change index # change index
unique = new_data_field.get("unique")
if old_new[0] is False and old_new[1] is True: if old_new[0] is False and old_new[1] is True:
unique = new_data_field.get("unique")
cls._add_operator( cls._add_operator(
cls._add_index(model, (field_name,), unique), upgrade, True cls._add_index(model, (field_name,), unique), upgrade, True
) )
else: else:
unique = old_data_field.get("unique")
cls._add_operator( cls._add_operator(
cls._drop_index(model, (field_name,), unique), upgrade, True cls._drop_index(model, (field_name,), unique), upgrade, True
) )
@@ -486,6 +605,9 @@ class Migrate:
elif option == "nullable": elif option == "nullable":
# change nullable # change nullable
cls._add_operator(cls._alter_null(model, new_data_field), upgrade) cls._add_operator(cls._alter_null(model, new_data_field), upgrade)
elif option == "description":
# change comment
cls._add_operator(cls._set_comment(model, new_data_field), upgrade)
else: else:
if modified: if modified:
continue continue
@@ -496,103 +618,118 @@ class Migrate:
) )
modified = True modified = True
for old_model in old_models: for old_model in old_models.keys() - new_models.keys():
if old_model not in new_models.keys(): cls._add_operator(cls.drop_model(old_models[old_model]["table"]), upgrade)
cls._add_operator(cls.drop_model(old_models.get(old_model).get("table")), upgrade)
@classmethod @classmethod
def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str): def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str) -> str:
return cls.ddl.rename_table(model, old_table_name, new_table_name) return cls.ddl.rename_table(model, old_table_name, new_table_name)
@classmethod @classmethod
def add_model(cls, model: Type[Model]): def add_model(cls, model: Type[Model]) -> str:
return cls.ddl.create_table(model) return cls.ddl.create_table(model)
@classmethod @classmethod
def drop_model(cls, table_name: str): def drop_model(cls, table_name: str) -> str:
return cls.ddl.drop_table(table_name) return cls.ddl.drop_table(table_name)
@classmethod @classmethod
def create_m2m(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict): def create_m2m(
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
return cls.ddl.create_m2m(model, field_describe, reference_table_describe) return cls.ddl.create_m2m(model, field_describe, reference_table_describe)
@classmethod @classmethod
def drop_m2m(cls, table_name: str): def drop_m2m(cls, table_name: str) -> str:
return cls.ddl.drop_m2m(table_name) return cls.ddl.drop_m2m(table_name)
@classmethod @classmethod
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]): def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Iterable[str]) -> List[str]:
ret = [] ret = []
for field_name in fields_name: for field_name in fields_name:
try:
field = model._meta.fields_map[field_name] field = model._meta.fields_map[field_name]
if field.source_field: except KeyError:
ret.append(field.source_field) # field dropped or to be add
elif field_name in model._meta.fk_fields: pass
ret.append(field_name + "_id")
else: else:
if field.source_field:
field_name = field.source_field
elif field_name in model._meta.fk_fields:
field_name += "_id"
ret.append(field_name) ret.append(field_name)
return ret return ret
@classmethod @classmethod
def _drop_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False): def _drop_index(
cls, model: Type[Model], fields_name: Union[Iterable[str], Index], unique=False
) -> str:
if isinstance(fields_name, Index): if isinstance(fields_name, Index):
return cls.ddl.drop_index_by_name( return cls.ddl.drop_index_by_name(
model, fields_name.index_name(cls.ddl.schema_generator, model) model, fields_name.index_name(cls.ddl.schema_generator, model)
) )
fields_name = cls._resolve_fk_fields_name(model, fields_name) field_names = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.drop_index(model, fields_name, unique) return cls.ddl.drop_index(model, field_names, unique)
@classmethod @classmethod
def _add_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False): def _add_index(
cls, model: Type[Model], fields_name: Union[Iterable[str], Index], unique=False
) -> str:
if isinstance(fields_name, Index): if isinstance(fields_name, Index):
return fields_name.get_sql(cls.ddl.schema_generator, model, False) return fields_name.get_sql(cls.ddl.schema_generator, model, False)
fields_name = cls._resolve_fk_fields_name(model, fields_name) field_names = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.add_index(model, fields_name, unique) return cls.ddl.add_index(model, field_names, unique)
@classmethod @classmethod
def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False): def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False) -> str:
return cls.ddl.add_column(model, field_describe, is_pk) return cls.ddl.add_column(model, field_describe, is_pk)
@classmethod @classmethod
def _alter_default(cls, model: Type[Model], field_describe: dict): def _alter_default(cls, model: Type[Model], field_describe: dict) -> str:
return cls.ddl.alter_column_default(model, field_describe) return cls.ddl.alter_column_default(model, field_describe)
@classmethod @classmethod
def _alter_null(cls, model: Type[Model], field_describe: dict): def _alter_null(cls, model: Type[Model], field_describe: dict) -> str:
return cls.ddl.alter_column_null(model, field_describe) return cls.ddl.alter_column_null(model, field_describe)
@classmethod @classmethod
def _set_comment(cls, model: Type[Model], field_describe: dict): def _set_comment(cls, model: Type[Model], field_describe: dict) -> str:
return cls.ddl.set_comment(model, field_describe) return cls.ddl.set_comment(model, field_describe)
@classmethod @classmethod
def _modify_field(cls, model: Type[Model], field_describe: dict): def _modify_field(cls, model: Type[Model], field_describe: dict) -> str:
return cls.ddl.modify_column(model, field_describe) return cls.ddl.modify_column(model, field_describe)
@classmethod @classmethod
def _drop_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict): def _drop_fk(
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
return cls.ddl.drop_fk(model, field_describe, reference_table_describe) return cls.ddl.drop_fk(model, field_describe, reference_table_describe)
@classmethod @classmethod
def _remove_field(cls, model: Type[Model], column_name: str): def _remove_field(cls, model: Type[Model], column_name: str) -> str:
return cls.ddl.drop_column(model, column_name) return cls.ddl.drop_column(model, column_name)
@classmethod @classmethod
def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str): def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str) -> str:
return cls.ddl.rename_column(model, old_field_name, new_field_name) return cls.ddl.rename_column(model, old_field_name, new_field_name)
@classmethod @classmethod
def _change_field(cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict): def _change_field(
db_field_types = new_field_describe.get("db_field_types") cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict
) -> str:
db_field_types = cast(dict, new_field_describe.get("db_field_types"))
return cls.ddl.change_column( return cls.ddl.change_column(
model, model,
old_field_describe.get("db_column"), cast(str, old_field_describe.get("db_column")),
new_field_describe.get("db_column"), cast(str, new_field_describe.get("db_column")),
db_field_types.get(cls.dialect) or db_field_types.get(""), cast(str, db_field_types.get(cls.dialect) or db_field_types.get("")),
) )
@classmethod @classmethod
def _add_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict): def _add_fk(
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
""" """
add fk add fk
:param model: :param model:
@@ -603,7 +740,7 @@ class Migrate:
return cls.ddl.add_fk(model, field_describe, reference_table_describe) return cls.ddl.add_fk(model, field_describe, reference_table_describe)
@classmethod @classmethod
def _merge_operators(cls): def _merge_operators(cls) -> None:
""" """
fk/m2m/index must be last when add,first when drop fk/m2m/index must be last when add,first when drop
:return: :return:

View File

@@ -9,7 +9,7 @@ MAX_APP_LENGTH = 100
class Aerich(Model): class Aerich(Model):
version = fields.CharField(max_length=MAX_VERSION_LENGTH) version = fields.CharField(max_length=MAX_VERSION_LENGTH)
app = fields.CharField(max_length=MAX_APP_LENGTH) app = fields.CharField(max_length=MAX_APP_LENGTH)
content = fields.JSONField(encoder=encoder, decoder=decoder) content: dict = fields.JSONField(encoder=encoder, decoder=decoder)
class Meta: class Meta:
ordering = ["-id"] ordering = ["-id"]

View File

@@ -1,11 +1,15 @@
from __future__ import annotations
import importlib.util import importlib.util
import os import os
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict from types import ModuleType
from typing import Dict, Generator, Optional, Union
from click import BadOptionUsage, ClickException, Context from asyncclick import BadOptionUsage, ClickException, Context
from dictdiffer import diff
from tortoise import BaseDBAsyncClient, Tortoise from tortoise import BaseDBAsyncClient, Tortoise
@@ -84,19 +88,59 @@ def get_models_describe(app: str) -> Dict:
:return: :return:
""" """
ret = {} ret = {}
for model in Tortoise.apps.get(app).values(): for model in Tortoise.apps[app].values():
describe = model.describe() describe = model.describe()
ret[describe.get("name")] = describe ret[describe.get("name")] = describe
return ret return ret
def is_default_function(string: str): def is_default_function(string: str) -> Optional[re.Match]:
return re.match(r"^<function.+>$", str(string or "")) return re.match(r"^<function.+>$", str(string or ""))
def import_py_file(file: Path): def import_py_file(file: Union[str, Path]) -> ModuleType:
module_name, file_ext = os.path.splitext(os.path.split(file)[-1]) module_name, file_ext = os.path.splitext(os.path.split(file)[-1])
spec = importlib.util.spec_from_file_location(module_name, file) spec = importlib.util.spec_from_file_location(module_name, file)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec) # type:ignore[arg-type]
spec.loader.exec_module(module) spec.loader.exec_module(module) # type:ignore[union-attr]
return module return module
def get_dict_diff_by_key(
old_fields: list[dict], new_fields: list[dict], key="through"
) -> Generator[tuple]:
"""
Compare two list by key instead of by index
:param old_fields: previous field info list
:param new_fields: current field info list
:param key: if two dicts have the same value of this key, action is change; otherwise, is remove/add
:return: similar to dictdiffer.diff
Example::
>>> old = [{'through': 'a'}, {'through': 'b'}, {'through': 'c'}]
>>> new = [{'through': 'a'}, {'through': 'c'}] # remove the second element
>>> list(diff(old, new))
[('change', [1, 'through'], ('b', 'c')),
('remove', '', [(2, {'through': 'c'})])]
>>> list(get_dict_diff_by_key(old, new))
[('remove', '', [(0, {'through': 'b'})])]
"""
length_old, length_new = len(old_fields), len(new_fields)
if length_old == 0 or length_new == 0 or length_old == length_new == 1:
yield from diff(old_fields, new_fields)
else:
value_index: dict[str, int] = {f[key]: i for i, f in enumerate(new_fields)}
additions = set(range(length_new))
for field in old_fields:
value = field[key]
if (index := value_index.get(value)) is not None:
additions.remove(index)
yield from diff([field], [new_fields[index]]) # change
else:
yield from diff([field], []) # remove
if additions:
for index in sorted(additions):
yield from diff([], [new_fields[index]]) # add

View File

@@ -1 +1 @@
__version__ = "0.7.2" __version__ = "0.8.0"

View File

@@ -1,19 +1,23 @@
import asyncio import asyncio
import contextlib
import os import os
from typing import Generator
import pytest import pytest
from tortoise import Tortoise, expand_db_url, generate_schema_for_client from tortoise import Tortoise, expand_db_url, generate_schema_for_client
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
from tortoise.contrib.test import MEMORY_SQLITE
from tortoise.exceptions import DBConnectionError, OperationalError
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
from aerich.migrate import Migrate from aerich.migrate import Migrate
db_url = os.getenv("TEST_DB", "sqlite://:memory:") db_url = os.getenv("TEST_DB", MEMORY_SQLITE)
db_url_second = os.getenv("TEST_DB_SECOND", "sqlite://:memory:") db_url_second = os.getenv("TEST_DB_SECOND", MEMORY_SQLITE)
tortoise_orm = { tortoise_orm = {
"connections": { "connections": {
"default": expand_db_url(db_url, True), "default": expand_db_url(db_url, True),
@@ -27,7 +31,7 @@ tortoise_orm = {
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def reset_migrate(): def reset_migrate() -> None:
Migrate.upgrade_operators = [] Migrate.upgrade_operators = []
Migrate.downgrade_operators = [] Migrate.downgrade_operators = []
Migrate._upgrade_fk_m2m_index_operators = [] Migrate._upgrade_fk_m2m_index_operators = []
@@ -37,20 +41,25 @@ def reset_migrate():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def event_loop(): def event_loop() -> Generator:
policy = asyncio.get_event_loop_policy() policy = asyncio.get_event_loop_policy()
res = policy.new_event_loop() res = policy.new_event_loop()
asyncio.set_event_loop(res) asyncio.set_event_loop(res)
res._close = res.close res._close = res.close # type:ignore[attr-defined]
res.close = lambda: None res.close = lambda: None # type:ignore[method-assign]
yield res yield res
res._close() res._close() # type:ignore[attr-defined]
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
async def initialize_tests(event_loop, request): async def initialize_tests(event_loop, request) -> None:
# Placing init outside the try block since it doesn't
# establish connections to the DB eagerly.
await Tortoise.init(config=tortoise_orm)
with contextlib.suppress(DBConnectionError, OperationalError):
await Tortoise._drop_databases()
await Tortoise.init(config=tortoise_orm, _create_db=True) await Tortoise.init(config=tortoise_orm, _create_db=True)
await generate_schema_for_client(Tortoise.get_connection("default"), safe=True) await generate_schema_for_client(Tortoise.get_connection("default"), safe=True)

1363
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "aerich" name = "aerich"
version = "0.7.2" version = "0.8.1"
description = "A database migrations tool for Tortoise ORM." description = "A database migrations tool for Tortoise ORM."
authors = ["long2ice <long2ice@gmail.com>"] authors = ["long2ice <long2ice@gmail.com>"]
license = "Apache-2.0" license = "Apache-2.0"
@@ -15,29 +15,35 @@ packages = [
include = ["CHANGELOG.md", "LICENSE", "README.md"] include = ["CHANGELOG.md", "LICENSE", "README.md"]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.7" python = "^3.8"
tortoise-orm = "*" tortoise-orm = ">=0.21"
click = "*"
asyncpg = { version = "*", optional = true } asyncpg = { version = "*", optional = true }
asyncmy = { version = "^0.2.8rc1", optional = true, allow-prereleases = true } asyncmy = { version = "^0.2.9", optional = true, allow-prereleases = true }
pydantic = "*" pydantic = "^2.0,!=2.7.0"
dictdiffer = "*" dictdiffer = "*"
tomlkit = "*" tomlkit = { version = "*", optional = true, python="<3.11" }
tomli-w = { version = "^1.1.0", optional = true, python=">=3.11" }
asyncclick = "^8.1.7.2"
[tool.poetry.dev-dependencies] [tool.poetry.group.dev.dependencies]
ruff = "*" ruff = "*"
isort = "*" isort = "*"
black = "*" black = "*"
pytest = "*" pytest = "*"
pytest-xdist = "*" pytest-xdist = "*"
pytest-asyncio = "*" # Breaking change in 0.23.*
# https://github.com/pytest-dev/pytest-asyncio/issues/706
pytest-asyncio = "^0.21.2"
bandit = "*" bandit = "*"
pytest-mock = "*" pytest-mock = "*"
cryptography = "*" cryptography = "*"
mypy = "^1.10.0"
[tool.poetry.extras] [tool.poetry.extras]
asyncmy = ["asyncmy"] asyncmy = ["asyncmy"]
asyncpg = ["asyncpg"] asyncpg = ["asyncpg"]
toml = ["tomlkit", "tomli-w"]
[tool.aerich] [tool.aerich]
tortoise_orm = "conftest.tortoise_orm" tortoise_orm = "conftest.tortoise_orm"
@@ -45,22 +51,25 @@ location = "./migrations"
src_folder = "./." src_folder = "./."
[build-system] [build-system]
requires = ["poetry>=0.12"] requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.poetry.scripts] [tool.poetry.scripts]
aerich = "aerich.cli:main" aerich = "aerich.cli:main"
[tool.black] [tool.black]
line-length = 100 line-length = 100
target-version = ['py36', 'py37', 'py38', 'py39'] target-version = ['py38', 'py39', 'py310', 'py311', 'py312']
[tool.pytest.ini_options] [tool.pytest.ini_options]
asyncio_mode = 'auto' asyncio_mode = 'auto'
[tool.mypy] [tool.mypy]
pretty = true pretty = true
python_version = "3.8"
ignore_missing_imports = true ignore_missing_imports = true
[tool.ruff] [tool.ruff]
line-length = 100
[tool.ruff.lint]
ignore = ['E501'] ignore = ['E501']

7
tests/indexes.py Normal file
View File

@@ -0,0 +1,7 @@
from tortoise.indexes import Index
class CustomIndex(Index):
def __init__(self, *args, **kw) -> None:
super().__init__(*args, **kw)
self._foo = ""

View File

@@ -3,6 +3,9 @@ import uuid
from enum import IntEnum from enum import IntEnum
from tortoise import Model, fields from tortoise import Model, fields
from tortoise.indexes import Index
from tests.indexes import CustomIndex
class ProductType(IntEnum): class ProductType(IntEnum):
@@ -31,13 +34,20 @@ class User(Model):
intro = fields.TextField(default="") intro = fields.TextField(default="")
longitude = fields.DecimalField(max_digits=10, decimal_places=8) longitude = fields.DecimalField(max_digits=10, decimal_places=8)
products: fields.ManyToManyRelation["Product"]
class Meta:
# reverse indexes elements
indexes = [CustomIndex(fields=("is_superuser",)), Index(fields=("username", "is_active"))]
class Email(Model): class Email(Model):
email_id = fields.IntField(pk=True) email_id = fields.IntField(primary_key=True)
email = fields.CharField(max_length=200, index=True) email = fields.CharField(max_length=200, db_index=True)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
address = fields.CharField(max_length=200) address = fields.CharField(max_length=200)
users = fields.ManyToManyField("models.User") users: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User")
config: fields.OneToOneRelation["Config"] = fields.OneToOneField("models.Config")
def default_name(): def default_name():
@@ -47,12 +57,20 @@ def default_name():
class Category(Model): class Category(Model):
slug = fields.CharField(max_length=100) slug = fields.CharField(max_length=100)
name = fields.CharField(max_length=200, null=True, default=default_name) name = fields.CharField(max_length=200, null=True, default=default_name)
user = fields.ForeignKeyField("models.User", description="User") owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", description="User"
)
title = fields.CharField(max_length=20, unique=False)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Product(Model): class Product(Model):
categories = fields.ManyToManyField("models.Category") categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models.Category", null=False
)
users: fields.ManyToManyRelation[User] = fields.ManyToManyField(
"models.User", related_name="products"
)
name = fields.CharField(max_length=50) name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num", default=0) view_num = fields.IntField(description="View Num", default=0)
sort = fields.IntField() sort = fields.IntField()
@@ -63,6 +81,7 @@ class Product(Model):
pic = fields.CharField(max_length=200) pic = fields.CharField(max_length=200)
body = fields.TextField() body = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
is_deleted = fields.BooleanField(default=False)
class Meta: class Meta:
unique_together = (("name", "type"),) unique_together = (("name", "type"),)
@@ -70,11 +89,18 @@ class Product(Model):
class Config(Model): class Config(Model):
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models.Category", through="config_category_map", related_name="category_set"
)
label = fields.CharField(max_length=200) label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20) key = fields.CharField(max_length=20)
value = fields.JSONField() value: dict = fields.JSONField()
status: Status = fields.IntEnumField(Status) status: Status = fields.IntEnumField(Status)
user = fields.ForeignKeyField("models.User", description="User") user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", description="User"
)
email: fields.OneToOneRelation["Email"]
class NewModel(Model): class NewModel(Model):

View File

@@ -34,18 +34,24 @@ class User(Model):
class Email(Model): class Email(Model):
email = fields.CharField(max_length=200) email = fields.CharField(max_length=200)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("models_second.User", db_constraint=False) user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models_second.User", db_constraint=False
)
class Category(Model): class Category(Model):
slug = fields.CharField(max_length=200) slug = fields.CharField(max_length=200)
name = fields.CharField(max_length=200) name = fields.CharField(max_length=200)
user = fields.ForeignKeyField("models_second.User", description="User") user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models_second.User", description="User"
)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Product(Model): class Product(Model):
categories = fields.ManyToManyField("models_second.Category") categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models_second.Category"
)
name = fields.CharField(max_length=50) name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num") view_num = fields.IntField(description="View Num")
sort = fields.IntField() sort = fields.IntField()
@@ -61,5 +67,5 @@ class Product(Model):
class Config(Model): class Config(Model):
label = fields.CharField(max_length=200) label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20) key = fields.CharField(max_length=20)
value = fields.JSONField() value: dict = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on) status: Status = fields.IntEnumField(Status, default=Status.on)

View File

@@ -2,6 +2,9 @@ import datetime
from enum import IntEnum from enum import IntEnum
from tortoise import Model, fields from tortoise import Model, fields
from tortoise.indexes import Index
from tests.indexes import CustomIndex
class ProductType(IntEnum): class ProductType(IntEnum):
@@ -31,38 +34,52 @@ class User(Model):
intro = fields.TextField(default="") intro = fields.TextField(default="")
longitude = fields.DecimalField(max_digits=12, decimal_places=9) longitude = fields.DecimalField(max_digits=12, decimal_places=9)
class Meta:
indexes = [Index(fields=("username", "is_active")), CustomIndex(fields=("is_superuser",))]
class Email(Model): class Email(Model):
email = fields.CharField(max_length=200) email = fields.CharField(max_length=200)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("models.User", db_constraint=False) user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", db_constraint=False
)
class Category(Model): class Category(Model):
slug = fields.CharField(max_length=200) slug = fields.CharField(max_length=200)
name = fields.CharField(max_length=200) name = fields.CharField(max_length=200)
user = fields.ForeignKeyField("models.User", description="User") user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", description="User"
)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Product(Model): class Product(Model):
categories = fields.ManyToManyField("models.Category") categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category")
uid = fields.IntField(source_field="uuid", unique=True)
name = fields.CharField(max_length=50) name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num") view_num = fields.IntField(description="View Num")
sort = fields.IntField() sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed") is_review = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField( type = fields.IntEnumField(
ProductType, description="Product Type", source_field="type_db_alias" ProductType, description="Product Type", source_field="type_db_alias"
) )
image = fields.CharField(max_length=200) image = fields.CharField(max_length=200)
body = fields.TextField() body = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
is_delete = fields.BooleanField(default=False)
class Config(Model): class Config(Model):
category: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category")
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models.Category", through="config_category_map", related_name="config_set"
)
name = fields.CharField(max_length=100, unique=True)
label = fields.CharField(max_length=200) label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20) key = fields.CharField(max_length=20)
value = fields.JSONField() value: dict = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on) status: Status = fields.IntEnumField(Status, default=Status.on)
class Meta: class Meta:

View File

@@ -14,9 +14,10 @@ def test_create_table():
`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT, `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
`slug` VARCHAR(100) NOT NULL, `slug` VARCHAR(100) NOT NULL,
`name` VARCHAR(200), `name` VARCHAR(200),
`title` VARCHAR(20) NOT NULL,
`created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
`user_id` INT NOT NULL COMMENT 'User', `owner_id` INT NOT NULL COMMENT 'User',
CONSTRAINT `fk_category_user_e2e3874c` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE
) CHARACTER SET utf8mb4""" ) CHARACTER SET utf8mb4"""
) )
@@ -27,8 +28,9 @@ def test_create_table():
"id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
"slug" VARCHAR(100) NOT NULL, "slug" VARCHAR(100) NOT NULL,
"name" VARCHAR(200), "name" VARCHAR(200),
"title" VARCHAR(20) NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */ "owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */
)""" )"""
) )
@@ -39,10 +41,11 @@ def test_create_table():
"id" SERIAL NOT NULL PRIMARY KEY, "id" SERIAL NOT NULL PRIMARY KEY,
"slug" VARCHAR(100) NOT NULL, "slug" VARCHAR(100) NOT NULL,
"name" VARCHAR(200), "name" VARCHAR(200),
"title" VARCHAR(20) NOT NULL,
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, "created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE "owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE
); );
COMMENT ON COLUMN "category"."user_id" IS 'User'""" COMMENT ON COLUMN "category"."owner_id" IS 'User'"""
) )
@@ -60,6 +63,14 @@ def test_add_column():
assert ret == "ALTER TABLE `category` ADD `name` VARCHAR(200)" assert ret == "ALTER TABLE `category` ADD `name` VARCHAR(200)"
else: else:
assert ret == 'ALTER TABLE "category" ADD "name" VARCHAR(200)' assert ret == 'ALTER TABLE "category" ADD "name" VARCHAR(200)'
# add unique column
ret = Migrate.ddl.add_column(User, User._meta.fields_map.get("username").describe(False))
if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `user` ADD `username` VARCHAR(20) NOT NULL UNIQUE"
elif isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "user" ADD "username" VARCHAR(20) NOT NULL UNIQUE'
else:
assert ret == 'ALTER TABLE "user" ADD "username" VARCHAR(20) NOT NULL'
def test_modify_column(): def test_modify_column():
@@ -134,8 +145,8 @@ def test_set_comment():
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("name").describe(False)) ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("name").describe(False))
assert ret == 'COMMENT ON COLUMN "category"."name" IS NULL' assert ret == 'COMMENT ON COLUMN "category"."name" IS NULL'
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("user").describe(False)) ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("owner").describe(False))
assert ret == 'COMMENT ON COLUMN "category"."user_id" IS \'User\'' assert ret == 'COMMENT ON COLUMN "category"."owner_id" IS \'User\''
def test_drop_column(): def test_drop_column():
@@ -151,17 +162,10 @@ def test_add_index():
index_u = Migrate.ddl.add_index(Category, ["name"], True) index_u = Migrate.ddl.add_index(Category, ["name"], True)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert index == "ALTER TABLE `category` ADD INDEX `idx_category_name_8b0cb9` (`name`)" assert index == "ALTER TABLE `category` ADD INDEX `idx_category_name_8b0cb9` (`name`)"
assert ( assert index_u == "ALTER TABLE `category` ADD UNIQUE INDEX `name` (`name`)"
index_u == "ALTER TABLE `category` ADD UNIQUE INDEX `uid_category_name_8b0cb9` (`name`)" else:
)
elif isinstance(Migrate.ddl, PostgresDDL):
assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")' assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")'
assert index_u == 'CREATE UNIQUE INDEX "uid_category_name_8b0cb9" ON "category" ("name")' assert index_u == 'CREATE UNIQUE INDEX "uid_category_name_8b0cb9" ON "category" ("name")'
else:
assert index == 'ALTER TABLE "category" ADD INDEX "idx_category_name_8b0cb9" ("name")'
assert (
index_u == 'ALTER TABLE "category" ADD UNIQUE INDEX "uid_category_name_8b0cb9" ("name")'
)
def test_drop_index(): def test_drop_index():
@@ -169,38 +173,35 @@ def test_drop_index():
ret_u = Migrate.ddl.drop_index(Category, ["name"], True) ret_u = Migrate.ddl.drop_index(Category, ["name"], True)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` DROP INDEX `idx_category_name_8b0cb9`" assert ret == "ALTER TABLE `category` DROP INDEX `idx_category_name_8b0cb9`"
assert ret_u == "ALTER TABLE `category` DROP INDEX `uid_category_name_8b0cb9`" assert ret_u == "ALTER TABLE `category` DROP INDEX `name`"
elif isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'DROP INDEX "idx_category_name_8b0cb9"'
assert ret_u == 'DROP INDEX "uid_category_name_8b0cb9"'
else: else:
assert ret == 'ALTER TABLE "category" DROP INDEX "idx_category_name_8b0cb9"' assert ret == 'DROP INDEX IF EXISTS "idx_category_name_8b0cb9"'
assert ret_u == 'ALTER TABLE "category" DROP INDEX "uid_category_name_8b0cb9"' assert ret_u == 'DROP INDEX IF EXISTS "uid_category_name_8b0cb9"'
def test_add_fk(): def test_add_fk():
ret = Migrate.ddl.add_fk( ret = Migrate.ddl.add_fk(
Category, Category._meta.fields_map.get("user").describe(False), User.describe(False) Category, Category._meta.fields_map.get("owner").describe(False), User.describe(False)
) )
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ( assert (
ret ret
== "ALTER TABLE `category` ADD CONSTRAINT `fk_category_user_e2e3874c` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE" == "ALTER TABLE `category` ADD CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE"
) )
else: else:
assert ( assert (
ret ret
== 'ALTER TABLE "category" ADD CONSTRAINT "fk_category_user_e2e3874c" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE' == 'ALTER TABLE "category" ADD CONSTRAINT "fk_category_user_110d4c63" FOREIGN KEY ("owner_id") REFERENCES "user" ("id") ON DELETE CASCADE'
) )
def test_drop_fk(): def test_drop_fk():
ret = Migrate.ddl.drop_fk( ret = Migrate.ddl.drop_fk(
Category, Category._meta.fields_map.get("user").describe(False), User.describe(False) Category, Category._meta.fields_map.get("owner").describe(False), User.describe(False)
) )
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` DROP FOREIGN KEY `fk_category_user_e2e3874c`" assert ret == "ALTER TABLE `category` DROP FOREIGN KEY `fk_category_user_110d4c63`"
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" DROP CONSTRAINT "fk_category_user_e2e3874c"' assert ret == 'ALTER TABLE "category" DROP CONSTRAINT IF EXISTS "fk_category_user_110d4c63"'
else: else:
assert ret == 'ALTER TABLE "category" DROP FOREIGN KEY "fk_category_user_e2e3874c"' assert ret == 'ALTER TABLE "category" DROP FOREIGN KEY "fk_category_user_110d4c63"'

View File

@@ -1,13 +1,21 @@
from pathlib import Path
import pytest import pytest
import tortoise
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from tortoise.indexes import Index
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
from aerich.exceptions import NotSupportError from aerich.exceptions import NotSupportError
from aerich.migrate import Migrate from aerich.migrate import MIGRATE_TEMPLATE, Migrate
from aerich.utils import get_models_describe from aerich.utils import get_models_describe
from tests.indexes import CustomIndex
# tortoise-orm>=0.21 changes IntField constraints
# from {"ge": 1, "le": 2147483647} to {"ge": -2147483648, "le": 2147483647}
MIN_INT = 1 if tortoise.__version__ < "0.21" else -2147483648
old_models_describe = { old_models_describe = {
"models.Category": { "models.Category": {
"name": "models.Category", "name": "models.Category",
@@ -30,7 +38,7 @@ old_models_describe = {
"default": None, "default": None,
"description": None, "description": None,
"docstring": None, "docstring": None,
"constraints": {"ge": 1, "le": 2147483647}, "constraints": {"ge": MIN_INT, "le": 2147483647},
"db_field_types": {"": "INT"}, "db_field_types": {"": "INT"},
}, },
"data_fields": [ "data_fields": [
@@ -97,9 +105,24 @@ old_models_describe = {
"default": None, "default": None,
"description": "User", "description": "User",
"docstring": None, "docstring": None,
"constraints": {"ge": 1, "le": 2147483647}, "constraints": {"ge": MIN_INT, "le": 2147483647},
"db_field_types": {"": "INT"}, "db_field_types": {"": "INT"},
}, },
{
"name": "title",
"field_type": "CharField",
"db_column": "title",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 20},
"db_field_types": {"": "VARCHAR(20)"},
},
], ],
"fk_fields": [ "fk_fields": [
{ {
@@ -165,10 +188,25 @@ old_models_describe = {
"default": None, "default": None,
"description": None, "description": None,
"docstring": None, "docstring": None,
"constraints": {"ge": 1, "le": 2147483647}, "constraints": {"ge": MIN_INT, "le": 2147483647},
"db_field_types": {"": "INT"}, "db_field_types": {"": "INT"},
}, },
"data_fields": [ "data_fields": [
{
"name": "name",
"field_type": "CharField",
"db_column": "name",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 100},
"db_field_types": {"": "VARCHAR(100)"},
},
{ {
"name": "label", "name": "label",
"field_type": "CharField", "field_type": "CharField",
@@ -234,7 +272,48 @@ old_models_describe = {
"backward_fk_fields": [], "backward_fk_fields": [],
"o2o_fields": [], "o2o_fields": [],
"backward_o2o_fields": [], "backward_o2o_fields": [],
"m2m_fields": [], "m2m_fields": [
{
"name": "category",
"field_type": "ManyToManyFieldInstance",
"python_type": "models.Category",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"model_name": "models.Category",
"related_name": "configs",
"forward_key": "category_id",
"backward_key": "config_id",
"through": "config_category",
"on_delete": "CASCADE",
"_generated": False,
},
{
"name": "categories",
"field_type": "ManyToManyFieldInstance",
"python_type": "models.Category",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": None,
"description": None,
"docstring": None,
"constraints": {},
"model_name": "models.Category",
"related_name": "config_set",
"forward_key": "category_id",
"backward_key": "config_id",
"through": "config_category_map",
"on_delete": "CASCADE",
"_generated": False,
},
],
}, },
"models.Email": { "models.Email": {
"name": "models.Email", "name": "models.Email",
@@ -257,7 +336,7 @@ old_models_describe = {
"default": None, "default": None,
"description": None, "description": None,
"docstring": None, "docstring": None,
"constraints": {"ge": 1, "le": 2147483647}, "constraints": {"ge": MIN_INT, "le": 2147483647},
"db_field_types": {"": "INT"}, "db_field_types": {"": "INT"},
}, },
"data_fields": [ "data_fields": [
@@ -289,7 +368,12 @@ old_models_describe = {
"description": None, "description": None,
"docstring": None, "docstring": None,
"constraints": {}, "constraints": {},
"db_field_types": {"": "BOOL", "sqlite": "INT"}, "db_field_types": {
"": "BOOL",
"mssql": "BIT",
"oracle": "NUMBER(1)",
"sqlite": "INT",
},
}, },
{ {
"name": "user_id", "name": "user_id",
@@ -303,7 +387,7 @@ old_models_describe = {
"default": None, "default": None,
"description": None, "description": None,
"docstring": None, "docstring": None,
"constraints": {"ge": 1, "le": 2147483647}, "constraints": {"ge": MIN_INT, "le": 2147483647},
"db_field_types": {"": "INT"}, "db_field_types": {"": "INT"},
}, },
], ],
@@ -350,7 +434,7 @@ old_models_describe = {
"default": None, "default": None,
"description": None, "description": None,
"docstring": None, "docstring": None,
"constraints": {"ge": 1, "le": 2147483647}, "constraints": {"ge": MIN_INT, "le": 2147483647},
"db_field_types": {"": "INT"}, "db_field_types": {"": "INT"},
}, },
"data_fields": [ "data_fields": [
@@ -369,6 +453,21 @@ old_models_describe = {
"constraints": {"max_length": 50}, "constraints": {"max_length": 50},
"db_field_types": {"": "VARCHAR(50)"}, "db_field_types": {"": "VARCHAR(50)"},
}, },
{
"name": "uid",
"field_type": "IntField",
"db_column": "uuid",
"python_type": "int",
"generated": False,
"nullable": False,
"unique": True,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"ge": -2147483648, "le": 2147483647},
"db_field_types": {"": "INT"},
},
{ {
"name": "view_num", "name": "view_num",
"field_type": "IntField", "field_type": "IntField",
@@ -400,9 +499,9 @@ old_models_describe = {
"db_field_types": {"": "INT"}, "db_field_types": {"": "INT"},
}, },
{ {
"name": "is_reviewed", "name": "is_review",
"field_type": "BooleanField", "field_type": "BooleanField",
"db_column": "is_reviewed", "db_column": "is_review",
"python_type": "bool", "python_type": "bool",
"generated": False, "generated": False,
"nullable": False, "nullable": False,
@@ -412,7 +511,12 @@ old_models_describe = {
"description": "Is Reviewed", "description": "Is Reviewed",
"docstring": None, "docstring": None,
"constraints": {}, "constraints": {},
"db_field_types": {"": "BOOL", "sqlite": "INT"}, "db_field_types": {
"": "BOOL",
"mssql": "BIT",
"oracle": "NUMBER(1)",
"sqlite": "INT",
},
}, },
{ {
"name": "type", "name": "type",
@@ -480,6 +584,26 @@ old_models_describe = {
"auto_now_add": True, "auto_now_add": True,
"auto_now": False, "auto_now": False,
}, },
{
"name": "is_delete",
"field_type": "BooleanField",
"db_column": "is_delete",
"python_type": "bool",
"generated": False,
"nullable": False,
"unique": False,
"indexed": False,
"default": False,
"description": None,
"docstring": None,
"constraints": {},
"db_field_types": {
"": "BOOL",
"mssql": "BIT",
"oracle": "NUMBER(1)",
"sqlite": "INT",
},
},
], ],
"fk_fields": [], "fk_fields": [],
"backward_fk_fields": [], "backward_fk_fields": [],
@@ -516,7 +640,7 @@ old_models_describe = {
"description": None, "description": None,
"docstring": None, "docstring": None,
"unique_together": [], "unique_together": [],
"indexes": [], "indexes": [Index(fields=("username", "is_active")), CustomIndex(fields=("is_superuser",))],
"pk_field": { "pk_field": {
"name": "id", "name": "id",
"field_type": "IntField", "field_type": "IntField",
@@ -529,7 +653,7 @@ old_models_describe = {
"default": None, "default": None,
"description": None, "description": None,
"docstring": None, "docstring": None,
"constraints": {"ge": 1, "le": 2147483647}, "constraints": {"ge": MIN_INT, "le": 2147483647},
"db_field_types": {"": "INT"}, "db_field_types": {"": "INT"},
}, },
"data_fields": [ "data_fields": [
@@ -597,7 +721,12 @@ old_models_describe = {
"description": "Is Active", "description": "Is Active",
"docstring": None, "docstring": None,
"constraints": {}, "constraints": {},
"db_field_types": {"": "BOOL", "sqlite": "INT"}, "db_field_types": {
"": "BOOL",
"mssql": "BIT",
"oracle": "NUMBER(1)",
"sqlite": "INT",
},
}, },
{ {
"name": "is_superuser", "name": "is_superuser",
@@ -612,7 +741,12 @@ old_models_describe = {
"description": "Is SuperUser", "description": "Is SuperUser",
"docstring": None, "docstring": None,
"constraints": {}, "constraints": {},
"db_field_types": {"": "BOOL", "sqlite": "INT"}, "db_field_types": {
"": "BOOL",
"mssql": "BIT",
"oracle": "NUMBER(1)",
"sqlite": "INT",
},
}, },
{ {
"name": "avatar", "name": "avatar",
@@ -714,7 +848,7 @@ old_models_describe = {
"default": None, "default": None,
"description": None, "description": None,
"docstring": None, "docstring": None,
"constraints": {"ge": 1, "le": 2147483647}, "constraints": {"ge": MIN_INT, "le": 2147483647},
"db_field_types": {"": "INT"}, "db_field_types": {"": "INT"},
}, },
"data_fields": [ "data_fields": [
@@ -778,25 +912,35 @@ def test_migrate(mocker: MockerFixture):
models.py diff with old_models.py models.py diff with old_models.py
- change email pk: id -> email_id - change email pk: id -> email_id
- add field: Email.address - add field: Email.address
- add fk: Config.user - add fk field: Config.user
- drop fk: Email.user - drop fk field: Email.user
- drop field: User.avatar - drop field: User.avatar
- add index: Email.email - add index: Email.email
- add many to many: Email.users - add many to many: Email.users
- remove unique: User.username - add one to one: Email.config
- remove unique: Category.title
- add unique: User.username
- change column: length User.password - change column: length User.password
- add unique_together: (name,type) of Product - add unique_together: (name,type) of Product
- add one more many to many field: Product.users
- drop unique field: Config.name
- alter default: Config.status - alter default: Config.status
- rename column: Product.image -> Product.pic - rename column: Product.image -> Product.pic
- rename column: Product.is_review -> Product.is_reviewed
- rename column: Product.is_delete -> Product.is_deleted
- rename fk column: Category.user -> Category.owner
""" """
mocker.patch("click.prompt", side_effect=(True,)) mocker.patch("asyncclick.prompt", side_effect=(True, True, True, True))
models_describe = get_models_describe("models") models_describe = get_models_describe("models")
Migrate.app = "models" Migrate.app = "models"
if isinstance(Migrate.ddl, SqliteDDL): if isinstance(Migrate.ddl, SqliteDDL):
with pytest.raises(NotSupportError): with pytest.raises(NotSupportError):
Migrate.diff_models(old_models_describe, models_describe) Migrate.diff_models(old_models_describe, models_describe)
Migrate.upgrade_operators.clear()
with pytest.raises(NotSupportError):
Migrate.diff_models(models_describe, old_models_describe, False) Migrate.diff_models(models_describe, old_models_describe, False)
Migrate.downgrade_operators.clear()
else: else:
Migrate.diff_models(old_models_describe, models_describe) Migrate.diff_models(old_models_describe, models_describe)
Migrate.diff_models(models_describe, old_models_describe, False) Migrate.diff_models(models_describe, old_models_describe, False)
@@ -805,13 +949,22 @@ def test_migrate(mocker: MockerFixture):
expected_upgrade_operators = { expected_upgrade_operators = {
"ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)", "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)",
"ALTER TABLE `category` MODIFY COLUMN `slug` VARCHAR(100) NOT NULL", "ALTER TABLE `category` MODIFY COLUMN `slug` VARCHAR(100) NOT NULL",
"ALTER TABLE `category` DROP INDEX `title`",
"ALTER TABLE `category` RENAME COLUMN `user_id` TO `owner_id`",
"ALTER TABLE `category` ADD CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
"ALTER TABLE `email` DROP COLUMN `user_id`",
"ALTER TABLE `config` DROP COLUMN `name`",
"ALTER TABLE `config` DROP INDEX `name`",
"ALTER TABLE `config` ADD `user_id` INT NOT NULL COMMENT 'User'", "ALTER TABLE `config` ADD `user_id` INT NOT NULL COMMENT 'User'",
"ALTER TABLE `config` ADD CONSTRAINT `fk_config_user_17daa970` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE", "ALTER TABLE `config` ADD CONSTRAINT `fk_config_user_17daa970` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
"ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT", "ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT",
"ALTER TABLE `config` MODIFY COLUMN `value` JSON NOT NULL", "ALTER TABLE `config` MODIFY COLUMN `value` JSON NOT NULL",
"ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL", "ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL",
"ALTER TABLE `email` DROP COLUMN `user_id`", "ALTER TABLE `email` ADD CONSTRAINT `fk_email_config_76a9dc71` FOREIGN KEY (`config_id`) REFERENCES `config` (`id`) ON DELETE CASCADE",
"ALTER TABLE `email` ADD `config_id` INT NOT NULL UNIQUE",
"ALTER TABLE `configs` RENAME TO `config`", "ALTER TABLE `configs` RENAME TO `config`",
"ALTER TABLE `product` DROP COLUMN `uuid`",
"ALTER TABLE `product` DROP INDEX `uuid`",
"ALTER TABLE `product` RENAME COLUMN `image` TO `pic`", "ALTER TABLE `product` RENAME COLUMN `image` TO `pic`",
"ALTER TABLE `email` RENAME COLUMN `id` TO `email_id`", "ALTER TABLE `email` RENAME COLUMN `id` TO `email_id`",
"ALTER TABLE `product` ADD INDEX `idx_product_name_869427` (`name`, `type_db_alias`)", "ALTER TABLE `product` ADD INDEX `idx_product_name_869427` (`name`, `type_db_alias`)",
@@ -819,52 +972,63 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_869427` (`name`, `type_db_alias`)", "ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_869427` (`name`, `type_db_alias`)",
"ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0", "ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0",
"ALTER TABLE `product` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)", "ALTER TABLE `product` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `product` MODIFY COLUMN `is_reviewed` BOOL NOT NULL COMMENT 'Is Reviewed'", "ALTER TABLE `product` RENAME COLUMN `is_delete` TO `is_deleted`",
"ALTER TABLE `product` RENAME COLUMN `is_review` TO `is_reviewed`",
"ALTER TABLE `user` DROP COLUMN `avatar`", "ALTER TABLE `user` DROP COLUMN `avatar`",
"ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(100) NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(100) NOT NULL",
"ALTER TABLE `user` MODIFY COLUMN `intro` LONGTEXT NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `intro` LONGTEXT NOT NULL",
"ALTER TABLE `user` MODIFY COLUMN `last_login` DATETIME(6) NOT NULL COMMENT 'Last Login'", "ALTER TABLE `user` MODIFY COLUMN `last_login` DATETIME(6) NOT NULL COMMENT 'Last Login'",
"ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1",
"ALTER TABLE `user` MODIFY COLUMN `is_superuser` BOOL NOT NULL COMMENT 'Is SuperUser' DEFAULT 0",
"ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(10,8) NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(10,8) NOT NULL",
"ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)", "ALTER TABLE `user` ADD UNIQUE INDEX `username` (`username`)",
"CREATE TABLE `email_user` (\n `email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4", "CREATE TABLE `email_user` (\n `email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"CREATE TABLE IF NOT EXISTS `newmodel` (\n `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,\n `name` VARCHAR(50) NOT NULL\n) CHARACTER SET utf8mb4", "CREATE TABLE IF NOT EXISTS `newmodel` (\n `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,\n `name` VARCHAR(50) NOT NULL\n) CHARACTER SET utf8mb4",
"ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)", "ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL", "ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL",
"ALTER TABLE `email` MODIFY COLUMN `is_primary` BOOL NOT NULL DEFAULT 0", "CREATE TABLE `product_user` (\n `product_id` INT NOT NULL REFERENCES `product` (`id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"CREATE TABLE `config_category_map` (\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE,\n `config_id` INT NOT NULL REFERENCES `config` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"DROP TABLE IF EXISTS `config_category`",
} }
expected_downgrade_operators = { expected_downgrade_operators = {
"ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL", "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL",
"ALTER TABLE `category` MODIFY COLUMN `slug` VARCHAR(200) NOT NULL", "ALTER TABLE `category` MODIFY COLUMN `slug` VARCHAR(200) NOT NULL",
"ALTER TABLE `config` DROP COLUMN `user_id`", "ALTER TABLE `category` ADD UNIQUE INDEX `title` (`title`)",
"ALTER TABLE `category` RENAME COLUMN `owner_id` TO `user_id`",
"ALTER TABLE `category` DROP FOREIGN KEY `fk_category_user_110d4c63`",
"ALTER TABLE `config` ADD `name` VARCHAR(100) NOT NULL UNIQUE",
"ALTER TABLE `config` ADD UNIQUE INDEX `name` (`name`)",
"ALTER TABLE `config` DROP FOREIGN KEY `fk_config_user_17daa970`", "ALTER TABLE `config` DROP FOREIGN KEY `fk_config_user_17daa970`",
"ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1", "ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1",
"ALTER TABLE `email` ADD `user_id` INT NOT NULL", "ALTER TABLE `email` ADD `user_id` INT NOT NULL",
"ALTER TABLE `config` DROP COLUMN `user_id`",
"ALTER TABLE `email` DROP COLUMN `address`", "ALTER TABLE `email` DROP COLUMN `address`",
"ALTER TABLE `email` DROP COLUMN `config_id`",
"ALTER TABLE `email` DROP FOREIGN KEY `fk_email_config_76a9dc71`",
"ALTER TABLE `config` RENAME TO `configs`", "ALTER TABLE `config` RENAME TO `configs`",
"ALTER TABLE `product` RENAME COLUMN `pic` TO `image`", "ALTER TABLE `product` RENAME COLUMN `pic` TO `image`",
"ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`", "ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`",
"ALTER TABLE `product` ADD `uuid` INT NOT NULL UNIQUE",
"ALTER TABLE `product` ADD UNIQUE INDEX `uuid` (`uuid`)",
"ALTER TABLE `product` DROP INDEX `idx_product_name_869427`", "ALTER TABLE `product` DROP INDEX `idx_product_name_869427`",
"ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`", "ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`",
"ALTER TABLE `product` DROP INDEX `uid_product_name_869427`", "ALTER TABLE `product` DROP INDEX `uid_product_name_869427`",
"ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT", "ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT",
"ALTER TABLE `product` RENAME COLUMN `is_deleted` TO `is_delete`",
"ALTER TABLE `product` RENAME COLUMN `is_reviewed` TO `is_review`",
"ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''", "ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''",
"ALTER TABLE `user` DROP INDEX `idx_user_usernam_9987ab`", "ALTER TABLE `user` DROP INDEX `username`",
"ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(200) NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(200) NOT NULL",
"DROP TABLE IF EXISTS `email_user`", "DROP TABLE IF EXISTS `email_user`",
"DROP TABLE IF EXISTS `newmodel`", "DROP TABLE IF EXISTS `newmodel`",
"DROP TABLE IF EXISTS `product_user`",
"ALTER TABLE `user` MODIFY COLUMN `intro` LONGTEXT NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `intro` LONGTEXT NOT NULL",
"ALTER TABLE `config` MODIFY COLUMN `value` TEXT NOT NULL", "ALTER TABLE `config` MODIFY COLUMN `value` TEXT NOT NULL",
"ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)", "ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `product` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)", "ALTER TABLE `product` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `product` MODIFY COLUMN `is_reviewed` BOOL NOT NULL COMMENT 'Is Reviewed'",
"ALTER TABLE `user` MODIFY COLUMN `last_login` DATETIME(6) NOT NULL COMMENT 'Last Login'", "ALTER TABLE `user` MODIFY COLUMN `last_login` DATETIME(6) NOT NULL COMMENT 'Last Login'",
"ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1",
"ALTER TABLE `user` MODIFY COLUMN `is_superuser` BOOL NOT NULL COMMENT 'Is SuperUser' DEFAULT 0",
"ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(12,9) NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(12,9) NOT NULL",
"ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL", "ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL",
"ALTER TABLE `email` MODIFY COLUMN `is_primary` BOOL NOT NULL DEFAULT 0", "CREATE TABLE `config_category` (\n `config_id` INT NOT NULL REFERENCES `config` (`id`) ON DELETE CASCADE,\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"DROP TABLE IF EXISTS `config_category_map`",
} }
assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators) assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators)
@@ -874,29 +1038,36 @@ def test_migrate(mocker: MockerFixture):
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
expected_upgrade_operators = { expected_upgrade_operators = {
'DROP INDEX IF EXISTS "uid_category_title_f7fc03"',
'ALTER TABLE "category" ALTER COLUMN "name" DROP NOT NULL', 'ALTER TABLE "category" ALTER COLUMN "name" DROP NOT NULL',
'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(100) USING "slug"::VARCHAR(100)', 'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(100) USING "slug"::VARCHAR(100)',
'ALTER TABLE "category" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ', 'ALTER TABLE "category" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ',
'ALTER TABLE "category" RENAME COLUMN "user_id" TO "owner_id"',
'ALTER TABLE "category" ADD CONSTRAINT "fk_category_user_110d4c63" FOREIGN KEY ("owner_id") REFERENCES "user" ("id") ON DELETE CASCADE',
'ALTER TABLE "config" DROP COLUMN "name"',
'DROP INDEX IF EXISTS "uid_config_name_2c83c8"',
'ALTER TABLE "config" ADD "user_id" INT NOT NULL', 'ALTER TABLE "config" ADD "user_id" INT NOT NULL',
'ALTER TABLE "config" ADD CONSTRAINT "fk_config_user_17daa970" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE', 'ALTER TABLE "config" ADD CONSTRAINT "fk_config_user_17daa970" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE',
'ALTER TABLE "config" ALTER COLUMN "status" DROP DEFAULT', 'ALTER TABLE "config" ALTER COLUMN "status" DROP DEFAULT',
'ALTER TABLE "config" ALTER COLUMN "value" TYPE JSONB USING "value"::JSONB', 'ALTER TABLE "config" ALTER COLUMN "value" TYPE JSONB USING "value"::JSONB',
'ALTER TABLE "configs" RENAME TO "config"', 'ALTER TABLE "configs" RENAME TO "config"',
'ALTER TABLE "email" ADD "address" VARCHAR(200) NOT NULL', 'ALTER TABLE "email" ADD "address" VARCHAR(200) NOT NULL',
'ALTER TABLE "email" DROP COLUMN "user_id"',
'ALTER TABLE "email" RENAME COLUMN "id" TO "email_id"', 'ALTER TABLE "email" RENAME COLUMN "id" TO "email_id"',
'ALTER TABLE "email" ALTER COLUMN "is_primary" TYPE BOOL USING "is_primary"::BOOL', 'ALTER TABLE "email" DROP COLUMN "user_id"',
'ALTER TABLE "email" ADD CONSTRAINT "fk_email_config_76a9dc71" FOREIGN KEY ("config_id") REFERENCES "config" ("id") ON DELETE CASCADE',
'ALTER TABLE "email" ADD "config_id" INT NOT NULL UNIQUE',
'DROP INDEX IF EXISTS "uid_product_uuid_d33c18"',
'ALTER TABLE "product" DROP COLUMN "uuid"',
'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0', 'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0',
'ALTER TABLE "product" RENAME COLUMN "image" TO "pic"', 'ALTER TABLE "product" RENAME COLUMN "image" TO "pic"',
'ALTER TABLE "product" ALTER COLUMN "is_reviewed" TYPE BOOL USING "is_reviewed"::BOOL',
'ALTER TABLE "product" ALTER COLUMN "body" TYPE TEXT USING "body"::TEXT', 'ALTER TABLE "product" ALTER COLUMN "body" TYPE TEXT USING "body"::TEXT',
'ALTER TABLE "product" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ', 'ALTER TABLE "product" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ',
'ALTER TABLE "product" RENAME COLUMN "is_review" TO "is_reviewed"',
'ALTER TABLE "product" RENAME COLUMN "is_delete" TO "is_deleted"',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(100) USING "password"::VARCHAR(100)', 'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(100) USING "password"::VARCHAR(100)',
'ALTER TABLE "user" DROP COLUMN "avatar"', 'ALTER TABLE "user" DROP COLUMN "avatar"',
'ALTER TABLE "user" ALTER COLUMN "is_superuser" TYPE BOOL USING "is_superuser"::BOOL',
'ALTER TABLE "user" ALTER COLUMN "last_login" TYPE TIMESTAMPTZ USING "last_login"::TIMESTAMPTZ', 'ALTER TABLE "user" ALTER COLUMN "last_login" TYPE TIMESTAMPTZ USING "last_login"::TIMESTAMPTZ',
'ALTER TABLE "user" ALTER COLUMN "intro" TYPE TEXT USING "intro"::TEXT', 'ALTER TABLE "user" ALTER COLUMN "intro" TYPE TEXT USING "intro"::TEXT',
'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL USING "is_active"::BOOL',
'ALTER TABLE "user" ALTER COLUMN "longitude" TYPE DECIMAL(10,8) USING "longitude"::DECIMAL(10,8)', 'ALTER TABLE "user" ALTER COLUMN "longitude" TYPE DECIMAL(10,8) USING "longitude"::DECIMAL(10,8)',
'CREATE INDEX "idx_product_name_869427" ON "product" ("name", "type_db_alias")', 'CREATE INDEX "idx_product_name_869427" ON "product" ("name", "type_db_alias")',
'CREATE INDEX "idx_email_email_4a1a33" ON "email" ("email")', 'CREATE INDEX "idx_email_email_4a1a33" ON "email" ("email")',
@@ -904,38 +1075,51 @@ def test_migrate(mocker: MockerFixture):
'CREATE TABLE IF NOT EXISTS "newmodel" (\n "id" SERIAL NOT NULL PRIMARY KEY,\n "name" VARCHAR(50) NOT NULL\n);\nCOMMENT ON COLUMN "config"."user_id" IS \'User\'', 'CREATE TABLE IF NOT EXISTS "newmodel" (\n "id" SERIAL NOT NULL PRIMARY KEY,\n "name" VARCHAR(50) NOT NULL\n);\nCOMMENT ON COLUMN "config"."user_id" IS \'User\'',
'CREATE UNIQUE INDEX "uid_product_name_869427" ON "product" ("name", "type_db_alias")', 'CREATE UNIQUE INDEX "uid_product_name_869427" ON "product" ("name", "type_db_alias")',
'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")', 'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")',
'CREATE TABLE "product_user" (\n "product_id" INT NOT NULL REFERENCES "product" ("id") ON DELETE CASCADE,\n "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE\n)',
'CREATE TABLE "config_category_map" (\n "category_id" INT NOT NULL REFERENCES "category" ("id") ON DELETE CASCADE,\n "config_id" INT NOT NULL REFERENCES "config" ("id") ON DELETE CASCADE\n)',
'DROP TABLE IF EXISTS "config_category"',
} }
expected_downgrade_operators = { expected_downgrade_operators = {
'CREATE UNIQUE INDEX "uid_category_title_f7fc03" ON "category" ("title")',
'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL', 'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL',
'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(200) USING "slug"::VARCHAR(200)', 'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(200) USING "slug"::VARCHAR(200)',
'ALTER TABLE "category" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ', 'ALTER TABLE "category" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ',
'ALTER TABLE "category" RENAME COLUMN "owner_id" TO "user_id"',
'ALTER TABLE "category" DROP CONSTRAINT IF EXISTS "fk_category_user_110d4c63"',
'ALTER TABLE "config" ADD "name" VARCHAR(100) NOT NULL UNIQUE',
'CREATE UNIQUE INDEX "uid_config_name_2c83c8" ON "config" ("name")',
'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1', 'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1',
'ALTER TABLE "config" DROP COLUMN "user_id"', 'ALTER TABLE "config" DROP CONSTRAINT IF EXISTS "fk_config_user_17daa970"',
'ALTER TABLE "config" DROP CONSTRAINT "fk_config_user_17daa970"',
'ALTER TABLE "config" RENAME TO "configs"', 'ALTER TABLE "config" RENAME TO "configs"',
'ALTER TABLE "config" ALTER COLUMN "value" TYPE JSONB USING "value"::JSONB', 'ALTER TABLE "config" ALTER COLUMN "value" TYPE JSONB USING "value"::JSONB',
'ALTER TABLE "config" DROP COLUMN "user_id"',
'ALTER TABLE "email" ADD "user_id" INT NOT NULL', 'ALTER TABLE "email" ADD "user_id" INT NOT NULL',
'ALTER TABLE "email" DROP COLUMN "address"', 'ALTER TABLE "email" DROP COLUMN "address"',
'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"', 'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"',
'ALTER TABLE "email" ALTER COLUMN "is_primary" TYPE BOOL USING "is_primary"::BOOL', 'ALTER TABLE "email" DROP COLUMN "config_id"',
'ALTER TABLE "email" DROP CONSTRAINT IF EXISTS "fk_email_config_76a9dc71"',
'ALTER TABLE "product" ADD "uuid" INT NOT NULL UNIQUE',
'CREATE UNIQUE INDEX "uid_product_uuid_d33c18" ON "product" ("uuid")',
'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT', 'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT',
'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"', 'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"',
'ALTER TABLE "product" RENAME COLUMN "is_deleted" TO "is_delete"',
'ALTER TABLE "product" RENAME COLUMN "is_reviewed" TO "is_review"',
'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'', 'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(200) USING "password"::VARCHAR(200)', 'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(200) USING "password"::VARCHAR(200)',
'ALTER TABLE "user" ALTER COLUMN "last_login" TYPE TIMESTAMPTZ USING "last_login"::TIMESTAMPTZ', 'ALTER TABLE "user" ALTER COLUMN "last_login" TYPE TIMESTAMPTZ USING "last_login"::TIMESTAMPTZ',
'ALTER TABLE "user" ALTER COLUMN "is_superuser" TYPE BOOL USING "is_superuser"::BOOL',
'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL USING "is_active"::BOOL',
'ALTER TABLE "user" ALTER COLUMN "intro" TYPE TEXT USING "intro"::TEXT', 'ALTER TABLE "user" ALTER COLUMN "intro" TYPE TEXT USING "intro"::TEXT',
'ALTER TABLE "user" ALTER COLUMN "longitude" TYPE DECIMAL(12,9) USING "longitude"::DECIMAL(12,9)', 'ALTER TABLE "user" ALTER COLUMN "longitude" TYPE DECIMAL(12,9) USING "longitude"::DECIMAL(12,9)',
'ALTER TABLE "product" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ', 'ALTER TABLE "product" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ',
'ALTER TABLE "product" ALTER COLUMN "is_reviewed" TYPE BOOL USING "is_reviewed"::BOOL',
'ALTER TABLE "product" ALTER COLUMN "body" TYPE TEXT USING "body"::TEXT', 'ALTER TABLE "product" ALTER COLUMN "body" TYPE TEXT USING "body"::TEXT',
'DROP INDEX "idx_product_name_869427"', 'DROP TABLE IF EXISTS "product_user"',
'DROP INDEX "idx_email_email_4a1a33"', 'DROP INDEX IF EXISTS "idx_product_name_869427"',
'DROP INDEX "idx_user_usernam_9987ab"', 'DROP INDEX IF EXISTS "idx_email_email_4a1a33"',
'DROP INDEX "uid_product_name_869427"', 'DROP INDEX IF EXISTS "uid_user_usernam_9987ab"',
'DROP INDEX IF EXISTS "uid_product_name_869427"',
'DROP TABLE IF EXISTS "email_user"', 'DROP TABLE IF EXISTS "email_user"',
'DROP TABLE IF EXISTS "newmodel"', 'DROP TABLE IF EXISTS "newmodel"',
'CREATE TABLE "config_category" (\n "config_id" INT NOT NULL REFERENCES "config" ("id") ON DELETE CASCADE,\n "category_id" INT NOT NULL REFERENCES "category" ("id") ON DELETE CASCADE\n)',
'DROP TABLE IF EXISTS "config_category_map"',
} }
assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators) assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators)
assert not set(Migrate.downgrade_operators).symmetric_difference( assert not set(Migrate.downgrade_operators).symmetric_difference(
@@ -966,3 +1150,39 @@ def test_sort_all_version_files(mocker):
"10_datetime_update.py", "10_datetime_update.py",
"11_datetime_update.py", "11_datetime_update.py",
] ]
def test_sort_files_containing_non_migrations(mocker):
mocker.patch(
"os.listdir",
return_value=[
"1_datetime_update.py",
"11_datetime_update.py",
"10_datetime_update.py",
"2_datetime_update.py",
"not_a_migration.py",
"999.py",
"123foo_not_a_migration.py",
],
)
Migrate.migrate_location = "."
assert Migrate.get_all_version_files() == [
"1_datetime_update.py",
"2_datetime_update.py",
"10_datetime_update.py",
"11_datetime_update.py",
]
async def test_empty_migration(mocker, tmp_path: Path) -> None:
mocker.patch("os.listdir", return_value=[])
Migrate.app = "foo"
expected_content = MIGRATE_TEMPLATE.format(upgrade_sql="", downgrade_sql="")
Migrate.migrate_location = tmp_path
migration_file = await Migrate.migrate("update", True)
f = tmp_path / migration_file
assert f.read_text() == expected_content

View File

@@ -0,0 +1,304 @@
import contextlib
import os
import shlex
import shutil
import subprocess
import sys
from pathlib import Path
from aerich.ddl.sqlite import SqliteDDL
from aerich.migrate import Migrate
if sys.version_info >= (3, 11):
from contextlib import chdir
else:
class chdir(contextlib.AbstractContextManager): # Copied from source code of Python3.13
"""Non thread-safe context manager to change the current working directory."""
def __init__(self, path):
self.path = path
self._old_cwd = []
def __enter__(self):
self._old_cwd.append(os.getcwd())
os.chdir(self.path)
def __exit__(self, *excinfo):
os.chdir(self._old_cwd.pop())
MODELS = """from __future__ import annotations
from tortoise import Model, fields
class Foo(Model):
name = fields.CharField(max_length=60, db_index=False)
"""
SETTINGS = """from __future__ import annotations
TORTOISE_ORM = {
"connections": {"default": "sqlite://db.sqlite3"},
"apps": {"models": {"models": ["models", "aerich.models"]}},
}
"""
CONFTEST = """from __future__ import annotations
import asyncio
from typing import Generator
import pytest
import pytest_asyncio
from tortoise import Tortoise, connections
import settings
@pytest.fixture(scope="session")
def event_loop() -> Generator:
policy = asyncio.get_event_loop_policy()
res = policy.new_event_loop()
asyncio.set_event_loop(res)
res._close = res.close # type:ignore[attr-defined]
res.close = lambda: None # type:ignore[method-assign]
yield res
res._close() # type:ignore[attr-defined]
@pytest_asyncio.fixture(scope="session", autouse=True)
async def api(event_loop, request):
await Tortoise.init(config=settings.TORTOISE_ORM)
request.addfinalizer(lambda: event_loop.run_until_complete(connections.close_all(discard=True)))
"""
TESTS = """from __future__ import annotations
import uuid
import pytest
from tortoise.exceptions import IntegrityError
from models import Foo
@pytest.mark.asyncio
async def test_allow_duplicate() -> None:
await Foo.all().delete()
await Foo.create(name="foo")
obj = await Foo.create(name="foo")
assert (await Foo.all().count()) == 2
await obj.delete()
@pytest.mark.asyncio
async def test_unique_is_true() -> None:
with pytest.raises(IntegrityError):
await Foo.create(name="foo")
@pytest.mark.asyncio
async def test_add_unique_field() -> None:
if not await Foo.filter(age=0).exists():
await Foo.create(name="0_"+uuid.uuid4().hex, age=0)
with pytest.raises(IntegrityError):
await Foo.create(name=uuid.uuid4().hex, age=0)
@pytest.mark.asyncio
async def test_drop_unique_field() -> None:
name = "1_" + uuid.uuid4().hex
await Foo.create(name=name, age=0)
assert (await Foo.filter(name=name).exists())
@pytest.mark.asyncio
async def test_with_age_field() -> None:
name = "2_" + uuid.uuid4().hex
await Foo.create(name=name, age=0)
obj = await Foo.get(name=name)
assert obj.age == 0
@pytest.mark.asyncio
async def test_without_age_field() -> None:
name = "3_" + uuid.uuid4().hex
await Foo.create(name=name, age=0)
obj = await Foo.get(name=name)
assert getattr(obj, "age", None) is None
@pytest.mark.asyncio
async def test_m2m_with_custom_through() -> None:
from models import Group, FooGroup
name = "4_" + uuid.uuid4().hex
foo = await Foo.create(name=name)
group = await Group.create(name=name+"1")
await FooGroup.all().delete()
await foo.groups.add(group)
foo_group = await FooGroup.get(foo=foo, group=group)
assert not foo_group.is_active
@pytest.mark.asyncio
async def test_add_m2m_field_after_init_db() -> None:
from models import Group
name = "5_" + uuid.uuid4().hex
foo = await Foo.create(name=name)
group = await Group.create(name=name+"1")
await foo.groups.add(group)
assert (await group.users.all().first()) == foo
"""
def run_aerich(cmd: str) -> None:
with contextlib.suppress(subprocess.TimeoutExpired):
if not cmd.startswith("aerich"):
cmd = "aerich " + cmd
subprocess.run(shlex.split(cmd), timeout=2)
def run_shell(cmd: str) -> subprocess.CompletedProcess:
envs = dict(os.environ, PYTHONPATH=".")
return subprocess.run(shlex.split(cmd), env=envs)
def test_sqlite_migrate(tmp_path: Path) -> None:
if (ddl := getattr(Migrate, "ddl", None)) and not isinstance(ddl, SqliteDDL):
return
with chdir(tmp_path):
models_py = Path("models.py")
settings_py = Path("settings.py")
test_py = Path("_test.py")
models_py.write_text(MODELS)
settings_py.write_text(SETTINGS)
test_py.write_text(TESTS)
Path("conftest.py").write_text(CONFTEST)
if (db_file := Path("db.sqlite3")).exists():
db_file.unlink()
run_aerich("aerich init -t settings.TORTOISE_ORM")
run_aerich("aerich init-db")
r = run_shell("pytest _test.py::test_allow_duplicate")
assert r.returncode == 0
# Add index
models_py.write_text(MODELS.replace("index=False", "index=True"))
run_aerich("aerich migrate") # migrations/models/1_
run_aerich("aerich upgrade")
r = run_shell("pytest -s _test.py::test_allow_duplicate")
assert r.returncode == 0
# Drop index
models_py.write_text(MODELS)
run_aerich("aerich migrate") # migrations/models/2_
run_aerich("aerich upgrade")
r = run_shell("pytest -s _test.py::test_allow_duplicate")
assert r.returncode == 0
# Add unique index
models_py.write_text(MODELS.replace("index=False", "index=True, unique=True"))
run_aerich("aerich migrate") # migrations/models/3_
run_aerich("aerich upgrade")
r = run_shell("pytest _test.py::test_unique_is_true")
assert r.returncode == 0
# Drop unique index
models_py.write_text(MODELS)
run_aerich("aerich migrate") # migrations/models/4_
run_aerich("aerich upgrade")
r = run_shell("pytest _test.py::test_allow_duplicate")
assert r.returncode == 0
# Add field with unique=True
with models_py.open("a") as f:
f.write(" age = fields.IntField(unique=True, default=0)")
run_aerich("aerich migrate") # migrations/models/5_
run_aerich("aerich upgrade")
r = run_shell("pytest _test.py::test_add_unique_field")
assert r.returncode == 0
# Drop unique field
models_py.write_text(MODELS)
run_aerich("aerich migrate") # migrations/models/6_
run_aerich("aerich upgrade")
r = run_shell("pytest -s _test.py::test_drop_unique_field")
assert r.returncode == 0
# Initial with indexed field and then drop it
migrations_dir = Path("migrations/models")
shutil.rmtree(migrations_dir)
db_file.unlink()
models_py.write_text(MODELS + " age = fields.IntField(db_index=True)")
run_aerich("aerich init -t settings.TORTOISE_ORM")
run_aerich("aerich init-db")
migration_file = list(migrations_dir.glob("0_*.py"))[0]
assert "CREATE INDEX" in migration_file.read_text()
r = run_shell("pytest _test.py::test_with_age_field")
assert r.returncode == 0
models_py.write_text(MODELS)
run_aerich("aerich migrate")
run_aerich("aerich upgrade")
migration_file_1 = list(migrations_dir.glob("1_*.py"))[0]
assert "DROP INDEX" in migration_file_1.read_text()
r = run_shell("pytest _test.py::test_without_age_field")
assert r.returncode == 0
# Generate migration file in emptry directory
db_file.unlink()
run_aerich("aerich init-db")
assert not db_file.exists()
for p in migrations_dir.glob("*"):
if p.is_dir():
shutil.rmtree(p)
else:
p.unlink()
run_aerich("aerich init-db")
assert db_file.exists()
# init without '[tool]' section in pyproject.toml
config_file = Path("pyproject.toml")
config_file.write_text('[project]\nname = "project"')
run_aerich("init -t settings.TORTOISE_ORM")
assert "[tool.aerich]" in config_file.read_text()
# add m2m with custom model for through
new = """
groups = fields.ManyToManyField("models.Group", through="foo_group")
class Group(Model):
name = fields.CharField(max_length=60)
class FooGroup(Model):
foo = fields.ForeignKeyField("models.Foo")
group = fields.ForeignKeyField("models.Group")
is_active = fields.BooleanField(default=False)
class Meta:
table = "foo_group"
"""
models_py.write_text(MODELS + new)
run_aerich("aerich migrate")
run_aerich("aerich upgrade")
migration_file_1 = list(migrations_dir.glob("1_*.py"))[0]
assert "foo_group" in migration_file_1.read_text()
r = run_shell("pytest _test.py::test_m2m_with_custom_through")
assert r.returncode == 0
# add m2m field after init-db
new = """
groups = fields.ManyToManyField("models.Group", through="foo_group", related_name="users")
class Group(Model):
name = fields.CharField(max_length=60)
"""
if db_file.exists():
db_file.unlink()
if migrations_dir.exists():
shutil.rmtree(migrations_dir)
models_py.write_text(MODELS)
run_aerich("aerich init-db")
models_py.write_text(MODELS + new)
run_aerich("aerich migrate")
run_aerich("aerich upgrade")
migration_file_1 = list(migrations_dir.glob("1_*.py"))[0]
assert "foo_group" in migration_file_1.read_text()
r = run_shell("pytest _test.py::test_add_m2m_field_after_init_db")
assert r.returncode == 0

View File

@@ -1,6 +1,164 @@
from aerich.utils import import_py_file from aerich.utils import get_dict_diff_by_key, import_py_file
def test_import_py_file(): def test_import_py_file() -> None:
m = import_py_file("aerich/utils.py") m = import_py_file("aerich/utils.py")
assert getattr(m, "import_py_file") assert getattr(m, "import_py_file", None)
class TestDiffFields:
def test_the_same_through_order(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "members", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert type(get_dict_diff_by_key(old, new)).__name__ == "generator"
assert len(diffs) == 1
assert diffs == [("change", [0, "name"], ("users", "members"))]
def test_same_through_with_different_orders(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "admins", "through": "admins_group"},
{"name": "members", "through": "users_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 1
assert diffs == [("change", [0, "name"], ("users", "members"))]
def test_the_same_field_name_order(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "users", "through": "user_groups"},
{"name": "admins", "through": "admin_groups"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 4
assert diffs == [
("remove", "", [(0, {"name": "users", "through": "users_group"})]),
("remove", "", [(0, {"name": "admins", "through": "admins_group"})]),
("add", "", [(0, {"name": "users", "through": "user_groups"})]),
("add", "", [(0, {"name": "admins", "through": "admin_groups"})]),
]
def test_same_field_name_with_different_orders(self) -> None:
old = [
{"name": "admins", "through": "admins_group"},
{"name": "users", "through": "users_group"},
]
new = [
{"name": "users", "through": "user_groups"},
{"name": "admins", "through": "admin_groups"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 4
assert diffs == [
("remove", "", [(0, {"name": "admins", "through": "admins_group"})]),
("remove", "", [(0, {"name": "users", "through": "users_group"})]),
("add", "", [(0, {"name": "users", "through": "user_groups"})]),
("add", "", [(0, {"name": "admins", "through": "admin_groups"})]),
]
def test_drop_one(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "admins", "through": "admins_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 1
assert diffs == [("remove", "", [(0, {"name": "users", "through": "users_group"})])]
def test_add_one(self) -> None:
old = [
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 1
assert diffs == [("add", "", [(0, {"name": "users", "through": "users_group"})])]
def test_drop_some(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
{"name": "staffs", "through": "staffs_group"},
]
new = [
{"name": "admins", "through": "admins_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 2
assert diffs == [
("remove", "", [(0, {"name": "users", "through": "users_group"})]),
("remove", "", [(0, {"name": "staffs", "through": "staffs_group"})]),
]
def test_add_some(self) -> None:
old = [
{"name": "staffs", "through": "staffs_group"},
]
new = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
{"name": "staffs", "through": "staffs_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 2
assert diffs == [
("add", "", [(0, {"name": "users", "through": "users_group"})]),
("add", "", [(0, {"name": "admins", "through": "admins_group"})]),
]
def test_some_through_unchanged(self) -> None:
old = [
{"name": "staffs", "through": "staffs_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "users", "through": "users_group"},
{"name": "admins_new", "through": "admins_group"},
{"name": "staffs_new", "through": "staffs_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 3
assert diffs == [
("change", [0, "name"], ("staffs", "staffs_new")),
("change", [0, "name"], ("admins", "admins_new")),
("add", "", [(0, {"name": "users", "through": "users_group"})]),
]
def test_some_unchanged_without_drop_or_add(self) -> None:
old = [
{"name": "staffs", "through": "staffs_group"},
{"name": "admins", "through": "admins_group"},
{"name": "users", "through": "users_group"},
]
new = [
{"name": "users_new", "through": "users_group"},
{"name": "admins_new", "through": "admins_group"},
{"name": "staffs_new", "through": "staffs_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 3
assert diffs == [
("change", [0, "name"], ("staffs", "staffs_new")),
("change", [0, "name"], ("admins", "admins_new")),
("change", [0, "name"], ("users", "users_new")),
]