refactor: use pathlib to read and write text

This commit is contained in:
Waket Zheng 2022-09-12 00:57:46 +08:00
parent 8c2ecbaef1
commit 7a109f3c79
2 changed files with 39 additions and 41 deletions

View File

@ -26,7 +26,7 @@ def coro(f):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
# Close db connections at the end of all all but the cli group function # Close db connections at the end of all but the cli group function
try: try:
loop.run_until_complete(f(*args, **kwargs)) loop.run_until_complete(f(*args, **kwargs))
finally: finally:
@ -54,10 +54,10 @@ async def cli(ctx: Context, config, app):
invoked_subcommand = ctx.invoked_subcommand invoked_subcommand = ctx.invoked_subcommand
if invoked_subcommand != "init": if invoked_subcommand != "init":
if not Path(config).exists(): config_path = Path(config)
if not config_path.exists():
raise UsageError("You must exec init first", ctx=ctx) raise UsageError("You must exec init first", ctx=ctx)
with open(config, "r") as f: content = config_path.read_text()
content = f.read()
doc = tomlkit.parse(content) doc = tomlkit.parse(content)
try: try:
tool = doc["tool"]["aerich"] tool = doc["tool"]["aerich"]
@ -192,9 +192,9 @@ async def init(ctx: Context, tortoise_orm, location, src_folder):
# check that we can find the configuration, if not we can fail before the config file gets created # check that we can find the configuration, if not we can fail before the config file gets created
add_src_path(src_folder) add_src_path(src_folder)
get_tortoise_config(ctx, tortoise_orm) get_tortoise_config(ctx, tortoise_orm)
if Path(config_file).exists(): config_path = Path(config_file)
with open(config_file, "r") as f: if config_path.exists():
content = f.read() content = config_path.read_text()
doc = tomlkit.parse(content) doc = tomlkit.parse(content)
else: else:
doc = tomlkit.parse("[tool.aerich]") doc = tomlkit.parse("[tool.aerich]")
@ -204,8 +204,7 @@ async def init(ctx: Context, tortoise_orm, location, src_folder):
table["src_folder"] = src_folder table["src_folder"] = src_folder
doc["tool"]["aerich"] = table doc["tool"]["aerich"] = table
with open(config_file, "w") as f: config_path.write_text(tomlkit.dumps(doc))
f.write(tomlkit.dumps(doc))
Path(location).mkdir(parents=True, exist_ok=True) Path(location).mkdir(parents=True, exist_ok=True)

View File

@ -87,20 +87,19 @@ def get_version_content_from_file(version_file: Union[str, Path]) -> Dict:
:param version_file: :param version_file:
:return: :return:
""" """
with open(version_file, "r", encoding="utf-8") as f: content = Path(version_file).read_text(encoding="utf-8")
content = f.read() first = content.index(_UPGRADE)
first = content.index(_UPGRADE) try:
try: second = content.index(_DOWNGRADE)
second = content.index(_DOWNGRADE) except ValueError:
except ValueError: second = len(content) - 1
second = len(content) - 1 upgrade_content = content[first + len(_UPGRADE) : second].strip() # noqa:E203
upgrade_content = content[first + len(_UPGRADE) : second].strip() # noqa:E203 downgrade_content = content[second + len(_DOWNGRADE) :].strip() # noqa:E203
downgrade_content = content[second + len(_DOWNGRADE) :].strip() # noqa:E203 ret = {
ret = { "upgrade": list(filter(lambda x: x or False, upgrade_content.split(";\n"))),
"upgrade": list(filter(lambda x: x or False, upgrade_content.split(";\n"))), "downgrade": list(filter(lambda x: x or False, downgrade_content.split(";\n"))),
"downgrade": list(filter(lambda x: x or False, downgrade_content.split(";\n"))), }
} return ret
return ret
def write_version_file(version_file: Path, content: Dict): def write_version_file(version_file: Path, content: Dict):
@ -110,25 +109,25 @@ def write_version_file(version_file: Path, content: Dict):
:param content: :param content:
:return: :return:
""" """
with open(version_file, "w", encoding="utf-8") as f: text = _UPGRADE
f.write(_UPGRADE) upgrade = content.get("upgrade")
upgrade = content.get("upgrade") if len(upgrade) > 1:
if len(upgrade) > 1: text += ";\n".join(upgrade)
f.write(";\n".join(upgrade)) if not upgrade[-1].endswith(";"):
if not upgrade[-1].endswith(";"): text += ";\n"
f.write(";\n") else:
text += f"{upgrade[0]}"
if not upgrade[0].endswith(";"):
text += ";"
text += "\n"
downgrade = content.get("downgrade")
if downgrade:
text += _DOWNGRADE
if len(downgrade) > 1:
text += ";\n".join(downgrade) + ";\n"
else: else:
f.write(f"{upgrade[0]}") text += f"{downgrade[0]};\n"
if not upgrade[0].endswith(";"): version_file.write_text(text, encoding="utf-8")
f.write(";")
f.write("\n")
downgrade = content.get("downgrade")
if downgrade:
f.write(_DOWNGRADE)
if len(downgrade) > 1:
f.write(";\n".join(downgrade) + ";\n")
else:
f.write(f"{downgrade[0]};\n")
def get_models_describe(app: str) -> Dict: def get_models_describe(app: str) -> Dict: