diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 49ff8c2..fc0c877 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -9,6 +9,7 @@ ChangeLog ----- - Fix upgrade error when migrate. - Fix init db sql error. +- Support change column. 0.1.7 ----- diff --git a/README.rst b/README.rst index ae1de83..4135026 100644 --- a/README.rst +++ b/README.rst @@ -161,7 +161,7 @@ Show heads to be migrated Limitations =========== -* Not support ``change column`` now. +* Not support ``rename column`` now. * ``Sqlite`` and ``Postgres`` may not work as expected because I don't use those in my work. License diff --git a/aerich/cli.py b/aerich/cli.py index f752207..339566f 100644 --- a/aerich/cli.py +++ b/aerich/cli.py @@ -164,7 +164,7 @@ def history(ctx): ) @click.pass_context async def init( - ctx: Context, tortoise_orm, location, + ctx: Context, tortoise_orm, location, ): config_file = ctx.obj["config_file"] name = ctx.obj["name"] diff --git a/aerich/ddl/__init__.py b/aerich/ddl/__init__.py index c2a3ed0..f803f8f 100644 --- a/aerich/ddl/__init__.py +++ b/aerich/ddl/__init__.py @@ -18,6 +18,7 @@ class BaseDDL: _ADD_FK_TEMPLATE = "ALTER TABLE {table_name} ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}" _DROP_FK_TEMPLATE = "ALTER TABLE {table_name} DROP FOREIGN KEY {fk_name}" _M2M_TABLE_TEMPLATE = "CREATE TABLE {table_name} ({backward_key} {backward_type} NOT NULL REFERENCES {backward_table} ({backward_field}) ON DELETE CASCADE,{forward_key} {forward_type} NOT NULL REFERENCES {forward_table} ({forward_field}) ON DELETE CASCADE){extra}{comment};" + _MODIFY_COLUMN_TEMPLATE = "ALTER TABLE {table_name} MODIFY COLUMN {column}" def __init__(self, client: "BaseDBAsyncClient"): self.client = client @@ -51,7 +52,7 @@ class BaseDDL: def drop_m2m(self, field: ManyToManyFieldInstance): return self._DROP_TABLE_TEMPLATE.format(table_name=field.through) - def add_column(self, model: "Type[Model]", field_object: Field): + def _get_default(self, model: "Type[Model]", field_object: Field): db_table = model._meta.db_table default = field_object.default db_column = field_object.model_field_name @@ -74,6 +75,11 @@ class BaseDDL: default = "" else: default = "" + return default + + def add_column(self, model: "Type[Model]", field_object: Field): + db_table = model._meta.db_table + return self._ADD_COLUMN_TEMPLATE.format( table_name=db_table, column=self.schema_generator._create_string( @@ -89,7 +95,7 @@ class BaseDDL: if field_object.description else "", is_primary_key=field_object.pk, - default=default, + default=self._get_default(model, field_object), ), ) @@ -98,6 +104,27 @@ class BaseDDL: table_name=model._meta.db_table, column_name=column_name ) + def modify_column(self, model: "Type[Model]", field_object: Field): + db_table = model._meta.db_table + return self._MODIFY_COLUMN_TEMPLATE.format( + table_name=db_table, + column=self.schema_generator._create_string( + db_column=field_object.model_field_name, + field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"), + nullable="NOT NULL" if not field_object.null else "", + unique="", + comment=self.schema_generator._column_comment_generator( + table=db_table, + column=field_object.model_field_name, + comment=field_object.description, + ) + if field_object.description + else "", + is_primary_key=field_object.pk, + default=self._get_default(model, field_object), + ), + ) + def add_index(self, model: "Type[Model]", field_names: List[str], unique=False): return self._ADD_INDEX_TEMPLATE.format( unique="UNIQUE" if unique else "", diff --git a/aerich/migrate.py b/aerich/migrate.py index 862e8a2..b8365de 100644 --- a/aerich/migrate.py +++ b/aerich/migrate.py @@ -220,6 +220,10 @@ class Migrate: if old_model not in new_models.keys(): cls._add_operator(cls.remove_model(old_models.get(old_model)), upgrade) + @classmethod + def _is_fk_m2m(cls, field: Field): + return isinstance(field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)) + @classmethod def add_model(cls, model: Type[Model]): return cls.ddl.create_table(model) @@ -260,6 +264,14 @@ class Migrate: ) else: old_field = old_fields_map.get(new_key) + new_field_dict = new_field.describe(serializable=True) + new_field_dict.pop("unique") + new_field_dict.pop("indexed") + old_field_dict = old_field.describe(serializable=True) + old_field_dict.pop("unique") + old_field_dict.pop("indexed") + if not cls._is_fk_m2m(new_field) and new_field_dict != old_field_dict: + cls._add_operator(cls._modify_field(new_model, new_field), upgrade=upgrade) if (old_field.index and not new_field.index) or ( old_field.unique and not new_field.unique ): @@ -268,7 +280,7 @@ class Migrate: old_model, (old_field.model_field_name,), old_field.unique ), upgrade, - isinstance(old_field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)), + cls._is_fk_m2m(old_field), ) elif (new_field.index and not old_field.index) or ( new_field.unique and not old_field.unique @@ -276,16 +288,14 @@ class Migrate: cls._add_operator( cls._add_index(new_model, (new_field.model_field_name,), new_field.unique), upgrade, - isinstance(new_field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)), + cls._is_fk_m2m(new_field), ) for old_key in old_keys: field = old_fields_map.get(old_key) if old_key not in new_keys and not cls._exclude_field(field, upgrade): cls._add_operator( - cls._remove_field(old_model, field), - upgrade, - isinstance(field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)), + cls._remove_field(old_model, field), upgrade, cls._is_fk_m2m(field), ) for new_index in new_indexes: @@ -354,6 +364,10 @@ class Migrate: return cls.ddl.create_m2m_table(model, field) return cls.ddl.add_column(model, field) + @classmethod + def _modify_field(cls, model: Type[Model], field: Field): + return cls.ddl.modify_column(model, field) + @classmethod def _remove_field(cls, model: Type[Model], field: Field): if isinstance(field, ForeignKeyFieldInstance): diff --git a/conftest.py b/conftest.py index 4c2fa5e..30f7f47 100644 --- a/conftest.py +++ b/conftest.py @@ -16,10 +16,7 @@ db_url = os.getenv("TEST_DB", "sqlite://:memory:") tortoise_orm = { "connections": {"default": expand_db_url(db_url, True)}, "apps": { - "models": { - "models": ["tests.models", "aerich.models"], - "default_connection": "default", - }, + "models": {"models": ["tests.models", "aerich.models"], "default_connection": "default",}, }, } @@ -42,8 +39,11 @@ def loop(): @pytest.fixture(scope="session", autouse=True) def initialize_tests(loop, request): - tortoise_orm['connections']['diff_models'] = "sqlite://:memory:" - tortoise_orm['apps']['diff_models'] = {"models": ["tests.diff_models"], "default_connection": "diff_models"} + tortoise_orm["connections"]["diff_models"] = "sqlite://:memory:" + tortoise_orm["apps"]["diff_models"] = { + "models": ["tests.diff_models"], + "default_connection": "diff_models", + } loop.run_until_complete(Tortoise.init(config=tortoise_orm, _create_db=True)) loop.run_until_complete( diff --git a/tests/test_ddl.py b/tests/test_ddl.py index 5ab4040..5cbfc31 100644 --- a/tests/test_ddl.py +++ b/tests/test_ddl.py @@ -61,10 +61,19 @@ def test_add_column(): assert ret == 'ALTER TABLE category ADD "name" VARCHAR(200) NOT NULL' +def test_modify_column(): + ret = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name")) + if isinstance(Migrate.ddl, MysqlDDL): + assert ret == "ALTER TABLE category MODIFY COLUMN `name` VARCHAR(200) NOT NULL" + elif isinstance(Migrate.ddl, PostgresDDL): + assert ret == 'ALTER TABLE category MODIFY COLUMN "name" VARCHAR(200) NOT NULL' + elif isinstance(Migrate.ddl, SqliteDDL): + assert ret == 'ALTER TABLE category MODIFY COLUMN "name" VARCHAR(200) NOT NULL' + + def test_drop_column(): ret = Migrate.ddl.drop_column(Category, "name") assert ret == "ALTER TABLE category DROP COLUMN name" - assert ret == "ALTER TABLE category DROP COLUMN name" def test_add_index():