basically completed

This commit is contained in:
long2ice 2021-02-03 15:43:04 +08:00
parent c6c398fdf0
commit 01e3de9522
10 changed files with 222 additions and 120 deletions

View File

@ -1,8 +1,9 @@
from enum import Enum
from typing import List, Type from typing import List, Type
from tortoise import BaseDBAsyncClient, ForeignKeyFieldInstance, ManyToManyFieldInstance, Model from tortoise import BaseDBAsyncClient, ManyToManyFieldInstance, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.fields import CASCADE, Field, JSONField, TextField, UUIDField from tortoise.fields import CASCADE
class BaseDDL: class BaseDDL:
@ -11,6 +12,7 @@ class BaseDDL:
_DROP_TABLE_TEMPLATE = 'DROP TABLE IF EXISTS "{table_name}"' _DROP_TABLE_TEMPLATE = 'DROP TABLE IF EXISTS "{table_name}"'
_ADD_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {column}' _ADD_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {column}'
_DROP_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" DROP COLUMN "{column_name}"' _DROP_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" DROP COLUMN "{column_name}"'
_ALTER_DEFAULT_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {default}'
_RENAME_COLUMN_TEMPLATE = ( _RENAME_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" RENAME COLUMN "{old_column_name}" TO "{new_column_name}"' 'ALTER TABLE "{table_name}" RENAME COLUMN "{old_column_name}" TO "{new_column_name}"'
) )
@ -62,6 +64,8 @@ class BaseDDL:
def _get_default(self, model: "Type[Model]", field_describe: dict): def _get_default(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table db_table = model._meta.db_table
default = field_describe.get("default") default = field_describe.get("default")
if isinstance(default, Enum):
default = default.value
db_column = field_describe.get("db_column") db_column = field_describe.get("db_column")
auto_now_add = field_describe.get("auto_now_add", False) auto_now_add = field_describe.get("auto_now_add", False)
auto_now = field_describe.get("auto_now", False) auto_now = field_describe.get("auto_now", False)
@ -80,7 +84,7 @@ class BaseDDL:
except NotImplementedError: except NotImplementedError:
default = "" default = ""
else: else:
default = "" default = None
return default return default
def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False): def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
@ -88,6 +92,9 @@ class BaseDDL:
description = field_describe.get("description") description = field_describe.get("description")
db_column = field_describe.get("db_column") db_column = field_describe.get("db_column")
db_field_types = field_describe.get("db_field_types") db_field_types = field_describe.get("db_field_types")
default = self._get_default(model, field_describe)
if default is None:
default = ""
return self._ADD_COLUMN_TEMPLATE.format( return self._ADD_COLUMN_TEMPLATE.format(
table_name=db_table, table_name=db_table,
column=self.schema_generator._create_string( column=self.schema_generator._create_string(
@ -103,33 +110,37 @@ class BaseDDL:
if description if description
else "", else "",
is_primary_key=is_pk, is_primary_key=is_pk,
default=self._get_default(model, field_describe), default=default,
), ),
) )
def drop_column(self, model: "Type[Model]", field_describe: dict): def drop_column(self, model: "Type[Model]", column_name: str):
return self._DROP_COLUMN_TEMPLATE.format( return self._DROP_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, column_name=field_describe.get("db_column") table_name=model._meta.db_table, column_name=column_name
) )
def modify_column(self, model: "Type[Model]", field_object: Field): def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
db_table = model._meta.db_table db_table = model._meta.db_table
db_field_types = field_describe.get("db_field_types")
default = self._get_default(model, field_describe)
if default is None:
default = ""
return self._MODIFY_COLUMN_TEMPLATE.format( return self._MODIFY_COLUMN_TEMPLATE.format(
table_name=db_table, table_name=db_table,
column=self.schema_generator._create_string( column=self.schema_generator._create_string(
db_column=field_object.model_field_name, db_column=field_describe.get("db_column"),
field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"), field_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
nullable="NOT NULL" if not field_object.null else "", nullable="NOT NULL" if not field_describe.get("nullable") else "",
unique="", unique="",
comment=self.schema_generator._column_comment_generator( comment=self.schema_generator._column_comment_generator(
table=db_table, table=db_table,
column=field_object.model_field_name, column=field_describe.get("db_column"),
comment=field_object.description, comment=field_describe.get("description"),
) )
if field_object.description if field_describe.get("description")
else "", else "",
is_primary_key=field_object.pk, is_primary_key=is_pk,
default=self._get_default(model, field_object), default=default,
), ),
) )
@ -200,11 +211,17 @@ class BaseDDL:
), ),
) )
def alter_column_default(self, model: "Type[Model]", field_object: Field): def alter_column_default(self, model: "Type[Model]", field_describe: dict):
pass db_table = model._meta.db_table
default = self._get_default(model, field_describe)
return self._ALTER_DEFAULT_TEMPLATE.format(
table_name=db_table,
column=field_describe.get("db_column"),
default="SET" + default if default is not None else "DROP DEFAULT",
)
def alter_column_null(self, model: "Type[Model]", field_object: Field): def alter_column_null(self, model: "Type[Model]", field_describe: dict):
pass raise NotImplementedError
def set_comment(self, model: "Type[Model]", field_object: Field): def set_comment(self, model: "Type[Model]", field_describe: dict):
pass raise NotImplementedError

View File

@ -1,6 +1,10 @@
from typing import Type
from tortoise import Model
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
from aerich.exceptions import NotSupportError
class MysqlDDL(BaseDDL): class MysqlDDL(BaseDDL):
@ -8,6 +12,10 @@ class MysqlDDL(BaseDDL):
DIALECT = MySQLSchemaGenerator.DIALECT DIALECT = MySQLSchemaGenerator.DIALECT
_DROP_TABLE_TEMPLATE = "DROP TABLE IF EXISTS `{table_name}`" _DROP_TABLE_TEMPLATE = "DROP TABLE IF EXISTS `{table_name}`"
_ADD_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` ADD {column}" _ADD_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` ADD {column}"
_ALTER_DEFAULT_TEMPLATE = "ALTER TABLE `{table_name}` ALTER COLUMN `{column}` {default}"
_CHANGE_COLUMN_TEMPLATE = (
"ALTER TABLE `{table_name}` CHANGE {old_column_name} {new_column_name} {new_column_type}"
)
_DROP_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` DROP COLUMN `{column_name}`" _DROP_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` DROP COLUMN `{column_name}`"
_RENAME_COLUMN_TEMPLATE = ( _RENAME_COLUMN_TEMPLATE = (
"ALTER TABLE `{table_name}` RENAME COLUMN `{old_column_name}` TO `{new_column_name}`" "ALTER TABLE `{table_name}` RENAME COLUMN `{old_column_name}` TO `{new_column_name}`"
@ -20,3 +28,9 @@ class MysqlDDL(BaseDDL):
_DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`" _DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`"
_M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment};" _M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment};"
_MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}" _MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}"
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
raise NotSupportError("Alter column null is unsupported in MySQL.")
def set_comment(self, model: "Type[Model]", field_describe: dict):
raise NotSupportError("Alter column comment is unsupported in MySQL.")

View File

@ -16,35 +16,26 @@ class PostgresDDL(BaseDDL):
) )
_DROP_INDEX_TEMPLATE = 'DROP INDEX "{index_name}"' _DROP_INDEX_TEMPLATE = 'DROP INDEX "{index_name}"'
_DROP_UNIQUE_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{index_name}"' _DROP_UNIQUE_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{index_name}"'
_ALTER_DEFAULT_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {default}'
_ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL' _ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL'
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}' _MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}'
_SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}' _SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"'
def alter_column_default(self, model: "Type[Model]", field_object: Field): def alter_column_null(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table
default = self._get_default(model, field_object)
return self._ALTER_DEFAULT_TEMPLATE.format(
table_name=db_table,
column=field_object.model_field_name,
default="SET" + default if default else "DROP DEFAULT",
)
def alter_column_null(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table db_table = model._meta.db_table
return self._ALTER_NULL_TEMPLATE.format( return self._ALTER_NULL_TEMPLATE.format(
table_name=db_table, table_name=db_table,
column=field_object.model_field_name, column=field_describe.get("db_column"),
set_drop="DROP" if field_object.null else "SET", set_drop="DROP" if field_describe.get("nullable") else "SET",
) )
def modify_column(self, model: "Type[Model]", field_object: Field): def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
db_table = model._meta.db_table db_table = model._meta.db_table
db_field_types = field_describe.get("db_field_types")
return self._MODIFY_COLUMN_TEMPLATE.format( return self._MODIFY_COLUMN_TEMPLATE.format(
table_name=db_table, table_name=db_table,
column=field_object.model_field_name, column=field_describe.get("db_column"),
datatype=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"), datatype=db_field_types.get(self.DIALECT) or db_field_types.get(""),
) )
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False): def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):

View File

@ -2,7 +2,6 @@ from typing import Type
from tortoise import Model from tortoise import Model
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
from tortoise.fields import Field
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
from aerich.exceptions import NotSupportError from aerich.exceptions import NotSupportError
@ -15,5 +14,14 @@ class SqliteDDL(BaseDDL):
def drop_column(self, model: "Type[Model]", column_name: str): def drop_column(self, model: "Type[Model]", column_name: str):
raise NotSupportError("Drop column is unsupported in SQLite.") raise NotSupportError("Drop column is unsupported in SQLite.")
def modify_column(self, model: "Type[Model]", field_object: Field): def modify_column(self, model: "Type[Model]", field_object: dict, is_pk: bool = True):
raise NotSupportError("Modify column is unsupported in SQLite.") raise NotSupportError("Modify column is unsupported in SQLite.")
def alter_column_default(self, model: "Type[Model]", field_describe: dict):
raise NotSupportError("Alter column default is unsupported in SQLite.")
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
raise NotSupportError("Alter column null is unsupported in SQLite.")
def set_comment(self, model: "Type[Model]", field_describe: dict):
raise NotSupportError("Alter column comment is unsupported in SQLite.")

View File

@ -2,6 +2,7 @@ import os
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Tuple, Type
from dictdiffer import diff from dictdiffer import diff
from tortoise import ( from tortoise import (
BackwardFKRelation, BackwardFKRelation,
@ -173,6 +174,8 @@ class Migrate:
new_models.pop(_aerich, None) new_models.pop(_aerich, None)
for new_model_str, new_model_describe in new_models.items(): for new_model_str, new_model_describe in new_models.items():
model = cls._get_model(new_model_describe.get("name").split(".")[1])
if new_model_str not in old_models.keys(): if new_model_str not in old_models.keys():
cls._add_operator(cls.add_model(cls._get_model(new_model_str)), upgrade) cls._add_operator(cls.add_model(cls._get_model(new_model_str)), upgrade)
else: else:
@ -180,6 +183,18 @@ class Migrate:
old_unique_together = old_model_describe.get("unique_together") old_unique_together = old_model_describe.get("unique_together")
new_unique_together = new_model_describe.get("unique_together") new_unique_together = new_model_describe.get("unique_together")
# add unique_together
for index in set(new_unique_together).difference(set(old_unique_together)):
cls._add_operator(
cls._add_index(model, index, True),
upgrade,
)
# remove unique_together
for index in set(old_unique_together).difference(set(new_unique_together)):
cls._add_operator(
cls._drop_index(model, index, True),
upgrade,
)
old_data_fields = old_model_describe.get("data_fields") old_data_fields = old_model_describe.get("data_fields")
new_data_fields = new_model_describe.get("data_fields") new_data_fields = new_model_describe.get("data_fields")
@ -187,10 +202,9 @@ class Migrate:
old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields)) old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields))
new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields)) new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields))
model = cls._get_model(new_model_describe.get("name").split(".")[1])
# add fields # add fields
for new_data_field_name in set(new_data_fields_name).difference( for new_data_field_name in set(new_data_fields_name).difference(
set(old_data_fields_name) set(old_data_fields_name)
): ):
cls._add_operator( cls._add_operator(
cls._add_field( cls._add_field(
@ -205,7 +219,7 @@ class Migrate:
) )
# remove fields # remove fields
for old_data_field_name in set(old_data_fields_name).difference( for old_data_field_name in set(old_data_fields_name).difference(
set(new_data_fields_name) set(new_data_fields_name)
): ):
cls._add_operator( cls._add_operator(
cls._remove_field( cls._remove_field(
@ -214,7 +228,7 @@ class Migrate:
filter( filter(
lambda x: x.get("name") == old_data_field_name, old_data_fields lambda x: x.get("name") == old_data_field_name, old_data_fields
) )
), ).get("db_column"),
), ),
upgrade, upgrade,
) )
@ -226,7 +240,7 @@ class Migrate:
# add fk # add fk
for new_fk_field_name in set(new_fk_fields_name).difference( for new_fk_field_name in set(new_fk_fields_name).difference(
set(old_fk_fields_name) set(old_fk_fields_name)
): ):
fk_field = next( fk_field = next(
filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields) filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields)
@ -237,7 +251,7 @@ class Migrate:
) )
# drop fk # drop fk
for old_fk_field_name in set(old_fk_fields_name).difference( for old_fk_field_name in set(old_fk_fields_name).difference(
set(new_fk_fields_name) set(new_fk_fields_name)
): ):
old_fk_field = next( old_fk_field = next(
filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields) filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields)
@ -259,23 +273,28 @@ class Migrate:
changes = diff(old_data_field, new_data_field) changes = diff(old_data_field, new_data_field)
for change in changes: for change in changes:
_, option, old_new = change _, option, old_new = change
if option == 'indexed': if option == "indexed":
# change index # change index
unique = new_data_field.get('unique') unique = new_data_field.get("unique")
if old_new[0] is False and old_new[1] is True: if old_new[0] is False and old_new[1] is True:
cls._add_operator( cls._add_operator(
cls._add_index( cls._add_index(model, (field_name,), unique),
model, (field_name,), unique
),
upgrade, upgrade,
) )
else: else:
cls._add_operator( cls._add_operator(
cls._drop_index( cls._drop_index(model, (field_name,), unique),
model, (field_name,), unique
),
upgrade, upgrade,
) )
elif option == "db_field_types.":
# change column
cls._add_operator(
cls._change_field(model, old_data_field, new_data_field),
upgrade,
)
elif option == "default":
cls._add_operator(cls._alter_default(model, new_data_field), upgrade)
for old_model in old_models: for old_model in old_models:
if old_model not in new_models.keys(): if old_model not in new_models.keys():
cls._add_operator(cls.remove_model(cls._get_model(old_model)), upgrade) cls._add_operator(cls.remove_model(cls._get_model(old_model)), upgrade)
@ -340,40 +359,41 @@ class Migrate:
return cls.ddl.add_column(model, field_describe, is_pk) return cls.ddl.add_column(model, field_describe, is_pk)
@classmethod @classmethod
def _alter_default(cls, model: Type[Model], field: Field): def _alter_default(cls, model: Type[Model], field_describe: dict):
return cls.ddl.alter_column_default(model, field) return cls.ddl.alter_column_default(model, field_describe)
@classmethod @classmethod
def _alter_null(cls, model: Type[Model], field: Field): def _alter_null(cls, model: Type[Model], field_describe: dict):
return cls.ddl.alter_column_null(model, field) return cls.ddl.alter_column_null(model, field_describe)
@classmethod @classmethod
def _set_comment(cls, model: Type[Model], field: Field): def _set_comment(cls, model: Type[Model], field_describe: dict):
return cls.ddl.set_comment(model, field) return cls.ddl.set_comment(model, field_describe)
@classmethod @classmethod
def _modify_field(cls, model: Type[Model], field: Field): def _modify_field(cls, model: Type[Model], field_describe: dict):
return cls.ddl.modify_column(model, field) return cls.ddl.modify_column(model, field_describe)
@classmethod @classmethod
def _drop_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict): def _drop_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
return cls.ddl.drop_fk(model, field_describe, reference_table_describe) return cls.ddl.drop_fk(model, field_describe, reference_table_describe)
@classmethod @classmethod
def _remove_field(cls, model: Type[Model], field_describe: dict): def _remove_field(cls, model: Type[Model], column_name: str):
return cls.ddl.drop_column(model, field_describe) return cls.ddl.drop_column(model, column_name)
@classmethod @classmethod
def _rename_field(cls, model: Type[Model], old_field: Field, new_field: Field): def _rename_field(cls, model: Type[Model], old_field: Field, new_field: Field):
return cls.ddl.rename_column(model, old_field.model_field_name, new_field.model_field_name) return cls.ddl.rename_column(model, old_field.model_field_name, new_field.model_field_name)
@classmethod @classmethod
def _change_field(cls, model: Type[Model], old_field: Field, new_field: Field): def _change_field(cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict):
db_field_types = new_field_describe.get("db_field_types")
return cls.ddl.change_column( return cls.ddl.change_column(
model, model,
old_field.model_field_name, old_field_describe.get("db_column"),
new_field.model_field_name, new_field_describe.get("db_column"),
new_field.get_for_dialect(cls.dialect, "SQL_TYPE"), db_field_types.get(cls.dialect) or db_field_types.get(""),
) )
@classmethod @classmethod

View File

@ -51,7 +51,7 @@ def event_loop():
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
async def initialize_tests(event_loop, request): async def initialize_tests(event_loop, request):
await Tortoise.init(config=tortoise_orm, _create_db=True) await Tortoise.init(config=tortoise_orm)
await generate_schema_for_client(Tortoise.get_connection("default"), safe=True) await generate_schema_for_client(Tortoise.get_connection("default"), safe=True)
client = Tortoise.get_connection("default") client = Tortoise.get_connection("default")

4
poetry.lock generated
View File

@ -532,7 +532,7 @@ docs = ["sphinx", "cloud-sptheme", "pygments", "docutils"]
type = "git" type = "git"
url = "https://github.com/tortoise/tortoise-orm.git" url = "https://github.com/tortoise/tortoise-orm.git"
reference = "develop" reference = "develop"
resolved_reference = "37bb36ef3a715b03d18c30452764b348eac21c21" resolved_reference = "2739267e3dfcfea5e8e33347583de07d547837d6"
[[package]] [[package]]
name = "typed-ast" name = "typed-ast"
@ -635,7 +635,6 @@ click = [
] ]
colorama = [ colorama = [
{file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"},
{file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"},
] ]
ddlparse = [ ddlparse = [
{file = "ddlparse-1.9.0-py3-none-any.whl", hash = "sha256:a7962615a9325be7d0f182cbe34011e6283996473fb98c784c6f675b9783bc18"}, {file = "ddlparse-1.9.0-py3-none-any.whl", hash = "sha256:a7962615a9325be7d0f182cbe34011e6283996473fb98c784c6f675b9783bc18"},
@ -666,7 +665,6 @@ importlib-metadata = [
{file = "importlib_metadata-3.4.0.tar.gz", hash = "sha256:fa5daa4477a7414ae34e95942e4dd07f62adf589143c875c133c1e53c4eff38d"}, {file = "importlib_metadata-3.4.0.tar.gz", hash = "sha256:fa5daa4477a7414ae34e95942e4dd07f62adf589143c875c133c1e53c4eff38d"},
] ]
iniconfig = [ iniconfig = [
{file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
{file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"},
] ]
iso8601 = [ iso8601 = [

View File

@ -23,7 +23,7 @@ class Status(IntEnum):
class User(Model): class User(Model):
username = fields.CharField(max_length=20, unique=True) username = fields.CharField(max_length=20, unique=True)
password = fields.CharField(max_length=200) password = fields.CharField(max_length=100)
last_login = fields.DatetimeField(description="Last Login", default=datetime.datetime.now) last_login = fields.DatetimeField(description="Last Login", default=datetime.datetime.now)
is_active = fields.BooleanField(default=True, description="Is Active") is_active = fields.BooleanField(default=True, description="Is Active")
is_superuser = fields.BooleanField(default=False, description="Is SuperUser") is_superuser = fields.BooleanField(default=False, description="Is SuperUser")
@ -46,7 +46,7 @@ class Category(Model):
class Product(Model): class Product(Model):
categories = fields.ManyToManyField("models.Category") categories = fields.ManyToManyField("models.Category")
name = fields.CharField(max_length=50) name = fields.CharField(max_length=50)
view_num = fields.IntField(description="View Num") view_num = fields.IntField(description="View Num", default=0)
sort = fields.IntField() sort = fields.IntField()
is_reviewed = fields.BooleanField(description="Is Reviewed") is_reviewed = fields.BooleanField(description="Is Reviewed")
type = fields.IntEnumField(ProductType, description="Product Type") type = fields.IntEnumField(ProductType, description="Product Type")
@ -54,10 +54,13 @@ class Product(Model):
body = fields.TextField() body = fields.TextField()
created_at = fields.DatetimeField(auto_now_add=True) created_at = fields.DatetimeField(auto_now_add=True)
class Meta:
unique_together = (("name", "type"),)
class Config(Model): class Config(Model):
label = fields.CharField(max_length=200) label = fields.CharField(max_length=200)
key = fields.CharField(max_length=20) key = fields.CharField(max_length=20)
value = fields.JSONField() value = fields.JSONField()
status: Status = fields.IntEnumField(Status, default=Status.on) status: Status = fields.IntEnumField(Status)
user = fields.ForeignKeyField("models.User", description="User") user = fields.ForeignKeyField("models.User", description="User")

View File

@ -5,7 +5,7 @@ from aerich.ddl.postgres import PostgresDDL
from aerich.ddl.sqlite import SqliteDDL from aerich.ddl.sqlite import SqliteDDL
from aerich.exceptions import NotSupportError from aerich.exceptions import NotSupportError
from aerich.migrate import Migrate from aerich.migrate import Migrate
from tests.models import Category, User from tests.models import Category, Product, User
def test_create_table(): def test_create_table():
@ -67,13 +67,12 @@ def test_add_column():
def test_modify_column(): def test_modify_column():
if isinstance(Migrate.ddl, SqliteDDL): if isinstance(Migrate.ddl, SqliteDDL):
with pytest.raises(NotSupportError): return
ret0 = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name"))
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active"))
else: ret0 = Migrate.ddl.modify_column(
ret0 = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name")) Category, Category._meta.fields_map.get("name").describe(False)
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active")) )
ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active").describe(False))
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL" assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL"
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
@ -89,41 +88,52 @@ def test_modify_column():
def test_alter_column_default(): def test_alter_column_default():
ret = Migrate.ddl.alter_column_default(Category, Category._meta.fields_map.get("name")) if isinstance(Migrate.ddl, SqliteDDL):
return
ret = Migrate.ddl.alter_column_default(
Category, Category._meta.fields_map.get("name").describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" DROP DEFAULT' assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" DROP DEFAULT'
else: elif isinstance(Migrate.ddl, MysqlDDL):
assert ret is None assert ret == "ALTER TABLE `category` ALTER COLUMN `name` DROP DEFAULT"
ret = Migrate.ddl.alter_column_default(Category, Category._meta.fields_map.get("created_at")) ret = Migrate.ddl.alter_column_default(
Category, Category._meta.fields_map.get("created_at").describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):
assert ( assert (
ret == 'ALTER TABLE "category" ALTER COLUMN "created_at" SET DEFAULT CURRENT_TIMESTAMP' ret == 'ALTER TABLE "category" ALTER COLUMN "created_at" SET DEFAULT CURRENT_TIMESTAMP'
) )
else: elif isinstance(Migrate.ddl, MysqlDDL):
assert ret is None assert (
ret
== "ALTER TABLE `category` ALTER COLUMN `created_at` SET DEFAULT CURRENT_TIMESTAMP(6)"
)
ret = Migrate.ddl.alter_column_default(User, User._meta.fields_map.get("avatar")) ret = Migrate.ddl.alter_column_default(
Product, Product._meta.fields_map.get("view_num").describe(False)
)
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "user" ALTER COLUMN "avatar" SET DEFAULT \'\'' assert ret == 'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0'
else: elif isinstance(Migrate.ddl, MysqlDDL):
assert ret is None assert ret == "ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0"
def test_alter_column_null(): def test_alter_column_null():
if isinstance(Migrate.ddl, (SqliteDDL, MysqlDDL)):
return
ret = Migrate.ddl.alter_column_null(Category, Category._meta.fields_map.get("name")) ret = Migrate.ddl.alter_column_null(Category, Category._meta.fields_map.get("name"))
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL' assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL'
else:
assert ret is None
def test_set_comment(): def test_set_comment():
if isinstance(Migrate.ddl, (SqliteDDL, MysqlDDL)):
return
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("name")) ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("name"))
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):
assert ret == 'COMMENT ON COLUMN "category"."name" IS NULL' assert ret == 'COMMENT ON COLUMN "category"."name" IS NULL'
else:
assert ret is None
ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("user")) ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("user"))
if isinstance(Migrate.ddl, PostgresDDL): if isinstance(Migrate.ddl, PostgresDDL):

View File

@ -740,49 +740,90 @@ def test_migrate():
- drop field: User.avatar - drop field: User.avatar
- add index: Email.email - add index: Email.email
- remove unique: User.username - remove unique: User.username
- change column: length User.password
- add unique_together: (name,type) of Product
- alter default: Config.status
""" """
models_describe = get_models_describe("models") models_describe = get_models_describe("models")
Migrate.diff_models(old_models_describe, models_describe) Migrate.app = "models"
if isinstance(Migrate.ddl, SqliteDDL): if isinstance(Migrate.ddl, SqliteDDL):
with pytest.raises(NotSupportError): with pytest.raises(NotSupportError):
Migrate.diff_models(old_models_describe, models_describe)
Migrate.diff_models(models_describe, old_models_describe, False) Migrate.diff_models(models_describe, old_models_describe, False)
else: else:
Migrate.diff_models(old_models_describe, models_describe)
Migrate.diff_models(models_describe, old_models_describe, False) Migrate.diff_models(models_describe, old_models_describe, False)
Migrate._merge_operators() Migrate._merge_operators()
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert Migrate.upgrade_operators == [ assert sorted(Migrate.upgrade_operators) == sorted(
"ALTER TABLE `email` DROP FOREIGN KEY `fk_email_user_5b58673d`", [
"ALTER TABLE `category` ADD `name` VARCHAR(200) NOT NULL", "ALTER TABLE `config` ADD `user_id` INT NOT NULL COMMENT 'User'",
"ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)", "ALTER TABLE `config` ADD CONSTRAINT `fk_config_user_17daa970` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
"ALTER TABLE `user` RENAME COLUMN `last_login_at` TO `last_login`", "ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT",
] "ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL",
assert Migrate.downgrade_operators == [ "ALTER TABLE `email` DROP COLUMN `user_id`",
"ALTER TABLE `category` DROP COLUMN `name`", "ALTER TABLE `email` DROP FOREIGN KEY `fk_email_user_5b58673d`",
"ALTER TABLE `user` DROP INDEX `uid_user_usernam_9987ab`", "ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)",
"ALTER TABLE `user` RENAME COLUMN `last_login` TO `last_login_at`", "ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_f14935` (`name`, `type`)",
"ALTER TABLE `email` ADD CONSTRAINT `fk_email_user_5b58673d` FOREIGN KEY " "ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0",
"(`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE", "ALTER TABLE `user` DROP COLUMN `avatar`",
] "ALTER TABLE `user` CHANGE password password VARCHAR(100)",
"ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)",
]
)
assert sorted(Migrate.downgrade_operators) == sorted(
[
"ALTER TABLE `config` DROP COLUMN `user_id`",
"ALTER TABLE `config` DROP FOREIGN KEY `fk_config_user_17daa970`",
"ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1",
"ALTER TABLE `email` ADD `user_id` INT NOT NULL",
"ALTER TABLE `email` DROP COLUMN `address`",
"ALTER TABLE `email` ADD CONSTRAINT `fk_email_user_5b58673d` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE",
"ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`",
"ALTER TABLE `product` DROP INDEX `uid_product_name_f14935`",
"ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT",
"ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''",
"ALTER TABLE `user` DROP INDEX `idx_user_usernam_9987ab`",
"ALTER TABLE `user` CHANGE password password VARCHAR(200)",
]
)
elif isinstance(Migrate.ddl, PostgresDDL): elif isinstance(Migrate.ddl, PostgresDDL):
assert Migrate.upgrade_operators == [ assert Migrate.upgrade_operators == [
'ALTER TABLE "email" DROP CONSTRAINT "fk_email_user_5b58673d"', 'ALTER TABLE "config" ADD "user_id" INT NOT NULL COMMENT \'User\'',
'ALTER TABLE "category" ADD "name" VARCHAR(200) NOT NULL', 'ALTER TABLE "config" ADD CONSTRAINT "fk_config_user_17daa970" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE',
'ALTER TABLE "user" ADD CONSTRAINT "uid_user_usernam_9987ab" UNIQUE ("username")', 'ALTER TABLE "config" ALTER COLUMN "status" DROP DEFAULT',
'ALTER TABLE "user" RENAME COLUMN "last_login_at" TO "last_login"', 'ALTER TABLE "email" ADD "address" VARCHAR(200) NOT NULL',
'ALTER TABLE "email" DROP COLUMN "user_id"',
'ALTER TABLE "email" DROP FOREIGN KEY "fk_email_user_5b58673d"',
'ALTER TABLE "email" ADD INDEX "idx_email_email_4a1a33" ("email")',
'ALTER TABLE "product" ADD UNIQUE INDEX "uid_product_name_f14935" ("name", "type")',
'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0',
'ALTER TABLE "user" DROP COLUMN "avatar"',
'ALTER TABLE "user" CHANGE password password VARCHAR(100)',
'ALTER TABLE "user" ADD UNIQUE INDEX "uid_user_usernam_9987ab" ("username")',
] ]
assert Migrate.downgrade_operators == [ assert Migrate.downgrade_operators == [
'ALTER TABLE "category" DROP COLUMN "name"', 'ALTER TABLE "config" DROP COLUMN "user_id"',
'ALTER TABLE "user" DROP CONSTRAINT "uid_user_usernam_9987ab"', 'ALTER TABLE "config" DROP FOREIGN KEY "fk_config_user_17daa970"',
'ALTER TABLE "user" RENAME COLUMN "last_login" TO "last_login_at"', 'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1',
'ALTER TABLE "email" ADD "user_id" INT NOT NULL',
'ALTER TABLE "email" DROP COLUMN "address"',
'ALTER TABLE "email" ADD CONSTRAINT "fk_email_user_5b58673d" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE', 'ALTER TABLE "email" ADD CONSTRAINT "fk_email_user_5b58673d" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE',
'ALTER TABLE "email" DROP INDEX "idx_email_email_4a1a33"',
'ALTER TABLE "product" DROP INDEX "uid_product_name_f14935"',
'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT',
'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'',
'ALTER TABLE "user" DROP INDEX "idx_user_usernam_9987ab"',
'ALTER TABLE "user" CHANGE password password VARCHAR(200)',
] ]
elif isinstance(Migrate.ddl, SqliteDDL): elif isinstance(Migrate.ddl, SqliteDDL):
assert Migrate.upgrade_operators == [ assert Migrate.upgrade_operators == [
'ALTER TABLE "email" DROP FOREIGN KEY "fk_email_user_5b58673d"', 'ALTER TABLE "config" ADD "user_id" INT NOT NULL /* User */',
'ALTER TABLE "category" ADD "name" VARCHAR(200) NOT NULL', 'ALTER TABLE "config" ADD CONSTRAINT "fk_config_user_17daa970" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE',
'ALTER TABLE "user" ADD UNIQUE INDEX "uid_user_usernam_9987ab" ("username")',
'ALTER TABLE "user" RENAME COLUMN "last_login_at" TO "last_login"',
] ]
assert Migrate.downgrade_operators == [] assert Migrate.downgrade_operators == []