diff --git a/README.md b/README.md index a8254b4..10bc2c7 100644 --- a/README.md +++ b/README.md @@ -226,14 +226,14 @@ from tortoise import Model, fields class Test(Model): - date = fields.DateField(null=True, ) - datetime = fields.DatetimeField(auto_now=True, ) - decimal = fields.DecimalField(max_digits=10, decimal_places=2, ) - float = fields.FloatField(null=True, ) - id = fields.IntField(pk=True, ) - string = fields.CharField(max_length=200, null=True, ) - time = fields.TimeField(null=True, ) - tinyint = fields.BooleanField(null=True, ) + date = fields.DateField(null=True) + datetime = fields.DatetimeField(auto_now=True) + decimal = fields.DecimalField(max_digits=10, decimal_places=2) + float = fields.FloatField(null=True) + id = fields.IntField(primary_key=True) + string = fields.CharField(max_length=200, null=True) + time = fields.TimeField(null=True) + tinyint = fields.BooleanField(null=True) ``` Note that this command is limited and can't infer some fields, such as `IntEnumField`, `ForeignKeyField`, and others. @@ -243,8 +243,8 @@ Note that this command is limited and can't infer some fields, such as `IntEnumF ```python tortoise_orm = { "connections": { - "default": expand_db_url(db_url, True), - "second": expand_db_url(db_url_second, True), + "default": "postgres://postgres_user:postgres_pass@127.0.0.1:5432/db1", + "second": "postgres://postgres_user:postgres_pass@127.0.0.1:5432/db2", }, "apps": { "models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"}, @@ -253,7 +253,7 @@ tortoise_orm = { } ``` -You only need to specify `aerich.models` in one app, and must specify `--app` when running `aerich migrate` and so on. +You only need to specify `aerich.models` in one app, and must specify `--app` when running `aerich migrate` and so on, e.g. `aerich --app models_second migrate`. ## Restore `aerich` workflow @@ -273,9 +273,9 @@ You can use `aerich` out of cli by use `Command` class. ```python from aerich import Command -command = Command(tortoise_config=config, app='models') -await command.init() -await command.migrate('test') +async with Command(tortoise_config=config, app='models') as command: + await command.migrate('test') + await command.upgrade() ``` ## Upgrade/Downgrade with `--fake` option diff --git a/aerich/__init__.py b/aerich/__init__.py index 7a1bda0..9215ae8 100644 --- a/aerich/__init__.py +++ b/aerich/__init__.py @@ -2,11 +2,12 @@ from __future__ import annotations import os import platform +from contextlib import AbstractAsyncContextManager from pathlib import Path from typing import TYPE_CHECKING import tortoise -from tortoise import Tortoise, generate_schema_for_client +from tortoise import Tortoise, connections, generate_schema_for_client from tortoise.exceptions import OperationalError from tortoise.transactions import in_transaction from tortoise.utils import get_schema_sql @@ -59,10 +60,9 @@ def _init_tortoise_0_24_1_patch(): from tortoise.backends.base.schema_generator import BaseSchemaGenerator, cast, re def _get_m2m_tables( - self, model: type[Model], table_name: str, safe: bool, models_tables: list[str] - ) -> list[str]: + self, model: type[Model], db_table: str, safe: bool, models_tables: list[str] + ) -> list[str]: # Copied from tortoise-orm m2m_tables_for_create = [] - db_table = table_name for m2m_field in model._meta.m2m_fields: field_object = cast("ManyToManyFieldInstance", model._meta.fields_map[m2m_field]) if field_object._generated or field_object.through in models_tables: @@ -88,15 +88,15 @@ def _init_tortoise_0_24_1_patch(): else: backward_fk = forward_fk = "" exists = "IF NOT EXISTS " if safe else "" - table_name = field_object.through + through_table_name = field_object.through backward_type = self._get_pk_field_sql_type(model._meta.pk) forward_type = self._get_pk_field_sql_type(field_object.related_model._meta.pk) comment = "" if desc := field_object.description: - comment = self._table_comment_generator(table=table_name, comment=desc) + comment = self._table_comment_generator(table=through_table_name, comment=desc) m2m_create_string = self.M2M_TABLE_TEMPLATE.format( exists=exists, - table_name=table_name, + table_name=through_table_name, backward_fk=backward_fk, forward_fk=forward_fk, backward_key=backward_key, @@ -116,7 +116,7 @@ def _init_tortoise_0_24_1_patch(): m2m_create_string += self._post_table_hook() if field_object.create_unique_index: unique_index_create_sql = self._get_unique_index_sql( - exists, table_name, [backward_key, forward_key] + exists, through_table_name, [backward_key, forward_key] ) if unique_index_create_sql.endswith(";"): m2m_create_string += "\n" + unique_index_create_sql @@ -136,7 +136,7 @@ _init_asyncio_patch() _init_tortoise_0_24_1_patch() -class Command: +class Command(AbstractAsyncContextManager): def __init__( self, tortoise_config: dict, @@ -151,6 +151,16 @@ class Command: async def init(self) -> None: await Migrate.init(self.tortoise_config, self.app, self.location) + async def __aenter__(self) -> Command: + await self.init() + return self + + async def close(self) -> None: + await connections.close_all() + + async def __aexit__(self, *args, **kw) -> None: + await self.close() + async def _upgrade(self, conn, version_file, fake: bool = False) -> None: file_path = Path(Migrate.migrate_location, version_file) m = import_py_file(file_path) diff --git a/tests/test_command.py b/tests/test_command.py new file mode 100644 index 0000000..f7a5ff9 --- /dev/null +++ b/tests/test_command.py @@ -0,0 +1,11 @@ +from aerich import Command +from conftest import tortoise_orm + + +async def test_command(mocker): + mocker.patch("os.listdir", return_value=[]) + async with Command(tortoise_orm) as command: + history = await command.history() + heads = await command.heads() + assert history == [] + assert heads == []