aerich/aerich/__init__.py
Waket Zheng 0364ae3f83
feat: add project section (#424)
* refactor: apply future style type hints

* chore: use project section

* ci: upgrade to poetry v2

* ci: explicit declare python version for poetry

* fix error for generate index name

* fix _generate_fk_name

* ci: verify aiomysql support

* tests: poetry add

* Add patch to fix tortoise 0.24.1

* docs: update changelog
2025-02-26 14:24:02 +08:00

273 lines
10 KiB
Python

from __future__ import annotations
import os
import platform
from pathlib import Path
from typing import TYPE_CHECKING
import tortoise
from tortoise import Tortoise, generate_schema_for_client
from tortoise.exceptions import OperationalError
from tortoise.transactions import in_transaction
from tortoise.utils import get_schema_sql
from aerich.exceptions import DowngradeError
from aerich.inspectdb.mysql import InspectMySQL
from aerich.inspectdb.postgres import InspectPostgres
from aerich.inspectdb.sqlite import InspectSQLite
from aerich.migrate import MIGRATE_TEMPLATE, Migrate
from aerich.models import Aerich
from aerich.utils import (
get_app_connection,
get_app_connection_name,
get_models_describe,
import_py_file,
)
if TYPE_CHECKING:
from tortoise import Model
from tortoise.fields.relational import ManyToManyFieldInstance # NOQA:F401
from aerich.inspectdb import Inspect
def _init_asyncio_patch():
"""
Select compatible event loop for psycopg3.
As of Python 3.8+, the default event loop on Windows is `proactor`,
however psycopg3 requires the old default "selector" event loop.
See https://www.psycopg.org/psycopg3/docs/advanced/async.html
"""
if platform.system() == "Windows":
try:
from asyncio import WindowsSelectorEventLoopPolicy
except ImportError:
pass # Can't assign a policy which doesn't exist.
else:
from asyncio import get_event_loop_policy, set_event_loop_policy
if not isinstance(get_event_loop_policy(), WindowsSelectorEventLoopPolicy):
set_event_loop_policy(WindowsSelectorEventLoopPolicy())
def _init_tortoise_0_24_1_patch():
# this patch is for "tortoise-orm==0.24.1" to fix:
# https://github.com/tortoise/tortoise-orm/issues/1893
if tortoise.__version__ != "0.24.1":
return
from tortoise.backends.base.schema_generator import BaseSchemaGenerator, cast, re
def _get_m2m_tables(
self, model: type[Model], table_name: str, safe: bool, models_tables: list[str]
) -> list[str]:
m2m_tables_for_create = []
db_table = table_name
for m2m_field in model._meta.m2m_fields:
field_object = cast("ManyToManyFieldInstance", model._meta.fields_map[m2m_field])
if field_object._generated or field_object.through in models_tables:
continue
backward_key, forward_key = field_object.backward_key, field_object.forward_key
if field_object.db_constraint:
backward_fk = self._create_fk_string(
"",
backward_key,
db_table,
model._meta.db_pk_column,
field_object.on_delete,
"",
)
forward_fk = self._create_fk_string(
"",
forward_key,
field_object.related_model._meta.db_table,
field_object.related_model._meta.db_pk_column,
field_object.on_delete,
"",
)
else:
backward_fk = forward_fk = ""
exists = "IF NOT EXISTS " if safe else ""
table_name = field_object.through
backward_type = self._get_pk_field_sql_type(model._meta.pk)
forward_type = self._get_pk_field_sql_type(field_object.related_model._meta.pk)
comment = ""
if desc := field_object.description:
comment = self._table_comment_generator(table=table_name, comment=desc)
m2m_create_string = self.M2M_TABLE_TEMPLATE.format(
exists=exists,
table_name=table_name,
backward_fk=backward_fk,
forward_fk=forward_fk,
backward_key=backward_key,
backward_type=backward_type,
forward_key=forward_key,
forward_type=forward_type,
extra=self._table_generate_extra(table=field_object.through),
comment=comment,
)
if not field_object.db_constraint:
m2m_create_string = m2m_create_string.replace(
""",
,
""",
"",
) # may have better way
m2m_create_string += self._post_table_hook()
if field_object.create_unique_index:
unique_index_create_sql = self._get_unique_index_sql(
exists, table_name, [backward_key, forward_key]
)
if unique_index_create_sql.endswith(";"):
m2m_create_string += "\n" + unique_index_create_sql
else:
lines = m2m_create_string.splitlines()
lines[-2] += ","
indent = m.group() if (m := re.match(r"\s+", lines[-2])) else ""
lines.insert(-1, indent + unique_index_create_sql)
m2m_create_string = "\n".join(lines)
m2m_tables_for_create.append(m2m_create_string)
return m2m_tables_for_create
BaseSchemaGenerator._get_m2m_tables = _get_m2m_tables
_init_asyncio_patch()
_init_tortoise_0_24_1_patch()
class Command:
def __init__(
self,
tortoise_config: dict,
app: str = "models",
location: str = "./migrations",
) -> None:
self.tortoise_config = tortoise_config
self.app = app
self.location = location
Migrate.app = app
async def init(self) -> None:
await Migrate.init(self.tortoise_config, self.app, self.location)
async def _upgrade(self, conn, version_file, fake: bool = False) -> None:
file_path = Path(Migrate.migrate_location, version_file)
m = import_py_file(file_path)
upgrade = m.upgrade
if not fake:
await conn.execute_script(await upgrade(conn))
await Aerich.create(
version=version_file,
app=self.app,
content=get_models_describe(self.app),
)
async def upgrade(self, run_in_transaction: bool = True, fake: bool = False) -> list[str]:
migrated = []
for version_file in Migrate.get_all_version_files():
try:
exists = await Aerich.exists(version=version_file, app=self.app)
except OperationalError:
exists = False
if not exists:
app_conn_name = get_app_connection_name(self.tortoise_config, self.app)
if run_in_transaction:
async with in_transaction(app_conn_name) as conn:
await self._upgrade(conn, version_file, fake=fake)
else:
app_conn = get_app_connection(self.tortoise_config, self.app)
await self._upgrade(app_conn, version_file, fake=fake)
migrated.append(version_file)
return migrated
async def downgrade(self, version: int, delete: bool, fake: bool = False) -> list[str]:
ret: list[str] = []
if version == -1:
specified_version = await Migrate.get_last_version()
else:
specified_version = await Aerich.filter(
app=self.app, version__startswith=f"{version}_"
).first()
if not specified_version:
raise DowngradeError("No specified version found")
if version == -1:
versions = [specified_version]
else:
versions = await Aerich.filter(app=self.app, pk__gte=specified_version.pk)
for version_obj in versions:
file = version_obj.version
async with in_transaction(
get_app_connection_name(self.tortoise_config, self.app)
) as conn:
file_path = Path(Migrate.migrate_location, file)
m = import_py_file(file_path)
downgrade = m.downgrade
downgrade_sql = await downgrade(conn)
if not downgrade_sql.strip():
raise DowngradeError("No downgrade items found")
if not fake:
await conn.execute_script(downgrade_sql)
await version_obj.delete()
if delete:
os.unlink(file_path)
ret.append(file)
return ret
async def heads(self) -> list[str]:
ret = []
versions = Migrate.get_all_version_files()
for version in versions:
if not await Aerich.exists(version=version, app=self.app):
ret.append(version)
return ret
async def history(self) -> list[str]:
versions = Migrate.get_all_version_files()
return [version for version in versions]
async def inspectdb(self, tables: list[str] | None = None) -> str:
connection = get_app_connection(self.tortoise_config, self.app)
dialect = connection.schema_generator.DIALECT
if dialect == "mysql":
cls: type[Inspect] = InspectMySQL
elif dialect == "postgres":
cls = InspectPostgres
elif dialect == "sqlite":
cls = InspectSQLite
else:
raise NotImplementedError(f"{dialect} is not supported")
inspect = cls(connection, tables)
return await inspect.inspect()
async def migrate(self, name: str = "update", empty: bool = False) -> str:
return await Migrate.migrate(name, empty)
async def init_db(self, safe: bool) -> None:
location = self.location
app = self.app
dirname = Path(location, app)
if not dirname.exists():
dirname.mkdir(parents=True)
else:
# If directory is empty, go ahead, otherwise raise FileExistsError
for unexpected_file in dirname.glob("*"):
raise FileExistsError(str(unexpected_file))
await Tortoise.init(config=self.tortoise_config)
connection = get_app_connection(self.tortoise_config, app)
await generate_schema_for_client(connection, safe)
schema = get_schema_sql(connection, safe)
version = await Migrate.generate_version()
await Aerich.create(
version=version,
app=app,
content=get_models_describe(app),
)
version_file = Path(dirname, version)
content = MIGRATE_TEMPLATE.format(upgrade_sql=schema, downgrade_sql="")
with open(version_file, "w", encoding="utf-8") as f:
f.write(content)