diff --git a/aerich/ddl/__init__.py b/aerich/ddl/__init__.py index 47442c4..6826e83 100644 --- a/aerich/ddl/__init__.py +++ b/aerich/ddl/__init__.py @@ -1,8 +1,9 @@ +from enum import Enum from typing import List, Type -from tortoise import BaseDBAsyncClient, ForeignKeyFieldInstance, ManyToManyFieldInstance, Model +from tortoise import BaseDBAsyncClient, ManyToManyFieldInstance, Model from tortoise.backends.base.schema_generator import BaseSchemaGenerator -from tortoise.fields import CASCADE, Field, JSONField, TextField, UUIDField +from tortoise.fields import CASCADE class BaseDDL: @@ -11,6 +12,7 @@ class BaseDDL: _DROP_TABLE_TEMPLATE = 'DROP TABLE IF EXISTS "{table_name}"' _ADD_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {column}' _DROP_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" DROP COLUMN "{column_name}"' + _ALTER_DEFAULT_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {default}' _RENAME_COLUMN_TEMPLATE = ( 'ALTER TABLE "{table_name}" RENAME COLUMN "{old_column_name}" TO "{new_column_name}"' ) @@ -62,6 +64,8 @@ class BaseDDL: def _get_default(self, model: "Type[Model]", field_describe: dict): db_table = model._meta.db_table default = field_describe.get("default") + if isinstance(default, Enum): + default = default.value db_column = field_describe.get("db_column") auto_now_add = field_describe.get("auto_now_add", False) auto_now = field_describe.get("auto_now", False) @@ -80,7 +84,7 @@ class BaseDDL: except NotImplementedError: default = "" else: - default = "" + default = None return default def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False): @@ -88,6 +92,9 @@ class BaseDDL: description = field_describe.get("description") db_column = field_describe.get("db_column") db_field_types = field_describe.get("db_field_types") + default = self._get_default(model, field_describe) + if default is None: + default = "" return self._ADD_COLUMN_TEMPLATE.format( table_name=db_table, column=self.schema_generator._create_string( @@ -103,33 +110,37 @@ class BaseDDL: if description else "", is_primary_key=is_pk, - default=self._get_default(model, field_describe), + default=default, ), ) - def drop_column(self, model: "Type[Model]", field_describe: dict): + def drop_column(self, model: "Type[Model]", column_name: str): return self._DROP_COLUMN_TEMPLATE.format( - table_name=model._meta.db_table, column_name=field_describe.get("db_column") + table_name=model._meta.db_table, column_name=column_name ) - def modify_column(self, model: "Type[Model]", field_object: Field): + def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False): db_table = model._meta.db_table + db_field_types = field_describe.get("db_field_types") + default = self._get_default(model, field_describe) + if default is None: + default = "" return self._MODIFY_COLUMN_TEMPLATE.format( table_name=db_table, column=self.schema_generator._create_string( - db_column=field_object.model_field_name, - field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"), - nullable="NOT NULL" if not field_object.null else "", + db_column=field_describe.get("db_column"), + field_type=db_field_types.get(self.DIALECT) or db_field_types.get(""), + nullable="NOT NULL" if not field_describe.get("nullable") else "", unique="", comment=self.schema_generator._column_comment_generator( table=db_table, - column=field_object.model_field_name, - comment=field_object.description, + column=field_describe.get("db_column"), + comment=field_describe.get("description"), ) - if field_object.description + if field_describe.get("description") else "", - is_primary_key=field_object.pk, - default=self._get_default(model, field_object), + is_primary_key=is_pk, + default=default, ), ) @@ -200,11 +211,17 @@ class BaseDDL: ), ) - def alter_column_default(self, model: "Type[Model]", field_object: Field): - pass + def alter_column_default(self, model: "Type[Model]", field_describe: dict): + db_table = model._meta.db_table + default = self._get_default(model, field_describe) + return self._ALTER_DEFAULT_TEMPLATE.format( + table_name=db_table, + column=field_describe.get("db_column"), + default="SET" + default if default is not None else "DROP DEFAULT", + ) - def alter_column_null(self, model: "Type[Model]", field_object: Field): - pass + def alter_column_null(self, model: "Type[Model]", field_describe: dict): + raise NotImplementedError - def set_comment(self, model: "Type[Model]", field_object: Field): - pass + def set_comment(self, model: "Type[Model]", field_describe: dict): + raise NotImplementedError diff --git a/aerich/ddl/mysql/__init__.py b/aerich/ddl/mysql/__init__.py index 01a08fa..438d0fd 100644 --- a/aerich/ddl/mysql/__init__.py +++ b/aerich/ddl/mysql/__init__.py @@ -1,6 +1,10 @@ +from typing import Type + +from tortoise import Model from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator from aerich.ddl import BaseDDL +from aerich.exceptions import NotSupportError class MysqlDDL(BaseDDL): @@ -8,6 +12,10 @@ class MysqlDDL(BaseDDL): DIALECT = MySQLSchemaGenerator.DIALECT _DROP_TABLE_TEMPLATE = "DROP TABLE IF EXISTS `{table_name}`" _ADD_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` ADD {column}" + _ALTER_DEFAULT_TEMPLATE = "ALTER TABLE `{table_name}` ALTER COLUMN `{column}` {default}" + _CHANGE_COLUMN_TEMPLATE = ( + "ALTER TABLE `{table_name}` CHANGE {old_column_name} {new_column_name} {new_column_type}" + ) _DROP_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` DROP COLUMN `{column_name}`" _RENAME_COLUMN_TEMPLATE = ( "ALTER TABLE `{table_name}` RENAME COLUMN `{old_column_name}` TO `{new_column_name}`" @@ -20,3 +28,9 @@ class MysqlDDL(BaseDDL): _DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`" _M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment};" _MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}" + + def alter_column_null(self, model: "Type[Model]", field_describe: dict): + raise NotSupportError("Alter column null is unsupported in MySQL.") + + def set_comment(self, model: "Type[Model]", field_describe: dict): + raise NotSupportError("Alter column comment is unsupported in MySQL.") diff --git a/aerich/ddl/postgres/__init__.py b/aerich/ddl/postgres/__init__.py index 055b041..3ebf617 100644 --- a/aerich/ddl/postgres/__init__.py +++ b/aerich/ddl/postgres/__init__.py @@ -16,35 +16,26 @@ class PostgresDDL(BaseDDL): ) _DROP_INDEX_TEMPLATE = 'DROP INDEX "{index_name}"' _DROP_UNIQUE_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{index_name}"' - _ALTER_DEFAULT_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {default}' _ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL' _MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}' _SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"' - def alter_column_default(self, model: "Type[Model]", field_object: Field): - db_table = model._meta.db_table - default = self._get_default(model, field_object) - return self._ALTER_DEFAULT_TEMPLATE.format( - table_name=db_table, - column=field_object.model_field_name, - default="SET" + default if default else "DROP DEFAULT", - ) - - def alter_column_null(self, model: "Type[Model]", field_object: Field): + def alter_column_null(self, model: "Type[Model]", field_describe: dict): db_table = model._meta.db_table return self._ALTER_NULL_TEMPLATE.format( table_name=db_table, - column=field_object.model_field_name, - set_drop="DROP" if field_object.null else "SET", + column=field_describe.get("db_column"), + set_drop="DROP" if field_describe.get("nullable") else "SET", ) - def modify_column(self, model: "Type[Model]", field_object: Field): + def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False): db_table = model._meta.db_table + db_field_types = field_describe.get("db_field_types") return self._MODIFY_COLUMN_TEMPLATE.format( table_name=db_table, - column=field_object.model_field_name, - datatype=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"), + column=field_describe.get("db_column"), + datatype=db_field_types.get(self.DIALECT) or db_field_types.get(""), ) def add_index(self, model: "Type[Model]", field_names: List[str], unique=False): diff --git a/aerich/ddl/sqlite/__init__.py b/aerich/ddl/sqlite/__init__.py index ee66618..539c239 100644 --- a/aerich/ddl/sqlite/__init__.py +++ b/aerich/ddl/sqlite/__init__.py @@ -2,7 +2,6 @@ from typing import Type from tortoise import Model from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator -from tortoise.fields import Field from aerich.ddl import BaseDDL from aerich.exceptions import NotSupportError @@ -15,5 +14,14 @@ class SqliteDDL(BaseDDL): def drop_column(self, model: "Type[Model]", column_name: str): raise NotSupportError("Drop column is unsupported in SQLite.") - def modify_column(self, model: "Type[Model]", field_object: Field): + def modify_column(self, model: "Type[Model]", field_object: dict, is_pk: bool = True): raise NotSupportError("Modify column is unsupported in SQLite.") + + def alter_column_default(self, model: "Type[Model]", field_describe: dict): + raise NotSupportError("Alter column default is unsupported in SQLite.") + + def alter_column_null(self, model: "Type[Model]", field_describe: dict): + raise NotSupportError("Alter column null is unsupported in SQLite.") + + def set_comment(self, model: "Type[Model]", field_describe: dict): + raise NotSupportError("Alter column comment is unsupported in SQLite.") diff --git a/aerich/migrate.py b/aerich/migrate.py index 9e2a88a..46f3b21 100644 --- a/aerich/migrate.py +++ b/aerich/migrate.py @@ -2,6 +2,7 @@ import os from datetime import datetime from pathlib import Path from typing import Dict, List, Optional, Tuple, Type + from dictdiffer import diff from tortoise import ( BackwardFKRelation, @@ -173,6 +174,8 @@ class Migrate: new_models.pop(_aerich, None) for new_model_str, new_model_describe in new_models.items(): + model = cls._get_model(new_model_describe.get("name").split(".")[1]) + if new_model_str not in old_models.keys(): cls._add_operator(cls.add_model(cls._get_model(new_model_str)), upgrade) else: @@ -180,6 +183,18 @@ class Migrate: old_unique_together = old_model_describe.get("unique_together") new_unique_together = new_model_describe.get("unique_together") + # add unique_together + for index in set(new_unique_together).difference(set(old_unique_together)): + cls._add_operator( + cls._add_index(model, index, True), + upgrade, + ) + # remove unique_together + for index in set(old_unique_together).difference(set(new_unique_together)): + cls._add_operator( + cls._drop_index(model, index, True), + upgrade, + ) old_data_fields = old_model_describe.get("data_fields") new_data_fields = new_model_describe.get("data_fields") @@ -187,10 +202,9 @@ class Migrate: old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields)) new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields)) - model = cls._get_model(new_model_describe.get("name").split(".")[1]) # add fields for new_data_field_name in set(new_data_fields_name).difference( - set(old_data_fields_name) + set(old_data_fields_name) ): cls._add_operator( cls._add_field( @@ -205,7 +219,7 @@ class Migrate: ) # remove fields for old_data_field_name in set(old_data_fields_name).difference( - set(new_data_fields_name) + set(new_data_fields_name) ): cls._add_operator( cls._remove_field( @@ -214,7 +228,7 @@ class Migrate: filter( lambda x: x.get("name") == old_data_field_name, old_data_fields ) - ), + ).get("db_column"), ), upgrade, ) @@ -226,7 +240,7 @@ class Migrate: # add fk for new_fk_field_name in set(new_fk_fields_name).difference( - set(old_fk_fields_name) + set(old_fk_fields_name) ): fk_field = next( filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields) @@ -237,7 +251,7 @@ class Migrate: ) # drop fk for old_fk_field_name in set(old_fk_fields_name).difference( - set(new_fk_fields_name) + set(new_fk_fields_name) ): old_fk_field = next( filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields) @@ -259,23 +273,28 @@ class Migrate: changes = diff(old_data_field, new_data_field) for change in changes: _, option, old_new = change - if option == 'indexed': + if option == "indexed": # change index - unique = new_data_field.get('unique') + unique = new_data_field.get("unique") if old_new[0] is False and old_new[1] is True: cls._add_operator( - cls._add_index( - model, (field_name,), unique - ), + cls._add_index(model, (field_name,), unique), upgrade, ) else: cls._add_operator( - cls._drop_index( - model, (field_name,), unique - ), + cls._drop_index(model, (field_name,), unique), upgrade, ) + elif option == "db_field_types.": + # change column + cls._add_operator( + cls._change_field(model, old_data_field, new_data_field), + upgrade, + ) + elif option == "default": + cls._add_operator(cls._alter_default(model, new_data_field), upgrade) + for old_model in old_models: if old_model not in new_models.keys(): cls._add_operator(cls.remove_model(cls._get_model(old_model)), upgrade) @@ -340,40 +359,41 @@ class Migrate: return cls.ddl.add_column(model, field_describe, is_pk) @classmethod - def _alter_default(cls, model: Type[Model], field: Field): - return cls.ddl.alter_column_default(model, field) + def _alter_default(cls, model: Type[Model], field_describe: dict): + return cls.ddl.alter_column_default(model, field_describe) @classmethod - def _alter_null(cls, model: Type[Model], field: Field): - return cls.ddl.alter_column_null(model, field) + def _alter_null(cls, model: Type[Model], field_describe: dict): + return cls.ddl.alter_column_null(model, field_describe) @classmethod - def _set_comment(cls, model: Type[Model], field: Field): - return cls.ddl.set_comment(model, field) + def _set_comment(cls, model: Type[Model], field_describe: dict): + return cls.ddl.set_comment(model, field_describe) @classmethod - def _modify_field(cls, model: Type[Model], field: Field): - return cls.ddl.modify_column(model, field) + def _modify_field(cls, model: Type[Model], field_describe: dict): + return cls.ddl.modify_column(model, field_describe) @classmethod def _drop_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict): return cls.ddl.drop_fk(model, field_describe, reference_table_describe) @classmethod - def _remove_field(cls, model: Type[Model], field_describe: dict): - return cls.ddl.drop_column(model, field_describe) + def _remove_field(cls, model: Type[Model], column_name: str): + return cls.ddl.drop_column(model, column_name) @classmethod def _rename_field(cls, model: Type[Model], old_field: Field, new_field: Field): return cls.ddl.rename_column(model, old_field.model_field_name, new_field.model_field_name) @classmethod - def _change_field(cls, model: Type[Model], old_field: Field, new_field: Field): + def _change_field(cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict): + db_field_types = new_field_describe.get("db_field_types") return cls.ddl.change_column( model, - old_field.model_field_name, - new_field.model_field_name, - new_field.get_for_dialect(cls.dialect, "SQL_TYPE"), + old_field_describe.get("db_column"), + new_field_describe.get("db_column"), + db_field_types.get(cls.dialect) or db_field_types.get(""), ) @classmethod diff --git a/conftest.py b/conftest.py index 4eee9e1..0980aa5 100644 --- a/conftest.py +++ b/conftest.py @@ -51,7 +51,7 @@ def event_loop(): @pytest.fixture(scope="session", autouse=True) async def initialize_tests(event_loop, request): - await Tortoise.init(config=tortoise_orm, _create_db=True) + await Tortoise.init(config=tortoise_orm) await generate_schema_for_client(Tortoise.get_connection("default"), safe=True) client = Tortoise.get_connection("default") diff --git a/poetry.lock b/poetry.lock index 06bd9b4..7970cba 100644 --- a/poetry.lock +++ b/poetry.lock @@ -532,7 +532,7 @@ docs = ["sphinx", "cloud-sptheme", "pygments", "docutils"] type = "git" url = "https://github.com/tortoise/tortoise-orm.git" reference = "develop" -resolved_reference = "37bb36ef3a715b03d18c30452764b348eac21c21" +resolved_reference = "2739267e3dfcfea5e8e33347583de07d547837d6" [[package]] name = "typed-ast" @@ -635,7 +635,6 @@ click = [ ] colorama = [ {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, - {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, ] ddlparse = [ {file = "ddlparse-1.9.0-py3-none-any.whl", hash = "sha256:a7962615a9325be7d0f182cbe34011e6283996473fb98c784c6f675b9783bc18"}, @@ -666,7 +665,6 @@ importlib-metadata = [ {file = "importlib_metadata-3.4.0.tar.gz", hash = "sha256:fa5daa4477a7414ae34e95942e4dd07f62adf589143c875c133c1e53c4eff38d"}, ] iniconfig = [ - {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, ] iso8601 = [ diff --git a/tests/models.py b/tests/models.py index 765e32d..48cc3e0 100644 --- a/tests/models.py +++ b/tests/models.py @@ -23,7 +23,7 @@ class Status(IntEnum): class User(Model): username = fields.CharField(max_length=20, unique=True) - password = fields.CharField(max_length=200) + password = fields.CharField(max_length=100) last_login = fields.DatetimeField(description="Last Login", default=datetime.datetime.now) is_active = fields.BooleanField(default=True, description="Is Active") is_superuser = fields.BooleanField(default=False, description="Is SuperUser") @@ -46,7 +46,7 @@ class Category(Model): class Product(Model): categories = fields.ManyToManyField("models.Category") name = fields.CharField(max_length=50) - view_num = fields.IntField(description="View Num") + view_num = fields.IntField(description="View Num", default=0) sort = fields.IntField() is_reviewed = fields.BooleanField(description="Is Reviewed") type = fields.IntEnumField(ProductType, description="Product Type") @@ -54,10 +54,13 @@ class Product(Model): body = fields.TextField() created_at = fields.DatetimeField(auto_now_add=True) + class Meta: + unique_together = (("name", "type"),) + class Config(Model): label = fields.CharField(max_length=200) key = fields.CharField(max_length=20) value = fields.JSONField() - status: Status = fields.IntEnumField(Status, default=Status.on) + status: Status = fields.IntEnumField(Status) user = fields.ForeignKeyField("models.User", description="User") diff --git a/tests/test_ddl.py b/tests/test_ddl.py index 174bc51..2a9091e 100644 --- a/tests/test_ddl.py +++ b/tests/test_ddl.py @@ -5,7 +5,7 @@ from aerich.ddl.postgres import PostgresDDL from aerich.ddl.sqlite import SqliteDDL from aerich.exceptions import NotSupportError from aerich.migrate import Migrate -from tests.models import Category, User +from tests.models import Category, Product, User def test_create_table(): @@ -67,13 +67,12 @@ def test_add_column(): def test_modify_column(): if isinstance(Migrate.ddl, SqliteDDL): - with pytest.raises(NotSupportError): - ret0 = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name")) - ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active")) + return - else: - ret0 = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name")) - ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active")) + ret0 = Migrate.ddl.modify_column( + Category, Category._meta.fields_map.get("name").describe(False) + ) + ret1 = Migrate.ddl.modify_column(User, User._meta.fields_map.get("is_active").describe(False)) if isinstance(Migrate.ddl, MysqlDDL): assert ret0 == "ALTER TABLE `category` MODIFY COLUMN `name` VARCHAR(200) NOT NULL" elif isinstance(Migrate.ddl, PostgresDDL): @@ -89,41 +88,52 @@ def test_modify_column(): def test_alter_column_default(): - ret = Migrate.ddl.alter_column_default(Category, Category._meta.fields_map.get("name")) + if isinstance(Migrate.ddl, SqliteDDL): + return + ret = Migrate.ddl.alter_column_default( + Category, Category._meta.fields_map.get("name").describe(False) + ) if isinstance(Migrate.ddl, PostgresDDL): assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" DROP DEFAULT' - else: - assert ret is None + elif isinstance(Migrate.ddl, MysqlDDL): + assert ret == "ALTER TABLE `category` ALTER COLUMN `name` DROP DEFAULT" - ret = Migrate.ddl.alter_column_default(Category, Category._meta.fields_map.get("created_at")) + ret = Migrate.ddl.alter_column_default( + Category, Category._meta.fields_map.get("created_at").describe(False) + ) if isinstance(Migrate.ddl, PostgresDDL): assert ( ret == 'ALTER TABLE "category" ALTER COLUMN "created_at" SET DEFAULT CURRENT_TIMESTAMP' ) - else: - assert ret is None + elif isinstance(Migrate.ddl, MysqlDDL): + assert ( + ret + == "ALTER TABLE `category` ALTER COLUMN `created_at` SET DEFAULT CURRENT_TIMESTAMP(6)" + ) - ret = Migrate.ddl.alter_column_default(User, User._meta.fields_map.get("avatar")) + ret = Migrate.ddl.alter_column_default( + Product, Product._meta.fields_map.get("view_num").describe(False) + ) if isinstance(Migrate.ddl, PostgresDDL): - assert ret == 'ALTER TABLE "user" ALTER COLUMN "avatar" SET DEFAULT \'\'' - else: - assert ret is None + assert ret == 'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0' + elif isinstance(Migrate.ddl, MysqlDDL): + assert ret == "ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0" def test_alter_column_null(): + if isinstance(Migrate.ddl, (SqliteDDL, MysqlDDL)): + return ret = Migrate.ddl.alter_column_null(Category, Category._meta.fields_map.get("name")) if isinstance(Migrate.ddl, PostgresDDL): assert ret == 'ALTER TABLE "category" ALTER COLUMN "name" SET NOT NULL' - else: - assert ret is None def test_set_comment(): + if isinstance(Migrate.ddl, (SqliteDDL, MysqlDDL)): + return ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("name")) if isinstance(Migrate.ddl, PostgresDDL): assert ret == 'COMMENT ON COLUMN "category"."name" IS NULL' - else: - assert ret is None ret = Migrate.ddl.set_comment(Category, Category._meta.fields_map.get("user")) if isinstance(Migrate.ddl, PostgresDDL): diff --git a/tests/test_migrate.py b/tests/test_migrate.py index 23fa838..2a91b01 100644 --- a/tests/test_migrate.py +++ b/tests/test_migrate.py @@ -740,49 +740,90 @@ def test_migrate(): - drop field: User.avatar - add index: Email.email - remove unique: User.username + - change column: length User.password + - add unique_together: (name,type) of Product + - alter default: Config.status """ models_describe = get_models_describe("models") - Migrate.diff_models(old_models_describe, models_describe) + Migrate.app = "models" if isinstance(Migrate.ddl, SqliteDDL): with pytest.raises(NotSupportError): + Migrate.diff_models(old_models_describe, models_describe) Migrate.diff_models(models_describe, old_models_describe, False) else: + Migrate.diff_models(old_models_describe, models_describe) Migrate.diff_models(models_describe, old_models_describe, False) Migrate._merge_operators() if isinstance(Migrate.ddl, MysqlDDL): - assert Migrate.upgrade_operators == [ - "ALTER TABLE `email` DROP FOREIGN KEY `fk_email_user_5b58673d`", - "ALTER TABLE `category` ADD `name` VARCHAR(200) NOT NULL", - "ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)", - "ALTER TABLE `user` RENAME COLUMN `last_login_at` TO `last_login`", - ] - assert Migrate.downgrade_operators == [ - "ALTER TABLE `category` DROP COLUMN `name`", - "ALTER TABLE `user` DROP INDEX `uid_user_usernam_9987ab`", - "ALTER TABLE `user` RENAME COLUMN `last_login` TO `last_login_at`", - "ALTER TABLE `email` ADD CONSTRAINT `fk_email_user_5b58673d` FOREIGN KEY " - "(`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE", - ] + assert sorted(Migrate.upgrade_operators) == sorted( + [ + "ALTER TABLE `config` ADD `user_id` INT NOT NULL COMMENT 'User'", + "ALTER TABLE `config` ADD CONSTRAINT `fk_config_user_17daa970` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE", + "ALTER TABLE `config` ALTER COLUMN `status` DROP DEFAULT", + "ALTER TABLE `email` ADD `address` VARCHAR(200) NOT NULL", + "ALTER TABLE `email` DROP COLUMN `user_id`", + "ALTER TABLE `email` DROP FOREIGN KEY `fk_email_user_5b58673d`", + "ALTER TABLE `email` ADD INDEX `idx_email_email_4a1a33` (`email`)", + "ALTER TABLE `product` ADD UNIQUE INDEX `uid_product_name_f14935` (`name`, `type`)", + "ALTER TABLE `product` ALTER COLUMN `view_num` SET DEFAULT 0", + "ALTER TABLE `user` DROP COLUMN `avatar`", + "ALTER TABLE `user` CHANGE password password VARCHAR(100)", + "ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)", + ] + ) + + assert sorted(Migrate.downgrade_operators) == sorted( + [ + "ALTER TABLE `config` DROP COLUMN `user_id`", + "ALTER TABLE `config` DROP FOREIGN KEY `fk_config_user_17daa970`", + "ALTER TABLE `config` ALTER COLUMN `status` SET DEFAULT 1", + "ALTER TABLE `email` ADD `user_id` INT NOT NULL", + "ALTER TABLE `email` DROP COLUMN `address`", + "ALTER TABLE `email` ADD CONSTRAINT `fk_email_user_5b58673d` FOREIGN KEY (`user_id`) REFERENCES `user` (`id`) ON DELETE CASCADE", + "ALTER TABLE `email` DROP INDEX `idx_email_email_4a1a33`", + "ALTER TABLE `product` DROP INDEX `uid_product_name_f14935`", + "ALTER TABLE `product` ALTER COLUMN `view_num` DROP DEFAULT", + "ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''", + "ALTER TABLE `user` DROP INDEX `idx_user_usernam_9987ab`", + "ALTER TABLE `user` CHANGE password password VARCHAR(200)", + ] + ) + elif isinstance(Migrate.ddl, PostgresDDL): assert Migrate.upgrade_operators == [ - 'ALTER TABLE "email" DROP CONSTRAINT "fk_email_user_5b58673d"', - 'ALTER TABLE "category" ADD "name" VARCHAR(200) NOT NULL', - 'ALTER TABLE "user" ADD CONSTRAINT "uid_user_usernam_9987ab" UNIQUE ("username")', - 'ALTER TABLE "user" RENAME COLUMN "last_login_at" TO "last_login"', + 'ALTER TABLE "config" ADD "user_id" INT NOT NULL COMMENT \'User\'', + 'ALTER TABLE "config" ADD CONSTRAINT "fk_config_user_17daa970" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE', + 'ALTER TABLE "config" ALTER COLUMN "status" DROP DEFAULT', + 'ALTER TABLE "email" ADD "address" VARCHAR(200) NOT NULL', + 'ALTER TABLE "email" DROP COLUMN "user_id"', + 'ALTER TABLE "email" DROP FOREIGN KEY "fk_email_user_5b58673d"', + 'ALTER TABLE "email" ADD INDEX "idx_email_email_4a1a33" ("email")', + 'ALTER TABLE "product" ADD UNIQUE INDEX "uid_product_name_f14935" ("name", "type")', + 'ALTER TABLE "product" ALTER COLUMN "view_num" SET DEFAULT 0', + 'ALTER TABLE "user" DROP COLUMN "avatar"', + 'ALTER TABLE "user" CHANGE password password VARCHAR(100)', + 'ALTER TABLE "user" ADD UNIQUE INDEX "uid_user_usernam_9987ab" ("username")', ] assert Migrate.downgrade_operators == [ - 'ALTER TABLE "category" DROP COLUMN "name"', - 'ALTER TABLE "user" DROP CONSTRAINT "uid_user_usernam_9987ab"', - 'ALTER TABLE "user" RENAME COLUMN "last_login" TO "last_login_at"', + 'ALTER TABLE "config" DROP COLUMN "user_id"', + 'ALTER TABLE "config" DROP FOREIGN KEY "fk_config_user_17daa970"', + 'ALTER TABLE "config" ALTER COLUMN "status" SET DEFAULT 1', + 'ALTER TABLE "email" ADD "user_id" INT NOT NULL', + 'ALTER TABLE "email" DROP COLUMN "address"', 'ALTER TABLE "email" ADD CONSTRAINT "fk_email_user_5b58673d" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE', + 'ALTER TABLE "email" DROP INDEX "idx_email_email_4a1a33"', + 'ALTER TABLE "product" DROP INDEX "uid_product_name_f14935"', + 'ALTER TABLE "product" ALTER COLUMN "view_num" DROP DEFAULT', + 'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'', + 'ALTER TABLE "user" DROP INDEX "idx_user_usernam_9987ab"', + 'ALTER TABLE "user" CHANGE password password VARCHAR(200)', ] elif isinstance(Migrate.ddl, SqliteDDL): assert Migrate.upgrade_operators == [ - 'ALTER TABLE "email" DROP FOREIGN KEY "fk_email_user_5b58673d"', - 'ALTER TABLE "category" ADD "name" VARCHAR(200) NOT NULL', - 'ALTER TABLE "user" ADD UNIQUE INDEX "uid_user_usernam_9987ab" ("username")', - 'ALTER TABLE "user" RENAME COLUMN "last_login_at" TO "last_login"', + 'ALTER TABLE "config" ADD "user_id" INT NOT NULL /* User */', + 'ALTER TABLE "config" ADD CONSTRAINT "fk_config_user_17daa970" FOREIGN KEY ("user_id") REFERENCES "user" ("id") ON DELETE CASCADE', ] + assert Migrate.downgrade_operators == []