diff --git a/aerich/cli.py b/aerich/cli.py index a168f13..20fa4a6 100644 --- a/aerich/cli.py +++ b/aerich/cli.py @@ -18,6 +18,7 @@ from aerich.migrate import Migrate from aerich.utils import ( get_app_connection, get_app_connection_name, + get_models_describe, get_tortoise_config, get_version_content_from_file, write_version_file, @@ -34,11 +35,7 @@ def coro(f): @wraps(f) def wrapper(*args, **kwargs): loop = asyncio.get_event_loop() - ctx = args[0] loop.run_until_complete(f(*args, **kwargs)) - app = ctx.obj.get("app") - if app: - Migrate.remove_old_model_file(app, ctx.obj["location"]) loop.run_until_complete(Tortoise.close_connections()) return wrapper @@ -86,7 +83,7 @@ async def cli(ctx: Context, config, app, name): if invoked_subcommand != "init-db": if not Path(location, app).exists(): raise UsageError("You must exec init-db first", ctx=ctx) - await Migrate.init_with_old_models(tortoise_config, app, location) + await Migrate.init(tortoise_config, app, location) @cli.command(help="Generate migrate changes file.") @@ -106,7 +103,6 @@ async def migrate(ctx: Context, name): async def upgrade(ctx: Context): config = ctx.obj["config"] app = ctx.obj["app"] - location = ctx.obj["location"] migrated = False for version_file in Migrate.get_all_version_files(): try: @@ -123,7 +119,7 @@ async def upgrade(ctx: Context): await Aerich.create( version=version_file, app=app, - content=Migrate.get_models_content(config, app, location), + content=get_models_describe(app), ) click.secho(f"Success upgrade {version_file}", fg=Color.green) migrated = True @@ -281,7 +277,7 @@ async def init_db(ctx: Context, safe): await Aerich.create( version=version, app=app, - content=Migrate.get_models_content(config, app, location), + content=get_models_describe(app), ) content = { "upgrade": [schema], diff --git a/aerich/ddl/__init__.py b/aerich/ddl/__init__.py index b39c644..59bb86a 100644 --- a/aerich/ddl/__init__.py +++ b/aerich/ddl/__init__.py @@ -140,7 +140,7 @@ class BaseDDL: ) def change_column( - self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str + self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str ): return self._CHANGE_COLUMN_TEMPLATE.format( table_name=model._meta.db_table, @@ -167,26 +167,23 @@ class BaseDDL: table_name=model._meta.db_table, ) - def add_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance): + def add_fk(self, model: "Type[Model]", field: dict): db_table = model._meta.db_table - to_field_name = field.to_field_instance.source_field - if not to_field_name: - to_field_name = field.to_field_instance.model_field_name - db_column = field.source_field or field.model_field_name + "_id" + db_column = field.get('db_column') fk_name = self.schema_generator._generate_fk_name( from_table=db_table, from_field=db_column, to_table=field.related_model._meta.db_table, - to_field=to_field_name, + to_field=db_column, ) return self._ADD_FK_TEMPLATE.format( table_name=db_table, fk_name=fk_name, db_column=db_column, table=field.related_model._meta.db_table, - field=to_field_name, - on_delete=field.on_delete, + field=db_column, + on_delete=field.get('on_delete'), ) def drop_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance): diff --git a/aerich/migrate.py b/aerich/migrate.py index 9f6903b..858ddde 100644 --- a/aerich/migrate.py +++ b/aerich/migrate.py @@ -1,14 +1,10 @@ -import inspect import os -import re from datetime import datetime -from importlib import import_module -from io import StringIO from pathlib import Path -from types import ModuleType from typing import Dict, List, Optional, Tuple, Type import click +from dictdiffer import diff from tortoise import ( BackwardFKRelation, BackwardOneToOneRelation, @@ -23,7 +19,7 @@ from tortoise.fields import Field from aerich.ddl import BaseDDL from aerich.models import MAX_VERSION_LENGTH, Aerich -from aerich.utils import get_app_connection, write_version_file +from aerich.utils import get_app_connection, get_models_describe, write_version_file class Migrate: @@ -38,18 +34,12 @@ class Migrate: _rename_new = [] ddl: BaseDDL - migrate_config: dict - old_models = "old_models" - diff_app = "diff_models" + _last_version_content: Optional[dict] = None app: str migrate_location: str dialect: str _db_version: Optional[str] = None - @classmethod - def get_old_model_file(cls, app: str, location: str): - return Path(location, app, cls.old_models + ".py") - @classmethod def get_all_version_files(cls) -> List[str]: return sorted( @@ -57,6 +47,10 @@ class Migrate: key=lambda x: int(x.split("_")[0]), ) + @classmethod + def _get_model(cls, model: str) -> Type[Model]: + return Tortoise.apps.get(cls.app).get(model) + @classmethod async def get_last_version(cls) -> Optional[Aerich]: try: @@ -64,13 +58,6 @@ class Migrate: except OperationalError: pass - @classmethod - def remove_old_model_file(cls, app: str, location: str): - try: - os.unlink(cls.get_old_model_file(app, location)) - except (OSError, FileNotFoundError): - pass - @classmethod async def _get_db_version(cls, connection: BaseDBAsyncClient): if cls.dialect == "mysql": @@ -79,19 +66,13 @@ class Migrate: cls._db_version = ret[1][0].get("version") @classmethod - async def init_with_old_models(cls, config: dict, app: str, location: str): + async def init(cls, config: dict, app: str, location: str): await Tortoise.init(config=config) last_version = await cls.get_last_version() cls.app = app cls.migrate_location = Path(location, app) if last_version: - content = last_version.content - with open(cls.get_old_model_file(app, location), "w", encoding="utf-8") as f: - f.write(content) - - migrate_config = cls._get_migrate_config(config, app, location) - cls.migrate_config = migrate_config - await Tortoise.init(config=migrate_config) + cls._last_version_content = last_version.content connection = get_app_connection(config, app) cls.dialect = connection.schema_generator.DIALECT @@ -149,12 +130,9 @@ class Migrate: :param name: :return: """ - apps = Tortoise.apps - diff_models = apps.get(cls.diff_app) - app_models = apps.get(cls.app) - - cls.diff_models(diff_models, app_models) - cls.diff_models(app_models, diff_models, False) + new_version_content = get_models_describe(cls.app) + cls.diff_models(cls._last_version_content, new_version_content) + cls.diff_models(new_version_content, cls._last_version_content, False) cls._merge_operators() @@ -183,58 +161,9 @@ class Migrate: else: cls.downgrade_operators.append(operator) - @classmethod - def _get_migrate_config(cls, config: dict, app: str, location: str): - """ - generate tmp config with old models - :param config: - :param app: - :param location: - :return: - """ - path = Path(location, app, cls.old_models).as_posix().replace("/", ".") - config["apps"][cls.diff_app] = { - "models": [path], - "default_connection": config.get("apps").get(app).get("default_connection", "default"), - } - return config - - @classmethod - def get_models_content(cls, config: dict, app: str, location: str): - """ - write new models to old models - :param config: - :param app: - :param location: - :return: - """ - old_model_files = [] - models = config.get("apps").get(app).get("models") - for model in models: - if isinstance(model, ModuleType): - module = model - else: - module = import_module(model) - possible_models = [getattr(module, attr_name) for attr_name in dir(module)] - for attr in filter( - lambda x: inspect.isclass(x) and issubclass(x, Model) and x is not Model, - possible_models, - ): - file = inspect.getfile(attr) - if file not in old_model_files: - old_model_files.append(file) - pattern = rf"(\n)?('|\")({app})(.\w+)('|\")" - str_io = StringIO() - for i, model_file in enumerate(old_model_files): - with open(model_file, "r", encoding="utf-8") as f: - content = f.read() - ret = re.sub(pattern, rf"\2{cls.diff_app}\4\5", content) - str_io.write(f"{ret}\n") - return str_io.getvalue() - @classmethod def diff_models( - cls, old_models: Dict[str, Type[Model]], new_models: Dict[str, Type[Model]], upgrade=True + cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True ): """ diff models and add operators @@ -243,18 +172,19 @@ class Migrate: :param upgrade: :return: """ - old_models.pop(cls._aerich, None) - new_models.pop(cls._aerich, None) + _aerich = f'{cls.app}.{cls._aerich}' + old_models.pop(_aerich, None) + new_models.pop(_aerich, None) - for new_model_str, new_model in new_models.items(): + for new_model_str, new_model_describe in new_models.items(): if new_model_str not in old_models.keys(): - cls._add_operator(cls.add_model(new_model), upgrade) + cls._add_operator(cls.add_model(cls._get_model(new_model_str)), upgrade) else: - cls.diff_model(old_models.get(new_model_str), new_model, upgrade) + cls.diff_model(old_models.get(new_model_str), new_model_describe, upgrade) for old_model in old_models: if old_model not in new_models.keys(): - cls._add_operator(cls.remove_model(old_models.get(old_model)), upgrade) + cls._add_operator(cls.remove_model(cls._get_model(old_model)), upgrade) @classmethod def _is_fk_m2m(cls, field: Field): @@ -269,166 +199,20 @@ class Migrate: return cls.ddl.drop_table(model) @classmethod - def diff_model(cls, old_model: Type[Model], new_model: Type[Model], upgrade=True): + def diff_model(cls, old_model_describe: dict, new_model_describe: dict, upgrade=True): """ diff single model - :param old_model: - :param new_model: + :param old_model_describe: + :param new_model_describe: :param upgrade: :return: """ - old_indexes = old_model._meta.indexes - new_indexes = new_model._meta.indexes - - old_unique_together = old_model._meta.unique_together - new_unique_together = new_model._meta.unique_together - - old_fields_map = old_model._meta.fields_map - new_fields_map = new_model._meta.fields_map - - old_keys = old_fields_map.keys() - new_keys = new_fields_map.keys() - for new_key in new_keys: - new_field = new_fields_map.get(new_key) - if cls._exclude_field(new_field, upgrade): - continue - if new_key not in old_keys: - new_field_dict = new_field.describe(serializable=True) - new_field_dict.pop("name", None) - new_field_dict.pop("db_column", None) - for diff_key in old_keys - new_keys: - old_field = old_fields_map.get(diff_key) - old_field_dict = old_field.describe(serializable=True) - old_field_dict.pop("name", None) - old_field_dict.pop("db_column", None) - if old_field_dict == new_field_dict: - if upgrade: - is_rename = click.prompt( - f"Rename {diff_key} to {new_key}?", - default=True, - type=bool, - show_choices=True, - ) - cls._rename_new.append(new_key) - cls._rename_old.append(diff_key) - else: - is_rename = diff_key in cls._rename_new - if is_rename: - if ( - cls.dialect == "mysql" - and cls._db_version - and cls._db_version.startswith("5.") - ): - cls._add_operator( - cls._change_field(new_model, old_field, new_field), - upgrade, - ) - else: - cls._add_operator( - cls._rename_field(new_model, old_field, new_field), - upgrade, - ) - break - else: - cls._add_operator( - cls._add_field(new_model, new_field), - upgrade, - cls._is_fk_m2m(new_field), - ) - else: - old_field = old_fields_map.get(new_key) - new_field_dict = new_field.describe(serializable=True) - new_field_dict.pop("unique") - new_field_dict.pop("indexed") - old_field_dict = old_field.describe(serializable=True) - old_field_dict.pop("unique") - old_field_dict.pop("indexed") - if not cls._is_fk_m2m(new_field) and new_field_dict != old_field_dict: - if cls.dialect == "postgres": - if new_field.null != old_field.null: - cls._add_operator( - cls._alter_null(new_model, new_field), upgrade=upgrade - ) - if new_field.default != old_field.default and not callable( - new_field.default - ): - cls._add_operator( - cls._alter_default(new_model, new_field), upgrade=upgrade - ) - if new_field.description != old_field.description: - cls._add_operator( - cls._set_comment(new_model, new_field), upgrade=upgrade - ) - if new_field.field_type != old_field.field_type: - cls._add_operator( - cls._modify_field(new_model, new_field), upgrade=upgrade - ) - else: - cls._add_operator(cls._modify_field(new_model, new_field), upgrade=upgrade) - if (old_field.index and not new_field.index) or ( - old_field.unique and not new_field.unique - ): - cls._add_operator( - cls._remove_index( - old_model, (old_field.model_field_name,), old_field.unique - ), - upgrade, - cls._is_fk_m2m(old_field), - ) - elif (new_field.index and not old_field.index) or ( - new_field.unique and not old_field.unique - ): - cls._add_operator( - cls._add_index(new_model, (new_field.model_field_name,), new_field.unique), - upgrade, - cls._is_fk_m2m(new_field), - ) - if isinstance(new_field, ForeignKeyFieldInstance): - if old_field.db_constraint and not new_field.db_constraint: - cls._add_operator( - cls._drop_fk(new_model, new_field), - upgrade, - True, - ) - if new_field.db_constraint and not old_field.db_constraint: - cls._add_operator( - cls._add_fk(new_model, new_field), - upgrade, - True, - ) - - for old_key in old_keys: - field = old_fields_map.get(old_key) - if old_key not in new_keys and not cls._exclude_field(field, upgrade): - if (upgrade and old_key not in cls._rename_old) or ( - not upgrade and old_key not in cls._rename_new - ): - cls._add_operator( - cls._remove_field(old_model, field), - upgrade, - cls._is_fk_m2m(field), - ) - - for new_index in new_indexes: - if new_index not in old_indexes: - cls._add_operator( - cls._add_index( - new_model, - new_index, - ), - upgrade, - ) - for old_index in old_indexes: - if old_index not in new_indexes: - cls._add_operator(cls._remove_index(old_model, old_index), upgrade) - - for new_unique in new_unique_together: - if new_unique not in old_unique_together: - cls._add_operator(cls._add_index(new_model, new_unique, unique=True), upgrade) - - for old_unique in old_unique_together: - if old_unique not in new_unique_together: - cls._add_operator(cls._remove_index(old_model, old_unique, unique=True), upgrade) + for change in diff(old_model_describe, new_model_describe): + action, field_type, fields = change + if action == 'add': + for field in fields: + _, field_describe = field + cls._add_field(cls._get_model) @classmethod def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]): @@ -474,10 +258,10 @@ class Migrate: return isinstance(field, (BackwardFKRelation, BackwardOneToOneRelation)) @classmethod - def _add_field(cls, model: Type[Model], field: Field): - if isinstance(field, ForeignKeyFieldInstance): + def _add_field(cls, model: Type[Model], field: dict): + if field.get('field_type') == 'ForeignKeyFieldInstance': return cls.ddl.add_fk(model, field) - if isinstance(field, ManyToManyFieldInstance): + if field.get('field_type') == 'ManyToManyFieldInstance': return cls.ddl.create_m2m_table(model, field) return cls.ddl.add_column(model, field) diff --git a/aerich/models.py b/aerich/models.py index da57d67..a689f03 100644 --- a/aerich/models.py +++ b/aerich/models.py @@ -6,7 +6,7 @@ MAX_VERSION_LENGTH = 255 class Aerich(Model): version = fields.CharField(max_length=MAX_VERSION_LENGTH) app = fields.CharField(max_length=20) - content = fields.TextField() + content = fields.JSONField() class Meta: ordering = ["-id"] diff --git a/aerich/utils.py b/aerich/utils.py index c291874..6e12716 100644 --- a/aerich/utils.py +++ b/aerich/utils.py @@ -108,3 +108,16 @@ def write_version_file(version_file: str, content: Dict): f.write(";\n".join(downgrade) + ";\n") else: f.write(f"{downgrade[0]};\n") + + +def get_models_describe(app: str) -> Dict: + """ + get app models describe + :param app: + :return: + """ + ret = {} + for model in Tortoise.apps.get(app).values(): + describe = model.describe() + ret[describe.get("name")] = describe + return ret