Use .sql instead of .json to store version file. (#79)
This commit is contained in:
@@ -1 +1 @@
|
||||
__version__ = "0.3.3"
|
||||
__version__ = "0.4.0"
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from configparser import ConfigParser
|
||||
@@ -13,7 +12,13 @@ from tortoise.transactions import in_transaction
|
||||
from tortoise.utils import get_schema_sql
|
||||
|
||||
from aerich.migrate import Migrate
|
||||
from aerich.utils import get_app_connection, get_app_connection_name, get_tortoise_config
|
||||
from aerich.utils import (
|
||||
get_app_connection,
|
||||
get_app_connection_name,
|
||||
get_tortoise_config,
|
||||
get_version_content_from_file,
|
||||
write_version_file,
|
||||
)
|
||||
|
||||
from . import __version__
|
||||
from .enums import Color
|
||||
@@ -105,11 +110,11 @@ async def upgrade(ctx: Context):
|
||||
if not exists:
|
||||
async with in_transaction(get_app_connection_name(config, app)) as conn:
|
||||
file_path = os.path.join(Migrate.migrate_location, version_file)
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = json.load(f)
|
||||
upgrade_query_list = content.get("upgrade")
|
||||
for upgrade_query in upgrade_query_list:
|
||||
await conn.execute_script(upgrade_query)
|
||||
content = get_version_content_from_file(file_path)
|
||||
upgrade_query_list = content.get("upgrade")
|
||||
print(upgrade_query_list)
|
||||
for upgrade_query in upgrade_query_list:
|
||||
await conn.execute_script(upgrade_query)
|
||||
await Aerich.create(
|
||||
version=version_file,
|
||||
app=app,
|
||||
@@ -152,14 +157,13 @@ async def downgrade(ctx: Context, version: int):
|
||||
file = version.version
|
||||
async with in_transaction(get_app_connection_name(config, app)) as conn:
|
||||
file_path = os.path.join(Migrate.migrate_location, file)
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
content = json.load(f)
|
||||
downgrade_query_list = content.get("downgrade")
|
||||
if not downgrade_query_list:
|
||||
return click.secho("No downgrade item found", fg=Color.yellow)
|
||||
for downgrade_query in downgrade_query_list:
|
||||
await conn.execute_query(downgrade_query)
|
||||
await version.delete()
|
||||
content = get_version_content_from_file(file_path)
|
||||
downgrade_query_list = content.get("downgrade")
|
||||
if not downgrade_query_list:
|
||||
return click.secho("No downgrade items found", fg=Color.yellow)
|
||||
for downgrade_query in downgrade_query_list:
|
||||
await conn.execute_query(downgrade_query)
|
||||
await version.delete()
|
||||
os.unlink(file_path)
|
||||
click.secho(f"Success downgrade {file}", fg=Color.green)
|
||||
|
||||
@@ -263,11 +267,10 @@ async def init_db(ctx: Context, safe):
|
||||
app=app,
|
||||
content=Migrate.get_models_content(config, app, location),
|
||||
)
|
||||
with open(os.path.join(dirname, version), "w", encoding="utf-8") as f:
|
||||
content = {
|
||||
"upgrade": [schema],
|
||||
}
|
||||
json.dump(content, f, ensure_ascii=False, indent=2)
|
||||
content = {
|
||||
"upgrade": [schema],
|
||||
}
|
||||
write_version_file(os.path.join(dirname, version), content)
|
||||
return click.secho(f'Success generate schema for app "{app}"', fg=Color.green)
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
@@ -21,7 +20,7 @@ from tortoise.fields import Field
|
||||
|
||||
from aerich.ddl import BaseDDL
|
||||
from aerich.models import MAX_VERSION_LENGTH, Aerich
|
||||
from aerich.utils import get_app_connection
|
||||
from aerich.utils import get_app_connection, write_version_file
|
||||
|
||||
|
||||
class Migrate:
|
||||
@@ -50,7 +49,7 @@ class Migrate:
|
||||
@classmethod
|
||||
def get_all_version_files(cls) -> List[str]:
|
||||
return sorted(
|
||||
filter(lambda x: x.endswith("json"), os.listdir(cls.migrate_location)),
|
||||
filter(lambda x: x.endswith("sql"), os.listdir(cls.migrate_location)),
|
||||
key=lambda x: int(x.split("_")[0]),
|
||||
)
|
||||
|
||||
@@ -111,8 +110,8 @@ class Migrate:
|
||||
now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "")
|
||||
last_version_num = await cls._get_last_version_num()
|
||||
if last_version_num is None:
|
||||
return f"0_{now}_init.json"
|
||||
version = f"{last_version_num + 1}_{now}_{name}.json"
|
||||
return f"0_{now}_init.sql"
|
||||
version = f"{last_version_num + 1}_{now}_{name}.sql"
|
||||
if len(version) > MAX_VERSION_LENGTH:
|
||||
raise ValueError(f"Version name exceeds maximum length ({MAX_VERSION_LENGTH})")
|
||||
return version
|
||||
@@ -128,8 +127,7 @@ class Migrate:
|
||||
"upgrade": cls.upgrade_operators,
|
||||
"downgrade": cls.downgrade_operators,
|
||||
}
|
||||
with open(os.path.join(cls.migrate_location, version), "w", encoding="utf-8") as f:
|
||||
json.dump(content, f, indent=2, ensure_ascii=False)
|
||||
write_version_file(os.path.join(cls.migrate_location, version), content)
|
||||
return version
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import importlib
|
||||
from typing import Dict
|
||||
|
||||
from click import BadOptionUsage, Context
|
||||
from tortoise import BaseDBAsyncClient, Tortoise
|
||||
@@ -49,3 +50,46 @@ def get_tortoise_config(ctx: Context, tortoise_orm: str) -> dict:
|
||||
ctx=ctx,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
_UPGRADE = "##### upgrade #####\n"
|
||||
_DOWNGRADE = "##### downgrade #####\n"
|
||||
|
||||
|
||||
def get_version_content_from_file(version_file: str) -> Dict:
|
||||
"""
|
||||
get version content
|
||||
:param version_file:
|
||||
:return:
|
||||
"""
|
||||
with open(version_file, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
first = content.index(_UPGRADE)
|
||||
second = content.index(_DOWNGRADE)
|
||||
upgrade_content = content[first + len(_UPGRADE) : second].strip() # noqa:E203
|
||||
downgrade_content = content[second + len(_DOWNGRADE) :].strip() # noqa:E203
|
||||
ret = {"upgrade": upgrade_content.split("\n"), "downgrade": downgrade_content.split("\n")}
|
||||
return ret
|
||||
|
||||
|
||||
def write_version_file(version_file: str, content: Dict):
|
||||
"""
|
||||
write version file
|
||||
:param version_file:
|
||||
:param content:
|
||||
:return:
|
||||
"""
|
||||
with open(version_file, "w", encoding="utf-8") as f:
|
||||
f.write(_UPGRADE)
|
||||
upgrade = content.get("upgrade")
|
||||
if len(upgrade) > 1:
|
||||
f.write(";\n".join(upgrade) + ";\n")
|
||||
else:
|
||||
f.write(f"{upgrade[0]};\n")
|
||||
downgrade = content.get("downgrade")
|
||||
if downgrade:
|
||||
f.write(_DOWNGRADE)
|
||||
if len(downgrade) > 1:
|
||||
f.write(";\n".join(downgrade) + ";\n")
|
||||
else:
|
||||
f.write(f"{downgrade[0]};\n")
|
||||
|
||||
Reference in New Issue
Block a user