Enhance PostgreSQL support

This commit is contained in:
Adam Ciarciński 2020-07-07 19:54:55 +02:00
parent 77e9d7bc91
commit 3c111792a9
6 changed files with 127 additions and 7 deletions

View File

@ -100,7 +100,7 @@ async def upgrade(ctx: Context):
content = json.load(f) content = json.load(f)
upgrade_query_list = content.get("upgrade") upgrade_query_list = content.get("upgrade")
for upgrade_query in upgrade_query_list: for upgrade_query in upgrade_query_list:
await conn.execute_query(upgrade_query) await conn.execute_script(upgrade_query)
await Aerich.create(version=version, app=app) await Aerich.create(version=version, app=app)
click.secho(f"Success upgrade {version}", fg=Color.green) click.secho(f"Success upgrade {version}", fg=Color.green)
migrated = True migrated = True

View File

@ -179,3 +179,12 @@ class BaseDDL:
to_field=to_field_name, to_field=to_field_name,
), ),
) )
def alter_column_default(self, model: "Type[Model]", field_object: Field):
pass
def alter_column_null(self, model: "Type[Model]", field_object: Field):
pass
def set_comment(self, model: "Type[Model]", field_object: Field):
pass

View File

@ -1,4 +1,5 @@
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
from tortoise.fields import Field
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
@ -6,3 +7,40 @@ from aerich.ddl import BaseDDL
class PostgresDDL(BaseDDL): class PostgresDDL(BaseDDL):
schema_generator_cls = AsyncpgSchemaGenerator schema_generator_cls = AsyncpgSchemaGenerator
DIALECT = AsyncpgSchemaGenerator.DIALECT DIALECT = AsyncpgSchemaGenerator.DIALECT
_ALTER_DEFAULT_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {default}'
_ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL'
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}'
_SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}'
def alter_column_default(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table
default = self._get_default(model, field_object)
return self._ALTER_DEFAULT_TEMPLATE.format(
table_name=db_table,
column=field_object.model_field_name,
default="SET" + default if default else "DROP DEFAULT"
)
def alter_column_null(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table
return self._ALTER_NULL_TEMPLATE.format(
table_name=db_table,
column=field_object.model_field_name,
set_drop="DROP" if field_object.null else "SET"
)
def modify_column(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table
return self._MODIFY_COLUMN_TEMPLATE.format(
table_name=db_table,
column=field_object.model_field_name,
datatype=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE")
)
def set_comment(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table
return self._SET_COMMENT_TEMPLATE.format(
table_name=db_table,
column=field_object.model_field_name,
comment="'{}'".format(field_object.description) if field_object.description else 'NULL'
)

View File

@ -36,6 +36,7 @@ class Migrate:
diff_app = "diff_models" diff_app = "diff_models"
app: str app: str
migrate_location: str migrate_location: str
dialect: str
@classmethod @classmethod
def get_old_model_file(cls): def get_old_model_file(cls):
@ -60,15 +61,16 @@ class Migrate:
await Tortoise.init(config=migrate_config) await Tortoise.init(config=migrate_config)
connection = get_app_connection(config, app) connection = get_app_connection(config, app)
if connection.schema_generator.DIALECT == "mysql": cls.dialect = connection.schema_generator.DIALECT
if cls.dialect == "mysql":
from aerich.ddl.mysql import MysqlDDL from aerich.ddl.mysql import MysqlDDL
cls.ddl = MysqlDDL(connection) cls.ddl = MysqlDDL(connection)
elif connection.schema_generator.DIALECT == "sqlite": elif cls.dialect == "sqlite":
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
cls.ddl = SqliteDDL(connection) cls.ddl = SqliteDDL(connection)
elif connection.schema_generator.DIALECT == "postgres": elif cls.dialect == "postgres":
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
cls.ddl = PostgresDDL(connection) cls.ddl = PostgresDDL(connection)
@ -79,7 +81,7 @@ class Migrate:
if not last_version: if not last_version:
return None return None
version = last_version.version version = last_version.version
return int(version.split("_")[0]) return int(version.split("_", 1)[0])
@classmethod @classmethod
async def generate_version(cls, name=None): async def generate_version(cls, name=None):
@ -272,6 +274,13 @@ class Migrate:
old_field_dict.pop("unique") old_field_dict.pop("unique")
old_field_dict.pop("indexed") old_field_dict.pop("indexed")
if not cls._is_fk_m2m(new_field) and new_field_dict != old_field_dict: if not cls._is_fk_m2m(new_field) and new_field_dict != old_field_dict:
if cls.dialect == "postgres":
if new_field.null != old_field.null:
cls._add_operator(cls._alter_null(new_model, new_field), upgrade=upgrade)
if new_field.default != old_field.default:
cls._add_operator(cls._alter_default(new_model, new_field), upgrade=upgrade)
if new_field.description != old_field.description:
cls._add_operator(cls._set_comment(new_model, new_field), upgrade=upgrade)
cls._add_operator(cls._modify_field(new_model, new_field), upgrade=upgrade) cls._add_operator(cls._modify_field(new_model, new_field), upgrade=upgrade)
if (old_field.index and not new_field.index) or ( if (old_field.index and not new_field.index) or (
old_field.unique and not new_field.unique old_field.unique and not new_field.unique
@ -365,6 +374,18 @@ class Migrate:
return cls.ddl.create_m2m_table(model, field) return cls.ddl.create_m2m_table(model, field)
return cls.ddl.add_column(model, field) return cls.ddl.add_column(model, field)
@classmethod
def _alter_default(cls, model: Type[Model], field: Field):
return cls.ddl.alter_column_default(model, field)
@classmethod
def _alter_null(cls, model: Type[Model], field: Field):
return cls.ddl.alter_column_null(model, field)
@classmethod
def _set_comment(cls, model: Type[Model], field: Field):
return cls.ddl.set_comment(model, field)
@classmethod @classmethod
def _modify_field(cls, model: Type[Model], field: Field): def _modify_field(cls, model: Type[Model], field: Field):
return cls.ddl.modify_column(model, field) return cls.ddl.modify_column(model, field)

View File

@ -11,7 +11,7 @@ def get_app_connection_name(config, app) -> str:
:param app: :param app:
:return: :return:
""" """
return config.get("apps").get(app).get("default_connection") return config.get("apps").get(app).get("default_connection", "default")
def get_app_connection(config, app) -> BaseDBAsyncClient: def get_app_connection(config, app) -> BaseDBAsyncClient:

View File

@ -2,7 +2,7 @@ from aerich.ddl.mysql import MysqlDDL
from aerich.ddl.postgres import PostgresDDL from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
from aerich.migrate import Migrate from aerich.migrate import Migrate
from tests.models import Category from tests.models import Category, User
def test_create_table(): def test_create_table():
@ -66,9 +66,61 @@ def test_modify_column():
ret = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name")) ret = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name"))
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL" assert ret == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL"
elif isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" TYPE VARCHAR(200)'
else: else:
assert ret == 'ALTER TABLE "category" MODIFY COLUMN "name" VARCHAR(200) NOT NULL' assert ret == 'ALTER TABLE "category" MODIFY COLUMN "name" VARCHAR(200) NOT NULL'
ret = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active"))
if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL"
elif isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL'
else:
assert ret == 'ALTER TABLE "user" MODIFY COLUMN "is_active" INT NOT NULL DEFAULT 1 /* Is Active */'
def test_alter_column_default():
ret = Migrate.ddl.alter_column_default(Category, Category._meta.fields_map.get("name"))
if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" DROP DEFAULT'
else:
assert ret == None
ret = Migrate.ddl.alter_column_default(Category, Category._meta.fields_map.get("created_at"))
if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" ALTER COLUMN "created_at" SET DEFAULT CURRENT_TIMESTAMP'
else:
assert ret == None
ret = Migrate.ddl.alter_column_default(User, User._meta.fields_map.get("avatar"))
if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "user" ALTER COLUMN "avatar" SET DEFAULT \'\''
else:
assert ret == None
def test_alter_column_null():
ret = Migrate.ddl.alter_column_null(Category, Category._meta.fields_map.get("name"))
if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL'
else:
assert ret == None
def test_set_comment():
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("name"))
if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'COMMENT ON COLUMN "category"."name" IS NULL'
else:
assert ret == None
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("user"))
if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'COMMENT ON COLUMN "category"."user" IS \'User\''
else:
assert ret == None
def test_drop_column(): def test_drop_column():
ret = Migrate.ddl.drop_column(Category, "name") ret = Migrate.ddl.drop_column(Category, "name")