Use .sql instead of .json to store version file. (#79)

This commit is contained in:
long2ice
2020-11-16 22:25:01 +08:00
parent f5588a35c5
commit b21b954d32
8 changed files with 159 additions and 93 deletions

View File

@@ -1 +1 @@
__version__ = "0.3.3"
__version__ = "0.4.0"

View File

@@ -1,5 +1,4 @@
import asyncio
import json
import os
import sys
from configparser import ConfigParser
@@ -13,7 +12,13 @@ from tortoise.transactions import in_transaction
from tortoise.utils import get_schema_sql
from aerich.migrate import Migrate
from aerich.utils import get_app_connection, get_app_connection_name, get_tortoise_config
from aerich.utils import (
get_app_connection,
get_app_connection_name,
get_tortoise_config,
get_version_content_from_file,
write_version_file,
)
from . import __version__
from .enums import Color
@@ -105,11 +110,11 @@ async def upgrade(ctx: Context):
if not exists:
async with in_transaction(get_app_connection_name(config, app)) as conn:
file_path = os.path.join(Migrate.migrate_location, version_file)
with open(file_path, "r", encoding="utf-8") as f:
content = json.load(f)
upgrade_query_list = content.get("upgrade")
for upgrade_query in upgrade_query_list:
await conn.execute_script(upgrade_query)
content = get_version_content_from_file(file_path)
upgrade_query_list = content.get("upgrade")
print(upgrade_query_list)
for upgrade_query in upgrade_query_list:
await conn.execute_script(upgrade_query)
await Aerich.create(
version=version_file,
app=app,
@@ -152,14 +157,13 @@ async def downgrade(ctx: Context, version: int):
file = version.version
async with in_transaction(get_app_connection_name(config, app)) as conn:
file_path = os.path.join(Migrate.migrate_location, file)
with open(file_path, "r", encoding="utf-8") as f:
content = json.load(f)
downgrade_query_list = content.get("downgrade")
if not downgrade_query_list:
return click.secho("No downgrade item found", fg=Color.yellow)
for downgrade_query in downgrade_query_list:
await conn.execute_query(downgrade_query)
await version.delete()
content = get_version_content_from_file(file_path)
downgrade_query_list = content.get("downgrade")
if not downgrade_query_list:
return click.secho("No downgrade items found", fg=Color.yellow)
for downgrade_query in downgrade_query_list:
await conn.execute_query(downgrade_query)
await version.delete()
os.unlink(file_path)
click.secho(f"Success downgrade {file}", fg=Color.green)
@@ -263,11 +267,10 @@ async def init_db(ctx: Context, safe):
app=app,
content=Migrate.get_models_content(config, app, location),
)
with open(os.path.join(dirname, version), "w", encoding="utf-8") as f:
content = {
"upgrade": [schema],
}
json.dump(content, f, ensure_ascii=False, indent=2)
content = {
"upgrade": [schema],
}
write_version_file(os.path.join(dirname, version), content)
return click.secho(f'Success generate schema for app "{app}"', fg=Color.green)

View File

@@ -1,5 +1,4 @@
import inspect
import json
import os
import re
from datetime import datetime
@@ -21,7 +20,7 @@ from tortoise.fields import Field
from aerich.ddl import BaseDDL
from aerich.models import MAX_VERSION_LENGTH, Aerich
from aerich.utils import get_app_connection
from aerich.utils import get_app_connection, write_version_file
class Migrate:
@@ -50,7 +49,7 @@ class Migrate:
@classmethod
def get_all_version_files(cls) -> List[str]:
return sorted(
filter(lambda x: x.endswith("json"), os.listdir(cls.migrate_location)),
filter(lambda x: x.endswith("sql"), os.listdir(cls.migrate_location)),
key=lambda x: int(x.split("_")[0]),
)
@@ -111,8 +110,8 @@ class Migrate:
now = datetime.now().strftime("%Y%m%d%H%M%S").replace("/", "")
last_version_num = await cls._get_last_version_num()
if last_version_num is None:
return f"0_{now}_init.json"
version = f"{last_version_num + 1}_{now}_{name}.json"
return f"0_{now}_init.sql"
version = f"{last_version_num + 1}_{now}_{name}.sql"
if len(version) > MAX_VERSION_LENGTH:
raise ValueError(f"Version name exceeds maximum length ({MAX_VERSION_LENGTH})")
return version
@@ -128,8 +127,7 @@ class Migrate:
"upgrade": cls.upgrade_operators,
"downgrade": cls.downgrade_operators,
}
with open(os.path.join(cls.migrate_location, version), "w", encoding="utf-8") as f:
json.dump(content, f, indent=2, ensure_ascii=False)
write_version_file(os.path.join(cls.migrate_location, version), content)
return version
@classmethod

View File

@@ -1,4 +1,5 @@
import importlib
from typing import Dict
from click import BadOptionUsage, Context
from tortoise import BaseDBAsyncClient, Tortoise
@@ -49,3 +50,46 @@ def get_tortoise_config(ctx: Context, tortoise_orm: str) -> dict:
ctx=ctx,
)
return config
_UPGRADE = "##### upgrade #####\n"
_DOWNGRADE = "##### downgrade #####\n"
def get_version_content_from_file(version_file: str) -> Dict:
"""
get version content
:param version_file:
:return:
"""
with open(version_file, "r", encoding="utf-8") as f:
content = f.read()
first = content.index(_UPGRADE)
second = content.index(_DOWNGRADE)
upgrade_content = content[first + len(_UPGRADE) : second].strip() # noqa:E203
downgrade_content = content[second + len(_DOWNGRADE) :].strip() # noqa:E203
ret = {"upgrade": upgrade_content.split("\n"), "downgrade": downgrade_content.split("\n")}
return ret
def write_version_file(version_file: str, content: Dict):
"""
write version file
:param version_file:
:param content:
:return:
"""
with open(version_file, "w", encoding="utf-8") as f:
f.write(_UPGRADE)
upgrade = content.get("upgrade")
if len(upgrade) > 1:
f.write(";\n".join(upgrade) + ";\n")
else:
f.write(f"{upgrade[0]};\n")
downgrade = content.get("downgrade")
if downgrade:
f.write(_DOWNGRADE)
if len(downgrade) > 1:
f.write(";\n".join(downgrade) + ";\n")
else:
f.write(f"{downgrade[0]};\n")