v0.5 refactoring

This commit is contained in:
long2ice 2021-01-31 23:10:30 +08:00
parent 4780b90c1c
commit b4cc2de0e3
5 changed files with 56 additions and 266 deletions

View File

@ -18,6 +18,7 @@ from aerich.migrate import Migrate
from aerich.utils import ( from aerich.utils import (
get_app_connection, get_app_connection,
get_app_connection_name, get_app_connection_name,
get_models_describe,
get_tortoise_config, get_tortoise_config,
get_version_content_from_file, get_version_content_from_file,
write_version_file, write_version_file,
@ -34,11 +35,7 @@ def coro(f):
@wraps(f) @wraps(f)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
ctx = args[0]
loop.run_until_complete(f(*args, **kwargs)) loop.run_until_complete(f(*args, **kwargs))
app = ctx.obj.get("app")
if app:
Migrate.remove_old_model_file(app, ctx.obj["location"])
loop.run_until_complete(Tortoise.close_connections()) loop.run_until_complete(Tortoise.close_connections())
return wrapper return wrapper
@ -86,7 +83,7 @@ async def cli(ctx: Context, config, app, name):
if invoked_subcommand != "init-db": if invoked_subcommand != "init-db":
if not Path(location, app).exists(): if not Path(location, app).exists():
raise UsageError("You must exec init-db first", ctx=ctx) raise UsageError("You must exec init-db first", ctx=ctx)
await Migrate.init_with_old_models(tortoise_config, app, location) await Migrate.init(tortoise_config, app, location)
@cli.command(help="Generate migrate changes file.") @cli.command(help="Generate migrate changes file.")
@ -106,7 +103,6 @@ async def migrate(ctx: Context, name):
async def upgrade(ctx: Context): async def upgrade(ctx: Context):
config = ctx.obj["config"] config = ctx.obj["config"]
app = ctx.obj["app"] app = ctx.obj["app"]
location = ctx.obj["location"]
migrated = False migrated = False
for version_file in Migrate.get_all_version_files(): for version_file in Migrate.get_all_version_files():
try: try:
@ -123,7 +119,7 @@ async def upgrade(ctx: Context):
await Aerich.create( await Aerich.create(
version=version_file, version=version_file,
app=app, app=app,
content=Migrate.get_models_content(config, app, location), content=get_models_describe(app),
) )
click.secho(f"Success upgrade {version_file}", fg=Color.green) click.secho(f"Success upgrade {version_file}", fg=Color.green)
migrated = True migrated = True
@ -281,7 +277,7 @@ async def init_db(ctx: Context, safe):
await Aerich.create( await Aerich.create(
version=version, version=version,
app=app, app=app,
content=Migrate.get_models_content(config, app, location), content=get_models_describe(app),
) )
content = { content = {
"upgrade": [schema], "upgrade": [schema],

View File

@ -167,26 +167,23 @@ class BaseDDL:
table_name=model._meta.db_table, table_name=model._meta.db_table,
) )
def add_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance): def add_fk(self, model: "Type[Model]", field: dict):
db_table = model._meta.db_table db_table = model._meta.db_table
to_field_name = field.to_field_instance.source_field
if not to_field_name:
to_field_name = field.to_field_instance.model_field_name
db_column = field.source_field or field.model_field_name + "_id" db_column = field.get('db_column')
fk_name = self.schema_generator._generate_fk_name( fk_name = self.schema_generator._generate_fk_name(
from_table=db_table, from_table=db_table,
from_field=db_column, from_field=db_column,
to_table=field.related_model._meta.db_table, to_table=field.related_model._meta.db_table,
to_field=to_field_name, to_field=db_column,
) )
return self._ADD_FK_TEMPLATE.format( return self._ADD_FK_TEMPLATE.format(
table_name=db_table, table_name=db_table,
fk_name=fk_name, fk_name=fk_name,
db_column=db_column, db_column=db_column,
table=field.related_model._meta.db_table, table=field.related_model._meta.db_table,
field=to_field_name, field=db_column,
on_delete=field.on_delete, on_delete=field.get('on_delete'),
) )
def drop_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance): def drop_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance):

View File

@ -1,14 +1,10 @@
import inspect
import os import os
import re
from datetime import datetime from datetime import datetime
from importlib import import_module
from io import StringIO
from pathlib import Path from pathlib import Path
from types import ModuleType
from typing import Dict, List, Optional, Tuple, Type from typing import Dict, List, Optional, Tuple, Type
import click import click
from dictdiffer import diff
from tortoise import ( from tortoise import (
BackwardFKRelation, BackwardFKRelation,
BackwardOneToOneRelation, BackwardOneToOneRelation,
@ -23,7 +19,7 @@ from tortoise.fields import Field
from aerich.ddl import BaseDDL from aerich.ddl import BaseDDL
from aerich.models import MAX_VERSION_LENGTH, Aerich from aerich.models import MAX_VERSION_LENGTH, Aerich
from aerich.utils import get_app_connection, write_version_file from aerich.utils import get_app_connection, get_models_describe, write_version_file
class Migrate: class Migrate:
@ -38,18 +34,12 @@ class Migrate:
_rename_new = [] _rename_new = []
ddl: BaseDDL ddl: BaseDDL
migrate_config: dict _last_version_content: Optional[dict] = None
old_models = "old_models"
diff_app = "diff_models"
app: str app: str
migrate_location: str migrate_location: str
dialect: str dialect: str
_db_version: Optional[str] = None _db_version: Optional[str] = None
@classmethod
def get_old_model_file(cls, app: str, location: str):
return Path(location, app, cls.old_models + ".py")
@classmethod @classmethod
def get_all_version_files(cls) -> List[str]: def get_all_version_files(cls) -> List[str]:
return sorted( return sorted(
@ -57,6 +47,10 @@ class Migrate:
key=lambda x: int(x.split("_")[0]), key=lambda x: int(x.split("_")[0]),
) )
@classmethod
def _get_model(cls, model: str) -> Type[Model]:
return Tortoise.apps.get(cls.app).get(model)
@classmethod @classmethod
async def get_last_version(cls) -> Optional[Aerich]: async def get_last_version(cls) -> Optional[Aerich]:
try: try:
@ -64,13 +58,6 @@ class Migrate:
except OperationalError: except OperationalError:
pass pass
@classmethod
def remove_old_model_file(cls, app: str, location: str):
try:
os.unlink(cls.get_old_model_file(app, location))
except (OSError, FileNotFoundError):
pass
@classmethod @classmethod
async def _get_db_version(cls, connection: BaseDBAsyncClient): async def _get_db_version(cls, connection: BaseDBAsyncClient):
if cls.dialect == "mysql": if cls.dialect == "mysql":
@ -79,19 +66,13 @@ class Migrate:
cls._db_version = ret[1][0].get("version") cls._db_version = ret[1][0].get("version")
@classmethod @classmethod
async def init_with_old_models(cls, config: dict, app: str, location: str): async def init(cls, config: dict, app: str, location: str):
await Tortoise.init(config=config) await Tortoise.init(config=config)
last_version = await cls.get_last_version() last_version = await cls.get_last_version()
cls.app = app cls.app = app
cls.migrate_location = Path(location, app) cls.migrate_location = Path(location, app)
if last_version: if last_version:
content = last_version.content cls._last_version_content = last_version.content
with open(cls.get_old_model_file(app, location), "w", encoding="utf-8") as f:
f.write(content)
migrate_config = cls._get_migrate_config(config, app, location)
cls.migrate_config = migrate_config
await Tortoise.init(config=migrate_config)
connection = get_app_connection(config, app) connection = get_app_connection(config, app)
cls.dialect = connection.schema_generator.DIALECT cls.dialect = connection.schema_generator.DIALECT
@ -149,12 +130,9 @@ class Migrate:
:param name: :param name:
:return: :return:
""" """
apps = Tortoise.apps new_version_content = get_models_describe(cls.app)
diff_models = apps.get(cls.diff_app) cls.diff_models(cls._last_version_content, new_version_content)
app_models = apps.get(cls.app) cls.diff_models(new_version_content, cls._last_version_content, False)
cls.diff_models(diff_models, app_models)
cls.diff_models(app_models, diff_models, False)
cls._merge_operators() cls._merge_operators()
@ -183,58 +161,9 @@ class Migrate:
else: else:
cls.downgrade_operators.append(operator) cls.downgrade_operators.append(operator)
@classmethod
def _get_migrate_config(cls, config: dict, app: str, location: str):
"""
generate tmp config with old models
:param config:
:param app:
:param location:
:return:
"""
path = Path(location, app, cls.old_models).as_posix().replace("/", ".")
config["apps"][cls.diff_app] = {
"models": [path],
"default_connection": config.get("apps").get(app).get("default_connection", "default"),
}
return config
@classmethod
def get_models_content(cls, config: dict, app: str, location: str):
"""
write new models to old models
:param config:
:param app:
:param location:
:return:
"""
old_model_files = []
models = config.get("apps").get(app).get("models")
for model in models:
if isinstance(model, ModuleType):
module = model
else:
module = import_module(model)
possible_models = [getattr(module, attr_name) for attr_name in dir(module)]
for attr in filter(
lambda x: inspect.isclass(x) and issubclass(x, Model) and x is not Model,
possible_models,
):
file = inspect.getfile(attr)
if file not in old_model_files:
old_model_files.append(file)
pattern = rf"(\n)?('|\")({app})(.\w+)('|\")"
str_io = StringIO()
for i, model_file in enumerate(old_model_files):
with open(model_file, "r", encoding="utf-8") as f:
content = f.read()
ret = re.sub(pattern, rf"\2{cls.diff_app}\4\5", content)
str_io.write(f"{ret}\n")
return str_io.getvalue()
@classmethod @classmethod
def diff_models( def diff_models(
cls, old_models: Dict[str, Type[Model]], new_models: Dict[str, Type[Model]], upgrade=True cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True
): ):
""" """
diff models and add operators diff models and add operators
@ -243,18 +172,19 @@ class Migrate:
:param upgrade: :param upgrade:
:return: :return:
""" """
old_models.pop(cls._aerich, None) _aerich = f'{cls.app}.{cls._aerich}'
new_models.pop(cls._aerich, None) old_models.pop(_aerich, None)
new_models.pop(_aerich, None)
for new_model_str, new_model in new_models.items(): for new_model_str, new_model_describe in new_models.items():
if new_model_str not in old_models.keys(): if new_model_str not in old_models.keys():
cls._add_operator(cls.add_model(new_model), upgrade) cls._add_operator(cls.add_model(cls._get_model(new_model_str)), upgrade)
else: else:
cls.diff_model(old_models.get(new_model_str), new_model, upgrade) cls.diff_model(old_models.get(new_model_str), new_model_describe, upgrade)
for old_model in old_models: for old_model in old_models:
if old_model not in new_models.keys(): if old_model not in new_models.keys():
cls._add_operator(cls.remove_model(old_models.get(old_model)), upgrade) cls._add_operator(cls.remove_model(cls._get_model(old_model)), upgrade)
@classmethod @classmethod
def _is_fk_m2m(cls, field: Field): def _is_fk_m2m(cls, field: Field):
@ -269,166 +199,20 @@ class Migrate:
return cls.ddl.drop_table(model) return cls.ddl.drop_table(model)
@classmethod @classmethod
def diff_model(cls, old_model: Type[Model], new_model: Type[Model], upgrade=True): def diff_model(cls, old_model_describe: dict, new_model_describe: dict, upgrade=True):
""" """
diff single model diff single model
:param old_model: :param old_model_describe:
:param new_model: :param new_model_describe:
:param upgrade: :param upgrade:
:return: :return:
""" """
old_indexes = old_model._meta.indexes for change in diff(old_model_describe, new_model_describe):
new_indexes = new_model._meta.indexes action, field_type, fields = change
if action == 'add':
old_unique_together = old_model._meta.unique_together for field in fields:
new_unique_together = new_model._meta.unique_together _, field_describe = field
cls._add_field(cls._get_model)
old_fields_map = old_model._meta.fields_map
new_fields_map = new_model._meta.fields_map
old_keys = old_fields_map.keys()
new_keys = new_fields_map.keys()
for new_key in new_keys:
new_field = new_fields_map.get(new_key)
if cls._exclude_field(new_field, upgrade):
continue
if new_key not in old_keys:
new_field_dict = new_field.describe(serializable=True)
new_field_dict.pop("name", None)
new_field_dict.pop("db_column", None)
for diff_key in old_keys - new_keys:
old_field = old_fields_map.get(diff_key)
old_field_dict = old_field.describe(serializable=True)
old_field_dict.pop("name", None)
old_field_dict.pop("db_column", None)
if old_field_dict == new_field_dict:
if upgrade:
is_rename = click.prompt(
f"Rename {diff_key} to {new_key}?",
default=True,
type=bool,
show_choices=True,
)
cls._rename_new.append(new_key)
cls._rename_old.append(diff_key)
else:
is_rename = diff_key in cls._rename_new
if is_rename:
if (
cls.dialect == "mysql"
and cls._db_version
and cls._db_version.startswith("5.")
):
cls._add_operator(
cls._change_field(new_model, old_field, new_field),
upgrade,
)
else:
cls._add_operator(
cls._rename_field(new_model, old_field, new_field),
upgrade,
)
break
else:
cls._add_operator(
cls._add_field(new_model, new_field),
upgrade,
cls._is_fk_m2m(new_field),
)
else:
old_field = old_fields_map.get(new_key)
new_field_dict = new_field.describe(serializable=True)
new_field_dict.pop("unique")
new_field_dict.pop("indexed")
old_field_dict = old_field.describe(serializable=True)
old_field_dict.pop("unique")
old_field_dict.pop("indexed")
if not cls._is_fk_m2m(new_field) and new_field_dict != old_field_dict:
if cls.dialect == "postgres":
if new_field.null != old_field.null:
cls._add_operator(
cls._alter_null(new_model, new_field), upgrade=upgrade
)
if new_field.default != old_field.default and not callable(
new_field.default
):
cls._add_operator(
cls._alter_default(new_model, new_field), upgrade=upgrade
)
if new_field.description != old_field.description:
cls._add_operator(
cls._set_comment(new_model, new_field), upgrade=upgrade
)
if new_field.field_type != old_field.field_type:
cls._add_operator(
cls._modify_field(new_model, new_field), upgrade=upgrade
)
else:
cls._add_operator(cls._modify_field(new_model, new_field), upgrade=upgrade)
if (old_field.index and not new_field.index) or (
old_field.unique and not new_field.unique
):
cls._add_operator(
cls._remove_index(
old_model, (old_field.model_field_name,), old_field.unique
),
upgrade,
cls._is_fk_m2m(old_field),
)
elif (new_field.index and not old_field.index) or (
new_field.unique and not old_field.unique
):
cls._add_operator(
cls._add_index(new_model, (new_field.model_field_name,), new_field.unique),
upgrade,
cls._is_fk_m2m(new_field),
)
if isinstance(new_field, ForeignKeyFieldInstance):
if old_field.db_constraint and not new_field.db_constraint:
cls._add_operator(
cls._drop_fk(new_model, new_field),
upgrade,
True,
)
if new_field.db_constraint and not old_field.db_constraint:
cls._add_operator(
cls._add_fk(new_model, new_field),
upgrade,
True,
)
for old_key in old_keys:
field = old_fields_map.get(old_key)
if old_key not in new_keys and not cls._exclude_field(field, upgrade):
if (upgrade and old_key not in cls._rename_old) or (
not upgrade and old_key not in cls._rename_new
):
cls._add_operator(
cls._remove_field(old_model, field),
upgrade,
cls._is_fk_m2m(field),
)
for new_index in new_indexes:
if new_index not in old_indexes:
cls._add_operator(
cls._add_index(
new_model,
new_index,
),
upgrade,
)
for old_index in old_indexes:
if old_index not in new_indexes:
cls._add_operator(cls._remove_index(old_model, old_index), upgrade)
for new_unique in new_unique_together:
if new_unique not in old_unique_together:
cls._add_operator(cls._add_index(new_model, new_unique, unique=True), upgrade)
for old_unique in old_unique_together:
if old_unique not in new_unique_together:
cls._add_operator(cls._remove_index(old_model, old_unique, unique=True), upgrade)
@classmethod @classmethod
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]): def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]):
@ -474,10 +258,10 @@ class Migrate:
return isinstance(field, (BackwardFKRelation, BackwardOneToOneRelation)) return isinstance(field, (BackwardFKRelation, BackwardOneToOneRelation))
@classmethod @classmethod
def _add_field(cls, model: Type[Model], field: Field): def _add_field(cls, model: Type[Model], field: dict):
if isinstance(field, ForeignKeyFieldInstance): if field.get('field_type') == 'ForeignKeyFieldInstance':
return cls.ddl.add_fk(model, field) return cls.ddl.add_fk(model, field)
if isinstance(field, ManyToManyFieldInstance): if field.get('field_type') == 'ManyToManyFieldInstance':
return cls.ddl.create_m2m_table(model, field) return cls.ddl.create_m2m_table(model, field)
return cls.ddl.add_column(model, field) return cls.ddl.add_column(model, field)

View File

@ -6,7 +6,7 @@ MAX_VERSION_LENGTH = 255
class Aerich(Model): class Aerich(Model):
version = fields.CharField(max_length=MAX_VERSION_LENGTH) version = fields.CharField(max_length=MAX_VERSION_LENGTH)
app = fields.CharField(max_length=20) app = fields.CharField(max_length=20)
content = fields.TextField() content = fields.JSONField()
class Meta: class Meta:
ordering = ["-id"] ordering = ["-id"]

View File

@ -108,3 +108,16 @@ def write_version_file(version_file: str, content: Dict):
f.write(";\n".join(downgrade) + ";\n") f.write(";\n".join(downgrade) + ";\n")
else: else:
f.write(f"{downgrade[0]};\n") f.write(f"{downgrade[0]};\n")
def get_models_describe(app: str) -> Dict:
"""
get app models describe
:param app:
:return:
"""
ret = {}
for model in Tortoise.apps.get(app).values():
describe = model.describe()
ret[describe.get("name")] = describe
return ret