add support m2m

This commit is contained in:
long2ice 2021-02-03 22:22:22 +08:00
parent 0d94b22b3f
commit 38a3df9b5a
7 changed files with 97 additions and 19 deletions

View File

@ -1,9 +1,8 @@
from enum import Enum
from typing import List, Type
from tortoise import BaseDBAsyncClient, ManyToManyFieldInstance, Model
from tortoise import BaseDBAsyncClient, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.fields import CASCADE
class BaseDDL:
@ -22,7 +21,7 @@ class BaseDDL:
_DROP_INDEX_TEMPLATE = 'ALTER TABLE "{table_name}" DROP INDEX "{index_name}"'
_ADD_FK_TEMPLATE = 'ALTER TABLE "{table_name}" ADD CONSTRAINT "{fk_name}" FOREIGN KEY ("{db_column}") REFERENCES "{table}" ("{field}") ON DELETE {on_delete}'
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP FOREIGN KEY "{fk_name}"'
_M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment};'
_M2M_TABLE_TEMPLATE = 'CREATE TABLE "{table_name}" ("{backward_key}" {backward_type} NOT NULL REFERENCES "{backward_table}" ("{backward_field}") ON DELETE CASCADE,"{forward_key}" {forward_type} NOT NULL REFERENCES "{forward_table}" ("{forward_field}") ON DELETE {on_delete}){extra}{comment}'
_MODIFY_COLUMN_TEMPLATE = 'ALTER TABLE "{table_name}" MODIFY COLUMN {column}'
_CHANGE_COLUMN_TEMPLATE = (
'ALTER TABLE "{table_name}" CHANGE {old_column_name} {new_column_name} {new_column_type}'
@ -38,28 +37,34 @@ class BaseDDL:
def drop_table(self, model: "Type[Model]"):
return self._DROP_TABLE_TEMPLATE.format(table_name=model._meta.db_table)
def create_m2m_table(self, model: "Type[Model]", field: ManyToManyFieldInstance):
def create_m2m(
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
):
through = field_describe.get("through")
description = field_describe.get("description")
reference_id = reference_table_describe.get("pk_field").get("db_column")
db_field_types = reference_table_describe.get("pk_field").get("db_field_types")
return self._M2M_TABLE_TEMPLATE.format(
table_name=field.through,
table_name=through,
backward_table=model._meta.db_table,
forward_table=field.related_model._meta.db_table,
forward_table=reference_table_describe.get("table"),
backward_field=model._meta.db_pk_column,
forward_field=field.related_model._meta.db_pk_column,
backward_key=field.backward_key,
forward_field=reference_id,
backward_key=field_describe.get("backward_key"),
backward_type=model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"),
forward_key=field.forward_key,
forward_type=field.related_model._meta.pk.get_for_dialect(self.DIALECT, "SQL_TYPE"),
on_delete=CASCADE,
extra=self.schema_generator._table_generate_extra(table=field.through),
forward_key=field_describe.get("forward_key"),
forward_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
on_delete=field_describe.get("on_delete"),
extra=self.schema_generator._table_generate_extra(table=through),
comment=self.schema_generator._table_comment_generator(
table=field.through, comment=field.description
table=through, comment=description
)
if field.description
if description
else "",
)
def drop_m2m(self, field: ManyToManyFieldInstance):
return self._DROP_TABLE_TEMPLATE.format(table_name=field.through)
def drop_m2m(self, table_name: str):
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
def _get_default(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table

View File

@ -26,7 +26,7 @@ class MysqlDDL(BaseDDL):
_DROP_INDEX_TEMPLATE = "ALTER TABLE `{table_name}` DROP INDEX `{index_name}`"
_ADD_FK_TEMPLATE = "ALTER TABLE `{table_name}` ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}"
_DROP_FK_TEMPLATE = "ALTER TABLE `{table_name}` DROP FOREIGN KEY `{fk_name}`"
_M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment};"
_M2M_TABLE_TEMPLATE = "CREATE TABLE `{table_name}` (`{backward_key}` {backward_type} NOT NULL REFERENCES `{backward_table}` (`{backward_field}`) ON DELETE CASCADE,`{forward_key}` {forward_type} NOT NULL REFERENCES `{forward_table}` (`{forward_field}`) ON DELETE CASCADE){extra}{comment}"
_MODIFY_COLUMN_TEMPLATE = "ALTER TABLE `{table_name}` MODIFY COLUMN {column}"
def alter_column_null(self, model: "Type[Model]", field_describe: dict):

View File

@ -183,7 +183,39 @@ class Migrate:
# current only support rename pk
if action == "change" and option == "name":
cls._add_operator(cls._rename_field(model, *change), upgrade)
# m2m fields
old_m2m_fields = old_model_describe.get("m2m_fields")
new_m2m_fields = new_model_describe.get("m2m_fields")
for action, option, change in diff(old_m2m_fields, new_m2m_fields):
table = change[0][1].get("through")
if action == "add":
add = False
if upgrade and table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table)
add = True
elif not upgrade and table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table)
add = True
if add:
cls._add_operator(
cls.create_m2m(
model,
change[0][1],
new_models.get(change[0][1].get("model_name")),
),
upgrade,
fk_m2m=True,
)
elif action == "remove":
add = False
if upgrade and table not in cls._upgrade_m2m:
cls._upgrade_m2m.append(table)
add = True
elif not upgrade and table not in cls._downgrade_m2m:
cls._downgrade_m2m.append(table)
add = True
if add:
cls._add_operator(cls.drop_m2m(table), upgrade, fk_m2m=True)
# add unique_together
for index in set(new_unique_together).difference(set(old_unique_together)):
cls._add_operator(
@ -298,6 +330,7 @@ class Migrate:
cls._add_operator(
cls._add_fk(model, fk_field, old_models.get(fk_field.get("python_type"))),
upgrade,
fk_m2m=True,
)
# drop fk
for old_fk_field_name in set(old_fk_fields_name).difference(
@ -311,6 +344,7 @@ class Migrate:
model, old_fk_field, old_models.get(old_fk_field.get("python_type"))
),
upgrade,
fk_m2m=True,
)
# change fields
for field_name in set(new_data_fields_name).intersection(set(old_data_fields_name)):
@ -357,6 +391,14 @@ class Migrate:
def remove_model(cls, model: Type[Model]):
return cls.ddl.drop_table(model)
@classmethod
def create_m2m(cls, model: Type[Model], field_describe: dict, reference_table_describe: dict):
return cls.ddl.create_m2m(model, field_describe, reference_table_describe)
@classmethod
def drop_m2m(cls, table_name: str):
return cls.ddl.drop_m2m(table_name)
@classmethod
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]):
ret = []

View File

@ -20,7 +20,10 @@ tortoise_orm = {
"second": expand_db_url(db_url_second, True),
},
"apps": {
"models": {"models": ["tests.models", "aerich.models"], "default_connection": "default"},
"models": {
"models": ["tests.models", "aerich.models"],
"default_connection": "default",
},
"models_second": {"models": ["tests.models_second"], "default_connection": "second"},
},
}

2
poetry.lock generated
View File

@ -635,6 +635,7 @@ click = [
]
colorama = [
{file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"},
{file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"},
]
ddlparse = [
{file = "ddlparse-1.9.0-py3-none-any.whl", hash = "sha256:a7962615a9325be7d0f182cbe34011e6283996473fb98c784c6f675b9783bc18"},
@ -665,6 +666,7 @@ importlib-metadata = [
{file = "importlib_metadata-3.4.0.tar.gz", hash = "sha256:fa5daa4477a7414ae34e95942e4dd07f62adf589143c875c133c1e53c4eff38d"},
]
iniconfig = [
{file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"},
{file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"},
]
iso8601 = [

View File

@ -35,6 +35,7 @@ class Email(Model):
email = fields.CharField(max_length=200, index=True)
is_primary = fields.BooleanField(default=False)
address = fields.CharField(max_length=200)
users = fields.ManyToManyField("models.User")
class Category(Model):

View File

@ -81,6 +81,8 @@ old_models_describe = {
"mysql": "DATETIME(6)",
"postgres": "TIMESTAMPTZ",
},
"auto_now_add": True,
"auto_now": False,
},
{
"name": "user_id",
@ -131,6 +133,13 @@ old_models_describe = {
"description": None,
"docstring": None,
"constraints": {},
"model_name": "models.Product",
"related_name": "categories",
"forward_key": "product_id",
"backward_key": "category_id",
"through": "product_category",
"on_delete": "CASCADE",
"_generated": True,
}
],
},
@ -464,6 +473,8 @@ old_models_describe = {
"mysql": "DATETIME(6)",
"postgres": "TIMESTAMPTZ",
},
"auto_now_add": True,
"auto_now": False,
},
],
"fk_fields": [],
@ -483,6 +494,13 @@ old_models_describe = {
"description": None,
"docstring": None,
"constraints": {},
"model_name": "models.Category",
"related_name": "products",
"forward_key": "category_id",
"backward_key": "product_id",
"through": "product_category",
"on_delete": "CASCADE",
"_generated": False,
}
],
},
@ -558,6 +576,8 @@ old_models_describe = {
"mysql": "DATETIME(6)",
"postgres": "TIMESTAMPTZ",
},
"auto_now_add": False,
"auto_now": False,
},
{
"name": "is_active",
@ -741,6 +761,7 @@ def test_migrate(mocker: MockerFixture):
- drop fk: Email.user
- drop field: User.avatar
- add index: Email.email
- add many to many: Email.users
- remove unique: User.username
- change column: length User.password
- add unique_together: (name,type) of Product
@ -776,6 +797,7 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `user` DROP COLUMN `avatar`",
"ALTER TABLE `user` CHANGE password password VARCHAR(100)",
"ALTER TABLE `user` ADD UNIQUE INDEX `uid_user_usernam_9987ab` (`username`)",
"CREATE TABLE `email_user` (`email_id` INT NOT NULL REFERENCES `email` (`email_id`) ON DELETE CASCADE,`user_id` INT NOT NULL REFERENCES `user` (`id`) ON DELETE CASCADE) CHARACTER SET utf8mb4",
]
)
@ -795,6 +817,7 @@ def test_migrate(mocker: MockerFixture):
"ALTER TABLE `user` ADD `avatar` VARCHAR(200) NOT NULL DEFAULT ''",
"ALTER TABLE `user` DROP INDEX `idx_user_usernam_9987ab`",
"ALTER TABLE `user` CHANGE password password VARCHAR(200)",
"DROP TABLE IF EXISTS `email_user`",
]
)
@ -815,6 +838,7 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "user" DROP COLUMN "avatar"',
'ALTER TABLE "user" CHANGE password password VARCHAR(100)',
'CREATE UNIQUE INDEX "uid_user_usernam_9987ab" ON "user" ("username")',
'CREATE TABLE "email_user" ("email_id" INT NOT NULL REFERENCES "email" ("email_id") ON DELETE CASCADE,"user_id" INT NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE)',
]
)
assert sorted(Migrate.downgrade_operators) == sorted(
@ -833,6 +857,7 @@ def test_migrate(mocker: MockerFixture):
'ALTER TABLE "user" ADD "avatar" VARCHAR(200) NOT NULL DEFAULT \'\'',
'DROP INDEX "idx_user_usernam_9987ab"',
'ALTER TABLE "user" CHANGE password password VARCHAR(200)',
'DROP TABLE IF EXISTS "email_user"',
]
)
elif isinstance(Migrate.ddl, SqliteDDL):