v0.5 refactoring
This commit is contained in:
		| @@ -18,6 +18,7 @@ from aerich.migrate import Migrate | ||||
| from aerich.utils import ( | ||||
|     get_app_connection, | ||||
|     get_app_connection_name, | ||||
|     get_models_describe, | ||||
|     get_tortoise_config, | ||||
|     get_version_content_from_file, | ||||
|     write_version_file, | ||||
| @@ -34,11 +35,7 @@ def coro(f): | ||||
|     @wraps(f) | ||||
|     def wrapper(*args, **kwargs): | ||||
|         loop = asyncio.get_event_loop() | ||||
|         ctx = args[0] | ||||
|         loop.run_until_complete(f(*args, **kwargs)) | ||||
|         app = ctx.obj.get("app") | ||||
|         if app: | ||||
|             Migrate.remove_old_model_file(app, ctx.obj["location"]) | ||||
|         loop.run_until_complete(Tortoise.close_connections()) | ||||
|  | ||||
|     return wrapper | ||||
| @@ -86,7 +83,7 @@ async def cli(ctx: Context, config, app, name): | ||||
|         if invoked_subcommand != "init-db": | ||||
|             if not Path(location, app).exists(): | ||||
|                 raise UsageError("You must exec init-db first", ctx=ctx) | ||||
|             await Migrate.init_with_old_models(tortoise_config, app, location) | ||||
|             await Migrate.init(tortoise_config, app, location) | ||||
|  | ||||
|  | ||||
| @cli.command(help="Generate migrate changes file.") | ||||
| @@ -106,7 +103,6 @@ async def migrate(ctx: Context, name): | ||||
| async def upgrade(ctx: Context): | ||||
|     config = ctx.obj["config"] | ||||
|     app = ctx.obj["app"] | ||||
|     location = ctx.obj["location"] | ||||
|     migrated = False | ||||
|     for version_file in Migrate.get_all_version_files(): | ||||
|         try: | ||||
| @@ -123,7 +119,7 @@ async def upgrade(ctx: Context): | ||||
|                 await Aerich.create( | ||||
|                     version=version_file, | ||||
|                     app=app, | ||||
|                     content=Migrate.get_models_content(config, app, location), | ||||
|                     content=get_models_describe(app), | ||||
|                 ) | ||||
|             click.secho(f"Success upgrade {version_file}", fg=Color.green) | ||||
|             migrated = True | ||||
| @@ -281,7 +277,7 @@ async def init_db(ctx: Context, safe): | ||||
|     await Aerich.create( | ||||
|         version=version, | ||||
|         app=app, | ||||
|         content=Migrate.get_models_content(config, app, location), | ||||
|         content=get_models_describe(app), | ||||
|     ) | ||||
|     content = { | ||||
|         "upgrade": [schema], | ||||
|   | ||||
| @@ -140,7 +140,7 @@ class BaseDDL: | ||||
|         ) | ||||
|  | ||||
|     def change_column( | ||||
|         self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str | ||||
|             self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str | ||||
|     ): | ||||
|         return self._CHANGE_COLUMN_TEMPLATE.format( | ||||
|             table_name=model._meta.db_table, | ||||
| @@ -167,26 +167,23 @@ class BaseDDL: | ||||
|             table_name=model._meta.db_table, | ||||
|         ) | ||||
|  | ||||
|     def add_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance): | ||||
|     def add_fk(self, model: "Type[Model]", field: dict): | ||||
|         db_table = model._meta.db_table | ||||
|         to_field_name = field.to_field_instance.source_field | ||||
|         if not to_field_name: | ||||
|             to_field_name = field.to_field_instance.model_field_name | ||||
|  | ||||
|         db_column = field.source_field or field.model_field_name + "_id" | ||||
|         db_column = field.get('db_column') | ||||
|         fk_name = self.schema_generator._generate_fk_name( | ||||
|             from_table=db_table, | ||||
|             from_field=db_column, | ||||
|             to_table=field.related_model._meta.db_table, | ||||
|             to_field=to_field_name, | ||||
|             to_field=db_column, | ||||
|         ) | ||||
|         return self._ADD_FK_TEMPLATE.format( | ||||
|             table_name=db_table, | ||||
|             fk_name=fk_name, | ||||
|             db_column=db_column, | ||||
|             table=field.related_model._meta.db_table, | ||||
|             field=to_field_name, | ||||
|             on_delete=field.on_delete, | ||||
|             field=db_column, | ||||
|             on_delete=field.get('on_delete'), | ||||
|         ) | ||||
|  | ||||
|     def drop_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance): | ||||
|   | ||||
| @@ -1,14 +1,10 @@ | ||||
| import inspect | ||||
| import os | ||||
| import re | ||||
| from datetime import datetime | ||||
| from importlib import import_module | ||||
| from io import StringIO | ||||
| from pathlib import Path | ||||
| from types import ModuleType | ||||
| from typing import Dict, List, Optional, Tuple, Type | ||||
|  | ||||
| import click | ||||
| from dictdiffer import diff | ||||
| from tortoise import ( | ||||
|     BackwardFKRelation, | ||||
|     BackwardOneToOneRelation, | ||||
| @@ -23,7 +19,7 @@ from tortoise.fields import Field | ||||
|  | ||||
| from aerich.ddl import BaseDDL | ||||
| from aerich.models import MAX_VERSION_LENGTH, Aerich | ||||
| from aerich.utils import get_app_connection, write_version_file | ||||
| from aerich.utils import get_app_connection, get_models_describe, write_version_file | ||||
|  | ||||
|  | ||||
| class Migrate: | ||||
| @@ -38,18 +34,12 @@ class Migrate: | ||||
|     _rename_new = [] | ||||
|  | ||||
|     ddl: BaseDDL | ||||
|     migrate_config: dict | ||||
|     old_models = "old_models" | ||||
|     diff_app = "diff_models" | ||||
|     _last_version_content: Optional[dict] = None | ||||
|     app: str | ||||
|     migrate_location: str | ||||
|     dialect: str | ||||
|     _db_version: Optional[str] = None | ||||
|  | ||||
|     @classmethod | ||||
|     def get_old_model_file(cls, app: str, location: str): | ||||
|         return Path(location, app, cls.old_models + ".py") | ||||
|  | ||||
|     @classmethod | ||||
|     def get_all_version_files(cls) -> List[str]: | ||||
|         return sorted( | ||||
| @@ -57,6 +47,10 @@ class Migrate: | ||||
|             key=lambda x: int(x.split("_")[0]), | ||||
|         ) | ||||
|  | ||||
|     @classmethod | ||||
|     def _get_model(cls, model: str) -> Type[Model]: | ||||
|         return Tortoise.apps.get(cls.app).get(model) | ||||
|  | ||||
|     @classmethod | ||||
|     async def get_last_version(cls) -> Optional[Aerich]: | ||||
|         try: | ||||
| @@ -64,13 +58,6 @@ class Migrate: | ||||
|         except OperationalError: | ||||
|             pass | ||||
|  | ||||
|     @classmethod | ||||
|     def remove_old_model_file(cls, app: str, location: str): | ||||
|         try: | ||||
|             os.unlink(cls.get_old_model_file(app, location)) | ||||
|         except (OSError, FileNotFoundError): | ||||
|             pass | ||||
|  | ||||
|     @classmethod | ||||
|     async def _get_db_version(cls, connection: BaseDBAsyncClient): | ||||
|         if cls.dialect == "mysql": | ||||
| @@ -79,19 +66,13 @@ class Migrate: | ||||
|             cls._db_version = ret[1][0].get("version") | ||||
|  | ||||
|     @classmethod | ||||
|     async def init_with_old_models(cls, config: dict, app: str, location: str): | ||||
|     async def init(cls, config: dict, app: str, location: str): | ||||
|         await Tortoise.init(config=config) | ||||
|         last_version = await cls.get_last_version() | ||||
|         cls.app = app | ||||
|         cls.migrate_location = Path(location, app) | ||||
|         if last_version: | ||||
|             content = last_version.content | ||||
|             with open(cls.get_old_model_file(app, location), "w", encoding="utf-8") as f: | ||||
|                 f.write(content) | ||||
|  | ||||
|             migrate_config = cls._get_migrate_config(config, app, location) | ||||
|             cls.migrate_config = migrate_config | ||||
|             await Tortoise.init(config=migrate_config) | ||||
|             cls._last_version_content = last_version.content | ||||
|  | ||||
|         connection = get_app_connection(config, app) | ||||
|         cls.dialect = connection.schema_generator.DIALECT | ||||
| @@ -149,12 +130,9 @@ class Migrate: | ||||
|         :param name: | ||||
|         :return: | ||||
|         """ | ||||
|         apps = Tortoise.apps | ||||
|         diff_models = apps.get(cls.diff_app) | ||||
|         app_models = apps.get(cls.app) | ||||
|  | ||||
|         cls.diff_models(diff_models, app_models) | ||||
|         cls.diff_models(app_models, diff_models, False) | ||||
|         new_version_content = get_models_describe(cls.app) | ||||
|         cls.diff_models(cls._last_version_content, new_version_content) | ||||
|         cls.diff_models(new_version_content, cls._last_version_content, False) | ||||
|  | ||||
|         cls._merge_operators() | ||||
|  | ||||
| @@ -183,58 +161,9 @@ class Migrate: | ||||
|             else: | ||||
|                 cls.downgrade_operators.append(operator) | ||||
|  | ||||
|     @classmethod | ||||
|     def _get_migrate_config(cls, config: dict, app: str, location: str): | ||||
|         """ | ||||
|         generate tmp config with old models | ||||
|         :param config: | ||||
|         :param app: | ||||
|         :param location: | ||||
|         :return: | ||||
|         """ | ||||
|         path = Path(location, app, cls.old_models).as_posix().replace("/", ".") | ||||
|         config["apps"][cls.diff_app] = { | ||||
|             "models": [path], | ||||
|             "default_connection": config.get("apps").get(app).get("default_connection", "default"), | ||||
|         } | ||||
|         return config | ||||
|  | ||||
|     @classmethod | ||||
|     def get_models_content(cls, config: dict, app: str, location: str): | ||||
|         """ | ||||
|         write new models to old models | ||||
|         :param config: | ||||
|         :param app: | ||||
|         :param location: | ||||
|         :return: | ||||
|         """ | ||||
|         old_model_files = [] | ||||
|         models = config.get("apps").get(app).get("models") | ||||
|         for model in models: | ||||
|             if isinstance(model, ModuleType): | ||||
|                 module = model | ||||
|             else: | ||||
|                 module = import_module(model) | ||||
|             possible_models = [getattr(module, attr_name) for attr_name in dir(module)] | ||||
|             for attr in filter( | ||||
|                 lambda x: inspect.isclass(x) and issubclass(x, Model) and x is not Model, | ||||
|                 possible_models, | ||||
|             ): | ||||
|                 file = inspect.getfile(attr) | ||||
|                 if file not in old_model_files: | ||||
|                     old_model_files.append(file) | ||||
|         pattern = rf"(\n)?('|\")({app})(.\w+)('|\")" | ||||
|         str_io = StringIO() | ||||
|         for i, model_file in enumerate(old_model_files): | ||||
|             with open(model_file, "r", encoding="utf-8") as f: | ||||
|                 content = f.read() | ||||
|             ret = re.sub(pattern, rf"\2{cls.diff_app}\4\5", content) | ||||
|             str_io.write(f"{ret}\n") | ||||
|         return str_io.getvalue() | ||||
|  | ||||
|     @classmethod | ||||
|     def diff_models( | ||||
|         cls, old_models: Dict[str, Type[Model]], new_models: Dict[str, Type[Model]], upgrade=True | ||||
|             cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True | ||||
|     ): | ||||
|         """ | ||||
|         diff models and add operators | ||||
| @@ -243,18 +172,19 @@ class Migrate: | ||||
|         :param upgrade: | ||||
|         :return: | ||||
|         """ | ||||
|         old_models.pop(cls._aerich, None) | ||||
|         new_models.pop(cls._aerich, None) | ||||
|         _aerich = f'{cls.app}.{cls._aerich}' | ||||
|         old_models.pop(_aerich, None) | ||||
|         new_models.pop(_aerich, None) | ||||
|  | ||||
|         for new_model_str, new_model in new_models.items(): | ||||
|         for new_model_str, new_model_describe in new_models.items(): | ||||
|             if new_model_str not in old_models.keys(): | ||||
|                 cls._add_operator(cls.add_model(new_model), upgrade) | ||||
|                 cls._add_operator(cls.add_model(cls._get_model(new_model_str)), upgrade) | ||||
|             else: | ||||
|                 cls.diff_model(old_models.get(new_model_str), new_model, upgrade) | ||||
|                 cls.diff_model(old_models.get(new_model_str), new_model_describe, upgrade) | ||||
|  | ||||
|         for old_model in old_models: | ||||
|             if old_model not in new_models.keys(): | ||||
|                 cls._add_operator(cls.remove_model(old_models.get(old_model)), upgrade) | ||||
|                 cls._add_operator(cls.remove_model(cls._get_model(old_model)), upgrade) | ||||
|  | ||||
|     @classmethod | ||||
|     def _is_fk_m2m(cls, field: Field): | ||||
| @@ -269,166 +199,20 @@ class Migrate: | ||||
|         return cls.ddl.drop_table(model) | ||||
|  | ||||
|     @classmethod | ||||
|     def diff_model(cls, old_model: Type[Model], new_model: Type[Model], upgrade=True): | ||||
|     def diff_model(cls, old_model_describe: dict, new_model_describe: dict, upgrade=True): | ||||
|         """ | ||||
|         diff single model | ||||
|         :param old_model: | ||||
|         :param new_model: | ||||
|         :param old_model_describe: | ||||
|         :param new_model_describe: | ||||
|         :param upgrade: | ||||
|         :return: | ||||
|         """ | ||||
|         old_indexes = old_model._meta.indexes | ||||
|         new_indexes = new_model._meta.indexes | ||||
|  | ||||
|         old_unique_together = old_model._meta.unique_together | ||||
|         new_unique_together = new_model._meta.unique_together | ||||
|  | ||||
|         old_fields_map = old_model._meta.fields_map | ||||
|         new_fields_map = new_model._meta.fields_map | ||||
|  | ||||
|         old_keys = old_fields_map.keys() | ||||
|         new_keys = new_fields_map.keys() | ||||
|         for new_key in new_keys: | ||||
|             new_field = new_fields_map.get(new_key) | ||||
|             if cls._exclude_field(new_field, upgrade): | ||||
|                 continue | ||||
|             if new_key not in old_keys: | ||||
|                 new_field_dict = new_field.describe(serializable=True) | ||||
|                 new_field_dict.pop("name", None) | ||||
|                 new_field_dict.pop("db_column", None) | ||||
|                 for diff_key in old_keys - new_keys: | ||||
|                     old_field = old_fields_map.get(diff_key) | ||||
|                     old_field_dict = old_field.describe(serializable=True) | ||||
|                     old_field_dict.pop("name", None) | ||||
|                     old_field_dict.pop("db_column", None) | ||||
|                     if old_field_dict == new_field_dict: | ||||
|                         if upgrade: | ||||
|                             is_rename = click.prompt( | ||||
|                                 f"Rename {diff_key} to {new_key}?", | ||||
|                                 default=True, | ||||
|                                 type=bool, | ||||
|                                 show_choices=True, | ||||
|                             ) | ||||
|                             cls._rename_new.append(new_key) | ||||
|                             cls._rename_old.append(diff_key) | ||||
|                         else: | ||||
|                             is_rename = diff_key in cls._rename_new | ||||
|                         if is_rename: | ||||
|                             if ( | ||||
|                                 cls.dialect == "mysql" | ||||
|                                 and cls._db_version | ||||
|                                 and cls._db_version.startswith("5.") | ||||
|                             ): | ||||
|                                 cls._add_operator( | ||||
|                                     cls._change_field(new_model, old_field, new_field), | ||||
|                                     upgrade, | ||||
|                                 ) | ||||
|                             else: | ||||
|                                 cls._add_operator( | ||||
|                                     cls._rename_field(new_model, old_field, new_field), | ||||
|                                     upgrade, | ||||
|                                 ) | ||||
|                             break | ||||
|                 else: | ||||
|                     cls._add_operator( | ||||
|                         cls._add_field(new_model, new_field), | ||||
|                         upgrade, | ||||
|                         cls._is_fk_m2m(new_field), | ||||
|                     ) | ||||
|             else: | ||||
|                 old_field = old_fields_map.get(new_key) | ||||
|                 new_field_dict = new_field.describe(serializable=True) | ||||
|                 new_field_dict.pop("unique") | ||||
|                 new_field_dict.pop("indexed") | ||||
|                 old_field_dict = old_field.describe(serializable=True) | ||||
|                 old_field_dict.pop("unique") | ||||
|                 old_field_dict.pop("indexed") | ||||
|                 if not cls._is_fk_m2m(new_field) and new_field_dict != old_field_dict: | ||||
|                     if cls.dialect == "postgres": | ||||
|                         if new_field.null != old_field.null: | ||||
|                             cls._add_operator( | ||||
|                                 cls._alter_null(new_model, new_field), upgrade=upgrade | ||||
|                             ) | ||||
|                         if new_field.default != old_field.default and not callable( | ||||
|                             new_field.default | ||||
|                         ): | ||||
|                             cls._add_operator( | ||||
|                                 cls._alter_default(new_model, new_field), upgrade=upgrade | ||||
|                             ) | ||||
|                         if new_field.description != old_field.description: | ||||
|                             cls._add_operator( | ||||
|                                 cls._set_comment(new_model, new_field), upgrade=upgrade | ||||
|                             ) | ||||
|                         if new_field.field_type != old_field.field_type: | ||||
|                             cls._add_operator( | ||||
|                                 cls._modify_field(new_model, new_field), upgrade=upgrade | ||||
|                             ) | ||||
|                     else: | ||||
|                         cls._add_operator(cls._modify_field(new_model, new_field), upgrade=upgrade) | ||||
|                 if (old_field.index and not new_field.index) or ( | ||||
|                     old_field.unique and not new_field.unique | ||||
|                 ): | ||||
|                     cls._add_operator( | ||||
|                         cls._remove_index( | ||||
|                             old_model, (old_field.model_field_name,), old_field.unique | ||||
|                         ), | ||||
|                         upgrade, | ||||
|                         cls._is_fk_m2m(old_field), | ||||
|                     ) | ||||
|                 elif (new_field.index and not old_field.index) or ( | ||||
|                     new_field.unique and not old_field.unique | ||||
|                 ): | ||||
|                     cls._add_operator( | ||||
|                         cls._add_index(new_model, (new_field.model_field_name,), new_field.unique), | ||||
|                         upgrade, | ||||
|                         cls._is_fk_m2m(new_field), | ||||
|                     ) | ||||
|                 if isinstance(new_field, ForeignKeyFieldInstance): | ||||
|                     if old_field.db_constraint and not new_field.db_constraint: | ||||
|                         cls._add_operator( | ||||
|                             cls._drop_fk(new_model, new_field), | ||||
|                             upgrade, | ||||
|                             True, | ||||
|                         ) | ||||
|                     if new_field.db_constraint and not old_field.db_constraint: | ||||
|                         cls._add_operator( | ||||
|                             cls._add_fk(new_model, new_field), | ||||
|                             upgrade, | ||||
|                             True, | ||||
|                         ) | ||||
|  | ||||
|         for old_key in old_keys: | ||||
|             field = old_fields_map.get(old_key) | ||||
|             if old_key not in new_keys and not cls._exclude_field(field, upgrade): | ||||
|                 if (upgrade and old_key not in cls._rename_old) or ( | ||||
|                     not upgrade and old_key not in cls._rename_new | ||||
|                 ): | ||||
|                     cls._add_operator( | ||||
|                         cls._remove_field(old_model, field), | ||||
|                         upgrade, | ||||
|                         cls._is_fk_m2m(field), | ||||
|                     ) | ||||
|  | ||||
|         for new_index in new_indexes: | ||||
|             if new_index not in old_indexes: | ||||
|                 cls._add_operator( | ||||
|                     cls._add_index( | ||||
|                         new_model, | ||||
|                         new_index, | ||||
|                     ), | ||||
|                     upgrade, | ||||
|                 ) | ||||
|         for old_index in old_indexes: | ||||
|             if old_index not in new_indexes: | ||||
|                 cls._add_operator(cls._remove_index(old_model, old_index), upgrade) | ||||
|  | ||||
|         for new_unique in new_unique_together: | ||||
|             if new_unique not in old_unique_together: | ||||
|                 cls._add_operator(cls._add_index(new_model, new_unique, unique=True), upgrade) | ||||
|  | ||||
|         for old_unique in old_unique_together: | ||||
|             if old_unique not in new_unique_together: | ||||
|                 cls._add_operator(cls._remove_index(old_model, old_unique, unique=True), upgrade) | ||||
|         for change in diff(old_model_describe, new_model_describe): | ||||
|             action, field_type, fields = change | ||||
|             if action == 'add': | ||||
|                 for field in fields: | ||||
|                     _, field_describe = field | ||||
|                     cls._add_field(cls._get_model) | ||||
|  | ||||
|     @classmethod | ||||
|     def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]): | ||||
| @@ -474,10 +258,10 @@ class Migrate: | ||||
|         return isinstance(field, (BackwardFKRelation, BackwardOneToOneRelation)) | ||||
|  | ||||
|     @classmethod | ||||
|     def _add_field(cls, model: Type[Model], field: Field): | ||||
|         if isinstance(field, ForeignKeyFieldInstance): | ||||
|     def _add_field(cls, model: Type[Model], field: dict): | ||||
|         if field.get('field_type') == 'ForeignKeyFieldInstance': | ||||
|             return cls.ddl.add_fk(model, field) | ||||
|         if isinstance(field, ManyToManyFieldInstance): | ||||
|         if field.get('field_type') == 'ManyToManyFieldInstance': | ||||
|             return cls.ddl.create_m2m_table(model, field) | ||||
|         return cls.ddl.add_column(model, field) | ||||
|  | ||||
|   | ||||
| @@ -6,7 +6,7 @@ MAX_VERSION_LENGTH = 255 | ||||
| class Aerich(Model): | ||||
|     version = fields.CharField(max_length=MAX_VERSION_LENGTH) | ||||
|     app = fields.CharField(max_length=20) | ||||
|     content = fields.TextField() | ||||
|     content = fields.JSONField() | ||||
|  | ||||
|     class Meta: | ||||
|         ordering = ["-id"] | ||||
|   | ||||
| @@ -108,3 +108,16 @@ def write_version_file(version_file: str, content: Dict): | ||||
|                 f.write(";\n".join(downgrade) + ";\n") | ||||
|             else: | ||||
|                 f.write(f"{downgrade[0]};\n") | ||||
|  | ||||
|  | ||||
| def get_models_describe(app: str) -> Dict: | ||||
|     """ | ||||
|     get app models describe | ||||
|     :param app: | ||||
|     :return: | ||||
|     """ | ||||
|     ret = {} | ||||
|     for model in Tortoise.apps.get(app).values(): | ||||
|         describe = model.describe() | ||||
|         ret[describe.get("name")] = describe | ||||
|     return ret | ||||
|   | ||||
		Reference in New Issue
	
	Block a user