feat: use .py version files

This commit is contained in:
long2ice 2022-09-23 10:29:48 +08:00
parent 8c2ecbaef1
commit e9b76bdd35
10 changed files with 91 additions and 111 deletions

View File

@ -1,8 +1,16 @@
# ChangeLog # ChangeLog
## 0.6 ## 0.7
### 0.6.4 ### 0.7.0
**Now aerich use `.py` file to record versions.**
Upgrade Note:
1. Drop `aerich` table
2. Delete `migrations/models` folder
3. Run `aerich init-db`
- Improve `inspectdb` adding support to `postgresql::numeric` data type - Improve `inspectdb` adding support to `postgresql::numeric` data type
- Add support for dynamically load DDL classes easing to add support to - Add support for dynamically load DDL classes easing to add support to
@ -10,6 +18,8 @@
- Fix decimal field change. (#246) - Fix decimal field change. (#246)
- Support add/remove field with index. - Support add/remove field with index.
## 0.6
### 0.6.3 ### 0.6.3
- Improve `inspectdb` and support `postgres` & `sqlite`. - Improve `inspectdb` and support `postgres` & `sqlite`.

View File

@ -101,11 +101,11 @@ e.g. `aerich --app other_models init-db`.
```shell ```shell
> aerich migrate --name drop_column > aerich migrate --name drop_column
Success migrate 1_202029051520102929_drop_column.sql Success migrate 1_202029051520102929_drop_column.py
``` ```
Format of migrate filename is Format of migrate filename is
`{version_num}_{datetime}_{name|update}.sql`. `{version_num}_{datetime}_{name|update}.py`.
If `aerich` guesses you are renaming a column, it will ask `Rename {old_column} to {new_column} [True]`. You can choose If `aerich` guesses you are renaming a column, it will ask `Rename {old_column} to {new_column} [True]`. You can choose
`True` to rename column without column drop, or choose `False` to drop the column then create. Note that the latter may `True` to rename column without column drop, or choose `False` to drop the column then create. Note that the latter may
@ -116,7 +116,7 @@ lose data.
```shell ```shell
> aerich upgrade > aerich upgrade
Success upgrade 1_202029051520102929_drop_column.sql Success upgrade 1_202029051520102929_drop_column.py
``` ```
Now your db is migrated to latest. Now your db is migrated to latest.
@ -142,7 +142,7 @@ Options:
```shell ```shell
> aerich downgrade > aerich downgrade
Success downgrade 1_202029051520102929_drop_column.sql Success downgrade 1_202029051520102929_drop_column.py
``` ```
Now your db is rolled back to the specified version. Now your db is rolled back to the specified version.
@ -152,7 +152,7 @@ Now your db is rolled back to the specified version.
```shell ```shell
> aerich history > aerich history
1_202029051520102929_drop_column.sql 1_202029051520102929_drop_column.py
``` ```
### Show heads to be migrated ### Show heads to be migrated
@ -160,7 +160,7 @@ Now your db is rolled back to the specified version.
```shell ```shell
> aerich heads > aerich heads
1_202029051520102929_drop_column.sql 1_202029051520102929_drop_column.py
``` ```
### Inspect db tables to TortoiseORM model ### Inspect db tables to TortoiseORM model

View File

@ -11,14 +11,13 @@ from aerich.exceptions import DowngradeError
from aerich.inspectdb.mysql import InspectMySQL from aerich.inspectdb.mysql import InspectMySQL
from aerich.inspectdb.postgres import InspectPostgres from aerich.inspectdb.postgres import InspectPostgres
from aerich.inspectdb.sqlite import InspectSQLite from aerich.inspectdb.sqlite import InspectSQLite
from aerich.migrate import Migrate from aerich.migrate import MIGRATE_TEMPLATE, Migrate
from aerich.models import Aerich from aerich.models import Aerich
from aerich.utils import ( from aerich.utils import (
get_app_connection, get_app_connection,
get_app_connection_name, get_app_connection_name,
get_models_describe, get_models_describe,
get_version_content_from_file, import_py_file,
write_version_file,
) )
@ -49,10 +48,9 @@ class Command:
get_app_connection_name(self.tortoise_config, self.app) get_app_connection_name(self.tortoise_config, self.app)
) as conn: ) as conn:
file_path = Path(Migrate.migrate_location, version_file) file_path = Path(Migrate.migrate_location, version_file)
content = get_version_content_from_file(file_path) m = import_py_file(file_path)
upgrade_query_list = content.get("upgrade") upgrade = getattr(m, "upgrade")
for upgrade_query in upgrade_query_list: await upgrade(conn)
await conn.execute_script(upgrade_query)
await Aerich.create( await Aerich.create(
version=version_file, version=version_file,
app=self.app, app=self.app,
@ -81,12 +79,11 @@ class Command:
get_app_connection_name(self.tortoise_config, self.app) get_app_connection_name(self.tortoise_config, self.app)
) as conn: ) as conn:
file_path = Path(Migrate.migrate_location, file) file_path = Path(Migrate.migrate_location, file)
content = get_version_content_from_file(file_path) m = import_py_file(file_path)
downgrade_query_list = content.get("downgrade") downgrade = getattr(m, "downgrade", None)
if not downgrade_query_list: if not downgrade:
raise DowngradeError("No downgrade items found") raise DowngradeError("No downgrade items found")
for downgrade_query in downgrade_query_list: await downgrade(conn)
await conn.execute_query(downgrade_query)
await version.delete() await version.delete()
if delete: if delete:
os.unlink(file_path) os.unlink(file_path)
@ -143,7 +140,7 @@ class Command:
app=app, app=app,
content=get_models_describe(app), content=get_models_describe(app),
) )
content = { version_file = Path(dirname, version)
"upgrade": [schema], content = MIGRATE_TEMPLATE.format(upgrade_sql=f'"""{schema}"""', downgrade_sql="")
} with open(version_file, "w", encoding="utf-8") as f:
write_version_file(Path(dirname, version), content) f.write(content)

View File

@ -26,11 +26,11 @@ def coro(f):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
# Close db connections at the end of all all but the cli group function # Close db connections at the end of all but the cli group function
try: try:
loop.run_until_complete(f(*args, **kwargs)) loop.run_until_complete(f(*args, **kwargs))
finally: finally:
if f.__name__ not in ["cli", "init_db"]: if f.__name__ not in ["cli", "init_db", "init"]:
loop.run_until_complete(Tortoise.close_connections()) loop.run_until_complete(Tortoise.close_connections())
return wrapper return wrapper

View File

@ -13,12 +13,24 @@ from tortoise.indexes import Index
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
from aerich.models import MAX_VERSION_LENGTH, Aerich from aerich.models import MAX_VERSION_LENGTH, Aerich
from aerich.utils import ( from aerich.utils import get_app_connection, get_models_describe, is_default_function
get_app_connection,
get_models_describe, MIGRATE_TEMPLATE = """from typing import List
is_default_function,
write_version_file, from tortoise import BaseDBAsyncClient
)
async def upgrade(db: BaseDBAsyncClient) -> List[str]:
return [
{upgrade_sql}
]
async def downgrade(db: BaseDBAsyncClient) -> List[str]:
return [
{downgrade_sql}
]
"""
class Migrate: class Migrate:
@ -40,9 +52,9 @@ class Migrate:
_db_version: Optional[str] = None _db_version: Optional[str] = None
@classmethod @classmethod
def get_all_version_files(cls) -> List[str]: def get_all_version_files(cls) -> list[str]:
return sorted( return sorted(
filter(lambda x: x.endswith("sql"), os.listdir(cls.migrate_location)), filter(lambda x: x.endswith("py"), os.listdir(cls.migrate_location)),
key=lambda x: int(x.split("_")[0]), key=lambda x: int(x.split("_")[0]),
) )
@ -97,24 +109,27 @@ class Migrate:
now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "") now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "")
last_version_num = await cls._get_last_version_num() last_version_num = await cls._get_last_version_num()
if last_version_num is None: if last_version_num is None:
return f"0_{now}_init.sql" return f"0_{now}_init.py"
version = f"{last_version_num + 1}_{now}_{name}.sql" version = f"{last_version_num + 1}_{now}_{name}.py"
if len(version) > MAX_VERSION_LENGTH: if len(version) > MAX_VERSION_LENGTH:
raise ValueError(f"Version name exceeds maximum length ({MAX_VERSION_LENGTH})") raise ValueError(f"Version name exceeds maximum length ({MAX_VERSION_LENGTH})")
return version return version
@classmethod @classmethod
async def _generate_diff_sql(cls, name): async def _generate_diff_py(cls, name):
version = await cls.generate_version(name) version = await cls.generate_version(name)
# delete if same version exists # delete if same version exists
for version_file in cls.get_all_version_files(): for version_file in cls.get_all_version_files():
if version_file.startswith(version.split("_")[0]): if version_file.startswith(version.split("_")[0]):
os.unlink(Path(cls.migrate_location, version_file)) os.unlink(Path(cls.migrate_location, version_file))
content = {
"upgrade": list(dict.fromkeys(cls.upgrade_operators)), version_file = Path(cls.migrate_location, version)
"downgrade": list(dict.fromkeys(cls.downgrade_operators)), content = MIGRATE_TEMPLATE.format(
} upgrade_sql=",\n ".join(map(lambda x: f'"""{x}"""', cls.upgrade_operators)),
write_version_file(Path(cls.migrate_location, version), content) downgrade_sql=",\n ".join(map(lambda x: f'"""{x}"""', cls.downgrade_operators)),
)
with open(version_file, "w", encoding="utf-8") as f:
f.write(content)
return version return version
@classmethod @classmethod
@ -133,7 +148,7 @@ class Migrate:
if not cls.upgrade_operators: if not cls.upgrade_operators:
return "" return ""
return await cls._generate_diff_sql(name) return await cls._generate_diff_py(name)
@classmethod @classmethod
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False): def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False):

View File

@ -1,9 +1,9 @@
import importlib import importlib.util
import os import os
import re import re
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict, Union from typing import Dict
from click import BadOptionUsage, ClickException, Context from click import BadOptionUsage, ClickException, Context
from tortoise import BaseDBAsyncClient, Tortoise from tortoise import BaseDBAsyncClient, Tortoise
@ -11,7 +11,7 @@ from tortoise import BaseDBAsyncClient, Tortoise
def add_src_path(path: str) -> str: def add_src_path(path: str) -> str:
""" """
add a folder to the paths so we can import from there add a folder to the paths, so we can import from there
:param path: path to add :param path: path to add
:return: absolute path :return: absolute path
""" """
@ -77,60 +77,6 @@ def get_tortoise_config(ctx: Context, tortoise_orm: str) -> dict:
return config return config
_UPGRADE = "-- upgrade --\n"
_DOWNGRADE = "-- downgrade --\n"
def get_version_content_from_file(version_file: Union[str, Path]) -> Dict:
"""
get version content
:param version_file:
:return:
"""
with open(version_file, "r", encoding="utf-8") as f:
content = f.read()
first = content.index(_UPGRADE)
try:
second = content.index(_DOWNGRADE)
except ValueError:
second = len(content) - 1
upgrade_content = content[first + len(_UPGRADE) : second].strip() # noqa:E203
downgrade_content = content[second + len(_DOWNGRADE) :].strip() # noqa:E203
ret = {
"upgrade": list(filter(lambda x: x or False, upgrade_content.split(";\n"))),
"downgrade": list(filter(lambda x: x or False, downgrade_content.split(";\n"))),
}
return ret
def write_version_file(version_file: Path, content: Dict):
"""
write version file
:param version_file:
:param content:
:return:
"""
with open(version_file, "w", encoding="utf-8") as f:
f.write(_UPGRADE)
upgrade = content.get("upgrade")
if len(upgrade) > 1:
f.write(";\n".join(upgrade))
if not upgrade[-1].endswith(";"):
f.write(";\n")
else:
f.write(f"{upgrade[0]}")
if not upgrade[0].endswith(";"):
f.write(";")
f.write("\n")
downgrade = content.get("downgrade")
if downgrade:
f.write(_DOWNGRADE)
if len(downgrade) > 1:
f.write(";\n".join(downgrade) + ";\n")
else:
f.write(f"{downgrade[0]};\n")
def get_models_describe(app: str) -> Dict: def get_models_describe(app: str) -> Dict:
""" """
get app models describe get app models describe
@ -146,3 +92,11 @@ def get_models_describe(app: str) -> Dict:
def is_default_function(string: str): def is_default_function(string: str):
return re.match(r"^<function.+>$", str(string or "")) return re.match(r"^<function.+>$", str(string or ""))
def import_py_file(file: Path):
module_name, file_ext = os.path.splitext(os.path.split(file)[-1])
spec = importlib.util.spec_from_file_location(module_name, file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module

View File

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "aerich" name = "aerich"
version = "0.6.4" version = "0.7.0"
description = "A database migrations tool for Tortoise ORM." description = "A database migrations tool for Tortoise ORM."
authors = ["long2ice <long2ice@gmail.com>"] authors = ["long2ice <long2ice@gmail.com>"]
license = "Apache-2.0" license = "Apache-2.0"

View File

@ -28,8 +28,6 @@ class User(Model):
last_login = fields.DatetimeField(description="Last Login", default=datetime.datetime.now) last_login = fields.DatetimeField(description="Last Login", default=datetime.datetime.now)
is_active = fields.BooleanField(default=True, description="Is Active") is_active = fields.BooleanField(default=True, description="Is Active")
is_superuser = fields.BooleanField(default=False, description="Is SuperUser") is_superuser = fields.BooleanField(default=False, description="Is SuperUser")
intro = fields.TextField(default="")
longitude = fields.DecimalField(max_digits=10, decimal_places=8)
class Email(Model): class Email(Model):

View File

@ -959,18 +959,18 @@ def test_sort_all_version_files(mocker):
mocker.patch( mocker.patch(
"os.listdir", "os.listdir",
return_value=[ return_value=[
"1_datetime_update.sql", "1_datetime_update.py",
"11_datetime_update.sql", "11_datetime_update.py",
"10_datetime_update.sql", "10_datetime_update.py",
"2_datetime_update.sql", "2_datetime_update.py",
], ],
) )
Migrate.migrate_location = "." Migrate.migrate_location = "."
assert Migrate.get_all_version_files() == [ assert Migrate.get_all_version_files() == [
"1_datetime_update.sql", "1_datetime_update.py",
"2_datetime_update.sql", "2_datetime_update.py",
"10_datetime_update.sql", "10_datetime_update.py",
"11_datetime_update.sql", "11_datetime_update.py",
] ]

6
tests/test_utils.py Normal file
View File

@ -0,0 +1,6 @@
from aerich.utils import import_py_file
def test_import_py_file():
m = import_py_file("aerich/utils.py")
assert getattr(m, "import_py_file")