v0.5 refactoring

This commit is contained in:
long2ice 2021-01-31 23:10:30 +08:00
parent 4780b90c1c
commit b4cc2de0e3
5 changed files with 56 additions and 266 deletions

View File

@ -18,6 +18,7 @@ from aerich.migrate import Migrate
from aerich.utils import (
get_app_connection,
get_app_connection_name,
get_models_describe,
get_tortoise_config,
get_version_content_from_file,
write_version_file,
@ -34,11 +35,7 @@ def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
ctx = args[0]
loop.run_until_complete(f(*args, **kwargs))
app = ctx.obj.get("app")
if app:
Migrate.remove_old_model_file(app, ctx.obj["location"])
loop.run_until_complete(Tortoise.close_connections())
return wrapper
@ -86,7 +83,7 @@ async def cli(ctx: Context, config, app, name):
if invoked_subcommand != "init-db":
if not Path(location, app).exists():
raise UsageError("You must exec init-db first", ctx=ctx)
await Migrate.init_with_old_models(tortoise_config, app, location)
await Migrate.init(tortoise_config, app, location)
@cli.command(help="Generate migrate changes file.")
@ -106,7 +103,6 @@ async def migrate(ctx: Context, name):
async def upgrade(ctx: Context):
config = ctx.obj["config"]
app = ctx.obj["app"]
location = ctx.obj["location"]
migrated = False
for version_file in Migrate.get_all_version_files():
try:
@ -123,7 +119,7 @@ async def upgrade(ctx: Context):
await Aerich.create(
version=version_file,
app=app,
content=Migrate.get_models_content(config, app, location),
content=get_models_describe(app),
)
click.secho(f"Success upgrade {version_file}", fg=Color.green)
migrated = True
@ -281,7 +277,7 @@ async def init_db(ctx: Context, safe):
await Aerich.create(
version=version,
app=app,
content=Migrate.get_models_content(config, app, location),
content=get_models_describe(app),
)
content = {
"upgrade": [schema],

View File

@ -140,7 +140,7 @@ class BaseDDL:
)
def change_column(
self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str
self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str
):
return self._CHANGE_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table,
@ -167,26 +167,23 @@ class BaseDDL:
table_name=model._meta.db_table,
)
def add_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance):
def add_fk(self, model: "Type[Model]", field: dict):
db_table = model._meta.db_table
to_field_name = field.to_field_instance.source_field
if not to_field_name:
to_field_name = field.to_field_instance.model_field_name
db_column = field.source_field or field.model_field_name + "_id"
db_column = field.get('db_column')
fk_name = self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=db_column,
to_table=field.related_model._meta.db_table,
to_field=to_field_name,
to_field=db_column,
)
return self._ADD_FK_TEMPLATE.format(
table_name=db_table,
fk_name=fk_name,
db_column=db_column,
table=field.related_model._meta.db_table,
field=to_field_name,
on_delete=field.on_delete,
field=db_column,
on_delete=field.get('on_delete'),
)
def drop_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance):

View File

@ -1,14 +1,10 @@
import inspect
import os
import re
from datetime import datetime
from importlib import import_module
from io import StringIO
from pathlib import Path
from types import ModuleType
from typing import Dict, List, Optional, Tuple, Type
import click
from dictdiffer import diff
from tortoise import (
BackwardFKRelation,
BackwardOneToOneRelation,
@ -23,7 +19,7 @@ from tortoise.fields import Field
from aerich.ddl import BaseDDL
from aerich.models import MAX_VERSION_LENGTH, Aerich
from aerich.utils import get_app_connection, write_version_file
from aerich.utils import get_app_connection, get_models_describe, write_version_file
class Migrate:
@ -38,18 +34,12 @@ class Migrate:
_rename_new = []
ddl: BaseDDL
migrate_config: dict
old_models = "old_models"
diff_app = "diff_models"
_last_version_content: Optional[dict] = None
app: str
migrate_location: str
dialect: str
_db_version: Optional[str] = None
@classmethod
def get_old_model_file(cls, app: str, location: str):
return Path(location, app, cls.old_models + ".py")
@classmethod
def get_all_version_files(cls) -> List[str]:
return sorted(
@ -57,6 +47,10 @@ class Migrate:
key=lambda x: int(x.split("_")[0]),
)
@classmethod
def _get_model(cls, model: str) -> Type[Model]:
return Tortoise.apps.get(cls.app).get(model)
@classmethod
async def get_last_version(cls) -> Optional[Aerich]:
try:
@ -64,13 +58,6 @@ class Migrate:
except OperationalError:
pass
@classmethod
def remove_old_model_file(cls, app: str, location: str):
try:
os.unlink(cls.get_old_model_file(app, location))
except (OSError, FileNotFoundError):
pass
@classmethod
async def _get_db_version(cls, connection: BaseDBAsyncClient):
if cls.dialect == "mysql":
@ -79,19 +66,13 @@ class Migrate:
cls._db_version = ret[1][0].get("version")
@classmethod
async def init_with_old_models(cls, config: dict, app: str, location: str):
async def init(cls, config: dict, app: str, location: str):
await Tortoise.init(config=config)
last_version = await cls.get_last_version()
cls.app = app
cls.migrate_location = Path(location, app)
if last_version:
content = last_version.content
with open(cls.get_old_model_file(app, location), "w", encoding="utf-8") as f:
f.write(content)
migrate_config = cls._get_migrate_config(config, app, location)
cls.migrate_config = migrate_config
await Tortoise.init(config=migrate_config)
cls._last_version_content = last_version.content
connection = get_app_connection(config, app)
cls.dialect = connection.schema_generator.DIALECT
@ -149,12 +130,9 @@ class Migrate:
:param name:
:return:
"""
apps = Tortoise.apps
diff_models = apps.get(cls.diff_app)
app_models = apps.get(cls.app)
cls.diff_models(diff_models, app_models)
cls.diff_models(app_models, diff_models, False)
new_version_content = get_models_describe(cls.app)
cls.diff_models(cls._last_version_content, new_version_content)
cls.diff_models(new_version_content, cls._last_version_content, False)
cls._merge_operators()
@ -183,58 +161,9 @@ class Migrate:
else:
cls.downgrade_operators.append(operator)
@classmethod
def _get_migrate_config(cls, config: dict, app: str, location: str):
"""
generate tmp config with old models
:param config:
:param app:
:param location:
:return:
"""
path = Path(location, app, cls.old_models).as_posix().replace("/", ".")
config["apps"][cls.diff_app] = {
"models": [path],
"default_connection": config.get("apps").get(app).get("default_connection", "default"),
}
return config
@classmethod
def get_models_content(cls, config: dict, app: str, location: str):
"""
write new models to old models
:param config:
:param app:
:param location:
:return:
"""
old_model_files = []
models = config.get("apps").get(app).get("models")
for model in models:
if isinstance(model, ModuleType):
module = model
else:
module = import_module(model)
possible_models = [getattr(module, attr_name) for attr_name in dir(module)]
for attr in filter(
lambda x: inspect.isclass(x) and issubclass(x, Model) and x is not Model,
possible_models,
):
file = inspect.getfile(attr)
if file not in old_model_files:
old_model_files.append(file)
pattern = rf"(\n)?('|\")({app})(.\w+)('|\")"
str_io = StringIO()
for i, model_file in enumerate(old_model_files):
with open(model_file, "r", encoding="utf-8") as f:
content = f.read()
ret = re.sub(pattern, rf"\2{cls.diff_app}\4\5", content)
str_io.write(f"{ret}\n")
return str_io.getvalue()
@classmethod
def diff_models(
cls, old_models: Dict[str, Type[Model]], new_models: Dict[str, Type[Model]], upgrade=True
cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True
):
"""
diff models and add operators
@ -243,18 +172,19 @@ class Migrate:
:param upgrade:
:return:
"""
old_models.pop(cls._aerich, None)
new_models.pop(cls._aerich, None)
_aerich = f'{cls.app}.{cls._aerich}'
old_models.pop(_aerich, None)
new_models.pop(_aerich, None)
for new_model_str, new_model in new_models.items():
for new_model_str, new_model_describe in new_models.items():
if new_model_str not in old_models.keys():
cls._add_operator(cls.add_model(new_model), upgrade)
cls._add_operator(cls.add_model(cls._get_model(new_model_str)), upgrade)
else:
cls.diff_model(old_models.get(new_model_str), new_model, upgrade)
cls.diff_model(old_models.get(new_model_str), new_model_describe, upgrade)
for old_model in old_models:
if old_model not in new_models.keys():
cls._add_operator(cls.remove_model(old_models.get(old_model)), upgrade)
cls._add_operator(cls.remove_model(cls._get_model(old_model)), upgrade)
@classmethod
def _is_fk_m2m(cls, field: Field):
@ -269,166 +199,20 @@ class Migrate:
return cls.ddl.drop_table(model)
@classmethod
def diff_model(cls, old_model: Type[Model], new_model: Type[Model], upgrade=True):
def diff_model(cls, old_model_describe: dict, new_model_describe: dict, upgrade=True):
"""
diff single model
:param old_model:
:param new_model:
:param old_model_describe:
:param new_model_describe:
:param upgrade:
:return:
"""
old_indexes = old_model._meta.indexes
new_indexes = new_model._meta.indexes
old_unique_together = old_model._meta.unique_together
new_unique_together = new_model._meta.unique_together
old_fields_map = old_model._meta.fields_map
new_fields_map = new_model._meta.fields_map
old_keys = old_fields_map.keys()
new_keys = new_fields_map.keys()
for new_key in new_keys:
new_field = new_fields_map.get(new_key)
if cls._exclude_field(new_field, upgrade):
continue
if new_key not in old_keys:
new_field_dict = new_field.describe(serializable=True)
new_field_dict.pop("name", None)
new_field_dict.pop("db_column", None)
for diff_key in old_keys - new_keys:
old_field = old_fields_map.get(diff_key)
old_field_dict = old_field.describe(serializable=True)
old_field_dict.pop("name", None)
old_field_dict.pop("db_column", None)
if old_field_dict == new_field_dict:
if upgrade:
is_rename = click.prompt(
f"Rename {diff_key} to {new_key}?",
default=True,
type=bool,
show_choices=True,
)
cls._rename_new.append(new_key)
cls._rename_old.append(diff_key)
else:
is_rename = diff_key in cls._rename_new
if is_rename:
if (
cls.dialect == "mysql"
and cls._db_version
and cls._db_version.startswith("5.")
):
cls._add_operator(
cls._change_field(new_model, old_field, new_field),
upgrade,
)
else:
cls._add_operator(
cls._rename_field(new_model, old_field, new_field),
upgrade,
)
break
else:
cls._add_operator(
cls._add_field(new_model, new_field),
upgrade,
cls._is_fk_m2m(new_field),
)
else:
old_field = old_fields_map.get(new_key)
new_field_dict = new_field.describe(serializable=True)
new_field_dict.pop("unique")
new_field_dict.pop("indexed")
old_field_dict = old_field.describe(serializable=True)
old_field_dict.pop("unique")
old_field_dict.pop("indexed")
if not cls._is_fk_m2m(new_field) and new_field_dict != old_field_dict:
if cls.dialect == "postgres":
if new_field.null != old_field.null:
cls._add_operator(
cls._alter_null(new_model, new_field), upgrade=upgrade
)
if new_field.default != old_field.default and not callable(
new_field.default
):
cls._add_operator(
cls._alter_default(new_model, new_field), upgrade=upgrade
)
if new_field.description != old_field.description:
cls._add_operator(
cls._set_comment(new_model, new_field), upgrade=upgrade
)
if new_field.field_type != old_field.field_type:
cls._add_operator(
cls._modify_field(new_model, new_field), upgrade=upgrade
)
else:
cls._add_operator(cls._modify_field(new_model, new_field), upgrade=upgrade)
if (old_field.index and not new_field.index) or (
old_field.unique and not new_field.unique
):
cls._add_operator(
cls._remove_index(
old_model, (old_field.model_field_name,), old_field.unique
),
upgrade,
cls._is_fk_m2m(old_field),
)
elif (new_field.index and not old_field.index) or (
new_field.unique and not old_field.unique
):
cls._add_operator(
cls._add_index(new_model, (new_field.model_field_name,), new_field.unique),
upgrade,
cls._is_fk_m2m(new_field),
)
if isinstance(new_field, ForeignKeyFieldInstance):
if old_field.db_constraint and not new_field.db_constraint:
cls._add_operator(
cls._drop_fk(new_model, new_field),
upgrade,
True,
)
if new_field.db_constraint and not old_field.db_constraint:
cls._add_operator(
cls._add_fk(new_model, new_field),
upgrade,
True,
)
for old_key in old_keys:
field = old_fields_map.get(old_key)
if old_key not in new_keys and not cls._exclude_field(field, upgrade):
if (upgrade and old_key not in cls._rename_old) or (
not upgrade and old_key not in cls._rename_new
):
cls._add_operator(
cls._remove_field(old_model, field),
upgrade,
cls._is_fk_m2m(field),
)
for new_index in new_indexes:
if new_index not in old_indexes:
cls._add_operator(
cls._add_index(
new_model,
new_index,
),
upgrade,
)
for old_index in old_indexes:
if old_index not in new_indexes:
cls._add_operator(cls._remove_index(old_model, old_index), upgrade)
for new_unique in new_unique_together:
if new_unique not in old_unique_together:
cls._add_operator(cls._add_index(new_model, new_unique, unique=True), upgrade)
for old_unique in old_unique_together:
if old_unique not in new_unique_together:
cls._add_operator(cls._remove_index(old_model, old_unique, unique=True), upgrade)
for change in diff(old_model_describe, new_model_describe):
action, field_type, fields = change
if action == 'add':
for field in fields:
_, field_describe = field
cls._add_field(cls._get_model)
@classmethod
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]):
@ -474,10 +258,10 @@ class Migrate:
return isinstance(field, (BackwardFKRelation, BackwardOneToOneRelation))
@classmethod
def _add_field(cls, model: Type[Model], field: Field):
if isinstance(field, ForeignKeyFieldInstance):
def _add_field(cls, model: Type[Model], field: dict):
if field.get('field_type') == 'ForeignKeyFieldInstance':
return cls.ddl.add_fk(model, field)
if isinstance(field, ManyToManyFieldInstance):
if field.get('field_type') == 'ManyToManyFieldInstance':
return cls.ddl.create_m2m_table(model, field)
return cls.ddl.add_column(model, field)

View File

@ -6,7 +6,7 @@ MAX_VERSION_LENGTH = 255
class Aerich(Model):
version = fields.CharField(max_length=MAX_VERSION_LENGTH)
app = fields.CharField(max_length=20)
content = fields.TextField()
content = fields.JSONField()
class Meta:
ordering = ["-id"]

View File

@ -108,3 +108,16 @@ def write_version_file(version_file: str, content: Dict):
f.write(";\n".join(downgrade) + ";\n")
else:
f.write(f"{downgrade[0]};\n")
def get_models_describe(app: str) -> Dict:
"""
get app models describe
:param app:
:return:
"""
ret = {}
for model in Tortoise.apps.get(app).values():
describe = model.describe()
ret[describe.get("name")] = describe
return ret