feat: use .py version files

This commit is contained in:
long2ice 2022-09-23 10:29:48 +08:00
parent 8c2ecbaef1
commit e9b76bdd35
10 changed files with 91 additions and 111 deletions

View File

@ -1,8 +1,16 @@
# ChangeLog
## 0.6
## 0.7
### 0.6.4
### 0.7.0
**Now aerich use `.py` file to record versions.**
Upgrade Note:
1. Drop `aerich` table
2. Delete `migrations/models` folder
3. Run `aerich init-db`
- Improve `inspectdb` adding support to `postgresql::numeric` data type
- Add support for dynamically load DDL classes easing to add support to
@ -10,6 +18,8 @@
- Fix decimal field change. (#246)
- Support add/remove field with index.
## 0.6
### 0.6.3
- Improve `inspectdb` and support `postgres` & `sqlite`.

View File

@ -101,11 +101,11 @@ e.g. `aerich --app other_models init-db`.
```shell
> aerich migrate --name drop_column
Success migrate 1_202029051520102929_drop_column.sql
Success migrate 1_202029051520102929_drop_column.py
```
Format of migrate filename is
`{version_num}_{datetime}_{name|update}.sql`.
`{version_num}_{datetime}_{name|update}.py`.
If `aerich` guesses you are renaming a column, it will ask `Rename {old_column} to {new_column} [True]`. You can choose
`True` to rename column without column drop, or choose `False` to drop the column then create. Note that the latter may
@ -116,7 +116,7 @@ lose data.
```shell
> aerich upgrade
Success upgrade 1_202029051520102929_drop_column.sql
Success upgrade 1_202029051520102929_drop_column.py
```
Now your db is migrated to latest.
@ -142,7 +142,7 @@ Options:
```shell
> aerich downgrade
Success downgrade 1_202029051520102929_drop_column.sql
Success downgrade 1_202029051520102929_drop_column.py
```
Now your db is rolled back to the specified version.
@ -152,7 +152,7 @@ Now your db is rolled back to the specified version.
```shell
> aerich history
1_202029051520102929_drop_column.sql
1_202029051520102929_drop_column.py
```
### Show heads to be migrated
@ -160,7 +160,7 @@ Now your db is rolled back to the specified version.
```shell
> aerich heads
1_202029051520102929_drop_column.sql
1_202029051520102929_drop_column.py
```
### Inspect db tables to TortoiseORM model

View File

@ -11,14 +11,13 @@ from aerich.exceptions import DowngradeError
from aerich.inspectdb.mysql import InspectMySQL
from aerich.inspectdb.postgres import InspectPostgres
from aerich.inspectdb.sqlite import InspectSQLite
from aerich.migrate import Migrate
from aerich.migrate import MIGRATE_TEMPLATE, Migrate
from aerich.models import Aerich
from aerich.utils import (
get_app_connection,
get_app_connection_name,
get_models_describe,
get_version_content_from_file,
write_version_file,
import_py_file,
)
@ -49,10 +48,9 @@ class Command:
get_app_connection_name(self.tortoise_config, self.app)
) as conn:
file_path = Path(Migrate.migrate_location, version_file)
content = get_version_content_from_file(file_path)
upgrade_query_list = content.get("upgrade")
for upgrade_query in upgrade_query_list:
await conn.execute_script(upgrade_query)
m = import_py_file(file_path)
upgrade = getattr(m, "upgrade")
await upgrade(conn)
await Aerich.create(
version=version_file,
app=self.app,
@ -81,12 +79,11 @@ class Command:
get_app_connection_name(self.tortoise_config, self.app)
) as conn:
file_path = Path(Migrate.migrate_location, file)
content = get_version_content_from_file(file_path)
downgrade_query_list = content.get("downgrade")
if not downgrade_query_list:
m = import_py_file(file_path)
downgrade = getattr(m, "downgrade", None)
if not downgrade:
raise DowngradeError("No downgrade items found")
for downgrade_query in downgrade_query_list:
await conn.execute_query(downgrade_query)
await downgrade(conn)
await version.delete()
if delete:
os.unlink(file_path)
@ -143,7 +140,7 @@ class Command:
app=app,
content=get_models_describe(app),
)
content = {
"upgrade": [schema],
}
write_version_file(Path(dirname, version), content)
version_file = Path(dirname, version)
content = MIGRATE_TEMPLATE.format(upgrade_sql=f'"""{schema}"""', downgrade_sql="")
with open(version_file, "w", encoding="utf-8") as f:
f.write(content)

View File

@ -26,11 +26,11 @@ def coro(f):
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
# Close db connections at the end of all all but the cli group function
# Close db connections at the end of all but the cli group function
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
if f.__name__ not in ["cli", "init_db"]:
if f.__name__ not in ["cli", "init_db", "init"]:
loop.run_until_complete(Tortoise.close_connections())
return wrapper

View File

@ -13,12 +13,24 @@ from tortoise.indexes import Index
from aerich.ddl import BaseDDL
from aerich.models import MAX_VERSION_LENGTH, Aerich
from aerich.utils import (
get_app_connection,
get_models_describe,
is_default_function,
write_version_file,
)
from aerich.utils import get_app_connection, get_models_describe, is_default_function
MIGRATE_TEMPLATE = """from typing import List
from tortoise import BaseDBAsyncClient
async def upgrade(db: BaseDBAsyncClient) -> List[str]:
return [
{upgrade_sql}
]
async def downgrade(db: BaseDBAsyncClient) -> List[str]:
return [
{downgrade_sql}
]
"""
class Migrate:
@ -40,9 +52,9 @@ class Migrate:
_db_version: Optional[str] = None
@classmethod
def get_all_version_files(cls) -> List[str]:
def get_all_version_files(cls) -> list[str]:
return sorted(
filter(lambda x: x.endswith("sql"), os.listdir(cls.migrate_location)),
filter(lambda x: x.endswith("py"), os.listdir(cls.migrate_location)),
key=lambda x: int(x.split("_")[0]),
)
@ -97,24 +109,27 @@ class Migrate:
now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "")
last_version_num = await cls._get_last_version_num()
if last_version_num is None:
return f"0_{now}_init.sql"
version = f"{last_version_num + 1}_{now}_{name}.sql"
return f"0_{now}_init.py"
version = f"{last_version_num + 1}_{now}_{name}.py"
if len(version) > MAX_VERSION_LENGTH:
raise ValueError(f"Version name exceeds maximum length ({MAX_VERSION_LENGTH})")
return version
@classmethod
async def _generate_diff_sql(cls, name):
async def _generate_diff_py(cls, name):
version = await cls.generate_version(name)
# delete if same version exists
for version_file in cls.get_all_version_files():
if version_file.startswith(version.split("_")[0]):
os.unlink(Path(cls.migrate_location, version_file))
content = {
"upgrade": list(dict.fromkeys(cls.upgrade_operators)),
"downgrade": list(dict.fromkeys(cls.downgrade_operators)),
}
write_version_file(Path(cls.migrate_location, version), content)
version_file = Path(cls.migrate_location, version)
content = MIGRATE_TEMPLATE.format(
upgrade_sql=",\n ".join(map(lambda x: f'"""{x}"""', cls.upgrade_operators)),
downgrade_sql=",\n ".join(map(lambda x: f'"""{x}"""', cls.downgrade_operators)),
)
with open(version_file, "w", encoding="utf-8") as f:
f.write(content)
return version
@classmethod
@ -133,7 +148,7 @@ class Migrate:
if not cls.upgrade_operators:
return ""
return await cls._generate_diff_sql(name)
return await cls._generate_diff_py(name)
@classmethod
def _add_operator(cls, operator: str, upgrade=True, fk_m2m_index=False):

View File

@ -1,9 +1,9 @@
import importlib
import importlib.util
import os
import re
import sys
from pathlib import Path
from typing import Dict, Union
from typing import Dict
from click import BadOptionUsage, ClickException, Context
from tortoise import BaseDBAsyncClient, Tortoise
@ -11,7 +11,7 @@ from tortoise import BaseDBAsyncClient, Tortoise
def add_src_path(path: str) -> str:
"""
add a folder to the paths so we can import from there
add a folder to the paths, so we can import from there
:param path: path to add
:return: absolute path
"""
@ -77,60 +77,6 @@ def get_tortoise_config(ctx: Context, tortoise_orm: str) -> dict:
return config
_UPGRADE = "-- upgrade --\n"
_DOWNGRADE = "-- downgrade --\n"
def get_version_content_from_file(version_file: Union[str, Path]) -> Dict:
"""
get version content
:param version_file:
:return:
"""
with open(version_file, "r", encoding="utf-8") as f:
content = f.read()
first = content.index(_UPGRADE)
try:
second = content.index(_DOWNGRADE)
except ValueError:
second = len(content) - 1
upgrade_content = content[first + len(_UPGRADE) : second].strip() # noqa:E203
downgrade_content = content[second + len(_DOWNGRADE) :].strip() # noqa:E203
ret = {
"upgrade": list(filter(lambda x: x or False, upgrade_content.split(";\n"))),
"downgrade": list(filter(lambda x: x or False, downgrade_content.split(";\n"))),
}
return ret
def write_version_file(version_file: Path, content: Dict):
"""
write version file
:param version_file:
:param content:
:return:
"""
with open(version_file, "w", encoding="utf-8") as f:
f.write(_UPGRADE)
upgrade = content.get("upgrade")
if len(upgrade) > 1:
f.write(";\n".join(upgrade))
if not upgrade[-1].endswith(";"):
f.write(";\n")
else:
f.write(f"{upgrade[0]}")
if not upgrade[0].endswith(";"):
f.write(";")
f.write("\n")
downgrade = content.get("downgrade")
if downgrade:
f.write(_DOWNGRADE)
if len(downgrade) > 1:
f.write(";\n".join(downgrade) + ";\n")
else:
f.write(f"{downgrade[0]};\n")
def get_models_describe(app: str) -> Dict:
"""
get app models describe
@ -146,3 +92,11 @@ def get_models_describe(app: str) -> Dict:
def is_default_function(string: str):
return re.match(r"^<function.+>$", str(string or ""))
def import_py_file(file: Path):
module_name, file_ext = os.path.splitext(os.path.split(file)[-1])
spec = importlib.util.spec_from_file_location(module_name, file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "aerich"
version = "0.6.4"
version = "0.7.0"
description = "A database migrations tool for Tortoise ORM."
authors = ["long2ice <long2ice@gmail.com>"]
license = "Apache-2.0"

View File

@ -28,8 +28,6 @@ class User(Model):
last_login = fields.DatetimeField(description="Last Login", default=datetime.datetime.now)
is_active = fields.BooleanField(default=True, description="Is Active")
is_superuser = fields.BooleanField(default=False, description="Is SuperUser")
intro = fields.TextField(default="")
longitude = fields.DecimalField(max_digits=10, decimal_places=8)
class Email(Model):

View File

@ -959,18 +959,18 @@ def test_sort_all_version_files(mocker):
mocker.patch(
"os.listdir",
return_value=[
"1_datetime_update.sql",
"11_datetime_update.sql",
"10_datetime_update.sql",
"2_datetime_update.sql",
"1_datetime_update.py",
"11_datetime_update.py",
"10_datetime_update.py",
"2_datetime_update.py",
],
)
Migrate.migrate_location = "."
assert Migrate.get_all_version_files() == [
"1_datetime_update.sql",
"2_datetime_update.sql",
"10_datetime_update.sql",
"11_datetime_update.sql",
"1_datetime_update.py",
"2_datetime_update.py",
"10_datetime_update.py",
"11_datetime_update.py",
]

6
tests/test_utils.py Normal file
View File

@ -0,0 +1,6 @@
from aerich.utils import import_py_file
def test_import_py_file():
m = import_py_file("aerich/utils.py")
assert getattr(m, "import_py_file")