basically completed
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from enum import Enum
|
||||
from typing import List, Type
|
||||
|
||||
from tortoise import BaseDBAsyncClient, ForeignKeyFieldInstance, ManyToManyFieldInstance, Model
|
||||
from tortoise import BaseDBAsyncClient, ManyToManyFieldInstance, Model
|
||||
from tortoise.backends.base.schema_generator import BaseSchemaGenerator
|
||||
from tortoise.fields import CASCADE, Field, JSONField, TextField, UUIDField
|
||||
from tortoise.fields import CASCADE
|
||||
|
||||
|
||||
class BaseDDL:
|
||||
@@ -11,6 +12,7 @@ class BaseDDL:
|
||||
_DROP_TABLE_TEMPLATE = 'DROP TABLE IF EXISTS "{table_name}"'
|
||||
_ADD_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ADD {column}'
|
||||
_DROP_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" DROP COLUMN "{column_name}"'
|
||||
_ALTER_DEFAULT_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {default}'
|
||||
_RENAME_COLUMN_TEMPLATE = (
|
||||
'ALTER TABLE "{table_name}" RENAME COLUMN "{old_column_name}" TO "{new_column_name}"'
|
||||
)
|
||||
@@ -62,6 +64,8 @@ class BaseDDL:
|
||||
def _get_default(self, model: "Type[Model]", field_describe: dict):
|
||||
db_table = model._meta.db_table
|
||||
default = field_describe.get("default")
|
||||
if isinstance(default, Enum):
|
||||
default = default.value
|
||||
db_column = field_describe.get("db_column")
|
||||
auto_now_add = field_describe.get("auto_now_add", False)
|
||||
auto_now = field_describe.get("auto_now", False)
|
||||
@@ -80,7 +84,7 @@ class BaseDDL:
|
||||
except NotImplementedError:
|
||||
default = ""
|
||||
else:
|
||||
default = ""
|
||||
default = None
|
||||
return default
|
||||
|
||||
def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
|
||||
@@ -88,6 +92,9 @@ class BaseDDL:
|
||||
description = field_describe.get("description")
|
||||
db_column = field_describe.get("db_column")
|
||||
db_field_types = field_describe.get("db_field_types")
|
||||
default = self._get_default(model, field_describe)
|
||||
if default is None:
|
||||
default = ""
|
||||
return self._ADD_COLUMN_TEMPLATE.format(
|
||||
table_name=db_table,
|
||||
column=self.schema_generator._create_string(
|
||||
@@ -103,33 +110,37 @@ class BaseDDL:
|
||||
if description
|
||||
else "",
|
||||
is_primary_key=is_pk,
|
||||
default=self._get_default(model, field_describe),
|
||||
default=default,
|
||||
),
|
||||
)
|
||||
|
||||
def drop_column(self, model: "Type[Model]", field_describe: dict):
|
||||
def drop_column(self, model: "Type[Model]", column_name: str):
|
||||
return self._DROP_COLUMN_TEMPLATE.format(
|
||||
table_name=model._meta.db_table, column_name=field_describe.get("db_column")
|
||||
table_name=model._meta.db_table, column_name=column_name
|
||||
)
|
||||
|
||||
def modify_column(self, model: "Type[Model]", field_object: Field):
|
||||
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
|
||||
db_table = model._meta.db_table
|
||||
db_field_types = field_describe.get("db_field_types")
|
||||
default = self._get_default(model, field_describe)
|
||||
if default is None:
|
||||
default = ""
|
||||
return self._MODIFY_COLUMN_TEMPLATE.format(
|
||||
table_name=db_table,
|
||||
column=self.schema_generator._create_string(
|
||||
db_column=field_object.model_field_name,
|
||||
field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"),
|
||||
nullable="NOT NULL" if not field_object.null else "",
|
||||
db_column=field_describe.get("db_column"),
|
||||
field_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
|
||||
nullable="NOT NULL" if not field_describe.get("nullable") else "",
|
||||
unique="",
|
||||
comment=self.schema_generator._column_comment_generator(
|
||||
table=db_table,
|
||||
column=field_object.model_field_name,
|
||||
comment=field_object.description,
|
||||
column=field_describe.get("db_column"),
|
||||
comment=field_describe.get("description"),
|
||||
)
|
||||
if field_object.description
|
||||
if field_describe.get("description")
|
||||
else "",
|
||||
is_primary_key=field_object.pk,
|
||||
default=self._get_default(model, field_object),
|
||||
is_primary_key=is_pk,
|
||||
default=default,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -200,11 +211,17 @@ class BaseDDL:
|
||||
),
|
||||
)
|
||||
|
||||
def alter_column_default(self, model: "Type[Model]", field_object: Field):
|
||||
pass
|
||||
def alter_column_default(self, model: "Type[Model]", field_describe: dict):
|
||||
db_table = model._meta.db_table
|
||||
default = self._get_default(model, field_describe)
|
||||
return self._ALTER_DEFAULT_TEMPLATE.format(
|
||||
table_name=db_table,
|
||||
column=field_describe.get("db_column"),
|
||||
default="SET" + default if default is not None else "DROP DEFAULT",
|
||||
)
|
||||
|
||||
def alter_column_null(self, model: "Type[Model]", field_object: Field):
|
||||
pass
|
||||
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_comment(self, model: "Type[Model]", field_object: Field):
|
||||
pass
|
||||
def set_comment(self, model: "Type[Model]", field_describe: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from typing import Type
|
||||
|
||||
from tortoise import Model
|
||||
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
|
||||
|
||||
from aerich.ddl import BaseDDL
|
||||
from aerich.exceptions import NotSupportError
|
||||
|
||||
|
||||
class MysqlDDL(BaseDDL):
|
||||
@@ -8,6 +12,10 @@ class MysqlDDL(BaseDDL):
|
||||
DIALECT = MySQLSchemaGenerator.DIALECT
|
||||
_DROP_TABLE_TEMPLATE = "DROP TABLE IF EXISTS `{table_name}`"
|
||||
_ADD_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` ADD {column}"
|
||||
_ALTER_DEFAULT_TEMPLATE = "ALTER TABLE `{table_name}` ALTER COLUMN `{column}` {default}"
|
||||
_CHANGE_COLUMN_TEMPLATE = (
|
||||
"ALTER TABLE `{table_name}` CHANGE {old_column_name} {new_column_name} {new_column_type}"
|
||||
)
|
||||
_DROP_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` DROP COLUMN `{column_name}`"
|
||||
_RENAME_COLUMN_TEMPLATE = (
|
||||
"ALTER TABLE `{table_name}` RENAME COLUMN `{old_column_name}` TO `{new_column_name}`"
|
||||
@@ -20,3 +28,9 @@ class MysqlDDL(BaseDDL):
|
||||
_DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`"
|
||||
_M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment};"
|
||||
_MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}"
|
||||
|
||||
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
|
||||
raise NotSupportError("Alter column null is unsupported in MySQL.")
|
||||
|
||||
def set_comment(self, model: "Type[Model]", field_describe: dict):
|
||||
raise NotSupportError("Alter column comment is unsupported in MySQL.")
|
||||
|
||||
@@ -16,35 +16,26 @@ class PostgresDDL(BaseDDL):
|
||||
)
|
||||
_DROP_INDEX_TEMPLATE = 'DROP INDEX "{index_name}"'
|
||||
_DROP_UNIQUE_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{index_name}"'
|
||||
_ALTER_DEFAULT_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {default}'
|
||||
_ALTER_NULL_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" {set_drop} NOT NULL'
|
||||
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" ALTER COLUMN "{column}" TYPE {datatype}'
|
||||
_SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}'
|
||||
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"'
|
||||
|
||||
def alter_column_default(self, model: "Type[Model]", field_object: Field):
|
||||
db_table = model._meta.db_table
|
||||
default = self._get_default(model, field_object)
|
||||
return self._ALTER_DEFAULT_TEMPLATE.format(
|
||||
table_name=db_table,
|
||||
column=field_object.model_field_name,
|
||||
default="SET" + default if default else "DROP DEFAULT",
|
||||
)
|
||||
|
||||
def alter_column_null(self, model: "Type[Model]", field_object: Field):
|
||||
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
|
||||
db_table = model._meta.db_table
|
||||
return self._ALTER_NULL_TEMPLATE.format(
|
||||
table_name=db_table,
|
||||
column=field_object.model_field_name,
|
||||
set_drop="DROP" if field_object.null else "SET",
|
||||
column=field_describe.get("db_column"),
|
||||
set_drop="DROP" if field_describe.get("nullable") else "SET",
|
||||
)
|
||||
|
||||
def modify_column(self, model: "Type[Model]", field_object: Field):
|
||||
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
|
||||
db_table = model._meta.db_table
|
||||
db_field_types = field_describe.get("db_field_types")
|
||||
return self._MODIFY_COLUMN_TEMPLATE.format(
|
||||
table_name=db_table,
|
||||
column=field_object.model_field_name,
|
||||
datatype=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"),
|
||||
column=field_describe.get("db_column"),
|
||||
datatype=db_field_types.get(self.DIALECT) or db_field_types.get(""),
|
||||
)
|
||||
|
||||
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import Type
|
||||
|
||||
from tortoise import Model
|
||||
from tortoise.backends.sqlite.schema_generator import SqliteSchemaGenerator
|
||||
from tortoise.fields import Field
|
||||
|
||||
from aerich.ddl import BaseDDL
|
||||
from aerich.exceptions import NotSupportError
|
||||
@@ -15,5 +14,14 @@ class SqliteDDL(BaseDDL):
|
||||
def drop_column(self, model: "Type[Model]", column_name: str):
|
||||
raise NotSupportError("Drop column is unsupported in SQLite.")
|
||||
|
||||
def modify_column(self, model: "Type[Model]", field_object: Field):
|
||||
def modify_column(self, model: "Type[Model]", field_object: dict, is_pk: bool = True):
|
||||
raise NotSupportError("Modify column is unsupported in SQLite.")
|
||||
|
||||
def alter_column_default(self, model: "Type[Model]", field_describe: dict):
|
||||
raise NotSupportError("Alter column default is unsupported in SQLite.")
|
||||
|
||||
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
|
||||
raise NotSupportError("Alter column null is unsupported in SQLite.")
|
||||
|
||||
def set_comment(self, model: "Type[Model]", field_describe: dict):
|
||||
raise NotSupportError("Alter column comment is unsupported in SQLite.")
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
from dictdiffer import diff
|
||||
from tortoise import (
|
||||
BackwardFKRelation,
|
||||
@@ -173,6 +174,8 @@ class Migrate:
|
||||
new_models.pop(_aerich, None)
|
||||
|
||||
for new_model_str, new_model_describe in new_models.items():
|
||||
model = cls._get_model(new_model_describe.get("name").split(".")[1])
|
||||
|
||||
if new_model_str not in old_models.keys():
|
||||
cls._add_operator(cls.add_model(cls._get_model(new_model_str)), upgrade)
|
||||
else:
|
||||
@@ -180,6 +183,18 @@ class Migrate:
|
||||
|
||||
old_unique_together = old_model_describe.get("unique_together")
|
||||
new_unique_together = new_model_describe.get("unique_together")
|
||||
# add unique_together
|
||||
for index in set(new_unique_together).difference(set(old_unique_together)):
|
||||
cls._add_operator(
|
||||
cls._add_index(model, index, True),
|
||||
upgrade,
|
||||
)
|
||||
# remove unique_together
|
||||
for index in set(old_unique_together).difference(set(new_unique_together)):
|
||||
cls._add_operator(
|
||||
cls._drop_index(model, index, True),
|
||||
upgrade,
|
||||
)
|
||||
|
||||
old_data_fields = old_model_describe.get("data_fields")
|
||||
new_data_fields = new_model_describe.get("data_fields")
|
||||
@@ -187,10 +202,9 @@ class Migrate:
|
||||
old_data_fields_name = list(map(lambda x: x.get("name"), old_data_fields))
|
||||
new_data_fields_name = list(map(lambda x: x.get("name"), new_data_fields))
|
||||
|
||||
model = cls._get_model(new_model_describe.get("name").split(".")[1])
|
||||
# add fields
|
||||
for new_data_field_name in set(new_data_fields_name).difference(
|
||||
set(old_data_fields_name)
|
||||
set(old_data_fields_name)
|
||||
):
|
||||
cls._add_operator(
|
||||
cls._add_field(
|
||||
@@ -205,7 +219,7 @@ class Migrate:
|
||||
)
|
||||
# remove fields
|
||||
for old_data_field_name in set(old_data_fields_name).difference(
|
||||
set(new_data_fields_name)
|
||||
set(new_data_fields_name)
|
||||
):
|
||||
cls._add_operator(
|
||||
cls._remove_field(
|
||||
@@ -214,7 +228,7 @@ class Migrate:
|
||||
filter(
|
||||
lambda x: x.get("name") == old_data_field_name, old_data_fields
|
||||
)
|
||||
),
|
||||
).get("db_column"),
|
||||
),
|
||||
upgrade,
|
||||
)
|
||||
@@ -226,7 +240,7 @@ class Migrate:
|
||||
|
||||
# add fk
|
||||
for new_fk_field_name in set(new_fk_fields_name).difference(
|
||||
set(old_fk_fields_name)
|
||||
set(old_fk_fields_name)
|
||||
):
|
||||
fk_field = next(
|
||||
filter(lambda x: x.get("name") == new_fk_field_name, new_fk_fields)
|
||||
@@ -237,7 +251,7 @@ class Migrate:
|
||||
)
|
||||
# drop fk
|
||||
for old_fk_field_name in set(old_fk_fields_name).difference(
|
||||
set(new_fk_fields_name)
|
||||
set(new_fk_fields_name)
|
||||
):
|
||||
old_fk_field = next(
|
||||
filter(lambda x: x.get("name") == old_fk_field_name, old_fk_fields)
|
||||
@@ -259,23 +273,28 @@ class Migrate:
|
||||
changes = diff(old_data_field, new_data_field)
|
||||
for change in changes:
|
||||
_, option, old_new = change
|
||||
if option == 'indexed':
|
||||
if option == "indexed":
|
||||
# change index
|
||||
unique = new_data_field.get('unique')
|
||||
unique = new_data_field.get("unique")
|
||||
if old_new[0] is False and old_new[1] is True:
|
||||
cls._add_operator(
|
||||
cls._add_index(
|
||||
model, (field_name,), unique
|
||||
),
|
||||
cls._add_index(model, (field_name,), unique),
|
||||
upgrade,
|
||||
)
|
||||
else:
|
||||
cls._add_operator(
|
||||
cls._drop_index(
|
||||
model, (field_name,), unique
|
||||
),
|
||||
cls._drop_index(model, (field_name,), unique),
|
||||
upgrade,
|
||||
)
|
||||
elif option == "db_field_types.":
|
||||
# change column
|
||||
cls._add_operator(
|
||||
cls._change_field(model, old_data_field, new_data_field),
|
||||
upgrade,
|
||||
)
|
||||
elif option == "default":
|
||||
cls._add_operator(cls._alter_default(model, new_data_field), upgrade)
|
||||
|
||||
for old_model in old_models:
|
||||
if old_model not in new_models.keys():
|
||||
cls._add_operator(cls.remove_model(cls._get_model(old_model)), upgrade)
|
||||
@@ -340,40 +359,41 @@ class Migrate:
|
||||
return cls.ddl.add_column(model, field_describe, is_pk)
|
||||
|
||||
@classmethod
|
||||
def _alter_default(cls, model: Type[Model], field: Field):
|
||||
return cls.ddl.alter_column_default(model, field)
|
||||
def _alter_default(cls, model: Type[Model], field_describe: dict):
|
||||
return cls.ddl.alter_column_default(model, field_describe)
|
||||
|
||||
@classmethod
|
||||
def _alter_null(cls, model: Type[Model], field: Field):
|
||||
return cls.ddl.alter_column_null(model, field)
|
||||
def _alter_null(cls, model: Type[Model], field_describe: dict):
|
||||
return cls.ddl.alter_column_null(model, field_describe)
|
||||
|
||||
@classmethod
|
||||
def _set_comment(cls, model: Type[Model], field: Field):
|
||||
return cls.ddl.set_comment(model, field)
|
||||
def _set_comment(cls, model: Type[Model], field_describe: dict):
|
||||
return cls.ddl.set_comment(model, field_describe)
|
||||
|
||||
@classmethod
|
||||
def _modify_field(cls, model: Type[Model], field: Field):
|
||||
return cls.ddl.modify_column(model, field)
|
||||
def _modify_field(cls, model: Type[Model], field_describe: dict):
|
||||
return cls.ddl.modify_column(model, field_describe)
|
||||
|
||||
@classmethod
|
||||
def _drop_fk(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
|
||||
return cls.ddl.drop_fk(model, field_describe, reference_table_describe)
|
||||
|
||||
@classmethod
|
||||
def _remove_field(cls, model: Type[Model], field_describe: dict):
|
||||
return cls.ddl.drop_column(model, field_describe)
|
||||
def _remove_field(cls, model: Type[Model], column_name: str):
|
||||
return cls.ddl.drop_column(model, column_name)
|
||||
|
||||
@classmethod
|
||||
def _rename_field(cls, model: Type[Model], old_field: Field, new_field: Field):
|
||||
return cls.ddl.rename_column(model, old_field.model_field_name, new_field.model_field_name)
|
||||
|
||||
@classmethod
|
||||
def _change_field(cls, model: Type[Model], old_field: Field, new_field: Field):
|
||||
def _change_field(cls, model: Type[Model], old_field_describe: dict, new_field_describe: dict):
|
||||
db_field_types = new_field_describe.get("db_field_types")
|
||||
return cls.ddl.change_column(
|
||||
model,
|
||||
old_field.model_field_name,
|
||||
new_field.model_field_name,
|
||||
new_field.get_for_dialect(cls.dialect, "SQL_TYPE"),
|
||||
old_field_describe.get("db_column"),
|
||||
new_field_describe.get("db_column"),
|
||||
db_field_types.get(cls.dialect) or db_field_types.get(""),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user