7 Commits

Author SHA1 Message Date
long2ice
d2e0a68351 Fix packaging error. (#92) 2020-12-02 23:03:15 +08:00
long2ice
ee6cc20c7d Fix empty items 2020-11-30 11:14:09 +08:00
long2ice
4e917495a0 Fix upgrade in new db. (#96) 2020-11-30 11:02:48 +08:00
long2ice
bfa66f6dd4 update changelog 2020-11-29 11:15:43 +08:00
long2ice
f00715d4c4 Merge pull request #97 from TrDex/pathlib-for-path-resolving
Use `pathlib` for path resolving
2020-11-29 11:02:44 +08:00
Mykola Solodukha
6e3105690a Use pathlib for path resolving 2020-11-28 19:23:34 +02:00
long2ice
c707f7ecb2 bug fix 2020-11-28 14:31:41 +08:00
6 changed files with 58 additions and 33 deletions

View File

@@ -2,6 +2,16 @@
## 0.4 ## 0.4
### 0.4.2
- Use `pathlib` for path resolving. (#89)
- Fix upgrade in new db. (#96)
- Fix packaging error. (#92)
### 0.4.1
- Bug fix. (#91 #93)
### 0.4.0 ### 0.4.0
- Use `.sql` instead of `.json` to store version file. - Use `.sql` instead of `.json` to store version file.

View File

@@ -1 +1 @@
__version__ = "0.4.0" __version__ = "0.4.2"

View File

@@ -3,6 +3,7 @@ import os
import sys import sys
from configparser import ConfigParser from configparser import ConfigParser
from functools import wraps from functools import wraps
from pathlib import Path
import click import click
from click import Context, UsageError from click import Context, UsageError
@@ -33,7 +34,6 @@ def coro(f):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
ctx = args[0] ctx = args[0]
loop.run_until_complete(f(*args, **kwargs)) loop.run_until_complete(f(*args, **kwargs))
loop.run_until_complete(Tortoise.close_connections())
app = ctx.obj.get("app") app = ctx.obj.get("app")
if app: if app:
Migrate.remove_old_model_file(app, ctx.obj["location"]) Migrate.remove_old_model_file(app, ctx.obj["location"])
@@ -67,7 +67,7 @@ async def cli(ctx: Context, config, app, name):
invoked_subcommand = ctx.invoked_subcommand invoked_subcommand = ctx.invoked_subcommand
if invoked_subcommand != "init": if invoked_subcommand != "init":
if not os.path.exists(config): if not Path(config).exists():
raise UsageError("You must exec init first", ctx=ctx) raise UsageError("You must exec init first", ctx=ctx)
parser.read(config) parser.read(config)
@@ -81,6 +81,8 @@ async def cli(ctx: Context, config, app, name):
ctx.obj["app"] = app ctx.obj["app"] = app
Migrate.app = app Migrate.app = app
if invoked_subcommand != "init-db": if invoked_subcommand != "init-db":
if not Path(location, app).exists():
raise UsageError("You must exec init-db first", ctx=ctx)
await Migrate.init_with_old_models(tortoise_config, app, location) await Migrate.init_with_old_models(tortoise_config, app, location)
@@ -110,10 +112,9 @@ async def upgrade(ctx: Context):
exists = False exists = False
if not exists: if not exists:
async with in_transaction(get_app_connection_name(config, app)) as conn: async with in_transaction(get_app_connection_name(config, app)) as conn:
file_path = os.path.join(Migrate.migrate_location, version_file) file_path = Path(Migrate.migrate_location, version_file)
content = get_version_content_from_file(file_path) content = get_version_content_from_file(file_path)
upgrade_query_list = content.get("upgrade") upgrade_query_list = content.get("upgrade")
print(upgrade_query_list)
for upgrade_query in upgrade_query_list: for upgrade_query in upgrade_query_list:
await conn.execute_script(upgrade_query) await conn.execute_script(upgrade_query)
await Aerich.create( await Aerich.create(
@@ -124,7 +125,7 @@ async def upgrade(ctx: Context):
click.secho(f"Success upgrade {version_file}", fg=Color.green) click.secho(f"Success upgrade {version_file}", fg=Color.green)
migrated = True migrated = True
if not migrated: if not migrated:
click.secho("No migrate items", fg=Color.yellow) click.secho("No upgrade items found", fg=Color.yellow)
@cli.command(help="Downgrade to specified version.") @cli.command(help="Downgrade to specified version.")
@@ -165,11 +166,12 @@ async def downgrade(ctx: Context, version: int, delete: bool):
for version in versions: for version in versions:
file = version.version file = version.version
async with in_transaction(get_app_connection_name(config, app)) as conn: async with in_transaction(get_app_connection_name(config, app)) as conn:
file_path = os.path.join(Migrate.migrate_location, file) file_path = Path(Migrate.migrate_location, file)
content = get_version_content_from_file(file_path) content = get_version_content_from_file(file_path)
downgrade_query_list = content.get("downgrade") downgrade_query_list = content.get("downgrade")
if not downgrade_query_list: if not downgrade_query_list:
return click.secho("No downgrade items found", fg=Color.yellow) click.secho("No downgrade items found", fg=Color.yellow)
return
for downgrade_query in downgrade_query_list: for downgrade_query in downgrade_query_list:
await conn.execute_query(downgrade_query) await conn.execute_query(downgrade_query)
await version.delete() await version.delete()
@@ -226,7 +228,7 @@ async def init(
): ):
config_file = ctx.obj["config_file"] config_file = ctx.obj["config_file"]
name = ctx.obj["name"] name = ctx.obj["name"]
if os.path.exists(config_file): if Path(config_file).exists():
return click.secho("You have inited", fg=Color.yellow) return click.secho("You have inited", fg=Color.yellow)
parser.add_section(name) parser.add_section(name)
@@ -236,7 +238,7 @@ async def init(
with open(config_file, "w", encoding="utf-8") as f: with open(config_file, "w", encoding="utf-8") as f:
parser.write(f) parser.write(f)
if not os.path.isdir(location): if not Path(location).is_dir():
os.mkdir(location) os.mkdir(location)
click.secho(f"Success create migrate location {location}", fg=Color.green) click.secho(f"Success create migrate location {location}", fg=Color.green)
@@ -258,8 +260,8 @@ async def init_db(ctx: Context, safe):
location = ctx.obj["location"] location = ctx.obj["location"]
app = ctx.obj["app"] app = ctx.obj["app"]
dirname = os.path.join(location, app) dirname = Path(location, app)
if not os.path.isdir(dirname): if not dirname.is_dir():
os.mkdir(dirname) os.mkdir(dirname)
click.secho(f"Success create app migrate location {dirname}", fg=Color.green) click.secho(f"Success create app migrate location {dirname}", fg=Color.green)
else: else:
@@ -282,7 +284,7 @@ async def init_db(ctx: Context, safe):
content = { content = {
"upgrade": [schema], "upgrade": [schema],
} }
write_version_file(os.path.join(dirname, version), content) write_version_file(Path(dirname, version), content)
click.secho(f'Success generate schema for app "{app}"', fg=Color.green) click.secho(f'Success generate schema for app "{app}"', fg=Color.green)

View File

@@ -4,11 +4,10 @@ import re
from datetime import datetime from datetime import datetime
from importlib import import_module from importlib import import_module
from io import StringIO from io import StringIO
from typing import Dict, List, Optional, Tuple, Type, Union from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type
import click import click
from packaging import version
from packaging.version import LegacyVersion, Version
from tortoise import ( from tortoise import (
BackwardFKRelation, BackwardFKRelation,
BackwardOneToOneRelation, BackwardOneToOneRelation,
@@ -44,11 +43,11 @@ class Migrate:
app: str app: str
migrate_location: str migrate_location: str
dialect: str dialect: str
_db_version: Union[LegacyVersion, Version] = None _db_version: Optional[str] = None
@classmethod @classmethod
def get_old_model_file(cls, app: str, location: str): def get_old_model_file(cls, app: str, location: str):
return os.path.join(location, app, cls.old_models + ".py") return Path(location, app, cls.old_models + ".py")
@classmethod @classmethod
def get_all_version_files(cls) -> List[str]: def get_all_version_files(cls) -> List[str]:
@@ -76,14 +75,14 @@ class Migrate:
if cls.dialect == "mysql": if cls.dialect == "mysql":
sql = "select version() as version" sql = "select version() as version"
ret = await connection.execute_query(sql) ret = await connection.execute_query(sql)
cls._db_version = version.parse(ret[1][0].get("version")) cls._db_version = ret[1][0].get("version")
@classmethod @classmethod
async def init_with_old_models(cls, config: dict, app: str, location: str): async def init_with_old_models(cls, config: dict, app: str, location: str):
await Tortoise.init(config=config) await Tortoise.init(config=config)
last_version = await cls.get_last_version() last_version = await cls.get_last_version()
cls.app = app cls.app = app
cls.migrate_location = os.path.join(location, app) cls.migrate_location = Path(location, app)
if last_version: if last_version:
content = last_version.content content = last_version.content
with open(cls.get_old_model_file(app, location), "w", encoding="utf-8") as f: with open(cls.get_old_model_file(app, location), "w", encoding="utf-8") as f:
@@ -94,6 +93,7 @@ class Migrate:
await Tortoise.init(config=migrate_config) await Tortoise.init(config=migrate_config)
connection = get_app_connection(config, app) connection = get_app_connection(config, app)
cls.dialect = connection.schema_generator.DIALECT
if cls.dialect == "mysql": if cls.dialect == "mysql":
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
@@ -106,7 +106,6 @@ class Migrate:
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
cls.ddl = PostgresDDL(connection) cls.ddl = PostgresDDL(connection)
cls.dialect = cls.ddl.DIALECT
await cls._get_db_version(connection) await cls._get_db_version(connection)
@classmethod @classmethod
@@ -134,12 +133,12 @@ class Migrate:
# delete if same version exists # delete if same version exists
for version_file in cls.get_all_version_files(): for version_file in cls.get_all_version_files():
if version_file.startswith(version.split("_")[0]): if version_file.startswith(version.split("_")[0]):
os.unlink(os.path.join(cls.migrate_location, version_file)) os.unlink(Path(cls.migrate_location, version_file))
content = { content = {
"upgrade": cls.upgrade_operators, "upgrade": cls.upgrade_operators,
"downgrade": cls.downgrade_operators, "downgrade": cls.downgrade_operators,
} }
write_version_file(os.path.join(cls.migrate_location, version), content) write_version_file(Path(cls.migrate_location, version), content)
return version return version
@classmethod @classmethod
@@ -192,8 +191,7 @@ class Migrate:
:param location: :param location:
:return: :return:
""" """
path = os.path.join(location, app, cls.old_models) path = Path(location, app, cls.old_models).as_posix().replace("/", ".")
path = path.replace(os.sep, ".").lstrip(".")
config["apps"][cls.diff_app] = { config["apps"][cls.diff_app] = {
"models": [path], "models": [path],
"default_connection": config.get("apps").get(app).get("default_connection", "default"), "default_connection": config.get("apps").get(app).get("default_connection", "default"),
@@ -315,7 +313,7 @@ class Migrate:
if ( if (
cls.dialect == "mysql" cls.dialect == "mysql"
and cls._db_version and cls._db_version
and cls._db_version.major == 5 and cls._db_version.startswith("5.")
): ):
cls._add_operator( cls._add_operator(
cls._change_field(new_model, old_field, new_field), cls._change_field(new_model, old_field, new_field),

View File

@@ -5,14 +5,20 @@ from click import BadOptionUsage, Context
from tortoise import BaseDBAsyncClient, Tortoise from tortoise import BaseDBAsyncClient, Tortoise
def get_app_connection_name(config, app) -> str: def get_app_connection_name(config, app_name: str) -> str:
""" """
get connection name get connection name
:param config: :param config:
:param app: :param app_name:
:return: :return:
""" """
return config.get("apps").get(app).get("default_connection", "default") app = config.get("apps").get(app_name)
if app:
return app.get("default_connection", "default")
raise BadOptionUsage(
option_name="--app",
message=f'Can\'t get app named "{app_name}"',
)
def get_app_connection(config, app) -> BaseDBAsyncClient: def get_app_connection(config, app) -> BaseDBAsyncClient:
@@ -65,10 +71,16 @@ def get_version_content_from_file(version_file: str) -> Dict:
with open(version_file, "r", encoding="utf-8") as f: with open(version_file, "r", encoding="utf-8") as f:
content = f.read() content = f.read()
first = content.index(_UPGRADE) first = content.index(_UPGRADE)
second = content.index(_DOWNGRADE) try:
second = content.index(_DOWNGRADE)
except ValueError:
second = len(content) - 1
upgrade_content = content[first + len(_UPGRADE) : second].strip() # noqa:E203 upgrade_content = content[first + len(_UPGRADE) : second].strip() # noqa:E203
downgrade_content = content[second + len(_DOWNGRADE) :].strip() # noqa:E203 downgrade_content = content[second + len(_DOWNGRADE) :].strip() # noqa:E203
ret = {"upgrade": upgrade_content.split("\n"), "downgrade": downgrade_content.split("\n")} ret = {
"upgrade": list(filter(lambda x: x or False, upgrade_content.split(";\n"))),
"downgrade": list(filter(lambda x: x or False, downgrade_content.split(";\n"))),
}
return ret return ret
@@ -85,7 +97,10 @@ def write_version_file(version_file: str, content: Dict):
if len(upgrade) > 1: if len(upgrade) > 1:
f.write(";\n".join(upgrade) + ";\n") f.write(";\n".join(upgrade) + ";\n")
else: else:
f.write(f"{upgrade[0]};\n") f.write(f"{upgrade[0]}")
if not upgrade[0].endswith(";"):
f.write(";")
f.write("\n")
downgrade = content.get("downgrade") downgrade = content.get("downgrade")
if downgrade: if downgrade:
f.write(_DOWNGRADE) f.write(_DOWNGRADE)

View File

@@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "aerich" name = "aerich"
version = "0.4.0" version = "0.4.2"
description = "A database migrations tool for Tortoise ORM." description = "A database migrations tool for Tortoise ORM."
authors = ["long2ice <long2ice@gmail.com>"] authors = ["long2ice <long2ice@gmail.com>"]
license = "Apache-2.0" license = "Apache-2.0"