add support m2m
This commit is contained in:
parent
0d94b22b3f
commit
38a3df9b5a
@ -1,9 +1,8 @@
|
||||
from enum import Enum
|
||||
from typing import List, Type
|
||||
|
||||
from tortoise import BaseDBAsyncClient, ManyToManyFieldInstance, Model
|
||||
from tortoise import BaseDBAsyncClient, Model
|
||||
from tortoise.backends.base.schema_generator import BaseSchemaGenerator
|
||||
from tortoise.fields import CASCADE
|
||||
|
||||
|
||||
class BaseDDL:
|
||||
@ -22,7 +21,7 @@ class BaseDDL:
|
||||
_DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"'
|
||||
_ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}'
|
||||
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"'
|
||||
_M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment};'
|
||||
_M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment}'
|
||||
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}'
|
||||
_CHANGE_COLUMN_TEMPLATE = (
|
||||
'ALTER TABLE "{table_name}" CHANGE {old_column_name} {new_column_name} {new_column_type}'
|
||||
@ -38,28 +37,34 @@ class BaseDDL:
|
||||
def drop_table(self, model: "Type[Model]"):
|
||||
return self._DROP_TABLE_TEMPLATE.format(table_name=model._meta.db_table)
|
||||
|
||||
def create_m2m_table(self, model: "Type[Model]", field: ManyToManyFieldInstance):
|
||||
def create_m2m(
|
||||
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
|
||||
):
|
||||
through = field_describe.get("through")
|
||||
description = field_describe.get("description")
|
||||
reference_id = reference_table_describe.get("pk_field").get("db_column")
|
||||
db_field_types = reference_table_describe.get("pk_field").get("db_field_types")
|
||||
return self._M2M_TABLE_TEMPLATE.format(
|
||||
table_name=field.through,
|
||||
table_name=through,
|
||||
backward_table=model._meta.db_table,
|
||||
forward_table=field.related_model._meta.db_table,
|
||||
forward_table=reference_table_describe.get("table"),
|
||||
backward_field=model._meta.db_pk_column,
|
||||
forward_field=field.related_model._meta.db_pk_column,
|
||||
backward_key=field.backward_key,
|
||||
forward_field=reference_id,
|
||||
backward_key=field_describe.get("backward_key"),
|
||||
backward_type=model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"),
|
||||
forward_key=field.forward_key,
|
||||
forward_type=field.related_model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"),
|
||||
on_delete=CASCADE,
|
||||
extra=self.schema_generator._table_generate_extra(table=field.through),
|
||||
forward_key=field_describe.get("forward_key"),
|
||||
forward_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
|
||||
on_delete=field_describe.get("on_delete"),
|
||||
extra=self.schema_generator._table_generate_extra(table=through),
|
||||
comment=self.schema_generator._table_comment_generator(
|
||||
table=field.through, comment=field.description
|
||||
table=through, comment=description
|
||||
)
|
||||
if field.description
|
||||
if description
|
||||
else "",
|
||||
)
|
||||
|
||||
def drop_m2m(self, field: ManyToManyFieldInstance):
|
||||
return self._DROP_TABLE_TEMPLATE.format(table_name=field.through)
|
||||
def drop_m2m(self, table_name: str):
|
||||
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
|
||||
|
||||
def _get_default(self, model: "Type[Model]", field_describe: dict):
|
||||
db_table = model._meta.db_table
|
||||
|
@ -26,7 +26,7 @@ class MysqlDDL(BaseDDL):
|
||||
_DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`"
|
||||
_ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}"
|
||||
_DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`"
|
||||
_M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment};"
|
||||
_M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment}"
|
||||
_MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}"
|
||||
|
||||
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
|
||||
|
@ -183,7 +183,39 @@ class Migrate:
|
||||
# current only support rename pk
|
||||
if action == "change" and option == "name":
|
||||
cls._add_operator(cls._rename_field(model, *change), upgrade)
|
||||
|
||||
# m2m fields
|
||||
old_m2m_fields = old_model_describe.get("m2m_fields")
|
||||
new_m2m_fields = new_model_describe.get("m2m_fields")
|
||||
for action, option, change in diff(old_m2m_fields, new_m2m_fields):
|
||||
table = change[0][1].get("through")
|
||||
if action == "add":
|
||||
add = False
|
||||
if upgrade and table not in cls._upgrade_m2m:
|
||||
cls._upgrade_m2m.append(table)
|
||||
add = True
|
||||
elif not upgrade and table not in cls._downgrade_m2m:
|
||||
cls._downgrade_m2m.append(table)
|
||||
add = True
|
||||
if add:
|
||||
cls._add_operator(
|
||||
cls.create_m2m(
|
||||
model,
|
||||
change[0][1],
|
||||
new_models.get(change[0][1].get("model_name")),
|
||||
),
|
||||
upgrade,
|
||||
fk_m2m=True,
|
||||
)
|
||||
elif action == "remove":
|
||||
add = False
|
||||
if upgrade and table not in cls._upgrade_m2m:
|
||||
cls._upgrade_m2m.append(table)
|
||||
add = True
|
||||
elif not upgrade and table not in cls._downgrade_m2m:
|
||||
cls._downgrade_m2m.append(table)
|
||||
add = True
|
||||
if add:
|
||||
cls._add_operator(cls.drop_m2m(table), upgrade, fk_m2m=True)
|
||||
# add unique_together
|
||||
for index in set(new_unique_together).difference(set(old_unique_together)):
|
||||
cls._add_operator(
|
||||
@ -298,6 +330,7 @@ class Migrate:
|
||||
cls._add_operator(
|
||||
cls._add_fk(model, fk_field, old_models.get(fk_field.get("python_type"))),
|
||||
upgrade,
|
||||
fk_m2m=True,
|
||||
)
|
||||
# drop fk
|
||||
for old_fk_field_name in set(old_fk_fields_name).difference(
|
||||
@ -311,6 +344,7 @@ class Migrate:
|
||||
model, old_fk_field, old_models.get(old_fk_field.get("python_type"))
|
||||
),
|
||||
upgrade,
|
||||
fk_m2m=True,
|
||||
)
|
||||
# change fields
|
||||
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
|
||||
@ -357,6 +391,14 @@ class Migrate:
|
||||
def remove_model(cls, model: Type[Model]):
|
||||
return cls.ddl.drop_table(model)
|
||||
|
||||
@classmethod
|
||||
def create_m2m(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
|
||||
return cls.ddl.create_m2m(model, field_describe, reference_table_describe)
|
||||
|
||||
@classmethod
|
||||
def drop_m2m(cls, table_name: str):
|
||||
return cls.ddl.drop_m2m(table_name)
|
||||
|
||||
@classmethod
|
||||
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]):
|
||||
ret = []
|
||||
|
@ -20,7 +20,10 @@ tortoise_orm = {
|
||||
"second": expand_db_url(db_url_second, True),
|
||||
},
|
||||
"apps": {
|
||||
"models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"},
|
||||
"models": {
|
||||
"models": ["tests.models", "aerich.models"],
|
||||
"default_connection": "default",
|
||||
},
|
||||
"models_second": {"models": ["tests.models_second"], "default_connection": "second"},
|
||||
},
|
||||
}
|
||||
|
2
poetry.lock
generated
2
poetry.lock
generated
@ -635,6 +635,7 @@ click = [
|
||||
]
|
||||
colorama = [
|
||||
{file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"},
|
||||
{file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"},
|
||||
]
|
||||
ddlparse = [
|
||||
{file = "ddlparse-1.9.0-py3-none-any.whl", hash = "sha256:a7962615a9325be7d0f182cbe34011e6283996473fb98c784c6f675b9783bc18"},
|
||||
@ -665,6 +666,7 @@ importlib-metadata = [
|
||||
{file = "importlib_metadata-3.4.0.tar.gz", hash = "sha256:fa5daa4477a7414ae34e95942e4dd07f62adf589143c875c133c1e53c4eff38d"},
|
||||
]
|
||||
iniconfig = [
|
||||
{file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
|
||||
{file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"},
|
||||
]
|
||||
iso8601 = [
|
||||
|
@ -35,6 +35,7 @@ class Email(Model):
|
||||
email = fields.CharField(max_length=200, index=True)
|
||||
is_primary = fields.BooleanField(default=False)
|
||||
address = fields.CharField(max_length=200)
|
||||
users = fields.ManyToManyField("models.User")
|
||||
|
||||
|
||||
class Category(Model):
|
||||
|
@ -81,6 +81,8 @@ old_models_describe = {
|
||||
"mysql": "DATETIME(6)",
|
||||
"postgres": "TIMESTAMPTZ",
|
||||
},
|
||||
"auto_now_add": True,
|
||||
"auto_now": False,
|
||||
},
|
||||
{
|
||||
"name": "user_id",
|
||||
@ -131,6 +133,13 @@ old_models_describe = {
|
||||
"description": None,
|
||||
"docstring": None,
|
||||
"constraints": {},
|
||||
"model_name": "models.Product",
|
||||
"related_name": "categories",
|
||||
"forward_key": "product_id",
|
||||
"backward_key": "category_id",
|
||||
"through": "product_category",
|
||||
"on_delete": "CASCADE",
|
||||
"_generated": True,
|
||||
}
|
||||
],
|
||||
},
|
||||
@ -464,6 +473,8 @@ old_models_describe = {
|
||||
"mysql": "DATETIME(6)",
|
||||
"postgres": "TIMESTAMPTZ",
|
||||
},
|
||||
"auto_now_add": True,
|
||||
"auto_now": False,
|
||||
},
|
||||
],
|
||||
"fk_fields": [],
|
||||
@ -483,6 +494,13 @@ old_models_describe = {
|
||||
"description": None,
|
||||
"docstring": None,
|
||||
"constraints": {},
|
||||
"model_name": "models.Category",
|
||||
"related_name": "products",
|
||||
"forward_key": "category_id",
|
||||
"backward_key": "product_id",
|
||||
"through": "product_category",
|
||||
"on_delete": "CASCADE",
|
||||
"_generated": False,
|
||||
}
|
||||
],
|
||||
},
|
||||
@ -558,6 +576,8 @@ old_models_describe = {
|
||||
"mysql": "DATETIME(6)",
|
||||
"postgres": "TIMESTAMPTZ",
|
||||
},
|
||||
"auto_now_add": False,
|
||||
"auto_now": False,
|
||||
},
|
||||
{
|
||||
"name": "is_active",
|
||||
@ -741,6 +761,7 @@ def test_migrate(mocker: MockerFixture):
|
||||
- drop fk: Email.user
|
||||
- drop field: User.avatar
|
||||
- add index: Email.email
|
||||
- add many to many: Email.users
|
||||
- remove unique: User.username
|
||||
- change column: length User.password
|
||||
- add unique_together: (name,type) of Product
|
||||
@ -776,6 +797,7 @@ def test_migrate(mocker: MockerFixture):
|
||||
"ALTER TABLE `user` DROP COLUMN `avatar`",
|
||||
"ALTER TABLE `user` CHANGE password password VARCHAR(100)",
|
||||
"ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)",
|
||||
"CREATE TABLE `email_user` (`email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,`user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE) CHARACTER SET utf8mb4",
|
||||
]
|
||||
)
|
||||
|
||||
@ -795,6 +817,7 @@ def test_migrate(mocker: MockerFixture):
|
||||
"ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''",
|
||||
"ALTER TABLE `user` DROP INDEX `idx_user_usernam_9987ab`",
|
||||
"ALTER TABLE `user` CHANGE password password VARCHAR(200)",
|
||||
"DROP TABLE IF EXISTS `email_user`",
|
||||
]
|
||||
)
|
||||
|
||||
@ -815,6 +838,7 @@ def test_migrate(mocker: MockerFixture):
|
||||
'ALTER TABLE "user" DROP COLUMN "avatar"',
|
||||
'ALTER TABLE "user" CHANGE password password VARCHAR(100)',
|
||||
'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")',
|
||||
'CREATE TABLE "email_user" ("email_id" INT NOT NULL REFERENCES "email" ("email_id") ON DELETE CASCADE,"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE)',
|
||||
]
|
||||
)
|
||||
assert sorted(Migrate.downgrade_operators) == sorted(
|
||||
@ -833,6 +857,7 @@ def test_migrate(mocker: MockerFixture):
|
||||
'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'',
|
||||
'DROP INDEX "idx_user_usernam_9987ab"',
|
||||
'ALTER TABLE "user" CHANGE password password VARCHAR(200)',
|
||||
'DROP TABLE IF EXISTS "email_user"',
|
||||
]
|
||||
)
|
||||
elif isinstance(Migrate.ddl, SqliteDDL):
|
||||
|
Loading…
x
Reference in New Issue
Block a user