Compatible with models file in directory. (#70)

This commit is contained in:
long2ice
2020-10-30 19:51:46 +08:00
parent fa73e132e2
commit 648f25a951
5 changed files with 304 additions and 269 deletions

View File

@@ -38,7 +38,11 @@ def coro(f):
@click.group(context_settings={"help_option_names": ["-h", "--help"]})
@click.version_option(__version__, "-V", "--version")
@click.option(
"-c", "--config", default="aerich.ini", show_default=True, help="Config file.",
"-c",
"--config",
default="aerich.ini",
show_default=True,
help="Config file.",
)
@click.option("--app", required=False, help="Tortoise-ORM app name.")
@click.option(
@@ -117,11 +121,6 @@ async def upgrade(ctx: Context):
click.secho("No migrate items", fg=Color.yellow)
def abort_if_false(ctx, param, value):
if not value:
ctx.abort()
@cli.command(help="Downgrade to specified version.")
@click.option(
"-v",
@@ -199,12 +198,17 @@ async def history(ctx: Context):
help="Tortoise-ORM config module dict variable, like settings.TORTOISE_ORM.",
)
@click.option(
"--location", default="./migrations", show_default=True, help="Migrate store location.",
"--location",
default="./migrations",
show_default=True,
help="Migrate store location.",
)
@click.pass_context
@coro
async def init(
ctx: Context, tortoise_orm, location,
ctx: Context,
tortoise_orm,
location,
):
config_file = ctx.obj["config_file"]
name = ctx.obj["name"]
@@ -255,7 +259,9 @@ async def init_db(ctx: Context, safe):
version = await Migrate.generate_version()
await Aerich.create(
version=version, app=app, content=Migrate.get_models_content(config, app, location),
version=version,
app=app,
content=Migrate.get_models_content(config, app, location),
)
with open(os.path.join(dirname, version), "w", encoding="utf-8") as f:
content = {
@@ -268,3 +274,7 @@ async def init_db(ctx: Context, safe):
def main():
sys.path.insert(0, ".")
cli()
if __name__ == "__main__":
main()

View File

@@ -1,3 +1,4 @@
import inspect
import json
import os
import re
@@ -201,7 +202,15 @@ class Migrate:
old_model_files = []
models = config.get("apps").get(app).get("models")
for model in models:
old_model_files.append(import_module(model).__file__)
module = import_module(model)
possible_models = [getattr(module, attr_name) for attr_name in dir(module)]
for attr in filter(
lambda x: inspect.isclass(x) and issubclass(x, Model) and x is not Model,
possible_models,
):
file = inspect.getfile(attr)
if file not in old_model_files:
old_model_files.append(file)
pattern = rf"(\n)?('|\")({app})(.\w+)('|\")"
str_io = StringIO()
for i, model_file in enumerate(old_model_files):
@@ -294,12 +303,15 @@ class Migrate:
is_rename = diff_key in cls._rename_new
if is_rename:
cls._add_operator(
cls._rename_field(new_model, old_field, new_field), upgrade,
cls._rename_field(new_model, old_field, new_field),
upgrade,
)
break
else:
cls._add_operator(
cls._add_field(new_model, new_field), upgrade, cls._is_fk_m2m(new_field),
cls._add_field(new_model, new_field),
upgrade,
cls._is_fk_m2m(new_field),
)
else:
old_field = old_fields_map.get(new_key)
@@ -350,11 +362,15 @@ class Migrate:
if isinstance(new_field, ForeignKeyFieldInstance):
if old_field.db_constraint and not new_field.db_constraint:
cls._add_operator(
cls._drop_fk(new_model, new_field), upgrade, True,
cls._drop_fk(new_model, new_field),
upgrade,
True,
)
if new_field.db_constraint and not old_field.db_constraint:
cls._add_operator(
cls._add_fk(new_model, new_field), upgrade, True,
cls._add_fk(new_model, new_field),
upgrade,
True,
)
for old_key in old_keys:
@@ -364,12 +380,20 @@ class Migrate:
not upgrade and old_key not in cls._rename_new
):
cls._add_operator(
cls._remove_field(old_model, field), upgrade, cls._is_fk_m2m(field),
cls._remove_field(old_model, field),
upgrade,
cls._is_fk_m2m(field),
)
for new_index in new_indexes:
if new_index not in old_indexes:
cls._add_operator(cls._add_index(new_model, new_index,), upgrade)
cls._add_operator(
cls._add_index(
new_model,
new_index,
),
upgrade,
)
for old_index in old_indexes:
if old_index not in new_indexes:
cls._add_operator(cls._remove_index(old_model, old_index), upgrade)