Compare commits

...

21 Commits
v0.8.1 ... dev

Author SHA1 Message Date
Waket Zheng
9c3ba7e273
fix: aerich init-db process is suspended (#435) 2025-03-06 13:39:56 +08:00
Waket Zheng
074ba9b743
fix: ci failed with m2m field migrate test (#434)
* fix style issue

* fixing m2m test error
2025-03-05 10:28:41 +08:00
Waket Zheng
5d9adbdb54
chore: improve type hints (#432)
* chore: improve type hints

* chore: set `warn_unused_ignores` true for mypy

* refactor: use function to compare tortoise version

* refactor: change function name
2025-03-04 14:52:12 +08:00
Waket Zheng
8609435815
Release 0.8.2 (#429) 2025-02-28 20:24:06 +08:00
Waket Zheng
a624d1b43b
fix: migrate does not recognise attribute changes for string primary key (#428)
* refactor: show warning for unsupported pk field changes

* fix: migrate does not recognise attribute changes for string primary key

* docs: update changelog

* refactor: reduce indents

* chore: update docs
2025-02-27 22:23:26 +08:00
Waket Zheng
e299f8e1d6
feat: aerich.Command support async with syntax (#427)
* feat: `aerich.Command` support `async with` syntax

* docs: update readme
2025-02-27 10:55:48 +08:00
Waket Zheng
db0cf656fc chore: show friendly message when config missing 'apps' section 2025-02-26 18:08:12 +08:00
Waket Zheng
49bfbf4e6b
feat: support psycopg (#425) 2025-02-26 17:11:31 +08:00
Waket Zheng
0364ae3f83
feat: add project section (#424)
* refactor: apply future style type hints

* chore: use project section

* ci: upgrade to poetry v2

* ci: explicit declare python version for poetry

* fix error for generate index name

* fix _generate_fk_name

* ci: verify aiomysql support

* tests: poetry add

* Add patch to fix tortoise 0.24.1

* docs: update changelog
2025-02-26 14:24:02 +08:00
Waket Zheng
91adf9334e
feat: support skip table migration by set managed=False (#397) 2025-02-21 17:08:03 +08:00
Waket Zheng
41df464e8b
fix: no migration occurs when adding unique true to indexed field (#414)
* feat: alter unique for indexed column

* chore: update docs and change some var names
2025-02-20 16:58:32 +08:00
程序猿过家家
c35282c2a3
fix: inspectdb not match data type 'DOUBLE' and 'CHAR' for MySQL
* increase:
1. Inspectdb adds DECIMAL, DOUBLE, CHAR, TIME data type matching;
2. Add exception handling, avoid the need to manually create the entire table because a certain data type is not supported.

* fix: aerich inspectdb raise KeyError for double in MySQL

* feat: support command `python -m aerich`

* docs: update changelog

* tests: verify mysql inspectdb for float field

* fix mysql uuid field inspect to be charfield

* refactor: use `db_index=True` instead of `index=True` for inspectdb

* docs: update changelog

---------

Co-authored-by: xiechen <xiechen@jinse.com>
Co-authored-by: Waket Zheng <waketzheng@gmail.com>
2025-02-19 16:04:15 +08:00
Waket Zheng
557271c8e1
feat: support command python -m aerich (#417)
* feat: support command `python -m aerich`

* docs: update changelog
2025-02-18 15:44:02 +08:00
radluz
7f8c5dcddc
fix: update asyncio event loop policy on Windows (#251)
* fix: update asyncio event loop policy on Windows

* Use `platform.system` instead of `sys.platform`

---------

Co-authored-by: Waket Zheng <waketzheng@gmail.com>
2025-02-17 18:06:10 +08:00
Waket Zheng
1793dab43d
refactor: apply future type hints style (#416)
* refactor: apply future style type hints

* chore: put cryptography out of dev dependencies
2025-02-17 11:42:56 +08:00
Waket Zheng
6bdfdfc6db
fix: aerich migrate raises tortoise.exceptions.FieldError when index.INDEX_TYPE is not empty (#415)
* fix: aerich migrate raises `tortoise.exceptions.FieldError` when `index.INDEX_TYPE` is not empty

* feat: add `IF NOT EXISTS` to postgres create index template

* chore: explicit declare type hints of function parameters
2025-02-13 18:48:45 +08:00
alistairmaclean
0be5c1b545
Remove system dependency on libsqlite3.so on command.upgrade (#413)
* Remove system dependency on libsqlite3.so on command.upgrade

* Fix styling using `make style` command
2025-02-07 20:09:04 +08:00
Abdeldjalil Hezouat
d6b35ab0ac
change hardcoded version (#412)
Co-authored-by: Waket Zheng <waketzheng@gmail.com>
2025-02-07 19:50:41 +08:00
Waket Zheng
b46ceafb2e
feat: support --fake for aerich upgrade (#398)
* feat: support `--fake` for aerich upgrade

* Add `--fake` to downgrade

* tests: check --fake result for aerich upgrade and downgrade

* Update readme

* Fix unittest failed because of `db_field_types` changed

* refactor: improve type hints and document
2025-02-07 19:44:15 +08:00
Waket Zheng
ac847ba616
refactor: avoid updating inited config file (#402)
* refactor: avoid updating config file if init config items not changed

* fix unittest error with tortoise develop branch

* Remove extra space

* fix mysql test error

* fix mysql create index error
2025-01-04 09:08:14 +08:00
Waket Zheng
f5d7d56fa5
fix: inspectdb raise KeyError 'int2' for smallint (#401)
* fix: inspectdb raise KeyError 'int2' for smallint

* fix ci error

* no ask confirm for ci

* docs: update changelog
2024-12-27 23:49:53 +08:00
40 changed files with 2811 additions and 952 deletions

View File

@ -25,26 +25,34 @@ jobs:
- tortoise021 - tortoise021
- tortoise022 - tortoise022
- tortoise023 - tortoise023
- tortoisedev - tortoise024
# TODO: add dev back when drop python3.8 support
# - tortoisedev
steps: steps:
- name: Start MySQL - name: Start MySQL
run: sudo systemctl start mysql.service run: sudo systemctl start mysql.service
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- uses: actions/cache@v4 - uses: actions/cache@v4
with: with:
path: ~/.cache/pip path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/poetry.lock') }} key: ${{ runner.os }}-pip-${{ hashFiles('**/poetry.lock') }}
restore-keys: | restore-keys: |
${{ runner.os }}-pip- ${{ runner.os }}-pip-
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install and configure Poetry - name: Install and configure Poetry
run: | run: |
pip install -U pip poetry pip install -U pip
poetry config virtualenvs.create false if [[ "${{ matrix.python-version }}" == "3.8" ]]; then
# poetry2.0+ does not support installed by python3.8, but can manage project using py38
python3.12 -m pip install "poetry>=2.0"
else
pip install "poetry>=2.0"
fi
poetry env use python${{ matrix.python-version }}
- name: Install dependencies and check style - name: Install dependencies and check style
run: make check run: poetry run make check
- name: Install TortoiseORM v0.21 - name: Install TortoiseORM v0.21
if: matrix.tortoise-orm == 'tortoise021' if: matrix.tortoise-orm == 'tortoise021'
run: poetry run pip install --upgrade "tortoise-orm>=0.21,<0.22" run: poetry run pip install --upgrade "tortoise-orm>=0.21,<0.22"
@ -54,9 +62,23 @@ jobs:
- name: Install TortoiseORM v0.23 - name: Install TortoiseORM v0.23
if: matrix.tortoise-orm == 'tortoise023' if: matrix.tortoise-orm == 'tortoise023'
run: poetry run pip install --upgrade "tortoise-orm>=0.23,<0.24" run: poetry run pip install --upgrade "tortoise-orm>=0.23,<0.24"
- name: Install TortoiseORM v0.24
if: matrix.tortoise-orm == 'tortoise024'
run: |
if [[ "${{ matrix.python-version }}" == "3.8" ]]; then
echo "Skip test for tortoise v0.24 as it does not support Python3.8"
else
poetry run pip install --upgrade "tortoise-orm>=0.24,<0.25"
fi
- name: Install TortoiseORM develop branch - name: Install TortoiseORM develop branch
if: matrix.tortoise-orm == 'tortoisedev' if: matrix.tortoise-orm == 'tortoisedev'
run: poetry run pip install --upgrade "git+https://github.com/tortoise/tortoise-orm" run: |
if [[ "${{ matrix.python-version }}" == "3.8" ]]; then
echo "Skip test for tortoise develop branch as it does not support Python3.8"
else
poetry run pip uninstall -y tortoise-orm
poetry run pip install --upgrade "git+https://github.com/tortoise/tortoise-orm"
fi
- name: CI - name: CI
env: env:
MYSQL_PASS: root MYSQL_PASS: root
@ -65,4 +87,23 @@ jobs:
POSTGRES_PASS: 123456 POSTGRES_PASS: 123456
POSTGRES_HOST: 127.0.0.1 POSTGRES_HOST: 127.0.0.1
POSTGRES_PORT: 5432 POSTGRES_PORT: 5432
run: make _testall run: poetry run make _testall
- name: Verify aiomysql support
# Only check the latest version of tortoise
if: matrix.tortoise-orm == 'tortoise024'
run: |
poetry run pip uninstall -y asyncmy
poetry run make test_mysql
poetry run pip install asyncmy
env:
MYSQL_PASS: root
MYSQL_HOST: 127.0.0.1
MYSQL_PORT: 3306
- name: Verify psycopg support
# Only check the latest version of tortoise
if: matrix.tortoise-orm == 'tortoise024'
run: poetry run make test_psycopg
env:
POSTGRES_PASS: 123456
POSTGRES_HOST: 127.0.0.1
POSTGRES_PORT: 5432

View File

@ -2,6 +2,42 @@
## 0.8 ## 0.8
### [0.8.3]**(Unreleased)**
#### Fixed
- fix: `aerich init-db` process is suspended. ([#435])
[#435]: https://github.com/tortoise/aerich/pull/435
### [0.8.2](../../releases/tag/v0.8.2) - 2025-02-28
#### Added
- Support changes `max_length` or int type for primary key field. ([#428])
- feat: support psycopg. ([#425])
- Support run `poetry add aerich` in project that inited by poetry v2. ([#424])
- feat: support command `python -m aerich`. ([#417])
- feat: add --fake to upgrade/downgrade. ([#398])
- Support ignore table by settings `managed=False` in `Meta` class. ([#397])
#### Fixed
- fix: aerich migrate raises tortoise.exceptions.FieldError when `index.INDEX_TYPE` is not empty. ([#415])
- No migration occurs as expected when adding `unique=True` to indexed field. ([#404])
- fix: inspectdb raise KeyError 'int2' for smallint. ([#401])
- fix: inspectdb not match data type 'DOUBLE' and 'CHAR' for MySQL. ([#187])
### Changed
- Refactored version management to use `importlib.metadata.version(__package__)` instead of hardcoded version string ([#412])
[#397]: https://github.com/tortoise/aerich/pull/397
[#398]: https://github.com/tortoise/aerich/pull/398
[#401]: https://github.com/tortoise/aerich/pull/401
[#404]: https://github.com/tortoise/aerich/pull/404
[#412]: https://github.com/tortoise/aerich/pull/412
[#415]: https://github.com/tortoise/aerich/pull/415
[#417]: https://github.com/tortoise/aerich/pull/417
[#424]: https://github.com/tortoise/aerich/pull/424
[#425]: https://github.com/tortoise/aerich/pull/425
### [0.8.1](../../releases/tag/v0.8.1) - 2024-12-27 ### [0.8.1](../../releases/tag/v0.8.1) - 2024-12-27
#### Fixed #### Fixed
@ -29,19 +65,18 @@
[#395]: https://github.com/tortoise/aerich/pull/395 [#395]: https://github.com/tortoise/aerich/pull/395
[#394]: https://github.com/tortoise/aerich/pull/394 [#394]: https://github.com/tortoise/aerich/pull/394
[#393]: https://github.com/tortoise/aerich/pull/393 [#393]: https://github.com/tortoise/aerich/pull/393
[#376]: https://github.com/tortoise/aerich/pull/376 [#392]: https://github.com/tortoise/aerich/pull/392
[#388]: https://github.com/tortoise/aerich/pull/388
[#386]: https://github.com/tortoise/aerich/pull/386 [#386]: https://github.com/tortoise/aerich/pull/386
[#272]: https://github.com/tortoise/aerich/pull/272
[#334]: https://github.com/tortoise/aerich/pull/334
[#284]: https://github.com/tortoise/aerich/pull/284
[#286]: https://github.com/tortoise/aerich/pull/286
[#302]: https://github.com/tortoise/aerich/pull/302
[#378]: https://github.com/tortoise/aerich/pull/378 [#378]: https://github.com/tortoise/aerich/pull/378
[#377]: https://github.com/tortoise/aerich/pull/377 [#377]: https://github.com/tortoise/aerich/pull/377
[#271]: https://github.com/tortoise/aerich/pull/271 [#376]: https://github.com/tortoise/aerich/pull/376
[#334]: https://github.com/tortoise/aerich/pull/334
[#302]: https://github.com/tortoise/aerich/pull/302
[#286]: https://github.com/tortoise/aerich/pull/286 [#286]: https://github.com/tortoise/aerich/pull/286
[#388]: https://github.com/tortoise/aerich/pull/388 [#284]: https://github.com/tortoise/aerich/pull/284
[#392]: https://github.com/tortoise/aerich/pull/392 [#272]: https://github.com/tortoise/aerich/pull/272
[#271]: https://github.com/tortoise/aerich/pull/271
### [0.8.0](../../releases/tag/v0.8.0) - 2024-12-04 ### [0.8.0](../../releases/tag/v0.8.0) - 2024-12-04

View File

@ -1,5 +1,4 @@
checkfiles = aerich/ tests/ conftest.py checkfiles = aerich/ tests/ conftest.py
black_opts = -l 100 -t py38
py_warn = PYTHONDEVMODE=1 py_warn = PYTHONDEVMODE=1
MYSQL_HOST ?= "127.0.0.1" MYSQL_HOST ?= "127.0.0.1"
MYSQL_PORT ?= 3306 MYSQL_PORT ?= 3306
@ -12,20 +11,28 @@ up:
@poetry update @poetry update
deps: deps:
@poetry install -E asyncpg -E asyncmy -E toml @poetry install --all-extras --all-groups
_style: _style:
@isort -src $(checkfiles) @ruff check --fix $(checkfiles)
@black $(black_opts) $(checkfiles) @ruff format $(checkfiles)
style: deps _style style: deps _style
_check: _check:
@black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false) @ruff format --check $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
@ruff check $(checkfiles) @ruff check $(checkfiles)
@mypy $(checkfiles) @mypy $(checkfiles)
@bandit -r aerich @bandit -r aerich
check: deps _check check: deps _check
_lint: _build
@ruff format $(checkfiles)
ruff check --fix $(checkfiles)
mypy $(checkfiles)
bandit -c pyproject.toml -r $(checkfiles)
twine check dist/*
lint: deps _lint
test: deps test: deps
$(py_warn) TEST_DB=sqlite://:memory: pytest $(py_warn) TEST_DB=sqlite://:memory: pytest
@ -38,10 +45,14 @@ test_mysql:
test_postgres: test_postgres:
$(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest -vv -s $(py_warn) TEST_DB="postgres://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest -vv -s
test_psycopg:
$(py_warn) TEST_DB="psycopg://postgres:$(POSTGRES_PASS)@$(POSTGRES_HOST):$(POSTGRES_PORT)/test_\{\}" pytest -vv -s
_testall: test_sqlite test_postgres test_mysql _testall: test_sqlite test_postgres test_mysql
testall: deps _testall testall: deps _testall
build: deps _build:
@poetry build @poetry build
build: deps _build
ci: build _check _testall ci: build _check _testall

View File

@ -226,14 +226,14 @@ from tortoise import Model, fields
class Test(Model): class Test(Model):
date = fields.DateField(null=True, ) date = fields.DateField(null=True)
datetime = fields.DatetimeField(auto_now=True, ) datetime = fields.DatetimeField(auto_now=True)
decimal = fields.DecimalField(max_digits=10, decimal_places=2, ) decimal = fields.DecimalField(max_digits=10, decimal_places=2)
float = fields.FloatField(null=True, ) float = fields.FloatField(null=True)
id = fields.IntField(pk=True, ) id = fields.IntField(primary_key=True)
string = fields.CharField(max_length=200, null=True, ) string = fields.CharField(max_length=200, null=True)
time = fields.TimeField(null=True, ) time = fields.TimeField(null=True)
tinyint = fields.BooleanField(null=True, ) tinyint = fields.BooleanField(null=True)
``` ```
Note that this command is limited and can't infer some fields, such as `IntEnumField`, `ForeignKeyField`, and others. Note that this command is limited and can't infer some fields, such as `IntEnumField`, `ForeignKeyField`, and others.
@ -243,8 +243,8 @@ Note that this command is limited and can't infer some fields, such as `IntEnumF
```python ```python
tortoise_orm = { tortoise_orm = {
"connections": { "connections": {
"default": expand_db_url(db_url, True), "default": "postgres://postgres_user:postgres_pass@127.0.0.1:5432/db1",
"second": expand_db_url(db_url_second, True), "second": "postgres://postgres_user:postgres_pass@127.0.0.1:5432/db2",
}, },
"apps": { "apps": {
"models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"}, "models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"},
@ -253,7 +253,7 @@ tortoise_orm = {
} }
``` ```
You only need to specify `aerich.models` in one app, and must specify `--app` when running `aerich migrate` and so on. You only need to specify `aerich.models` in one app, and must specify `--app` when running `aerich migrate` and so on, e.g. `aerich --app models_second migrate`.
## Restore `aerich` workflow ## Restore `aerich` workflow
@ -273,11 +273,38 @@ You can use `aerich` out of cli by use `Command` class.
```python ```python
from aerich import Command from aerich import Command
command = Command(tortoise_config=config, app='models') async with Command(tortoise_config=config, app='models') as command:
await command.init() await command.migrate('test')
await command.migrate('test') await command.upgrade()
``` ```
## Upgrade/Downgrade with `--fake` option
Marks the migrations up to the latest one(or back to the target one) as applied, but without actually running the SQL to change your database schema.
- Upgrade
```bash
aerich upgrade --fake
aerich --app models upgrade --fake
```
- Downgrade
```bash
aerich downgrade --fake -v 2
aerich --app models downgrade --fake -v 2
```
### Ignore tables
You can tell aerich to ignore table by setting `managed=False` in the `Meta` class, e.g.:
```py
class MyModel(Model):
class Meta:
managed = False
```
**Note** `managed=False` does not recognized by `tortoise-orm` and `aerich init-db`, it is only for `aerich migrate`.
## License ## License
This project is licensed under the This project is licensed under the

View File

@ -1,8 +1,13 @@
import os from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Type
from tortoise import Tortoise, generate_schema_for_client import os
import platform
from contextlib import AbstractAsyncContextManager
from pathlib import Path
from typing import TYPE_CHECKING
import tortoise
from tortoise import Tortoise, connections, generate_schema_for_client
from tortoise.exceptions import OperationalError from tortoise.exceptions import OperationalError
from tortoise.transactions import in_transaction from tortoise.transactions import in_transaction
from tortoise.utils import get_schema_sql from tortoise.utils import get_schema_sql
@ -21,10 +26,117 @@ from aerich.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from aerich.inspectdb import Inspect # noqa:F401 from tortoise import Model
from tortoise.fields.relational import ManyToManyFieldInstance # NOQA:F401
from aerich.inspectdb import Inspect
class Command: def _init_asyncio_patch():
"""
Select compatible event loop for psycopg3.
As of Python 3.8+, the default event loop on Windows is `proactor`,
however psycopg3 requires the old default "selector" event loop.
See https://www.psycopg.org/psycopg3/docs/advanced/async.html
"""
if platform.system() == "Windows":
try:
from asyncio import WindowsSelectorEventLoopPolicy # type:ignore
except ImportError:
pass # Can't assign a policy which doesn't exist.
else:
from asyncio import get_event_loop_policy, set_event_loop_policy
if not isinstance(get_event_loop_policy(), WindowsSelectorEventLoopPolicy):
set_event_loop_policy(WindowsSelectorEventLoopPolicy())
def _init_tortoise_0_24_1_patch():
# this patch is for "tortoise-orm==0.24.1" to fix:
# https://github.com/tortoise/tortoise-orm/issues/1893
if tortoise.__version__ != "0.24.1":
return
from tortoise.backends.base.schema_generator import BaseSchemaGenerator, cast, re
def _get_m2m_tables(
self, model: type[Model], db_table: str, safe: bool, models_tables: list[str]
) -> list[str]: # Copied from tortoise-orm
m2m_tables_for_create = []
for m2m_field in model._meta.m2m_fields:
field_object = cast("ManyToManyFieldInstance", model._meta.fields_map[m2m_field])
if field_object._generated or field_object.through in models_tables:
continue
backward_key, forward_key = field_object.backward_key, field_object.forward_key
if field_object.db_constraint:
backward_fk = self._create_fk_string(
"",
backward_key,
db_table,
model._meta.db_pk_column,
field_object.on_delete,
"",
)
forward_fk = self._create_fk_string(
"",
forward_key,
field_object.related_model._meta.db_table,
field_object.related_model._meta.db_pk_column,
field_object.on_delete,
"",
)
else:
backward_fk = forward_fk = ""
exists = "IF NOT EXISTS " if safe else ""
through_table_name = field_object.through
backward_type = self._get_pk_field_sql_type(model._meta.pk)
forward_type = self._get_pk_field_sql_type(field_object.related_model._meta.pk)
comment = ""
if desc := field_object.description:
comment = self._table_comment_generator(table=through_table_name, comment=desc)
m2m_create_string = self.M2M_TABLE_TEMPLATE.format(
exists=exists,
table_name=through_table_name,
backward_fk=backward_fk,
forward_fk=forward_fk,
backward_key=backward_key,
backward_type=backward_type,
forward_key=forward_key,
forward_type=forward_type,
extra=self._table_generate_extra(table=field_object.through),
comment=comment,
)
if not field_object.db_constraint:
m2m_create_string = m2m_create_string.replace(
""",
,
""",
"",
) # may have better way
m2m_create_string += self._post_table_hook()
if getattr(field_object, "create_unique_index", field_object.unique):
unique_index_create_sql = self._get_unique_index_sql(
exists, through_table_name, [backward_key, forward_key]
)
if unique_index_create_sql.endswith(";"):
m2m_create_string += "\n" + unique_index_create_sql
else:
lines = m2m_create_string.splitlines()
lines[-2] += ","
indent = m.group() if (m := re.match(r"\s+", lines[-2])) else ""
lines.insert(-1, indent + unique_index_create_sql)
m2m_create_string = "\n".join(lines)
m2m_tables_for_create.append(m2m_create_string)
return m2m_tables_for_create
setattr(BaseSchemaGenerator, "_get_m2m_tables", _get_m2m_tables)
_init_asyncio_patch()
_init_tortoise_0_24_1_patch()
class Command(AbstractAsyncContextManager):
def __init__( def __init__(
self, self,
tortoise_config: dict, tortoise_config: dict,
@ -39,18 +151,29 @@ class Command:
async def init(self) -> None: async def init(self) -> None:
await Migrate.init(self.tortoise_config, self.app, self.location) await Migrate.init(self.tortoise_config, self.app, self.location)
async def _upgrade(self, conn, version_file) -> None: async def __aenter__(self) -> Command:
await self.init()
return self
async def close(self) -> None:
await connections.close_all()
async def __aexit__(self, *args, **kw) -> None:
await self.close()
async def _upgrade(self, conn, version_file, fake: bool = False) -> None:
file_path = Path(Migrate.migrate_location, version_file) file_path = Path(Migrate.migrate_location, version_file)
m = import_py_file(file_path) m = import_py_file(file_path)
upgrade = m.upgrade upgrade = m.upgrade
await conn.execute_script(await upgrade(conn)) if not fake:
await conn.execute_script(await upgrade(conn))
await Aerich.create( await Aerich.create(
version=version_file, version=version_file,
app=self.app, app=self.app,
content=get_models_describe(self.app), content=get_models_describe(self.app),
) )
async def upgrade(self, run_in_transaction: bool = True) -> List[str]: async def upgrade(self, run_in_transaction: bool = True, fake: bool = False) -> list[str]:
migrated = [] migrated = []
for version_file in Migrate.get_all_version_files(): for version_file in Migrate.get_all_version_files():
try: try:
@ -61,15 +184,15 @@ class Command:
app_conn_name = get_app_connection_name(self.tortoise_config, self.app) app_conn_name = get_app_connection_name(self.tortoise_config, self.app)
if run_in_transaction: if run_in_transaction:
async with in_transaction(app_conn_name) as conn: async with in_transaction(app_conn_name) as conn:
await self._upgrade(conn, version_file) await self._upgrade(conn, version_file, fake=fake)
else: else:
app_conn = get_app_connection(self.tortoise_config, self.app) app_conn = get_app_connection(self.tortoise_config, self.app)
await self._upgrade(app_conn, version_file) await self._upgrade(app_conn, version_file, fake=fake)
migrated.append(version_file) migrated.append(version_file)
return migrated return migrated
async def downgrade(self, version: int, delete: bool) -> List[str]: async def downgrade(self, version: int, delete: bool, fake: bool = False) -> list[str]:
ret: List[str] = [] ret: list[str] = []
if version == -1: if version == -1:
specified_version = await Migrate.get_last_version() specified_version = await Migrate.get_last_version()
else: else:
@ -93,14 +216,15 @@ class Command:
downgrade_sql = await downgrade(conn) downgrade_sql = await downgrade(conn)
if not downgrade_sql.strip(): if not downgrade_sql.strip():
raise DowngradeError("No downgrade items found") raise DowngradeError("No downgrade items found")
await conn.execute_script(downgrade_sql) if not fake:
await conn.execute_script(downgrade_sql)
await version_obj.delete() await version_obj.delete()
if delete: if delete:
os.unlink(file_path) os.unlink(file_path)
ret.append(file) ret.append(file)
return ret return ret
async def heads(self) -> List[str]: async def heads(self) -> list[str]:
ret = [] ret = []
versions = Migrate.get_all_version_files() versions = Migrate.get_all_version_files()
for version in versions: for version in versions:
@ -108,15 +232,15 @@ class Command:
ret.append(version) ret.append(version)
return ret return ret
async def history(self) -> List[str]: async def history(self) -> list[str]:
versions = Migrate.get_all_version_files() versions = Migrate.get_all_version_files()
return [version for version in versions] return [version for version in versions]
async def inspectdb(self, tables: Optional[List[str]] = None) -> str: async def inspectdb(self, tables: list[str] | None = None) -> str:
connection = get_app_connection(self.tortoise_config, self.app) connection = get_app_connection(self.tortoise_config, self.app)
dialect = connection.schema_generator.DIALECT dialect = connection.schema_generator.DIALECT
if dialect == "mysql": if dialect == "mysql":
cls: Type["Inspect"] = InspectMySQL cls: type[Inspect] = InspectMySQL
elif dialect == "postgres": elif dialect == "postgres":
cls = InspectPostgres cls = InspectPostgres
elif dialect == "sqlite": elif dialect == "sqlite":

3
aerich/__main__.py Normal file
View File

@ -0,0 +1,3 @@
from .cli import main
main()

28
aerich/_compat.py Normal file
View File

@ -0,0 +1,28 @@
# mypy: disable-error-code="no-redef"
from __future__ import annotations
import sys
from types import ModuleType
import tortoise
if sys.version_info >= (3, 11):
import tomllib
else:
try:
import tomli as tomllib
except ImportError:
import tomlkit as tomllib
def imports_tomlkit() -> ModuleType:
try:
import tomli_w as tomlkit
except ImportError:
import tomlkit
return tomlkit
def tortoise_version_less_than(version: str) -> bool:
# The min version of tortoise is '0.11.0', so we can compare it by a `<`,
return tortoise.__version__ < version

View File

@ -1,30 +1,40 @@
from __future__ import annotations
import os import os
import sys
from pathlib import Path from pathlib import Path
from typing import Dict, List, cast from typing import cast
import asyncclick as click import asyncclick as click
from asyncclick import Context, UsageError from asyncclick import Context, UsageError
from aerich import Command from aerich import Command
from aerich._compat import imports_tomlkit, tomllib
from aerich.enums import Color from aerich.enums import Color
from aerich.exceptions import DowngradeError from aerich.exceptions import DowngradeError
from aerich.utils import add_src_path, get_tortoise_config from aerich.utils import add_src_path, get_tortoise_config
from aerich.version import __version__ from aerich.version import __version__
if sys.version_info >= (3, 11):
import tomllib
else:
try:
import tomli as tomllib
except ImportError:
import tomlkit as tomllib # type: ignore
CONFIG_DEFAULT_VALUES = { CONFIG_DEFAULT_VALUES = {
"src_folder": ".", "src_folder": ".",
} }
def _patch_context_to_close_tortoise_connections_when_exit() -> None:
from tortoise import Tortoise, connections
origin_aexit = Context.__aexit__
async def aexit(*args, **kw) -> None:
await origin_aexit(*args, **kw)
if Tortoise._inited:
await connections.close_all()
Context.__aexit__ = aexit # type:ignore[method-assign]
_patch_context_to_close_tortoise_connections_when_exit()
@click.group(context_settings={"help_option_names": ["-h", "--help"]}) @click.group(context_settings={"help_option_names": ["-h", "--help"]})
@click.version_option(__version__, "-V", "--version") @click.version_option(__version__, "-V", "--version")
@click.option( @click.option(
@ -50,7 +60,7 @@ async def cli(ctx: Context, config, app) -> None:
content = config_path.read_text("utf-8") content = config_path.read_text("utf-8")
doc: dict = tomllib.loads(content) doc: dict = tomllib.loads(content)
try: try:
tool = cast(Dict[str, str], doc["tool"]["aerich"]) tool = cast("dict[str, str]", doc["tool"]["aerich"])
location = tool["location"] location = tool["location"]
tortoise_orm = tool["tortoise_orm"] tortoise_orm = tool["tortoise_orm"]
src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"]) src_folder = tool.get("src_folder", CONFIG_DEFAULT_VALUES["src_folder"])
@ -61,7 +71,10 @@ async def cli(ctx: Context, config, app) -> None:
add_src_path(src_folder) add_src_path(src_folder)
tortoise_config = get_tortoise_config(ctx, tortoise_orm) tortoise_config = get_tortoise_config(ctx, tortoise_orm)
if not app: if not app:
apps_config = cast(dict, tortoise_config.get("apps")) try:
apps_config = cast(dict, tortoise_config["apps"])
except KeyError:
raise UsageError('Config must define "apps" section')
app = list(apps_config.keys())[0] app = list(apps_config.keys())[0]
command = Command(tortoise_config=tortoise_config, app=app, location=location) command = Command(tortoise_config=tortoise_config, app=app, location=location)
ctx.obj["command"] = command ctx.obj["command"] = command
@ -93,15 +106,26 @@ async def migrate(ctx: Context, name, empty) -> None:
type=bool, type=bool,
help="Make migrations in a single transaction or not. Can be helpful for large migrations or creating concurrent indexes.", help="Make migrations in a single transaction or not. Can be helpful for large migrations or creating concurrent indexes.",
) )
@click.option(
"--fake",
default=False,
is_flag=True,
help="Mark migrations as run without actually running them.",
)
@click.pass_context @click.pass_context
async def upgrade(ctx: Context, in_transaction: bool) -> None: async def upgrade(ctx: Context, in_transaction: bool, fake: bool) -> None:
command = ctx.obj["command"] command = ctx.obj["command"]
migrated = await command.upgrade(run_in_transaction=in_transaction) migrated = await command.upgrade(run_in_transaction=in_transaction, fake=fake)
if not migrated: if not migrated:
click.secho("No upgrade items found", fg=Color.yellow) click.secho("No upgrade items found", fg=Color.yellow)
else: else:
for version_file in migrated: for version_file in migrated:
click.secho(f"Success upgrading to {version_file}", fg=Color.green) if fake:
click.echo(
f"Upgrading to {version_file}... " + click.style("FAKED", fg=Color.green)
)
else:
click.secho(f"Success upgrading to {version_file}", fg=Color.green)
@cli.command(help="Downgrade to specified version.") @cli.command(help="Downgrade to specified version.")
@ -121,18 +145,27 @@ async def upgrade(ctx: Context, in_transaction: bool) -> None:
show_default=True, show_default=True,
help="Also delete the migration files.", help="Also delete the migration files.",
) )
@click.option(
"--fake",
default=False,
is_flag=True,
help="Mark migrations as run without actually running them.",
)
@click.pass_context @click.pass_context
@click.confirmation_option( @click.confirmation_option(
prompt="Downgrade is dangerous: you might lose your data! Are you sure?", prompt="Downgrade is dangerous: you might lose your data! Are you sure?",
) )
async def downgrade(ctx: Context, version: int, delete: bool) -> None: async def downgrade(ctx: Context, version: int, delete: bool, fake: bool) -> None:
command = ctx.obj["command"] command = ctx.obj["command"]
try: try:
files = await command.downgrade(version, delete) files = await command.downgrade(version, delete, fake=fake)
except DowngradeError as e: except DowngradeError as e:
return click.secho(str(e), fg=Color.yellow) return click.secho(str(e), fg=Color.yellow)
for file in files: for file in files:
click.secho(f"Success downgrading to {file}", fg=Color.green) if fake:
click.echo(f"Downgrading to {file}... " + click.style("FAKED", fg=Color.green))
else:
click.secho(f"Success downgrading to {file}", fg=Color.green)
@cli.command(help="Show currently available heads (unapplied migrations).") @cli.command(help="Show currently available heads (unapplied migrations).")
@ -157,6 +190,16 @@ async def history(ctx: Context) -> None:
click.secho(version, fg=Color.green) click.secho(version, fg=Color.green)
def _write_config(config_path, doc, table) -> None:
tomlkit = imports_tomlkit()
try:
doc["tool"]["aerich"] = table
except KeyError:
doc["tool"] = {"aerich": table}
config_path.write_text(tomlkit.dumps(doc))
@cli.command(help="Initialize aerich config and create migrations folder.") @cli.command(help="Initialize aerich config and create migrations folder.")
@click.option( @click.option(
"-t", "-t",
@ -179,10 +222,6 @@ async def history(ctx: Context) -> None:
) )
@click.pass_context @click.pass_context
async def init(ctx: Context, tortoise_orm, location, src_folder) -> None: async def init(ctx: Context, tortoise_orm, location, src_folder) -> None:
try:
import tomli_w as tomlkit
except ImportError:
import tomlkit # type: ignore
config_file = ctx.obj["config_file"] config_file = ctx.obj["config_file"]
if os.path.isabs(src_folder): if os.path.isabs(src_folder):
@ -197,20 +236,18 @@ async def init(ctx: Context, tortoise_orm, location, src_folder) -> None:
config_path = Path(config_file) config_path = Path(config_file)
content = config_path.read_text("utf-8") if config_path.exists() else "[tool.aerich]" content = config_path.read_text("utf-8") if config_path.exists() else "[tool.aerich]"
doc: dict = tomllib.loads(content) doc: dict = tomllib.loads(content)
table: dict = getattr(tomlkit, "table", dict)()
table["tortoise_orm"] = tortoise_orm table = {"tortoise_orm": tortoise_orm, "location": location, "src_folder": src_folder}
table["location"] = location if (aerich_config := doc.get("tool", {}).get("aerich")) and all(
table["src_folder"] = src_folder aerich_config.get(k) == v for k, v in table.items()
try: ):
doc["tool"]["aerich"] = table click.echo(f"Aerich config {config_file} already inited.")
except KeyError: else:
doc["tool"] = {"aerich": table} _write_config(config_path, doc, table)
config_path.write_text(tomlkit.dumps(doc)) click.secho(f"Success writing aerich config to {config_file}", fg=Color.green)
Path(location).mkdir(parents=True, exist_ok=True) Path(location).mkdir(parents=True, exist_ok=True)
click.secho(f"Success creating migrations folder {location}", fg=Color.green) click.secho(f"Success creating migrations folder {location}", fg=Color.green)
click.secho(f"Success writing aerich config to {config_file}", fg=Color.green)
@cli.command(help="Generate schema and generate app migration folder.") @cli.command(help="Generate schema and generate app migration folder.")
@ -247,7 +284,7 @@ async def init_db(ctx: Context, safe: bool) -> None:
required=False, required=False,
) )
@click.pass_context @click.pass_context
async def inspectdb(ctx: Context, table: List[str]) -> None: async def inspectdb(ctx: Context, table: list[str]) -> None:
command = ctx.obj["command"] command = ctx.obj["command"]
ret = await command.inspectdb(table) ret = await command.inspectdb(table)
click.secho(ret) click.secho(ret)

View File

@ -1,7 +1,9 @@
from __future__ import annotations
import base64 import base64
import json import json
import pickle # nosec: B301,B403 import pickle # nosec: B301,B403
from typing import Any, Union from typing import Any
from tortoise.indexes import Index from tortoise.indexes import Index
@ -9,6 +11,9 @@ from tortoise.indexes import Index
class JsonEncoder(json.JSONEncoder): class JsonEncoder(json.JSONEncoder):
def default(self, obj) -> Any: def default(self, obj) -> Any:
if isinstance(obj, Index): if isinstance(obj, Index):
if hasattr(obj, "describe"):
# For tortoise>=0.24
return obj.describe()
return { return {
"type": "index", "type": "index",
"val": base64.b64encode(pickle.dumps(obj)).decode(), # nosec: B301 "val": base64.b64encode(pickle.dumps(obj)).decode(), # nosec: B301
@ -18,15 +23,27 @@ class JsonEncoder(json.JSONEncoder):
def object_hook(obj) -> Any: def object_hook(obj) -> Any:
_type = obj.get("type") if (type_ := obj.get("type")) and type_ == "index" and (val := obj.get("val")):
if not _type: return pickle.loads(base64.b64decode(val)) # nosec: B301
return obj return obj
return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301
def load_index(obj: dict) -> Index:
"""Convert a dict that generated by `Index.decribe()` to a Index instance"""
try:
index = Index(fields=obj["fields"] or obj["expressions"], name=obj.get("name"))
except KeyError:
return object_hook(obj)
if extra := obj.get("extra"):
index.extra = extra
if idx_type := obj.get("type"):
index.INDEX_TYPE = idx_type
return index
def encoder(obj: dict) -> str: def encoder(obj: dict) -> str:
return json.dumps(obj, cls=JsonEncoder) return json.dumps(obj, cls=JsonEncoder)
def decoder(obj: Union[str, bytes]) -> Any: def decoder(obj: str | bytes) -> Any:
return json.loads(obj, object_hook=object_hook) return json.loads(obj, object_hook=object_hook)

View File

@ -1,15 +1,20 @@
from __future__ import annotations
import re
from enum import Enum from enum import Enum
from typing import Any, List, Type, cast from typing import TYPE_CHECKING, Any, cast
from tortoise import BaseDBAsyncClient, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
from aerich._compat import tortoise_version_less_than
from aerich.utils import is_default_function from aerich.utils import is_default_function
if TYPE_CHECKING:
from tortoise import BaseDBAsyncClient, Model
class BaseDDL: class BaseDDL:
schema_generator_cls: Type[BaseSchemaGenerator] = BaseSchemaGenerator schema_generator_cls: type[BaseSchemaGenerator] = BaseSchemaGenerator
DIALECT = "sql" DIALECT = "sql"
_DROP_TABLE_TEMPLATE = 'DROP TABLE IF EXISTS "{table_name}"' _DROP_TABLE_TEMPLATE = 'DROP TABLE IF EXISTS "{table_name}"'
_ADD_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {column}' _ADD_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {column}'
@ -18,10 +23,8 @@ class BaseDDL:
_RENAME_COLUMN_TEMPLATE = ( _RENAME_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" RENAME COLUMN "{old_column_name}" TO "{new_column_name}"' 'ALTER TABLE "{table_name}" RENAME COLUMN "{old_column_name}" TO "{new_column_name}"'
) )
_ADD_INDEX_TEMPLATE = ( _ADD_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {index_type}{unique}INDEX "{index_name}" ({column_names}){extra}'
'ALTER TABLE "{table_name}" ADD {unique}INDEX "{index_name}" ({column_names})' _DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX IF EXISTS "{index_name}"'
)
_DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"'
_ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}' _ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"'
_M2M_TABLE_TEMPLATE = ( _M2M_TABLE_TEMPLATE = (
@ -36,20 +39,26 @@ class BaseDDL:
) )
_RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"' _RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"'
def __init__(self, client: "BaseDBAsyncClient") -> None: def __init__(self, client: BaseDBAsyncClient) -> None:
self.client = client self.client = client
self.schema_generator = self.schema_generator_cls(client) self.schema_generator = self.schema_generator_cls(client)
def create_table(self, model: "Type[Model]") -> str: @staticmethod
return self.schema_generator._get_table_sql(model, True)["table_creation_string"].rstrip( def get_table_name(model: type[Model]) -> str:
";" return model._meta.db_table
)
def create_table(self, model: type[Model]) -> str:
schema = self.schema_generator._get_table_sql(model, True)["table_creation_string"]
if tortoise_version_less_than("0.23.1"):
# Remove extra space
schema = re.sub(r'(["()A-Za-z]) (["()A-Za-z])', r"\1 \2", schema)
return schema.rstrip(";")
def drop_table(self, table_name: str) -> str: def drop_table(self, table_name: str) -> str:
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name) return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def create_m2m( def create_m2m(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict self, model: type[Model], field_describe: dict, reference_table_describe: dict
) -> str: ) -> str:
through = cast(str, field_describe.get("through")) through = cast(str, field_describe.get("through"))
description = field_describe.get("description") description = field_describe.get("description")
@ -78,7 +87,7 @@ class BaseDDL:
def drop_m2m(self, table_name: str) -> str: def drop_m2m(self, table_name: str) -> str:
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name) return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def _get_default(self, model: "Type[Model]", field_describe: dict) -> Any: def _get_default(self, model: type[Model], field_describe: dict) -> Any:
db_table = model._meta.db_table db_table = model._meta.db_table
default = field_describe.get("default") default = field_describe.get("default")
if isinstance(default, Enum): if isinstance(default, Enum):
@ -104,14 +113,14 @@ class BaseDDL:
) )
except NotImplementedError: except NotImplementedError:
default = "" default = ""
else:
default = None
return default return default
def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str: def add_column(self, model: type[Model], field_describe: dict, is_pk: bool = False) -> str:
return self._add_or_modify_column(model, field_describe, is_pk) return self._add_or_modify_column(model, field_describe, is_pk)
def _add_or_modify_column(self, model, field_describe: dict, is_pk: bool, modify=False) -> str: def _add_or_modify_column(
self, model: type[Model], field_describe: dict, is_pk: bool, modify: bool = False
) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
description = field_describe.get("description") description = field_describe.get("description")
db_column = cast(str, field_describe.get("db_column")) db_column = cast(str, field_describe.get("db_column"))
@ -124,44 +133,38 @@ class BaseDDL:
template = self._MODIFY_COLUMN_TEMPLATE template = self._MODIFY_COLUMN_TEMPLATE
else: else:
# sqlite does not support alter table to add unique column # sqlite does not support alter table to add unique column
unique = ( unique = " UNIQUE" if field_describe.get("unique") and self.DIALECT != "sqlite" else ""
"UNIQUE"
if field_describe.get("unique") and self.DIALECT != SqliteSchemaGenerator.DIALECT
else ""
)
template = self._ADD_COLUMN_TEMPLATE template = self._ADD_COLUMN_TEMPLATE
return template.format( column = self.schema_generator._create_string(
table_name=db_table, db_column=db_column,
column=self.schema_generator._create_string( field_type=db_field_types.get(self.DIALECT, db_field_types.get("")),
db_column=db_column, nullable=" NOT NULL" if not field_describe.get("nullable") else "",
field_type=db_field_types.get(self.DIALECT, db_field_types.get("")), unique=unique,
nullable="NOT NULL" if not field_describe.get("nullable") else "", comment=(
unique=unique, self.schema_generator._column_comment_generator(
comment=( table=db_table,
self.schema_generator._column_comment_generator( column=db_column,
table=db_table, comment=description,
column=db_column, )
comment=description, if description
) else ""
if description
else ""
),
is_primary_key=is_pk,
default=default,
), ),
is_primary_key=is_pk,
default=default,
) )
if tortoise_version_less_than("0.23.1"):
column = column.replace(" ", " ")
return template.format(table_name=db_table, column=column)
def drop_column(self, model: "Type[Model]", column_name: str) -> str: def drop_column(self, model: type[Model], column_name: str) -> str:
return self._DROP_COLUMN_TEMPLATE.format( return self._DROP_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, column_name=column_name table_name=model._meta.db_table, column_name=column_name
) )
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str: def modify_column(self, model: type[Model], field_describe: dict, is_pk: bool = False) -> str:
return self._add_or_modify_column(model, field_describe, is_pk, modify=True) return self._add_or_modify_column(model, field_describe, is_pk, modify=True)
def rename_column( def rename_column(self, model: type[Model], old_column_name: str, new_column_name: str) -> str:
self, model: "Type[Model]", old_column_name: str, new_column_name: str
) -> str:
return self._RENAME_COLUMN_TEMPLATE.format( return self._RENAME_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, table_name=model._meta.db_table,
old_column_name=old_column_name, old_column_name=old_column_name,
@ -169,7 +172,7 @@ class BaseDDL:
) )
def change_column( def change_column(
self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str self, model: type[Model], old_column_name: str, new_column_name: str, new_column_type: str
) -> str: ) -> str:
return self._CHANGE_COLUMN_TEMPLATE.format( return self._CHANGE_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, table_name=model._meta.db_table,
@ -178,39 +181,61 @@ class BaseDDL:
new_column_type=new_column_type, new_column_type=new_column_type,
) )
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str: def _index_name(self, unique: bool | None, model: type[Model], field_names: list[str]) -> str:
func_name = "_get_index_name"
if not hasattr(self.schema_generator, func_name):
# For tortoise-orm<0.24.1
func_name = "_generate_index_name"
return getattr(self.schema_generator, func_name)(
"idx" if not unique else "uid", model, field_names
)
def add_index(
self,
model: type[Model],
field_names: list[str],
unique: bool | None = False,
name: str | None = None,
index_type: str = "",
extra: str | None = "",
) -> str:
return self._ADD_INDEX_TEMPLATE.format( return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE " if unique else "", unique="UNIQUE " if unique else "",
index_name=self.schema_generator._generate_index_name( index_name=name or self._index_name(unique, model, field_names),
"idx" if not unique else "uid", model, field_names
),
table_name=model._meta.db_table, table_name=model._meta.db_table,
column_names=", ".join(self.schema_generator.quote(f) for f in field_names), column_names=", ".join(self.schema_generator.quote(f) for f in field_names),
index_type=f"{index_type} " if index_type else "",
extra=f"{extra}" if extra else "",
) )
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str: def drop_index(
self,
model: type[Model],
field_names: list[str],
unique: bool | None = False,
name: str | None = None,
) -> str:
return self._DROP_INDEX_TEMPLATE.format( return self._DROP_INDEX_TEMPLATE.format(
index_name=self.schema_generator._generate_index_name( index_name=name or self._index_name(unique, model, field_names),
"idx" if not unique else "uid", model, field_names
),
table_name=model._meta.db_table, table_name=model._meta.db_table,
) )
def drop_index_by_name(self, model: "Type[Model]", index_name: str) -> str: def drop_index_by_name(self, model: type[Model], index_name: str) -> str:
return self._DROP_INDEX_TEMPLATE.format( return self.drop_index(model, [], name=index_name)
index_name=index_name,
table_name=model._meta.db_table,
)
def _generate_fk_name( def _generate_fk_name(
self, db_table, field_describe: dict, reference_table_describe: dict self, db_table: str, field_describe: dict, reference_table_describe: dict
) -> str: ) -> str:
"""Generate fk name""" """Generate fk name"""
db_column = cast(str, field_describe.get("raw_field")) db_column = cast(str, field_describe.get("raw_field"))
pk_field = cast(dict, reference_table_describe.get("pk_field")) pk_field = cast(dict, reference_table_describe.get("pk_field"))
to_field = cast(str, pk_field.get("db_column")) to_field = cast(str, pk_field.get("db_column"))
to_table = cast(str, reference_table_describe.get("table")) to_table = cast(str, reference_table_describe.get("table"))
return self.schema_generator._generate_fk_name( func_name = "_get_fk_name"
if not hasattr(self.schema_generator, func_name):
# For tortoise-orm<0.24.1
func_name = "_generate_fk_name"
return getattr(self.schema_generator, func_name)(
from_table=db_table, from_table=db_table,
from_field=db_column, from_field=db_column,
to_table=to_table, to_table=to_table,
@ -218,7 +243,7 @@ class BaseDDL:
) )
def add_fk( def add_fk(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict self, model: type[Model], field_describe: dict, reference_table_describe: dict
) -> str: ) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
@ -235,13 +260,13 @@ class BaseDDL:
) )
def drop_fk( def drop_fk(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict self, model: type[Model], field_describe: dict, reference_table_describe: dict
) -> str: ) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
fk_name = self._generate_fk_name(db_table, field_describe, reference_table_describe) fk_name = self._generate_fk_name(db_table, field_describe, reference_table_describe)
return self._DROP_FK_TEMPLATE.format(table_name=db_table, fk_name=fk_name) return self._DROP_FK_TEMPLATE.format(table_name=db_table, fk_name=fk_name)
def alter_column_default(self, model: "Type[Model]", field_describe: dict) -> str: def alter_column_default(self, model: type[Model], field_describe: dict) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
default = self._get_default(model, field_describe) default = self._get_default(model, field_describe)
return self._ALTER_DEFAULT_TEMPLATE.format( return self._ALTER_DEFAULT_TEMPLATE.format(
@ -250,14 +275,28 @@ class BaseDDL:
default="SET" + default if default is not None else "DROP DEFAULT", default="SET" + default if default is not None else "DROP DEFAULT",
) )
def alter_column_null(self, model: "Type[Model]", field_describe: dict) -> str: def alter_column_null(self, model: type[Model], field_describe: dict) -> str:
return self.modify_column(model, field_describe) return self.modify_column(model, field_describe)
def set_comment(self, model: "Type[Model]", field_describe: dict) -> str: def set_comment(self, model: type[Model], field_describe: dict) -> str:
return self.modify_column(model, field_describe) return self.modify_column(model, field_describe)
def rename_table(self, model: "Type[Model]", old_table_name: str, new_table_name: str) -> str: def rename_table(self, model: type[Model], old_table_name: str, new_table_name: str) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
return self._RENAME_TABLE_TEMPLATE.format( return self._RENAME_TABLE_TEMPLATE.format(
table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name
) )
def alter_indexed_column_unique(
self, model: type[Model], field_name: str, drop: bool = False
) -> list[str]:
"""Change unique constraint for indexed field, e.g.: Field(db_index=True) --> Field(unique=True)"""
fields = [field_name]
if drop:
drop_unique = self.drop_index(model, fields, unique=True)
add_normal_index = self.add_index(model, fields, unique=False)
return [drop_unique, add_normal_index]
else:
drop_index = self.drop_index(model, fields, unique=False)
add_unique_index = self.add_index(model, fields, unique=True)
return [drop_index, add_unique_index]

View File

@ -1,11 +1,13 @@
from typing import TYPE_CHECKING, List, Type from __future__ import annotations
from typing import TYPE_CHECKING
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
if TYPE_CHECKING: if TYPE_CHECKING:
from tortoise import Model # noqa:F401 from tortoise import Model
class MysqlDDL(BaseDDL): class MysqlDDL(BaseDDL):
@ -21,10 +23,14 @@ class MysqlDDL(BaseDDL):
_RENAME_COLUMN_TEMPLATE = ( _RENAME_COLUMN_TEMPLATE = (
"ALTER TABLE `{table_name}` RENAME COLUMN `{old_column_name}` TO `{new_column_name}`" "ALTER TABLE `{table_name}` RENAME COLUMN `{old_column_name}` TO `{new_column_name}`"
) )
_ADD_INDEX_TEMPLATE = ( _ADD_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` ADD {index_type}{unique}INDEX `{index_name}` ({column_names}){extra}"
"ALTER TABLE `{table_name}` ADD {unique}INDEX `{index_name}` ({column_names})"
)
_DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`" _DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`"
_ADD_INDEXED_UNIQUE_TEMPLATE = (
"ALTER TABLE `{table_name}` DROP INDEX `{index_name}`, ADD UNIQUE (`{column_name}`)"
)
_DROP_INDEXED_UNIQUE_TEMPLATE = (
"ALTER TABLE `{table_name}` DROP INDEX `{column_name}`, ADD INDEX (`{index_name}`)"
)
_ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}" _ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}"
_DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`" _DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`"
_M2M_TABLE_TEMPLATE = ( _M2M_TABLE_TEMPLATE = (
@ -36,28 +42,20 @@ class MysqlDDL(BaseDDL):
_MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}" _MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}"
_RENAME_TABLE_TEMPLATE = "ALTER TABLE `{old_table_name}` RENAME TO `{new_table_name}`" _RENAME_TABLE_TEMPLATE = "ALTER TABLE `{old_table_name}` RENAME TO `{new_table_name}`"
def _index_name(self, unique: bool, model: "Type[Model]", field_names: List[str]) -> str: def _index_name(self, unique: bool | None, model: type[Model], field_names: list[str]) -> str:
if unique: if unique and len(field_names) == 1:
if len(field_names) == 1: # Example: `email = CharField(max_length=50, unique=True)`
# Example: `email = CharField(max_length=50, unique=True)` # Generate schema: `"email" VARCHAR(10) NOT NULL UNIQUE`
# Generate schema: `"email" VARCHAR(10) NOT NULL UNIQUE` # Unique index key is the same as field name: `email`
# Unique index key is the same as field name: `email` return field_names[0]
return field_names[0] return super()._index_name(unique, model, field_names)
index_prefix = "uid"
else:
index_prefix = "idx"
return self.schema_generator._generate_index_name(index_prefix, model, field_names)
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str: def alter_indexed_column_unique(
return self._ADD_INDEX_TEMPLATE.format( self, model: type[Model], field_name: str, drop: bool = False
unique="UNIQUE " if unique else "", ) -> list[str]:
index_name=self._index_name(unique, model, field_names), # if drop is false: Drop index and add unique
table_name=model._meta.db_table, # else: Drop unique index and add normal index
column_names=", ".join(self.schema_generator.quote(f) for f in field_names), template = self._DROP_INDEXED_UNIQUE_TEMPLATE if drop else self._ADD_INDEXED_UNIQUE_TEMPLATE
) table = self.get_table_name(model)
index = self._index_name(unique=False, model=model, field_names=[field_name])
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str: return [template.format(table_name=table, index_name=index, column_name=field_name)]
return self._DROP_INDEX_TEMPLATE.format(
index_name=self._index_name(unique, model, field_names),
table_name=model._meta.db_table,
)

View File

@ -1,15 +1,17 @@
from typing import Type, cast from __future__ import annotations
from typing import cast
from tortoise import Model from tortoise import Model
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator from tortoise.backends.base_postgres.schema_generator import BasePostgresSchemaGenerator
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
class PostgresDDL(BaseDDL): class PostgresDDL(BaseDDL):
schema_generator_cls = AsyncpgSchemaGenerator schema_generator_cls = BasePostgresSchemaGenerator
DIALECT = AsyncpgSchemaGenerator.DIALECT DIALECT = BasePostgresSchemaGenerator.DIALECT
_ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX "{index_name}" ON "{table_name}" ({column_names})' _ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX IF NOT EXISTS "{index_name}" ON "{table_name}" {index_type}({column_names}){extra}'
_DROP_INDEX_TEMPLATE = 'DROP INDEX IF EXISTS "{index_name}"' _DROP_INDEX_TEMPLATE = 'DROP INDEX IF EXISTS "{index_name}"'
_ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL' _ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL'
_MODIFY_COLUMN_TEMPLATE = ( _MODIFY_COLUMN_TEMPLATE = (
@ -18,7 +20,7 @@ class PostgresDDL(BaseDDL):
_SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}' _SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT IF EXISTS "{fk_name}"' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT IF EXISTS "{fk_name}"'
def alter_column_null(self, model: "Type[Model]", field_describe: dict) -> str: def alter_column_null(self, model: type[Model], field_describe: dict) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
return self._ALTER_NULL_TEMPLATE.format( return self._ALTER_NULL_TEMPLATE.format(
table_name=db_table, table_name=db_table,
@ -26,7 +28,7 @@ class PostgresDDL(BaseDDL):
set_drop="DROP" if field_describe.get("nullable") else "SET", set_drop="DROP" if field_describe.get("nullable") else "SET",
) )
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str: def modify_column(self, model: type[Model], field_describe: dict, is_pk: bool = False) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
db_field_types = cast(dict, field_describe.get("db_field_types")) db_field_types = cast(dict, field_describe.get("db_field_types"))
db_column = field_describe.get("db_column") db_column = field_describe.get("db_column")
@ -38,7 +40,7 @@ class PostgresDDL(BaseDDL):
using=f' USING "{db_column}"::{datatype}', using=f' USING "{db_column}"::{datatype}',
) )
def set_comment(self, model: "Type[Model]", field_describe: dict) -> str: def set_comment(self, model: type[Model], field_describe: dict) -> str:
db_table = model._meta.db_table db_table = model._meta.db_table
return self._SET_COMMENT_TEMPLATE.format( return self._SET_COMMENT_TEMPLATE.format(
table_name=db_table, table_name=db_table,

View File

@ -1,4 +1,4 @@
from typing import Type from __future__ import annotations
from tortoise import Model from tortoise import Model
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
@ -13,14 +13,14 @@ class SqliteDDL(BaseDDL):
_ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX "{index_name}" ON "{table_name}" ({column_names})' _ADD_INDEX_TEMPLATE = 'CREATE {unique}INDEX "{index_name}" ON "{table_name}" ({column_names})'
_DROP_INDEX_TEMPLATE = 'DROP INDEX IF EXISTS "{index_name}"' _DROP_INDEX_TEMPLATE = 'DROP INDEX IF EXISTS "{index_name}"'
def modify_column(self, model: "Type[Model]", field_object: dict, is_pk: bool = True): def modify_column(self, model: type[Model], field_object: dict, is_pk: bool = True):
raise NotSupportError("Modify column is unsupported in SQLite.") raise NotSupportError("Modify column is unsupported in SQLite.")
def alter_column_default(self, model: "Type[Model]", field_describe: dict): def alter_column_default(self, model: type[Model], field_describe: dict):
raise NotSupportError("Alter column default is unsupported in SQLite.") raise NotSupportError("Alter column default is unsupported in SQLite.")
def alter_column_null(self, model: "Type[Model]", field_describe: dict): def alter_column_null(self, model: type[Model], field_describe: dict):
raise NotSupportError("Alter column null is unsupported in SQLite.") raise NotSupportError("Alter column null is unsupported in SQLite.")
def set_comment(self, model: "Type[Model]", field_describe: dict): def set_comment(self, model: type[Model], field_describe: dict):
raise NotSupportError("Alter column comment is unsupported in SQLite.") raise NotSupportError("Alter column comment is unsupported in SQLite.")

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib
from typing import Any, Callable, Dict, Optional, TypedDict from typing import Any, Callable, Dict, TypedDict
from pydantic import BaseModel from pydantic import BaseModel
from tortoise import BaseDBAsyncClient from tortoise import BaseDBAsyncClient
@ -17,6 +17,7 @@ class ColumnInfoDict(TypedDict):
comment: str comment: str
# TODO: use dict to replace typing.Dict when dropping support for Python3.8
FieldMapDict = Dict[str, Callable[..., str]] FieldMapDict = Dict[str, Callable[..., str]]
@ -25,25 +26,24 @@ class Column(BaseModel):
data_type: str data_type: str
null: bool null: bool
default: Any default: Any
comment: Optional[str] = None comment: str | None = None
pk: bool pk: bool
unique: bool unique: bool
index: bool index: bool
length: Optional[int] = None length: int | None = None
extra: Optional[str] = None extra: str | None = None
decimal_places: Optional[int] = None decimal_places: int | None = None
max_digits: Optional[int] = None max_digits: int | None = None
def translate(self) -> ColumnInfoDict: def translate(self) -> ColumnInfoDict:
comment = default = length = index = null = pk = "" comment = default = length = index = null = pk = ""
if self.pk: if self.pk:
pk = "pk=True, " pk = "primary_key=True, "
else: else:
if self.unique: if self.unique:
index = "unique=True, " index = "unique=True, "
else: elif self.index:
if self.index: index = "db_index=True, "
index = "index=True, "
if self.data_type in ("varchar", "VARCHAR"): if self.data_type in ("varchar", "VARCHAR"):
length = f"max_length={self.length}, " length = f"max_length={self.length}, "
elif self.data_type in ("decimal", "numeric"): elif self.data_type in ("decimal", "numeric"):
@ -56,7 +56,7 @@ class Column(BaseModel):
length = ", ".join(length_parts) + ", " length = ", ".join(length_parts) + ", "
if self.null: if self.null:
null = "null=True, " null = "null=True, "
if self.default is not None: if self.default is not None and not self.pk:
if self.data_type in ("tinyint", "INT"): if self.data_type in ("tinyint", "INT"):
default = f"default={'True' if self.default == '1' else 'False'}, " default = f"default={'True' if self.default == '1' else 'False'}, "
elif self.data_type == "bool": elif self.data_type == "bool":
@ -124,62 +124,69 @@ class Inspect:
async def get_all_tables(self) -> list[str]: async def get_all_tables(self) -> list[str]:
raise NotImplementedError raise NotImplementedError
@staticmethod
def get_field_string(
field_class: str, arguments: str = "{null}{default}{comment}", **kwargs
) -> str:
name = kwargs["name"]
field_params = arguments.format(**kwargs).strip().rstrip(",")
return f"{name} = fields.{field_class}({field_params})"
@classmethod @classmethod
def decimal_field(cls, **kwargs) -> str: def decimal_field(cls, **kwargs) -> str:
return "{name} = fields.DecimalField({pk}{index}{length}{null}{default}{comment})".format( return cls.get_field_string("DecimalField", **kwargs)
**kwargs
)
@classmethod @classmethod
def time_field(cls, **kwargs) -> str: def time_field(cls, **kwargs) -> str:
return "{name} = fields.TimeField({null}{default}{comment})".format(**kwargs) return cls.get_field_string("TimeField", **kwargs)
@classmethod @classmethod
def date_field(cls, **kwargs) -> str: def date_field(cls, **kwargs) -> str:
return "{name} = fields.DateField({null}{default}{comment})".format(**kwargs) return cls.get_field_string("DateField", **kwargs)
@classmethod @classmethod
def float_field(cls, **kwargs) -> str: def float_field(cls, **kwargs) -> str:
return "{name} = fields.FloatField({null}{default}{comment})".format(**kwargs) return cls.get_field_string("FloatField", **kwargs)
@classmethod @classmethod
def datetime_field(cls, **kwargs) -> str: def datetime_field(cls, **kwargs) -> str:
return "{name} = fields.DatetimeField({null}{default}{comment})".format(**kwargs) return cls.get_field_string("DatetimeField", **kwargs)
@classmethod @classmethod
def text_field(cls, **kwargs) -> str: def text_field(cls, **kwargs) -> str:
return "{name} = fields.TextField({null}{default}{comment})".format(**kwargs) return cls.get_field_string("TextField", **kwargs)
@classmethod @classmethod
def char_field(cls, **kwargs) -> str: def char_field(cls, **kwargs) -> str:
return "{name} = fields.CharField({pk}{index}{length}{null}{default}{comment})".format( arguments = "{pk}{index}{length}{null}{default}{comment}"
**kwargs return cls.get_field_string("CharField", arguments, **kwargs)
)
@classmethod @classmethod
def int_field(cls, **kwargs) -> str: def int_field(cls, field_class="IntField", **kwargs) -> str:
return "{name} = fields.IntField({pk}{index}{comment})".format(**kwargs) arguments = "{pk}{index}{default}{comment}"
return cls.get_field_string(field_class, arguments, **kwargs)
@classmethod @classmethod
def smallint_field(cls, **kwargs) -> str: def smallint_field(cls, **kwargs) -> str:
return "{name} = fields.SmallIntField({pk}{index}{comment})".format(**kwargs) return cls.int_field("SmallIntField", **kwargs)
@classmethod @classmethod
def bigint_field(cls, **kwargs) -> str: def bigint_field(cls, **kwargs) -> str:
return "{name} = fields.BigIntField({pk}{index}{default}{comment})".format(**kwargs) return cls.int_field("BigIntField", **kwargs)
@classmethod @classmethod
def bool_field(cls, **kwargs) -> str: def bool_field(cls, **kwargs) -> str:
return "{name} = fields.BooleanField({null}{default}{comment})".format(**kwargs) return cls.get_field_string("BooleanField", **kwargs)
@classmethod @classmethod
def uuid_field(cls, **kwargs) -> str: def uuid_field(cls, **kwargs) -> str:
return "{name} = fields.UUIDField({pk}{index}{default}{comment})".format(**kwargs) arguments = "{pk}{index}{default}{comment}"
return cls.get_field_string("UUIDField", arguments, **kwargs)
@classmethod @classmethod
def json_field(cls, **kwargs) -> str: def json_field(cls, **kwargs) -> str:
return "{name} = fields.JSONField({null}{default}{comment})".format(**kwargs) return cls.get_field_string("JSONField", **kwargs)
@classmethod @classmethod
def binary_field(cls, **kwargs) -> str: def binary_field(cls, **kwargs) -> str:
return "{name} = fields.BinaryField({null}{default}{comment})".format(**kwargs) return cls.get_field_string("BinaryField", **kwargs)

View File

@ -12,11 +12,12 @@ class InspectMySQL(Inspect):
"tinyint": self.bool_field, "tinyint": self.bool_field,
"bigint": self.bigint_field, "bigint": self.bigint_field,
"varchar": self.char_field, "varchar": self.char_field,
"char": self.char_field, "char": self.uuid_field,
"longtext": self.text_field, "longtext": self.text_field,
"text": self.text_field, "text": self.text_field,
"datetime": self.datetime_field, "datetime": self.datetime_field,
"float": self.float_field, "float": self.float_field,
"double": self.float_field,
"date": self.date_field, "date": self.date_field,
"time": self.time_field, "time": self.time_field,
"decimal": self.decimal_field, "decimal": self.decimal_field,
@ -43,6 +44,8 @@ where c.TABLE_SCHEMA = %s
unique = index = False unique = index = False
if (non_unique := row["NON_UNIQUE"]) is not None: if (non_unique := row["NON_UNIQUE"]) is not None:
unique = not non_unique unique = not non_unique
elif row["COLUMN_KEY"] == "UNI":
unique = True
if (index_name := row["INDEX_NAME"]) is not None: if (index_name := row["INDEX_NAME"]) is not None:
index = index_name != "PRIMARY" index = index_name != "PRIMARY"
columns.append( columns.append(
@ -53,10 +56,8 @@ where c.TABLE_SCHEMA = %s
default=row["COLUMN_DEFAULT"], default=row["COLUMN_DEFAULT"],
pk=row["COLUMN_KEY"] == "PRI", pk=row["COLUMN_KEY"] == "PRI",
comment=row["COLUMN_COMMENT"], comment=row["COLUMN_COMMENT"],
unique=row["COLUMN_KEY"] == "UNI", unique=unique,
extra=row["EXTRA"], extra=row["EXTRA"],
# TODO: why `unque`?
unque=unique, # type:ignore
index=index, index=index,
length=row["CHARACTER_MAXIMUM_LENGTH"], length=row["CHARACTER_MAXIMUM_LENGTH"],
max_digits=row["NUMERIC_PRECISION"], max_digits=row["NUMERIC_PRECISION"],

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import re
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from aerich.inspectdb import Column, FieldMapDict, Inspect from aerich.inspectdb import Column, FieldMapDict, Inspect
@ -9,19 +10,20 @@ if TYPE_CHECKING:
class InspectPostgres(Inspect): class InspectPostgres(Inspect):
def __init__(self, conn: "BasePostgresClient", tables: list[str] | None = None) -> None: def __init__(self, conn: BasePostgresClient, tables: list[str] | None = None) -> None:
super().__init__(conn, tables) super().__init__(conn, tables)
self.schema = conn.server_settings.get("schema") or "public" self.schema = conn.server_settings.get("schema") or "public"
@property @property
def field_map(self) -> FieldMapDict: def field_map(self) -> FieldMapDict:
return { return {
"int2": self.smallint_field,
"int4": self.int_field, "int4": self.int_field,
"int8": self.int_field, "int8": self.bigint_field,
"smallint": self.smallint_field, "smallint": self.smallint_field,
"bigint": self.bigint_field,
"varchar": self.char_field, "varchar": self.char_field,
"text": self.text_field, "text": self.text_field,
"bigint": self.bigint_field,
"timestamptz": self.datetime_field, "timestamptz": self.datetime_field,
"float4": self.float_field, "float4": self.float_field,
"float8": self.float_field, "float8": self.float_field,
@ -59,6 +61,8 @@ from information_schema.constraint_column_usage const
where c.table_catalog = $1 where c.table_catalog = $1
and c.table_name = $2 and c.table_name = $2
and c.table_schema = $3""" # nosec:B608 and c.table_schema = $3""" # nosec:B608
if "psycopg" in str(type(self.conn)).lower():
sql = re.sub(r"\$[123]", "%s", sql)
ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema]) ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema])
for row in ret: for row in ret:
columns.append( columns.append(

View File

@ -2,18 +2,21 @@ from __future__ import annotations
import importlib import importlib
import os import os
from collections.abc import Iterable
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set, Tuple, Type, Union, cast from typing import cast
import asyncclick as click import asyncclick as click
import tortoise
from dictdiffer import diff from dictdiffer import diff
from tortoise import BaseDBAsyncClient, Model, Tortoise from tortoise import BaseDBAsyncClient, Model, Tortoise
from tortoise.exceptions import OperationalError from tortoise.exceptions import OperationalError
from tortoise.indexes import Index from tortoise.indexes import Index
from aerich._compat import tortoise_version_less_than
from aerich.coder import load_index
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
from aerich.enums import Color
from aerich.models import MAX_VERSION_LENGTH, Aerich from aerich.models import MAX_VERSION_LENGTH, Aerich
from aerich.utils import ( from aerich.utils import (
get_app_connection, get_app_connection,
@ -37,29 +40,29 @@ async def downgrade(db: BaseDBAsyncClient) -> str:
class Migrate: class Migrate:
upgrade_operators: List[str] = [] upgrade_operators: list[str] = []
downgrade_operators: List[str] = [] downgrade_operators: list[str] = []
_upgrade_fk_m2m_index_operators: List[str] = [] _upgrade_fk_m2m_index_operators: list[str] = []
_downgrade_fk_m2m_index_operators: List[str] = [] _downgrade_fk_m2m_index_operators: list[str] = []
_upgrade_m2m: List[str] = [] _upgrade_m2m: list[str] = []
_downgrade_m2m: List[str] = [] _downgrade_m2m: list[str] = []
_aerich = Aerich.__name__ _aerich = Aerich.__name__
_rename_fields: Dict[str, Dict[str, str]] = {} # {'model': {'old_field': 'new_field'}} _rename_fields: dict[str, dict[str, str]] = {} # {'model': {'old_field': 'new_field'}}
ddl: BaseDDL ddl: BaseDDL
ddl_class: Type[BaseDDL] ddl_class: type[BaseDDL]
_last_version_content: Optional[dict] = None _last_version_content: dict | None = None
app: str app: str
migrate_location: Path migrate_location: Path
dialect: str dialect: str
_db_version: Optional[str] = None _db_version: str | None = None
@staticmethod @staticmethod
def get_field_by_name(name: str, fields: List[dict]) -> dict: def get_field_by_name(name: str, fields: list[dict]) -> dict:
return next(filter(lambda x: x.get("name") == name, fields)) return next(filter(lambda x: x.get("name") == name, fields))
@classmethod @classmethod
def get_all_version_files(cls) -> List[str]: def get_all_version_files(cls) -> list[str]:
def get_file_version(file_name: str) -> str: def get_file_version(file_name: str) -> str:
return file_name.split("_")[0] return file_name.split("_")[0]
@ -74,11 +77,11 @@ class Migrate:
return sorted(files, key=lambda x: int(get_file_version(x))) return sorted(files, key=lambda x: int(get_file_version(x)))
@classmethod @classmethod
def _get_model(cls, model: str) -> Type[Model]: def _get_model(cls, model: str) -> type[Model]:
return Tortoise.apps[cls.app].get(model) # type: ignore return Tortoise.apps[cls.app].get(model) # type: ignore
@classmethod @classmethod
async def get_last_version(cls) -> Optional[Aerich]: async def get_last_version(cls) -> Aerich | None:
try: try:
return await Aerich.filter(app=cls.app).first() return await Aerich.filter(app=cls.app).first()
except OperationalError: except OperationalError:
@ -92,7 +95,7 @@ class Migrate:
cls._db_version = ret[1][0].get("version") cls._db_version = ret[1][0].get("version")
@classmethod @classmethod
async def load_ddl_class(cls) -> Type[BaseDDL]: async def load_ddl_class(cls) -> type[BaseDDL]:
ddl_dialect_module = importlib.import_module(f"aerich.ddl.{cls.dialect}") ddl_dialect_module = importlib.import_module(f"aerich.ddl.{cls.dialect}")
return getattr(ddl_dialect_module, f"{cls.dialect.capitalize()}DDL") return getattr(ddl_dialect_module, f"{cls.dialect.capitalize()}DDL")
@ -112,7 +115,7 @@ class Migrate:
await cls._get_db_version(connection) await cls._get_db_version(connection)
@classmethod @classmethod
async def _get_last_version_num(cls) -> Optional[int]: async def _get_last_version_num(cls) -> int | None:
last_version = await cls.get_last_version() last_version = await cls.get_last_version()
if not last_version: if not last_version:
return None return None
@ -120,7 +123,7 @@ class Migrate:
return int(version.split("_", 1)[0]) return int(version.split("_", 1)[0])
@classmethod @classmethod
async def generate_version(cls, name=None) -> str: async def generate_version(cls, name: str | None = None) -> str:
now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "") now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "")
last_version_num = await cls._get_last_version_num() last_version_num = await cls._get_last_version_num()
if last_version_num is None: if last_version_num is None:
@ -142,6 +145,22 @@ class Migrate:
Path(cls.migrate_location, version).write_text(content, encoding="utf-8") Path(cls.migrate_location, version).write_text(content, encoding="utf-8")
return version return version
@classmethod
def _exclude_extra_field_types(cls, diffs) -> list[tuple]:
# Exclude changes of db_field_types that is not about the current dialect, e.g.:
# {"db_field_types": {
# "oracle": "VARCHAR(255)" --> "oracle": "NVARCHAR2(255)"
# }}
return [
c
for c in diffs
if not (
len(c) == 3
and c[1] == "db_field_types"
and not ({i[0] for i in c[2]} & {cls.dialect, ""})
)
]
@classmethod @classmethod
async def migrate(cls, name: str, empty: bool) -> str: async def migrate(cls, name: str, empty: bool) -> str:
""" """
@ -170,7 +189,7 @@ class Migrate:
builds content for diff file from template builds content for diff file from template
""" """
def join_lines(lines: List[str]) -> str: def join_lines(lines: list[str]) -> str:
if not lines: if not lines:
return "" return ""
return ";\n ".join(lines) + ";" return ";\n ".join(lines) + ";"
@ -181,7 +200,7 @@ class Migrate:
) )
@classmethod @classmethod
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False) -> None: def _add_operator(cls, operator: str, upgrade: bool = True, fk_m2m_index: bool = False) -> None:
""" """
add operator,differentiate fk because fk is order limit add operator,differentiate fk because fk is order limit
:param operator: :param operator:
@ -202,10 +221,9 @@ class Migrate:
cls.downgrade_operators.append(operator) cls.downgrade_operators.append(operator)
@classmethod @classmethod
def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]) -> list: def _handle_indexes(cls, model: type[Model], indexes: list[tuple[str] | Index]) -> list:
if tortoise.__version__ > "0.22.2": if not tortoise_version_less_than("0.23.0"):
# The min version of tortoise is '0.11.0', so we can compare it by a `>`, # tortoise>=0.23.0 have __eq__/__hash__ with Index class since 313ee76.
# tortoise>0.22.2 have __eq__/__hash__ with Index class since 313ee76.
return indexes return indexes
if index_classes := set(index.__class__ for index in indexes if isinstance(index, Index)): if index_classes := set(index.__class__ for index in indexes if isinstance(index, Index)):
# Leave magic patch here to compare with older version of tortoise-orm # Leave magic patch here to compare with older version of tortoise-orm
@ -224,13 +242,15 @@ class Migrate:
return indexes return indexes
@classmethod @classmethod
def _get_indexes(cls, model, model_describe: dict) -> Set[Union[Index, Tuple[str, ...]]]: def _get_indexes(cls, model, model_describe: dict) -> set[Index | tuple[str, ...]]:
indexes: Set[Union[Index, Tuple[str, ...]]] = set() indexes: set[Index | tuple[str, ...]] = set()
for x in cls._handle_indexes(model, model_describe.get("indexes", [])): for x in cls._handle_indexes(model, model_describe.get("indexes", [])):
if isinstance(x, Index): if isinstance(x, Index):
indexes.add(x) indexes.add(x)
elif isinstance(x, dict):
indexes.add(load_index(x))
else: else:
indexes.add(cast(Tuple[str, ...], tuple(x))) indexes.add(cast("tuple[str, ...]", tuple(x)))
return indexes return indexes
@staticmethod @staticmethod
@ -240,13 +260,29 @@ class Migrate:
@classmethod @classmethod
def _handle_m2m_fields( def _handle_m2m_fields(
cls, old_model_describe: Dict, new_model_describe: Dict, model, new_models, upgrade=True cls, old_model_describe: dict, new_model_describe: dict, model, new_models, upgrade=True
) -> None: ) -> None:
old_m2m_fields = cast(List[dict], old_model_describe.get("m2m_fields", [])) old_m2m_fields = cast("list[dict]", old_model_describe.get("m2m_fields", []))
new_m2m_fields = cast(List[dict], new_model_describe.get("m2m_fields", [])) new_m2m_fields = cast("list[dict]", new_model_describe.get("m2m_fields", []))
new_tables: Dict[str, dict] = {field["table"]: field for field in new_models.values()} new_tables: dict[str, dict] = {
field["table"]: field
for field in new_models.values()
if field.get("managed") is not False
}
for action, option, change in get_dict_diff_by_key(old_m2m_fields, new_m2m_fields): for action, option, change in get_dict_diff_by_key(old_m2m_fields, new_m2m_fields):
if (option and option[-1] == "nullable") or change[0][0] == "db_constraint": if action == "change":
# Example:: action = 'change'; option = [0, 'unique']; change = (False, True)
attr = option[-1]
if attr == "indexed":
# Ignore changing of indexed, as it usually changed by unique
continue
elif attr == "unique":
# TODO:
continue
elif attr == "nullable":
# nullable of m2m relation is constrainted by orm framework, not by db
continue
if change[0][0] == "db_constraint":
continue continue
new_value = change[0][1] new_value = change[0][1]
if isinstance(new_value, str): if isinstance(new_value, str):
@ -290,18 +326,18 @@ class Migrate:
def _handle_relational( def _handle_relational(
cls, cls,
key: str, key: str,
old_model_describe: Dict, old_model_describe: dict,
new_model_describe: Dict, new_model_describe: dict,
model: Type[Model], model: type[Model],
old_models: Dict, old_models: dict,
new_models: Dict, new_models: dict,
upgrade=True, upgrade=True,
) -> None: ) -> None:
old_fk_fields = cast(List[dict], old_model_describe.get(key)) old_fk_fields = cast("list[dict]", old_model_describe.get(key))
new_fk_fields = cast(List[dict], new_model_describe.get(key)) new_fk_fields = cast("list[dict]", new_model_describe.get(key))
old_fk_fields_name: List[str] = [i.get("name", "") for i in old_fk_fields] old_fk_fields_name: list[str] = [i.get("name", "") for i in old_fk_fields]
new_fk_fields_name: List[str] = [i.get("name", "") for i in new_fk_fields] new_fk_fields_name: list[str] = [i.get("name", "") for i in new_fk_fields]
# add # add
for new_fk_field_name in set(new_fk_fields_name).difference(set(old_fk_fields_name)): for new_fk_field_name in set(new_fk_fields_name).difference(set(old_fk_fields_name)):
@ -312,7 +348,9 @@ class Migrate:
cls._add_operator(sql, upgrade, fk_m2m_index=True) cls._add_operator(sql, upgrade, fk_m2m_index=True)
# drop # drop
for old_fk_field_name in set(old_fk_fields_name).difference(set(new_fk_fields_name)): for old_fk_field_name in set(old_fk_fields_name).difference(set(new_fk_fields_name)):
old_fk_field = cls.get_field_by_name(old_fk_field_name, cast(List[dict], old_fk_fields)) old_fk_field = cls.get_field_by_name(
old_fk_field_name, cast("list[dict]", old_fk_fields)
)
if old_fk_field.get("db_constraint"): if old_fk_field.get("db_constraint"):
ref_describe = cast(dict, old_models[old_fk_field["python_type"]]) ref_describe = cast(dict, old_models[old_fk_field["python_type"]])
sql = cls._drop_fk(model, old_fk_field, ref_describe) sql = cls._drop_fk(model, old_fk_field, ref_describe)
@ -321,11 +359,11 @@ class Migrate:
@classmethod @classmethod
def _handle_fk_fields( def _handle_fk_fields(
cls, cls,
old_model_describe: Dict, old_model_describe: dict,
new_model_describe: Dict, new_model_describe: dict,
model: Type[Model], model: type[Model],
old_models: Dict, old_models: dict,
new_models: Dict, new_models: dict,
upgrade=True, upgrade=True,
) -> None: ) -> None:
key = "fk_fields" key = "fk_fields"
@ -336,11 +374,11 @@ class Migrate:
@classmethod @classmethod
def _handle_o2o_fields( def _handle_o2o_fields(
cls, cls,
old_model_describe: Dict, old_model_describe: dict,
new_model_describe: Dict, new_model_describe: dict,
model: Type[Model], model: type[Model],
old_models: Dict, old_models: dict,
new_models: Dict, new_models: dict,
upgrade=True, upgrade=True,
) -> None: ) -> None:
key = "o2o_fields" key = "o2o_fields"
@ -350,7 +388,7 @@ class Migrate:
@classmethod @classmethod
def diff_models( def diff_models(
cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True cls, old_models: dict[str, dict], new_models: dict[str, dict], upgrade=True
) -> None: ) -> None:
""" """
diff models and add operators diff models and add operators
@ -362,9 +400,11 @@ class Migrate:
_aerich = f"{cls.app}.{cls._aerich}" _aerich = f"{cls.app}.{cls._aerich}"
old_models.pop(_aerich, None) old_models.pop(_aerich, None)
new_models.pop(_aerich, None) new_models.pop(_aerich, None)
models_with_rename_field: Set[str] = set() # models that trigger the click.prompt models_with_rename_field: set[str] = set() # models that trigger the click.prompt
for new_model_str, new_model_describe in new_models.items(): for new_model_str, new_model_describe in new_models.items():
if upgrade and new_model_describe.get("managed") is False:
continue
model = cls._get_model(new_model_describe["name"].split(".")[1]) model = cls._get_model(new_model_describe["name"].split(".")[1])
if new_model_str not in old_models: if new_model_str not in old_models:
if upgrade: if upgrade:
@ -375,6 +415,8 @@ class Migrate:
pass pass
else: else:
old_model_describe = cast(dict, old_models.get(new_model_str)) old_model_describe = cast(dict, old_models.get(new_model_str))
if not upgrade and old_model_describe.get("managed") is False:
continue
# rename table # rename table
new_table = cast(str, new_model_describe.get("table")) new_table = cast(str, new_model_describe.get("table"))
old_table = cast(str, old_model_describe.get("table")) old_table = cast(str, old_model_describe.get("table"))
@ -383,25 +425,19 @@ class Migrate:
old_unique_together = set( old_unique_together = set(
map( map(
lambda x: tuple(x), lambda x: tuple(x),
cast(List[Iterable[str]], old_model_describe.get("unique_together")), cast("list[Iterable[str]]", old_model_describe.get("unique_together")),
) )
) )
new_unique_together = set( new_unique_together = set(
map( map(
lambda x: tuple(x), lambda x: tuple(x),
cast(List[Iterable[str]], new_model_describe.get("unique_together")), cast("list[Iterable[str]]", new_model_describe.get("unique_together")),
) )
) )
old_indexes = cls._get_indexes(model, old_model_describe) old_indexes = cls._get_indexes(model, old_model_describe)
new_indexes = cls._get_indexes(model, new_model_describe) new_indexes = cls._get_indexes(model, new_model_describe)
old_pk_field = old_model_describe.get("pk_field")
new_pk_field = new_model_describe.get("pk_field")
# pk field # pk field
changes = diff(old_pk_field, new_pk_field) cls._handle_pk_field_alter(model, old_model_describe, new_model_describe, upgrade)
for action, option, change in changes:
# current only support rename pk
if action == "change" and option == "name":
cls._add_operator(cls._rename_field(model, *change), upgrade)
# fk fields # fk fields
args = (old_model_describe, new_model_describe, model, old_models, new_models) args = (old_model_describe, new_model_describe, model, old_models, new_models)
cls._handle_fk_fields(*args, upgrade=upgrade) cls._handle_fk_fields(*args, upgrade=upgrade)
@ -421,25 +457,25 @@ class Migrate:
cls._add_operator(cls._drop_index(model, index, True), upgrade, True) cls._add_operator(cls._drop_index(model, index, True), upgrade, True)
# add indexes # add indexes
for idx in new_indexes.difference(old_indexes): for idx in new_indexes.difference(old_indexes):
cls._add_operator(cls._add_index(model, idx, False), upgrade, True) cls._add_operator(cls._add_index(model, idx), upgrade, fk_m2m_index=True)
# remove indexes # remove indexes
for idx in old_indexes.difference(new_indexes): for idx in old_indexes.difference(new_indexes):
cls._add_operator(cls._drop_index(model, idx, False), upgrade, True) cls._add_operator(cls._drop_index(model, idx), upgrade, fk_m2m_index=True)
old_data_fields = list( old_data_fields = list(
filter( filter(
lambda x: x.get("db_field_types") is not None, lambda x: x.get("db_field_types") is not None,
cast(List[dict], old_model_describe.get("data_fields")), cast("list[dict]", old_model_describe.get("data_fields")),
) )
) )
new_data_fields = list( new_data_fields = list(
filter( filter(
lambda x: x.get("db_field_types") is not None, lambda x: x.get("db_field_types") is not None,
cast(List[dict], new_model_describe.get("data_fields")), cast("list[dict]", new_model_describe.get("data_fields")),
) )
) )
old_data_fields_name = cast(List[str], [i.get("name") for i in old_data_fields]) old_data_fields_name = cast("list[str]", [i.get("name") for i in old_data_fields])
new_data_fields_name = cast(List[str], [i.get("name") for i in new_data_fields]) new_data_fields_name = cast("list[str]", [i.get("name") for i in new_data_fields])
# add fields or rename fields # add fields or rename fields
for new_data_field_name in set(new_data_fields_name).difference( for new_data_field_name in set(new_data_fields_name).difference(
@ -459,7 +495,9 @@ class Migrate:
len(new_name.symmetric_difference(set(f.get("name", "")))), len(new_name.symmetric_difference(set(f.get("name", "")))),
), ),
): ):
changes = list(diff(old_data_field, new_data_field)) changes = cls._exclude_extra_field_types(
diff(old_data_field, new_data_field)
)
old_data_field_name = cast(str, old_data_field.get("name")) old_data_field_name = cast(str, old_data_field.get("name"))
if len(changes) == 2: if len(changes) == 2:
# rename field # rename field
@ -564,69 +602,115 @@ class Migrate:
# change fields # change fields
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)): for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
old_data_field = cls.get_field_by_name(field_name, old_data_fields) cls._handle_field_changes(
new_data_field = cls.get_field_by_name(field_name, new_data_fields) model, field_name, old_data_fields, new_data_fields, upgrade
changes = diff(old_data_field, new_data_field) )
modified = False
for change in changes:
_, option, old_new = change
if option == "indexed":
# change index
if old_new[0] is False and old_new[1] is True:
unique = new_data_field.get("unique")
cls._add_operator(
cls._add_index(model, (field_name,), unique), upgrade, True
)
else:
unique = old_data_field.get("unique")
cls._add_operator(
cls._drop_index(model, (field_name,), unique), upgrade, True
)
elif option == "db_field_types.":
if new_data_field.get("field_type") == "DecimalField":
# modify column
cls._add_operator(
cls._modify_field(model, new_data_field),
upgrade,
)
else:
continue
elif option == "default":
if not (
is_default_function(old_new[0]) or is_default_function(old_new[1])
):
# change column default
cls._add_operator(
cls._alter_default(model, new_data_field), upgrade
)
elif option == "unique":
# because indexed include it
continue
elif option == "nullable":
# change nullable
cls._add_operator(cls._alter_null(model, new_data_field), upgrade)
elif option == "description":
# change comment
cls._add_operator(cls._set_comment(model, new_data_field), upgrade)
else:
if modified:
continue
# modify column
cls._add_operator(
cls._modify_field(model, new_data_field),
upgrade,
)
modified = True
for old_model in old_models.keys() - new_models.keys(): for old_model in old_models.keys() - new_models.keys():
if not upgrade and old_models[old_model].get("managed") is False:
continue
cls._add_operator(cls.drop_model(old_models[old_model]["table"]), upgrade) cls._add_operator(cls.drop_model(old_models[old_model]["table"]), upgrade)
@classmethod @classmethod
def rename_table(cls, model: Type[Model], old_table_name: str, new_table_name: str) -> str: def _handle_pk_field_alter(
cls,
model: type[Model],
old_model_describe: dict[str, dict],
new_model_describe: dict[str, dict],
upgrade: bool,
) -> None:
old_pk_field = old_model_describe.get("pk_field", {})
new_pk_field = new_model_describe.get("pk_field", {})
changes = cls._exclude_extra_field_types(diff(old_pk_field, new_pk_field))
sqls: list[str] = []
for action, option, change in changes:
if action != "change":
continue
if option == "db_column":
# rename pk
sql = cls._rename_field(model, *change)
elif option == "constraints.max_length":
sql = cls._modify_field(model, new_pk_field)
elif option == "field_type":
# Only support change field type between int fields, e.g.: IntField -> BigIntField
if not all(field_type.endswith("IntField") for field_type in change):
if upgrade:
model_name = model._meta.full_name.split(".")[-1]
field_name = new_pk_field.get("name", "")
msg = (
f"Does not support change primary_key({model_name}.{field_name}) field type,"
" you may need to do it manually."
)
click.secho(msg, fg=Color.yellow)
return
sql = cls._modify_field(model, new_pk_field)
else:
# Skip option like 'constraints.ge', 'constraints.le', 'db_field_types.'
continue
sqls.append(sql)
for sql in sorted(sqls, key=lambda x: "RENAME" not in x):
# TODO: alter references field in m2m table
cls._add_operator(sql, upgrade)
@classmethod
def _handle_field_changes(
cls,
model: type[Model],
field_name: str,
old_data_fields: list[dict],
new_data_fields: list[dict],
upgrade: bool,
) -> None:
old_data_field = cls.get_field_by_name(field_name, old_data_fields)
new_data_field = cls.get_field_by_name(field_name, new_data_fields)
changes = cls._exclude_extra_field_types(diff(old_data_field, new_data_field))
options = {c[1] for c in changes}
modified = False
for change in changes:
_, option, old_new = change
if option == "indexed":
# change index
if old_new[0] is False and old_new[1] is True:
unique = new_data_field.get("unique")
cls._add_operator(cls._add_index(model, (field_name,), unique), upgrade, True)
else:
unique = old_data_field.get("unique")
cls._add_operator(cls._drop_index(model, (field_name,), unique), upgrade, True)
elif option == "db_field_types.":
if new_data_field.get("field_type") == "DecimalField":
# modify column
cls._add_operator(cls._modify_field(model, new_data_field), upgrade)
elif option == "default":
if not (is_default_function(old_new[0]) or is_default_function(old_new[1])):
# change column default
cls._add_operator(cls._alter_default(model, new_data_field), upgrade)
elif option == "unique":
if "indexed" in options:
# indexed include it
continue
# Change unique for indexed field, e.g.: `db_index=True, unique=False` --> `db_index=True, unique=True`
drop_unique = old_new[0] is True and old_new[1] is False
for sql in cls.ddl.alter_indexed_column_unique(model, field_name, drop_unique):
cls._add_operator(sql, upgrade, True)
elif option == "nullable":
# change nullable
cls._add_operator(cls._alter_null(model, new_data_field), upgrade)
elif option == "description":
# change comment
cls._add_operator(cls._set_comment(model, new_data_field), upgrade)
else:
if modified:
continue
# modify column
cls._add_operator(cls._modify_field(model, new_data_field), upgrade)
modified = True
@classmethod
def rename_table(cls, model: type[Model], old_table_name: str, new_table_name: str) -> str:
return cls.ddl.rename_table(model, old_table_name, new_table_name) return cls.ddl.rename_table(model, old_table_name, new_table_name)
@classmethod @classmethod
def add_model(cls, model: Type[Model]) -> str: def add_model(cls, model: type[Model]) -> str:
return cls.ddl.create_table(model) return cls.ddl.create_table(model)
@classmethod @classmethod
@ -635,7 +719,7 @@ class Migrate:
@classmethod @classmethod
def create_m2m( def create_m2m(
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict cls, model: type[Model], field_describe: dict, reference_table_describe: dict
) -> str: ) -> str:
return cls.ddl.create_m2m(model, field_describe, reference_table_describe) return cls.ddl.create_m2m(model, field_describe, reference_table_describe)
@ -644,7 +728,7 @@ class Migrate:
return cls.ddl.drop_m2m(table_name) return cls.ddl.drop_m2m(table_name)
@classmethod @classmethod
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Iterable[str]) -> List[str]: def _resolve_fk_fields_name(cls, model: type[Model], fields_name: Iterable[str]) -> list[str]:
ret = [] ret = []
for field_name in fields_name: for field_name in fields_name:
try: try:
@ -662,9 +746,19 @@ class Migrate:
@classmethod @classmethod
def _drop_index( def _drop_index(
cls, model: Type[Model], fields_name: Union[Iterable[str], Index], unique=False cls, model: type[Model], fields_name: Iterable[str] | Index, unique=False
) -> str: ) -> str:
if isinstance(fields_name, Index): if isinstance(fields_name, Index):
if cls.dialect == "mysql":
# schema_generator of MySQL return a empty index sql
if hasattr(fields_name, "field_names"):
# tortoise>=0.24
fields = fields_name.field_names
else:
# TODO: remove else when drop support for tortoise<0.24
if not (fields := fields_name.fields):
fields = [getattr(i, "get_sql")() for i in fields_name.expressions]
return cls.ddl.drop_index(model, fields, unique, name=fields_name.name)
return cls.ddl.drop_index_by_name( return cls.ddl.drop_index_by_name(
model, fields_name.index_name(cls.ddl.schema_generator, model) model, fields_name.index_name(cls.ddl.schema_generator, model)
) )
@ -673,50 +767,72 @@ class Migrate:
@classmethod @classmethod
def _add_index( def _add_index(
cls, model: Type[Model], fields_name: Union[Iterable[str], Index], unique=False cls, model: type[Model], fields_name: Iterable[str] | Index, unique=False
) -> str: ) -> str:
if isinstance(fields_name, Index): if isinstance(fields_name, Index):
return fields_name.get_sql(cls.ddl.schema_generator, model, False) if cls.dialect == "mysql":
# schema_generator of MySQL return a empty index sql
if hasattr(fields_name, "field_names"):
# tortoise>=0.24
fields = fields_name.field_names
else:
# TODO: remove else when drop support for tortoise<0.24
if not (fields := fields_name.fields):
fields = [getattr(i, "get_sql")() for i in fields_name.expressions]
return cls.ddl.add_index(
model,
fields,
name=fields_name.name,
index_type=fields_name.INDEX_TYPE,
extra=fields_name.extra,
)
sql = fields_name.get_sql(cls.ddl.schema_generator, model, safe=True)
if tortoise_version_less_than("0.24.0"):
sql = sql.replace(" ", " ")
if cls.dialect == "postgres" and (exists := "IF NOT EXISTS ") not in sql:
idx = " INDEX "
sql = sql.replace(idx, idx + exists)
return sql
field_names = cls._resolve_fk_fields_name(model, fields_name) field_names = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.add_index(model, field_names, unique) return cls.ddl.add_index(model, field_names, unique)
@classmethod @classmethod
def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False) -> str: def _add_field(cls, model: type[Model], field_describe: dict, is_pk: bool = False) -> str:
return cls.ddl.add_column(model, field_describe, is_pk) return cls.ddl.add_column(model, field_describe, is_pk)
@classmethod @classmethod
def _alter_default(cls, model: Type[Model], field_describe: dict) -> str: def _alter_default(cls, model: type[Model], field_describe: dict) -> str:
return cls.ddl.alter_column_default(model, field_describe) return cls.ddl.alter_column_default(model, field_describe)
@classmethod @classmethod
def _alter_null(cls, model: Type[Model], field_describe: dict) -> str: def _alter_null(cls, model: type[Model], field_describe: dict) -> str:
return cls.ddl.alter_column_null(model, field_describe) return cls.ddl.alter_column_null(model, field_describe)
@classmethod @classmethod
def _set_comment(cls, model: Type[Model], field_describe: dict) -> str: def _set_comment(cls, model: type[Model], field_describe: dict) -> str:
return cls.ddl.set_comment(model, field_describe) return cls.ddl.set_comment(model, field_describe)
@classmethod @classmethod
def _modify_field(cls, model: Type[Model], field_describe: dict) -> str: def _modify_field(cls, model: type[Model], field_describe: dict) -> str:
return cls.ddl.modify_column(model, field_describe) return cls.ddl.modify_column(model, field_describe)
@classmethod @classmethod
def _drop_fk( def _drop_fk(
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict cls, model: type[Model], field_describe: dict, reference_table_describe: dict
) -> str: ) -> str:
return cls.ddl.drop_fk(model, field_describe, reference_table_describe) return cls.ddl.drop_fk(model, field_describe, reference_table_describe)
@classmethod @classmethod
def _remove_field(cls, model: Type[Model], column_name: str) -> str: def _remove_field(cls, model: type[Model], column_name: str) -> str:
return cls.ddl.drop_column(model, column_name) return cls.ddl.drop_column(model, column_name)
@classmethod @classmethod
def _rename_field(cls, model: Type[Model], old_field_name: str, new_field_name: str) -> str: def _rename_field(cls, model: type[Model], old_field_name: str, new_field_name: str) -> str:
return cls.ddl.rename_column(model, old_field_name, new_field_name) return cls.ddl.rename_column(model, old_field_name, new_field_name)
@classmethod @classmethod
def _change_field( def _change_field(
cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict cls, model: type[Model], old_field_describe: dict, new_field_describe: dict
) -> str: ) -> str:
db_field_types = cast(dict, new_field_describe.get("db_field_types")) db_field_types = cast(dict, new_field_describe.get("db_field_types"))
return cls.ddl.change_column( return cls.ddl.change_column(
@ -728,7 +844,7 @@ class Migrate:
@classmethod @classmethod
def _add_fk( def _add_fk(
cls, model: Type[Model], field_describe: dict, reference_table_describe: dict cls, model: type[Model], field_describe: dict, reference_table_describe: dict
) -> str: ) -> str:
""" """
add fk add fk

View File

@ -4,9 +4,9 @@ import importlib.util
import os import os
import re import re
import sys import sys
from collections.abc import Generator
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import Dict, Generator, Optional, Union
from asyncclick import BadOptionUsage, ClickException, Context from asyncclick import BadOptionUsage, ClickException, Context
from dictdiffer import diff from dictdiffer import diff
@ -34,23 +34,19 @@ def get_app_connection_name(config, app_name: str) -> str:
get connection name get connection name
:param config: :param config:
:param app_name: :param app_name:
:return: :return: the default connection name (Usally it is 'default')
""" """
app = config.get("apps").get(app_name) if app := config.get("apps").get(app_name):
if app:
return app.get("default_connection", "default") return app.get("default_connection", "default")
raise BadOptionUsage( raise BadOptionUsage(option_name="--app", message=f"Can't get app named {app_name!r}")
option_name="--app",
message=f'Can\'t get app named "{app_name}"',
)
def get_app_connection(config, app) -> BaseDBAsyncClient: def get_app_connection(config, app) -> BaseDBAsyncClient:
""" """
get connection name get connection client
:param config: :param config:
:param app: :param app:
:return: :return: client instance
""" """
return Tortoise.get_connection(get_app_connection_name(config, app)) return Tortoise.get_connection(get_app_connection_name(config, app))
@ -81,7 +77,7 @@ def get_tortoise_config(ctx: Context, tortoise_orm: str) -> dict:
return config return config
def get_models_describe(app: str) -> Dict: def get_models_describe(app: str) -> dict:
""" """
get app models describe get app models describe
:param app: :param app:
@ -89,16 +85,17 @@ def get_models_describe(app: str) -> Dict:
""" """
ret = {} ret = {}
for model in Tortoise.apps[app].values(): for model in Tortoise.apps[app].values():
managed = getattr(model.Meta, "managed", None)
describe = model.describe() describe = model.describe()
ret[describe.get("name")] = describe ret[describe.get("name")] = dict(describe, managed=managed)
return ret return ret
def is_default_function(string: str) -> Optional[re.Match]: def is_default_function(string: str) -> re.Match | None:
return re.match(r"^<function.+>$", str(string or "")) return re.match(r"^<function.+>$", str(string or ""))
def import_py_file(file: Union[str, Path]) -> ModuleType: def import_py_file(file: str | Path) -> ModuleType:
module_name, file_ext = os.path.splitext(os.path.split(file)[-1]) module_name, file_ext = os.path.splitext(os.path.split(file)[-1])
spec = importlib.util.spec_from_file_location(module_name, file) spec = importlib.util.spec_from_file_location(module_name, file)
module = importlib.util.module_from_spec(spec) # type:ignore[arg-type] module = importlib.util.module_from_spec(spec) # type:ignore[arg-type]

View File

@ -1 +1,3 @@
__version__ = "0.8.0" from importlib.metadata import version
__version__ = version(__package__)

View File

@ -1,27 +1,30 @@
from __future__ import annotations
import asyncio import asyncio
import contextlib
import os import os
from typing import Generator import sys
from collections.abc import Generator
from pathlib import Path
import pytest import pytest
from tortoise import Tortoise, expand_db_url, generate_schema_for_client from tortoise import Tortoise, expand_db_url
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator from tortoise.backends.base_postgres.schema_generator import BasePostgresSchemaGenerator
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
from tortoise.contrib.test import MEMORY_SQLITE from tortoise.contrib.test import MEMORY_SQLITE
from tortoise.exceptions import DBConnectionError, OperationalError
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
from aerich.migrate import Migrate from aerich.migrate import Migrate
from tests._utils import chdir, copy_files, init_db, run_shell
db_url = os.getenv("TEST_DB", MEMORY_SQLITE) db_url = os.getenv("TEST_DB", MEMORY_SQLITE)
db_url_second = os.getenv("TEST_DB_SECOND", MEMORY_SQLITE) db_url_second = os.getenv("TEST_DB_SECOND", MEMORY_SQLITE)
tortoise_orm = { tortoise_orm = {
"connections": { "connections": {
"default": expand_db_url(db_url, True), "default": expand_db_url(db_url, testing=True),
"second": expand_db_url(db_url_second, True), "second": expand_db_url(db_url_second, testing=True),
}, },
"apps": { "apps": {
"models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"}, "models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"},
@ -55,20 +58,40 @@ def event_loop() -> Generator:
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
async def initialize_tests(event_loop, request) -> None: async def initialize_tests(event_loop, request) -> None:
# Placing init outside the try block since it doesn't await init_db(tortoise_orm)
# establish connections to the DB eagerly.
await Tortoise.init(config=tortoise_orm)
with contextlib.suppress(DBConnectionError, OperationalError):
await Tortoise._drop_databases()
await Tortoise.init(config=tortoise_orm, _create_db=True)
await generate_schema_for_client(Tortoise.get_connection("default"), safe=True)
client = Tortoise.get_connection("default") client = Tortoise.get_connection("default")
if client.schema_generator is MySQLSchemaGenerator: if client.schema_generator is MySQLSchemaGenerator:
Migrate.ddl = MysqlDDL(client) Migrate.ddl = MysqlDDL(client)
elif client.schema_generator is SqliteSchemaGenerator: elif client.schema_generator is SqliteSchemaGenerator:
Migrate.ddl = SqliteDDL(client) Migrate.ddl = SqliteDDL(client)
elif client.schema_generator is AsyncpgSchemaGenerator: elif issubclass(client.schema_generator, BasePostgresSchemaGenerator):
Migrate.ddl = PostgresDDL(client) Migrate.ddl = PostgresDDL(client)
Migrate.dialect = Migrate.ddl.DIALECT Migrate.dialect = Migrate.ddl.DIALECT
request.addfinalizer(lambda: event_loop.run_until_complete(Tortoise._drop_databases())) request.addfinalizer(lambda: event_loop.run_until_complete(Tortoise._drop_databases()))
@pytest.fixture
def new_aerich_project(tmp_path: Path):
test_dir = Path(__file__).parent / "tests"
asset_dir = test_dir / "assets" / "fake"
settings_py = asset_dir / "settings.py"
_tests_py = asset_dir / "_tests.py"
db_py = asset_dir / "db.py"
models_py = test_dir / "models.py"
models_second_py = test_dir / "models_second.py"
copy_files(settings_py, _tests_py, models_py, models_second_py, db_py, target_dir=tmp_path)
dst_dir = tmp_path / "tests"
dst_dir.mkdir()
dst_dir.joinpath("__init__.py").touch()
copy_files(test_dir / "_utils.py", test_dir / "indexes.py", target_dir=dst_dir)
if should_remove := str(tmp_path) not in sys.path:
sys.path.append(str(tmp_path))
with chdir(tmp_path):
run_shell("python db.py create", capture_output=False)
try:
yield
finally:
if not os.getenv("AERICH_DONT_DROP_FAKE_DB"):
run_shell("python db.py drop", capture_output=False)
if should_remove:
sys.path.remove(str(tmp_path))

1155
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,49 +1,60 @@
[tool.poetry] [project]
name = "aerich" name = "aerich"
version = "0.8.1" version = "0.8.2"
description = "A database migrations tool for Tortoise ORM." description = "A database migrations tool for Tortoise ORM."
authors = ["long2ice <long2ice@gmail.com>"] authors = [{name="long2ice", email="long2ice@gmail.com>"}]
license = "Apache-2.0" license = { text = "Apache-2.0" }
readme = "README.md" readme = "README.md"
keywords = ["migrate", "Tortoise-ORM", "mysql"]
packages = [{ include = "aerich" }]
include = ["CHANGELOG.md", "LICENSE", "README.md"]
requires-python = ">=3.8"
dependencies = [
"tortoise-orm (>=0.21.0,<1.0.0); python_version < '4.0'",
"pydantic (>=2.0.2,!=2.1.0,!=2.7.0,<3.0.0)",
"dictdiffer (>=0.9.0,<1.0.0)",
"asyncclick (>=8.1.7,<9.0.0)",
"eval-type-backport (>=0.2.2,<1.0.0); python_version < '3.10'",
]
[project.optional-dependencies]
toml = [
"tomli-w (>=1.1.0,<2.0.0); python_version >= '3.11'",
"tomlkit (>=0.11.4,<1.0.0); python_version < '3.11'",
]
# Need asyncpg or psyncopg for PostgreSQL
asyncpg = ["asyncpg"]
psycopg = ["psycopg[pool,binary] (>=3.0.12,<4.0.0)"]
# Need asyncmy or aiomysql for MySQL
asyncmy = ["asyncmy>=0.2.9; python_version < '4.0'"]
mysql = ["aiomysql>=0.2.0"]
[project.urls]
homepage = "https://github.com/tortoise/aerich" homepage = "https://github.com/tortoise/aerich"
repository = "https://github.com/tortoise/aerich.git" repository = "https://github.com/tortoise/aerich.git"
documentation = "https://github.com/tortoise/aerich" documentation = "https://github.com/tortoise/aerich"
keywords = ["migrate", "Tortoise-ORM", "mysql"]
packages = [
{ include = "aerich" }
]
include = ["CHANGELOG.md", "LICENSE", "README.md"]
[tool.poetry.dependencies] [project.scripts]
python = "^3.8" aerich = "aerich.cli:main"
tortoise-orm = ">=0.21"
asyncpg = { version = "*", optional = true } [tool.poetry]
asyncmy = { version = "^0.2.9", optional = true, allow-prereleases = true } requires-poetry = ">=2.0"
pydantic = "^2.0,!=2.7.0"
dictdiffer = "*"
tomlkit = { version = "*", optional = true, python="<3.11" }
tomli-w = { version = "^1.1.0", optional = true, python=">=3.11" }
asyncclick = "^8.1.7.2"
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
ruff = "*" ruff = "^0.9.0"
isort = "*" bandit = "^1.7.0"
black = "*" mypy = "^1.10.0"
pytest = "*" twine = "^6.1.0"
pytest-xdist = "*"
[tool.poetry.group.test.dependencies]
pytest = "^8.3.0"
pytest-mock = "^3.14.0"
pytest-xdist = "^3.6.0"
# Breaking change in 0.23.* # Breaking change in 0.23.*
# https://github.com/pytest-dev/pytest-asyncio/issues/706 # https://github.com/pytest-dev/pytest-asyncio/issues/706
pytest-asyncio = "^0.21.2" pytest-asyncio = "^0.21.2"
bandit = "*" # required for sha256_password by asyncmy
pytest-mock = "*" cryptography = {version="^44.0.1", python="!=3.9.0,!=3.9.1"}
cryptography = "*"
mypy = "^1.10.0"
[tool.poetry.extras]
asyncmy = ["asyncmy"]
asyncpg = ["asyncpg"]
toml = ["tomlkit", "tomli-w"]
[tool.aerich] [tool.aerich]
tortoise_orm = "conftest.tortoise_orm" tortoise_orm = "conftest.tortoise_orm"
@ -51,25 +62,55 @@ location = "./migrations"
src_folder = "./." src_folder = "./."
[build-system] [build-system]
requires = ["poetry-core>=1.0.0"] requires = ["poetry-core>=2.0.0"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"
[tool.poetry.scripts]
aerich = "aerich.cli:main"
[tool.black]
line-length = 100
target-version = ['py38', 'py39', 'py310', 'py311', 'py312']
[tool.pytest.ini_options] [tool.pytest.ini_options]
asyncio_mode = 'auto' asyncio_mode = 'auto'
[tool.coverage.run]
branch = true
source = ["aerich"]
[tool.coverage.report]
exclude_also = [
"if TYPE_CHECKING:"
]
[tool.mypy] [tool.mypy]
pretty = true pretty = true
python_version = "3.8" python_version = "3.8"
check_untyped_defs = true
warn_unused_ignores = true
disallow_incomplete_defs = false
exclude = ["tests/assets", "migrations"]
[[tool.mypy.overrides]]
module = [
'dictdiffer.*',
'tomlkit',
'tomli_w',
'tomli',
]
ignore_missing_imports = true ignore_missing_imports = true
[tool.ruff] [tool.ruff]
line-length = 100 line-length = 100
[tool.ruff.lint] [tool.ruff.lint]
ignore = ['E501'] extend-select = [
"I", # https://docs.astral.sh/ruff/rules/#isort-i
"SIM", # https://docs.astral.sh/ruff/rules/#flake8-simplify-sim
"FA", # https://docs.astral.sh/ruff/rules/#flake8-future-annotations-fa
"UP", # https://docs.astral.sh/ruff/rules/#pyupgrade-up
"RUF100", # https://docs.astral.sh/ruff/rules/#ruff-specific-rules-ruf
]
ignore = ["UP031"] # https://docs.astral.sh/ruff/rules/printf-string-formatting/
[tool.ruff.lint.per-file-ignores]
# TODO: Remove this line when dropping support for Python3.8
"aerich/inspectdb/__init__.py" = ["UP006", "UP035"]
"aerich/_compat.py" = ["F401"]
[tool.bandit]
exclude_dirs = ["tests", "conftest.py"]

87
tests/_utils.py Normal file
View File

@ -0,0 +1,87 @@
import contextlib
import os
import platform
import shlex
import shutil
import subprocess
import sys
from pathlib import Path
from tortoise import Tortoise, generate_schema_for_client
from tortoise.exceptions import DBConnectionError, OperationalError
if sys.version_info >= (3, 11):
from contextlib import chdir
else:
class chdir(contextlib.AbstractContextManager): # Copied from source code of Python3.13
"""Non thread-safe context manager to change the current working directory."""
def __init__(self, path):
self.path = path
self._old_cwd = []
def __enter__(self):
self._old_cwd.append(os.getcwd())
os.chdir(self.path)
def __exit__(self, *excinfo):
os.chdir(self._old_cwd.pop())
async def drop_db(tortoise_orm) -> None:
# Placing init outside the try-block(suppress) since it doesn't
# establish connections to the DB eagerly.
await Tortoise.init(config=tortoise_orm)
with contextlib.suppress(DBConnectionError, OperationalError):
await Tortoise._drop_databases()
async def init_db(tortoise_orm, generate_schemas=True) -> None:
await drop_db(tortoise_orm)
await Tortoise.init(config=tortoise_orm, _create_db=True)
if generate_schemas:
await generate_schema_for_client(Tortoise.get_connection("default"), safe=True)
def copy_files(*src_files: Path, target_dir: Path) -> None:
for src in src_files:
shutil.copy(src, target_dir)
class Dialect:
test_db_url: str
@classmethod
def load_env(cls) -> None:
if getattr(cls, "test_db_url", None) is None:
cls.test_db_url = os.getenv("TEST_DB", "")
@classmethod
def is_postgres(cls) -> bool:
cls.load_env()
return "postgres" in cls.test_db_url
@classmethod
def is_mysql(cls) -> bool:
cls.load_env()
return "mysql" in cls.test_db_url
@classmethod
def is_sqlite(cls) -> bool:
cls.load_env()
return not cls.test_db_url or "sqlite" in cls.test_db_url
WINDOWS = platform.system() == "Windows"
def run_shell(command: str, capture_output=True, **kw) -> str:
if WINDOWS and command.startswith("aerich "):
command = "python -m " + command
r = subprocess.run(shlex.split(command), capture_output=capture_output)
if r.returncode != 0 and r.stderr:
return r.stderr.decode()
if not r.stdout:
return ""
return r.stdout.decode()

View File

@ -0,0 +1,80 @@
import pytest
from models import NewModel
from models_second import Config
from settings import TORTOISE_ORM
from tortoise import Tortoise
from tortoise.exceptions import OperationalError
try:
# This error does not translate to tortoise's OperationalError
from psycopg.errors import UndefinedColumn
except ImportError:
errors = (OperationalError,)
else:
errors = (OperationalError, UndefinedColumn)
@pytest.fixture(scope="session")
def anyio_backend() -> str:
return "asyncio"
@pytest.fixture(autouse=True)
async def init_connections():
await Tortoise.init(TORTOISE_ORM)
try:
yield
finally:
await Tortoise.close_connections()
@pytest.mark.anyio
async def test_init_db():
m1 = await NewModel.filter(name="")
assert isinstance(m1, list)
m2 = await Config.filter(key="")
assert isinstance(m2, list)
await NewModel.create(name="")
await Config.create(key="", label="", value={})
@pytest.mark.anyio
async def test_fake_field_1():
assert "field_1" in NewModel._meta.fields_map
assert "field_1" in Config._meta.fields_map
with pytest.raises(errors):
await NewModel.create(name="", field_1=1)
with pytest.raises(errors):
await Config.create(key="", label="", value={}, field_1=1)
obj1 = NewModel(name="", field_1=1)
with pytest.raises(errors):
await obj1.save()
obj1 = NewModel(name="")
with pytest.raises(errors):
await obj1.save()
with pytest.raises(errors):
obj1 = await NewModel.first()
obj1 = await NewModel.all().first().values("id", "name")
assert obj1 and obj1["id"]
obj2 = Config(key="", label="", value={}, field_1=1)
with pytest.raises(errors):
await obj2.save()
obj2 = Config(key="", label="", value={})
with pytest.raises(errors):
await obj2.save()
with pytest.raises(errors):
obj2 = await Config.first()
obj2 = await Config.all().first().values("id", "key")
assert obj2 and obj2["id"]
@pytest.mark.anyio
async def test_fake_field_2():
assert "field_2" in NewModel._meta.fields_map
assert "field_2" in Config._meta.fields_map
with pytest.raises(errors):
await NewModel.create(name="")
with pytest.raises(errors):
await Config.create(key="", label="", value={})

28
tests/assets/fake/db.py Normal file
View File

@ -0,0 +1,28 @@
import asyncclick as click
from settings import TORTOISE_ORM
from tests._utils import drop_db, init_db
@click.group()
def cli(): ...
@cli.command()
async def create():
await init_db(TORTOISE_ORM, False)
click.echo(f"Success to create databases for {TORTOISE_ORM['connections']}")
@cli.command()
async def drop():
await drop_db(TORTOISE_ORM)
click.echo(f"Dropped databases for {TORTOISE_ORM['connections']}")
def main():
cli()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,22 @@
import os
from datetime import date
from tortoise.contrib.test import MEMORY_SQLITE
DB_URL = (
_u.replace("\\{\\}", f"aerich_fake_{date.today():%Y%m%d}")
if (_u := os.getenv("TEST_DB"))
else MEMORY_SQLITE
)
DB_URL_SECOND = (DB_URL + "_second") if DB_URL != MEMORY_SQLITE else MEMORY_SQLITE
TORTOISE_ORM = {
"connections": {
"default": DB_URL.replace(MEMORY_SQLITE, "sqlite://db.sqlite3"),
"second": DB_URL_SECOND.replace(MEMORY_SQLITE, "sqlite://db_second.sqlite3"),
},
"apps": {
"models": {"models": ["models", "aerich.models"], "default_connection": "default"},
"models_second": {"models": ["models_second"], "default_connection": "second"},
},
}

View File

@ -0,0 +1,76 @@
import uuid
import pytest
from models import Foo
from tortoise.exceptions import IntegrityError
@pytest.mark.asyncio
async def test_allow_duplicate() -> None:
await Foo.all().delete()
await Foo.create(name="foo")
obj = await Foo.create(name="foo")
assert (await Foo.all().count()) == 2
await obj.delete()
@pytest.mark.asyncio
async def test_unique_is_true() -> None:
with pytest.raises(IntegrityError):
await Foo.create(name="foo")
await Foo.create(name="foo")
@pytest.mark.asyncio
async def test_add_unique_field() -> None:
if not await Foo.filter(age=0).exists():
await Foo.create(name="0_" + uuid.uuid4().hex, age=0)
with pytest.raises(IntegrityError):
await Foo.create(name=uuid.uuid4().hex, age=0)
@pytest.mark.asyncio
async def test_drop_unique_field() -> None:
name = "1_" + uuid.uuid4().hex
await Foo.create(name=name, age=0)
assert await Foo.filter(name=name).exists()
@pytest.mark.asyncio
async def test_with_age_field() -> None:
name = "2_" + uuid.uuid4().hex
await Foo.create(name=name, age=0)
obj = await Foo.get(name=name)
assert obj.age == 0
@pytest.mark.asyncio
async def test_without_age_field() -> None:
name = "3_" + uuid.uuid4().hex
await Foo.create(name=name, age=0)
obj = await Foo.get(name=name)
assert getattr(obj, "age", None) is None
@pytest.mark.asyncio
async def test_m2m_with_custom_through() -> None:
from models import FooGroup, Group
name = "4_" + uuid.uuid4().hex
foo = await Foo.create(name=name)
group = await Group.create(name=name + "1")
await FooGroup.all().delete()
await foo.groups.add(group)
foo_group = await FooGroup.get(foo=foo, group=group)
assert not foo_group.is_active
@pytest.mark.asyncio
async def test_add_m2m_field_after_init_db() -> None:
from models import Group
name = "5_" + uuid.uuid4().hex
foo = await Foo.create(name=name)
group = await Group.create(name=name + "1")
await foo.groups.add(group)
assert (await group.users.all().first()) == foo

View File

@ -0,0 +1,28 @@
from __future__ import annotations
import asyncio
from collections.abc import Generator
import pytest
import pytest_asyncio
import settings
from tortoise import Tortoise, connections
@pytest.fixture(scope="session")
def event_loop() -> Generator:
policy = asyncio.get_event_loop_policy()
res = policy.new_event_loop()
asyncio.set_event_loop(res)
res._close = res.close # type:ignore[attr-defined]
res.close = lambda: None # type:ignore[method-assign]
yield res
res._close() # type:ignore[attr-defined]
@pytest_asyncio.fixture(scope="session", autouse=True)
async def api(event_loop, request):
await Tortoise.init(config=settings.TORTOISE_ORM)
request.addfinalizer(lambda: event_loop.run_until_complete(connections.close_all(discard=True)))

View File

@ -0,0 +1,5 @@
from tortoise import Model, fields
class Foo(Model):
name = fields.CharField(max_length=60, db_index=False)

View File

@ -0,0 +1,4 @@
TORTOISE_ORM = {
"connections": {"default": "sqlite://db.sqlite3"},
"apps": {"models": {"models": ["models", "aerich.models"]}},
}

View File

@ -1,10 +1,15 @@
from __future__ import annotations
import datetime import datetime
import uuid import uuid
from enum import IntEnum from enum import IntEnum
from tortoise import Model, fields from tortoise import Model, fields
from tortoise.contrib.mysql.indexes import FullTextIndex
from tortoise.contrib.postgres.indexes import HashIndex
from tortoise.indexes import Index from tortoise.indexes import Index
from tests._utils import Dialect
from tests.indexes import CustomIndex from tests.indexes import CustomIndex
@ -34,7 +39,7 @@ class User(Model):
intro = fields.TextField(default="") intro = fields.TextField(default="")
longitude = fields.DecimalField(max_digits=10, decimal_places=8) longitude = fields.DecimalField(max_digits=10, decimal_places=8)
products: fields.ManyToManyRelation["Product"] products: fields.ManyToManyRelation[Product]
class Meta: class Meta:
# reverse indexes elements # reverse indexes elements
@ -44,10 +49,11 @@ class User(Model):
class Email(Model): class Email(Model):
email_id = fields.IntField(primary_key=True) email_id = fields.IntField(primary_key=True)
email = fields.CharField(max_length=200, db_index=True) email = fields.CharField(max_length=200, db_index=True)
company = fields.CharField(max_length=100, db_index=True, unique=True)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
address = fields.CharField(max_length=200) address = fields.CharField(max_length=200)
users: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User") users: fields.ManyToManyRelation[User] = fields.ManyToManyField("models.User")
config: fields.OneToOneRelation["Config"] = fields.OneToOneField("models.Config") config: fields.OneToOneRelation[Config] = fields.OneToOneField("models.Config")
def default_name(): def default_name():
@ -63,8 +69,17 @@ class Category(Model):
title = fields.CharField(max_length=20, unique=False) title = fields.CharField(max_length=20, unique=False)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Meta:
if Dialect.is_postgres():
indexes = [HashIndex(fields=("slug",))]
elif Dialect.is_mysql():
indexes = [FullTextIndex(fields=("slug",))] # type:ignore
else:
indexes = [Index(fields=("slug",))] # type:ignore
class Product(Model): class Product(Model):
id = fields.BigIntField(primary_key=True)
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField( categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models.Category", null=False "models.Category", null=False
) )
@ -75,20 +90,24 @@ class Product(Model):
view_num = fields.IntField(description="View Num", default=0) view_num = fields.IntField(description="View Num", default=0)
sort = fields.IntField() sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed") is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField( type: int = fields.IntEnumField(
ProductType, description="Product Type", source_field="type_db_alias" ProductType, description="Product Type", source_field="type_db_alias"
) )
pic = fields.CharField(max_length=200) pic = fields.CharField(max_length=200)
body = fields.TextField() body = fields.TextField()
price = fields.FloatField(null=True)
no = fields.UUIDField(db_index=True)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
is_deleted = fields.BooleanField(default=False) is_deleted = fields.BooleanField(default=False)
class Meta: class Meta:
unique_together = (("name", "type"),) unique_together = (("name", "type"),)
indexes = (("name", "type"),) indexes = (("name", "type"),)
managed = True
class Config(Model): class Config(Model):
slug = fields.CharField(primary_key=True, max_length=20)
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField( categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models.Category", through="config_category_map", related_name="category_set" "models.Category", through="config_category_map", related_name="category_set"
) )
@ -100,7 +119,22 @@ class Config(Model):
"models.User", description="User" "models.User", description="User"
) )
email: fields.OneToOneRelation["Email"] email: fields.OneToOneRelation[Email]
class Meta:
managed = True
class DontManageMe(Model):
name = fields.CharField(max_length=50)
class Meta:
managed = False
class Ignore(Model):
class Meta:
managed = False
class NewModel(Model): class NewModel(Model):

View File

@ -56,7 +56,7 @@ class Product(Model):
view_num = fields.IntField(description="View Num") view_num = fields.IntField(description="View Num")
sort = fields.IntField() sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed") is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField( type: int = fields.IntEnumField(
ProductType, description="Product Type", source_field="type_db_alias" ProductType, description="Product Type", source_field="type_db_alias"
) )
image = fields.CharField(max_length=200) image = fields.CharField(max_length=200)

View File

@ -40,6 +40,7 @@ class User(Model):
class Email(Model): class Email(Model):
email = fields.CharField(max_length=200) email = fields.CharField(max_length=200)
company = fields.CharField(max_length=100, db_index=True)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", db_constraint=False "models.User", db_constraint=False
@ -52,8 +53,12 @@ class Category(Model):
user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField( user: fields.ForeignKeyRelation[User] = fields.ForeignKeyField(
"models.User", description="User" "models.User", description="User"
) )
title = fields.CharField(max_length=20, unique=True)
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Meta:
indexes = [Index(fields=("slug",))]
class Product(Model): class Product(Model):
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category") categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category")
@ -62,7 +67,7 @@ class Product(Model):
view_num = fields.IntField(description="View Num") view_num = fields.IntField(description="View Num")
sort = fields.IntField() sort = fields.IntField()
is_review = fields.BooleanField(description="Is Reviewed") is_review = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField( type: int = fields.IntEnumField(
ProductType, description="Product Type", source_field="type_db_alias" ProductType, description="Product Type", source_field="type_db_alias"
) )
image = fields.CharField(max_length=200) image = fields.CharField(max_length=200)
@ -72,6 +77,7 @@ class Product(Model):
class Config(Model): class Config(Model):
slug = fields.CharField(primary_key=True, max_length=10)
category: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category") category: fields.ManyToManyRelation[Category] = fields.ManyToManyField("models.Category")
categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField( categories: fields.ManyToManyRelation[Category] = fields.ManyToManyField(
"models.Category", through="config_category_map", related_name="config_set" "models.Category", through="config_category_map", related_name="config_set"
@ -84,3 +90,40 @@ class Config(Model):
class Meta: class Meta:
table = "configs" table = "configs"
class DontManageMe(Model):
name = fields.CharField(max_length=50)
class Meta:
table = "dont_manage"
class Ignore(Model):
name = fields.CharField(max_length=50)
class Meta:
managed = True
def main() -> None:
"""Generate a python file for the old_models_describe"""
from pathlib import Path
from tortoise import run_async
from tortoise.contrib.test import init_memory_sqlite
from aerich.utils import get_models_describe
@init_memory_sqlite
async def run() -> None:
old_models_describe = get_models_describe("models")
p = Path("old_models_describe.py")
p.write_text(f"{old_models_describe = }", encoding="utf-8")
print(f"Write value to {p}\nYou can reformat it by `ruff format {p}`")
run_async(run())
if __name__ == "__main__":
main()

11
tests/test_command.py Normal file
View File

@ -0,0 +1,11 @@
from aerich import Command
from conftest import tortoise_orm
async def test_command(mocker):
mocker.patch("os.listdir", return_value=[])
async with Command(tortoise_orm) as command:
history = await command.history()
heads = await command.heads()
assert history == []
assert heads == []

View File

@ -1,3 +1,5 @@
import tortoise
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
@ -8,6 +10,21 @@ from tests.models import Category, Product, User
def test_create_table(): def test_create_table():
ret = Migrate.ddl.create_table(Category) ret = Migrate.ddl.create_table(Category)
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
if tortoise.__version__ >= "0.24":
assert (
ret
== """CREATE TABLE IF NOT EXISTS `category` (
`id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,
`slug` VARCHAR(100) NOT NULL,
`name` VARCHAR(200),
`title` VARCHAR(20) NOT NULL,
`created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
`owner_id` INT NOT NULL COMMENT 'User',
CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE,
FULLTEXT KEY `idx_category_slug_e9bcff` (`slug`)
) CHARACTER SET utf8mb4"""
)
return
assert ( assert (
ret ret
== """CREATE TABLE IF NOT EXISTS `category` ( == """CREATE TABLE IF NOT EXISTS `category` (
@ -15,23 +32,26 @@ def test_create_table():
`slug` VARCHAR(100) NOT NULL, `slug` VARCHAR(100) NOT NULL,
`name` VARCHAR(200), `name` VARCHAR(200),
`title` VARCHAR(20) NOT NULL, `title` VARCHAR(20) NOT NULL,
`created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
`owner_id` INT NOT NULL COMMENT 'User', `owner_id` INT NOT NULL COMMENT 'User',
CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE
) CHARACTER SET utf8mb4""" ) CHARACTER SET utf8mb4;
CREATE FULLTEXT INDEX `idx_category_slug_e9bcff` ON `category` (`slug`)"""
) )
elif isinstance(Migrate.ddl, SqliteDDL): elif isinstance(Migrate.ddl, SqliteDDL):
exists = "IF NOT EXISTS " if tortoise.__version__ >= "0.24" else ""
assert ( assert (
ret ret
== """CREATE TABLE IF NOT EXISTS "category" ( == f"""CREATE TABLE IF NOT EXISTS "category" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, "id" INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
"slug" VARCHAR(100) NOT NULL, "slug" VARCHAR(100) NOT NULL,
"name" VARCHAR(200), "name" VARCHAR(200),
"title" VARCHAR(20) NOT NULL, "title" VARCHAR(20) NOT NULL,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */ "owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE /* User */
)""" );
CREATE INDEX {exists}"idx_category_slug_e9bcff" ON "category" ("slug")"""
) )
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
@ -42,9 +62,10 @@ def test_create_table():
"slug" VARCHAR(100) NOT NULL, "slug" VARCHAR(100) NOT NULL,
"name" VARCHAR(200), "name" VARCHAR(200),
"title" VARCHAR(20) NOT NULL, "title" VARCHAR(20) NOT NULL,
"created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, "created_at" TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP,
"owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE "owner_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE
); );
CREATE INDEX IF NOT EXISTS "idx_category_slug_e9bcff" ON "category" USING HASH ("slug");
COMMENT ON COLUMN "category"."owner_id" IS 'User'""" COMMENT ON COLUMN "category"."owner_id" IS 'User'"""
) )
@ -58,13 +79,13 @@ def test_drop_table():
def test_add_column(): def test_add_column():
ret = Migrate.ddl.add_column(Category, Category._meta.fields_map.get("name").describe(False)) ret = Migrate.ddl.add_column(Category, Category._meta.fields_map["name"].describe(False))
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` ADD `name` VARCHAR(200)" assert ret == "ALTER TABLE `category` ADD `name` VARCHAR(200)"
else: else:
assert ret == 'ALTER TABLE "category" ADD "name" VARCHAR(200)' assert ret == 'ALTER TABLE "category" ADD "name" VARCHAR(200)'
# add unique column # add unique column
ret = Migrate.ddl.add_column(User, User._meta.fields_map.get("username").describe(False)) ret = Migrate.ddl.add_column(User, User._meta.fields_map["username"].describe(False))
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `user` ADD `username` VARCHAR(20) NOT NULL UNIQUE" assert ret == "ALTER TABLE `user` ADD `username` VARCHAR(20) NOT NULL UNIQUE"
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
@ -77,15 +98,13 @@ def test_modify_column():
if isinstance(Migrate.ddl, SqliteDDL): if isinstance(Migrate.ddl, SqliteDDL):
return return
ret0 = Migrate.ddl.modify_column( ret0 = Migrate.ddl.modify_column(Category, Category._meta.fields_map["name"].describe(False))
Category, Category._meta.fields_map.get("name").describe(False) ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map["is_active"].describe(False))
)
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active").describe(False))
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)" assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200)"
assert ( assert (
ret1 ret1
== "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1" == "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL COMMENT 'Is Active' DEFAULT 1"
) )
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert ( assert (
@ -101,14 +120,14 @@ def test_modify_column():
def test_alter_column_default(): def test_alter_column_default():
if isinstance(Migrate.ddl, SqliteDDL): if isinstance(Migrate.ddl, SqliteDDL):
return return
ret = Migrate.ddl.alter_column_default(User, User._meta.fields_map.get("intro").describe(False)) ret = Migrate.ddl.alter_column_default(User, User._meta.fields_map["intro"].describe(False))
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "user" ALTER COLUMN "intro" SET DEFAULT \'\'' assert ret == 'ALTER TABLE "user" ALTER COLUMN "intro" SET DEFAULT \'\''
elif isinstance(Migrate.ddl, MysqlDDL): elif isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `user` ALTER COLUMN `intro` SET DEFAULT ''" assert ret == "ALTER TABLE `user` ALTER COLUMN `intro` SET DEFAULT ''"
ret = Migrate.ddl.alter_column_default( ret = Migrate.ddl.alter_column_default(
Category, Category._meta.fields_map.get("created_at").describe(False) Category, Category._meta.fields_map["created_at"].describe(False)
) )
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):
assert ( assert (
@ -121,7 +140,7 @@ def test_alter_column_default():
) )
ret = Migrate.ddl.alter_column_default( ret = Migrate.ddl.alter_column_default(
Product, Product._meta.fields_map.get("view_num").describe(False) Product, Product._meta.fields_map["view_num"].describe(False)
) )
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0' assert ret == 'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0'
@ -132,9 +151,7 @@ def test_alter_column_default():
def test_alter_column_null(): def test_alter_column_null():
if isinstance(Migrate.ddl, (SqliteDDL, MysqlDDL)): if isinstance(Migrate.ddl, (SqliteDDL, MysqlDDL)):
return return
ret = Migrate.ddl.alter_column_null( ret = Migrate.ddl.alter_column_null(Category, Category._meta.fields_map["name"].describe(False))
Category, Category._meta.fields_map.get("name").describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" DROP NOT NULL' assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" DROP NOT NULL'
@ -142,10 +159,10 @@ def test_alter_column_null():
def test_set_comment(): def test_set_comment():
if isinstance(Migrate.ddl, (SqliteDDL, MysqlDDL)): if isinstance(Migrate.ddl, (SqliteDDL, MysqlDDL)):
return return
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("name").describe(False)) ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map["name"].describe(False))
assert ret == 'COMMENT ON COLUMN "category"."name" IS NULL' assert ret == 'COMMENT ON COLUMN "category"."name" IS NULL'
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("owner").describe(False)) ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map["owner"].describe(False))
assert ret == 'COMMENT ON COLUMN "category"."owner_id" IS \'User\'' assert ret == 'COMMENT ON COLUMN "category"."owner_id" IS \'User\''
@ -163,6 +180,14 @@ def test_add_index():
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert index == "ALTER TABLE `category` ADD INDEX `idx_category_name_8b0cb9` (`name`)" assert index == "ALTER TABLE `category` ADD INDEX `idx_category_name_8b0cb9` (`name`)"
assert index_u == "ALTER TABLE `category` ADD UNIQUE INDEX `name` (`name`)" assert index_u == "ALTER TABLE `category` ADD UNIQUE INDEX `name` (`name`)"
elif isinstance(Migrate.ddl, PostgresDDL):
assert (
index == 'CREATE INDEX IF NOT EXISTS "idx_category_name_8b0cb9" ON "category" ("name")'
)
assert (
index_u
== 'CREATE UNIQUE INDEX IF NOT EXISTS "uid_category_name_8b0cb9" ON "category" ("name")'
)
else: else:
assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")' assert index == 'CREATE INDEX "idx_category_name_8b0cb9" ON "category" ("name")'
assert index_u == 'CREATE UNIQUE INDEX "uid_category_name_8b0cb9" ON "category" ("name")' assert index_u == 'CREATE UNIQUE INDEX "uid_category_name_8b0cb9" ON "category" ("name")'
@ -181,7 +206,7 @@ def test_drop_index():
def test_add_fk(): def test_add_fk():
ret = Migrate.ddl.add_fk( ret = Migrate.ddl.add_fk(
Category, Category._meta.fields_map.get("owner").describe(False), User.describe(False) Category, Category._meta.fields_map["owner"].describe(False), User.describe(False)
) )
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ( assert (
@ -197,7 +222,7 @@ def test_add_fk():
def test_drop_fk(): def test_drop_fk():
ret = Migrate.ddl.drop_fk( ret = Migrate.ddl.drop_fk(
Category, Category._meta.fields_map.get("owner").describe(False), User.describe(False) Category, Category._meta.fields_map["owner"].describe(False), User.describe(False)
) )
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` DROP FOREIGN KEY `fk_category_user_110d4c63`" assert ret == "ALTER TABLE `category` DROP FOREIGN KEY `fk_category_user_110d4c63`"

106
tests/test_fake.py Normal file
View File

@ -0,0 +1,106 @@
from __future__ import annotations
import os
import re
from pathlib import Path
from tests._utils import Dialect, run_shell
def _append_field(*files: str, name="field_1") -> None:
for file in files:
p = Path(file)
field = f" {name} = fields.IntField(default=0)"
with p.open("a") as f:
f.write(os.linesep + field)
def test_fake(new_aerich_project):
if Dialect.is_sqlite():
# TODO: go ahead if sqlite alter-column supported
return
output = run_shell("aerich init -t settings.TORTOISE_ORM")
assert "Success" in output
output = run_shell("aerich init-db")
assert "Success" in output
output = run_shell("aerich --app models_second init-db")
assert "Success" in output
output = run_shell("pytest _tests.py::test_init_db")
assert "error" not in output.lower()
_append_field("models.py", "models_second.py")
output = run_shell("aerich migrate")
assert "Success" in output
output = run_shell("aerich --app models_second migrate")
assert "Success" in output
output = run_shell("aerich upgrade --fake")
assert "FAKED" in output
output = run_shell("aerich --app models_second upgrade --fake")
assert "FAKED" in output
output = run_shell("pytest _tests.py::test_fake_field_1")
assert "error" not in output.lower()
_append_field("models.py", "models_second.py", name="field_2")
output = run_shell("aerich migrate")
assert "Success" in output
output = run_shell("aerich --app models_second migrate")
assert "Success" in output
output = run_shell("aerich heads")
assert "_update.py" in output
output = run_shell("aerich upgrade --fake")
assert "FAKED" in output
output = run_shell("aerich --app models_second upgrade --fake")
assert "FAKED" in output
output = run_shell("pytest _tests.py::test_fake_field_2")
assert "error" not in output.lower()
output = run_shell("aerich heads")
assert "No available heads." in output
output = run_shell("aerich --app models_second heads")
assert "No available heads." in output
_append_field("models.py", "models_second.py", name="field_3")
run_shell("aerich migrate", capture_output=False)
run_shell("aerich --app models_second migrate", capture_output=False)
run_shell("aerich upgrade --fake", capture_output=False)
run_shell("aerich --app models_second upgrade --fake", capture_output=False)
output = run_shell("aerich downgrade --fake -v 2 --yes", input="y\n")
assert "FAKED" in output
output = run_shell("aerich --app models_second downgrade --fake -v 2 --yes", input="y\n")
assert "FAKED" in output
output = run_shell("aerich heads")
assert "No available heads." not in output
assert not re.search(r"1_\d+_update\.py", output)
assert re.search(r"2_\d+_update\.py", output)
output = run_shell("aerich --app models_second heads")
assert "No available heads." not in output
assert not re.search(r"1_\d+_update\.py", output)
assert re.search(r"2_\d+_update\.py", output)
output = run_shell("aerich downgrade --fake -v 1 --yes", input="y\n")
assert "FAKED" in output
output = run_shell("aerich --app models_second downgrade --fake -v 1 --yes", input="y\n")
assert "FAKED" in output
output = run_shell("aerich heads")
assert "No available heads." not in output
assert re.search(r"1_\d+_update\.py", output)
assert re.search(r"2_\d+_update\.py", output)
output = run_shell("aerich --app models_second heads")
assert "No available heads." not in output
assert re.search(r"1_\d+_update\.py", output)
assert re.search(r"2_\d+_update\.py", output)
output = run_shell("aerich upgrade --fake")
assert "FAKED" in output
output = run_shell("aerich --app models_second upgrade --fake")
assert "FAKED" in output
output = run_shell("aerich heads")
assert "No available heads." in output
output = run_shell("aerich --app models_second heads")
assert "No available heads." in output
output = run_shell("aerich downgrade --fake -v 1 --yes", input="y\n")
assert "FAKED" in output
output = run_shell("aerich --app models_second downgrade --fake -v 1 --yes", input="y\n")
assert "FAKED" in output
output = run_shell("aerich heads")
assert "No available heads." not in output
assert re.search(r"1_\d+_update\.py", output)
assert re.search(r"2_\d+_update\.py", output)
output = run_shell("aerich --app models_second heads")
assert "No available heads." not in output
assert re.search(r"1_\d+_update\.py", output)
assert re.search(r"2_\d+_update\.py", output)

17
tests/test_inspectdb.py Normal file
View File

@ -0,0 +1,17 @@
from tests._utils import Dialect, run_shell
def test_inspect(new_aerich_project):
if Dialect.is_sqlite():
# TODO: test sqlite after #384 fixed
return
run_shell("aerich init -t settings.TORTOISE_ORM")
run_shell("aerich init-db")
ret = run_shell("aerich inspectdb -t product")
assert ret.startswith("from tortoise import Model, fields")
assert "primary_key=True" in ret
assert "fields.DatetimeField" in ret
assert "fields.FloatField" in ret
assert "fields.UUIDField" in ret
if Dialect.is_mysql():
assert "db_index=True" in ret

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from pathlib import Path from pathlib import Path
import pytest import pytest
@ -5,6 +7,7 @@ import tortoise
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from tortoise.indexes import Index from tortoise.indexes import Index
from aerich._compat import tortoise_version_less_than
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
@ -13,6 +16,16 @@ from aerich.migrate import MIGRATE_TEMPLATE, Migrate
from aerich.utils import get_models_describe from aerich.utils import get_models_describe
from tests.indexes import CustomIndex from tests.indexes import CustomIndex
def describe_index(idx: Index) -> Index | dict:
# tortoise-orm>=0.24 changes Index desribe to be dict
if tortoise_version_less_than("0.24"):
return idx
if hasattr(idx, "describe"):
return idx.describe()
return idx
# tortoise-orm>=0.21 changes IntField constraints # tortoise-orm>=0.21 changes IntField constraints
# from {"ge": 1, "le": 2147483647} to {"ge": -2147483648, "le": 2147483647} # from {"ge": 1, "le": 2147483647} to {"ge": -2147483648, "le": 2147483647}
MIN_INT = 1 if tortoise.__version__ < "0.21" else -2147483648 MIN_INT = 1 if tortoise.__version__ < "0.21" else -2147483648
@ -25,7 +38,7 @@ old_models_describe = {
"description": None, "description": None,
"docstring": None, "docstring": None,
"unique_together": [], "unique_together": [],
"indexes": [], "indexes": [describe_index(Index(fields=("slug",)))],
"pk_field": { "pk_field": {
"name": "id", "name": "id",
"field_type": "IntField", "field_type": "IntField",
@ -177,19 +190,19 @@ old_models_describe = {
"unique_together": [], "unique_together": [],
"indexes": [], "indexes": [],
"pk_field": { "pk_field": {
"name": "id", "name": "slug",
"field_type": "IntField", "field_type": "CharField",
"db_column": "id", "db_column": "slug",
"python_type": "int", "python_type": "str",
"generated": True, "generated": False,
"nullable": False, "nullable": False,
"unique": True, "unique": True,
"indexed": True, "indexed": True,
"default": None, "default": None,
"description": None, "description": None,
"docstring": None, "docstring": None,
"constraints": {"ge": MIN_INT, "le": 2147483647}, "constraints": {"max_length": 10},
"db_field_types": {"": "INT"}, "db_field_types": {"": "VARCHAR(10)"},
}, },
"data_fields": [ "data_fields": [
{ {
@ -355,6 +368,21 @@ old_models_describe = {
"constraints": {"max_length": 200}, "constraints": {"max_length": 200},
"db_field_types": {"": "VARCHAR(200)"}, "db_field_types": {"": "VARCHAR(200)"},
}, },
{
"name": "company",
"field_type": "CharField",
"db_column": "company",
"python_type": "str",
"generated": False,
"nullable": False,
"unique": False,
"indexed": True,
"default": None,
"description": None,
"docstring": None,
"constraints": {"max_length": 100},
"db_field_types": {"": "VARCHAR(100)"},
},
{ {
"name": "is_primary", "name": "is_primary",
"field_type": "BooleanField", "field_type": "BooleanField",
@ -640,7 +668,10 @@ old_models_describe = {
"description": None, "description": None,
"docstring": None, "docstring": None,
"unique_together": [], "unique_together": [],
"indexes": [Index(fields=("username", "is_active")), CustomIndex(fields=("is_superuser",))], "indexes": [
describe_index(Index(fields=("username", "is_active"))),
describe_index(CustomIndex(fields=("is_superuser",))),
],
"pk_field": { "pk_field": {
"name": "id", "name": "id",
"field_type": "IntField", "field_type": "IntField",
@ -911,11 +942,15 @@ def test_migrate(mocker: MockerFixture):
""" """
models.py diff with old_models.py models.py diff with old_models.py
- change email pk: id -> email_id - change email pk: id -> email_id
- change product pk field type: IntField -> BigIntField
- change config pk field attribute: max_length=10 -> max_length=20
- add field: Email.address - add field: Email.address
- add fk field: Config.user - add fk field: Config.user
- drop fk field: Email.user - drop fk field: Email.user
- drop field: User.avatar - drop field: User.avatar
- add index: Email.email - add index: Email.email
- add unique to indexed field: Email.company
- change index type for indexed field: Email.slug
- add many to many: Email.users - add many to many: Email.users
- add one to one: Email.config - add one to one: Email.config
- remove unique: Category.title - remove unique: Category.title
@ -952,179 +987,202 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `category` DROP INDEX `title`", "ALTER TABLE `category` DROP INDEX `title`",
"ALTER TABLE `category` RENAME COLUMN `user_id` TO `owner_id`", "ALTER TABLE `category` RENAME COLUMN `user_id` TO `owner_id`",
"ALTER TABLE `category` ADD CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE", "ALTER TABLE `category` ADD CONSTRAINT `fk_category_user_110d4c63` FOREIGN KEY (`owner_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
"ALTER TABLE `category` ADD FULLTEXT INDEX `idx_category_slug_e9bcff` (`slug`)",
"ALTER TABLE `category` DROP INDEX `idx_category_slug_e9bcff`",
"ALTER TABLE `email` DROP COLUMN `user_id`", "ALTER TABLE `email` DROP COLUMN `user_id`",
"ALTER TABLE `config` DROP COLUMN `name`", "ALTER TABLE `config` DROP COLUMN `name`",
"ALTER TABLE `config` DROP INDEX `name`", "ALTER TABLE `config` DROP INDEX `name`",
"ALTER TABLE `config` ADD `user_id` INT NOT NULL COMMENT 'User'", "ALTER TABLE `config` ADD `user_id` INT NOT NULL COMMENT 'User'",
"ALTER TABLE `config` ADD CONSTRAINT `fk_config_user_17daa970` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE", "ALTER TABLE `config` ADD CONSTRAINT `fk_config_user_17daa970` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
"ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT", "ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT",
"ALTER TABLE `config` MODIFY COLUMN `value` JSON NOT NULL",
"ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL", "ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL",
"ALTER TABLE `email` ADD CONSTRAINT `fk_email_config_76a9dc71` FOREIGN KEY (`config_id`) REFERENCES `config` (`id`) ON DELETE CASCADE", "ALTER TABLE `email` ADD CONSTRAINT `fk_email_config_88e28c1b` FOREIGN KEY (`config_id`) REFERENCES `config` (`slug`) ON DELETE CASCADE",
"ALTER TABLE `email` ADD `config_id` INT NOT NULL UNIQUE", "ALTER TABLE `email` ADD `config_id` VARCHAR(20) NOT NULL UNIQUE",
"ALTER TABLE `email` DROP INDEX `idx_email_company_1c9234`, ADD UNIQUE (`company`)",
"ALTER TABLE `configs` RENAME TO `config`", "ALTER TABLE `configs` RENAME TO `config`",
"ALTER TABLE `product` DROP COLUMN `uuid`", "ALTER TABLE `product` DROP COLUMN `uuid`",
"ALTER TABLE `product` DROP INDEX `uuid`", "ALTER TABLE `product` DROP INDEX `uuid`",
"ALTER TABLE `product` RENAME COLUMN `image` TO `pic`", "ALTER TABLE `product` RENAME COLUMN `image` TO `pic`",
"ALTER TABLE `product` ADD `price` DOUBLE",
"ALTER TABLE `product` ADD `no` CHAR(36) NOT NULL",
"ALTER TABLE `email` RENAME COLUMN `id` TO `email_id`", "ALTER TABLE `email` RENAME COLUMN `id` TO `email_id`",
"ALTER TABLE `product` ADD INDEX `idx_product_name_869427` (`name`, `type_db_alias`)", "ALTER TABLE `product` ADD INDEX `idx_product_name_869427` (`name`, `type_db_alias`)",
"ALTER TABLE `product` ADD INDEX `idx_product_no_e4d701` (`no`)",
"ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)", "ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)",
"ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_869427` (`name`, `type_db_alias`)", "ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_869427` (`name`, `type_db_alias`)",
"ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0", "ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0",
"ALTER TABLE `product` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `product` RENAME COLUMN `is_delete` TO `is_deleted`", "ALTER TABLE `product` RENAME COLUMN `is_delete` TO `is_deleted`",
"ALTER TABLE `product` RENAME COLUMN `is_review` TO `is_reviewed`", "ALTER TABLE `product` RENAME COLUMN `is_review` TO `is_reviewed`",
"ALTER TABLE `product` MODIFY COLUMN `id` BIGINT NOT NULL",
"ALTER TABLE `user` DROP COLUMN `avatar`", "ALTER TABLE `user` DROP COLUMN `avatar`",
"ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(100) NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(100) NOT NULL",
"ALTER TABLE `user` MODIFY COLUMN `intro` LONGTEXT NOT NULL",
"ALTER TABLE `user` MODIFY COLUMN `last_login` DATETIME(6) NOT NULL COMMENT 'Last Login'",
"ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(10,8) NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(10,8) NOT NULL",
"ALTER TABLE `user` ADD UNIQUE INDEX `username` (`username`)", "ALTER TABLE `user` ADD UNIQUE INDEX `username` (`username`)",
"CREATE TABLE `email_user` (\n `email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4", "CREATE TABLE `email_user` (\n `email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"CREATE TABLE IF NOT EXISTS `newmodel` (\n `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,\n `name` VARCHAR(50) NOT NULL\n) CHARACTER SET utf8mb4", "CREATE TABLE IF NOT EXISTS `newmodel` (\n `id` INT NOT NULL PRIMARY KEY AUTO_INCREMENT,\n `name` VARCHAR(50) NOT NULL\n) CHARACTER SET utf8mb4",
"ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)", "CREATE TABLE `product_user` (\n `product_id` BIGINT NOT NULL REFERENCES `product` (`id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL", "CREATE TABLE `config_category_map` (\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE,\n `config_id` VARCHAR(20) NOT NULL REFERENCES `config` (`slug`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"CREATE TABLE `product_user` (\n `product_id` INT NOT NULL REFERENCES `product` (`id`) ON DELETE CASCADE,\n `user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"CREATE TABLE `config_category_map` (\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE,\n `config_id` INT NOT NULL REFERENCES `config` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"DROP TABLE IF EXISTS `config_category`", "DROP TABLE IF EXISTS `config_category`",
"ALTER TABLE `config` MODIFY COLUMN `slug` VARCHAR(20) NOT NULL",
} }
upgrade_operators = set(Migrate.upgrade_operators)
upgrade_more_than_expected = upgrade_operators - expected_upgrade_operators
assert not upgrade_more_than_expected
upgrade_less_than_expected = expected_upgrade_operators - upgrade_operators
assert not upgrade_less_than_expected
expected_downgrade_operators = { expected_downgrade_operators = {
"ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL", "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL",
"ALTER TABLE `category` MODIFY COLUMN `slug` VARCHAR(200) NOT NULL", "ALTER TABLE `category` MODIFY COLUMN `slug` VARCHAR(200) NOT NULL",
"ALTER TABLE `category` ADD UNIQUE INDEX `title` (`title`)", "ALTER TABLE `category` ADD UNIQUE INDEX `title` (`title`)",
"ALTER TABLE `category` RENAME COLUMN `owner_id` TO `user_id`", "ALTER TABLE `category` RENAME COLUMN `owner_id` TO `user_id`",
"ALTER TABLE `category` DROP FOREIGN KEY `fk_category_user_110d4c63`", "ALTER TABLE `category` DROP FOREIGN KEY `fk_category_user_110d4c63`",
"ALTER TABLE `category` ADD INDEX `idx_category_slug_e9bcff` (`slug`)",
"ALTER TABLE `category` DROP INDEX `idx_category_slug_e9bcff`",
"ALTER TABLE `config` ADD `name` VARCHAR(100) NOT NULL UNIQUE", "ALTER TABLE `config` ADD `name` VARCHAR(100) NOT NULL UNIQUE",
"ALTER TABLE `config` ADD UNIQUE INDEX `name` (`name`)", "ALTER TABLE `config` ADD UNIQUE INDEX `name` (`name`)",
"ALTER TABLE `config` DROP FOREIGN KEY `fk_config_user_17daa970`", "ALTER TABLE `config` DROP FOREIGN KEY `fk_config_user_17daa970`",
"ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1", "ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1",
"ALTER TABLE `email` ADD `user_id` INT NOT NULL",
"ALTER TABLE `config` DROP COLUMN `user_id`", "ALTER TABLE `config` DROP COLUMN `user_id`",
"ALTER TABLE `config` MODIFY COLUMN `slug` VARCHAR(10) NOT NULL",
"ALTER TABLE `config` RENAME TO `configs`",
"ALTER TABLE `email` ADD `user_id` INT NOT NULL",
"ALTER TABLE `email` DROP COLUMN `address`", "ALTER TABLE `email` DROP COLUMN `address`",
"ALTER TABLE `email` DROP COLUMN `config_id`", "ALTER TABLE `email` DROP COLUMN `config_id`",
"ALTER TABLE `email` DROP FOREIGN KEY `fk_email_config_76a9dc71`", "ALTER TABLE `email` DROP FOREIGN KEY `fk_email_config_88e28c1b`",
"ALTER TABLE `config` RENAME TO `configs`",
"ALTER TABLE `product` RENAME COLUMN `pic` TO `image`",
"ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`", "ALTER TABLE `email` RENAME COLUMN `email_id` TO `id`",
"ALTER TABLE `email` DROP INDEX `company`, ADD INDEX (`idx_email_company_1c9234`)",
"ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`",
"ALTER TABLE `product` RENAME COLUMN `pic` TO `image`",
"ALTER TABLE `product` ADD `uuid` INT NOT NULL UNIQUE", "ALTER TABLE `product` ADD `uuid` INT NOT NULL UNIQUE",
"ALTER TABLE `product` ADD UNIQUE INDEX `uuid` (`uuid`)", "ALTER TABLE `product` ADD UNIQUE INDEX `uuid` (`uuid`)",
"ALTER TABLE `product` DROP INDEX `idx_product_name_869427`", "ALTER TABLE `product` DROP INDEX `idx_product_name_869427`",
"ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`", "ALTER TABLE `product` DROP COLUMN `price`",
"ALTER TABLE `product` DROP COLUMN `no`",
"ALTER TABLE `product` DROP INDEX `uid_product_name_869427`", "ALTER TABLE `product` DROP INDEX `uid_product_name_869427`",
"ALTER TABLE `product` DROP INDEX `idx_product_no_e4d701`",
"ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT", "ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT",
"ALTER TABLE `product` RENAME COLUMN `is_deleted` TO `is_delete`", "ALTER TABLE `product` RENAME COLUMN `is_deleted` TO `is_delete`",
"ALTER TABLE `product` RENAME COLUMN `is_reviewed` TO `is_review`", "ALTER TABLE `product` RENAME COLUMN `is_reviewed` TO `is_review`",
"ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''", "ALTER TABLE `product` MODIFY COLUMN `id` INT NOT NULL",
"ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''",
"ALTER TABLE `user` DROP INDEX `username`", "ALTER TABLE `user` DROP INDEX `username`",
"ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(200) NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `password` VARCHAR(200) NOT NULL",
"DROP TABLE IF EXISTS `email_user`", "DROP TABLE IF EXISTS `email_user`",
"DROP TABLE IF EXISTS `newmodel`", "DROP TABLE IF EXISTS `newmodel`",
"DROP TABLE IF EXISTS `product_user`", "DROP TABLE IF EXISTS `product_user`",
"ALTER TABLE `user` MODIFY COLUMN `intro` LONGTEXT NOT NULL",
"ALTER TABLE `config` MODIFY COLUMN `value` TEXT NOT NULL",
"ALTER TABLE `category` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `product` MODIFY COLUMN `created_at` DATETIME(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6)",
"ALTER TABLE `user` MODIFY COLUMN `last_login` DATETIME(6) NOT NULL COMMENT 'Last Login'",
"ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(12,9) NOT NULL", "ALTER TABLE `user` MODIFY COLUMN `longitude` DECIMAL(12,9) NOT NULL",
"ALTER TABLE `product` MODIFY COLUMN `body` LONGTEXT NOT NULL", "CREATE TABLE `config_category` (\n `config_id` VARCHAR(20) NOT NULL REFERENCES `config` (`slug`) ON DELETE CASCADE,\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"CREATE TABLE `config_category` (\n `config_id` INT NOT NULL REFERENCES `config` (`id`) ON DELETE CASCADE,\n `category_id` INT NOT NULL REFERENCES `category` (`id`) ON DELETE CASCADE\n) CHARACTER SET utf8mb4",
"DROP TABLE IF EXISTS `config_category_map`", "DROP TABLE IF EXISTS `config_category_map`",
} }
assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators) downgrade_operators = set(Migrate.downgrade_operators)
downgrade_more_than_expected = downgrade_operators - expected_downgrade_operators
assert not set(Migrate.downgrade_operators).symmetric_difference( assert not downgrade_more_than_expected
expected_downgrade_operators downgrade_less_than_expected = expected_downgrade_operators - downgrade_operators
) assert not downgrade_less_than_expected
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
expected_upgrade_operators = { expected_upgrade_operators = {
'DROP INDEX IF EXISTS "uid_category_title_f7fc03"', 'DROP INDEX IF EXISTS "uid_category_title_f7fc03"',
'ALTER TABLE "category" ALTER COLUMN "name" DROP NOT NULL', 'ALTER TABLE "category" ALTER COLUMN "name" DROP NOT NULL',
'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(100) USING "slug"::VARCHAR(100)', 'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(100) USING "slug"::VARCHAR(100)',
'ALTER TABLE "category" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ',
'ALTER TABLE "category" RENAME COLUMN "user_id" TO "owner_id"', 'ALTER TABLE "category" RENAME COLUMN "user_id" TO "owner_id"',
'ALTER TABLE "category" ADD CONSTRAINT "fk_category_user_110d4c63" FOREIGN KEY ("owner_id") REFERENCES "user" ("id") ON DELETE CASCADE', 'ALTER TABLE "category" ADD CONSTRAINT "fk_category_user_110d4c63" FOREIGN KEY ("owner_id") REFERENCES "user" ("id") ON DELETE CASCADE',
'CREATE INDEX IF NOT EXISTS "idx_category_slug_e9bcff" ON "category" USING HASH ("slug")',
'DROP INDEX IF EXISTS "idx_category_slug_e9bcff"',
'ALTER TABLE "configs" RENAME TO "config"',
'ALTER TABLE "config" DROP COLUMN "name"', 'ALTER TABLE "config" DROP COLUMN "name"',
'DROP INDEX IF EXISTS "uid_config_name_2c83c8"', 'DROP INDEX IF EXISTS "uid_config_name_2c83c8"',
'ALTER TABLE "config" ADD "user_id" INT NOT NULL', 'ALTER TABLE "config" ADD "user_id" INT NOT NULL',
'ALTER TABLE "config" ADD CONSTRAINT "fk_config_user_17daa970" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE', 'ALTER TABLE "config" ADD CONSTRAINT "fk_config_user_17daa970" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE',
'ALTER TABLE "config" ALTER COLUMN "status" DROP DEFAULT', 'ALTER TABLE "config" ALTER COLUMN "status" DROP DEFAULT',
'ALTER TABLE "config" ALTER COLUMN "value" TYPE JSONB USING "value"::JSONB', 'ALTER TABLE "config" ALTER COLUMN "slug" TYPE VARCHAR(20) USING "slug"::VARCHAR(20)',
'ALTER TABLE "configs" RENAME TO "config"', 'ALTER TABLE "email" ADD "config_id" VARCHAR(20) NOT NULL UNIQUE',
'ALTER TABLE "email" ADD "address" VARCHAR(200) NOT NULL', 'ALTER TABLE "email" ADD "address" VARCHAR(200) NOT NULL',
'ALTER TABLE "email" RENAME COLUMN "id" TO "email_id"', 'ALTER TABLE "email" RENAME COLUMN "id" TO "email_id"',
'ALTER TABLE "email" DROP COLUMN "user_id"', 'ALTER TABLE "email" DROP COLUMN "user_id"',
'ALTER TABLE "email" ADD CONSTRAINT "fk_email_config_76a9dc71" FOREIGN KEY ("config_id") REFERENCES "config" ("id") ON DELETE CASCADE', 'ALTER TABLE "email" ADD CONSTRAINT "fk_email_config_88e28c1b" FOREIGN KEY ("config_id") REFERENCES "config" ("slug") ON DELETE CASCADE',
'ALTER TABLE "email" ADD "config_id" INT NOT NULL UNIQUE', 'DROP INDEX IF EXISTS "idx_email_company_1c9234"',
'CREATE UNIQUE INDEX IF NOT EXISTS "uid_email_company_1c9234" ON "email" ("company")',
'DROP INDEX IF EXISTS "uid_product_uuid_d33c18"', 'DROP INDEX IF EXISTS "uid_product_uuid_d33c18"',
'ALTER TABLE "product" DROP COLUMN "uuid"', 'ALTER TABLE "product" DROP COLUMN "uuid"',
'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0', 'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0',
'ALTER TABLE "product" RENAME COLUMN "image" TO "pic"', 'ALTER TABLE "product" RENAME COLUMN "image" TO "pic"',
'ALTER TABLE "product" ALTER COLUMN "body" TYPE TEXT USING "body"::TEXT',
'ALTER TABLE "product" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ',
'ALTER TABLE "product" RENAME COLUMN "is_review" TO "is_reviewed"', 'ALTER TABLE "product" RENAME COLUMN "is_review" TO "is_reviewed"',
'ALTER TABLE "product" RENAME COLUMN "is_delete" TO "is_deleted"', 'ALTER TABLE "product" RENAME COLUMN "is_delete" TO "is_deleted"',
'ALTER TABLE "product" ADD "price" DOUBLE PRECISION',
'ALTER TABLE "product" ADD "no" UUID NOT NULL',
'ALTER TABLE "product" ALTER COLUMN "id" TYPE BIGINT USING "id"::BIGINT',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(100) USING "password"::VARCHAR(100)', 'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(100) USING "password"::VARCHAR(100)',
'ALTER TABLE "user" DROP COLUMN "avatar"', 'ALTER TABLE "user" DROP COLUMN "avatar"',
'ALTER TABLE "user" ALTER COLUMN "last_login" TYPE TIMESTAMPTZ USING "last_login"::TIMESTAMPTZ',
'ALTER TABLE "user" ALTER COLUMN "intro" TYPE TEXT USING "intro"::TEXT',
'ALTER TABLE "user" ALTER COLUMN "longitude" TYPE DECIMAL(10,8) USING "longitude"::DECIMAL(10,8)', 'ALTER TABLE "user" ALTER COLUMN "longitude" TYPE DECIMAL(10,8) USING "longitude"::DECIMAL(10,8)',
'CREATE INDEX "idx_product_name_869427" ON "product" ("name", "type_db_alias")', 'CREATE INDEX IF NOT EXISTS "idx_product_name_869427" ON "product" ("name", "type_db_alias")',
'CREATE INDEX "idx_email_email_4a1a33" ON "email" ("email")', 'CREATE INDEX IF NOT EXISTS "idx_email_email_4a1a33" ON "email" ("email")',
'CREATE INDEX IF NOT EXISTS "idx_product_no_e4d701" ON "product" ("no")',
'CREATE TABLE "email_user" (\n "email_id" INT NOT NULL REFERENCES "email" ("email_id") ON DELETE CASCADE,\n "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE\n)', 'CREATE TABLE "email_user" (\n "email_id" INT NOT NULL REFERENCES "email" ("email_id") ON DELETE CASCADE,\n "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE\n)',
'CREATE TABLE IF NOT EXISTS "newmodel" (\n "id" SERIAL NOT NULL PRIMARY KEY,\n "name" VARCHAR(50) NOT NULL\n);\nCOMMENT ON COLUMN "config"."user_id" IS \'User\'', 'CREATE TABLE IF NOT EXISTS "newmodel" (\n "id" SERIAL NOT NULL PRIMARY KEY,\n "name" VARCHAR(50) NOT NULL\n);\nCOMMENT ON COLUMN "config"."user_id" IS \'User\'',
'CREATE UNIQUE INDEX "uid_product_name_869427" ON "product" ("name", "type_db_alias")', 'CREATE UNIQUE INDEX IF NOT EXISTS "uid_product_name_869427" ON "product" ("name", "type_db_alias")',
'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")', 'CREATE UNIQUE INDEX IF NOT EXISTS "uid_user_usernam_9987ab" ON "user" ("username")',
'CREATE TABLE "product_user" (\n "product_id" INT NOT NULL REFERENCES "product" ("id") ON DELETE CASCADE,\n "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE\n)', 'CREATE TABLE "product_user" (\n "product_id" BIGINT NOT NULL REFERENCES "product" ("id") ON DELETE CASCADE,\n "user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE\n)',
'CREATE TABLE "config_category_map" (\n "category_id" INT NOT NULL REFERENCES "category" ("id") ON DELETE CASCADE,\n "config_id" INT NOT NULL REFERENCES "config" ("id") ON DELETE CASCADE\n)', 'CREATE TABLE "config_category_map" (\n "category_id" INT NOT NULL REFERENCES "category" ("id") ON DELETE CASCADE,\n "config_id" VARCHAR(20) NOT NULL REFERENCES "config" ("slug") ON DELETE CASCADE\n)',
'DROP TABLE IF EXISTS "config_category"', 'DROP TABLE IF EXISTS "config_category"',
} }
upgrade_operators = set(Migrate.upgrade_operators)
upgrade_more_than_expected = upgrade_operators - expected_upgrade_operators
assert not upgrade_more_than_expected
upgrade_less_than_expected = expected_upgrade_operators - upgrade_operators
assert not upgrade_less_than_expected
expected_downgrade_operators = { expected_downgrade_operators = {
'CREATE UNIQUE INDEX "uid_category_title_f7fc03" ON "category" ("title")', 'CREATE UNIQUE INDEX IF NOT EXISTS "uid_category_title_f7fc03" ON "category" ("title")',
'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL', 'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL',
'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(200) USING "slug"::VARCHAR(200)', 'ALTER TABLE "category" ALTER COLUMN "slug" TYPE VARCHAR(200) USING "slug"::VARCHAR(200)',
'ALTER TABLE "category" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ',
'ALTER TABLE "category" RENAME COLUMN "owner_id" TO "user_id"', 'ALTER TABLE "category" RENAME COLUMN "owner_id" TO "user_id"',
'ALTER TABLE "category" DROP CONSTRAINT IF EXISTS "fk_category_user_110d4c63"', 'ALTER TABLE "category" DROP CONSTRAINT IF EXISTS "fk_category_user_110d4c63"',
'DROP INDEX IF EXISTS "idx_category_slug_e9bcff"',
'CREATE INDEX IF NOT EXISTS "idx_category_slug_e9bcff" ON "category" ("slug")',
'ALTER TABLE "config" ADD "name" VARCHAR(100) NOT NULL UNIQUE', 'ALTER TABLE "config" ADD "name" VARCHAR(100) NOT NULL UNIQUE',
'CREATE UNIQUE INDEX "uid_config_name_2c83c8" ON "config" ("name")', 'CREATE UNIQUE INDEX IF NOT EXISTS "uid_config_name_2c83c8" ON "config" ("name")',
'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1', 'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1',
'ALTER TABLE "config" DROP CONSTRAINT IF EXISTS "fk_config_user_17daa970"', 'ALTER TABLE "config" DROP CONSTRAINT IF EXISTS "fk_config_user_17daa970"',
'ALTER TABLE "config" RENAME TO "configs"', 'ALTER TABLE "config" RENAME TO "configs"',
'ALTER TABLE "config" ALTER COLUMN "value" TYPE JSONB USING "value"::JSONB',
'ALTER TABLE "config" DROP COLUMN "user_id"', 'ALTER TABLE "config" DROP COLUMN "user_id"',
'ALTER TABLE "config" ALTER COLUMN "slug" TYPE VARCHAR(10) USING "slug"::VARCHAR(10)',
'ALTER TABLE "email" ADD "user_id" INT NOT NULL', 'ALTER TABLE "email" ADD "user_id" INT NOT NULL',
'ALTER TABLE "email" DROP COLUMN "address"', 'ALTER TABLE "email" DROP COLUMN "address"',
'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"', 'ALTER TABLE "email" RENAME COLUMN "email_id" TO "id"',
'ALTER TABLE "email" DROP COLUMN "config_id"', 'ALTER TABLE "email" DROP COLUMN "config_id"',
'ALTER TABLE "email" DROP CONSTRAINT IF EXISTS "fk_email_config_76a9dc71"', 'ALTER TABLE "email" DROP CONSTRAINT IF EXISTS "fk_email_config_88e28c1b"',
'CREATE INDEX IF NOT EXISTS "idx_email_company_1c9234" ON "email" ("company")',
'DROP INDEX IF EXISTS "uid_email_company_1c9234"',
'ALTER TABLE "product" ADD "uuid" INT NOT NULL UNIQUE', 'ALTER TABLE "product" ADD "uuid" INT NOT NULL UNIQUE',
'CREATE UNIQUE INDEX "uid_product_uuid_d33c18" ON "product" ("uuid")', 'CREATE UNIQUE INDEX IF NOT EXISTS "uid_product_uuid_d33c18" ON "product" ("uuid")',
'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT', 'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT',
'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"', 'ALTER TABLE "product" RENAME COLUMN "pic" TO "image"',
'ALTER TABLE "product" RENAME COLUMN "is_deleted" TO "is_delete"', 'ALTER TABLE "product" RENAME COLUMN "is_deleted" TO "is_delete"',
'ALTER TABLE "product" RENAME COLUMN "is_reviewed" TO "is_review"', 'ALTER TABLE "product" RENAME COLUMN "is_reviewed" TO "is_review"',
'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'', 'ALTER TABLE "product" DROP COLUMN "price"',
'ALTER TABLE "product" DROP COLUMN "no"',
'ALTER TABLE "product" ALTER COLUMN "id" TYPE INT USING "id"::INT',
'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'',
'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(200) USING "password"::VARCHAR(200)', 'ALTER TABLE "user" ALTER COLUMN "password" TYPE VARCHAR(200) USING "password"::VARCHAR(200)',
'ALTER TABLE "user" ALTER COLUMN "last_login" TYPE TIMESTAMPTZ USING "last_login"::TIMESTAMPTZ',
'ALTER TABLE "user" ALTER COLUMN "intro" TYPE TEXT USING "intro"::TEXT',
'ALTER TABLE "user" ALTER COLUMN "longitude" TYPE DECIMAL(12,9) USING "longitude"::DECIMAL(12,9)', 'ALTER TABLE "user" ALTER COLUMN "longitude" TYPE DECIMAL(12,9) USING "longitude"::DECIMAL(12,9)',
'ALTER TABLE "product" ALTER COLUMN "created_at" TYPE TIMESTAMPTZ USING "created_at"::TIMESTAMPTZ',
'ALTER TABLE "product" ALTER COLUMN "body" TYPE TEXT USING "body"::TEXT',
'DROP TABLE IF EXISTS "product_user"', 'DROP TABLE IF EXISTS "product_user"',
'DROP INDEX IF EXISTS "idx_product_name_869427"', 'DROP INDEX IF EXISTS "idx_product_name_869427"',
'DROP INDEX IF EXISTS "idx_email_email_4a1a33"', 'DROP INDEX IF EXISTS "idx_email_email_4a1a33"',
'DROP INDEX IF EXISTS "uid_user_usernam_9987ab"', 'DROP INDEX IF EXISTS "uid_user_usernam_9987ab"',
'DROP INDEX IF EXISTS "uid_product_name_869427"', 'DROP INDEX IF EXISTS "uid_product_name_869427"',
'DROP INDEX IF EXISTS "idx_product_no_e4d701"',
'DROP TABLE IF EXISTS "email_user"', 'DROP TABLE IF EXISTS "email_user"',
'DROP TABLE IF EXISTS "newmodel"', 'DROP TABLE IF EXISTS "newmodel"',
'CREATE TABLE "config_category" (\n "config_id" INT NOT NULL REFERENCES "config" ("id") ON DELETE CASCADE,\n "category_id" INT NOT NULL REFERENCES "category" ("id") ON DELETE CASCADE\n)', 'CREATE TABLE "config_category" (\n "config_id" VARCHAR(20) NOT NULL REFERENCES "config" ("slug") ON DELETE CASCADE,\n "category_id" INT NOT NULL REFERENCES "category" ("id") ON DELETE CASCADE\n)',
'DROP TABLE IF EXISTS "config_category_map"', 'DROP TABLE IF EXISTS "config_category_map"',
} }
assert not set(Migrate.upgrade_operators).symmetric_difference(expected_upgrade_operators) downgrade_operators = set(Migrate.downgrade_operators)
assert not set(Migrate.downgrade_operators).symmetric_difference( downgrade_more_than_expected = downgrade_operators - expected_downgrade_operators
expected_downgrade_operators assert not downgrade_more_than_expected
) downgrade_less_than_expected = expected_downgrade_operators - downgrade_operators
assert not downgrade_less_than_expected
elif isinstance(Migrate.ddl, SqliteDDL): elif isinstance(Migrate.ddl, SqliteDDL):
assert Migrate.upgrade_operators == [] assert Migrate.upgrade_operators == []
@ -1142,7 +1200,7 @@ def test_sort_all_version_files(mocker):
], ],
) )
Migrate.migrate_location = "." Migrate.migrate_location = Path(".")
assert Migrate.get_all_version_files() == [ assert Migrate.get_all_version_files() == [
"1_datetime_update.py", "1_datetime_update.py",
@ -1166,7 +1224,7 @@ def test_sort_files_containing_non_migrations(mocker):
], ],
) )
Migrate.migrate_location = "." Migrate.migrate_location = Path(".")
assert Migrate.get_all_version_files() == [ assert Migrate.get_all_version_files() == [
"1_datetime_update.py", "1_datetime_update.py",

18
tests/test_python_m.py Normal file
View File

@ -0,0 +1,18 @@
import subprocess # nosec
from pathlib import Path
from aerich.version import __version__
from tests._utils import chdir, run_shell
def test_python_m_aerich():
assert __version__ in run_shell("python -m aerich --version")
def test_poetry_add(tmp_path: Path):
package = Path(__file__).parent.resolve().parent
with chdir(tmp_path):
subprocess.run(["poetry", "new", "foo"]) # nosec
with chdir("foo"):
r = subprocess.run(["poetry", "add", package]) # nosec
assert r.returncode == 0

View File

@ -1,165 +1,28 @@
from __future__ import annotations
import contextlib import contextlib
import os import os
import platform
import shlex import shlex
import shutil import shutil
import subprocess import subprocess
import sys from collections.abc import Generator
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from aerich.ddl.sqlite import SqliteDDL from tests._utils import Dialect, chdir, copy_files
from aerich.migrate import Migrate
if sys.version_info >= (3, 11):
from contextlib import chdir
else:
class chdir(contextlib.AbstractContextManager): # Copied from source code of Python3.13
"""Non thread-safe context manager to change the current working directory."""
def __init__(self, path):
self.path = path
self._old_cwd = []
def __enter__(self):
self._old_cwd.append(os.getcwd())
os.chdir(self.path)
def __exit__(self, *excinfo):
os.chdir(self._old_cwd.pop())
MODELS = """from __future__ import annotations def run_aerich(cmd: str) -> subprocess.CompletedProcess | None:
if not cmd.startswith("poetry") and not cmd.startswith("python"):
from tortoise import Model, fields
class Foo(Model):
name = fields.CharField(max_length=60, db_index=False)
"""
SETTINGS = """from __future__ import annotations
TORTOISE_ORM = {
"connections": {"default": "sqlite://db.sqlite3"},
"apps": {"models": {"models": ["models", "aerich.models"]}},
}
"""
CONFTEST = """from __future__ import annotations
import asyncio
from typing import Generator
import pytest
import pytest_asyncio
from tortoise import Tortoise, connections
import settings
@pytest.fixture(scope="session")
def event_loop() -> Generator:
policy = asyncio.get_event_loop_policy()
res = policy.new_event_loop()
asyncio.set_event_loop(res)
res._close = res.close # type:ignore[attr-defined]
res.close = lambda: None # type:ignore[method-assign]
yield res
res._close() # type:ignore[attr-defined]
@pytest_asyncio.fixture(scope="session", autouse=True)
async def api(event_loop, request):
await Tortoise.init(config=settings.TORTOISE_ORM)
request.addfinalizer(lambda: event_loop.run_until_complete(connections.close_all(discard=True)))
"""
TESTS = """from __future__ import annotations
import uuid
import pytest
from tortoise.exceptions import IntegrityError
from models import Foo
@pytest.mark.asyncio
async def test_allow_duplicate() -> None:
await Foo.all().delete()
await Foo.create(name="foo")
obj = await Foo.create(name="foo")
assert (await Foo.all().count()) == 2
await obj.delete()
@pytest.mark.asyncio
async def test_unique_is_true() -> None:
with pytest.raises(IntegrityError):
await Foo.create(name="foo")
@pytest.mark.asyncio
async def test_add_unique_field() -> None:
if not await Foo.filter(age=0).exists():
await Foo.create(name="0_"+uuid.uuid4().hex, age=0)
with pytest.raises(IntegrityError):
await Foo.create(name=uuid.uuid4().hex, age=0)
@pytest.mark.asyncio
async def test_drop_unique_field() -> None:
name = "1_" + uuid.uuid4().hex
await Foo.create(name=name, age=0)
assert (await Foo.filter(name=name).exists())
@pytest.mark.asyncio
async def test_with_age_field() -> None:
name = "2_" + uuid.uuid4().hex
await Foo.create(name=name, age=0)
obj = await Foo.get(name=name)
assert obj.age == 0
@pytest.mark.asyncio
async def test_without_age_field() -> None:
name = "3_" + uuid.uuid4().hex
await Foo.create(name=name, age=0)
obj = await Foo.get(name=name)
assert getattr(obj, "age", None) is None
@pytest.mark.asyncio
async def test_m2m_with_custom_through() -> None:
from models import Group, FooGroup
name = "4_" + uuid.uuid4().hex
foo = await Foo.create(name=name)
group = await Group.create(name=name+"1")
await FooGroup.all().delete()
await foo.groups.add(group)
foo_group = await FooGroup.get(foo=foo, group=group)
assert not foo_group.is_active
@pytest.mark.asyncio
async def test_add_m2m_field_after_init_db() -> None:
from models import Group
name = "5_" + uuid.uuid4().hex
foo = await Foo.create(name=name)
group = await Group.create(name=name+"1")
await foo.groups.add(group)
assert (await group.users.all().first()) == foo
"""
def run_aerich(cmd: str) -> None:
with contextlib.suppress(subprocess.TimeoutExpired):
if not cmd.startswith("aerich"): if not cmd.startswith("aerich"):
cmd = "aerich " + cmd cmd = "aerich " + cmd
subprocess.run(shlex.split(cmd), timeout=2) if platform.system() == "Windows":
cmd = "python -m " + cmd
r = None
with contextlib.suppress(subprocess.TimeoutExpired):
r = subprocess.run(shlex.split(cmd), timeout=2)
return r
def run_shell(cmd: str) -> subprocess.CompletedProcess: def run_shell(cmd: str) -> subprocess.CompletedProcess:
@ -167,78 +30,139 @@ def run_shell(cmd: str) -> subprocess.CompletedProcess:
return subprocess.run(shlex.split(cmd), env=envs) return subprocess.run(shlex.split(cmd), env=envs)
def test_sqlite_migrate(tmp_path: Path) -> None: def _get_empty_db() -> Path:
if (ddl := getattr(Migrate, "ddl", None)) and not isinstance(ddl, SqliteDDL): if (db_file := Path("db.sqlite3")).exists():
return db_file.unlink()
return db_file
@contextmanager
def prepare_sqlite_project(tmp_path: Path) -> Generator[tuple[Path, str]]:
test_dir = Path(__file__).parent
asset_dir = test_dir / "assets" / "sqlite_migrate"
with chdir(tmp_path): with chdir(tmp_path):
models_py = Path("models.py") files = ("models.py", "settings.py", "_tests.py")
settings_py = Path("settings.py") copy_files(*(asset_dir / f for f in files), target_dir=Path())
test_py = Path("_test.py") models_py, settings_py, test_py = (Path(f) for f in files)
models_py.write_text(MODELS) copy_files(asset_dir / "conftest_.py", target_dir=Path("conftest.py"))
settings_py.write_text(SETTINGS) _get_empty_db()
test_py.write_text(TESTS) yield models_py, models_py.read_text("utf-8")
Path("conftest.py").write_text(CONFTEST)
if (db_file := Path("db.sqlite3")).exists():
db_file.unlink() def test_close_tortoise_connections_patch(tmp_path: Path) -> None:
if not Dialect.is_sqlite():
return
with prepare_sqlite_project(tmp_path) as (models_py, models_text):
run_aerich("aerich init -t settings.TORTOISE_ORM")
r = run_aerich("aerich init-db")
assert r is not None
def test_sqlite_migrate_alter_indexed_unique(tmp_path: Path) -> None:
if not Dialect.is_sqlite():
return
with prepare_sqlite_project(tmp_path) as (models_py, models_text):
models_py.write_text(models_text.replace("db_index=False", "db_index=True"))
run_aerich("aerich init -t settings.TORTOISE_ORM") run_aerich("aerich init -t settings.TORTOISE_ORM")
run_aerich("aerich init-db") run_aerich("aerich init-db")
r = run_shell("pytest _test.py::test_allow_duplicate") r = run_shell("pytest -s _tests.py::test_allow_duplicate")
assert r.returncode == 0
models_py.write_text(models_text.replace("db_index=False", "unique=True"))
run_aerich("aerich migrate") # migrations/models/1_
run_aerich("aerich upgrade")
r = run_shell("pytest _tests.py::test_unique_is_true")
assert r.returncode == 0
models_py.write_text(models_text.replace("db_index=False", "db_index=True"))
run_aerich("aerich migrate") # migrations/models/2_
run_aerich("aerich upgrade")
r = run_shell("pytest -s _tests.py::test_allow_duplicate")
assert r.returncode == 0
M2M_WITH_CUSTOM_THROUGH = """
groups = fields.ManyToManyField("models.Group", through="foo_group")
class Group(Model):
name = fields.CharField(max_length=60)
class FooGroup(Model):
foo = fields.ForeignKeyField("models.Foo")
group = fields.ForeignKeyField("models.Group")
is_active = fields.BooleanField(default=False)
class Meta:
table = "foo_group"
"""
def test_sqlite_migrate(tmp_path: Path) -> None:
if not Dialect.is_sqlite():
return
with prepare_sqlite_project(tmp_path) as (models_py, models_text):
MODELS = models_text
run_aerich("aerich init -t settings.TORTOISE_ORM")
config_file = Path("pyproject.toml")
modify_time = config_file.stat().st_mtime
run_aerich("aerich init-db")
run_aerich("aerich init -t settings.TORTOISE_ORM")
assert modify_time == config_file.stat().st_mtime
r = run_shell("pytest _tests.py::test_allow_duplicate")
assert r.returncode == 0 assert r.returncode == 0
# Add index # Add index
models_py.write_text(MODELS.replace("index=False", "index=True")) models_py.write_text(MODELS.replace("index=False", "index=True"))
run_aerich("aerich migrate") # migrations/models/1_ run_aerich("aerich migrate") # migrations/models/1_
run_aerich("aerich upgrade") run_aerich("aerich upgrade")
r = run_shell("pytest -s _test.py::test_allow_duplicate") r = run_shell("pytest -s _tests.py::test_allow_duplicate")
assert r.returncode == 0 assert r.returncode == 0
# Drop index # Drop index
models_py.write_text(MODELS) models_py.write_text(MODELS)
run_aerich("aerich migrate") # migrations/models/2_ run_aerich("aerich migrate") # migrations/models/2_
run_aerich("aerich upgrade") run_aerich("aerich upgrade")
r = run_shell("pytest -s _test.py::test_allow_duplicate") r = run_shell("pytest -s _tests.py::test_allow_duplicate")
assert r.returncode == 0 assert r.returncode == 0
# Add unique index # Add unique index
models_py.write_text(MODELS.replace("index=False", "index=True, unique=True")) models_py.write_text(MODELS.replace("index=False", "index=True, unique=True"))
run_aerich("aerich migrate") # migrations/models/3_ run_aerich("aerich migrate") # migrations/models/3_
run_aerich("aerich upgrade") run_aerich("aerich upgrade")
r = run_shell("pytest _test.py::test_unique_is_true") r = run_shell("pytest _tests.py::test_unique_is_true")
assert r.returncode == 0 assert r.returncode == 0
# Drop unique index # Drop unique index
models_py.write_text(MODELS) models_py.write_text(MODELS)
run_aerich("aerich migrate") # migrations/models/4_ run_aerich("aerich migrate") # migrations/models/4_
run_aerich("aerich upgrade") run_aerich("aerich upgrade")
r = run_shell("pytest _test.py::test_allow_duplicate") r = run_shell("pytest _tests.py::test_allow_duplicate")
assert r.returncode == 0 assert r.returncode == 0
# Add field with unique=True # Add field with unique=True
with models_py.open("a") as f: with models_py.open("a") as f:
f.write(" age = fields.IntField(unique=True, default=0)") f.write(" age = fields.IntField(unique=True, default=0)")
run_aerich("aerich migrate") # migrations/models/5_ run_aerich("aerich migrate") # migrations/models/5_
run_aerich("aerich upgrade") run_aerich("aerich upgrade")
r = run_shell("pytest _test.py::test_add_unique_field") r = run_shell("pytest _tests.py::test_add_unique_field")
assert r.returncode == 0 assert r.returncode == 0
# Drop unique field # Drop unique field
models_py.write_text(MODELS) models_py.write_text(MODELS)
run_aerich("aerich migrate") # migrations/models/6_ run_aerich("aerich migrate") # migrations/models/6_
run_aerich("aerich upgrade") run_aerich("aerich upgrade")
r = run_shell("pytest -s _test.py::test_drop_unique_field") r = run_shell("pytest -s _tests.py::test_drop_unique_field")
assert r.returncode == 0 assert r.returncode == 0
# Initial with indexed field and then drop it # Initial with indexed field and then drop it
migrations_dir = Path("migrations/models") migrations_dir = Path("migrations/models")
shutil.rmtree(migrations_dir) shutil.rmtree(migrations_dir)
db_file.unlink() db_file = _get_empty_db()
models_py.write_text(MODELS + " age = fields.IntField(db_index=True)") models_py.write_text(MODELS + " age = fields.IntField(db_index=True)")
run_aerich("aerich init -t settings.TORTOISE_ORM") run_aerich("aerich init -t settings.TORTOISE_ORM")
run_aerich("aerich init-db") run_aerich("aerich init-db")
migration_file = list(migrations_dir.glob("0_*.py"))[0] migration_file = list(migrations_dir.glob("0_*.py"))[0]
assert "CREATE INDEX" in migration_file.read_text() assert "CREATE INDEX" in migration_file.read_text()
r = run_shell("pytest _test.py::test_with_age_field") r = run_shell("pytest _tests.py::test_with_age_field")
assert r.returncode == 0 assert r.returncode == 0
models_py.write_text(MODELS) models_py.write_text(MODELS)
run_aerich("aerich migrate") run_aerich("aerich migrate")
run_aerich("aerich upgrade") run_aerich("aerich upgrade")
migration_file_1 = list(migrations_dir.glob("1_*.py"))[0] migration_file_1 = list(migrations_dir.glob("1_*.py"))[0]
assert "DROP INDEX" in migration_file_1.read_text() assert "DROP INDEX" in migration_file_1.read_text()
r = run_shell("pytest _test.py::test_without_age_field") r = run_shell("pytest _tests.py::test_without_age_field")
assert r.returncode == 0 assert r.returncode == 0
# Generate migration file in emptry directory # Generate migration file in emptry directory
@ -260,26 +184,12 @@ def test_sqlite_migrate(tmp_path: Path) -> None:
assert "[tool.aerich]" in config_file.read_text() assert "[tool.aerich]" in config_file.read_text()
# add m2m with custom model for through # add m2m with custom model for through
new = """ models_py.write_text(MODELS + M2M_WITH_CUSTOM_THROUGH)
groups = fields.ManyToManyField("models.Group", through="foo_group")
class Group(Model):
name = fields.CharField(max_length=60)
class FooGroup(Model):
foo = fields.ForeignKeyField("models.Foo")
group = fields.ForeignKeyField("models.Group")
is_active = fields.BooleanField(default=False)
class Meta:
table = "foo_group"
"""
models_py.write_text(MODELS + new)
run_aerich("aerich migrate") run_aerich("aerich migrate")
run_aerich("aerich upgrade") run_aerich("aerich upgrade")
migration_file_1 = list(migrations_dir.glob("1_*.py"))[0] migration_file_1 = list(migrations_dir.glob("1_*.py"))[0]
assert "foo_group" in migration_file_1.read_text() assert "foo_group" in migration_file_1.read_text()
r = run_shell("pytest _test.py::test_m2m_with_custom_through") r = run_shell("pytest _tests.py::test_m2m_with_custom_through")
assert r.returncode == 0 assert r.returncode == 0
# add m2m field after init-db # add m2m field after init-db
@ -289,8 +199,7 @@ class FooGroup(Model):
class Group(Model): class Group(Model):
name = fields.CharField(max_length=60) name = fields.CharField(max_length=60)
""" """
if db_file.exists(): _get_empty_db()
db_file.unlink()
if migrations_dir.exists(): if migrations_dir.exists():
shutil.rmtree(migrations_dir) shutil.rmtree(migrations_dir)
models_py.write_text(MODELS) models_py.write_text(MODELS)
@ -300,5 +209,5 @@ class Group(Model):
run_aerich("aerich upgrade") run_aerich("aerich upgrade")
migration_file_1 = list(migrations_dir.glob("1_*.py"))[0] migration_file_1 = list(migrations_dir.glob("1_*.py"))[0]
assert "foo_group" in migration_file_1.read_text() assert "foo_group" in migration_file_1.read_text()
r = run_shell("pytest _test.py::test_add_m2m_field_after_init_db") r = run_shell("pytest _tests.py::test_add_m2m_field_after_init_db")
assert r.returncode == 0 assert r.returncode == 0