From e9b76bdd35cb34e25a2271709fc08c6c06a28032 Mon Sep 17 00:00:00 2001 From: long2ice Date: Fri, 23 Sep 2022 10:29:48 +0800 Subject: [PATCH] feat: use .py version files --- CHANGELOG.md | 14 +++++++-- README.md | 12 ++++---- aerich/__init__.py | 29 +++++++++--------- aerich/cli.py | 4 +-- aerich/migrate.py | 49 ++++++++++++++++++++----------- aerich/utils.py | 68 +++++++------------------------------------ pyproject.toml | 2 +- tests/models.py | 2 -- tests/test_migrate.py | 16 +++++----- tests/test_utils.py | 6 ++++ 10 files changed, 91 insertions(+), 111 deletions(-) create mode 100644 tests/test_utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 490d4f4..a11bf7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,16 @@ # ChangeLog -## 0.6 +## 0.7 -### 0.6.4 +### 0.7.0 + +**Now aerich use `.py` file to record versions.** + +Upgrade Note: + +1. Drop `aerich` table +2. Delete `migrations/models` folder +3. Run `aerich init-db` - Improve `inspectdb` adding support to `postgresql::numeric` data type - Add support for dynamically load DDL classes easing to add support to @@ -10,6 +18,8 @@ - Fix decimal field change. (#246) - Support add/remove field with index. +## 0.6 + ### 0.6.3 - Improve `inspectdb` and support `postgres` & `sqlite`. diff --git a/README.md b/README.md index 7456d46..ba05d8c 100644 --- a/README.md +++ b/README.md @@ -101,11 +101,11 @@ e.g. `aerich --app other_models init-db`. ```shell > aerich migrate --name drop_column -Success migrate 1_202029051520102929_drop_column.sql +Success migrate 1_202029051520102929_drop_column.py ``` Format of migrate filename is -`{version_num}_{datetime}_{name|update}.sql`. +`{version_num}_{datetime}_{name|update}.py`. If `aerich` guesses you are renaming a column, it will ask `Rename {old_column} to {new_column} [True]`. You can choose `True` to rename column without column drop, or choose `False` to drop the column then create. Note that the latter may @@ -116,7 +116,7 @@ lose data. ```shell > aerich upgrade -Success upgrade 1_202029051520102929_drop_column.sql +Success upgrade 1_202029051520102929_drop_column.py ``` Now your db is migrated to latest. @@ -142,7 +142,7 @@ Options: ```shell > aerich downgrade -Success downgrade 1_202029051520102929_drop_column.sql +Success downgrade 1_202029051520102929_drop_column.py ``` Now your db is rolled back to the specified version. @@ -152,7 +152,7 @@ Now your db is rolled back to the specified version. ```shell > aerich history -1_202029051520102929_drop_column.sql +1_202029051520102929_drop_column.py ``` ### Show heads to be migrated @@ -160,7 +160,7 @@ Now your db is rolled back to the specified version. ```shell > aerich heads -1_202029051520102929_drop_column.sql +1_202029051520102929_drop_column.py ``` ### Inspect db tables to TortoiseORM model diff --git a/aerich/__init__.py b/aerich/__init__.py index 2eec79c..a4dfe93 100644 --- a/aerich/__init__.py +++ b/aerich/__init__.py @@ -11,14 +11,13 @@ from aerich.exceptions import DowngradeError from aerich.inspectdb.mysql import InspectMySQL from aerich.inspectdb.postgres import InspectPostgres from aerich.inspectdb.sqlite import InspectSQLite -from aerich.migrate import Migrate +from aerich.migrate import MIGRATE_TEMPLATE, Migrate from aerich.models import Aerich from aerich.utils import ( get_app_connection, get_app_connection_name, get_models_describe, - get_version_content_from_file, - write_version_file, + import_py_file, ) @@ -49,10 +48,9 @@ class Command: get_app_connection_name(self.tortoise_config, self.app) ) as conn: file_path = Path(Migrate.migrate_location, version_file) - content = get_version_content_from_file(file_path) - upgrade_query_list = content.get("upgrade") - for upgrade_query in upgrade_query_list: - await conn.execute_script(upgrade_query) + m = import_py_file(file_path) + upgrade = getattr(m, "upgrade") + await upgrade(conn) await Aerich.create( version=version_file, app=self.app, @@ -81,12 +79,11 @@ class Command: get_app_connection_name(self.tortoise_config, self.app) ) as conn: file_path = Path(Migrate.migrate_location, file) - content = get_version_content_from_file(file_path) - downgrade_query_list = content.get("downgrade") - if not downgrade_query_list: + m = import_py_file(file_path) + downgrade = getattr(m, "downgrade", None) + if not downgrade: raise DowngradeError("No downgrade items found") - for downgrade_query in downgrade_query_list: - await conn.execute_query(downgrade_query) + await downgrade(conn) await version.delete() if delete: os.unlink(file_path) @@ -143,7 +140,7 @@ class Command: app=app, content=get_models_describe(app), ) - content = { - "upgrade": [schema], - } - write_version_file(Path(dirname, version), content) + version_file = Path(dirname, version) + content = MIGRATE_TEMPLATE.format(upgrade_sql=f'"""{schema}"""', downgrade_sql="") + with open(version_file, "w", encoding="utf-8") as f: + f.write(content) diff --git a/aerich/cli.py b/aerich/cli.py index 41dc65b..39d9942 100644 --- a/aerich/cli.py +++ b/aerich/cli.py @@ -26,11 +26,11 @@ def coro(f): def wrapper(*args, **kwargs): loop = asyncio.get_event_loop() - # Close db connections at the end of all all but the cli group function + # Close db connections at the end of all but the cli group function try: loop.run_until_complete(f(*args, **kwargs)) finally: - if f.__name__ not in ["cli", "init_db"]: + if f.__name__ not in ["cli", "init_db", "init"]: loop.run_until_complete(Tortoise.close_connections()) return wrapper diff --git a/aerich/migrate.py b/aerich/migrate.py index a38b973..616f976 100644 --- a/aerich/migrate.py +++ b/aerich/migrate.py @@ -13,12 +13,24 @@ from tortoise.indexes import Index from aerich.ddl import BaseDDL from aerich.models import MAX_VERSION_LENGTH, Aerich -from aerich.utils import ( - get_app_connection, - get_models_describe, - is_default_function, - write_version_file, -) +from aerich.utils import get_app_connection, get_models_describe, is_default_function + +MIGRATE_TEMPLATE = """from typing import List + +from tortoise import BaseDBAsyncClient + + +async def upgrade(db: BaseDBAsyncClient) -> List[str]: + return [ + {upgrade_sql} + ] + + +async def downgrade(db: BaseDBAsyncClient) -> List[str]: + return [ + {downgrade_sql} + ] +""" class Migrate: @@ -40,9 +52,9 @@ class Migrate: _db_version: Optional[str] = None @classmethod - def get_all_version_files(cls) -> List[str]: + def get_all_version_files(cls) -> list[str]: return sorted( - filter(lambda x: x.endswith("sql"), os.listdir(cls.migrate_location)), + filter(lambda x: x.endswith("py"), os.listdir(cls.migrate_location)), key=lambda x: int(x.split("_")[0]), ) @@ -97,24 +109,27 @@ class Migrate: now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "") last_version_num = await cls._get_last_version_num() if last_version_num is None: - return f"0_{now}_init.sql" - version = f"{last_version_num + 1}_{now}_{name}.sql" + return f"0_{now}_init.py" + version = f"{last_version_num + 1}_{now}_{name}.py" if len(version) > MAX_VERSION_LENGTH: raise ValueError(f"Version name exceeds maximum length ({MAX_VERSION_LENGTH})") return version @classmethod - async def _generate_diff_sql(cls, name): + async def _generate_diff_py(cls, name): version = await cls.generate_version(name) # delete if same version exists for version_file in cls.get_all_version_files(): if version_file.startswith(version.split("_")[0]): os.unlink(Path(cls.migrate_location, version_file)) - content = { - "upgrade": list(dict.fromkeys(cls.upgrade_operators)), - "downgrade": list(dict.fromkeys(cls.downgrade_operators)), - } - write_version_file(Path(cls.migrate_location, version), content) + + version_file = Path(cls.migrate_location, version) + content = MIGRATE_TEMPLATE.format( + upgrade_sql=",\n ".join(map(lambda x: f'"""{x}"""', cls.upgrade_operators)), + downgrade_sql=",\n ".join(map(lambda x: f'"""{x}"""', cls.downgrade_operators)), + ) + with open(version_file, "w", encoding="utf-8") as f: + f.write(content) return version @classmethod @@ -133,7 +148,7 @@ class Migrate: if not cls.upgrade_operators: return "" - return await cls._generate_diff_sql(name) + return await cls._generate_diff_py(name) @classmethod def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False): diff --git a/aerich/utils.py b/aerich/utils.py index c8bffdc..ac5fb87 100644 --- a/aerich/utils.py +++ b/aerich/utils.py @@ -1,9 +1,9 @@ -import importlib +import importlib.util import os import re import sys from pathlib import Path -from typing import Dict, Union +from typing import Dict from click import BadOptionUsage, ClickException, Context from tortoise import BaseDBAsyncClient, Tortoise @@ -11,7 +11,7 @@ from tortoise import BaseDBAsyncClient, Tortoise def add_src_path(path: str) -> str: """ - add a folder to the paths so we can import from there + add a folder to the paths, so we can import from there :param path: path to add :return: absolute path """ @@ -77,60 +77,6 @@ def get_tortoise_config(ctx: Context, tortoise_orm: str) -> dict: return config -_UPGRADE = "-- upgrade --\n" -_DOWNGRADE = "-- downgrade --\n" - - -def get_version_content_from_file(version_file: Union[str, Path]) -> Dict: - """ - get version content - :param version_file: - :return: - """ - with open(version_file, "r", encoding="utf-8") as f: - content = f.read() - first = content.index(_UPGRADE) - try: - second = content.index(_DOWNGRADE) - except ValueError: - second = len(content) - 1 - upgrade_content = content[first + len(_UPGRADE) : second].strip() # noqa:E203 - downgrade_content = content[second + len(_DOWNGRADE) :].strip() # noqa:E203 - ret = { - "upgrade": list(filter(lambda x: x or False, upgrade_content.split(";\n"))), - "downgrade": list(filter(lambda x: x or False, downgrade_content.split(";\n"))), - } - return ret - - -def write_version_file(version_file: Path, content: Dict): - """ - write version file - :param version_file: - :param content: - :return: - """ - with open(version_file, "w", encoding="utf-8") as f: - f.write(_UPGRADE) - upgrade = content.get("upgrade") - if len(upgrade) > 1: - f.write(";\n".join(upgrade)) - if not upgrade[-1].endswith(";"): - f.write(";\n") - else: - f.write(f"{upgrade[0]}") - if not upgrade[0].endswith(";"): - f.write(";") - f.write("\n") - downgrade = content.get("downgrade") - if downgrade: - f.write(_DOWNGRADE) - if len(downgrade) > 1: - f.write(";\n".join(downgrade) + ";\n") - else: - f.write(f"{downgrade[0]};\n") - - def get_models_describe(app: str) -> Dict: """ get app models describe @@ -146,3 +92,11 @@ def get_models_describe(app: str) -> Dict: def is_default_function(string: str): return re.match(r"^$", str(string or "")) + + +def import_py_file(file: Path): + module_name, file_ext = os.path.splitext(os.path.split(file)[-1]) + spec = importlib.util.spec_from_file_location(module_name, file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module diff --git a/pyproject.toml b/pyproject.toml index e4a059a..060cd36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aerich" -version = "0.6.4" +version = "0.7.0" description = "A database migrations tool for Tortoise ORM." authors = ["long2ice "] license = "Apache-2.0" diff --git a/tests/models.py b/tests/models.py index 597ee59..923c339 100644 --- a/tests/models.py +++ b/tests/models.py @@ -28,8 +28,6 @@ class User(Model): last_login = fields.DatetimeField(description="Last Login", default=datetime.datetime.now) is_active = fields.BooleanField(default=True, description="Is Active") is_superuser = fields.BooleanField(default=False, description="Is SuperUser") - intro = fields.TextField(default="") - longitude = fields.DecimalField(max_digits=10, decimal_places=8) class Email(Model): diff --git a/tests/test_migrate.py b/tests/test_migrate.py index 41d9dde..6ce773a 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -959,18 +959,18 @@ def test_sort_all_version_files(mocker): mocker.patch( "os.listdir", return_value=[ - "1_datetime_update.sql", - "11_datetime_update.sql", - "10_datetime_update.sql", - "2_datetime_update.sql", + "1_datetime_update.py", + "11_datetime_update.py", + "10_datetime_update.py", + "2_datetime_update.py", ], ) Migrate.migrate_location = "." assert Migrate.get_all_version_files() == [ - "1_datetime_update.sql", - "2_datetime_update.sql", - "10_datetime_update.sql", - "11_datetime_update.sql", + "1_datetime_update.py", + "2_datetime_update.py", + "10_datetime_update.py", + "11_datetime_update.py", ] diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..654bc0d --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,6 @@ +from aerich.utils import import_py_file + + +def test_import_py_file(): + m = import_py_file("aerich/utils.py") + assert getattr(m, "import_py_file")