diff --git a/aerich/cli.py b/aerich/cli.py index 954eda3..4feb041 100644 --- a/aerich/cli.py +++ b/aerich/cli.py @@ -100,7 +100,7 @@ async def upgrade(ctx: Context): content = json.load(f) upgrade_query_list = content.get("upgrade") for upgrade_query in upgrade_query_list: - await conn.execute_query(upgrade_query) + await conn.execute_script(upgrade_query) await Aerich.create(version=version, app=app) click.secho(f"Success upgrade {version}", fg=Color.green) migrated = True diff --git a/aerich/ddl/__init__.py b/aerich/ddl/__init__.py index 09c898d..c484831 100644 --- a/aerich/ddl/__init__.py +++ b/aerich/ddl/__init__.py @@ -179,3 +179,12 @@ class BaseDDL: to_field=to_field_name, ), ) + + def alter_column_default(self, model: "Type[Model]", field_object: Field): + pass + + def alter_column_null(self, model: "Type[Model]", field_object: Field): + pass + + def set_comment(self, model: "Type[Model]", field_object: Field): + pass diff --git a/aerich/ddl/postgres/__init__.py b/aerich/ddl/postgres/__init__.py index 1901301..bdbea6a 100644 --- a/aerich/ddl/postgres/__init__.py +++ b/aerich/ddl/postgres/__init__.py @@ -1,4 +1,5 @@ from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator +from tortoise.fields import Field from aerich.ddl import BaseDDL @@ -6,3 +7,40 @@ from aerich.ddl import BaseDDL class PostgresDDL(BaseDDL): schema_generator_cls = AsyncpgSchemaGenerator DIALECT = AsyncpgSchemaGenerator.DIALECT + _ALTER_DEFAULT_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {default}' + _ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL' + _MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}' + _SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}' + + def alter_column_default(self, model: "Type[Model]", field_object: Field): + db_table = model._meta.db_table + default = self._get_default(model, field_object) + return self._ALTER_DEFAULT_TEMPLATE.format( + table_name=db_table, + column=field_object.model_field_name, + default="SET" + default if default else "DROP DEFAULT" + ) + + def alter_column_null(self, model: "Type[Model]", field_object: Field): + db_table = model._meta.db_table + return self._ALTER_NULL_TEMPLATE.format( + table_name=db_table, + column=field_object.model_field_name, + set_drop="DROP" if field_object.null else "SET" + ) + + def modify_column(self, model: "Type[Model]", field_object: Field): + db_table = model._meta.db_table + return self._MODIFY_COLUMN_TEMPLATE.format( + table_name=db_table, + column=field_object.model_field_name, + datatype=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE") + ) + + def set_comment(self, model: "Type[Model]", field_object: Field): + db_table = model._meta.db_table + return self._SET_COMMENT_TEMPLATE.format( + table_name=db_table, + column=field_object.model_field_name, + comment="'{}'".format(field_object.description) if field_object.description else 'NULL' + ) diff --git a/aerich/migrate.py b/aerich/migrate.py index 2f80ae4..f00130b 100644 --- a/aerich/migrate.py +++ b/aerich/migrate.py @@ -36,6 +36,7 @@ class Migrate: diff_app = "diff_models" app: str migrate_location: str + dialect: str @classmethod def get_old_model_file(cls): @@ -60,15 +61,16 @@ class Migrate: await Tortoise.init(config=migrate_config) connection = get_app_connection(config, app) - if connection.schema_generator.DIALECT == "mysql": + cls.dialect = connection.schema_generator.DIALECT + if cls.dialect == "mysql": from aerich.ddl.mysql import MysqlDDL cls.ddl = MysqlDDL(connection) - elif connection.schema_generator.DIALECT == "sqlite": + elif cls.dialect == "sqlite": from aerich.ddl.sqlite import SqliteDDL cls.ddl = SqliteDDL(connection) - elif connection.schema_generator.DIALECT == "postgres": + elif cls.dialect == "postgres": from aerich.ddl.postgres import PostgresDDL cls.ddl = PostgresDDL(connection) @@ -79,7 +81,7 @@ class Migrate: if not last_version: return None version = last_version.version - return int(version.split("_")[0]) + return int(version.split("_", 1)[0]) @classmethod async def generate_version(cls, name=None): @@ -272,6 +274,13 @@ class Migrate: old_field_dict.pop("unique") old_field_dict.pop("indexed") if not cls._is_fk_m2m(new_field) and new_field_dict != old_field_dict: + if cls.dialect == "postgres": + if new_field.null != old_field.null: + cls._add_operator(cls._alter_null(new_model, new_field), upgrade=upgrade) + if new_field.default != old_field.default: + cls._add_operator(cls._alter_default(new_model, new_field), upgrade=upgrade) + if new_field.description != old_field.description: + cls._add_operator(cls._set_comment(new_model, new_field), upgrade=upgrade) cls._add_operator(cls._modify_field(new_model, new_field), upgrade=upgrade) if (old_field.index and not new_field.index) or ( old_field.unique and not new_field.unique @@ -365,6 +374,18 @@ class Migrate: return cls.ddl.create_m2m_table(model, field) return cls.ddl.add_column(model, field) + @classmethod + def _alter_default(cls, model: Type[Model], field: Field): + return cls.ddl.alter_column_default(model, field) + + @classmethod + def _alter_null(cls, model: Type[Model], field: Field): + return cls.ddl.alter_column_null(model, field) + + @classmethod + def _set_comment(cls, model: Type[Model], field: Field): + return cls.ddl.set_comment(model, field) + @classmethod def _modify_field(cls, model: Type[Model], field: Field): return cls.ddl.modify_column(model, field) diff --git a/aerich/utils.py b/aerich/utils.py index 4085699..e4a98dc 100644 --- a/aerich/utils.py +++ b/aerich/utils.py @@ -11,7 +11,7 @@ def get_app_connection_name(config, app) -> str: :param app: :return: """ - return config.get("apps").get(app).get("default_connection") + return config.get("apps").get(app).get("default_connection", "default") def get_app_connection(config, app) -> BaseDBAsyncClient: diff --git a/tests/test_ddl.py b/tests/test_ddl.py index bbf6cb8..9a6c574 100644 --- a/tests/test_ddl.py +++ b/tests/test_ddl.py @@ -2,7 +2,7 @@ from aerich.ddl.mysql import MysqlDDL from aerich.ddl.postgres import PostgresDDL from aerich.ddl.sqlite import SqliteDDL from aerich.migrate import Migrate -from tests.models import Category +from tests.models import Category, User def test_create_table(): @@ -66,9 +66,61 @@ def test_modify_column(): ret = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name")) if isinstance(Migrate.ddl, MysqlDDL): assert ret == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL" + elif isinstance(Migrate.ddl, PostgresDDL): + assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" TYPE VARCHAR(200)' else: assert ret == 'ALTER TABLE "category" MODIFY COLUMN "name" VARCHAR(200) NOT NULL' + ret = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active")) + if isinstance(Migrate.ddl, MysqlDDL): + assert ret == "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL" + elif isinstance(Migrate.ddl, PostgresDDL): + assert ret == 'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL' + else: + assert ret == 'ALTER TABLE "user" MODIFY COLUMN "is_active" INT NOT NULL DEFAULT 1 /* Is Active */' + + +def test_alter_column_default(): + ret = Migrate.ddl.alter_column_default(Category, Category._meta.fields_map.get("name")) + if isinstance(Migrate.ddl, PostgresDDL): + assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" DROP DEFAULT' + else: + assert ret == None + + ret = Migrate.ddl.alter_column_default(Category, Category._meta.fields_map.get("created_at")) + if isinstance(Migrate.ddl, PostgresDDL): + assert ret == 'ALTER TABLE "category" ALTER COLUMN "created_at" SET DEFAULT CURRENT_TIMESTAMP' + else: + assert ret == None + + ret = Migrate.ddl.alter_column_default(User, User._meta.fields_map.get("avatar")) + if isinstance(Migrate.ddl, PostgresDDL): + assert ret == 'ALTER TABLE "user" ALTER COLUMN "avatar" SET DEFAULT \'\'' + else: + assert ret == None + + +def test_alter_column_null(): + ret = Migrate.ddl.alter_column_null(Category, Category._meta.fields_map.get("name")) + if isinstance(Migrate.ddl, PostgresDDL): + assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL' + else: + assert ret == None + + +def test_set_comment(): + ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("name")) + if isinstance(Migrate.ddl, PostgresDDL): + assert ret == 'COMMENT ON COLUMN "category"."name" IS NULL' + else: + assert ret == None + + ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("user")) + if isinstance(Migrate.ddl, PostgresDDL): + assert ret == 'COMMENT ON COLUMN "category"."user" IS \'User\'' + else: + assert ret == None + def test_drop_column(): ret = Migrate.ddl.drop_column(Category, "name")