This commit is contained in:
long2ice 2021-02-01 16:54:35 +08:00
parent 36f84702b7
commit f443dc68db
7 changed files with 81 additions and 104 deletions

View File

@ -36,7 +36,6 @@ def coro(f):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(f(*args, **kwargs)) loop.run_until_complete(f(*args, **kwargs))
loop.run_until_complete(Tortoise.close_connections())
return wrapper return wrapper
@ -221,9 +220,9 @@ async def history(ctx: Context):
@click.pass_context @click.pass_context
@coro @coro
async def init( async def init(
ctx: Context, ctx: Context,
tortoise_orm, tortoise_orm,
location, location,
): ):
config_file = ctx.obj["config_file"] config_file = ctx.obj["config_file"]
name = ctx.obj["name"] name = ctx.obj["name"]

View File

@ -59,17 +59,16 @@ class BaseDDL:
def drop_m2m(self, field: ManyToManyFieldInstance): def drop_m2m(self, field: ManyToManyFieldInstance):
return self._DROP_TABLE_TEMPLATE.format(table_name=field.through) return self._DROP_TABLE_TEMPLATE.format(table_name=field.through)
def _get_default(self, model: "Type[Model]", field_object: Field): def _get_default(self, model: "Type[Model]", field_describe: dict):
db_table = model._meta.db_table db_table = model._meta.db_table
default = field_object.default default = field_describe.get('default')
db_column = field_object.model_field_name db_column = field_describe.get('db_column')
auto_now_add = getattr(field_object, "auto_now_add", False) auto_now_add = field_describe.get("auto_now_add", False)
auto_now = getattr(field_object, "auto_now", False) auto_now = field_describe.get( "auto_now", False)
if default is not None or auto_now_add: if default is not None or auto_now_add:
if callable(default) or isinstance(field_object, (UUIDField, TextField, JSONField)): if field_describe.get('field_type')in ['UUIDField', 'TextField', 'JSONField']:
default = "" default = ""
else: else:
default = field_object.to_db_value(default, model)
try: try:
default = self.schema_generator._column_default_generator( default = self.schema_generator._column_default_generator(
db_table, db_table,
@ -104,13 +103,13 @@ class BaseDDL:
if description if description
else "", else "",
is_primary_key=is_pk, is_primary_key=is_pk,
default=field_describe.get("default"), default=self._get_default(model,field_describe),
), ),
) )
def drop_column(self, model: "Type[Model]", column_name: str): def drop_column(self, model: "Type[Model]", field_describe: dict):
return self._DROP_COLUMN_TEMPLATE.format( return self._DROP_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, column_name=column_name table_name=model._meta.db_table, column_name=field_describe.get('db_column')
) )
def modify_column(self, model: "Type[Model]", field_object: Field): def modify_column(self, model: "Type[Model]", field_object: Field):
@ -142,7 +141,7 @@ class BaseDDL:
) )
def change_column( def change_column(
self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str
): ):
return self._CHANGE_COLUMN_TEMPLATE.format( return self._CHANGE_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, table_name=model._meta.db_table,
@ -169,37 +168,34 @@ class BaseDDL:
table_name=model._meta.db_table, table_name=model._meta.db_table,
) )
def add_fk(self, model: "Type[Model]", field: dict): def add_fk(self, model: "Type[Model]", field_describe: dict, field_describe_target: dict):
db_table = model._meta.db_table db_table = model._meta.db_table
db_column = field.get("db_column") db_column = field_describe.get("raw_field")
fk_name = self.schema_generator._generate_fk_name( fk_name = self.schema_generator._generate_fk_name(
from_table=db_table, from_table=db_table,
from_field=db_column, from_field=db_column,
to_table=field.related_model._meta.db_table, to_table=field_describe.get('name'),
to_field=db_column, to_field=db_column,
) )
return self._ADD_FK_TEMPLATE.format( return self._ADD_FK_TEMPLATE.format(
table_name=db_table, table_name=db_table,
fk_name=fk_name, fk_name=fk_name,
db_column=db_column, db_column=db_column,
table=field.related_model._meta.db_table, table=field_describe.get('name'),
field=db_column, field=db_column,
on_delete=field.get("on_delete"), on_delete=field_describe.get('on_delete'),
) )
def drop_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance): def drop_fk(self, model: "Type[Model]", field_describe: dict, field_describe_target: dict):
to_field_name = field.to_field_instance.source_field
if not to_field_name:
to_field_name = field.to_field_instance.model_field_name
db_table = model._meta.db_table db_table = model._meta.db_table
return self._DROP_FK_TEMPLATE.format( return self._DROP_FK_TEMPLATE.format(
table_name=db_table, table_name=db_table,
fk_name=self.schema_generator._generate_fk_name( fk_name=self.schema_generator._generate_fk_name(
from_table=db_table, from_table=db_table,
from_field=field.source_field or field.model_field_name + "_id", from_field=field_describe.get('raw_field'),
to_table=field.related_model._meta.db_table, to_table=field_describe.get('name'),
to_field=to_field_name, to_field=field_describe_target.get('db_column'),
), ),
) )

View File

@ -3,7 +3,7 @@ from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Tuple, Type
from dictdiffer import diff import click
from tortoise import ( from tortoise import (
BackwardFKRelation, BackwardFKRelation,
BackwardOneToOneRelation, BackwardOneToOneRelation,
@ -204,30 +204,49 @@ class Migrate:
:param upgrade: :param upgrade:
:return: :return:
""" """
for change in diff(old_model_describe, new_model_describe):
action, field_type, fields = change old_unique_together = old_model_describe.get('unique_together')
is_pk = field_type == "pk_field" new_unique_together = new_model_describe.get('unique_together')
if action == "add":
for field in fields: old_data_fields = old_model_describe.get('data_fields')
_, field_describe = field new_data_fields = new_model_describe.get('data_fields')
cls._add_operator(
cls._add_field( old_data_fields_name = list(map(lambda x: x.get('name'), old_data_fields))
cls._get_model(new_model_describe.get("name").split(".")[1]), new_data_fields_name = list(map(lambda x: x.get('name'), new_data_fields))
field_describe,
is_pk, model = cls._get_model(new_model_describe.get('name').split('.')[1])
), # add fields
upgrade, for new_data_field_name in set(new_data_fields_name).difference(set(old_data_fields_name)):
) cls._add_operator(
elif action == "remove": cls._add_field(model, next(filter(lambda x: x.get('name') == new_data_field_name, new_data_fields))),
for field in fields: upgrade)
_, field_describe = field # remove fields
cls._add_operator( for old_data_field_name in set(old_data_fields_name).difference(set(new_data_fields_name)):
cls._remove_field( cls._add_operator(
cls._get_model(new_model_describe.get("name").split(".")[1]), cls._remove_field(model, next(filter(lambda x: x.get('name') == old_data_field_name, old_data_fields))),
field_describe, upgrade)
),
upgrade, old_fk_fields = old_model_describe.get('fk_fields')
) new_fk_fields = new_model_describe.get('fk_fields')
old_fk_fields_name = list(map(lambda x: x.get('name'), old_fk_fields))
new_fk_fields_name = list(map(lambda x: x.get('name'), new_fk_fields))
# add fk
for new_fk_field_name in set(new_fk_fields_name).difference(set(old_fk_fields_name)):
fk_field = next(filter(lambda x: x.get('name') == new_fk_field_name, new_fk_fields))
cls._add_operator(
cls._add_fk(model, fk_field,
next(filter(lambda x: x.get('db_column') == fk_field.get('raw_field'), new_data_fields))),
upgrade)
# drop fk
for old_fk_field_name in set(old_fk_fields_name).difference(set(new_fk_fields_name)):
old_fk_field = next(filter(lambda x: x.get('name') == old_fk_field_name, old_fk_fields))
cls._add_operator(
cls._drop_fk(
model, old_fk_field,
next(filter(lambda x: x.get('db_column') == old_fk_field.get('raw_field'), old_data_fields))),
upgrade)
@classmethod @classmethod
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]): def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]):
@ -273,12 +292,8 @@ class Migrate:
return isinstance(field, (BackwardFKRelation, BackwardOneToOneRelation)) return isinstance(field, (BackwardFKRelation, BackwardOneToOneRelation))
@classmethod @classmethod
def _add_field(cls, model: Type[Model], field: dict, is_pk: bool = False): def _add_field(cls, model: Type[Model], field_describe: dict, is_pk: bool = False):
if field.get("field_type") == "ForeignKeyFieldInstance": return cls.ddl.add_column(model, field_describe, is_pk)
return cls.ddl.add_fk(model, field)
if field.get("field_type") == "ManyToManyFieldInstance":
return cls.ddl.create_m2m_table(model, field)
return cls.ddl.add_column(model, field, is_pk)
@classmethod @classmethod
def _alter_default(cls, model: Type[Model], field: Field): def _alter_default(cls, model: Type[Model], field: Field):
@ -297,16 +312,12 @@ class Migrate:
return cls.ddl.modify_column(model, field) return cls.ddl.modify_column(model, field)
@classmethod @classmethod
def _drop_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance): def _drop_fk(cls, model: Type[Model], field_describe: dict, field_describe_target: dict):
return cls.ddl.drop_fk(model, field) return cls.ddl.drop_fk(model, field_describe, field_describe_target)
@classmethod @classmethod
def _remove_field(cls, model: Type[Model], field: Field): def _remove_field(cls, model: Type[Model], field_describe: dict):
if isinstance(field, ForeignKeyFieldInstance): return cls.ddl.drop_column(model, field_describe)
return cls.ddl.drop_fk(model, field)
if isinstance(field, ManyToManyFieldInstance):
return cls.ddl.drop_m2m(field)
return cls.ddl.drop_column(model, field.model_field_name)
@classmethod @classmethod
def _rename_field(cls, model: Type[Model], old_field: Field, new_field: Field): def _rename_field(cls, model: Type[Model], old_field: Field, new_field: Field):
@ -322,24 +333,14 @@ class Migrate:
) )
@classmethod @classmethod
def _add_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance): def _add_fk(cls, model: Type[Model], field_describe: dict, field_describe_target: dict):
""" """
add fk add fk
:param model: :param model:
:param field: :param field:
:return: :return:
""" """
return cls.ddl.add_fk(model, field) return cls.ddl.add_fk(model, field_describe, field_describe_target)
@classmethod
def _remove_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance):
"""
drop fk
:param model:
:param field:
:return:
"""
return cls.ddl.drop_fk(model, field)
@classmethod @classmethod
def _merge_operators(cls): def _merge_operators(cls):

20
poetry.lock generated
View File

@ -138,20 +138,6 @@ python-versions = "*"
[package.dependencies] [package.dependencies]
pyparsing = "*" pyparsing = "*"
[[package]]
name = "dictdiffer"
version = "0.8.1"
description = "Dictdiffer is a library that helps you to diff and patch dictionaries."
category = "main"
optional = false
python-versions = "*"
[package.extras]
all = ["Sphinx (>=1.4.4)", "sphinx-rtd-theme (>=0.1.9)", "check-manifest (>=0.25)", "coverage (>=4.0)", "isort (>=4.2.2)", "mock (>=1.3.0)", "pydocstyle (>=1.0.0)", "pytest-cov (>=1.8.0)", "pytest-pep8 (>=1.0.6)", "pytest (>=2.8.0)", "tox (>=3.7.0)", "numpy (>=1.11.0)"]
docs = ["Sphinx (>=1.4.4)", "sphinx-rtd-theme (>=0.1.9)"]
numpy = ["numpy (>=1.11.0)"]
tests = ["check-manifest (>=0.25)", "coverage (>=4.0)", "isort (>=4.2.2)", "mock (>=1.3.0)", "pydocstyle (>=1.0.0)", "pytest-cov (>=1.8.0)", "pytest-pep8 (>=1.0.6)", "pytest (>=2.8.0)", "tox (>=3.7.0)"]
[[package]] [[package]]
name = "execnet" name = "execnet"
version = "1.8.0" version = "1.8.0"
@ -561,7 +547,7 @@ dbdrivers = ["aiomysql", "asyncpg"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.7" python-versions = "^3.7"
content-hash = "f4ef33a953946570d6d35a479dad75768cd3c6a72e5953c68f2de1566c40873b" content-hash = "9adf7beba99d615c71a9148391386c9016cbafc7c11c5fc3ad81c8ec61026236"
[metadata.files] [metadata.files]
aiomysql = [ aiomysql = [
@ -633,10 +619,6 @@ ddlparse = [
{file = "ddlparse-1.9.0-py3-none-any.whl", hash = "sha256:a7962615a9325be7d0f182cbe34011e6283996473fb98c784c6f675b9783bc18"}, {file = "ddlparse-1.9.0-py3-none-any.whl", hash = "sha256:a7962615a9325be7d0f182cbe34011e6283996473fb98c784c6f675b9783bc18"},
{file = "ddlparse-1.9.0.tar.gz", hash = "sha256:cdffcf2f692f304a23c8e903b00afd7e83a920b79a2ff4e2f25c875b369d4f58"}, {file = "ddlparse-1.9.0.tar.gz", hash = "sha256:cdffcf2f692f304a23c8e903b00afd7e83a920b79a2ff4e2f25c875b369d4f58"},
] ]
dictdiffer = [
{file = "dictdiffer-0.8.1-py2.py3-none-any.whl", hash = "sha256:d79d9a39e459fe33497c858470ca0d2e93cb96621751de06d631856adfd9c390"},
{file = "dictdiffer-0.8.1.tar.gz", hash = "sha256:1adec0d67cdf6166bda96ae2934ddb5e54433998ceab63c984574d187cc563d2"},
]
execnet = [ execnet = [
{file = "execnet-1.8.0-py2.py3-none-any.whl", hash = "sha256:7a13113028b1e1cc4c6492b28098b3c6576c9dccc7973bfe47b342afadafb2ac"}, {file = "execnet-1.8.0-py2.py3-none-any.whl", hash = "sha256:7a13113028b1e1cc4c6492b28098b3c6576c9dccc7973bfe47b342afadafb2ac"},
{file = "execnet-1.8.0.tar.gz", hash = "sha256:b73c5565e517f24b62dea8a5ceac178c661c4309d3aa0c3e420856c072c411b4"}, {file = "execnet-1.8.0.tar.gz", hash = "sha256:b73c5565e517f24b62dea8a5ceac178c661c4309d3aa0c3e420856c072c411b4"},

View File

@ -16,13 +16,12 @@ include = ["CHANGELOG.md", "LICENSE", "README.md"]
[tool.poetry.dependencies] [tool.poetry.dependencies]
python = "^3.7" python = "^3.7"
tortoise-orm = "*" tortoise-orm = "^0.16.21"
click = "*" click = "*"
pydantic = "*" pydantic = "*"
aiomysql = {version = "*", optional = true} aiomysql = { version = "*", optional = true }
asyncpg = {version = "*", optional = true} asyncpg = { version = "*", optional = true }
ddlparse = "*" ddlparse = "*"
dictdiffer = "*"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
flake8 = "*" flake8 = "*"

View File

@ -28,12 +28,12 @@ class User(Model):
is_active = fields.BooleanField(default=True, description="Is Active") is_active = fields.BooleanField(default=True, description="Is Active")
is_superuser = fields.BooleanField(default=False, description="Is SuperUser") is_superuser = fields.BooleanField(default=False, description="Is SuperUser")
avatar = fields.CharField(max_length=200, default="") avatar = fields.CharField(max_length=200, default="")
intro = fields.TextField(default="")
class Email(Model): class Email(Model):
email = fields.CharField(max_length=200) email = fields.CharField(max_length=200)
is_primary = fields.BooleanField(default=False) is_primary = fields.BooleanField(default=False)
user = fields.ForeignKeyField("models.User", db_constraint=False)
class Category(Model): class Category(Model):

View File

@ -58,7 +58,7 @@ def test_drop_table():
def test_add_column(): def test_add_column():
ret = Migrate.ddl.add_column(Category, Category._meta.fields_map.get("name")) ret = Migrate.ddl.add_column(Category, Category._meta.fields_map.get("name").describe(False))
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ret == "ALTER TABLE `category` ADD `name` VARCHAR(200) NOT NULL" assert ret == "ALTER TABLE `category` ADD `name` VARCHAR(200) NOT NULL"
else: else:
@ -180,7 +180,7 @@ def test_drop_index():
def test_add_fk(): def test_add_fk():
ret = Migrate.ddl.add_fk(Category, Category._meta.fields_map.get("user")) ret = Migrate.ddl.add_fk(Category, Category._meta.fields_map.get("user").describe(False))
if isinstance(Migrate.ddl, MysqlDDL): if isinstance(Migrate.ddl, MysqlDDL):
assert ( assert (
ret ret