98 lines
3.5 KiB
Python
98 lines
3.5 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import os
|
|
import sys
|
|
from collections.abc import Generator
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
from tortoise import Tortoise, expand_db_url
|
|
from tortoise.backends.base_postgres.schema_generator import BasePostgresSchemaGenerator
|
|
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
|
|
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
|
|
from tortoise.contrib.test import MEMORY_SQLITE
|
|
|
|
from aerich.ddl.mysql import MysqlDDL
|
|
from aerich.ddl.postgres import PostgresDDL
|
|
from aerich.ddl.sqlite import SqliteDDL
|
|
from aerich.migrate import Migrate
|
|
from tests._utils import chdir, copy_files, init_db, run_shell
|
|
|
|
db_url = os.getenv("TEST_DB", MEMORY_SQLITE)
|
|
db_url_second = os.getenv("TEST_DB_SECOND", MEMORY_SQLITE)
|
|
tortoise_orm = {
|
|
"connections": {
|
|
"default": expand_db_url(db_url, testing=True),
|
|
"second": expand_db_url(db_url_second, testing=True),
|
|
},
|
|
"apps": {
|
|
"models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"},
|
|
"models_second": {"models": ["tests.models_second"], "default_connection": "second"},
|
|
},
|
|
}
|
|
|
|
|
|
@pytest.fixture(scope="function", autouse=True)
|
|
def reset_migrate() -> None:
|
|
Migrate.upgrade_operators = []
|
|
Migrate.downgrade_operators = []
|
|
Migrate._upgrade_fk_m2m_index_operators = []
|
|
Migrate._downgrade_fk_m2m_index_operators = []
|
|
Migrate._upgrade_m2m = []
|
|
Migrate._downgrade_m2m = []
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def event_loop() -> Generator:
|
|
policy = asyncio.get_event_loop_policy()
|
|
res = policy.new_event_loop()
|
|
asyncio.set_event_loop(res)
|
|
res._close = res.close # type:ignore[attr-defined]
|
|
res.close = lambda: None # type:ignore[method-assign]
|
|
|
|
yield res
|
|
|
|
res._close() # type:ignore[attr-defined]
|
|
|
|
|
|
@pytest.fixture(scope="session", autouse=True)
|
|
async def initialize_tests(event_loop, request) -> None:
|
|
await init_db(tortoise_orm)
|
|
client = Tortoise.get_connection("default")
|
|
if client.schema_generator is MySQLSchemaGenerator:
|
|
Migrate.ddl = MysqlDDL(client)
|
|
elif client.schema_generator is SqliteSchemaGenerator:
|
|
Migrate.ddl = SqliteDDL(client)
|
|
elif issubclass(client.schema_generator, BasePostgresSchemaGenerator):
|
|
Migrate.ddl = PostgresDDL(client)
|
|
Migrate.dialect = Migrate.ddl.DIALECT
|
|
request.addfinalizer(lambda: event_loop.run_until_complete(Tortoise._drop_databases()))
|
|
|
|
|
|
@pytest.fixture
|
|
def new_aerich_project(tmp_path: Path):
|
|
test_dir = Path(__file__).parent / "tests"
|
|
asset_dir = test_dir / "assets" / "fake"
|
|
settings_py = asset_dir / "settings.py"
|
|
_tests_py = asset_dir / "_tests.py"
|
|
db_py = asset_dir / "db.py"
|
|
models_py = test_dir / "models.py"
|
|
models_second_py = test_dir / "models_second.py"
|
|
copy_files(settings_py, _tests_py, models_py, models_second_py, db_py, target_dir=tmp_path)
|
|
dst_dir = tmp_path / "tests"
|
|
dst_dir.mkdir()
|
|
dst_dir.joinpath("__init__.py").touch()
|
|
copy_files(test_dir / "_utils.py", test_dir / "indexes.py", target_dir=dst_dir)
|
|
if should_remove := str(tmp_path) not in sys.path:
|
|
sys.path.append(str(tmp_path))
|
|
with chdir(tmp_path):
|
|
run_shell("python db.py create", capture_output=False)
|
|
try:
|
|
yield
|
|
finally:
|
|
if not os.getenv("AERICH_DONT_DROP_FAKE_DB"):
|
|
run_shell("python db.py drop", capture_output=False)
|
|
if should_remove:
|
|
sys.path.remove(str(tmp_path))
|