From 7a109f3c79a431adece72170d79e30084045a5e6 Mon Sep 17 00:00:00 2001 From: Waket Zheng Date: Mon, 12 Sep 2022 00:57:46 +0800 Subject: [PATCH] refactor: use pathlib to read and write text --- aerich/cli.py | 17 +++++++------ aerich/utils.py | 63 ++++++++++++++++++++++++------------------------- 2 files changed, 39 insertions(+), 41 deletions(-) diff --git a/aerich/cli.py b/aerich/cli.py index 41dc65b..564dfca 100644 --- a/aerich/cli.py +++ b/aerich/cli.py @@ -26,7 +26,7 @@ def coro(f): def wrapper(*args, **kwargs): loop = asyncio.get_event_loop() - # Close db connections at the end of all all but the cli group function + # Close db connections at the end of all but the cli group function try: loop.run_until_complete(f(*args, **kwargs)) finally: @@ -54,10 +54,10 @@ async def cli(ctx: Context, config, app): invoked_subcommand = ctx.invoked_subcommand if invoked_subcommand != "init": - if not Path(config).exists(): + config_path = Path(config) + if not config_path.exists(): raise UsageError("You must exec init first", ctx=ctx) - with open(config, "r") as f: - content = f.read() + content = config_path.read_text() doc = tomlkit.parse(content) try: tool = doc["tool"]["aerich"] @@ -192,9 +192,9 @@ async def init(ctx: Context, tortoise_orm, location, src_folder): # check that we can find the configuration, if not we can fail before the config file gets created add_src_path(src_folder) get_tortoise_config(ctx, tortoise_orm) - if Path(config_file).exists(): - with open(config_file, "r") as f: - content = f.read() + config_path = Path(config_file) + if config_path.exists(): + content = config_path.read_text() doc = tomlkit.parse(content) else: doc = tomlkit.parse("[tool.aerich]") @@ -204,8 +204,7 @@ async def init(ctx: Context, tortoise_orm, location, src_folder): table["src_folder"] = src_folder doc["tool"]["aerich"] = table - with open(config_file, "w") as f: - f.write(tomlkit.dumps(doc)) + config_path.write_text(tomlkit.dumps(doc)) Path(location).mkdir(parents=True, exist_ok=True) diff --git a/aerich/utils.py b/aerich/utils.py index c8bffdc..b51264b 100644 --- a/aerich/utils.py +++ b/aerich/utils.py @@ -87,20 +87,19 @@ def get_version_content_from_file(version_file: Union[str, Path]) -> Dict: :param version_file: :return: """ - with open(version_file, "r", encoding="utf-8") as f: - content = f.read() - first = content.index(_UPGRADE) - try: - second = content.index(_DOWNGRADE) - except ValueError: - second = len(content) - 1 - upgrade_content = content[first + len(_UPGRADE) : second].strip() # noqa:E203 - downgrade_content = content[second + len(_DOWNGRADE) :].strip() # noqa:E203 - ret = { - "upgrade": list(filter(lambda x: x or False, upgrade_content.split(";\n"))), - "downgrade": list(filter(lambda x: x or False, downgrade_content.split(";\n"))), - } - return ret + content = Path(version_file).read_text(encoding="utf-8") + first = content.index(_UPGRADE) + try: + second = content.index(_DOWNGRADE) + except ValueError: + second = len(content) - 1 + upgrade_content = content[first + len(_UPGRADE) : second].strip() # noqa:E203 + downgrade_content = content[second + len(_DOWNGRADE) :].strip() # noqa:E203 + ret = { + "upgrade": list(filter(lambda x: x or False, upgrade_content.split(";\n"))), + "downgrade": list(filter(lambda x: x or False, downgrade_content.split(";\n"))), + } + return ret def write_version_file(version_file: Path, content: Dict): @@ -110,25 +109,25 @@ def write_version_file(version_file: Path, content: Dict): :param content: :return: """ - with open(version_file, "w", encoding="utf-8") as f: - f.write(_UPGRADE) - upgrade = content.get("upgrade") - if len(upgrade) > 1: - f.write(";\n".join(upgrade)) - if not upgrade[-1].endswith(";"): - f.write(";\n") + text = _UPGRADE + upgrade = content.get("upgrade") + if len(upgrade) > 1: + text += ";\n".join(upgrade) + if not upgrade[-1].endswith(";"): + text += ";\n" + else: + text += f"{upgrade[0]}" + if not upgrade[0].endswith(";"): + text += ";" + text += "\n" + downgrade = content.get("downgrade") + if downgrade: + text += _DOWNGRADE + if len(downgrade) > 1: + text += ";\n".join(downgrade) + ";\n" else: - f.write(f"{upgrade[0]}") - if not upgrade[0].endswith(";"): - f.write(";") - f.write("\n") - downgrade = content.get("downgrade") - if downgrade: - f.write(_DOWNGRADE) - if len(downgrade) > 1: - f.write(";\n".join(downgrade) + ";\n") - else: - f.write(f"{downgrade[0]};\n") + text += f"{downgrade[0]};\n" + version_file.write_text(text, encoding="utf-8") def get_models_describe(app: str) -> Dict: