Enhance PostgreSQL support
This commit is contained in:
parent
77e9d7bc91
commit
3c111792a9
@ -100,7 +100,7 @@ async def upgrade(ctx: Context):
|
|||||||
content = json.load(f)
|
content = json.load(f)
|
||||||
upgrade_query_list = content.get("upgrade")
|
upgrade_query_list = content.get("upgrade")
|
||||||
for upgrade_query in upgrade_query_list:
|
for upgrade_query in upgrade_query_list:
|
||||||
await conn.execute_query(upgrade_query)
|
await conn.execute_script(upgrade_query)
|
||||||
await Aerich.create(version=version, app=app)
|
await Aerich.create(version=version, app=app)
|
||||||
click.secho(f"Success upgrade {version}", fg=Color.green)
|
click.secho(f"Success upgrade {version}", fg=Color.green)
|
||||||
migrated = True
|
migrated = True
|
||||||
|
@ -179,3 +179,12 @@ class BaseDDL:
|
|||||||
to_field=to_field_name,
|
to_field=to_field_name,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def alter_column_default(self, model: "Type[Model]", field_object: Field):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def alter_column_null(self, model: "Type[Model]", field_object: Field):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_comment(self, model: "Type[Model]", field_object: Field):
|
||||||
|
pass
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
|
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
|
||||||
|
from tortoise.fields import Field
|
||||||
|
|
||||||
from aerich.ddl import BaseDDL
|
from aerich.ddl import BaseDDL
|
||||||
|
|
||||||
@ -6,3 +7,40 @@ from aerich.ddl import BaseDDL
|
|||||||
class PostgresDDL(BaseDDL):
|
class PostgresDDL(BaseDDL):
|
||||||
schema_generator_cls = AsyncpgSchemaGenerator
|
schema_generator_cls = AsyncpgSchemaGenerator
|
||||||
DIALECT = AsyncpgSchemaGenerator.DIALECT
|
DIALECT = AsyncpgSchemaGenerator.DIALECT
|
||||||
|
_ALTER_DEFAULT_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {default}'
|
||||||
|
_ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL'
|
||||||
|
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}'
|
||||||
|
_SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}'
|
||||||
|
|
||||||
|
def alter_column_default(self, model: "Type[Model]", field_object: Field):
|
||||||
|
db_table = model._meta.db_table
|
||||||
|
default = self._get_default(model, field_object)
|
||||||
|
return self._ALTER_DEFAULT_TEMPLATE.format(
|
||||||
|
table_name=db_table,
|
||||||
|
column=field_object.model_field_name,
|
||||||
|
default="SET" + default if default else "DROP DEFAULT"
|
||||||
|
)
|
||||||
|
|
||||||
|
def alter_column_null(self, model: "Type[Model]", field_object: Field):
|
||||||
|
db_table = model._meta.db_table
|
||||||
|
return self._ALTER_NULL_TEMPLATE.format(
|
||||||
|
table_name=db_table,
|
||||||
|
column=field_object.model_field_name,
|
||||||
|
set_drop="DROP" if field_object.null else "SET"
|
||||||
|
)
|
||||||
|
|
||||||
|
def modify_column(self, model: "Type[Model]", field_object: Field):
|
||||||
|
db_table = model._meta.db_table
|
||||||
|
return self._MODIFY_COLUMN_TEMPLATE.format(
|
||||||
|
table_name=db_table,
|
||||||
|
column=field_object.model_field_name,
|
||||||
|
datatype=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE")
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_comment(self, model: "Type[Model]", field_object: Field):
|
||||||
|
db_table = model._meta.db_table
|
||||||
|
return self._SET_COMMENT_TEMPLATE.format(
|
||||||
|
table_name=db_table,
|
||||||
|
column=field_object.model_field_name,
|
||||||
|
comment="'{}'".format(field_object.description) if field_object.description else 'NULL'
|
||||||
|
)
|
||||||
|
@ -36,6 +36,7 @@ class Migrate:
|
|||||||
diff_app = "diff_models"
|
diff_app = "diff_models"
|
||||||
app: str
|
app: str
|
||||||
migrate_location: str
|
migrate_location: str
|
||||||
|
dialect: str
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_old_model_file(cls):
|
def get_old_model_file(cls):
|
||||||
@ -60,15 +61,16 @@ class Migrate:
|
|||||||
await Tortoise.init(config=migrate_config)
|
await Tortoise.init(config=migrate_config)
|
||||||
|
|
||||||
connection = get_app_connection(config, app)
|
connection = get_app_connection(config, app)
|
||||||
if connection.schema_generator.DIALECT == "mysql":
|
cls.dialect = connection.schema_generator.DIALECT
|
||||||
|
if cls.dialect == "mysql":
|
||||||
from aerich.ddl.mysql import MysqlDDL
|
from aerich.ddl.mysql import MysqlDDL
|
||||||
|
|
||||||
cls.ddl = MysqlDDL(connection)
|
cls.ddl = MysqlDDL(connection)
|
||||||
elif connection.schema_generator.DIALECT == "sqlite":
|
elif cls.dialect == "sqlite":
|
||||||
from aerich.ddl.sqlite import SqliteDDL
|
from aerich.ddl.sqlite import SqliteDDL
|
||||||
|
|
||||||
cls.ddl = SqliteDDL(connection)
|
cls.ddl = SqliteDDL(connection)
|
||||||
elif connection.schema_generator.DIALECT == "postgres":
|
elif cls.dialect == "postgres":
|
||||||
from aerich.ddl.postgres import PostgresDDL
|
from aerich.ddl.postgres import PostgresDDL
|
||||||
|
|
||||||
cls.ddl = PostgresDDL(connection)
|
cls.ddl = PostgresDDL(connection)
|
||||||
@ -79,7 +81,7 @@ class Migrate:
|
|||||||
if not last_version:
|
if not last_version:
|
||||||
return None
|
return None
|
||||||
version = last_version.version
|
version = last_version.version
|
||||||
return int(version.split("_")[0])
|
return int(version.split("_", 1)[0])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def generate_version(cls, name=None):
|
async def generate_version(cls, name=None):
|
||||||
@ -272,6 +274,13 @@ class Migrate:
|
|||||||
old_field_dict.pop("unique")
|
old_field_dict.pop("unique")
|
||||||
old_field_dict.pop("indexed")
|
old_field_dict.pop("indexed")
|
||||||
if not cls._is_fk_m2m(new_field) and new_field_dict != old_field_dict:
|
if not cls._is_fk_m2m(new_field) and new_field_dict != old_field_dict:
|
||||||
|
if cls.dialect == "postgres":
|
||||||
|
if new_field.null != old_field.null:
|
||||||
|
cls._add_operator(cls._alter_null(new_model, new_field), upgrade=upgrade)
|
||||||
|
if new_field.default != old_field.default:
|
||||||
|
cls._add_operator(cls._alter_default(new_model, new_field), upgrade=upgrade)
|
||||||
|
if new_field.description != old_field.description:
|
||||||
|
cls._add_operator(cls._set_comment(new_model, new_field), upgrade=upgrade)
|
||||||
cls._add_operator(cls._modify_field(new_model, new_field), upgrade=upgrade)
|
cls._add_operator(cls._modify_field(new_model, new_field), upgrade=upgrade)
|
||||||
if (old_field.index and not new_field.index) or (
|
if (old_field.index and not new_field.index) or (
|
||||||
old_field.unique and not new_field.unique
|
old_field.unique and not new_field.unique
|
||||||
@ -365,6 +374,18 @@ class Migrate:
|
|||||||
return cls.ddl.create_m2m_table(model, field)
|
return cls.ddl.create_m2m_table(model, field)
|
||||||
return cls.ddl.add_column(model, field)
|
return cls.ddl.add_column(model, field)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _alter_default(cls, model: Type[Model], field: Field):
|
||||||
|
return cls.ddl.alter_column_default(model, field)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _alter_null(cls, model: Type[Model], field: Field):
|
||||||
|
return cls.ddl.alter_column_null(model, field)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _set_comment(cls, model: Type[Model], field: Field):
|
||||||
|
return cls.ddl.set_comment(model, field)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _modify_field(cls, model: Type[Model], field: Field):
|
def _modify_field(cls, model: Type[Model], field: Field):
|
||||||
return cls.ddl.modify_column(model, field)
|
return cls.ddl.modify_column(model, field)
|
||||||
|
@ -11,7 +11,7 @@ def get_app_connection_name(config, app) -> str:
|
|||||||
:param app:
|
:param app:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return config.get("apps").get(app).get("default_connection")
|
return config.get("apps").get(app).get("default_connection", "default")
|
||||||
|
|
||||||
|
|
||||||
def get_app_connection(config, app) -> BaseDBAsyncClient:
|
def get_app_connection(config, app) -> BaseDBAsyncClient:
|
||||||
|
@ -2,7 +2,7 @@ from aerich.ddl.mysql import MysqlDDL
|
|||||||
from aerich.ddl.postgres import PostgresDDL
|
from aerich.ddl.postgres import PostgresDDL
|
||||||
from aerich.ddl.sqlite import SqliteDDL
|
from aerich.ddl.sqlite import SqliteDDL
|
||||||
from aerich.migrate import Migrate
|
from aerich.migrate import Migrate
|
||||||
from tests.models import Category
|
from tests.models import Category, User
|
||||||
|
|
||||||
|
|
||||||
def test_create_table():
|
def test_create_table():
|
||||||
@ -66,9 +66,61 @@ def test_modify_column():
|
|||||||
ret = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name"))
|
ret = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name"))
|
||||||
if isinstance(Migrate.ddl, MysqlDDL):
|
if isinstance(Migrate.ddl, MysqlDDL):
|
||||||
assert ret == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL"
|
assert ret == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL"
|
||||||
|
elif isinstance(Migrate.ddl, PostgresDDL):
|
||||||
|
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" TYPE VARCHAR(200)'
|
||||||
else:
|
else:
|
||||||
assert ret == 'ALTER TABLE "category" MODIFY COLUMN "name" VARCHAR(200) NOT NULL'
|
assert ret == 'ALTER TABLE "category" MODIFY COLUMN "name" VARCHAR(200) NOT NULL'
|
||||||
|
|
||||||
|
ret = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active"))
|
||||||
|
if isinstance(Migrate.ddl, MysqlDDL):
|
||||||
|
assert ret == "ALTER TABLE `user` MODIFY COLUMN `is_active` BOOL NOT NULL"
|
||||||
|
elif isinstance(Migrate.ddl, PostgresDDL):
|
||||||
|
assert ret == 'ALTER TABLE "user" ALTER COLUMN "is_active" TYPE BOOL'
|
||||||
|
else:
|
||||||
|
assert ret == 'ALTER TABLE "user" MODIFY COLUMN "is_active" INT NOT NULL DEFAULT 1 /* Is Active */'
|
||||||
|
|
||||||
|
|
||||||
|
def test_alter_column_default():
|
||||||
|
ret = Migrate.ddl.alter_column_default(Category, Category._meta.fields_map.get("name"))
|
||||||
|
if isinstance(Migrate.ddl, PostgresDDL):
|
||||||
|
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" DROP DEFAULT'
|
||||||
|
else:
|
||||||
|
assert ret == None
|
||||||
|
|
||||||
|
ret = Migrate.ddl.alter_column_default(Category, Category._meta.fields_map.get("created_at"))
|
||||||
|
if isinstance(Migrate.ddl, PostgresDDL):
|
||||||
|
assert ret == 'ALTER TABLE "category" ALTER COLUMN "created_at" SET DEFAULT CURRENT_TIMESTAMP'
|
||||||
|
else:
|
||||||
|
assert ret == None
|
||||||
|
|
||||||
|
ret = Migrate.ddl.alter_column_default(User, User._meta.fields_map.get("avatar"))
|
||||||
|
if isinstance(Migrate.ddl, PostgresDDL):
|
||||||
|
assert ret == 'ALTER TABLE "user" ALTER COLUMN "avatar" SET DEFAULT \'\''
|
||||||
|
else:
|
||||||
|
assert ret == None
|
||||||
|
|
||||||
|
|
||||||
|
def test_alter_column_null():
|
||||||
|
ret = Migrate.ddl.alter_column_null(Category, Category._meta.fields_map.get("name"))
|
||||||
|
if isinstance(Migrate.ddl, PostgresDDL):
|
||||||
|
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL'
|
||||||
|
else:
|
||||||
|
assert ret == None
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_comment():
|
||||||
|
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("name"))
|
||||||
|
if isinstance(Migrate.ddl, PostgresDDL):
|
||||||
|
assert ret == 'COMMENT ON COLUMN "category"."name" IS NULL'
|
||||||
|
else:
|
||||||
|
assert ret == None
|
||||||
|
|
||||||
|
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("user"))
|
||||||
|
if isinstance(Migrate.ddl, PostgresDDL):
|
||||||
|
assert ret == 'COMMENT ON COLUMN "category"."user" IS \'User\''
|
||||||
|
else:
|
||||||
|
assert ret == None
|
||||||
|
|
||||||
|
|
||||||
def test_drop_column():
|
def test_drop_column():
|
||||||
ret = Migrate.ddl.drop_column(Category, "name")
|
ret = Migrate.ddl.drop_column(Category, "name")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user