Use pathlib for path resolving

This commit is contained in:
Mykola Solodukha 2020-11-28 19:23:34 +02:00
parent c707f7ecb2
commit 6e3105690a
2 changed files with 15 additions and 14 deletions

View File

@ -3,6 +3,7 @@ import os
import sys import sys
from configparser import ConfigParser from configparser import ConfigParser
from functools import wraps from functools import wraps
from pathlib import Path
import click import click
from click import Context, UsageError from click import Context, UsageError
@ -66,7 +67,7 @@ async def cli(ctx: Context, config, app, name):
invoked_subcommand = ctx.invoked_subcommand invoked_subcommand = ctx.invoked_subcommand
if invoked_subcommand != "init": if invoked_subcommand != "init":
if not os.path.exists(config): if not Path(config).exists():
raise UsageError("You must exec init first", ctx=ctx) raise UsageError("You must exec init first", ctx=ctx)
parser.read(config) parser.read(config)
@ -109,7 +110,7 @@ async def upgrade(ctx: Context):
exists = False exists = False
if not exists: if not exists:
async with in_transaction(get_app_connection_name(config, app)) as conn: async with in_transaction(get_app_connection_name(config, app)) as conn:
file_path = os.path.join(Migrate.migrate_location, version_file) file_path = Path(Migrate.migrate_location, version_file)
content = get_version_content_from_file(file_path) content = get_version_content_from_file(file_path)
upgrade_query_list = content.get("upgrade") upgrade_query_list = content.get("upgrade")
for upgrade_query in upgrade_query_list: for upgrade_query in upgrade_query_list:
@ -163,7 +164,7 @@ async def downgrade(ctx: Context, version: int, delete: bool):
for version in versions: for version in versions:
file = version.version file = version.version
async with in_transaction(get_app_connection_name(config, app)) as conn: async with in_transaction(get_app_connection_name(config, app)) as conn:
file_path = os.path.join(Migrate.migrate_location, file) file_path = Path(Migrate.migrate_location, file)
content = get_version_content_from_file(file_path) content = get_version_content_from_file(file_path)
downgrade_query_list = content.get("downgrade") downgrade_query_list = content.get("downgrade")
if not downgrade_query_list: if not downgrade_query_list:
@ -224,7 +225,7 @@ async def init(
): ):
config_file = ctx.obj["config_file"] config_file = ctx.obj["config_file"]
name = ctx.obj["name"] name = ctx.obj["name"]
if os.path.exists(config_file): if Path(config_file).exists():
return click.secho("You have inited", fg=Color.yellow) return click.secho("You have inited", fg=Color.yellow)
parser.add_section(name) parser.add_section(name)
@ -234,7 +235,7 @@ async def init(
with open(config_file, "w", encoding="utf-8") as f: with open(config_file, "w", encoding="utf-8") as f:
parser.write(f) parser.write(f)
if not os.path.isdir(location): if not Path(location).is_dir():
os.mkdir(location) os.mkdir(location)
click.secho(f"Success create migrate location {location}", fg=Color.green) click.secho(f"Success create migrate location {location}", fg=Color.green)
@ -256,8 +257,8 @@ async def init_db(ctx: Context, safe):
location = ctx.obj["location"] location = ctx.obj["location"]
app = ctx.obj["app"] app = ctx.obj["app"]
dirname = os.path.join(location, app) dirname = Path(location, app)
if not os.path.isdir(dirname): if not dirname.is_dir():
os.mkdir(dirname) os.mkdir(dirname)
click.secho(f"Success create app migrate location {dirname}", fg=Color.green) click.secho(f"Success create app migrate location {dirname}", fg=Color.green)
else: else:
@ -280,7 +281,7 @@ async def init_db(ctx: Context, safe):
content = { content = {
"upgrade": [schema], "upgrade": [schema],
} }
write_version_file(os.path.join(dirname, version), content) write_version_file(Path(dirname, version), content)
click.secho(f'Success generate schema for app "{app}"', fg=Color.green) click.secho(f'Success generate schema for app "{app}"', fg=Color.green)

View File

@ -4,6 +4,7 @@ import re
from datetime import datetime from datetime import datetime
from importlib import import_module from importlib import import_module
from io import StringIO from io import StringIO
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type, Union from typing import Dict, List, Optional, Tuple, Type, Union
import click import click
@ -48,7 +49,7 @@ class Migrate:
@classmethod @classmethod
def get_old_model_file(cls, app: str, location: str): def get_old_model_file(cls, app: str, location: str):
return os.path.join(location, app, cls.old_models + ".py") return Path(location, app, cls.old_models + ".py")
@classmethod @classmethod
def get_all_version_files(cls) -> List[str]: def get_all_version_files(cls) -> List[str]:
@ -83,7 +84,7 @@ class Migrate:
await Tortoise.init(config=config) await Tortoise.init(config=config)
last_version = await cls.get_last_version() last_version = await cls.get_last_version()
cls.app = app cls.app = app
cls.migrate_location = os.path.join(location, app) cls.migrate_location = Path(location, app)
if last_version: if last_version:
content = last_version.content content = last_version.content
with open(cls.get_old_model_file(app, location), "w", encoding="utf-8") as f: with open(cls.get_old_model_file(app, location), "w", encoding="utf-8") as f:
@ -134,12 +135,12 @@ class Migrate:
# delete if same version exists # delete if same version exists
for version_file in cls.get_all_version_files(): for version_file in cls.get_all_version_files():
if version_file.startswith(version.split("_")[0]): if version_file.startswith(version.split("_")[0]):
os.unlink(os.path.join(cls.migrate_location, version_file)) os.unlink(Path(cls.migrate_location, version_file))
content = { content = {
"upgrade": cls.upgrade_operators, "upgrade": cls.upgrade_operators,
"downgrade": cls.downgrade_operators, "downgrade": cls.downgrade_operators,
} }
write_version_file(os.path.join(cls.migrate_location, version), content) write_version_file(Path(cls.migrate_location, version), content)
return version return version
@classmethod @classmethod
@ -192,8 +193,7 @@ class Migrate:
:param location: :param location:
:return: :return:
""" """
path = os.path.join(location, app, cls.old_models) path = Path(location, app, cls.old_models).as_posix().replace("/", ".")
path = path.replace(os.sep, ".").lstrip(".")
config["apps"][cls.diff_app] = { config["apps"][cls.diff_app] = {
"models": [path], "models": [path],
"default_connection": config.get("apps").get(app).get("default_connection", "default"), "default_connection": config.get("apps").get(app).get("default_connection", "default"),