add support m2m

This commit is contained in:
long2ice 2021-02-03 22:22:22 +08:00
parent 0d94b22b3f
commit 38a3df9b5a
7 changed files with 97 additions and 19 deletions

View File

@ -1,9 +1,8 @@
from enum import Enum from enum import Enum
from typing import List, Type from typing import List, Type
from tortoise import BaseDBAsyncClient, ManyToManyFieldInstance, Model from tortoise import BaseDBAsyncClient, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.fields import CASCADE
class BaseDDL: class BaseDDL:
@ -22,7 +21,7 @@ class BaseDDL:
_DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"' _DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"'
_ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}' _ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"' _DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"'
_M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment};' _M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment}'
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}' _MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}'
_CHANGE_COLUMN_TEMPLATE = ( _CHANGE_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" CHANGE {old_column_name} {new_column_name} {new_column_type}' 'ALTER TABLE "{table_name}" CHANGE {old_column_name} {new_column_name} {new_column_type}'
@ -38,28 +37,34 @@ class BaseDDL:
def drop_table(self, model: "Type[Model]"): def drop_table(self, model: "Type[Model]"):
return self._DROP_TABLE_TEMPLATE.format(table_name=model._meta.db_table) return self._DROP_TABLE_TEMPLATE.format(table_name=model._meta.db_table)
def create_m2m_table(self, model: "Type[Model]", field: ManyToManyFieldInstance): def create_m2m(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
):
through = field_describe.get("through")
description = field_describe.get("description")
reference_id = reference_table_describe.get("pk_field").get("db_column")
db_field_types = reference_table_describe.get("pk_field").get("db_field_types")
return self._M2M_TABLE_TEMPLATE.format( return self._M2M_TABLE_TEMPLATE.format(
table_name=field.through, table_name=through,
backward_table=model._meta.db_table, backward_table=model._meta.db_table,
forward_table=field.related_model._meta.db_table, forward_table=reference_table_describe.get("table"),
backward_field=model._meta.db_pk_column, backward_field=model._meta.db_pk_column,
forward_field=field.related_model._meta.db_pk_column, forward_field=reference_id,
backward_key=field.backward_key, backward_key=field_describe.get("backward_key"),
backward_type=model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"), backward_type=model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"),
forward_key=field.forward_key, forward_key=field_describe.get("forward_key"),
forward_type=field.related_model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"), forward_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
on_delete=CASCADE, on_delete=field_describe.get("on_delete"),
extra=self.schema_generator._table_generate_extra(table=field.through), extra=self.schema_generator._table_generate_extra(table=through),
comment=self.schema_generator._table_comment_generator( comment=self.schema_generator._table_comment_generator(
table=field.through, comment=field.description table=through, comment=description
) )
if field.description if description
else "", else "",
) )
def drop_m2m(self, field: ManyToManyFieldInstance): def drop_m2m(self, table_name: str):
return self._DROP_TABLE_TEMPLATE.format(table_name=field.through) return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def _get_default(self, model: "Type[Model]", field_describe: dict): def _get_default(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table db_table = model._meta.db_table

View File

@ -26,7 +26,7 @@ class MysqlDDL(BaseDDL):
_DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`" _DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`"
_ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}" _ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}"
_DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`" _DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`"
_M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment};" _M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment}"
_MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}" _MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}"
def alter_column_null(self, model: "Type[Model]", field_describe: dict): def alter_column_null(self, model: "Type[Model]", field_describe: dict):

View File

@ -183,7 +183,39 @@ class Migrate:
# current only support rename pk # current only support rename pk
if action == "change" and option == "name": if action == "change" and option == "name":
cls._add_operator(cls._rename_field(model, *change), upgrade) cls._add_operator(cls._rename_field(model, *change), upgrade)
# m2m fields
old_m2m_fields = old_model_describe.get("m2m_fields")
new_m2m_fields = new_model_describe.get("m2m_fields")
for action, option, change in diff(old_m2m_fields, new_m2m_fields):
table = change[0][1].get("through")
if action == "add":
add = False
if upgrade and table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table)
add = True
elif not upgrade and table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table)
add = True
if add:
cls._add_operator(
cls.create_m2m(
model,
change[0][1],
new_models.get(change[0][1].get("model_name")),
),
upgrade,
fk_m2m=True,
)
elif action == "remove":
add = False
if upgrade and table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table)
add = True
elif not upgrade and table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table)
add = True
if add:
cls._add_operator(cls.drop_m2m(table), upgrade, fk_m2m=True)
# add unique_together # add unique_together
for index in set(new_unique_together).difference(set(old_unique_together)): for index in set(new_unique_together).difference(set(old_unique_together)):
cls._add_operator( cls._add_operator(
@ -298,6 +330,7 @@ class Migrate:
cls._add_operator( cls._add_operator(
cls._add_fk(model, fk_field, old_models.get(fk_field.get("python_type"))), cls._add_fk(model, fk_field, old_models.get(fk_field.get("python_type"))),
upgrade, upgrade,
fk_m2m=True,
) )
# drop fk # drop fk
for old_fk_field_name in set(old_fk_fields_name).difference( for old_fk_field_name in set(old_fk_fields_name).difference(
@ -311,6 +344,7 @@ class Migrate:
model, old_fk_field, old_models.get(old_fk_field.get("python_type")) model, old_fk_field, old_models.get(old_fk_field.get("python_type"))
), ),
upgrade, upgrade,
fk_m2m=True,
) )
# change fields # change fields
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)): for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
@ -357,6 +391,14 @@ class Migrate:
def remove_model(cls, model: Type[Model]): def remove_model(cls, model: Type[Model]):
return cls.ddl.drop_table(model) return cls.ddl.drop_table(model)
@classmethod
def create_m2m(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
return cls.ddl.create_m2m(model, field_describe, reference_table_describe)
@classmethod
def drop_m2m(cls, table_name: str):
return cls.ddl.drop_m2m(table_name)
@classmethod @classmethod
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]): def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]):
ret = [] ret = []

View File

@ -20,7 +20,10 @@ tortoise_orm = {
"second": expand_db_url(db_url_second, True), "second": expand_db_url(db_url_second, True),
}, },
"apps": { "apps": {
"models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"}, "models": {
"models": ["tests.models", "aerich.models"],
"default_connection": "default",
},
"models_second": {"models": ["tests.models_second"], "default_connection": "second"}, "models_second": {"models": ["tests.models_second"], "default_connection": "second"},
}, },
} }

2
poetry.lock generated
View File

@ -635,6 +635,7 @@ click = [
] ]
colorama = [ colorama = [
{file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"},
{file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"},
] ]
ddlparse = [ ddlparse = [
{file = "ddlparse-1.9.0-py3-none-any.whl", hash = "sha256:a7962615a9325be7d0f182cbe34011e6283996473fb98c784c6f675b9783bc18"}, {file = "ddlparse-1.9.0-py3-none-any.whl", hash = "sha256:a7962615a9325be7d0f182cbe34011e6283996473fb98c784c6f675b9783bc18"},
@ -665,6 +666,7 @@ importlib-metadata = [
{file = "importlib_metadata-3.4.0.tar.gz", hash = "sha256:fa5daa4477a7414ae34e95942e4dd07f62adf589143c875c133c1e53c4eff38d"}, {file = "importlib_metadata-3.4.0.tar.gz", hash = "sha256:fa5daa4477a7414ae34e95942e4dd07f62adf589143c875c133c1e53c4eff38d"},
] ]
iniconfig = [ iniconfig = [
{file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
{file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"},
] ]
iso8601 = [ iso8601 = [

View File

@ -35,6 +35,7 @@ class Email(Model):
email = fields.CharField(max_length=200, index=True) email = fields.CharField(max_length=200, index=True)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
address = fields.CharField(max_length=200) address = fields.CharField(max_length=200)
users = fields.ManyToManyField("models.User")
class Category(Model): class Category(Model):

View File

@ -81,6 +81,8 @@ old_models_describe = {
"mysql": "DATETIME(6)", "mysql": "DATETIME(6)",
"postgres": "TIMESTAMPTZ", "postgres": "TIMESTAMPTZ",
}, },
"auto_now_add": True,
"auto_now": False,
}, },
{ {
"name": "user_id", "name": "user_id",
@ -131,6 +133,13 @@ old_models_describe = {
"description": None, "description": None,
"docstring": None, "docstring": None,
"constraints": {}, "constraints": {},
"model_name": "models.Product",
"related_name": "categories",
"forward_key": "product_id",
"backward_key": "category_id",
"through": "product_category",
"on_delete": "CASCADE",
"_generated": True,
} }
], ],
}, },
@ -464,6 +473,8 @@ old_models_describe = {
"mysql": "DATETIME(6)", "mysql": "DATETIME(6)",
"postgres": "TIMESTAMPTZ", "postgres": "TIMESTAMPTZ",
}, },
"auto_now_add": True,
"auto_now": False,
}, },
], ],
"fk_fields": [], "fk_fields": [],
@ -483,6 +494,13 @@ old_models_describe = {
"description": None, "description": None,
"docstring": None, "docstring": None,
"constraints": {}, "constraints": {},
"model_name": "models.Category",
"related_name": "products",
"forward_key": "category_id",
"backward_key": "product_id",
"through": "product_category",
"on_delete": "CASCADE",
"_generated": False,
} }
], ],
}, },
@ -558,6 +576,8 @@ old_models_describe = {
"mysql": "DATETIME(6)", "mysql": "DATETIME(6)",
"postgres": "TIMESTAMPTZ", "postgres": "TIMESTAMPTZ",
}, },
"auto_now_add": False,
"auto_now": False,
}, },
{ {
"name": "is_active", "name": "is_active",
@ -741,6 +761,7 @@ def test_migrate(mocker: MockerFixture):
- drop fk: Email.user - drop fk: Email.user
- drop field: User.avatar - drop field: User.avatar
- add index: Email.email - add index: Email.email
- add many to many: Email.users
- remove unique: User.username - remove unique: User.username
- change column: length User.password - change column: length User.password
- add unique_together: (name,type) of Product - add unique_together: (name,type) of Product
@ -776,6 +797,7 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `user` DROP COLUMN `avatar`", "ALTER TABLE `user` DROP COLUMN `avatar`",
"ALTER TABLE `user` CHANGE password password VARCHAR(100)", "ALTER TABLE `user` CHANGE password password VARCHAR(100)",
"ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)", "ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)",
"CREATE TABLE `email_user` (`email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,`user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE) CHARACTER SET utf8mb4",
] ]
) )
@ -795,6 +817,7 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''", "ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''",
"ALTER TABLE `user` DROP INDEX `idx_user_usernam_9987ab`", "ALTER TABLE `user` DROP INDEX `idx_user_usernam_9987ab`",
"ALTER TABLE `user` CHANGE password password VARCHAR(200)", "ALTER TABLE `user` CHANGE password password VARCHAR(200)",
"DROP TABLE IF EXISTS `email_user`",
] ]
) )
@ -815,6 +838,7 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "user" DROP COLUMN "avatar"', 'ALTER TABLE "user" DROP COLUMN "avatar"',
'ALTER TABLE "user" CHANGE password password VARCHAR(100)', 'ALTER TABLE "user" CHANGE password password VARCHAR(100)',
'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")', 'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")',
'CREATE TABLE "email_user" ("email_id" INT NOT NULL REFERENCES "email" ("email_id") ON DELETE CASCADE,"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE)',
] ]
) )
assert sorted(Migrate.downgrade_operators) == sorted( assert sorted(Migrate.downgrade_operators) == sorted(
@ -833,6 +857,7 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'', 'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'',
'DROP INDEX "idx_user_usernam_9987ab"', 'DROP INDEX "idx_user_usernam_9987ab"',
'ALTER TABLE "user" CHANGE password password VARCHAR(200)', 'ALTER TABLE "user" CHANGE password password VARCHAR(200)',
'DROP TABLE IF EXISTS "email_user"',
] ]
) )
elif isinstance(Migrate.ddl, SqliteDDL): elif isinstance(Migrate.ddl, SqliteDDL):