Compare commits

..

No commits in common. "dev" and "v0.4.0" have entirely different histories.
dev ... v0.4.0

54 changed files with 1887 additions and 7143 deletions

1
.github/FUNDING.yml vendored
View File

@ -1 +0,0 @@
custom: ["https://sponsor.long2ice.io"]

View File

@ -1,109 +0,0 @@
name: ci
on:
push:
branches-ignore:
- main
pull_request:
branches-ignore:
- main
jobs:
ci:
runs-on: ubuntu-latest
services:
postgres:
image: postgres:latest
ports:
- 5432:5432
env:
POSTGRES_PASSWORD: 123456
POSTGRES_USER: postgres
options: --health-cmd=pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"]
tortoise-orm:
- tortoise021
- tortoise022
- tortoise023
- tortoise024
# TODO: add dev back when drop python3.8 support
# - tortoisedev
steps:
- name: Start MySQL
run: sudo systemctl start mysql.service
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/poetry.lock') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install and configure Poetry
run: |
pip install -U pip
if [[ "${{ matrix.python-version }}" == "3.8" ]]; then
# poetry2.0+ does not support installed by python3.8, but can manage project using py38
python3.12 -m pip install "poetry>=2.0"
else
pip install "poetry>=2.0"
fi
poetry env use python${{ matrix.python-version }}
- name: Install dependencies and check style
run: poetry run make check
- name: Install TortoiseORM v0.21
if: matrix.tortoise-orm == 'tortoise021'
run: poetry run pip install --upgrade "tortoise-orm>=0.21,<0.22"
- name: Install TortoiseORM v0.22
if: matrix.tortoise-orm == 'tortoise022'
run: poetry run pip install --upgrade "tortoise-orm>=0.22,<0.23"
- name: Install TortoiseORM v0.23
if: matrix.tortoise-orm == 'tortoise023'
run: poetry run pip install --upgrade "tortoise-orm>=0.23,<0.24"
- name: Install TortoiseORM v0.24
if: matrix.tortoise-orm == 'tortoise024'
run: |
if [[ "${{ matrix.python-version }}" == "3.8" ]]; then
echo "Skip test for tortoise v0.24 as it does not support Python3.8"
else
poetry run pip install --upgrade "tortoise-orm>=0.24,<0.25"
fi
- name: Install TortoiseORM develop branch
if: matrix.tortoise-orm == 'tortoisedev'
run: |
if [[ "${{ matrix.python-version }}" == "3.8" ]]; then
echo "Skip test for tortoise develop branch as it does not support Python3.8"
else
poetry run pip uninstall -y tortoise-orm
poetry run pip install --upgrade "git+https://github.com/tortoise/tortoise-orm"
fi
- name: CI
env:
MYSQL_PASS: root
MYSQL_HOST: 127.0.0.1
MYSQL_PORT: 3306
POSTGRES_PASS: 123456
POSTGRES_HOST: 127.0.0.1
POSTGRES_PORT: 5432
run: poetry run make _testall
- name: Verify aiomysql support
# Only check the latest version of tortoise
if: matrix.tortoise-orm == 'tortoise024'
run: |
poetry run pip uninstall -y asyncmy
poetry run make test_mysql
poetry run pip install asyncmy
env:
MYSQL_PASS: root
MYSQL_HOST: 127.0.0.1
MYSQL_PORT: 3306
- name: Verify psycopg support
# Only check the latest version of tortoise
if: matrix.tortoise-orm == 'tortoise024'
run: poetry run make test_psycopg
env:
POSTGRES_PASS: 123456
POSTGRES_HOST: 127.0.0.1
POSTGRES_PORT: 5432

View File

@ -7,14 +7,14 @@ jobs:
publish: publish:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v2
- uses: actions/setup-python@v5 - uses: actions/setup-python@v2
with: with:
python-version: '3.x' python-version: '3.x'
- name: Install and configure Poetry - name: Install and configure Poetry
run: | uses: snok/install-poetry@v1.1.1
pip install -U pip poetry with:
poetry config virtualenvs.create false virtualenvs-create: false
- name: Build dists - name: Build dists
run: make build run: make build
- name: Pypi Publish - name: Pypi Publish

34
.github/workflows/test.yml vendored Normal file
View File

@ -0,0 +1,34 @@
name: test
on: [ push, pull_request ]
jobs:
testall:
runs-on: ubuntu-latest
services:
postgres:
image: postgres:latest
ports:
- 5432:5432
env:
POSTGRES_PASSWORD: 123456
POSTGRES_USER: postgres
options: --health-cmd=pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
steps:
- name: Start MySQL
run: sudo systemctl start mysql.service
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: '3.x'
- name: Install and configure Poetry
uses: snok/install-poetry@v1.1.1
with:
virtualenvs-create: false
- name: CI
env:
MYSQL_PASS: root
MYSQL_HOST: 127.0.0.1
MYSQL_PORT: 3306
POSTGRES_PASS: 123456
POSTGRES_HOST: 127.0.0.1
POSTGRES_PORT: 5432
run: make ci

1
.gitignore vendored
View File

@ -146,4 +146,3 @@ aerich.ini
src src
.vscode .vscode
.DS_Store .DS_Store
.python-version

View File

@ -1,224 +1,7 @@
# ChangeLog # ChangeLog
## 0.8
### [0.8.3]**(Unreleased)**
#### Fixed
- fix: `aerich init-db` process is suspended. ([#435])
[#435]: https://github.com/tortoise/aerich/pull/435
### [0.8.2](../../releases/tag/v0.8.2) - 2025-02-28
#### Added
- Support changes `max_length` or int type for primary key field. ([#428])
- feat: support psycopg. ([#425])
- Support run `poetry add aerich` in project that inited by poetry v2. ([#424])
- feat: support command `python -m aerich`. ([#417])
- feat: add --fake to upgrade/downgrade. ([#398])
- Support ignore table by settings `managed=False` in `Meta` class. ([#397])
#### Fixed
- fix: aerich migrate raises tortoise.exceptions.FieldError when `index.INDEX_TYPE` is not empty. ([#415])
- No migration occurs as expected when adding `unique=True` to indexed field. ([#404])
- fix: inspectdb raise KeyError 'int2' for smallint. ([#401])
- fix: inspectdb not match data type 'DOUBLE' and 'CHAR' for MySQL. ([#187])
### Changed
- Refactored version management to use `importlib.metadata.version(__package__)` instead of hardcoded version string ([#412])
[#397]: https://github.com/tortoise/aerich/pull/397
[#398]: https://github.com/tortoise/aerich/pull/398
[#401]: https://github.com/tortoise/aerich/pull/401
[#404]: https://github.com/tortoise/aerich/pull/404
[#412]: https://github.com/tortoise/aerich/pull/412
[#415]: https://github.com/tortoise/aerich/pull/415
[#417]: https://github.com/tortoise/aerich/pull/417
[#424]: https://github.com/tortoise/aerich/pull/424
[#425]: https://github.com/tortoise/aerich/pull/425
### [0.8.1](../../releases/tag/v0.8.1) - 2024-12-27
#### Fixed
- fix: add o2o field does not create constraint when migrating. ([#396])
- Migration with duplicate renaming of columns in some cases. ([#395])
- fix: intermediate table for m2m relation not created. ([#394])
- Migrate add m2m field with custom through generate duplicated table. ([#393])
- Migrate drop the wrong m2m field when model have multi m2m fields. ([#376])
- KeyError raised when removing or renaming an existing model. ([#386])
- fix: error when there is `__init__.py` in the migration folder. ([#272])
- Setting null=false on m2m field causes migration to fail. ([#334])
- Fix NonExistentKey when running `aerich init` without `[tool]` section in config file. ([#284])
- Fix configuration file reading error when containing Chinese characters. ([#286])
- sqlite: failed to create/drop index. ([#302])
- PostgreSQL: Cannot drop constraint after deleting or rename FK on a model. ([#378])
- Fix create/drop indexes in every migration. ([#377])
- Sort m2m fields before comparing them with diff. ([#271])
#### Changed
- Allow run `aerich init-db` with empty migration directories instead of abort with warnings. ([#286])
- Add version constraint(>=0.21) for tortoise-orm. ([#388])
- Move `tomlkit` to optional and support `pip install aerich[toml]`. ([#392])
[#396]: https://github.com/tortoise/aerich/pull/396
[#395]: https://github.com/tortoise/aerich/pull/395
[#394]: https://github.com/tortoise/aerich/pull/394
[#393]: https://github.com/tortoise/aerich/pull/393
[#392]: https://github.com/tortoise/aerich/pull/392
[#388]: https://github.com/tortoise/aerich/pull/388
[#386]: https://github.com/tortoise/aerich/pull/386
[#378]: https://github.com/tortoise/aerich/pull/378
[#377]: https://github.com/tortoise/aerich/pull/377
[#376]: https://github.com/tortoise/aerich/pull/376
[#334]: https://github.com/tortoise/aerich/pull/334
[#302]: https://github.com/tortoise/aerich/pull/302
[#286]: https://github.com/tortoise/aerich/pull/286
[#284]: https://github.com/tortoise/aerich/pull/284
[#272]: https://github.com/tortoise/aerich/pull/272
[#271]: https://github.com/tortoise/aerich/pull/271
### [0.8.0](../../releases/tag/v0.8.0) - 2024-12-04
- Fix the issue of parameter concatenation when generating ORM with inspectdb (#331)
- Fix KeyError when deleting a field with unqiue=True. (#364)
- Correct the click import. (#360)
- Improve CLI help text and output. (#355)
- Fix mysql drop unique index raises OperationalError. (#346)
**Upgrade note:**
1. Use column name as unique key name for mysql
2. Drop support for Python3.7
## 0.7
### [0.7.2](../../releases/tag/v0.7.2) - 2023-07-20
- Support virtual fields.
- Fix modify multiple times. (#279)
- Added `-i` and `--in-transaction` options to `aerich migrate` command. (#296)
- Fix generates two semicolons in a row. (#301)
### 0.7.1
- Fix syntax error with python3.8.10. (#265)
- Fix sql generate error. (#263)
- Fix initialize an empty database. (#267)
### 0.7.1rc1
- Fix postgres sql error (#263)
### 0.7.0
**Now aerich use `.py` file to record versions.**
Upgrade Note:
1. Drop `aerich` table
2. Delete `migrations/models` folder
3. Run `aerich init-db`
- Improve `inspectdb` adding support to `postgresql::numeric` data type
- Add support for dynamically load DDL classes easing to add support to
new databases without changing `Migrate` class logic
- Fix decimal field change. (#246)
- Support add/remove field with index.
## 0.6
### 0.6.3
- Improve `inspectdb` and support `postgres` & `sqlite`.
### 0.6.2
- Support migration for specified index. (#203)
### 0.6.1
- Fix `pyproject.toml` not existing error. (#217)
### 0.6.0
- Change default config file from `aerich.ini` to `pyproject.toml`. (#197)
**Upgrade note:**
1. Run `aerich init -t config.TORTOISE_ORM`.
2. Remove `aerich.ini`.
- Remove `pydantic` dependency. (#198)
- `inspectdb` support `DATE`. (#215)
## 0.5
### 0.5.8
- Support `indexes` change. (#193)
### 0.5.7
- Fix no module found error. (#188) (#189)
### 0.5.6
- Add `Command` class. (#148) (#141) (#123) (#106)
- Fix: migrate doesn't use source_field in unique_together. (#181)
### 0.5.5
- Fix KeyError: 'src_folder' after upgrading aerich to 0.5.4. (#176)
- Fix MySQL 5.X rename column.
- Fix `db_constraint` when fk changed. (#179)
### 0.5.4
- Fix incorrect index creation order. (#151)
- Not catch exception when import config. (#164)
- Support `drop column` for sqlite. (#40)
### 0.5.3
- Fix postgre alter null. (#142)
- Fix default function when migrate. (#147)
### 0.5.2
- Fix rename field on the field add. (#134)
- Fix postgres field type change error. (#135)
- Fix inspectdb for `FloatField`. (#138)
- Support `rename table`. (#139)
### 0.5.1
- Fix tortoise connections not being closed properly. (#120)
- Fix bug for field change. (#119)
- Fix drop model in the downgrade. (#132)
### 0.5.0
- Refactor core code, now has no limitation for everything.
## 0.4 ## 0.4
### 0.4.4
- Fix unnecessary import. (#113)
### 0.4.3
- Replace migrations separator to sql standard comment.
- Add `inspectdb` command.
### 0.4.2
- Use `pathlib` for path resolving. (#89)
- Fix upgrade in new db. (#96)
- Fix packaging error. (#92)
### 0.4.1
- Bug fix. (#91 #93)
### 0.4.0 ### 0.4.0
- Use `.sql` instead of `.json` to store version file. - Use `.sql` instead of `.json` to store version file.

View File

@ -1,58 +1,55 @@
checkfiles = aerich/ tests/ conftest.py checkfiles = aerich/ tests/ conftest.py
black_opts = -l 100 -t py38
py_warn = PYTHONDEVMODE=1 py_warn = PYTHONDEVMODE=1
MYSQL_HOST ?= "127.0.0.1" MYSQL_HOST ?= "127.0.0.1"
MYSQL_PORT ?= 3306 MYSQL_PORT ?= 3306
MYSQL_PASS ?= "123456" MYSQL_PASS ?= "123456"
POSTGRES_HOST ?= "127.0.0.1" POSTGRES_HOST ?= "127.0.0.1"
POSTGRES_PORT ?= 5432 POSTGRES_PORT ?= 5432
POSTGRES_PASS ?= 123456 POSTGRES_PASS ?= "123456"
help:
@echo "Aerich development makefile"
@echo
@echo "usage: make <target>"
@echo "Targets:"
@echo " up Updates dev/test dependencies"
@echo " deps Ensure dev/test dependencies are installed"
@echo " check Checks that build is sane"
@echo " lint Reports all linter violations"
@echo " test Runs all tests"
@echo " style Auto-formats the code"
up: up:
@poetry update @poetry update
deps: deps:
@poetry install --all-extras --all-groups @poetry install -E dbdrivers
_style: style: deps
@ruff check --fix $(checkfiles) isort -src $(checkfiles)
@ruff format $(checkfiles) black $(black_opts) $(checkfiles)
style: deps _style
_check: check: deps
@ruff format --check $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false) black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
@ruff check $(checkfiles) flake8 $(checkfiles)
@mypy $(checkfiles) bandit -x tests -r $(checkfiles)
@bandit -r aerich
check: deps _check
_lint: _build
@ruff format $(checkfiles)
ruff check --fix $(checkfiles)
mypy $(checkfiles)
bandit -c pyproject.toml -r $(checkfiles)
twine check dist/*
lint: deps _lint
test: deps test: deps
$(py_warn) TEST_DB=sqlite://:memory: pytest $(py_warn) TEST_DB=sqlite://:memory: py.test
test_sqlite: test_sqlite:
$(py_warn) TEST_DB=sqlite://:memory: pytest $(py_warn) TEST_DB=sqlite://:memory: py.test
test_mysql: test_mysql:
$(py_warn) TEST_DB="mysql://root:$(MYSQL_PASS)@$(MYSQL_HOST):$(MYSQL_PORT)/test_\{\}" pytest -vv -s $(py_warn) TEST_DB="mysql://root:$(MYSQL_PASS)@$(MYSQL_HOST):$(MYSQL_PORT)/test_\{\}" pytest -vv -s
test_postgres: test_postgres:
$(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest -vv -s $(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest
test_psycopg: testall: deps test_sqlite test_postgres test_mysql
$(py_warn) TEST_DB="psycopg://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest -vv -s
_testall: test_sqlite test_postgres test_mysql build: deps
testall: deps _testall
_build:
@poetry build @poetry build
build: deps _build
ci: build _check _testall ci: check testall

195
README.md
View File

@ -1,23 +1,23 @@
# Aerich # Aerich
[![image](https://img.shields.io/pypi/v/aerich.svg?style=flat)](https://pypi.python.org/pypi/aerich) [![image](https://img.shields.io/pypi/v/aerich.svg?style=flat)](https://pypi.python.org/pypi/aerich)
[![image](https://img.shields.io/github/license/tortoise/aerich)](https://github.com/tortoise/aerich) [![image](https://img.shields.io/github/license/long2ice/aerich)](https://github.com/long2ice/aerich)
[![image](https://github.com/tortoise/aerich/workflows/pypi/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:pypi) [![image](https://github.com/long2ice/aerich/workflows/pypi/badge.svg)](https://github.com/long2ice/aerich/actions?query=workflow:pypi)
[![image](https://github.com/tortoise/aerich/workflows/ci/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:ci) [![image](https://github.com/long2ice/aerich/workflows/test/badge.svg)](https://github.com/long2ice/aerich/actions?query=workflow:test)
English | [Русский](./README_RU.md)
## Introduction ## Introduction
Aerich is a database migrations tool for TortoiseORM, which is like alembic for SQLAlchemy, or like Django ORM with Aerich is a database migrations tool for Tortoise-ORM, which like alembic for SQLAlchemy, or Django ORM with it\'s
it\'s own migration solution. own migrations solution.
**Important: You can only use absolutely import in your `models.py` to make `aerich` work.**
## Install ## Install
Just install from pypi: Just install from pypi:
```shell ```shell
pip install "aerich[toml]" > pip install aerich
``` ```
## Quick Start ## Quick Start
@ -28,9 +28,10 @@ pip install "aerich[toml]"
Usage: aerich [OPTIONS] COMMAND [ARGS]... Usage: aerich [OPTIONS] COMMAND [ARGS]...
Options: Options:
-V, --version Show the version and exit. -c, --config TEXT Config file. [default: aerich.ini]
-c, --config TEXT Config file. [default: pyproject.toml] --app TEXT Tortoise-ORM app name. [default: models]
--app TEXT Tortoise-ORM app name. -n, --name TEXT Name of section in .ini file to use for aerich config.
[default: aerich]
-h, --help Show this message and exit. -h, --help Show this message and exit.
Commands: Commands:
@ -39,14 +40,14 @@ Commands:
history List all migrate items. history List all migrate items.
init Init config file and generate root migrate location. init Init config file and generate root migrate location.
init-db Generate schema and generate app migrate location. init-db Generate schema and generate app migrate location.
inspectdb Introspects the database tables to standard output as...
migrate Generate migrate changes file. migrate Generate migrate changes file.
upgrade Upgrade to specified version. upgrade Upgrade to latest version.
``` ```
## Usage ## Usage
You need to add `aerich.models` to your `Tortoise-ORM` config first. Example: You need add `aerich.models` to your `Tortoise-ORM` config first,
example:
```python ```python
TORTOISE_ORM = { TORTOISE_ORM = {
@ -70,20 +71,19 @@ Usage: aerich init [OPTIONS]
Init config file and generate root migrate location. Init config file and generate root migrate location.
Options: Options:
-t, --tortoise-orm TEXT Tortoise-ORM config module dict variable, like -t, --tortoise-orm TEXT Tortoise-ORM config module dict variable, like settings.TORTOISE_ORM.
settings.TORTOISE_ORM. [required] [required]
--location TEXT Migrate store location. [default: ./migrations] --location TEXT Migrate store location. [default: ./migrations]
-s, --src_folder TEXT Folder of the source, relative to the project root.
-h, --help Show this message and exit. -h, --help Show this message and exit.
``` ```
Initialize the config file and migrations location: Init config file and location:
```shell ```shell
> aerich init -t tests.backends.mysql.TORTOISE_ORM > aerich init -t tests.backends.mysql.TORTOISE_ORM
Success create migrate location ./migrations Success create migrate location ./migrations
Success write config to pyproject.toml Success generate config file aerich.ini
``` ```
### Init db ### Init db
@ -95,38 +95,28 @@ Success create app migrate location ./migrations/models
Success generate schema for app "models" Success generate schema for app "models"
``` ```
If your Tortoise-ORM app is not the default `models`, you must specify the correct app via `--app`, If your Tortoise-ORM app is not default `models`, you must specify
e.g. `aerich --app other_models init-db`. `--app` like `aerich --app other_models init-db`.
### Update models and make migrate ### Update models and make migrate
```shell ```shell
> aerich migrate --name drop_column > aerich migrate --name drop_column
Success migrate 1_202029051520102929_drop_column.py Success migrate 1_202029051520102929_drop_column.sql
``` ```
Format of migrate filename is Format of migrate filename is
`{version_num}_{datetime}_{name|update}.py`. `{version_num}_{datetime}_{name|update}.sql`.
If `aerich` guesses you are renaming a column, it will ask `Rename {old_column} to {new_column} [True]`. You can choose And if `aerich` guess you are renaming a column, it will ask `Rename {old_column} to {new_column} [True]`, you can choice `True` to rename column without column drop, or choice `False` to drop column then create.
`True` to rename column without column drop, or choose `False` to drop the column then create. Note that the latter may
lose data.
If you need to manually write migration, you could generate empty file:
```shell
> aerich migrate --name add_index --empty
Success migrate 1_202326122220101229_add_index.py
```
### Upgrade to latest version ### Upgrade to latest version
```shell ```shell
> aerich upgrade > aerich upgrade
Success upgrade 1_202029051520102929_drop_column.py Success upgrade 1_202029051520102929_drop_column.sql
``` ```
Now your db is migrated to latest. Now your db is migrated to latest.
@ -134,7 +124,7 @@ Now your db is migrated to latest.
### Downgrade to specified version ### Downgrade to specified version
```shell ```shell
> aerich downgrade -h > aerich init -h
Usage: aerich downgrade [OPTIONS] Usage: aerich downgrade [OPTIONS]
@ -152,17 +142,17 @@ Options:
```shell ```shell
> aerich downgrade > aerich downgrade
Success downgrade 1_202029051520102929_drop_column.py Success downgrade 1_202029051520102929_drop_column.sql
``` ```
Now your db is rolled back to the specified version. Now your db rollback to specified version.
### Show history ### Show history
```shell ```shell
> aerich history > aerich history
1_202029051520102929_drop_column.py 1_202029051520102929_drop_column.sql
``` ```
### Show heads to be migrated ### Show heads to be migrated
@ -170,140 +160,31 @@ Now your db is rolled back to the specified version.
```shell ```shell
> aerich heads > aerich heads
1_202029051520102929_drop_column.py 1_202029051520102929_drop_column.sql
``` ```
### Inspect db tables to TortoiseORM model
Currently `inspectdb` support MySQL & Postgres & SQLite.
```shell
Usage: aerich inspectdb [OPTIONS]
Introspects the database tables to standard output as TortoiseORM model.
Options:
-t, --table TEXT Which tables to inspect.
-h, --help Show this message and exit.
```
Inspect all tables and print to console:
```shell
aerich --app models inspectdb
```
Inspect a specified table in the default app and redirect to `models.py`:
```shell
aerich inspectdb -t user > models.py
```
For example, you table is:
```sql
CREATE TABLE `test`
(
`id` int NOT NULL AUTO_INCREMENT,
`decimal` decimal(10, 2) NOT NULL,
`date` date DEFAULT NULL,
`datetime` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`time` time DEFAULT NULL,
`float` float DEFAULT NULL,
`string` varchar(200) COLLATE utf8mb4_general_ci DEFAULT NULL,
`tinyint` tinyint DEFAULT NULL,
PRIMARY KEY (`id`),
KEY `asyncmy_string_index` (`string`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COLLATE = utf8mb4_general_ci
```
Now run `aerich inspectdb -t test` to see the generated model:
```python
from tortoise import Model, fields
class Test(Model):
date = fields.DateField(null=True)
datetime = fields.DatetimeField(auto_now=True)
decimal = fields.DecimalField(max_digits=10, decimal_places=2)
float = fields.FloatField(null=True)
id = fields.IntField(primary_key=True)
string = fields.CharField(max_length=200, null=True)
time = fields.TimeField(null=True)
tinyint = fields.BooleanField(null=True)
```
Note that this command is limited and can't infer some fields, such as `IntEnumField`, `ForeignKeyField`, and others.
### Multiple databases ### Multiple databases
```python ```python
tortoise_orm = { tortoise_orm = {
"connections": { "connections": {
"default": "postgres://postgres_user:postgres_pass@127.0.0.1:5432/db1", "default": expand_db_url(db_url, True),
"second": "postgres://postgres_user:postgres_pass@127.0.0.1:5432/db2", "second": expand_db_url(db_url_second, True),
}, },
"apps": { "apps": {
"models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"}, "models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"},
"models_second": {"models": ["tests.models_second"], "default_connection": "second", }, "models_second": {"models": ["tests.models_second"], "default_connection": "second",},
}, },
} }
``` ```
You only need to specify `aerich.models` in one app, and must specify `--app` when running `aerich migrate` and so on, e.g. `aerich --app models_second migrate`. You need only specify `aerich.models` in one app, and must specify `--app` when run `aerich migrate` and so on.
## Restore `aerich` workflow ## Support this project
In some cases, such as broken changes from upgrade of `aerich`, you can't run `aerich migrate` or `aerich upgrade`, you | AliPay | WeChatPay | PayPal |
can make the following steps: | -------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | ---------------------------------------------------------------- |
| <img width="200" src="https://github.com/long2ice/aerich/raw/dev/images/alipay.jpeg"/> | <img width="200" src="https://github.com/long2ice/aerich/raw/dev/images/wechatpay.jpeg"/> | [PayPal](https://www.paypal.me/long2ice) to my account long2ice. |
1. drop `aerich` table.
2. delete `migrations/{app}` directory.
3. rerun `aerich init-db`.
Note that these actions is safe, also you can do that to reset your migrations if your migration files is too many.
## Use `aerich` in application
You can use `aerich` out of cli by use `Command` class.
```python
from aerich import Command
async with Command(tortoise_config=config, app='models') as command:
await command.migrate('test')
await command.upgrade()
```
## Upgrade/Downgrade with `--fake` option
Marks the migrations up to the latest one(or back to the target one) as applied, but without actually running the SQL to change your database schema.
- Upgrade
```bash
aerich upgrade --fake
aerich --app models upgrade --fake
```
- Downgrade
```bash
aerich downgrade --fake -v 2
aerich --app models downgrade --fake -v 2
```
### Ignore tables
You can tell aerich to ignore table by setting `managed=False` in the `Meta` class, e.g.:
```py
class MyModel(Model):
class Meta:
managed = False
```
**Note** `managed=False` does not recognized by `tortoise-orm` and `aerich init-db`, it is only for `aerich migrate`.
## License ## License

View File

@ -1,274 +0,0 @@
# Aerich
[![image](https://img.shields.io/pypi/v/aerich.svg?style=flat)](https://pypi.python.org/pypi/aerich)
[![image](https://img.shields.io/github/license/tortoise/aerich)](https://github.com/tortoise/aerich)
[![image](https://github.com/tortoise/aerich/workflows/pypi/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:pypi)
[![image](https://github.com/tortoise/aerich/workflows/ci/badge.svg)](https://github.com/tortoise/aerich/actions?query=workflow:ci)
[English](./README.md) | Русский
## Введение
Aerich - это инструмент для миграции базы данных для TortoiseORM, который аналогичен Alembic для SQLAlchemy или встроенному решению миграций в Django ORM.
## Установка
Просто установите из pypi:
```shell
pip install aerich
```
## Быстрый старт
```shell
> aerich -h
Usage: aerich [OPTIONS] COMMAND [ARGS]...
Options:
-V, --version Show the version and exit.
-c, --config TEXT Config file. [default: pyproject.toml]
--app TEXT Tortoise-ORM app name.
-h, --help Show this message and exit.
Commands:
downgrade Downgrade to specified version.
heads Show current available heads in migrate location.
history List all migrate items.
init Init config file and generate root migrate location.
init-db Generate schema and generate app migrate location.
inspectdb Introspects the database tables to standard output as...
migrate Generate migrate changes file.
upgrade Upgrade to specified version.
```
## Использование
Сначала вам нужно добавить aerich.models в конфигурацию вашего Tortoise-ORM. Пример:
```python
TORTOISE_ORM = {
"connections": {"default": "mysql://root:123456@127.0.0.1:3306/test"},
"apps": {
"models": {
"models": ["tests.models", "aerich.models"],
"default_connection": "default",
},
},
}
```
### Инициализация
```shell
> aerich init -h
Usage: aerich init [OPTIONS]
Init config file and generate root migrate location.
Options:
-t, --tortoise-orm TEXT Tortoise-ORM config module dict variable, like
settings.TORTOISE_ORM. [required]
--location TEXT Migrate store location. [default: ./migrations]
-s, --src_folder TEXT Folder of the source, relative to the project root.
-h, --help Show this message and exit.
```
Инициализируйте файл конфигурации и задайте местоположение миграций:
```shell
> aerich init -t tests.backends.mysql.TORTOISE_ORM
Success create migrate location ./migrations
Success write config to pyproject.toml
```
### Инициализация базы данных
```shell
> aerich init-db
Success create app migrate location ./migrations/models
Success generate schema for app "models"
```
Если ваше приложение Tortoise-ORM не является приложением по умолчанию с именем models, вы должны указать правильное имя приложения с помощью параметра --app, например: aerich --app other_models init-db.
### Обновление моделей и создание миграции
```shell
> aerich migrate --name drop_column
Success migrate 1_202029051520102929_drop_column.py
```
Формат имени файла миграции следующий: `{версия}_{дата_и_время}_{имя|обновление}.py`.
Если aerich предполагает, что вы переименовываете столбец, он спросит:
Переименовать `{старый_столбец} в {новый_столбец} [True]`. Вы можете выбрать `True`,
чтобы переименовать столбец без удаления столбца, или выбрать `False`, чтобы удалить столбец,
а затем создать новый. Обратите внимание, что последний вариант может привести к потере данных.
### Обновление до последней версии
```shell
> aerich upgrade
Success upgrade 1_202029051520102929_drop_column.py
```
Теперь ваша база данных обновлена до последней версии.
### Откат до указанной версии
```shell
> aerich downgrade -h
Usage: aerich downgrade [OPTIONS]
Downgrade to specified version.
Options:
-v, --version INTEGER Specified version, default to last. [default: -1]
-d, --delete Delete version files at the same time. [default:
False]
--yes Confirm the action without prompting.
-h, --help Show this message and exit.
```
```shell
> aerich downgrade
Success downgrade 1_202029051520102929_drop_column.py
```
Теперь ваша база данных откатилась до указанной версии.
### Показать историю
```shell
> aerich history
1_202029051520102929_drop_column.py
```
### Чтобы узнать, какие миграции должны быть применены, можно использовать команду:
```shell
> aerich heads
1_202029051520102929_drop_column.py
```
### Осмотр таблиц базы данных для модели TortoiseORM
В настоящее время inspectdb поддерживает MySQL, Postgres и SQLite.
```shell
Usage: aerich inspectdb [OPTIONS]
Introspects the database tables to standard output as TortoiseORM model.
Options:
-t, --table TEXT Which tables to inspect.
-h, --help Show this message and exit.
```
Посмотреть все таблицы и вывести их на консоль:
```shell
aerich --app models inspectdb
```
Осмотреть указанную таблицу в приложении по умолчанию и перенаправить в models.py:
```shell
aerich inspectdb -t user > models.py
```
Например, ваша таблица выглядит следующим образом:
```sql
CREATE TABLE `test`
(
`id` int NOT NULL AUTO_INCREMENT,
`decimal` decimal(10, 2) NOT NULL,
`date` date DEFAULT NULL,
`datetime` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
`time` time DEFAULT NULL,
`float` float DEFAULT NULL,
`string` varchar(200) COLLATE utf8mb4_general_ci DEFAULT NULL,
`tinyint` tinyint DEFAULT NULL,
PRIMARY KEY (`id`),
KEY `asyncmy_string_index` (`string`)
) ENGINE = InnoDB
DEFAULT CHARSET = utf8mb4
COLLATE = utf8mb4_general_ci
```
Теперь выполните команду aerich inspectdb -t test, чтобы увидеть сгенерированную модель:
```python
from tortoise import Model, fields
class Test(Model):
date = fields.DateField(null=True, )
datetime = fields.DatetimeField(auto_now=True, )
decimal = fields.DecimalField(max_digits=10, decimal_places=2, )
float = fields.FloatField(null=True, )
id = fields.IntField(pk=True, )
string = fields.CharField(max_length=200, null=True, )
time = fields.TimeField(null=True, )
tinyint = fields.BooleanField(null=True, )
```
Обратите внимание, что эта команда имеет ограничения и не может автоматически определить некоторые поля, такие как `IntEnumField`, `ForeignKeyField` и другие.
### Несколько баз данных
```python
tortoise_orm = {
"connections": {
"default": expand_db_url(db_url, True),
"second": expand_db_url(db_url_second, True),
},
"apps": {
"models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"},
"models_second": {"models": ["tests.models_second"], "default_connection": "second", },
},
}
```
Вам нужно указать `aerich.models` только в одном приложении и должны указывать `--app` при запуске команды `aerich migrate` и т.д.
## Восстановление рабочего процесса aerich
В некоторых случаях, например, при возникновении проблем после обновления `aerich`, вы не можете запустить `aerich migrate` или `aerich upgrade`. В таком случае вы можете выполнить следующие шаги:
1. удалите таблицы `aerich`.
2. удалите директорию `migrations/{app}`.
3. rerun `aerich init-db`.
Обратите внимание, что эти действия безопасны, и вы можете использовать их для сброса миграций, если у вас слишком много файлов миграции.
## Использование aerich в приложении
Вы можете использовать `aerich` вне командной строки, используя класс `Command`.
```python
from aerich import Command
command = Command(tortoise_config=config, app='models')
await command.init()
await command.migrate('test')
```
## Лицензия
Этот проект лицензирован в соответствии с лицензией
[Apache-2.0](https://github.com/long2ice/aerich/blob/master/LICENSE) Лицензия.

View File

@ -1,282 +1 @@
from __future__ import annotations __version__ = "0.4.0"
import os
import platform
from contextlib import AbstractAsyncContextManager
from pathlib import Path
from typing import TYPE_CHECKING
import tortoise
from tortoise import Tortoise, connections, generate_schema_for_client
from tortoise.exceptions import OperationalError
from tortoise.transactions import in_transaction
from tortoise.utils import get_schema_sql
from aerich.exceptions import DowngradeError
from aerich.inspectdb.mysql import InspectMySQL
from aerich.inspectdb.postgres import InspectPostgres
from aerich.inspectdb.sqlite import InspectSQLite
from aerich.migrate import MIGRATE_TEMPLATE, Migrate
from aerich.models import Aerich
from aerich.utils import (
get_app_connection,
get_app_connection_name,
get_models_describe,
import_py_file,
)
if TYPE_CHECKING:
from tortoise import Model
from tortoise.fields.relational import ManyToManyFieldInstance # NOQA:F401
from aerich.inspectdb import Inspect
def _init_asyncio_patch():
"""
Select compatible event loop for psycopg3.
As of Python 3.8+, the default event loop on Windows is `proactor`,
however psycopg3 requires the old default "selector" event loop.
See https://www.psycopg.org/psycopg3/docs/advanced/async.html
"""
if platform.system() == "Windows":
try:
from asyncio import WindowsSelectorEventLoopPolicy # type:ignore
except ImportError:
pass # Can't assign a policy which doesn't exist.
else:
from asyncio import get_event_loop_policy, set_event_loop_policy
if not isinstance(get_event_loop_policy(), WindowsSelectorEventLoopPolicy):
set_event_loop_policy(WindowsSelectorEventLoopPolicy())
def _init_tortoise_0_24_1_patch():
# this patch is for "tortoise-orm==0.24.1" to fix:
# https://github.com/tortoise/tortoise-orm/issues/1893
if tortoise.__version__ != "0.24.1":
return
from tortoise.backends.base.schema_generator import BaseSchemaGenerator, cast, re
def _get_m2m_tables(
self, model: type[Model], db_table: str, safe: bool, models_tables: list[str]
) -> list[str]: # Copied from tortoise-orm
m2m_tables_for_create = []
for m2m_field in model._meta.m2m_fields:
field_object = cast("ManyToManyFieldInstance", model._meta.fields_map[m2m_field])
if field_object._generated or field_object.through in models_tables:
continue
backward_key, forward_key = field_object.backward_key, field_object.forward_key
if field_object.db_constraint:
backward_fk = self._create_fk_string(
"",
backward_key,
db_table,
model._meta.db_pk_column,
field_object.on_delete,
"",
)
forward_fk = self._create_fk_string(
"",
forward_key,
field_object.related_model._meta.db_table,
field_object.related_model._meta.db_pk_column,
field_object.on_delete,
"",
)
else:
backward_fk = forward_fk = ""
exists = "IF NOT EXISTS " if safe else ""
through_table_name = field_object.through
backward_type = self._get_pk_field_sql_type(model._meta.pk)
forward_type = self._get_pk_field_sql_type(field_object.related_model._meta.pk)
comment = ""
if desc := field_object.description:
comment = self._table_comment_generator(table=through_table_name, comment=desc)
m2m_create_string = self.M2M_TABLE_TEMPLATE.format(
exists=exists,
table_name=through_table_name,
backward_fk=backward_fk,
forward_fk=forward_fk,
backward_key=backward_key,
backward_type=backward_type,
forward_key=forward_key,
forward_type=forward_type,
extra=self._table_generate_extra(table=field_object.through),
comment=comment,
)
if not field_object.db_constraint:
m2m_create_string = m2m_create_string.replace(
""",
,
""",
"",
) # may have better way
m2m_create_string += self._post_table_hook()
if getattr(field_object, "create_unique_index", field_object.unique):
unique_index_create_sql = self._get_unique_index_sql(
exists, through_table_name, [backward_key, forward_key]
)
if unique_index_create_sql.endswith(";"):
m2m_create_string += "\n" + unique_index_create_sql
else:
lines = m2m_create_string.splitlines()
lines[-2] += ","
indent = m.group() if (m := re.match(r"\s+", lines[-2])) else ""
lines.insert(-1, indent + unique_index_create_sql)
m2m_create_string = "\n".join(lines)
m2m_tables_for_create.append(m2m_create_string)
return m2m_tables_for_create
setattr(BaseSchemaGenerator, "_get_m2m_tables", _get_m2m_tables)
_init_asyncio_patch()
_init_tortoise_0_24_1_patch()
class Command(AbstractAsyncContextManager):
def __init__(
self,
tortoise_config: dict,
app: str = "models",
location: str = "./migrations",
) -> None:
self.tortoise_config = tortoise_config
self.app = app
self.location = location
Migrate.app = app
async def init(self) -> None:
await Migrate.init(self.tortoise_config, self.app, self.location)
async def __aenter__(self) -> Command:
await self.init()
return self
async def close(self) -> None:
await connections.close_all()
async def __aexit__(self, *args, **kw) -> None:
await self.close()
async def _upgrade(self, conn, version_file, fake: bool = False) -> None:
file_path = Path(Migrate.migrate_location, version_file)
m = import_py_file(file_path)
upgrade = m.upgrade
if not fake:
await conn.execute_script(await upgrade(conn))
await Aerich.create(
version=version_file,
app=self.app,
content=get_models_describe(self.app),
)
async def upgrade(self, run_in_transaction: bool = True, fake: bool = False) -> list[str]:
migrated = []
for version_file in Migrate.get_all_version_files():
try:
exists = await Aerich.exists(version=version_file, app=self.app)
except OperationalError:
exists = False
if not exists:
app_conn_name = get_app_connection_name(self.tortoise_config, self.app)
if run_in_transaction:
async with in_transaction(app_conn_name) as conn:
await self._upgrade(conn, version_file, fake=fake)
else:
app_conn = get_app_connection(self.tortoise_config, self.app)
await self._upgrade(app_conn, version_file, fake=fake)
migrated.append(version_file)
return migrated
async def downgrade(self, version: int, delete: bool, fake: bool = False) -> list[str]:
ret: list[str] = []
if version == -1:
specified_version = await Migrate.get_last_version()
else:
specified_version = await Aerich.filter(
app=self.app, version__startswith=f"{version}_"
).first()
if not specified_version:
raise DowngradeError("No specified version found")
if version == -1:
versions = [specified_version]
else:
versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk)
for version_obj in versions:
file = version_obj.version
async with in_transaction(
get_app_connection_name(self.tortoise_config, self.app)
) as conn:
file_path = Path(Migrate.migrate_location, file)
m = import_py_file(file_path)
downgrade = m.downgrade
downgrade_sql = await downgrade(conn)
if not downgrade_sql.strip():
raise DowngradeError("No downgrade items found")
if not fake:
await conn.execute_script(downgrade_sql)
await version_obj.delete()
if delete:
os.unlink(file_path)
ret.append(file)
return ret
async def heads(self) -> list[str]:
ret = []
versions = Migrate.get_all_version_files()
for version in versions:
if not await Aerich.exists(version=version, app=self.app):
ret.append(version)
return ret
async def history(self) -> list[str]:
versions = Migrate.get_all_version_files()
return [version for version in versions]
async def inspectdb(self, tables: list[str] | None = None) -> str:
connection = get_app_connection(self.tortoise_config, self.app)
dialect = connection.schema_generator.DIALECT
if dialect == "mysql":
cls: type[Inspect] = InspectMySQL
elif dialect == "postgres":
cls = InspectPostgres
elif dialect == "sqlite":
cls = InspectSQLite
else:
raise NotImplementedError(f"{dialect} is not supported")
inspect = cls(connection, tables)
return await inspect.inspect()
async def migrate(self, name: str = "update", empty: bool = False) -> str:
return await Migrate.migrate(name, empty)
async def init_db(self, safe: bool) -> None:
location = self.location
app = self.app
dirname = Path(location, app)
if not dirname.exists():
dirname.mkdir(parents=True)
else:
# If directory is empty, go ahead, otherwise raise FileExistsError
for unexpected_file in dirname.glob("*"):
raise FileExistsError(str(unexpected_file))
await Tortoise.init(config=self.tortoise_config)
connection = get_app_connection(self.tortoise_config, app)
await generate_schema_for_client(connection, safe)
schema = get_schema_sql(connection, safe)
version = await Migrate.generate_version()
await Aerich.create(
version=version,
app=app,
content=get_models_describe(app),
)
version_file = Path(dirname, version)
content = MIGRATE_TEMPLATE.format(upgrade_sql=schema, downgrade_sql="")
with open(version_file, "w", encoding="utf-8") as f:
f.write(content)

View File

@ -1,3 +0,0 @@
from .cli import main
main()

View File

@ -1,28 +0,0 @@
# mypy: disable-error-code="no-redef"
from __future__ import annotations
import sys
from types import ModuleType
import tortoise
if sys.version_info >= (3, 11):
import tomllib
else:
try:
import tomli as tomllib
except ImportError:
import tomlkit as tomllib
def imports_tomlkit() -> ModuleType:
try:
import tomli_w as tomlkit
except ImportError:
import tomlkit
return tomlkit
def tortoise_version_less_than(version: str) -> bool:
# The min version of tortoise is '0.11.0', so we can compare it by a `<`,
return tortoise.__version__ < version

View File

@ -1,38 +1,44 @@
from __future__ import annotations import asyncio
import os import os
from pathlib import Path import sys
from typing import cast from configparser import ConfigParser
from functools import wraps
import asyncclick as click import click
from asyncclick import Context, UsageError from click import Context, UsageError
from tortoise import Tortoise, generate_schema_for_client
from tortoise.exceptions import OperationalError
from tortoise.transactions import in_transaction
from tortoise.utils import get_schema_sql
from aerich import Command from aerich.migrate import Migrate
from aerich._compat import imports_tomlkit, tomllib from aerich.utils import (
from aerich.enums import Color get_app_connection,
from aerich.exceptions import DowngradeError get_app_connection_name,
from aerich.utils import add_src_path, get_tortoise_config get_tortoise_config,
from aerich.version import __version__ get_version_content_from_file,
write_version_file,
)
CONFIG_DEFAULT_VALUES = { from . import __version__
"src_folder": ".", from .enums import Color
} from .models import Aerich
parser = ConfigParser()
def _patch_context_to_close_tortoise_connections_when_exit() -> None: def coro(f):
from tortoise import Tortoise, connections @wraps(f)
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
ctx = args[0]
loop.run_until_complete(f(*args, **kwargs))
loop.run_until_complete(Tortoise.close_connections())
app = ctx.obj.get("app")
if app:
Migrate.remove_old_model_file(app, ctx.obj["location"])
origin_aexit = Context.__aexit__ return wrapper
async def aexit(*args, **kw) -> None:
await origin_aexit(*args, **kw)
if Tortoise._inited:
await connections.close_all()
Context.__aexit__ = aexit # type:ignore[method-assign]
_patch_context_to_close_tortoise_connections_when_exit()
@click.group(context_settings={"help_option_names": ["-h", "--help"]}) @click.group(context_settings={"help_option_names": ["-h", "--help"]})
@ -40,92 +46,85 @@ _patch_context_to_close_tortoise_connections_when_exit()
@click.option( @click.option(
"-c", "-c",
"--config", "--config",
default="pyproject.toml", default="aerich.ini",
show_default=True, show_default=True,
help="Config file.", help="Config file.",
) )
@click.option("--app", required=False, help="Tortoise-ORM app name.") @click.option("--app", required=False, help="Tortoise-ORM app name.")
@click.option(
"-n",
"--name",
default="aerich",
show_default=True,
help="Name of section in .ini file to use for aerich config.",
)
@click.pass_context @click.pass_context
async def cli(ctx: Context, config, app) -> None: @coro
async def cli(ctx: Context, config, app, name):
ctx.ensure_object(dict) ctx.ensure_object(dict)
ctx.obj["config_file"] = config ctx.obj["config_file"] = config
ctx.obj["name"] = name
invoked_subcommand = ctx.invoked_subcommand invoked_subcommand = ctx.invoked_subcommand
if invoked_subcommand != "init": if invoked_subcommand != "init":
config_path = Path(config) if not os.path.exists(config):
if not config_path.exists(): raise UsageError("You must exec init first", ctx=ctx)
raise UsageError( parser.read(config)
"You need to run `aerich init` first to create the config file.", ctx=ctx
) location = parser[name]["location"]
content = config_path.read_text("utf-8") tortoise_orm = parser[name]["tortoise_orm"]
doc: dict = tomllib.loads(content)
try:
tool = cast("dict[str, str]", doc["tool"]["aerich"])
location = tool["location"]
tortoise_orm = tool["tortoise_orm"]
src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"])
except KeyError as e:
raise UsageError(
"You need run `aerich init` again when upgrading to aerich 0.6.0+."
) from e
add_src_path(src_folder)
tortoise_config = get_tortoise_config(ctx, tortoise_orm) tortoise_config = get_tortoise_config(ctx, tortoise_orm)
if not app: app = app or list(tortoise_config.get("apps").keys())[0]
try: ctx.obj["config"] = tortoise_config
apps_config = cast(dict, tortoise_config["apps"]) ctx.obj["location"] = location
except KeyError: ctx.obj["app"] = app
raise UsageError('Config must define "apps" section') Migrate.app = app
app = list(apps_config.keys())[0]
command = Command(tortoise_config=tortoise_config, app=app, location=location)
ctx.obj["command"] = command
if invoked_subcommand != "init-db": if invoked_subcommand != "init-db":
if not Path(location, app).exists(): await Migrate.init_with_old_models(tortoise_config, app, location)
raise UsageError(
"You need to run `aerich init-db` first to initialize the database.", ctx=ctx
)
await command.init()
@cli.command(help="Generate a migration file for the current state of the models.") @cli.command(help="Generate migrate changes file.")
@click.option("--name", default="update", show_default=True, help="Migration name.") @click.option("--name", default="update", show_default=True, help="Migrate name.")
@click.option("--empty", default=False, is_flag=True, help="Generate an empty migration file.")
@click.pass_context @click.pass_context
async def migrate(ctx: Context, name, empty) -> None: @coro
command = ctx.obj["command"] async def migrate(ctx: Context, name):
ret = await command.migrate(name, empty) ret = await Migrate.migrate(name)
if not ret: if not ret:
return click.secho("No changes detected", fg=Color.yellow) return click.secho("No changes detected", fg=Color.yellow)
click.secho(f"Success creating migration file {ret}", fg=Color.green) click.secho(f"Success migrate {ret}", fg=Color.green)
@cli.command(help="Upgrade to specified migration version.") @cli.command(help="Upgrade to specified version.")
@click.option(
"--in-transaction",
"-i",
default=True,
type=bool,
help="Make migrations in a single transaction or not. Can be helpful for large migrations or creating concurrent indexes.",
)
@click.option(
"--fake",
default=False,
is_flag=True,
help="Mark migrations as run without actually running them.",
)
@click.pass_context @click.pass_context
async def upgrade(ctx: Context, in_transaction: bool, fake: bool) -> None: @coro
command = ctx.obj["command"] async def upgrade(ctx: Context):
migrated = await command.upgrade(run_in_transaction=in_transaction, fake=fake) config = ctx.obj["config"]
if not migrated: app = ctx.obj["app"]
click.secho("No upgrade items found", fg=Color.yellow) location = ctx.obj["location"]
else: migrated = False
for version_file in migrated: for version_file in Migrate.get_all_version_files():
if fake: try:
click.echo( exists = await Aerich.exists(version=version_file, app=app)
f"Upgrading to {version_file}... " + click.style("FAKED", fg=Color.green) except OperationalError:
exists = False
if not exists:
async with in_transaction(get_app_connection_name(config, app)) as conn:
file_path = os.path.join(Migrate.migrate_location, version_file)
content = get_version_content_from_file(file_path)
upgrade_query_list = content.get("upgrade")
print(upgrade_query_list)
for upgrade_query in upgrade_query_list:
await conn.execute_script(upgrade_query)
await Aerich.create(
version=version_file,
app=app,
content=Migrate.get_models_content(config, app, location),
) )
else: click.secho(f"Success upgrade {version_file}", fg=Color.green)
click.secho(f"Success upgrading to {version_file}", fg=Color.green) migrated = True
if not migrated:
click.secho("No migrate items", fg=Color.yellow)
@cli.command(help="Downgrade to specified version.") @cli.command(help="Downgrade to specified version.")
@ -134,8 +133,8 @@ async def upgrade(ctx: Context, in_transaction: bool, fake: bool) -> None:
"--version", "--version",
default=-1, default=-1,
type=int, type=int,
show_default=False, show_default=True,
help="Specified version, default to last migration.", help="Specified version, default to last.",
) )
@click.option( @click.option(
"-d", "-d",
@ -143,154 +142,152 @@ async def upgrade(ctx: Context, in_transaction: bool, fake: bool) -> None:
is_flag=True, is_flag=True,
default=False, default=False,
show_default=True, show_default=True,
help="Also delete the migration files.", help="Delete version files at the same time.",
)
@click.option(
"--fake",
default=False,
is_flag=True,
help="Mark migrations as run without actually running them.",
) )
@click.pass_context @click.pass_context
@click.confirmation_option( @click.confirmation_option(
prompt="Downgrade is dangerous: you might lose your data! Are you sure?", prompt="Downgrade is dangerous, which maybe lose your data, are you sure?",
) )
async def downgrade(ctx: Context, version: int, delete: bool, fake: bool) -> None: @coro
command = ctx.obj["command"] async def downgrade(ctx: Context, version: int, delete: bool):
try: app = ctx.obj["app"]
files = await command.downgrade(version, delete, fake=fake) config = ctx.obj["config"]
except DowngradeError as e: if version == -1:
return click.secho(str(e), fg=Color.yellow) specified_version = await Migrate.get_last_version()
for file in files:
if fake:
click.echo(f"Downgrading to {file}... " + click.style("FAKED", fg=Color.green))
else: else:
click.secho(f"Success downgrading to {file}", fg=Color.green) specified_version = await Aerich.filter(app=app, version__startswith=f"{version}_").first()
if not specified_version:
return click.secho("No specified version found", fg=Color.yellow)
if version == -1:
versions = [specified_version]
else:
versions = await Aerich.filter(app=app, pk__gte=specified_version.pk)
for version in versions:
file = version.version
async with in_transaction(get_app_connection_name(config, app)) as conn:
file_path = os.path.join(Migrate.migrate_location, file)
content = get_version_content_from_file(file_path)
downgrade_query_list = content.get("downgrade")
if not downgrade_query_list:
return click.secho("No downgrade items found", fg=Color.yellow)
for downgrade_query in downgrade_query_list:
await conn.execute_query(downgrade_query)
await version.delete()
if delete:
os.unlink(file_path)
click.secho(f"Success downgrade {file}", fg=Color.green)
@cli.command(help="Show currently available heads (unapplied migrations).") @cli.command(help="Show current available heads in migrate location.")
@click.pass_context @click.pass_context
async def heads(ctx: Context) -> None: @coro
command = ctx.obj["command"] async def heads(ctx: Context):
head_list = await command.heads() app = ctx.obj["app"]
if not head_list: versions = Migrate.get_all_version_files()
return click.secho("No available heads.", fg=Color.green) is_heads = False
for version in head_list: for version in versions:
if not await Aerich.exists(version=version, app=app):
click.secho(version, fg=Color.green) click.secho(version, fg=Color.green)
is_heads = True
if not is_heads:
click.secho("No available heads,try migrate first", fg=Color.green)
@cli.command(help="List all migrations.") @cli.command(help="List all migrate items.")
@click.pass_context @click.pass_context
async def history(ctx: Context) -> None: @coro
command = ctx.obj["command"] async def history(ctx: Context):
versions = await command.history() versions = Migrate.get_all_version_files()
if not versions:
return click.secho("No migrations created yet.", fg=Color.green)
for version in versions: for version in versions:
click.secho(version, fg=Color.green) click.secho(version, fg=Color.green)
if not versions:
click.secho("No history,try migrate", fg=Color.green)
def _write_config(config_path, doc, table) -> None: @cli.command(help="Init config file and generate root migrate location.")
tomlkit = imports_tomlkit()
try:
doc["tool"]["aerich"] = table
except KeyError:
doc["tool"] = {"aerich": table}
config_path.write_text(tomlkit.dumps(doc))
@cli.command(help="Initialize aerich config and create migrations folder.")
@click.option( @click.option(
"-t", "-t",
"--tortoise-orm", "--tortoise-orm",
required=True, required=True,
help="Tortoise-ORM config dict location, like `settings.TORTOISE_ORM`.", help="Tortoise-ORM config module dict variable, like settings.TORTOISE_ORM.",
) )
@click.option( @click.option(
"--location", "--location",
default="./migrations", default="./migrations",
show_default=True, show_default=True,
help="Migrations folder.", help="Migrate store location.",
)
@click.option(
"-s",
"--src_folder",
default=CONFIG_DEFAULT_VALUES["src_folder"],
show_default=False,
help="Folder of the source, relative to the project root.",
) )
@click.pass_context @click.pass_context
async def init(ctx: Context, tortoise_orm, location, src_folder) -> None: @coro
async def init(
ctx: Context,
tortoise_orm,
location,
):
config_file = ctx.obj["config_file"] config_file = ctx.obj["config_file"]
name = ctx.obj["name"]
if os.path.exists(config_file):
return click.secho("You have inited", fg=Color.yellow)
if os.path.isabs(src_folder): parser.add_section(name)
src_folder = os.path.relpath(os.getcwd(), src_folder) parser.set(name, "tortoise_orm", tortoise_orm)
# Add ./ so it's clear that this is relative path parser.set(name, "location", location)
if not src_folder.startswith("./"):
src_folder = "./" + src_folder
# check that we can find the configuration, if not we can fail before the config file gets created with open(config_file, "w", encoding="utf-8") as f:
add_src_path(src_folder) parser.write(f)
get_tortoise_config(ctx, tortoise_orm)
config_path = Path(config_file)
content = config_path.read_text("utf-8") if config_path.exists() else "[tool.aerich]"
doc: dict = tomllib.loads(content)
table = {"tortoise_orm": tortoise_orm, "location": location, "src_folder": src_folder} if not os.path.isdir(location):
if (aerich_config := doc.get("tool", {}).get("aerich")) and all( os.mkdir(location)
aerich_config.get(k) == v for k, v in table.items()
):
click.echo(f"Aerich config {config_file} already inited.")
else:
_write_config(config_path, doc, table)
click.secho(f"Success writing aerich config to {config_file}", fg=Color.green)
Path(location).mkdir(parents=True, exist_ok=True) click.secho(f"Success create migrate location {location}", fg=Color.green)
click.secho(f"Success creating migrations folder {location}", fg=Color.green) click.secho(f"Success generate config file {config_file}", fg=Color.green)
@cli.command(help="Generate schema and generate app migration folder.") @cli.command(help="Generate schema and generate app migrate location.")
@click.option( @click.option(
"-s",
"--safe", "--safe",
type=bool, type=bool,
is_flag=True,
default=True, default=True,
help="Create tables only when they do not already exist.", help="When set to true, creates the table only when it does not already exist.",
show_default=True, show_default=True,
) )
@click.pass_context @click.pass_context
async def init_db(ctx: Context, safe: bool) -> None: @coro
command = ctx.obj["command"] async def init_db(ctx: Context, safe):
app = command.app config = ctx.obj["config"]
dirname = Path(command.location, app) location = ctx.obj["location"]
try: app = ctx.obj["app"]
await command.init_db(safe)
click.secho(f"Success creating app migration folder {dirname}", fg=Color.green) dirname = os.path.join(location, app)
click.secho(f'Success generating initial migration file for app "{app}"', fg=Color.green) if not os.path.isdir(dirname):
except FileExistsError: os.mkdir(dirname)
click.secho(f"Success create app migrate location {dirname}", fg=Color.green)
else:
return click.secho( return click.secho(
f"App {app} is already initialized. Delete {dirname} and try again.", fg=Color.yellow f"Inited {app} already, or delete {dirname} and try again.", fg=Color.yellow
) )
await Tortoise.init(config=config)
connection = get_app_connection(config, app)
await generate_schema_for_client(connection, safe)
@cli.command(help="Prints the current database tables to stdout as Tortoise-ORM models.") schema = get_schema_sql(connection, safe)
@click.option(
"-t", version = await Migrate.generate_version()
"--table", await Aerich.create(
help="Which tables to inspect.", version=version,
multiple=True, app=app,
required=False, content=Migrate.get_models_content(config, app, location),
) )
@click.pass_context content = {
async def inspectdb(ctx: Context, table: list[str]) -> None: "upgrade": [schema],
command = ctx.obj["command"] }
ret = await command.inspectdb(table) write_version_file(os.path.join(dirname, version), content)
click.secho(ret) click.secho(f'Success generate schema for app "{app}"', fg=Color.green)
def main() -> None: def main():
sys.path.insert(0, ".")
cli() cli()

View File

@ -1,49 +0,0 @@
from __future__ import annotations
import base64
import json
import pickle # nosec: B301,B403
from typing import Any
from tortoise.indexes import Index
class JsonEncoder(json.JSONEncoder):
def default(self, obj) -> Any:
if isinstance(obj, Index):
if hasattr(obj, "describe"):
# For tortoise>=0.24
return obj.describe()
return {
"type": "index",
"val": base64.b64encode(pickle.dumps(obj)).decode(), # nosec: B301
}
else:
return super().default(obj)
def object_hook(obj) -> Any:
if (type_ := obj.get("type")) and type_ == "index" and (val := obj.get("val")):
return pickle.loads(base64.b64decode(val)) # nosec: B301
return obj
def load_index(obj: dict) -> Index:
"""Convert a dict that generated by `Index.decribe()` to a Index instance"""
try:
index = Index(fields=obj["fields"] or obj["expressions"], name=obj.get("name"))
except KeyError:
return object_hook(obj)
if extra := obj.get("extra"):
index.extra = extra
if idx_type := obj.get("type"):
index.INDEX_TYPE = idx_type
return index
def encoder(obj: dict) -> str:
return json.dumps(obj, cls=JsonEncoder)
def decoder(obj: str | bytes) -> Any:
return json.loads(obj, object_hook=object_hook)

View File

@ -1,108 +1,75 @@
from __future__ import annotations from typing import List, Type
import re
from enum import Enum
from typing import TYPE_CHECKING, Any, cast
from tortoise import BaseDBAsyncClient, ForeignKeyFieldInstance, ManyToManyFieldInstance, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.fields import CASCADE, Field, JSONField, TextField, UUIDField
from aerich._compat import tortoise_version_less_than
from aerich.utils import is_default_function
if TYPE_CHECKING:
from tortoise import BaseDBAsyncClient, Model
class BaseDDL: class BaseDDL:
schema_generator_cls: type[BaseSchemaGenerator] = BaseSchemaGenerator schema_generator_cls: Type[BaseSchemaGenerator] = BaseSchemaGenerator
DIALECT = "sql" DIALECT = "sql"
_DROP_TABLE_TEMPLATE = 'DROP TABLE IF EXISTS "{table_name}"' _DROP_TABLE_TEMPLATE = 'DROP TABLE IF EXISTS "{table_name}"'
_ADD_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {column}' _ADD_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {column}'
_DROP_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" DROP COLUMN "{column_name}"' _DROP_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" DROP COLUMN "{column_name}"'
_ALTER_DEFAULT_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {default}'
_RENAME_COLUMN_TEMPLATE = ( _RENAME_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" RENAME COLUMN "{old_column_name}" TO "{new_column_name}"' 'ALTER TABLE "{table_name}" RENAME COLUMN "{old_column_name}" TO "{new_column_name}"'
) )
_ADD_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {index_type}{unique}INDEX "{index_name}" ({column_names}){extra}' _ADD_INDEX_TEMPLATE = (
_DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX IF EXISTS "{index_name}"' 'ALTER TABLE "{table_name}" ADD {unique} INDEX "{index_name}" ({column_names})'
)
_DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"'
_ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}' _ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"'
_M2M_TABLE_TEMPLATE = ( _M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment};'
'CREATE TABLE "{table_name}" (\n'
' "{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,\n'
' "{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}\n'
"){extra}{comment}"
)
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}' _MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}'
_CHANGE_COLUMN_TEMPLATE = ( _CHANGE_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" CHANGE {old_column_name} {new_column_name} {new_column_type}' 'ALTER TABLE "{table_name}" CHANGE {old_column_name} {new_column_name} {new_column_type}'
) )
_RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"'
def __init__(self, client: BaseDBAsyncClient) -> None: def __init__(self, client: "BaseDBAsyncClient"):
self.client = client self.client = client
self.schema_generator = self.schema_generator_cls(client) self.schema_generator = self.schema_generator_cls(client)
@staticmethod def create_table(self, model: "Type[Model]"):
def get_table_name(model: type[Model]) -> str: return self.schema_generator._get_table_sql(model, True)["table_creation_string"]
return model._meta.db_table
def create_table(self, model: type[Model]) -> str: def drop_table(self, model: "Type[Model]"):
schema = self.schema_generator._get_table_sql(model, True)["table_creation_string"] return self._DROP_TABLE_TEMPLATE.format(table_name=model._meta.db_table)
if tortoise_version_less_than("0.23.1"):
# Remove extra space
schema = re.sub(r'(["()A-Za-z]) (["()A-Za-z])', r"\1 \2", schema)
return schema.rstrip(";")
def drop_table(self, table_name: str) -> str: def create_m2m_table(self, model: "Type[Model]", field: ManyToManyFieldInstance):
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def create_m2m(
self, model: type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
through = cast(str, field_describe.get("through"))
description = field_describe.get("description")
pk_field = cast(dict, reference_table_describe.get("pk_field"))
reference_id = pk_field.get("db_column")
db_field_types = cast(dict, pk_field.get("db_field_types"))
return self._M2M_TABLE_TEMPLATE.format( return self._M2M_TABLE_TEMPLATE.format(
table_name=through, table_name=field.through,
backward_table=model._meta.db_table, backward_table=model._meta.db_table,
forward_table=reference_table_describe.get("table"), forward_table=field.related_model._meta.db_table,
backward_field=model._meta.db_pk_column, backward_field=model._meta.db_pk_column,
forward_field=reference_id, forward_field=field.related_model._meta.db_pk_column,
backward_key=field_describe.get("backward_key"), backward_key=field.backward_key,
backward_type=model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"), backward_type=model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"),
forward_key=field_describe.get("forward_key"), forward_key=field.forward_key,
forward_type=db_field_types.get(self.DIALECT) or db_field_types.get(""), forward_type=field.related_model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"),
on_delete=field_describe.get("on_delete"), on_delete=CASCADE,
extra=self.schema_generator._table_generate_extra(table=through), extra=self.schema_generator._table_generate_extra(table=field.through),
comment=( comment=self.schema_generator._table_comment_generator(
self.schema_generator._table_comment_generator(table=through, comment=description) table=field.through, comment=field.description
if description )
else "" if field.description
), else "",
) )
def drop_m2m(self, table_name: str) -> str: def drop_m2m(self, field: ManyToManyFieldInstance):
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name) return self._DROP_TABLE_TEMPLATE.format(table_name=field.through)
def _get_default(self, model: type[Model], field_describe: dict) -> Any: def _get_default(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table db_table = model._meta.db_table
default = field_describe.get("default") default = field_object.default
if isinstance(default, Enum): db_column = field_object.model_field_name
default = default.value auto_now_add = getattr(field_object, "auto_now_add", False)
db_column = cast(str, field_describe.get("db_column")) auto_now = getattr(field_object, "auto_now", False)
auto_now_add = field_describe.get("auto_now_add", False)
auto_now = field_describe.get("auto_now", False)
if default is not None or auto_now_add: if default is not None or auto_now_add:
if field_describe.get("field_type") in [ if callable(default) or isinstance(field_object, (UUIDField, TextField, JSONField)):
"UUIDField",
"TextField",
"JSONField",
] or is_default_function(default):
default = "" default = ""
else: else:
default = field_object.to_db_value(default, model)
try: try:
default = self.schema_generator._column_default_generator( default = self.schema_generator._column_default_generator(
db_table, db_table,
@ -113,58 +80,59 @@ class BaseDDL:
) )
except NotImplementedError: except NotImplementedError:
default = "" default = ""
else:
default = ""
return default return default
def add_column(self, model: type[Model], field_describe: dict, is_pk: bool = False) -> str: def add_column(self, model: "Type[Model]", field_object: Field):
return self._add_or_modify_column(model, field_describe, is_pk)
def _add_or_modify_column(
self, model: type[Model], field_describe: dict, is_pk: bool, modify: bool = False
) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
description = field_describe.get("description")
db_column = cast(str, field_describe.get("db_column"))
db_field_types = cast(dict, field_describe.get("db_field_types"))
default = self._get_default(model, field_describe)
if default is None:
default = ""
if modify:
unique = ""
template = self._MODIFY_COLUMN_TEMPLATE
else:
# sqlite does not support alter table to add unique column
unique = " UNIQUE" if field_describe.get("unique") and self.DIALECT != "sqlite" else ""
template = self._ADD_COLUMN_TEMPLATE
column = self.schema_generator._create_string(
db_column=db_column,
field_type=db_field_types.get(self.DIALECT, db_field_types.get("")),
nullable=" NOT NULL" if not field_describe.get("nullable") else "",
unique=unique,
comment=(
self.schema_generator._column_comment_generator(
table=db_table,
column=db_column,
comment=description,
)
if description
else ""
),
is_primary_key=is_pk,
default=default,
)
if tortoise_version_less_than("0.23.1"):
column = column.replace(" ", " ")
return template.format(table_name=db_table, column=column)
def drop_column(self, model: type[Model], column_name: str) -> str: return self._ADD_COLUMN_TEMPLATE.format(
table_name=db_table,
column=self.schema_generator._create_string(
db_column=field_object.model_field_name,
field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"),
nullable="NOT NULL" if not field_object.null else "",
unique="UNIQUE" if field_object.unique else "",
comment=self.schema_generator._column_comment_generator(
table=db_table,
column=field_object.model_field_name,
comment=field_object.description,
)
if field_object.description
else "",
is_primary_key=field_object.pk,
default=self._get_default(model, field_object),
),
)
def drop_column(self, model: "Type[Model]", column_name: str):
return self._DROP_COLUMN_TEMPLATE.format( return self._DROP_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, column_name=column_name table_name=model._meta.db_table, column_name=column_name
) )
def modify_column(self, model: type[Model], field_describe: dict, is_pk: bool = False) -> str: def modify_column(self, model: "Type[Model]", field_object: Field):
return self._add_or_modify_column(model, field_describe, is_pk, modify=True) db_table = model._meta.db_table
return self._MODIFY_COLUMN_TEMPLATE.format(
table_name=db_table,
column=self.schema_generator._create_string(
db_column=field_object.model_field_name,
field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"),
nullable="NOT NULL" if not field_object.null else "",
unique="",
comment=self.schema_generator._column_comment_generator(
table=db_table,
column=field_object.model_field_name,
comment=field_object.description,
)
if field_object.description
else "",
is_primary_key=field_object.pk,
default=self._get_default(model, field_object),
),
)
def rename_column(self, model: type[Model], old_column_name: str, new_column_name: str) -> str: def rename_column(self, model: "Type[Model]", old_column_name: str, new_column_name: str):
return self._RENAME_COLUMN_TEMPLATE.format( return self._RENAME_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, table_name=model._meta.db_table,
old_column_name=old_column_name, old_column_name=old_column_name,
@ -172,8 +140,8 @@ class BaseDDL:
) )
def change_column( def change_column(
self, model: type[Model], old_column_name: str, new_column_name: str, new_column_type: str self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str
) -> str: ):
return self._CHANGE_COLUMN_TEMPLATE.format( return self._CHANGE_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, table_name=model._meta.db_table,
old_column_name=old_column_name, old_column_name=old_column_name,
@ -181,122 +149,66 @@ class BaseDDL:
new_column_type=new_column_type, new_column_type=new_column_type,
) )
def _index_name(self, unique: bool | None, model: type[Model], field_names: list[str]) -> str: def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
func_name = "_get_index_name"
if not hasattr(self.schema_generator, func_name):
# For tortoise-orm<0.24.1
func_name = "_generate_index_name"
return getattr(self.schema_generator, func_name)(
"idx" if not unique else "uid", model, field_names
)
def add_index(
self,
model: type[Model],
field_names: list[str],
unique: bool | None = False,
name: str | None = None,
index_type: str = "",
extra: str | None = "",
) -> str:
return self._ADD_INDEX_TEMPLATE.format( return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE " if unique else "", unique="UNIQUE" if unique else "",
index_name=name or self._index_name(unique, model, field_names), index_name=self.schema_generator._generate_index_name(
"idx" if not unique else "uid", model, field_names
),
table_name=model._meta.db_table, table_name=model._meta.db_table,
column_names=", ".join(self.schema_generator.quote(f) for f in field_names), column_names=", ".join([self.schema_generator.quote(f) for f in field_names]),
index_type=f"{index_type} " if index_type else "",
extra=f"{extra}" if extra else "",
) )
def drop_index( def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False):
self,
model: type[Model],
field_names: list[str],
unique: bool | None = False,
name: str | None = None,
) -> str:
return self._DROP_INDEX_TEMPLATE.format( return self._DROP_INDEX_TEMPLATE.format(
index_name=name or self._index_name(unique, model, field_names), index_name=self.schema_generator._generate_index_name(
"idx" if not unique else "uid", model, field_names
),
table_name=model._meta.db_table, table_name=model._meta.db_table,
) )
def drop_index_by_name(self, model: type[Model], index_name: str) -> str: def add_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance):
return self.drop_index(model, [], name=index_name) db_table = model._meta.db_table
to_field_name = field.to_field_instance.source_field
if not to_field_name:
to_field_name = field.to_field_instance.model_field_name
def _generate_fk_name( db_column = field.source_field or field.model_field_name + "_id"
self, db_table: str, field_describe: dict, reference_table_describe: dict fk_name = self.schema_generator._generate_fk_name(
) -> str:
"""Generate fk name"""
db_column = cast(str, field_describe.get("raw_field"))
pk_field = cast(dict, reference_table_describe.get("pk_field"))
to_field = cast(str, pk_field.get("db_column"))
to_table = cast(str, reference_table_describe.get("table"))
func_name = "_get_fk_name"
if not hasattr(self.schema_generator, func_name):
# For tortoise-orm<0.24.1
func_name = "_generate_fk_name"
return getattr(self.schema_generator, func_name)(
from_table=db_table, from_table=db_table,
from_field=db_column, from_field=db_column,
to_table=to_table, to_table=field.related_model._meta.db_table,
to_field=to_field, to_field=to_field_name,
) )
def add_fk(
self, model: type[Model], field_describe: dict, reference_table_describe: dict
) -> str:
db_table = model._meta.db_table
db_column = field_describe.get("raw_field")
pk_field = cast(dict, reference_table_describe.get("pk_field"))
reference_id = pk_field.get("db_column")
return self._ADD_FK_TEMPLATE.format( return self._ADD_FK_TEMPLATE.format(
table_name=db_table, table_name=db_table,
fk_name=self._generate_fk_name(db_table, field_describe, reference_table_describe), fk_name=fk_name,
db_column=db_column, db_column=db_column,
table=reference_table_describe.get("table"), table=field.related_model._meta.db_table,
field=reference_id, field=to_field_name,
on_delete=field_describe.get("on_delete"), on_delete=field.on_delete,
) )
def drop_fk( def drop_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance):
self, model: type[Model], field_describe: dict, reference_table_describe: dict to_field_name = field.to_field_instance.source_field
) -> str: if not to_field_name:
to_field_name = field.to_field_instance.model_field_name
db_table = model._meta.db_table db_table = model._meta.db_table
fk_name = self._generate_fk_name(db_table, field_describe, reference_table_describe) return self._DROP_FK_TEMPLATE.format(
return self._DROP_FK_TEMPLATE.format(table_name=db_table, fk_name=fk_name)
def alter_column_default(self, model: type[Model], field_describe: dict) -> str:
db_table = model._meta.db_table
default = self._get_default(model, field_describe)
return self._ALTER_DEFAULT_TEMPLATE.format(
table_name=db_table, table_name=db_table,
column=field_describe.get("db_column"), fk_name=self.schema_generator._generate_fk_name(
default="SET" + default if default is not None else "DROP DEFAULT", from_table=db_table,
from_field=field.source_field or field.model_field_name + "_id",
to_table=field.related_model._meta.db_table,
to_field=to_field_name,
),
) )
def alter_column_null(self, model: type[Model], field_describe: dict) -> str: def alter_column_default(self, model: "Type[Model]", field_object: Field):
return self.modify_column(model, field_describe) pass
def set_comment(self, model: type[Model], field_describe: dict) -> str: def alter_column_null(self, model: "Type[Model]", field_object: Field):
return self.modify_column(model, field_describe) pass
def rename_table(self, model: type[Model], old_table_name: str, new_table_name: str) -> str: def set_comment(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table pass
return self._RENAME_TABLE_TEMPLATE.format(
table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name
)
def alter_indexed_column_unique(
self, model: type[Model], field_name: str, drop: bool = False
) -> list[str]:
"""Change unique constraint for indexed field, e.g.: Field(db_index=True) --> Field(unique=True)"""
fields = [field_name]
if drop:
drop_unique = self.drop_index(model, fields, unique=True)
add_normal_index = self.add_index(model, fields, unique=False)
return [drop_unique, add_normal_index]
else:
drop_index = self.drop_index(model, fields, unique=False)
add_unique_index = self.add_index(model, fields, unique=True)
return [drop_index, add_unique_index]

View File

@ -1,61 +1,22 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
if TYPE_CHECKING:
from tortoise import Model
class MysqlDDL(BaseDDL): class MysqlDDL(BaseDDL):
schema_generator_cls = MySQLSchemaGenerator schema_generator_cls = MySQLSchemaGenerator
DIALECT = MySQLSchemaGenerator.DIALECT DIALECT = MySQLSchemaGenerator.DIALECT
_DROP_TABLE_TEMPLATE = "DROP TABLE IF EXISTS `{table_name}`" _DROP_TABLE_TEMPLATE = "DROP TABLE IF EXISTS `{table_name}`"
_ADD_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` ADD {column}" _ADD_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` ADD {column}"
_ALTER_DEFAULT_TEMPLATE = "ALTER TABLE `{table_name}` ALTER COLUMN `{column}` {default}"
_CHANGE_COLUMN_TEMPLATE = (
"ALTER TABLE `{table_name}` CHANGE {old_column_name} {new_column_name} {new_column_type}"
)
_DROP_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` DROP COLUMN `{column_name}`" _DROP_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` DROP COLUMN `{column_name}`"
_RENAME_COLUMN_TEMPLATE = ( _RENAME_COLUMN_TEMPLATE = (
"ALTER TABLE `{table_name}` RENAME COLUMN `{old_column_name}` TO `{new_column_name}`" "ALTER TABLE `{table_name}` RENAME COLUMN `{old_column_name}` TO `{new_column_name}`"
) )
_ADD_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` ADD {index_type}{unique}INDEX `{index_name}` ({column_names}){extra}" _ADD_INDEX_TEMPLATE = (
"ALTER TABLE `{table_name}` ADD {unique} INDEX `{index_name}` ({column_names})"
)
_DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`" _DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`"
_ADD_INDEXED_UNIQUE_TEMPLATE = (
"ALTER TABLE `{table_name}` DROP INDEX `{index_name}`, ADD UNIQUE (`{column_name}`)"
)
_DROP_INDEXED_UNIQUE_TEMPLATE = (
"ALTER TABLE `{table_name}` DROP INDEX `{column_name}`, ADD INDEX (`{index_name}`)"
)
_ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}" _ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}"
_DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`" _DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`"
_M2M_TABLE_TEMPLATE = ( _M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment};"
"CREATE TABLE `{table_name}` (\n"
" `{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,\n"
" `{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE\n"
"){extra}{comment}"
)
_MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}" _MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}"
_RENAME_TABLE_TEMPLATE = "ALTER TABLE `{old_table_name}` RENAME TO `{new_table_name}`"
def _index_name(self, unique: bool | None, model: type[Model], field_names: list[str]) -> str:
if unique and len(field_names) == 1:
# Example: `email = CharField(max_length=50, unique=True)`
# Generate schema: `"email" VARCHAR(10) NOT NULL UNIQUE`
# Unique index key is the same as field name: `email`
return field_names[0]
return super()._index_name(unique, model, field_names)
def alter_indexed_column_unique(
self, model: type[Model], field_name: str, drop: bool = False
) -> list[str]:
# if drop is false: Drop index and add unique
# else: Drop unique index and add normal index
template = self._DROP_INDEXED_UNIQUE_TEMPLATE if drop else self._ADD_INDEXED_UNIQUE_TEMPLATE
table = self.get_table_name(model)
index = self._index_name(unique=False, model=model, field_names=[field_name])
return [template.format(table_name=table, index_name=index, column_name=field_name)]

View File

@ -1,53 +1,75 @@
from __future__ import annotations from typing import List, Type
from typing import cast
from tortoise import Model from tortoise import Model
from tortoise.backends.base_postgres.schema_generator import BasePostgresSchemaGenerator from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
from tortoise.fields import Field
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
class PostgresDDL(BaseDDL): class PostgresDDL(BaseDDL):
schema_generator_cls = BasePostgresSchemaGenerator schema_generator_cls = AsyncpgSchemaGenerator
DIALECT = BasePostgresSchemaGenerator.DIALECT DIALECT = AsyncpgSchemaGenerator.DIALECT
_ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX IF NOT EXISTS "{index_name}" ON "{table_name}" {index_type}({column_names}){extra}' _ADD_INDEX_TEMPLATE = 'CREATE INDEX "{index_name}" ON "{table_name}" ({column_names})'
_DROP_INDEX_TEMPLATE = 'DROP INDEX IF EXISTS "{index_name}"' _ADD_UNIQUE_TEMPLATE = (
_ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL' 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{index_name}" UNIQUE ({column_names})'
_MODIFY_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}{using}'
) )
_DROP_INDEX_TEMPLATE = 'DROP INDEX "{index_name}"'
_DROP_UNIQUE_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{index_name}"'
_ALTER_DEFAULT_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {default}'
_ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL'
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}'
_SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}' _SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT IF EXISTS "{fk_name}"' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"'
def alter_column_null(self, model: type[Model], field_describe: dict) -> str: def alter_column_default(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table
default = self._get_default(model, field_object)
return self._ALTER_DEFAULT_TEMPLATE.format(
table_name=db_table,
column=field_object.model_field_name,
default="SET" + default if default else "DROP DEFAULT",
)
def alter_column_null(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table db_table = model._meta.db_table
return self._ALTER_NULL_TEMPLATE.format( return self._ALTER_NULL_TEMPLATE.format(
table_name=db_table, table_name=db_table,
column=field_describe.get("db_column"), column=field_object.model_field_name,
set_drop="DROP" if field_describe.get("nullable") else "SET", set_drop="DROP" if field_object.null else "SET",
) )
def modify_column(self, model: type[Model], field_describe: dict, is_pk: bool = False) -> str: def modify_column(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table db_table = model._meta.db_table
db_field_types = cast(dict, field_describe.get("db_field_types"))
db_column = field_describe.get("db_column")
datatype = db_field_types.get(self.DIALECT) or db_field_types.get("")
return self._MODIFY_COLUMN_TEMPLATE.format( return self._MODIFY_COLUMN_TEMPLATE.format(
table_name=db_table, table_name=db_table,
column=db_column, column=field_object.model_field_name,
datatype=datatype, datatype=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"),
using=f' USING "{db_column}"::{datatype}',
) )
def set_comment(self, model: type[Model], field_describe: dict) -> str: def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
template = self._ADD_UNIQUE_TEMPLATE if unique else self._ADD_INDEX_TEMPLATE
return template.format(
index_name=self.schema_generator._generate_index_name(
"uid" if unique else "idx", model, field_names
),
table_name=model._meta.db_table,
column_names=", ".join([self.schema_generator.quote(f) for f in field_names]),
)
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False):
template = self._DROP_UNIQUE_TEMPLATE if unique else self._DROP_INDEX_TEMPLATE
return template.format(
index_name=self.schema_generator._generate_index_name(
"uid" if unique else "idx", model, field_names
),
table_name=model._meta.db_table,
)
def set_comment(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table db_table = model._meta.db_table
return self._SET_COMMENT_TEMPLATE.format( return self._SET_COMMENT_TEMPLATE.format(
table_name=db_table, table_name=db_table,
column=field_describe.get("db_column") or field_describe.get("raw_field"), column=field_object.model_field_name,
comment=( comment="'{}'".format(field_object.description) if field_object.description else "NULL",
"'{}'".format(field_describe.get("description"))
if field_describe.get("description")
else "NULL"
),
) )

View File

@ -1,7 +1,8 @@
from __future__ import annotations from typing import Type
from tortoise import Model from tortoise import Model
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
from tortoise.fields import Field
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
from aerich.exceptions import NotSupportError from aerich.exceptions import NotSupportError
@ -10,17 +11,9 @@ from aerich.exceptions import NotSupportError
class SqliteDDL(BaseDDL): class SqliteDDL(BaseDDL):
schema_generator_cls = SqliteSchemaGenerator schema_generator_cls = SqliteSchemaGenerator
DIALECT = SqliteSchemaGenerator.DIALECT DIALECT = SqliteSchemaGenerator.DIALECT
_ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX "{index_name}" ON "{table_name}" ({column_names})'
_DROP_INDEX_TEMPLATE = 'DROP INDEX IF EXISTS "{index_name}"'
def modify_column(self, model: type[Model], field_object: dict, is_pk: bool = True): def drop_column(self, model: "Type[Model]", column_name: str):
raise NotSupportError("Drop column is unsupported in SQLite.")
def modify_column(self, model: "Type[Model]", field_object: Field):
raise NotSupportError("Modify column is unsupported in SQLite.") raise NotSupportError("Modify column is unsupported in SQLite.")
def alter_column_default(self, model: type[Model], field_describe: dict):
raise NotSupportError("Alter column default is unsupported in SQLite.")
def alter_column_null(self, model: type[Model], field_describe: dict):
raise NotSupportError("Alter column null is unsupported in SQLite.")
def set_comment(self, model: type[Model], field_describe: dict):
raise NotSupportError("Alter column comment is unsupported in SQLite.")

View File

@ -2,9 +2,3 @@ class NotSupportError(Exception):
""" """
raise when features not support raise when features not support
""" """
class DowngradeError(Exception):
"""
raise when downgrade error
"""

View File

@ -1,192 +0,0 @@
from __future__ import annotations
import contextlib
from typing import Any, Callable, Dict, TypedDict
from pydantic import BaseModel
from tortoise import BaseDBAsyncClient
class ColumnInfoDict(TypedDict):
name: str
pk: str
index: str
null: str
default: str
length: str
comment: str
# TODO: use dict to replace typing.Dict when dropping support for Python3.8
FieldMapDict = Dict[str, Callable[..., str]]
class Column(BaseModel):
name: str
data_type: str
null: bool
default: Any
comment: str | None = None
pk: bool
unique: bool
index: bool
length: int | None = None
extra: str | None = None
decimal_places: int | None = None
max_digits: int | None = None
def translate(self) -> ColumnInfoDict:
comment = default = length = index = null = pk = ""
if self.pk:
pk = "primary_key=True, "
else:
if self.unique:
index = "unique=True, "
elif self.index:
index = "db_index=True, "
if self.data_type in ("varchar", "VARCHAR"):
length = f"max_length={self.length}, "
elif self.data_type in ("decimal", "numeric"):
length_parts = []
if self.max_digits:
length_parts.append(f"max_digits={self.max_digits}")
if self.decimal_places:
length_parts.append(f"decimal_places={self.decimal_places}")
if length_parts:
length = ", ".join(length_parts) + ", "
if self.null:
null = "null=True, "
if self.default is not None and not self.pk:
if self.data_type in ("tinyint", "INT"):
default = f"default={'True' if self.default == '1' else 'False'}, "
elif self.data_type == "bool":
default = f"default={'True' if self.default == 'true' else 'False'}, "
elif self.data_type in ("datetime", "timestamptz", "TIMESTAMP"):
if self.default == "CURRENT_TIMESTAMP":
if self.extra == "DEFAULT_GENERATED on update CURRENT_TIMESTAMP":
default = "auto_now=True, "
else:
default = "auto_now_add=True, "
else:
if "::" in self.default:
default = f"default={self.default.split('::')[0]}, "
elif self.default.endswith("()"):
default = ""
elif self.default == "":
default = 'default=""'
else:
default = f"default={self.default}, "
if self.comment:
comment = f"description='{self.comment}', "
return {
"name": self.name,
"pk": pk,
"index": index,
"null": null,
"default": default,
"length": length,
"comment": comment,
}
class Inspect:
_table_template = "class {table}(Model):\n"
def __init__(self, conn: BaseDBAsyncClient, tables: list[str] | None = None) -> None:
self.conn = conn
with contextlib.suppress(AttributeError):
self.database = conn.database # type:ignore[attr-defined]
self.tables = tables
@property
def field_map(self) -> FieldMapDict:
raise NotImplementedError
async def inspect(self) -> str:
if not self.tables:
self.tables = await self.get_all_tables()
result = "from tortoise import Model, fields\n\n\n"
tables = []
for table in self.tables:
columns = await self.get_columns(table)
fields = []
model = self._table_template.format(table=table.title().replace("_", ""))
for column in columns:
field = self.field_map[column.data_type](**column.translate())
fields.append(" " + field)
tables.append(model + "\n".join(fields))
return result + "\n\n\n".join(tables)
async def get_columns(self, table: str) -> list[Column]:
raise NotImplementedError
async def get_all_tables(self) -> list[str]:
raise NotImplementedError
@staticmethod
def get_field_string(
field_class: str, arguments: str = "{null}{default}{comment}", **kwargs
) -> str:
name = kwargs["name"]
field_params = arguments.format(**kwargs).strip().rstrip(",")
return f"{name} = fields.{field_class}({field_params})"
@classmethod
def decimal_field(cls, **kwargs) -> str:
return cls.get_field_string("DecimalField", **kwargs)
@classmethod
def time_field(cls, **kwargs) -> str:
return cls.get_field_string("TimeField", **kwargs)
@classmethod
def date_field(cls, **kwargs) -> str:
return cls.get_field_string("DateField", **kwargs)
@classmethod
def float_field(cls, **kwargs) -> str:
return cls.get_field_string("FloatField", **kwargs)
@classmethod
def datetime_field(cls, **kwargs) -> str:
return cls.get_field_string("DatetimeField", **kwargs)
@classmethod
def text_field(cls, **kwargs) -> str:
return cls.get_field_string("TextField", **kwargs)
@classmethod
def char_field(cls, **kwargs) -> str:
arguments = "{pk}{index}{length}{null}{default}{comment}"
return cls.get_field_string("CharField", arguments, **kwargs)
@classmethod
def int_field(cls, field_class="IntField", **kwargs) -> str:
arguments = "{pk}{index}{default}{comment}"
return cls.get_field_string(field_class, arguments, **kwargs)
@classmethod
def smallint_field(cls, **kwargs) -> str:
return cls.int_field("SmallIntField", **kwargs)
@classmethod
def bigint_field(cls, **kwargs) -> str:
return cls.int_field("BigIntField", **kwargs)
@classmethod
def bool_field(cls, **kwargs) -> str:
return cls.get_field_string("BooleanField", **kwargs)
@classmethod
def uuid_field(cls, **kwargs) -> str:
arguments = "{pk}{index}{default}{comment}"
return cls.get_field_string("UUIDField", arguments, **kwargs)
@classmethod
def json_field(cls, **kwargs) -> str:
return cls.get_field_string("JSONField", **kwargs)
@classmethod
def binary_field(cls, **kwargs) -> str:
return cls.get_field_string("BinaryField", **kwargs)

View File

@ -1,67 +0,0 @@
from __future__ import annotations
from aerich.inspectdb import Column, FieldMapDict, Inspect
class InspectMySQL(Inspect):
@property
def field_map(self) -> FieldMapDict:
return {
"int": self.int_field,
"smallint": self.smallint_field,
"tinyint": self.bool_field,
"bigint": self.bigint_field,
"varchar": self.char_field,
"char": self.uuid_field,
"longtext": self.text_field,
"text": self.text_field,
"datetime": self.datetime_field,
"float": self.float_field,
"double": self.float_field,
"date": self.date_field,
"time": self.time_field,
"decimal": self.decimal_field,
"json": self.json_field,
"longblob": self.binary_field,
}
async def get_all_tables(self) -> list[str]:
sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s"
ret = await self.conn.execute_query_dict(sql, [self.database])
return list(map(lambda x: x["TABLE_NAME"], ret))
async def get_columns(self, table: str) -> list[Column]:
columns = []
sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME
from information_schema.COLUMNS c
left join information_schema.STATISTICS s on c.TABLE_NAME = s.TABLE_NAME
and c.TABLE_SCHEMA = s.TABLE_SCHEMA
and c.COLUMN_NAME = s.COLUMN_NAME
where c.TABLE_SCHEMA = %s
and c.TABLE_NAME = %s"""
ret = await self.conn.execute_query_dict(sql, [self.database, table])
for row in ret:
unique = index = False
if (non_unique := row["NON_UNIQUE"]) is not None:
unique = not non_unique
elif row["COLUMN_KEY"] == "UNI":
unique = True
if (index_name := row["INDEX_NAME"]) is not None:
index = index_name != "PRIMARY"
columns.append(
Column(
name=row["COLUMN_NAME"],
data_type=row["DATA_TYPE"],
null=row["IS_NULLABLE"] == "YES",
default=row["COLUMN_DEFAULT"],
pk=row["COLUMN_KEY"] == "PRI",
comment=row["COLUMN_COMMENT"],
unique=unique,
extra=row["EXTRA"],
index=index,
length=row["CHARACTER_MAXIMUM_LENGTH"],
max_digits=row["NUMERIC_PRECISION"],
decimal_places=row["NUMERIC_SCALE"],
)
)
return columns

View File

@ -1,83 +0,0 @@
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from aerich.inspectdb import Column, FieldMapDict, Inspect
if TYPE_CHECKING:
from tortoise.backends.base_postgres.client import BasePostgresClient
class InspectPostgres(Inspect):
def __init__(self, conn: BasePostgresClient, tables: list[str] | None = None) -> None:
super().__init__(conn, tables)
self.schema = conn.server_settings.get("schema") or "public"
@property
def field_map(self) -> FieldMapDict:
return {
"int2": self.smallint_field,
"int4": self.int_field,
"int8": self.bigint_field,
"smallint": self.smallint_field,
"bigint": self.bigint_field,
"varchar": self.char_field,
"text": self.text_field,
"timestamptz": self.datetime_field,
"float4": self.float_field,
"float8": self.float_field,
"date": self.date_field,
"time": self.time_field,
"decimal": self.decimal_field,
"numeric": self.decimal_field,
"uuid": self.uuid_field,
"jsonb": self.json_field,
"bytea": self.binary_field,
"bool": self.bool_field,
"timestamp": self.datetime_field,
}
async def get_all_tables(self) -> list[str]:
sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2"
ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
return list(map(lambda x: x["table_name"], ret))
async def get_columns(self, table: str) -> list[Column]:
columns = []
sql = f"""select c.column_name,
col_description('public.{table}'::regclass, ordinal_position) as column_comment,
t.constraint_type as column_key,
udt_name as data_type,
is_nullable,
column_default,
character_maximum_length,
numeric_precision,
numeric_scale
from information_schema.constraint_column_usage const
join information_schema.table_constraints t
using (table_catalog, table_schema, table_name, constraint_catalog, constraint_schema, constraint_name)
right join information_schema.columns c using (column_name, table_catalog, table_schema, table_name)
where c.table_catalog = $1
and c.table_name = $2
and c.table_schema = $3""" # nosec:B608
if "psycopg" in str(type(self.conn)).lower():
sql = re.sub(r"\$[123]", "%s", sql)
ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema])
for row in ret:
columns.append(
Column(
name=row["column_name"],
data_type=row["data_type"],
null=row["is_nullable"] == "YES",
default=row["column_default"],
length=row["character_maximum_length"],
max_digits=row["numeric_precision"],
decimal_places=row["numeric_scale"],
comment=row["column_comment"],
pk=row["column_key"] == "PRIMARY KEY",
unique=False, # can't get this simply
index=False, # can't get this simply
)
)
return columns

View File

@ -1,61 +0,0 @@
from __future__ import annotations
from aerich.inspectdb import Column, FieldMapDict, Inspect
class InspectSQLite(Inspect):
@property
def field_map(self) -> FieldMapDict:
return {
"INTEGER": self.int_field,
"INT": self.bool_field,
"SMALLINT": self.smallint_field,
"VARCHAR": self.char_field,
"TEXT": self.text_field,
"TIMESTAMP": self.datetime_field,
"REAL": self.float_field,
"BIGINT": self.bigint_field,
"DATE": self.date_field,
"TIME": self.time_field,
"JSON": self.json_field,
"BLOB": self.binary_field,
}
async def get_columns(self, table: str) -> list[Column]:
columns = []
sql = f"PRAGMA table_info({table})"
ret = await self.conn.execute_query_dict(sql)
columns_index = await self._get_columns_index(table)
for row in ret:
try:
length = row["type"].split("(")[1].split(")")[0]
except IndexError:
length = None
columns.append(
Column(
name=row["name"],
data_type=row["type"].split("(")[0],
null=row["notnull"] == 0,
default=row["dflt_value"],
length=length,
pk=row["pk"] == 1,
unique=columns_index.get(row["name"]) == "unique",
index=columns_index.get(row["name"]) == "index",
)
)
return columns
async def _get_columns_index(self, table: str) -> dict[str, str]:
sql = f"PRAGMA index_list ({table})"
indexes = await self.conn.execute_query_dict(sql)
ret = {}
for index in indexes:
sql = f"PRAGMA index_info({index['name']})"
index_info = (await self.conn.execute_query_dict(sql))[0]
ret[index_info["name"]] = "unique" if index["unique"] else "index"
return ret
async def get_all_tables(self) -> list[str]:
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
ret = await self.conn.execute_query_dict(sql)
return list(map(lambda x: x["tbl_name"], ret))

File diff suppressed because it is too large Load Diff

View File

@ -1,15 +1,12 @@
from tortoise import Model, fields from tortoise import Model, fields
from aerich.coder import decoder, encoder
MAX_VERSION_LENGTH = 255 MAX_VERSION_LENGTH = 255
MAX_APP_LENGTH = 100
class Aerich(Model): class Aerich(Model):
version = fields.CharField(max_length=MAX_VERSION_LENGTH) version = fields.CharField(max_length=MAX_VERSION_LENGTH)
app = fields.CharField(max_length=MAX_APP_LENGTH) app = fields.CharField(max_length=20)
content: dict = fields.JSONField(encoder=encoder, decoder=decoder) content = fields.TextField()
class Meta: class Meta:
ordering = ["-id"] ordering = ["-id"]

View File

@ -1,52 +1,26 @@
from __future__ import annotations import importlib
from typing import Dict
import importlib.util from click import BadOptionUsage, Context
import os
import re
import sys
from collections.abc import Generator
from pathlib import Path
from types import ModuleType
from asyncclick import BadOptionUsage, ClickException, Context
from dictdiffer import diff
from tortoise import BaseDBAsyncClient, Tortoise from tortoise import BaseDBAsyncClient, Tortoise
def add_src_path(path: str) -> str: def get_app_connection_name(config, app) -> str:
"""
add a folder to the paths, so we can import from there
:param path: path to add
:return: absolute path
"""
if not os.path.isabs(path):
# use the absolute path, otherwise some other things (e.g. __file__) won't work properly
path = os.path.abspath(path)
if not os.path.isdir(path):
raise ClickException(f"Specified source folder does not exist: {path}")
if path not in sys.path:
sys.path.insert(0, path)
return path
def get_app_connection_name(config, app_name: str) -> str:
""" """
get connection name get connection name
:param config: :param config:
:param app_name: :param app:
:return: the default connection name (Usally it is 'default') :return:
""" """
if app := config.get("apps").get(app_name): return config.get("apps").get(app).get("default_connection", "default")
return app.get("default_connection", "default")
raise BadOptionUsage(option_name="--app", message=f"Can't get app named {app_name!r}")
def get_app_connection(config, app) -> BaseDBAsyncClient: def get_app_connection(config, app) -> BaseDBAsyncClient:
""" """
get connection client get connection name
:param config: :param config:
:param app: :param app:
:return: client instance :return:
""" """
return Tortoise.get_connection(get_app_connection_name(config, app)) return Tortoise.get_connection(get_app_connection_name(config, app))
@ -61,11 +35,12 @@ def get_tortoise_config(ctx: Context, tortoise_orm: str) -> dict:
splits = tortoise_orm.split(".") splits = tortoise_orm.split(".")
config_path = ".".join(splits[:-1]) config_path = ".".join(splits[:-1])
tortoise_config = splits[-1] tortoise_config = splits[-1]
try: try:
config_module = importlib.import_module(config_path) config_module = importlib.import_module(config_path)
except ModuleNotFoundError as e: except (ModuleNotFoundError, AttributeError):
raise ClickException(f"Error while importing configuration module: {e}") from None raise BadOptionUsage(
ctx=ctx, message=f'No config named "{config_path}"', option_name="--config"
)
config = getattr(config_module, tortoise_config, None) config = getattr(config_module, tortoise_config, None)
if not config: if not config:
@ -77,67 +52,44 @@ def get_tortoise_config(ctx: Context, tortoise_orm: str) -> dict:
return config return config
def get_models_describe(app: str) -> dict: _UPGRADE = "##### upgrade #####\n"
_DOWNGRADE = "##### downgrade #####\n"
def get_version_content_from_file(version_file: str) -> Dict:
""" """
get app models describe get version content
:param app: :param version_file:
:return: :return:
""" """
ret = {} with open(version_file, "r", encoding="utf-8") as f:
for model in Tortoise.apps[app].values(): content = f.read()
managed = getattr(model.Meta, "managed", None) first = content.index(_UPGRADE)
describe = model.describe() second = content.index(_DOWNGRADE)
ret[describe.get("name")] = dict(describe, managed=managed) upgrade_content = content[first + len(_UPGRADE) : second].strip() # noqa:E203
downgrade_content = content[second + len(_DOWNGRADE) :].strip() # noqa:E203
ret = {"upgrade": upgrade_content.split("\n"), "downgrade": downgrade_content.split("\n")}
return ret return ret
def is_default_function(string: str) -> re.Match | None: def write_version_file(version_file: str, content: Dict):
return re.match(r"^<function.+>$", str(string or ""))
def import_py_file(file: str | Path) -> ModuleType:
module_name, file_ext = os.path.splitext(os.path.split(file)[-1])
spec = importlib.util.spec_from_file_location(module_name, file)
module = importlib.util.module_from_spec(spec) # type:ignore[arg-type]
spec.loader.exec_module(module) # type:ignore[union-attr]
return module
def get_dict_diff_by_key(
old_fields: list[dict], new_fields: list[dict], key="through"
) -> Generator[tuple]:
""" """
Compare two list by key instead of by index write version file
:param version_file:
:param old_fields: previous field info list :param content:
:param new_fields: current field info list :return:
:param key: if two dicts have the same value of this key, action is change; otherwise, is remove/add
:return: similar to dictdiffer.diff
Example::
>>> old = [{'through': 'a'}, {'through': 'b'}, {'through': 'c'}]
>>> new = [{'through': 'a'}, {'through': 'c'}] # remove the second element
>>> list(diff(old, new))
[('change', [1, 'through'], ('b', 'c')),
('remove', '', [(2, {'through': 'c'})])]
>>> list(get_dict_diff_by_key(old, new))
[('remove', '', [(0, {'through': 'b'})])]
""" """
length_old, length_new = len(old_fields), len(new_fields) with open(version_file, "w", encoding="utf-8") as f:
if length_old == 0 or length_new == 0 or length_old == length_new == 1: f.write(_UPGRADE)
yield from diff(old_fields, new_fields) upgrade = content.get("upgrade")
if len(upgrade) > 1:
f.write(";\n".join(upgrade) + ";\n")
else: else:
value_index: dict[str, int] = {f[key]: i for i, f in enumerate(new_fields)} f.write(f"{upgrade[0]};\n")
additions = set(range(length_new)) downgrade = content.get("downgrade")
for field in old_fields: if downgrade:
value = field[key] f.write(_DOWNGRADE)
if (index := value_index.get(value)) is not None: if len(downgrade) > 1:
additions.remove(index) f.write(";\n".join(downgrade) + ";\n")
yield from diff([field], [new_fields[index]]) # change
else: else:
yield from diff([field], []) # remove f.write(f"{downgrade[0]};\n")
if additions:
for index in sorted(additions):
yield from diff([], [new_fields[index]]) # add

View File

@ -1,3 +0,0 @@
from importlib.metadata import version
__version__ = version(__package__)

View File

@ -1,30 +1,23 @@
from __future__ import annotations
import asyncio import asyncio
import os import os
import sys
from collections.abc import Generator
from pathlib import Path
import pytest import pytest
from tortoise import Tortoise, expand_db_url from tortoise import Tortoise, expand_db_url, generate_schema_for_client
from tortoise.backends.base_postgres.schema_generator import BasePostgresSchemaGenerator from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
from tortoise.contrib.test import MEMORY_SQLITE
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
from aerich.migrate import Migrate from aerich.migrate import Migrate
from tests._utils import chdir, copy_files, init_db, run_shell
db_url = os.getenv("TEST_DB", MEMORY_SQLITE) db_url = os.getenv("TEST_DB", "sqlite://:memory:")
db_url_second = os.getenv("TEST_DB_SECOND", MEMORY_SQLITE) db_url_second = os.getenv("TEST_DB_SECOND", "sqlite://:memory:")
tortoise_orm = { tortoise_orm = {
"connections": { "connections": {
"default": expand_db_url(db_url, testing=True), "default": expand_db_url(db_url, True),
"second": expand_db_url(db_url_second, testing=True), "second": expand_db_url(db_url_second, True),
}, },
"apps": { "apps": {
"models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"}, "models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"},
@ -34,7 +27,7 @@ tortoise_orm = {
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def reset_migrate() -> None: def reset_migrate():
Migrate.upgrade_operators = [] Migrate.upgrade_operators = []
Migrate.downgrade_operators = [] Migrate.downgrade_operators = []
Migrate._upgrade_fk_m2m_index_operators = [] Migrate._upgrade_fk_m2m_index_operators = []
@ -43,55 +36,36 @@ def reset_migrate() -> None:
Migrate._downgrade_m2m = [] Migrate._downgrade_m2m = []
@pytest.fixture(scope="session") @pytest.yield_fixture(scope="session")
def event_loop() -> Generator: def event_loop():
policy = asyncio.get_event_loop_policy() policy = asyncio.get_event_loop_policy()
res = policy.new_event_loop() res = policy.new_event_loop()
asyncio.set_event_loop(res) asyncio.set_event_loop(res)
res._close = res.close # type:ignore[attr-defined] res._close = res.close
res.close = lambda: None # type:ignore[method-assign] res.close = lambda: None
yield res yield res
res._close() # type:ignore[attr-defined] res._close()
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
async def initialize_tests(event_loop, request) -> None: async def initialize_tests(event_loop, request):
await init_db(tortoise_orm) tortoise_orm["connections"]["diff_models"] = "sqlite://:memory:"
tortoise_orm["apps"]["diff_models"] = {
"models": ["tests.diff_models"],
"default_connection": "diff_models",
}
await Tortoise.init(config=tortoise_orm, _create_db=True)
await generate_schema_for_client(Tortoise.get_connection("default"), safe=True)
client = Tortoise.get_connection("default") client = Tortoise.get_connection("default")
if client.schema_generator is MySQLSchemaGenerator: if client.schema_generator is MySQLSchemaGenerator:
Migrate.ddl = MysqlDDL(client) Migrate.ddl = MysqlDDL(client)
elif client.schema_generator is SqliteSchemaGenerator: elif client.schema_generator is SqliteSchemaGenerator:
Migrate.ddl = SqliteDDL(client) Migrate.ddl = SqliteDDL(client)
elif issubclass(client.schema_generator, BasePostgresSchemaGenerator): elif client.schema_generator is AsyncpgSchemaGenerator:
Migrate.ddl = PostgresDDL(client) Migrate.ddl = PostgresDDL(client)
Migrate.dialect = Migrate.ddl.DIALECT Migrate.dialect = Migrate.ddl.DIALECT
request.addfinalizer(lambda: event_loop.run_until_complete(Tortoise._drop_databases())) request.addfinalizer(lambda: event_loop.run_until_complete(Tortoise._drop_databases()))
@pytest.fixture
def new_aerich_project(tmp_path: Path):
test_dir = Path(__file__).parent / "tests"
asset_dir = test_dir / "assets" / "fake"
settings_py = asset_dir / "settings.py"
_tests_py = asset_dir / "_tests.py"
db_py = asset_dir / "db.py"
models_py = test_dir / "models.py"
models_second_py = test_dir / "models_second.py"
copy_files(settings_py, _tests_py, models_py, models_second_py, db_py, target_dir=tmp_path)
dst_dir = tmp_path / "tests"
dst_dir.mkdir()
dst_dir.joinpath("__init__.py").touch()
copy_files(test_dir / "_utils.py", test_dir / "indexes.py", target_dir=dst_dir)
if should_remove := str(tmp_path) not in sys.path:
sys.path.append(str(tmp_path))
with chdir(tmp_path):
run_shell("python db.py create", capture_output=False)
try:
yield
finally:
if not os.getenv("AERICH_DONT_DROP_FAKE_DB"):
run_shell("python db.py drop", capture_output=False)
if should_remove:
sys.path.remove(str(tmp_path))

BIN
images/alipay.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

BIN
images/wechatpay.jpeg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 76 KiB

2374
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,116 +1,43 @@
[project]
name = "aerich"
version = "0.8.2"
description = "A database migrations tool for Tortoise ORM."
authors = [{name="long2ice", email="long2ice@gmail.com>"}]
license = { text = "Apache-2.0" }
readme = "README.md"
keywords = ["migrate", "Tortoise-ORM", "mysql"]
packages = [{ include = "aerich" }]
include = ["CHANGELOG.md", "LICENSE", "README.md"]
requires-python = ">=3.8"
dependencies = [
"tortoise-orm (>=0.21.0,<1.0.0); python_version < '4.0'",
"pydantic (>=2.0.2,!=2.1.0,!=2.7.0,<3.0.0)",
"dictdiffer (>=0.9.0,<1.0.0)",
"asyncclick (>=8.1.7,<9.0.0)",
"eval-type-backport (>=0.2.2,<1.0.0); python_version < '3.10'",
]
[project.optional-dependencies]
toml = [
"tomli-w (>=1.1.0,<2.0.0); python_version >= '3.11'",
"tomlkit (>=0.11.4,<1.0.0); python_version < '3.11'",
]
# Need asyncpg or psyncopg for PostgreSQL
asyncpg = ["asyncpg"]
psycopg = ["psycopg[pool,binary] (>=3.0.12,<4.0.0)"]
# Need asyncmy or aiomysql for MySQL
asyncmy = ["asyncmy>=0.2.9; python_version < '4.0'"]
mysql = ["aiomysql>=0.2.0"]
[project.urls]
homepage = "https://github.com/tortoise/aerich"
repository = "https://github.com/tortoise/aerich.git"
documentation = "https://github.com/tortoise/aerich"
[project.scripts]
aerich = "aerich.cli:main"
[tool.poetry] [tool.poetry]
requires-poetry = ">=2.0" name = "aerich"
version = "0.4.0"
description = "A database migrations tool for Tortoise ORM."
authors = ["long2ice <long2ice@gmail.com>"]
license = "Apache-2.0"
readme = "README.md"
homepage = "https://github.com/long2ice/aerich"
repository = "https://github.com/long2ice/aerich.git"
documentation = "https://github.com/long2ice/aerich"
keywords = ["migrate", "Tortoise-ORM", "mysql"]
packages = [
{ include = "aerich" }
]
include = ["CHANGELOG.md", "LICENSE", "README.md"]
[tool.poetry.group.dev.dependencies] [tool.poetry.dependencies]
ruff = "^0.9.0" python = "^3.7"
bandit = "^1.7.0" tortoise-orm = "*"
mypy = "^1.10.0" click = "*"
twine = "^6.1.0" pydantic = "*"
aiomysql = {version = "*", optional = true}
asyncpg = {version = "*", optional = true}
[tool.poetry.group.test.dependencies] [tool.poetry.dev-dependencies]
pytest = "^8.3.0" flake8 = "*"
pytest-mock = "^3.14.0" isort = "*"
pytest-xdist = "^3.6.0" black = "^20.8b1"
# Breaking change in 0.23.* pytest = "*"
# https://github.com/pytest-dev/pytest-asyncio/issues/706 pytest-xdist = "*"
pytest-asyncio = "^0.21.2" pytest-asyncio = "*"
# required for sha256_password by asyncmy bandit = "*"
cryptography = {version="^44.0.1", python="!=3.9.0,!=3.9.1"} pytest-mock = "*"
[tool.aerich] [tool.poetry.extras]
tortoise_orm = "conftest.tortoise_orm" dbdrivers = ["aiomysql", "asyncpg"]
location = "./migrations"
src_folder = "./."
[build-system] [build-system]
requires = ["poetry-core>=2.0.0"] requires = ["poetry>=0.12"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.masonry.api"
[tool.pytest.ini_options] [tool.poetry.scripts]
asyncio_mode = 'auto' aerich = "aerich.cli:main"
[tool.coverage.run]
branch = true
source = ["aerich"]
[tool.coverage.report]
exclude_also = [
"if TYPE_CHECKING:"
]
[tool.mypy]
pretty = true
python_version = "3.8"
check_untyped_defs = true
warn_unused_ignores = true
disallow_incomplete_defs = false
exclude = ["tests/assets", "migrations"]
[[tool.mypy.overrides]]
module = [
'dictdiffer.*',
'tomlkit',
'tomli_w',
'tomli',
]
ignore_missing_imports = true
[tool.ruff]
line-length = 100
[tool.ruff.lint]
extend-select = [
"I", # https://docs.astral.sh/ruff/rules/#isort-i
"SIM", # https://docs.astral.sh/ruff/rules/#flake8-simplify-sim
"FA", # https://docs.astral.sh/ruff/rules/#flake8-future-annotations-fa
"UP", # https://docs.astral.sh/ruff/rules/#pyupgrade-up
"RUF100", # https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
]
ignore = ["UP031"] # https://docs.astral.sh/ruff/rules/printf-string-formatting/
[tool.ruff.lint.per-file-ignores]
# TODO: Remove this line when dropping support for Python3.8
"aerich/inspectdb/__init__.py" = ["UP006", "UP035"]
"aerich/_compat.py" = ["F401"]
[tool.bandit]
exclude_dirs = ["tests", "conftest.py"]

2
setup.cfg Normal file
View File

@ -0,0 +1,2 @@
[flake8]
ignore = E501,W503

View File

@ -1,87 +0,0 @@
import contextlib
import os
import platform
import shlex
import shutil
import subprocess
import sys
from pathlib import Path
from tortoise import Tortoise, generate_schema_for_client
from tortoise.exceptions import DBConnectionError, OperationalError
if sys.version_info >= (3, 11):
from contextlib import chdir
else:
class chdir(contextlib.AbstractContextManager): # Copied from source code of Python3.13
"""Non thread-safe context manager to change the current working directory."""
def __init__(self, path):
self.path = path
self._old_cwd = []
def __enter__(self):
self._old_cwd.append(os.getcwd())
os.chdir(self.path)
def __exit__(self, *excinfo):
os.chdir(self._old_cwd.pop())
async def drop_db(tortoise_orm) -> None:
# Placing init outside the try-block(suppress) since it doesn't
# establish connections to the DB eagerly.
await Tortoise.init(config=tortoise_orm)
with contextlib.suppress(DBConnectionError, OperationalError):
await Tortoise._drop_databases()
async def init_db(tortoise_orm, generate_schemas=True) -> None:
await drop_db(tortoise_orm)
await Tortoise.init(config=tortoise_orm, _create_db=True)
if generate_schemas:
await generate_schema_for_client(Tortoise.get_connection("default"), safe=True)
def copy_files(*src_files: Path, target_dir: Path) -> None:
for src in src_files:
shutil.copy(src, target_dir)
class Dialect:
test_db_url: str
@classmethod
def load_env(cls) -> None:
if getattr(cls, "test_db_url", None) is None:
cls.test_db_url = os.getenv("TEST_DB", "")
@classmethod
def is_postgres(cls) -> bool:
cls.load_env()
return "postgres" in cls.test_db_url
@classmethod
def is_mysql(cls) -> bool:
cls.load_env()
return "mysql" in cls.test_db_url
@classmethod
def is_sqlite(cls) -> bool:
cls.load_env()
return not cls.test_db_url or "sqlite" in cls.test_db_url
WINDOWS = platform.system() == "Windows"
def run_shell(command: str, capture_output=True, **kw) -> str:
if WINDOWS and command.startswith("aerich "):
command = "python -m " + command
r = subprocess.run(shlex.split(command), capture_output=capture_output)
if r.returncode != 0 and r.stderr:
return r.stderr.decode()
if not r.stdout:
return ""
return r.stdout.decode()

View File

@ -1,80 +0,0 @@
import pytest
from models import NewModel
from models_second import Config
from settings import TORTOISE_ORM
from tortoise import Tortoise
from tortoise.exceptions import OperationalError
try:
# This error does not translate to tortoise's OperationalError
from psycopg.errors import UndefinedColumn
except ImportError:
errors = (OperationalError,)
else:
errors = (OperationalError, UndefinedColumn)
@pytest.fixture(scope="session")
def anyio_backend() -> str:
return "asyncio"
@pytest.fixture(autouse=True)
async def init_connections():
await Tortoise.init(TORTOISE_ORM)
try:
yield
finally:
await Tortoise.close_connections()
@pytest.mark.anyio
async def test_init_db():
m1 = await NewModel.filter(name="")
assert isinstance(m1, list)
m2 = await Config.filter(key="")
assert isinstance(m2, list)
await NewModel.create(name="")
await Config.create(key="", label="", value={})
@pytest.mark.anyio
async def test_fake_field_1():
assert "field_1" in NewModel._meta.fields_map
assert "field_1" in Config._meta.fields_map
with pytest.raises(errors):
await NewModel.create(name="", field_1=1)
with pytest.raises(errors):
await Config.create(key="", label="", value={}, field_1=1)
obj1 = NewModel(name="", field_1=1)
with pytest.raises(errors):
await obj1.save()
obj1 = NewModel(name="")
with pytest.raises(errors):
await obj1.save()
with pytest.raises(errors):
obj1 = await NewModel.first()
obj1 = await NewModel.all().first().values("id", "name")
assert obj1 and obj1["id"]
obj2 = Config(key="", label="", value={}, field_1=1)
with pytest.raises(errors):
await obj2.save()
obj2 = Config(key="", label="", value={})
with pytest.raises(errors):
await obj2.save()
with pytest.raises(errors):
obj2 = await Config.first()
obj2 = await Config.all().first().values("id", "key")
assert obj2 and obj2["id"]
@pytest.mark.anyio
async def test_fake_field_2():
assert "field_2" in NewModel._meta.fields_map
assert "field_2" in Config._meta.fields_map
with pytest.raises(errors):
await NewModel.create(name="")
with pytest.raises(errors):
await Config.create(key="", label="", value={})

View File

@ -1,28 +0,0 @@
import asyncclick as click
from settings import TORTOISE_ORM
from tests._utils import drop_db, init_db
@click.group()
def cli(): ...
@cli.command()
async def create():
await init_db(TORTOISE_ORM, False)
click.echo(f"Success to create databases for {TORTOISE_ORM['connections']}")
@cli.command()
async def drop():
await drop_db(TORTOISE_ORM)
click.echo(f"Dropped databases for {TORTOISE_ORM['connections']}")
def main():
cli()
if __name__ == "__main__":
main()

View File

@ -1,22 +0,0 @@
import os
from datetime import date
from tortoise.contrib.test import MEMORY_SQLITE
DB_URL = (
_u.replace("\\{\\}", f"aerich_fake_{date.today():%Y%m%d}")
if (_u := os.getenv("TEST_DB"))
else MEMORY_SQLITE
)
DB_URL_SECOND = (DB_URL + "_second") if DB_URL != MEMORY_SQLITE else MEMORY_SQLITE
TORTOISE_ORM = {
"connections": {
"default": DB_URL.replace(MEMORY_SQLITE, "sqlite://db.sqlite3"),
"second": DB_URL_SECOND.replace(MEMORY_SQLITE, "sqlite://db_second.sqlite3"),
},
"apps": {
"models": {"models": ["models", "aerich.models"], "default_connection": "default"},
"models_second": {"models": ["models_second"], "default_connection": "second"},
},
}

View File

@ -1,76 +0,0 @@
import uuid
import pytest
from models import Foo
from tortoise.exceptions import IntegrityError
@pytest.mark.asyncio
async def test_allow_duplicate() -> None:
await Foo.all().delete()
await Foo.create(name="foo")
obj = await Foo.create(name="foo")
assert (await Foo.all().count()) == 2
await obj.delete()
@pytest.mark.asyncio
async def test_unique_is_true() -> None:
with pytest.raises(IntegrityError):
await Foo.create(name="foo")
await Foo.create(name="foo")
@pytest.mark.asyncio
async def test_add_unique_field() -> None:
if not await Foo.filter(age=0).exists():
await Foo.create(name="0_" + uuid.uuid4().hex, age=0)
with pytest.raises(IntegrityError):
await Foo.create(name=uuid.uuid4().hex, age=0)
@pytest.mark.asyncio
async def test_drop_unique_field() -> None:
name = "1_" + uuid.uuid4().hex
await Foo.create(name=name, age=0)
assert await Foo.filter(name=name).exists()
@pytest.mark.asyncio
async def test_with_age_field() -> None:
name = "2_" + uuid.uuid4().hex
await Foo.create(name=name, age=0)
obj = await Foo.get(name=name)
assert obj.age == 0
@pytest.mark.asyncio
async def test_without_age_field() -> None:
name = "3_" + uuid.uuid4().hex
await Foo.create(name=name, age=0)
obj = await Foo.get(name=name)
assert getattr(obj, "age", None) is None
@pytest.mark.asyncio
async def test_m2m_with_custom_through() -> None:
from models import FooGroup, Group
name = "4_" + uuid.uuid4().hex
foo = await Foo.create(name=name)
group = await Group.create(name=name + "1")
await FooGroup.all().delete()
await foo.groups.add(group)
foo_group = await FooGroup.get(foo=foo, group=group)
assert not foo_group.is_active
@pytest.mark.asyncio
async def test_add_m2m_field_after_init_db() -> None:
from models import Group
name = "5_" + uuid.uuid4().hex
foo = await Foo.create(name=name)
group = await Group.create(name=name + "1")
await foo.groups.add(group)
assert (await group.users.all().first()) == foo

View File

@ -1,28 +0,0 @@
from __future__ import annotations
import asyncio
from collections.abc import Generator
import pytest
import pytest_asyncio
import settings
from tortoise import Tortoise, connections
@pytest.fixture(scope="session")
def event_loop() -> Generator:
policy = asyncio.get_event_loop_policy()
res = policy.new_event_loop()
asyncio.set_event_loop(res)
res._close = res.close # type:ignore[attr-defined]
res.close = lambda: None # type:ignore[method-assign]
yield res
res._close() # type:ignore[attr-defined]
@pytest_asyncio.fixture(scope="session", autouse=True)
async def api(event_loop, request):
await Tortoise.init(config=settings.TORTOISE_ORM)
request.addfinalizer(lambda: event_loop.run_until_complete(connections.close_all(discard=True)))

View File

@ -1,5 +0,0 @@
from tortoise import Model, fields
class Foo(Model):
name = fields.CharField(max_length=60, db_index=False)

View File

@ -1,4 +0,0 @@
TORTOISE_ORM = {
"connections": {"default": "sqlite://db.sqlite3"},
"apps": {"models": {"models": ["models", "aerich.models"]}},
}

62
tests/diff_models.py Normal file
View File

@ -0,0 +1,62 @@
import datetime
from enum import IntEnum
from tortoise import Model, fields
class ProductType(IntEnum):
article = 1
page = 2
class PermissionAction(IntEnum):
create = 1
delete = 2
update = 3
read = 4
class Status(IntEnum):
on = 1
off = 0
class User(Model):
username = fields.CharField(max_length=20)
password = fields.CharField(max_length=200)
last_login_at = fields.DatetimeField(description="Last Login", default=datetime.datetime.now)
is_active = fields.BooleanField(default=True, description="Is Active")
is_superuser = fields.BooleanField(default=False, description="Is SuperUser")
avatar = fields.CharField(max_length=200, default="")
intro = fields.TextField(default="")
class Email(Model):
email = fields.CharField(max_length=200)
is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("diff_models.User", db_constraint=True)
class Category(Model):
slug = fields.CharField(max_length=200)
user = fields.ForeignKeyField("diff_models.User", description="User")
created_at = fields.DatetimeField(auto_now_add=True)
class Product(Model):
categories = fields.ManyToManyField("diff_models.Category")
name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num")
sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField(ProductType, description="Product Type")
image = fields.CharField(max_length=200)
body = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True)
class Config(Model):
label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20)
value = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on)

View File

@ -1,7 +0,0 @@
from tortoise.indexes import Index
class CustomIndex(Index):
def __init__(self, *args, **kw) -> None:
super().__init__(*args, **kw)
self._foo = ""

View File

@ -1,16 +1,7 @@
from __future__ import annotations
import datetime import datetime
import uuid
from enum import IntEnum from enum import IntEnum
from tortoise import Model, fields from tortoise import Model, fields
from tortoise.contrib.mysql.indexes import FullTextIndex
from tortoise.contrib.postgres.indexes import HashIndex
from tortoise.indexes import Index
from tests._utils import Dialect
from tests.indexes import CustomIndex
class ProductType(IntEnum): class ProductType(IntEnum):
@ -32,110 +23,41 @@ class Status(IntEnum):
class User(Model): class User(Model):
username = fields.CharField(max_length=20, unique=True) username = fields.CharField(max_length=20, unique=True)
password = fields.CharField(max_length=100) password = fields.CharField(max_length=200)
last_login = fields.DatetimeField(description="Last Login", default=datetime.datetime.now) last_login = fields.DatetimeField(description="Last Login", default=datetime.datetime.now)
is_active = fields.BooleanField(default=True, description="Is Active") is_active = fields.BooleanField(default=True, description="Is Active")
is_superuser = fields.BooleanField(default=False, description="Is SuperUser") is_superuser = fields.BooleanField(default=False, description="Is SuperUser")
avatar = fields.CharField(max_length=200, default="")
intro = fields.TextField(default="") intro = fields.TextField(default="")
longitude = fields.DecimalField(max_digits=10, decimal_places=8)
products: fields.ManyToManyRelation[Product]
class Meta:
# reverse indexes elements
indexes = [CustomIndex(fields=("is_superuser",)), Index(fields=("username", "is_active"))]
class Email(Model): class Email(Model):
email_id = fields.IntField(primary_key=True) email = fields.CharField(max_length=200)
email = fields.CharField(max_length=200, db_index=True)
company = fields.CharField(max_length=100, db_index=True, unique=True)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
address = fields.CharField(max_length=200) user = fields.ForeignKeyField("models.User", db_constraint=False)
users: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User")
config: fields.OneToOneRelation[Config] = fields.OneToOneField("models.Config")
def default_name():
return uuid.uuid4()
class Category(Model): class Category(Model):
slug = fields.CharField(max_length=100) slug = fields.CharField(max_length=200)
name = fields.CharField(max_length=200, null=True, default=default_name) name = fields.CharField(max_length=200)
owner: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( user = fields.ForeignKeyField("models.User", description="User")
"models.User", description="User"
)
title = fields.CharField(max_length=20, unique=False)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Meta:
if Dialect.is_postgres():
indexes = [HashIndex(fields=("slug",))]
elif Dialect.is_mysql():
indexes = [FullTextIndex(fields=("slug",))] # type:ignore
else:
indexes = [Index(fields=("slug",))] # type:ignore
class Product(Model): class Product(Model):
id = fields.BigIntField(primary_key=True) categories = fields.ManyToManyField("models.Category")
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models.Category", null=False
)
users: fields.ManyToManyRelation[User] = fields.ManyToManyField(
"models.User", related_name="products"
)
name = fields.CharField(max_length=50) name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num", default=0) view_num = fields.IntField(description="View Num")
sort = fields.IntField() sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed") is_reviewed = fields.BooleanField(description="Is Reviewed")
type: int = fields.IntEnumField( type = fields.IntEnumField(ProductType, description="Product Type")
ProductType, description="Product Type", source_field="type_db_alias" image = fields.CharField(max_length=200)
)
pic = fields.CharField(max_length=200)
body = fields.TextField() body = fields.TextField()
price = fields.FloatField(null=True)
no = fields.UUIDField(db_index=True)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
is_deleted = fields.BooleanField(default=False)
class Meta:
unique_together = (("name", "type"),)
indexes = (("name", "type"),)
managed = True
class Config(Model): class Config(Model):
slug = fields.CharField(primary_key=True, max_length=20)
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models.Category", through="config_category_map", related_name="category_set"
)
label = fields.CharField(max_length=200) label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20) key = fields.CharField(max_length=20)
value: dict = fields.JSONField() value = fields.JSONField()
status: Status = fields.IntEnumField(Status) status: Status = fields.IntEnumField(Status, default=Status.on)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", description="User"
)
email: fields.OneToOneRelation[Email]
class Meta:
managed = True
class DontManageMe(Model):
name = fields.CharField(max_length=50)
class Meta:
managed = False
class Ignore(Model):
class Meta:
managed = False
class NewModel(Model):
name = fields.CharField(max_length=50)

View File

@ -34,31 +34,23 @@ class User(Model):
class Email(Model): class Email(Model):
email = fields.CharField(max_length=200) email = fields.CharField(max_length=200)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( user = fields.ForeignKeyField("models_second.User", db_constraint=False)
"models_second.User", db_constraint=False
)
class Category(Model): class Category(Model):
slug = fields.CharField(max_length=200) slug = fields.CharField(max_length=200)
name = fields.CharField(max_length=200) name = fields.CharField(max_length=200)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( user = fields.ForeignKeyField("models_second.User", description="User")
"models_second.User", description="User"
)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Product(Model): class Product(Model):
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField( categories = fields.ManyToManyField("models_second.Category")
"models_second.Category"
)
name = fields.CharField(max_length=50) name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num") view_num = fields.IntField(description="View Num")
sort = fields.IntField() sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed") is_reviewed = fields.BooleanField(description="Is Reviewed")
type: int = fields.IntEnumField( type = fields.IntEnumField(ProductType, description="Product Type")
ProductType, description="Product Type", source_field="type_db_alias"
)
image = fields.CharField(max_length=200) image = fields.CharField(max_length=200)
body = fields.TextField() body = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
@ -67,5 +59,5 @@ class Product(Model):
class Config(Model): class Config(Model):
label = fields.CharField(max_length=200) label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20) key = fields.CharField(max_length=20)
value: dict = fields.JSONField() value = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on) status: Status = fields.IntEnumField(Status, default=Status.on)

View File

@ -1,129 +0,0 @@
import datetime
from enum import IntEnum
from tortoise import Model, fields
from tortoise.indexes import Index
from tests.indexes import CustomIndex
class ProductType(IntEnum):
article = 1
page = 2
class PermissionAction(IntEnum):
create = 1
delete = 2
update = 3
read = 4
class Status(IntEnum):
on = 1
off = 0
class User(Model):
username = fields.CharField(max_length=20)
password = fields.CharField(max_length=200)
last_login = fields.DatetimeField(description="Last Login", default=datetime.datetime.now)
is_active = fields.BooleanField(default=True, description="Is Active")
is_superuser = fields.BooleanField(default=False, description="Is SuperUser")
avatar = fields.CharField(max_length=200, default="")
intro = fields.TextField(default="")
longitude = fields.DecimalField(max_digits=12, decimal_places=9)
class Meta:
indexes = [Index(fields=("username", "is_active")), CustomIndex(fields=("is_superuser",))]
class Email(Model):
email = fields.CharField(max_length=200)
company = fields.CharField(max_length=100, db_index=True)
is_primary = fields.BooleanField(default=False)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", db_constraint=False
)
class Category(Model):
slug = fields.CharField(max_length=200)
name = fields.CharField(max_length=200)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", description="User"
)
title = fields.CharField(max_length=20, unique=True)
created_at = fields.DatetimeField(auto_now_add=True)
class Meta:
indexes = [Index(fields=("slug",))]
class Product(Model):
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category")
uid = fields.IntField(source_field="uuid", unique=True)
name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num")
sort = fields.IntField()
is_review = fields.BooleanField(description="Is Reviewed")
type: int = fields.IntEnumField(
ProductType, description="Product Type", source_field="type_db_alias"
)
image = fields.CharField(max_length=200)
body = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True)
is_delete = fields.BooleanField(default=False)
class Config(Model):
slug = fields.CharField(primary_key=True, max_length=10)
category: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category")
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models.Category", through="config_category_map", related_name="config_set"
)
name = fields.CharField(max_length=100, unique=True)
label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20)
value: dict = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on)
class Meta:
table = "configs"
class DontManageMe(Model):
name = fields.CharField(max_length=50)
class Meta:
table = "dont_manage"
class Ignore(Model):
name = fields.CharField(max_length=50)
class Meta:
managed = True
def main() -> None:
"""Generate a python file for the old_models_describe"""
from pathlib import Path
from tortoise import run_async
from tortoise.contrib.test import init_memory_sqlite
from aerich.utils import get_models_describe
@init_memory_sqlite
async def run() -> None:
old_models_describe = get_models_describe("models")
p = Path("old_models_describe.py")
p.write_text(f"{old_models_describe = }", encoding="utf-8")
print(f"Write value to {p}\nYou can reformat it by `ruff format {p}`")
run_async(run())
if __name__ == "__main__":
main()

View File

@ -1,11 +0,0 @@
from aerich import Command
from conftest import tortoise_orm
async def test_command(mocker):
mocker.patch("os.listdir", return_value=[])
async with Command(tortoise_orm) as command:
history = await command.history()
heads = await command.heads()
assert history == []
assert heads == []

View File

@ -1,57 +1,38 @@
import tortoise import pytest
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
from aerich.exceptions import NotSupportError
from aerich.migrate import Migrate from aerich.migrate import Migrate
from tests.models import Category, Product, User from tests.models import Category, User
def test_create_table(): def test_create_table():
ret = Migrate.ddl.create_table(Category) ret = Migrate.ddl.create_table(Category)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
if tortoise.__version__ >= "0.24":
assert ( assert (
ret ret
== """CREATE TABLE IF NOT EXISTS `category` ( == """CREATE TABLE IF NOT EXISTS `category` (
`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT, `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
`slug` VARCHAR(100) NOT NULL, `slug` VARCHAR(200) NOT NULL,
`name` VARCHAR(200), `name` VARCHAR(200) NOT NULL,
`title` VARCHAR(20) NOT NULL,
`created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
`owner_id` INT NOT NULL COMMENT 'User', `user_id` INT NOT NULL COMMENT 'User',
CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE, CONSTRAINT `fk_category_user_e2e3874c` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE
FULLTEXT KEY `idx_category_slug_e9bcff` (`slug`) ) CHARACTER SET utf8mb4;"""
) CHARACTER SET utf8mb4"""
)
return
assert (
ret
== """CREATE TABLE IF NOT EXISTS `category` (
`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
`slug` VARCHAR(100) NOT NULL,
`name` VARCHAR(200),
`title` VARCHAR(20) NOT NULL,
`created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
`owner_id` INT NOT NULL COMMENT 'User',
CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE
) CHARACTER SET utf8mb4;
CREATE FULLTEXT INDEX `idx_category_slug_e9bcff` ON `category` (`slug`)"""
) )
elif isinstance(Migrate.ddl, SqliteDDL): elif isinstance(Migrate.ddl, SqliteDDL):
exists = "IF NOT EXISTS " if tortoise.__version__ >= "0.24" else ""
assert ( assert (
ret ret
== f"""CREATE TABLE IF NOT EXISTS "category" ( == """CREATE TABLE IF NOT EXISTS "category" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
"slug" VARCHAR(100) NOT NULL, "slug" VARCHAR(200) NOT NULL,
"name" VARCHAR(200), "name" VARCHAR(200) NOT NULL,
"title" VARCHAR(20) NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */ "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */
); );"""
CREATE INDEX {exists}"idx_category_slug_e9bcff" ON "category" ("slug")"""
) )
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
@ -59,19 +40,17 @@ CREATE INDEX {exists}"idx_category_slug_e9bcff" ON "category" ("slug")"""
ret ret
== """CREATE TABLE IF NOT EXISTS "category" ( == """CREATE TABLE IF NOT EXISTS "category" (
"id" SERIAL NOT NULL PRIMARY KEY, "id" SERIAL NOT NULL PRIMARY KEY,
"slug" VARCHAR(100) NOT NULL, "slug" VARCHAR(200) NOT NULL,
"name" VARCHAR(200), "name" VARCHAR(200) NOT NULL,
"title" VARCHAR(20) NOT NULL,
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, "created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE
); );
CREATE INDEX IF NOT EXISTS "idx_category_slug_e9bcff" ON "category" USING HASH ("slug"); COMMENT ON COLUMN "category"."user_id" IS 'User';"""
COMMENT ON COLUMN "category"."owner_id" IS 'User'"""
) )
def test_drop_table(): def test_drop_table():
ret = Migrate.ddl.drop_table(Category._meta.db_table) ret = Migrate.ddl.drop_table(Category)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "DROP TABLE IF EXISTS `category`" assert ret == "DROP TABLE IF EXISTS `category`"
else: else:
@ -79,94 +58,85 @@ def test_drop_table():
def test_add_column(): def test_add_column():
ret = Migrate.ddl.add_column(Category, Category._meta.fields_map["name"].describe(False)) ret = Migrate.ddl.add_column(Category, Category._meta.fields_map.get("name"))
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` ADD `name` VARCHAR(200)" assert ret == "ALTER TABLE `category` ADD `name` VARCHAR(200) NOT NULL"
else: else:
assert ret == 'ALTER TABLE "category" ADD "name" VARCHAR(200)' assert ret == 'ALTER TABLE "category" ADD "name" VARCHAR(200) NOT NULL'
# add unique column
ret = Migrate.ddl.add_column(User, User._meta.fields_map["username"].describe(False))
if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `user` ADD `username` VARCHAR(20) NOT NULL UNIQUE"
elif isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "user" ADD "username" VARCHAR(20) NOT NULL UNIQUE'
else:
assert ret == 'ALTER TABLE "user" ADD "username" VARCHAR(20) NOT NULL'
def test_modify_column(): def test_modify_column():
if isinstance(Migrate.ddl, SqliteDDL): if isinstance(Migrate.ddl, SqliteDDL):
return with pytest.raises(NotSupportError):
ret0 = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name"))
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active"))
else:
ret0 = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name"))
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active"))
if isinstance(Migrate.ddl, MysqlDDL):
assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL"
elif isinstance(Migrate.ddl, PostgresDDL):
assert ret0 == 'ALTER TABLE "category" ALTER COLUMN "name" TYPE VARCHAR(200)'
ret0 = Migrate.ddl.modify_column(Category, Category._meta.fields_map["name"].describe(False))
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map["is_active"].describe(False))
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)"
assert ( assert (
ret1 ret1
== "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1" == "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1"
) )
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert ( assert ret1 == 'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL'
ret0
== 'ALTER TABLE "category" ALTER COLUMN "name" TYPE VARCHAR(200) USING "name"::VARCHAR(200)'
)
assert (
ret1 == 'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL USING "is_active"::BOOL'
)
def test_alter_column_default(): def test_alter_column_default():
if isinstance(Migrate.ddl, SqliteDDL): ret = Migrate.ddl.alter_column_default(Category, Category._meta.fields_map.get("name"))
return
ret = Migrate.ddl.alter_column_default(User, User._meta.fields_map["intro"].describe(False))
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "user" ALTER COLUMN "intro" SET DEFAULT \'\'' assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" DROP DEFAULT'
elif isinstance(Migrate.ddl, MysqlDDL): else:
assert ret == "ALTER TABLE `user` ALTER COLUMN `intro` SET DEFAULT ''" assert ret is None
ret = Migrate.ddl.alter_column_default( ret = Migrate.ddl.alter_column_default(Category, Category._meta.fields_map.get("created_at"))
Category, Category._meta.fields_map["created_at"].describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):
assert ( assert (
ret == 'ALTER TABLE "category" ALTER COLUMN "created_at" SET DEFAULT CURRENT_TIMESTAMP' ret == 'ALTER TABLE "category" ALTER COLUMN "created_at" SET DEFAULT CURRENT_TIMESTAMP'
) )
elif isinstance(Migrate.ddl, MysqlDDL): else:
assert ( assert ret is None
ret
== "ALTER TABLE `category` ALTER COLUMN `created_at` SET DEFAULT CURRENT_TIMESTAMP(6)"
)
ret = Migrate.ddl.alter_column_default( ret = Migrate.ddl.alter_column_default(User, User._meta.fields_map.get("avatar"))
Product, Product._meta.fields_map["view_num"].describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0' assert ret == 'ALTER TABLE "user" ALTER COLUMN "avatar" SET DEFAULT \'\''
elif isinstance(Migrate.ddl, MysqlDDL): else:
assert ret == "ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0" assert ret is None
def test_alter_column_null(): def test_alter_column_null():
if isinstance(Migrate.ddl, (SqliteDDL, MysqlDDL)): ret = Migrate.ddl.alter_column_null(Category, Category._meta.fields_map.get("name"))
return
ret = Migrate.ddl.alter_column_null(Category, Category._meta.fields_map["name"].describe(False))
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" DROP NOT NULL' assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL'
else:
assert ret is None
def test_set_comment(): def test_set_comment():
if isinstance(Migrate.ddl, (SqliteDDL, MysqlDDL)): ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("name"))
return if isinstance(Migrate.ddl, PostgresDDL):
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map["name"].describe(False))
assert ret == 'COMMENT ON COLUMN "category"."name" IS NULL' assert ret == 'COMMENT ON COLUMN "category"."name" IS NULL'
else:
assert ret is None
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map["owner"].describe(False)) ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("user"))
assert ret == 'COMMENT ON COLUMN "category"."owner_id" IS \'User\'' if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'COMMENT ON COLUMN "category"."user" IS \'User\''
else:
assert ret is None
def test_drop_column(): def test_drop_column():
if isinstance(Migrate.ddl, SqliteDDL):
with pytest.raises(NotSupportError):
ret = Migrate.ddl.drop_column(Category, "name")
else:
ret = Migrate.ddl.drop_column(Category, "name") ret = Migrate.ddl.drop_column(Category, "name")
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` DROP COLUMN `name`" assert ret == "ALTER TABLE `category` DROP COLUMN `name`"
@ -179,18 +149,20 @@ def test_add_index():
index_u = Migrate.ddl.add_index(Category, ["name"], True) index_u = Migrate.ddl.add_index(Category, ["name"], True)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert index == "ALTER TABLE `category` ADD INDEX `idx_category_name_8b0cb9` (`name`)" assert index == "ALTER TABLE `category` ADD INDEX `idx_category_name_8b0cb9` (`name`)"
assert index_u == "ALTER TABLE `category` ADD UNIQUE INDEX `name` (`name`)"
elif isinstance(Migrate.ddl, PostgresDDL):
assert ( assert (
index == 'CREATE INDEX IF NOT EXISTS "idx_category_name_8b0cb9" ON "category" ("name")' index_u == "ALTER TABLE `category` ADD UNIQUE INDEX `uid_category_name_8b0cb9` (`name`)"
) )
elif isinstance(Migrate.ddl, PostgresDDL):
assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")'
assert ( assert (
index_u index_u
== 'CREATE UNIQUE INDEX IF NOT EXISTS "uid_category_name_8b0cb9" ON "category" ("name")' == 'ALTER TABLE "category" ADD CONSTRAINT "uid_category_name_8b0cb9" UNIQUE ("name")'
) )
else: else:
assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")' assert index == 'ALTER TABLE "category" ADD INDEX "idx_category_name_8b0cb9" ("name")'
assert index_u == 'CREATE UNIQUE INDEX "uid_category_name_8b0cb9" ON "category" ("name")' assert (
index_u == 'ALTER TABLE "category" ADD UNIQUE INDEX "uid_category_name_8b0cb9" ("name")'
)
def test_drop_index(): def test_drop_index():
@ -198,35 +170,34 @@ def test_drop_index():
ret_u = Migrate.ddl.drop_index(Category, ["name"], True) ret_u = Migrate.ddl.drop_index(Category, ["name"], True)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` DROP INDEX `idx_category_name_8b0cb9`" assert ret == "ALTER TABLE `category` DROP INDEX `idx_category_name_8b0cb9`"
assert ret_u == "ALTER TABLE `category` DROP INDEX `name`" assert ret_u == "ALTER TABLE `category` DROP INDEX `uid_category_name_8b0cb9`"
elif isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'DROP INDEX "idx_category_name_8b0cb9"'
assert ret_u == 'ALTER TABLE "category" DROP CONSTRAINT "uid_category_name_8b0cb9"'
else: else:
assert ret == 'DROP INDEX IF EXISTS "idx_category_name_8b0cb9"' assert ret == 'ALTER TABLE "category" DROP INDEX "idx_category_name_8b0cb9"'
assert ret_u == 'DROP INDEX IF EXISTS "uid_category_name_8b0cb9"' assert ret_u == 'ALTER TABLE "category" DROP INDEX "uid_category_name_8b0cb9"'
def test_add_fk(): def test_add_fk():
ret = Migrate.ddl.add_fk( ret = Migrate.ddl.add_fk(Category, Category._meta.fields_map.get("user"))
Category, Category._meta.fields_map["owner"].describe(False), User.describe(False)
)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ( assert (
ret ret
== "ALTER TABLE `category` ADD CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE" == "ALTER TABLE `category` ADD CONSTRAINT `fk_category_user_e2e3874c` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE"
) )
else: else:
assert ( assert (
ret ret
== 'ALTER TABLE "category" ADD CONSTRAINT "fk_category_user_110d4c63" FOREIGN KEY ("owner_id") REFERENCES "user" ("id") ON DELETE CASCADE' == 'ALTER TABLE "category" ADD CONSTRAINT "fk_category_user_e2e3874c" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE'
) )
def test_drop_fk(): def test_drop_fk():
ret = Migrate.ddl.drop_fk( ret = Migrate.ddl.drop_fk(Category, Category._meta.fields_map.get("user"))
Category, Category._meta.fields_map["owner"].describe(False), User.describe(False)
)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` DROP FOREIGN KEY `fk_category_user_110d4c63`" assert ret == "ALTER TABLE `category` DROP FOREIGN KEY `fk_category_user_e2e3874c`"
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" DROP CONSTRAINT IF EXISTS "fk_category_user_110d4c63"' assert ret == 'ALTER TABLE "category" DROP CONSTRAINT "fk_category_user_e2e3874c"'
else: else:
assert ret == 'ALTER TABLE "category" DROP FOREIGN KEY "fk_category_user_110d4c63"' assert ret == 'ALTER TABLE "category" DROP FOREIGN KEY "fk_category_user_e2e3874c"'

View File

@ -1,106 +0,0 @@
from __future__ import annotations
import os
import re
from pathlib import Path
from tests._utils import Dialect, run_shell
def _append_field(*files: str, name="field_1") -> None:
for file in files:
p = Path(file)
field = f" {name} = fields.IntField(default=0)"
with p.open("a") as f:
f.write(os.linesep + field)
def test_fake(new_aerich_project):
if Dialect.is_sqlite():
# TODO: go ahead if sqlite alter-column supported
return
output = run_shell("aerich init -t settings.TORTOISE_ORM")
assert "Success" in output
output = run_shell("aerich init-db")
assert "Success" in output
output = run_shell("aerich --app models_second init-db")
assert "Success" in output
output = run_shell("pytest _tests.py::test_init_db")
assert "error" not in output.lower()
_append_field("models.py", "models_second.py")
output = run_shell("aerich migrate")
assert "Success" in output
output = run_shell("aerich --app models_second migrate")
assert "Success" in output
output = run_shell("aerich upgrade --fake")
assert "FAKED" in output
output = run_shell("aerich --app models_second upgrade --fake")
assert "FAKED" in output
output = run_shell("pytest _tests.py::test_fake_field_1")
assert "error" not in output.lower()
_append_field("models.py", "models_second.py", name="field_2")
output = run_shell("aerich migrate")
assert "Success" in output
output = run_shell("aerich --app models_second migrate")
assert "Success" in output
output = run_shell("aerich heads")
assert "_update.py" in output
output = run_shell("aerich upgrade --fake")
assert "FAKED" in output
output = run_shell("aerich --app models_second upgrade --fake")
assert "FAKED" in output
output = run_shell("pytest _tests.py::test_fake_field_2")
assert "error" not in output.lower()
output = run_shell("aerich heads")
assert "No available heads." in output
output = run_shell("aerich --app models_second heads")
assert "No available heads." in output
_append_field("models.py", "models_second.py", name="field_3")
run_shell("aerich migrate", capture_output=False)
run_shell("aerich --app models_second migrate", capture_output=False)
run_shell("aerich upgrade --fake", capture_output=False)
run_shell("aerich --app models_second upgrade --fake", capture_output=False)
output = run_shell("aerich downgrade --fake -v 2 --yes", input="y\n")
assert "FAKED" in output
output = run_shell("aerich --app models_second downgrade --fake -v 2 --yes", input="y\n")
assert "FAKED" in output
output = run_shell("aerich heads")
assert "No available heads." not in output
assert not re.search(r"1_\d+_update\.py", output)
assert re.search(r"2_\d+_update\.py", output)
output = run_shell("aerich --app models_second heads")
assert "No available heads." not in output
assert not re.search(r"1_\d+_update\.py", output)
assert re.search(r"2_\d+_update\.py", output)
output = run_shell("aerich downgrade --fake -v 1 --yes", input="y\n")
assert "FAKED" in output
output = run_shell("aerich --app models_second downgrade --fake -v 1 --yes", input="y\n")
assert "FAKED" in output
output = run_shell("aerich heads")
assert "No available heads." not in output
assert re.search(r"1_\d+_update\.py", output)
assert re.search(r"2_\d+_update\.py", output)
output = run_shell("aerich --app models_second heads")
assert "No available heads." not in output
assert re.search(r"1_\d+_update\.py", output)
assert re.search(r"2_\d+_update\.py", output)
output = run_shell("aerich upgrade --fake")
assert "FAKED" in output
output = run_shell("aerich --app models_second upgrade --fake")
assert "FAKED" in output
output = run_shell("aerich heads")
assert "No available heads." in output
output = run_shell("aerich --app models_second heads")
assert "No available heads." in output
output = run_shell("aerich downgrade --fake -v 1 --yes", input="y\n")
assert "FAKED" in output
output = run_shell("aerich --app models_second downgrade --fake -v 1 --yes", input="y\n")
assert "FAKED" in output
output = run_shell("aerich heads")
assert "No available heads." not in output
assert re.search(r"1_\d+_update\.py", output)
assert re.search(r"2_\d+_update\.py", output)
output = run_shell("aerich --app models_second heads")
assert "No available heads." not in output
assert re.search(r"1_\d+_update\.py", output)
assert re.search(r"2_\d+_update\.py", output)

View File

@ -1,17 +0,0 @@
from tests._utils import Dialect, run_shell
def test_inspect(new_aerich_project):
if Dialect.is_sqlite():
# TODO: test sqlite after #384 fixed
return
run_shell("aerich init -t settings.TORTOISE_ORM")
run_shell("aerich init-db")
ret = run_shell("aerich inspectdb -t product")
assert ret.startswith("from tortoise import Model, fields")
assert "primary_key=True" in ret
assert "fields.DatetimeField" in ret
assert "fields.FloatField" in ret
assert "fields.UUIDField" in ret
if Dialect.is_mysql():
assert "db_index=True" in ret

File diff suppressed because it is too large Load Diff

View File

@ -1,18 +0,0 @@
import subprocess # nosec
from pathlib import Path
from aerich.version import __version__
from tests._utils import chdir, run_shell
def test_python_m_aerich():
assert __version__ in run_shell("python -m aerich --version")
def test_poetry_add(tmp_path: Path):
package = Path(__file__).parent.resolve().parent
with chdir(tmp_path):
subprocess.run(["poetry", "new", "foo"]) # nosec
with chdir("foo"):
r = subprocess.run(["poetry", "add", package]) # nosec
assert r.returncode == 0

View File

@ -1,213 +0,0 @@
from __future__ import annotations
import contextlib
import os
import platform
import shlex
import shutil
import subprocess
from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path
from tests._utils import Dialect, chdir, copy_files
def run_aerich(cmd: str) -> subprocess.CompletedProcess | None:
if not cmd.startswith("poetry") and not cmd.startswith("python"):
if not cmd.startswith("aerich"):
cmd = "aerich " + cmd
if platform.system() == "Windows":
cmd = "python -m " + cmd
r = None
with contextlib.suppress(subprocess.TimeoutExpired):
r = subprocess.run(shlex.split(cmd), timeout=2)
return r
def run_shell(cmd: str) -> subprocess.CompletedProcess:
envs = dict(os.environ, PYTHONPATH=".")
return subprocess.run(shlex.split(cmd), env=envs)
def _get_empty_db() -> Path:
if (db_file := Path("db.sqlite3")).exists():
db_file.unlink()
return db_file
@contextmanager
def prepare_sqlite_project(tmp_path: Path) -> Generator[tuple[Path, str]]:
test_dir = Path(__file__).parent
asset_dir = test_dir / "assets" / "sqlite_migrate"
with chdir(tmp_path):
files = ("models.py", "settings.py", "_tests.py")
copy_files(*(asset_dir / f for f in files), target_dir=Path())
models_py, settings_py, test_py = (Path(f) for f in files)
copy_files(asset_dir / "conftest_.py", target_dir=Path("conftest.py"))
_get_empty_db()
yield models_py, models_py.read_text("utf-8")
def test_close_tortoise_connections_patch(tmp_path: Path) -> None:
if not Dialect.is_sqlite():
return
with prepare_sqlite_project(tmp_path) as (models_py, models_text):
run_aerich("aerich init -t settings.TORTOISE_ORM")
r = run_aerich("aerich init-db")
assert r is not None
def test_sqlite_migrate_alter_indexed_unique(tmp_path: Path) -> None:
if not Dialect.is_sqlite():
return
with prepare_sqlite_project(tmp_path) as (models_py, models_text):
models_py.write_text(models_text.replace("db_index=False", "db_index=True"))
run_aerich("aerich init -t settings.TORTOISE_ORM")
run_aerich("aerich init-db")
r = run_shell("pytest -s _tests.py::test_allow_duplicate")
assert r.returncode == 0
models_py.write_text(models_text.replace("db_index=False", "unique=True"))
run_aerich("aerich migrate") # migrations/models/1_
run_aerich("aerich upgrade")
r = run_shell("pytest _tests.py::test_unique_is_true")
assert r.returncode == 0
models_py.write_text(models_text.replace("db_index=False", "db_index=True"))
run_aerich("aerich migrate") # migrations/models/2_
run_aerich("aerich upgrade")
r = run_shell("pytest -s _tests.py::test_allow_duplicate")
assert r.returncode == 0
M2M_WITH_CUSTOM_THROUGH = """
groups = fields.ManyToManyField("models.Group", through="foo_group")
class Group(Model):
name = fields.CharField(max_length=60)
class FooGroup(Model):
foo = fields.ForeignKeyField("models.Foo")
group = fields.ForeignKeyField("models.Group")
is_active = fields.BooleanField(default=False)
class Meta:
table = "foo_group"
"""
def test_sqlite_migrate(tmp_path: Path) -> None:
if not Dialect.is_sqlite():
return
with prepare_sqlite_project(tmp_path) as (models_py, models_text):
MODELS = models_text
run_aerich("aerich init -t settings.TORTOISE_ORM")
config_file = Path("pyproject.toml")
modify_time = config_file.stat().st_mtime
run_aerich("aerich init-db")
run_aerich("aerich init -t settings.TORTOISE_ORM")
assert modify_time == config_file.stat().st_mtime
r = run_shell("pytest _tests.py::test_allow_duplicate")
assert r.returncode == 0
# Add index
models_py.write_text(MODELS.replace("index=False", "index=True"))
run_aerich("aerich migrate") # migrations/models/1_
run_aerich("aerich upgrade")
r = run_shell("pytest -s _tests.py::test_allow_duplicate")
assert r.returncode == 0
# Drop index
models_py.write_text(MODELS)
run_aerich("aerich migrate") # migrations/models/2_
run_aerich("aerich upgrade")
r = run_shell("pytest -s _tests.py::test_allow_duplicate")
assert r.returncode == 0
# Add unique index
models_py.write_text(MODELS.replace("index=False", "index=True, unique=True"))
run_aerich("aerich migrate") # migrations/models/3_
run_aerich("aerich upgrade")
r = run_shell("pytest _tests.py::test_unique_is_true")
assert r.returncode == 0
# Drop unique index
models_py.write_text(MODELS)
run_aerich("aerich migrate") # migrations/models/4_
run_aerich("aerich upgrade")
r = run_shell("pytest _tests.py::test_allow_duplicate")
assert r.returncode == 0
# Add field with unique=True
with models_py.open("a") as f:
f.write(" age = fields.IntField(unique=True, default=0)")
run_aerich("aerich migrate") # migrations/models/5_
run_aerich("aerich upgrade")
r = run_shell("pytest _tests.py::test_add_unique_field")
assert r.returncode == 0
# Drop unique field
models_py.write_text(MODELS)
run_aerich("aerich migrate") # migrations/models/6_
run_aerich("aerich upgrade")
r = run_shell("pytest -s _tests.py::test_drop_unique_field")
assert r.returncode == 0
# Initial with indexed field and then drop it
migrations_dir = Path("migrations/models")
shutil.rmtree(migrations_dir)
db_file = _get_empty_db()
models_py.write_text(MODELS + " age = fields.IntField(db_index=True)")
run_aerich("aerich init -t settings.TORTOISE_ORM")
run_aerich("aerich init-db")
migration_file = list(migrations_dir.glob("0_*.py"))[0]
assert "CREATE INDEX" in migration_file.read_text()
r = run_shell("pytest _tests.py::test_with_age_field")
assert r.returncode == 0
models_py.write_text(MODELS)
run_aerich("aerich migrate")
run_aerich("aerich upgrade")
migration_file_1 = list(migrations_dir.glob("1_*.py"))[0]
assert "DROP INDEX" in migration_file_1.read_text()
r = run_shell("pytest _tests.py::test_without_age_field")
assert r.returncode == 0
# Generate migration file in emptry directory
db_file.unlink()
run_aerich("aerich init-db")
assert not db_file.exists()
for p in migrations_dir.glob("*"):
if p.is_dir():
shutil.rmtree(p)
else:
p.unlink()
run_aerich("aerich init-db")
assert db_file.exists()
# init without '[tool]' section in pyproject.toml
config_file = Path("pyproject.toml")
config_file.write_text('[project]\nname = "project"')
run_aerich("init -t settings.TORTOISE_ORM")
assert "[tool.aerich]" in config_file.read_text()
# add m2m with custom model for through
models_py.write_text(MODELS + M2M_WITH_CUSTOM_THROUGH)
run_aerich("aerich migrate")
run_aerich("aerich upgrade")
migration_file_1 = list(migrations_dir.glob("1_*.py"))[0]
assert "foo_group" in migration_file_1.read_text()
r = run_shell("pytest _tests.py::test_m2m_with_custom_through")
assert r.returncode == 0
# add m2m field after init-db
new = """
groups = fields.ManyToManyField("models.Group", through="foo_group", related_name="users")
class Group(Model):
name = fields.CharField(max_length=60)
"""
_get_empty_db()
if migrations_dir.exists():
shutil.rmtree(migrations_dir)
models_py.write_text(MODELS)
run_aerich("aerich init-db")
models_py.write_text(MODELS + new)
run_aerich("aerich migrate")
run_aerich("aerich upgrade")
migration_file_1 = list(migrations_dir.glob("1_*.py"))[0]
assert "foo_group" in migration_file_1.read_text()
r = run_shell("pytest _tests.py::test_add_m2m_field_after_init_db")
assert r.returncode == 0

View File

@ -1,164 +0,0 @@
from aerich.utils import get_dict_diff_by_key, import_py_file
def test_import_py_file() -> None:
m = import_py_file("aerich/utils.py")
assert getattr(m, "import_py_file", None)
class TestDiffFields:
def test_the_same_through_order(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "members", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert type(get_dict_diff_by_key(old, new)).__name__ == "generator"
assert len(diffs) == 1
assert diffs == [("change", [0, "name"], ("users", "members"))]
def test_same_through_with_different_orders(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "admins", "through": "admins_group"},
{"name": "members", "through": "users_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 1
assert diffs == [("change", [0, "name"], ("users", "members"))]
def test_the_same_field_name_order(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "users", "through": "user_groups"},
{"name": "admins", "through": "admin_groups"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 4
assert diffs == [
("remove", "", [(0, {"name": "users", "through": "users_group"})]),
("remove", "", [(0, {"name": "admins", "through": "admins_group"})]),
("add", "", [(0, {"name": "users", "through": "user_groups"})]),
("add", "", [(0, {"name": "admins", "through": "admin_groups"})]),
]
def test_same_field_name_with_different_orders(self) -> None:
old = [
{"name": "admins", "through": "admins_group"},
{"name": "users", "through": "users_group"},
]
new = [
{"name": "users", "through": "user_groups"},
{"name": "admins", "through": "admin_groups"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 4
assert diffs == [
("remove", "", [(0, {"name": "admins", "through": "admins_group"})]),
("remove", "", [(0, {"name": "users", "through": "users_group"})]),
("add", "", [(0, {"name": "users", "through": "user_groups"})]),
("add", "", [(0, {"name": "admins", "through": "admin_groups"})]),
]
def test_drop_one(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "admins", "through": "admins_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 1
assert diffs == [("remove", "", [(0, {"name": "users", "through": "users_group"})])]
def test_add_one(self) -> None:
old = [
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 1
assert diffs == [("add", "", [(0, {"name": "users", "through": "users_group"})])]
def test_drop_some(self) -> None:
old = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
{"name": "staffs", "through": "staffs_group"},
]
new = [
{"name": "admins", "through": "admins_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 2
assert diffs == [
("remove", "", [(0, {"name": "users", "through": "users_group"})]),
("remove", "", [(0, {"name": "staffs", "through": "staffs_group"})]),
]
def test_add_some(self) -> None:
old = [
{"name": "staffs", "through": "staffs_group"},
]
new = [
{"name": "users", "through": "users_group"},
{"name": "admins", "through": "admins_group"},
{"name": "staffs", "through": "staffs_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 2
assert diffs == [
("add", "", [(0, {"name": "users", "through": "users_group"})]),
("add", "", [(0, {"name": "admins", "through": "admins_group"})]),
]
def test_some_through_unchanged(self) -> None:
old = [
{"name": "staffs", "through": "staffs_group"},
{"name": "admins", "through": "admins_group"},
]
new = [
{"name": "users", "through": "users_group"},
{"name": "admins_new", "through": "admins_group"},
{"name": "staffs_new", "through": "staffs_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 3
assert diffs == [
("change", [0, "name"], ("staffs", "staffs_new")),
("change", [0, "name"], ("admins", "admins_new")),
("add", "", [(0, {"name": "users", "through": "users_group"})]),
]
def test_some_unchanged_without_drop_or_add(self) -> None:
old = [
{"name": "staffs", "through": "staffs_group"},
{"name": "admins", "through": "admins_group"},
{"name": "users", "through": "users_group"},
]
new = [
{"name": "users_new", "through": "users_group"},
{"name": "admins_new", "through": "admins_group"},
{"name": "staffs_new", "through": "staffs_group"},
]
diffs = list(get_dict_diff_by_key(old, new))
assert len(diffs) == 3
assert diffs == [
("change", [0, "name"], ("staffs", "staffs_new")),
("change", [0, "name"], ("admins", "admins_new")),
("change", [0, "name"], ("users", "users_new")),
]