refactoring migrate logic

This commit is contained in:
long2ice
2020-10-09 00:05:22 +08:00
parent 9889d9492b
commit 8cace21fde
7 changed files with 124 additions and 104 deletions

View File

@@ -1 +1 @@
__version__ = "0.2.5"
__version__ = "0.2.6"

View File

@@ -26,7 +26,9 @@ def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
return loop.run_until_complete(f(*args, **kwargs))
ctx = args[0]
loop.run_until_complete(f(*args, **kwargs))
Migrate.remove_old_model_file(ctx.obj["app"], ctx.obj["location"])
return wrapper
@@ -67,7 +69,7 @@ async def cli(ctx: Context, config, app, name):
ctx.obj["config"] = tortoise_config
ctx.obj["location"] = location
ctx.obj["app"] = app
Migrate.app = app
if invoked_subcommand != "init-db":
await Migrate.init_with_old_models(tortoise_config, app, location)
@@ -80,50 +82,70 @@ async def migrate(ctx: Context, name):
config = ctx.obj["config"]
location = ctx.obj["location"]
app = ctx.obj["app"]
ret = await Migrate.migrate(name)
if not ret:
return click.secho("No changes detected", fg=Color.yellow)
Migrate.write_old_models(config, app, location)
click.secho(f"Success migrate {ret}", fg=Color.green)
@cli.command(help="Upgrade to latest version.")
@cli.command(help="Upgrade to specified version.")
@click.option(
"--version",
default=-1,
type=int,
show_default=True,
help="Specified version, default to latest.",
)
@click.pass_context
@coro
async def upgrade(ctx: Context):
async def upgrade(ctx: Context, version: int):
config = ctx.obj["config"]
app = ctx.obj["app"]
location = ctx.obj["location"]
migrated = False
for version in Migrate.get_all_version_files():
for version_file in Migrate.get_all_version_files():
try:
exists = await Aerich.exists(version=version, app=app)
exists = await Aerich.exists(version=version_file, app=app)
except OperationalError:
exists = False
if not exists:
async with in_transaction(get_app_connection_name(config, app)) as conn:
file_path = os.path.join(Migrate.migrate_location, version)
file_path = os.path.join(Migrate.migrate_location, version_file)
with open(file_path, "r", encoding="utf-8") as f:
content = json.load(f)
upgrade_query_list = content.get("upgrade")
for upgrade_query in upgrade_query_list:
await conn.execute_script(upgrade_query)
await Aerich.create(version=version, app=app)
click.secho(f"Success upgrade {version}", fg=Color.green)
await Aerich.create(
version=version_file,
app=app,
content=Migrate.get_models_content(config, app, location),
)
click.secho(f"Success upgrade {version_file}", fg=Color.green)
migrated = True
if version != -1 and version_file.startswith(str(version)):
break
if not migrated:
click.secho("No migrate items", fg=Color.yellow)
@cli.command(help="Downgrade to previous version.")
@cli.command(help="Downgrade to specified version.")
@click.option(
"--version", default=-1, type=int, show_default=True, help="Specified version, default to last."
)
@click.pass_context
@coro
async def downgrade(ctx: Context):
async def downgrade(ctx: Context, version: int):
app = ctx.obj["app"]
config = ctx.obj["config"]
last_version = await Migrate.get_last_version()
if not last_version:
return click.secho("No last version found", fg=Color.yellow)
file = last_version.version
if version == -1:
specified_version = await Migrate.get_last_version()
else:
specified_version = await Aerich.filter(app=app, pk=version + 1).first()
if not specified_version:
return click.secho("No specified version found", fg=Color.yellow)
file = specified_version.version
async with in_transaction(get_app_connection_name(config, app)) as conn:
file_path = os.path.join(Migrate.migrate_location, file)
with open(file_path, "r", encoding="utf-8") as f:
@@ -133,7 +155,8 @@ async def downgrade(ctx: Context):
return click.secho("No downgrade item found", fg=Color.yellow)
for downgrade_query in downgrade_query_list:
await conn.execute_query(downgrade_query)
await last_version.delete()
await specified_version.delete()
os.unlink(file_path)
return click.secho(f"Success downgrade {file}", fg=Color.green)
@@ -149,7 +172,7 @@ async def heads(ctx: Context):
click.secho(version, fg=Color.green)
is_heads = True
if not is_heads:
click.secho("No available heads,try migrate", fg=Color.green)
click.secho("No available heads,try migrate first", fg=Color.green)
@cli.command(help="List all migrate items.")
@@ -219,8 +242,6 @@ async def init_db(ctx: Context, safe):
else:
return click.secho(f"Inited {app} already", fg=Color.yellow)
Migrate.write_old_models(config, app, location)
await Tortoise.init(config=config)
connection = get_app_connection(config, app)
await generate_schema_for_client(connection, safe)
@@ -228,7 +249,9 @@ async def init_db(ctx: Context, safe):
schema = get_schema_sql(connection, safe)
version = await Migrate.generate_version()
await Aerich.create(version=version, app=app)
await Aerich.create(
version=version, app=app, content=Migrate.get_models_content(config, app, location)
)
with open(os.path.join(dirname, version), "w", encoding="utf-8") as f:
content = {
"upgrade": [schema],

View File

@@ -2,3 +2,9 @@ class NotSupportError(Exception):
"""
raise when features not support
"""
class DuplicationError(Exception):
"""
raise when something duplication
"""

View File

@@ -3,6 +3,7 @@ import os
import re
from datetime import datetime
from importlib import import_module
from io import StringIO
from typing import Dict, List, Tuple, Type
import click
@@ -41,8 +42,8 @@ class Migrate:
dialect: str
@classmethod
def get_old_model_file(cls):
return cls.old_models + ".py"
def get_old_model_file(cls, app: str, location: str):
return os.path.join(location, app, cls.old_models + ".py")
@classmethod
def get_all_version_files(cls) -> List[str]:
@@ -56,9 +57,21 @@ class Migrate:
return await Aerich.filter(app=cls.app).first()
@classmethod
async def init_with_old_models(cls, config: dict, app: str, location: str):
migrate_config = cls._get_migrate_config(config, app, location)
def remove_old_model_file(cls, app: str, location: str):
try:
os.unlink(cls.get_old_model_file(app, location))
except FileNotFoundError:
pass
@classmethod
async def init_with_old_models(cls, config: dict, app: str, location: str):
await Tortoise.init(config=config)
last_version = await cls.get_last_version()
content = last_version.content
with open(cls.get_old_model_file(app, location), "w") as f:
f.write(content)
migrate_config = cls._get_migrate_config(config, app, location)
cls.app = app
cls.migrate_config = migrate_config
cls.migrate_location = os.path.join(location, app)
@@ -151,26 +164,6 @@ class Migrate:
else:
cls.downgrade_operators.append(operator)
@classmethod
def cp_models(
cls, app: str, model_files: List[str], old_model_file,
):
"""
cp currents models to old_model_files
:param app:
:param model_files:
:param old_model_file:
:return:
"""
pattern = rf"(\n)?('|\")({app})(.\w+)('|\")"
for i, model_file in enumerate(model_files):
with open(model_file, "r", encoding="utf-8") as f:
content = f.read()
ret = re.sub(pattern, rf"\2{cls.diff_app}\4\5", content)
mode = "w" if i == 0 else "a"
with open(old_model_file, mode, encoding="utf-8") as f:
f.write(f"{ret}\n")
@classmethod
def _get_migrate_config(cls, config: dict, app: str, location: str):
"""
@@ -189,7 +182,7 @@ class Migrate:
return config
@classmethod
def write_old_models(cls, config: dict, app: str, location: str):
def get_models_content(cls, config: dict, app: str, location: str):
"""
write new models to old models
:param config:
@@ -197,14 +190,18 @@ class Migrate:
:param location:
:return:
"""
cls.app = app
old_model_files = []
models = config.get("apps").get(app).get("models")
for model in models:
old_model_files.append(import_module(model).__file__)
cls.cp_models(app, old_model_files, os.path.join(location, app, cls.get_old_model_file()))
pattern = rf"(\n)?('|\")({app})(.\w+)('|\")"
str_io = StringIO()
for i, model_file in enumerate(old_model_files):
with open(model_file, "r", encoding="utf-8") as f:
content = f.read()
ret = re.sub(pattern, rf"\2{cls.diff_app}\4\5", content)
str_io.write(f"{ret}\n")
return str_io.getvalue()
@classmethod
def diff_models(

View File

@@ -6,6 +6,7 @@ MAX_VERSION_LENGTH = 255
class Aerich(Model):
version = fields.CharField(max_length=MAX_VERSION_LENGTH)
app = fields.CharField(max_length=20)
content = fields.TextField()
class Meta:
ordering = ["-id"]