commit
615d9747dc
@ -4,9 +4,12 @@ ChangeLog
|
||||
|
||||
0.1
|
||||
===
|
||||
|
||||
0.1.8
|
||||
_____
|
||||
-----
|
||||
- Fix upgrade error when migrate.
|
||||
- Fix init db sql error.
|
||||
- Support change column.
|
||||
|
||||
0.1.7
|
||||
-----
|
||||
|
@ -161,7 +161,7 @@ Show heads to be migrated
|
||||
|
||||
Limitations
|
||||
===========
|
||||
* Not support ``change column`` now.
|
||||
* Not support ``rename column`` now.
|
||||
* ``Sqlite`` and ``Postgres`` may not work as expected because I don't use those in my work.
|
||||
|
||||
License
|
||||
|
@ -218,7 +218,7 @@ async def init_db(ctx: Context, safe):
|
||||
await Aerich.create(version=version, app=app)
|
||||
with open(os.path.join(dirname, version), "w") as f:
|
||||
content = {
|
||||
"upgrade": schema,
|
||||
"upgrade": [schema],
|
||||
}
|
||||
json.dump(content, f, ensure_ascii=False, indent=2)
|
||||
return click.secho(f'Success generate schema for app "{app}"', fg=Color.green)
|
||||
|
@ -18,6 +18,7 @@ class BaseDDL:
|
||||
_ADD_FK_TEMPLATE = "ALTER TABLE {table_name} ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}"
|
||||
_DROP_FK_TEMPLATE = "ALTER TABLE {table_name} DROP FOREIGN KEY {fk_name}"
|
||||
_M2M_TABLE_TEMPLATE = "CREATE TABLE {table_name} ({backward_key} {backward_type} NOT NULL REFERENCES {backward_table} ({backward_field}) ON DELETE CASCADE,{forward_key} {forward_type} NOT NULL REFERENCES {forward_table} ({forward_field}) ON DELETE CASCADE){extra}{comment};"
|
||||
_MODIFY_COLUMN_TEMPLATE = "ALTER TABLE {table_name} MODIFY COLUMN {column}"
|
||||
|
||||
def __init__(self, client: "BaseDBAsyncClient"):
|
||||
self.client = client
|
||||
@ -51,7 +52,7 @@ class BaseDDL:
|
||||
def drop_m2m(self, field: ManyToManyFieldInstance):
|
||||
return self._DROP_TABLE_TEMPLATE.format(table_name=field.through)
|
||||
|
||||
def add_column(self, model: "Type[Model]", field_object: Field):
|
||||
def _get_default(self, model: "Type[Model]", field_object: Field):
|
||||
db_table = model._meta.db_table
|
||||
default = field_object.default
|
||||
db_column = field_object.model_field_name
|
||||
@ -74,6 +75,11 @@ class BaseDDL:
|
||||
default = ""
|
||||
else:
|
||||
default = ""
|
||||
return default
|
||||
|
||||
def add_column(self, model: "Type[Model]", field_object: Field):
|
||||
db_table = model._meta.db_table
|
||||
|
||||
return self._ADD_COLUMN_TEMPLATE.format(
|
||||
table_name=db_table,
|
||||
column=self.schema_generator._create_string(
|
||||
@ -89,7 +95,7 @@ class BaseDDL:
|
||||
if field_object.description
|
||||
else "",
|
||||
is_primary_key=field_object.pk,
|
||||
default=default,
|
||||
default=self._get_default(model, field_object),
|
||||
),
|
||||
)
|
||||
|
||||
@ -98,6 +104,27 @@ class BaseDDL:
|
||||
table_name=model._meta.db_table, column_name=column_name
|
||||
)
|
||||
|
||||
def modify_column(self, model: "Type[Model]", field_object: Field):
|
||||
db_table = model._meta.db_table
|
||||
return self._MODIFY_COLUMN_TEMPLATE.format(
|
||||
table_name=db_table,
|
||||
column=self.schema_generator._create_string(
|
||||
db_column=field_object.model_field_name,
|
||||
field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"),
|
||||
nullable="NOT NULL" if not field_object.null else "",
|
||||
unique="",
|
||||
comment=self.schema_generator._column_comment_generator(
|
||||
table=db_table,
|
||||
column=field_object.model_field_name,
|
||||
comment=field_object.description,
|
||||
)
|
||||
if field_object.description
|
||||
else "",
|
||||
is_primary_key=field_object.pk,
|
||||
default=self._get_default(model, field_object),
|
||||
),
|
||||
)
|
||||
|
||||
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
|
||||
return self._ADD_INDEX_TEMPLATE.format(
|
||||
unique="UNIQUE" if unique else "",
|
||||
|
@ -220,6 +220,10 @@ class Migrate:
|
||||
if old_model not in new_models.keys():
|
||||
cls._add_operator(cls.remove_model(old_models.get(old_model)), upgrade)
|
||||
|
||||
@classmethod
|
||||
def _is_fk_m2m(cls, field: Field):
|
||||
return isinstance(field, (ForeignKeyFieldInstance, ManyToManyFieldInstance))
|
||||
|
||||
@classmethod
|
||||
def add_model(cls, model: Type[Model]):
|
||||
return cls.ddl.create_table(model)
|
||||
@ -260,6 +264,14 @@ class Migrate:
|
||||
)
|
||||
else:
|
||||
old_field = old_fields_map.get(new_key)
|
||||
new_field_dict = new_field.describe(serializable=True)
|
||||
new_field_dict.pop("unique")
|
||||
new_field_dict.pop("indexed")
|
||||
old_field_dict = old_field.describe(serializable=True)
|
||||
old_field_dict.pop("unique")
|
||||
old_field_dict.pop("indexed")
|
||||
if not cls._is_fk_m2m(new_field) and new_field_dict != old_field_dict:
|
||||
cls._add_operator(cls._modify_field(new_model, new_field), upgrade=upgrade)
|
||||
if (old_field.index and not new_field.index) or (
|
||||
old_field.unique and not new_field.unique
|
||||
):
|
||||
@ -268,7 +280,7 @@ class Migrate:
|
||||
old_model, (old_field.model_field_name,), old_field.unique
|
||||
),
|
||||
upgrade,
|
||||
isinstance(old_field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)),
|
||||
cls._is_fk_m2m(old_field),
|
||||
)
|
||||
elif (new_field.index and not old_field.index) or (
|
||||
new_field.unique and not old_field.unique
|
||||
@ -276,16 +288,14 @@ class Migrate:
|
||||
cls._add_operator(
|
||||
cls._add_index(new_model, (new_field.model_field_name,), new_field.unique),
|
||||
upgrade,
|
||||
isinstance(new_field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)),
|
||||
cls._is_fk_m2m(new_field),
|
||||
)
|
||||
|
||||
for old_key in old_keys:
|
||||
field = old_fields_map.get(old_key)
|
||||
if old_key not in new_keys and not cls._exclude_field(field, upgrade):
|
||||
cls._add_operator(
|
||||
cls._remove_field(old_model, field),
|
||||
upgrade,
|
||||
isinstance(field, (ForeignKeyFieldInstance, ManyToManyFieldInstance)),
|
||||
cls._remove_field(old_model, field), upgrade, cls._is_fk_m2m(field),
|
||||
)
|
||||
|
||||
for new_index in new_indexes:
|
||||
@ -354,6 +364,10 @@ class Migrate:
|
||||
return cls.ddl.create_m2m_table(model, field)
|
||||
return cls.ddl.add_column(model, field)
|
||||
|
||||
@classmethod
|
||||
def _modify_field(cls, model: Type[Model], field: Field):
|
||||
return cls.ddl.modify_column(model, field)
|
||||
|
||||
@classmethod
|
||||
def _remove_field(cls, model: Type[Model], field: Field):
|
||||
if isinstance(field, ForeignKeyFieldInstance):
|
||||
|
12
conftest.py
12
conftest.py
@ -16,10 +16,7 @@ db_url = os.getenv("TEST_DB", "sqlite://:memory:")
|
||||
tortoise_orm = {
|
||||
"connections": {"default": expand_db_url(db_url, True)},
|
||||
"apps": {
|
||||
"models": {
|
||||
"models": ["tests.models", "aerich.models"],
|
||||
"default_connection": "default",
|
||||
},
|
||||
"models": {"models": ["tests.models", "aerich.models"], "default_connection": "default",},
|
||||
},
|
||||
}
|
||||
|
||||
@ -42,8 +39,11 @@ def loop():
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def initialize_tests(loop, request):
|
||||
tortoise_orm['connections']['diff_models'] = "sqlite://:memory:"
|
||||
tortoise_orm['apps']['diff_models'] = {"models": ["tests.diff_models"], "default_connection": "diff_models"}
|
||||
tortoise_orm["connections"]["diff_models"] = "sqlite://:memory:"
|
||||
tortoise_orm["apps"]["diff_models"] = {
|
||||
"models": ["tests.diff_models"],
|
||||
"default_connection": "diff_models",
|
||||
}
|
||||
|
||||
loop.run_until_complete(Tortoise.init(config=tortoise_orm, _create_db=True))
|
||||
loop.run_until_complete(
|
||||
|
@ -61,10 +61,19 @@ def test_add_column():
|
||||
assert ret == 'ALTER TABLE category ADD "name" VARCHAR(200) NOT NULL'
|
||||
|
||||
|
||||
def test_modify_column():
|
||||
ret = Migrate.ddl.modify_column(Category, Category._meta.fields_map.get("name"))
|
||||
if isinstance(Migrate.ddl, MysqlDDL):
|
||||
assert ret == "ALTER TABLE category MODIFY COLUMN `name` VARCHAR(200) NOT NULL"
|
||||
elif isinstance(Migrate.ddl, PostgresDDL):
|
||||
assert ret == 'ALTER TABLE category MODIFY COLUMN "name" VARCHAR(200) NOT NULL'
|
||||
elif isinstance(Migrate.ddl, SqliteDDL):
|
||||
assert ret == 'ALTER TABLE category MODIFY COLUMN "name" VARCHAR(200) NOT NULL'
|
||||
|
||||
|
||||
def test_drop_column():
|
||||
ret = Migrate.ddl.drop_column(Category, "name")
|
||||
assert ret == "ALTER TABLE category DROP COLUMN name"
|
||||
assert ret == "ALTER TABLE category DROP COLUMN name"
|
||||
|
||||
|
||||
def test_add_index():
|
||||
|
Loading…
x
Reference in New Issue
Block a user