diff --git a/aerich/migrate.py b/aerich/migrate.py index 6085efc..e114858 100644 --- a/aerich/migrate.py +++ b/aerich/migrate.py @@ -1,3 +1,4 @@ +import importlib import os from datetime import datetime from hashlib import md5 @@ -63,6 +64,11 @@ class Migrate: ret = await connection.execute_query(sql) cls._db_version = ret[1][0].get("version") + @classmethod + async def load_ddl_class(cls): + ddl_dialect_module = importlib.import_module(f"aerich.ddl.{cls.dialect}") + return getattr(ddl_dialect_module, f"{cls.dialect.capitalize()}DDL") + @classmethod async def init(cls, config: dict, app: str, location: str): await Tortoise.init(config=config) @@ -74,18 +80,8 @@ class Migrate: connection = get_app_connection(config, app) cls.dialect = connection.schema_generator.DIALECT - if cls.dialect == "mysql": - from aerich.ddl.mysql import MysqlDDL - - cls.ddl = MysqlDDL(connection) - elif cls.dialect == "sqlite": - from aerich.ddl.sqlite import SqliteDDL - - cls.ddl = SqliteDDL(connection) - elif cls.dialect == "postgres": - from aerich.ddl.postgres import PostgresDDL - - cls.ddl = PostgresDDL(connection) + cls.ddl_class = await cls.load_ddl_class() + cls.ddl = cls.ddl_class(connection) await cls._get_db_version(connection) @classmethod