rename to aerich

This commit is contained in:
long2ice
2020-05-15 13:25:28 +08:00
parent 00764c1b3d
commit b767f409f4
16 changed files with 72 additions and 68 deletions

1
aerich/__init__.py Normal file
View File

@@ -0,0 +1 @@
__version__ = "0.1.1"

182
aerich/cli.py Normal file
View File

@@ -0,0 +1,182 @@
import importlib
import json
import os
import sys
from enum import Enum
import asyncclick as click
from asyncclick import BadOptionUsage, Context, UsageError
from tortoise import Tortoise, generate_schema_for_client
from aerich.migrate import Migrate
from aerich.utils import get_app_connection
class Color(str, Enum):
green = "green"
red = "red"
yellow = "yellow"
@click.group(context_settings={"help_option_names": ["-h", "--help"]})
@click.option(
"--config",
default="settings",
show_default=True,
help="Tortoise-ORM config module, will auto read dict config variable from it.",
)
@click.option(
"--tortoise-orm",
default="TORTOISE_ORM",
show_default=True,
help="Tortoise-ORM config dict variable.",
)
@click.option(
"--location", default="./migrations", show_default=True, help="Migrate store location."
)
@click.option("--app", default="models", show_default=True, help="Tortoise-ORM app name.")
@click.pass_context
async def cli(ctx: Context, config, tortoise_orm, location, app):
ctx.ensure_object(dict)
try:
config_module = importlib.import_module(config, ".")
except ModuleNotFoundError:
raise BadOptionUsage(ctx=ctx, message=f'No module named "{config}"', option_name="--config")
config = getattr(config_module, tortoise_orm, None)
if not config:
raise BadOptionUsage(
option_name="--config",
message=f'Can\'t get "{tortoise_orm}" from module "{config_module}"',
ctx=ctx,
)
if app not in config.get("apps").keys():
raise BadOptionUsage(option_name="--config", message=f'No app found in "{config}"', ctx=ctx)
ctx.obj["config"] = config
ctx.obj["location"] = location
ctx.obj["app"] = app
if ctx.invoked_subcommand == "init":
await Tortoise.init(config=config)
else:
if not os.path.isdir(location):
raise UsageError("You must exec init first", ctx=ctx)
await Migrate.init_with_old_models(config, app, location)
@cli.command(help="Generate migrate changes file.")
@click.option("--name", default="update", show_default=True, help="Migrate name.")
@click.pass_context
async def migrate(ctx: Context, name):
config = ctx.obj["config"]
location = ctx.obj["location"]
app = ctx.obj["app"]
ret = Migrate.migrate(name)
if not ret:
return click.secho("No changes detected", fg=Color.yellow)
Migrate.write_old_models(config, app, location)
click.secho(f"Success migrate {ret}", fg=Color.green)
@cli.command(help="Upgrade to latest version.")
@click.pass_context
async def upgrade(ctx: Context):
app = ctx.obj["app"]
config = ctx.obj["config"]
connection = get_app_connection(config, app)
available_versions = Migrate.get_all_version_files(is_all=False)
if not available_versions:
return click.secho("No migrate items", fg=Color.yellow)
async with connection._in_transaction() as conn:
for file in available_versions:
file_path = os.path.join(Migrate.migrate_location, file)
with open(file_path, "r") as f:
content = json.load(f)
upgrade_query_list = content.get("upgrade")
for upgrade_query in upgrade_query_list:
await conn.execute_query(upgrade_query)
with open(file_path, "w") as f:
content["migrate"] = True
json.dump(content, f, indent=4)
click.secho(f"Success upgrade {file}", fg=Color.green)
@cli.command(help="Downgrade to previous version.")
@click.pass_context
async def downgrade(ctx: Context):
app = ctx.obj["app"]
config = ctx.obj["config"]
connection = get_app_connection(config, app)
available_versions = Migrate.get_all_version_files()
if not available_versions:
return click.secho("No migrate items", fg=Color.yellow)
async with connection._in_transaction() as conn:
for file in reversed(available_versions):
file_path = os.path.join(Migrate.migrate_location, file)
with open(file_path, "r") as f:
content = json.load(f)
if content.get("migrate"):
downgrade_query_list = content.get("downgrade")
for downgrade_query in downgrade_query_list:
await conn.execute_query(downgrade_query)
else:
continue
with open(file_path, "w") as f:
content["migrate"] = False
json.dump(content, f, indent=4)
return click.secho(f"Success downgrade {file}", fg=Color.green)
@cli.command(help="Show current available heads in migrate location.")
@click.pass_context
def heads(ctx: Context):
for version in Migrate.get_all_version_files(is_all=False):
click.secho(version, fg=Color.yellow)
@cli.command(help="List all migrate items.")
@click.pass_context
def history(ctx):
for version in Migrate.get_all_version_files():
click.secho(version, fg=Color.yellow)
@cli.command(help="Init migrate location and generate schema, you must exec first.")
@click.option(
"--safe",
is_flag=True,
default=True,
help="When set to true, creates the table only when it does not already exist.",
show_default=True,
)
@click.pass_context
async def init(ctx: Context, safe):
location = ctx.obj["location"]
app = ctx.obj["app"]
config = ctx.obj["config"]
if not os.path.isdir(location):
os.mkdir(location)
dirname = os.path.join(location, app)
if not os.path.isdir(dirname):
os.mkdir(dirname)
click.secho(f"Success create migrate location {dirname}", fg=Color.green)
else:
return click.secho(f'Already inited app "{app}"', fg=Color.yellow)
Migrate.write_old_models(config, app, location)
connection = get_app_connection(config, app)
await generate_schema_for_client(connection, safe)
return click.secho(f'Success init for app "{app}"', fg=Color.green)
def main():
sys.path.insert(0, ".")
cli(_anyio_backend="asyncio")

47
aerich/ddl/__init__.py Normal file
View File

@@ -0,0 +1,47 @@
from typing import List, Type
from tortoise import BaseDBAsyncClient, ForeignKeyFieldInstance, Model
from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.fields import Field
class BaseDDL:
schema_generator_cls: Type[BaseSchemaGenerator] = BaseSchemaGenerator
DIALECT = "sql"
_DROP_TABLE_TEMPLATE = "DROP TABLE {table_name} IF EXISTS"
_ADD_COLUMN_TEMPLATE = "ALTER TABLE {table_name} ADD {column}"
_DROP_COLUMN_TEMPLATE = "ALTER TABLE {table_name} DROP COLUMN {column_name}"
_ADD_INDEX_TEMPLATE = (
"ALTER TABLE {table_name} ADD {unique} INDEX {index_name} ({column_names})"
)
_DROP_INDEX_TEMPLATE = "ALTER TABLE {table_name} DROP INDEX {index_name}"
_ADD_FK_TEMPLATE = "ALTER TABLE {table_name} ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}"
_DROP_FK_TEMPLATE = "ALTER TABLE {table_name} DROP FOREIGN KEY {fk_name}"
def __init__(self, client: "BaseDBAsyncClient"):
self.client = client
self.schema_generator = self.schema_generator_cls(client)
def create_table(self, model: "Type[Model]"):
raise NotImplementedError
def drop_table(self, model: "Type[Model]"):
raise NotImplementedError
def add_column(self, model: "Type[Model]", field_object: Field):
raise NotImplementedError
def drop_column(self, model: "Type[Model]", column_name: str):
raise NotImplementedError
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
raise NotImplementedError
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False):
raise NotImplementedError
def add_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance):
raise NotImplementedError
def drop_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance):
raise NotImplementedError

View File

@@ -0,0 +1,120 @@
from typing import List, Type
from tortoise import ForeignKeyFieldInstance, Model
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
from tortoise.fields import Field, JSONField, TextField, UUIDField
from aerich.ddl import BaseDDL
class MysqlDDL(BaseDDL):
schema_generator_cls = MySQLSchemaGenerator
DIALECT = MySQLSchemaGenerator.DIALECT
def create_table(self, model: "Type[Model]"):
return self.schema_generator._get_table_sql(model, True)["table_creation_string"]
def drop_table(self, model: "Type[Model]"):
return self._DROP_TABLE_TEMPLATE.format(table_name=model._meta.db_table)
def add_column(self, model: "Type[Model]", field_object: Field):
db_table = model._meta.db_table
default = field_object.default
db_column = field_object.model_field_name
auto_now_add = getattr(field_object, "auto_now_add", False)
auto_now = getattr(field_object, "auto_now", False)
if default is not None or auto_now_add:
if callable(default) or isinstance(field_object, (UUIDField, TextField, JSONField)):
default = ""
else:
default = field_object.to_db_value(default, model)
try:
default = self.schema_generator._column_default_generator(
db_table,
db_column,
self.schema_generator._escape_default_value(default),
auto_now_add,
auto_now,
)
except NotImplementedError:
default = ""
else:
default = ""
return self._ADD_COLUMN_TEMPLATE.format(
table_name=db_table,
column=self.schema_generator._create_string(
db_column=field_object.model_field_name,
field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"),
nullable="NOT NULL" if not field_object.null else "",
unique="UNIQUE" if field_object.unique else "",
comment=self.schema_generator._column_comment_generator(
table=db_table,
column=field_object.model_field_name,
comment=field_object.description,
)
if field_object.description
else "",
is_primary_key=field_object.pk,
default=default,
),
)
def drop_column(self, model: "Type[Model]", column_name: str):
return self._DROP_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table, column_name=column_name
)
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
return self._ADD_INDEX_TEMPLATE.format(
unique="UNIQUE" if unique else "",
index_name=self.schema_generator._generate_index_name(
"idx" if not unique else "uid", model, field_names
),
table_name=model._meta.db_table,
column_names=", ".join([self.schema_generator.quote(f) for f in field_names]),
)
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False):
return self._DROP_INDEX_TEMPLATE.format(
index_name=self.schema_generator._generate_index_name(
"idx" if not unique else "uid", model, field_names
),
table_name=model._meta.db_table,
)
def add_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance):
db_table = model._meta.db_table
to_field_name = field.to_field_instance.source_field
if not to_field_name:
to_field_name = field.to_field_instance.model_field_name
db_column = field.source_field or field.model_field_name + "_id"
fk_name = self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=db_column,
to_table=field.related_model._meta.db_table,
to_field=to_field_name,
)
return self._ADD_FK_TEMPLATE.format(
table_name=db_table,
fk_name=fk_name,
db_column=db_column,
table=field.related_model._meta.db_table,
field=to_field_name,
on_delete=field.on_delete,
)
def drop_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance):
to_field_name = field.to_field_instance.source_field
if not to_field_name:
to_field_name = field.to_field_instance.model_field_name
db_table = model._meta.db_table
return self._DROP_FK_TEMPLATE.format(
table_name=db_table,
fk_name=self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=field.source_field or field.model_field_name + "_id",
to_table=field.related_model._meta.db_table,
to_field=to_field_name,
),
)

6
aerich/exceptions.py Normal file
View File

@@ -0,0 +1,6 @@
class ConfigurationError(Exception):
"""
config error
"""
pass

319
aerich/migrate.py Normal file
View File

@@ -0,0 +1,319 @@
import json
import os
import re
from copy import deepcopy
from datetime import datetime
from typing import Dict, List, Type
from tortoise import BackwardFKRelation, ForeignKeyFieldInstance, Model, Tortoise
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
from tortoise.fields import Field
from aerich.ddl import BaseDDL
from aerich.ddl.mysql import MysqlDDL
from aerich.exceptions import ConfigurationError
from aerich.utils import get_app_connection
class Migrate:
upgrade_operators: List[str] = []
downgrade_operators: List[str] = []
_upgrade_fk_operators: List[str] = []
_downgrade_fk_operators: List[str] = []
ddl: BaseDDL
migrate_config: dict
old_models = "old_models"
diff_app = "diff_models"
app: str
migrate_location: str
@classmethod
def get_old_model_file(cls):
return cls.old_models + ".py"
@classmethod
def _get_all_migrate_files(cls):
return sorted(filter(lambda x: x.endswith("json"), os.listdir(cls.migrate_location)))
@classmethod
def _get_latest_version(cls) -> int:
ret = cls._get_all_migrate_files()
if ret:
return int(ret[-1].split("_")[0])
return 0
@classmethod
def get_all_version_files(cls, is_all=True):
files = cls._get_all_migrate_files()
ret = []
for file in files:
with open(os.path.join(cls.migrate_location, file), "r") as f:
content = json.load(f)
if is_all or not content.get("migrate"):
ret.append(file)
return ret
@classmethod
async def init_with_old_models(cls, config: dict, app: str, location: str):
migrate_config = cls._get_migrate_config(config, app, location)
cls.app = app
cls.migrate_config = migrate_config
cls.migrate_location = os.path.join(location, app)
await Tortoise.init(config=migrate_config)
connection = get_app_connection(config, app)
if connection.schema_generator is MySQLSchemaGenerator:
cls.ddl = MysqlDDL(connection)
else:
raise NotImplementedError("Current only support MySQL")
@classmethod
def _generate_diff_sql(cls, name):
now = datetime.now().strftime("%Y%M%D%H%M%S").replace("/", "")
filename = f"{cls._get_latest_version() + 1}_{now}_{name}.json"
content = {
"upgrade": cls.upgrade_operators,
"downgrade": cls.downgrade_operators,
"migrate": False,
}
with open(os.path.join(cls.migrate_location, filename), "w") as f:
json.dump(content, f, indent=4)
return filename
@classmethod
def migrate(cls, name):
"""
diff old models and new models to generate diff content
:param name:
:return:
"""
if not cls.migrate_config:
raise ConfigurationError("You must call init_with_old_models() first!")
apps = Tortoise.apps
diff_models = apps.get(cls.diff_app)
app_models = apps.get(cls.app)
cls._diff_models(diff_models, app_models)
cls._diff_models(app_models, diff_models, False)
if not cls.upgrade_operators:
return False
cls._merge_operators()
return cls._generate_diff_sql(name)
@classmethod
def _add_operator(cls, operator: str, upgrade=True, fk=False):
"""
add operator,differentiate fk because fk is order limit
:param operator:
:param upgrade:
:param fk:
:return:
"""
if upgrade:
if fk:
cls._upgrade_fk_operators.append(operator)
else:
cls.upgrade_operators.append(operator)
else:
if fk:
cls._downgrade_fk_operators.append(operator)
else:
cls.downgrade_operators.append(operator)
@classmethod
def cp_models(
cls, model_files: List[str], old_model_file,
):
"""
cp currents models to old_model_files
:param model_files:
:param old_model_file:
:return:
"""
pattern = r"(ManyToManyField|ForeignKeyField|OneToOneField)\(('|\")(\w+)."
for i, model_file in enumerate(model_files):
with open(model_file, "r") as f:
content = f.read()
ret = re.sub(pattern, rf"\1(\2{cls.diff_app}.", content)
with open(old_model_file, "w" if i == 0 else "w+a") as f:
f.write(ret)
@classmethod
def _get_migrate_config(cls, config: dict, app: str, location: str):
"""
generate tmp config with old models
:param config:
:param app:
:param location:
:return:
"""
temp_config = deepcopy(config)
path = os.path.join(location, app, cls.old_models)
path = path.replace("/", ".").lstrip(".")
temp_config["apps"][cls.diff_app] = {"models": [path]}
return temp_config
@classmethod
def write_old_models(cls, config: dict, app: str, location: str):
"""
write new models to old models
:param config:
:param app:
:param location:
:return:
"""
old_model_files = []
models = config.get("apps").get(app).get("models")
for model in models:
old_model_files.append(model.replace(".", "/") + ".py")
cls.cp_models(old_model_files, os.path.join(location, app, cls.get_old_model_file()))
@classmethod
def _diff_models(
cls, old_models: Dict[str, Type[Model]], new_models: Dict[str, Type[Model]], upgrade=True
):
"""
diff models and add operators
:param old_models:
:param new_models:
:param upgrade:
:return:
"""
for new_model_str, new_model in new_models.items():
if new_model_str not in old_models.keys():
cls._add_operator(cls.add_model(new_model), upgrade)
else:
cls.diff_model(old_models.get(new_model_str), new_model, upgrade)
for old_model in old_models:
if old_model not in new_models.keys():
cls._add_operator(cls.remove_model(old_models.get(old_model)), upgrade)
@classmethod
def add_model(cls, model: Type[Model]):
return cls.ddl.create_table(model)
@classmethod
def remove_model(cls, model: Type[Model]):
return cls.ddl.drop_table(model)
@classmethod
def diff_model(cls, old_model: Type[Model], new_model: Type[Model], upgrade=True):
"""
diff single model
:param old_model:
:param new_model:
:param upgrade:
:return:
"""
old_fields_map = old_model._meta.fields_map
new_fields_map = new_model._meta.fields_map
old_keys = old_fields_map.keys()
new_keys = new_fields_map.keys()
for new_key in new_keys:
new_field = new_fields_map.get(new_key)
if cls._exclude_field(new_field):
continue
if new_key not in old_keys:
cls._add_operator(
cls._add_field(new_model, new_field),
upgrade,
isinstance(new_field, ForeignKeyFieldInstance),
)
else:
old_field = old_fields_map.get(new_key)
if old_field.index and not new_field.index:
cls._add_operator(
cls._remove_index(old_model, old_field),
upgrade,
isinstance(old_field, ForeignKeyFieldInstance),
)
elif new_field.index and not old_field.index:
cls._add_operator(
cls._add_index(new_model, new_field),
upgrade,
isinstance(new_field, ForeignKeyFieldInstance),
)
for old_key in old_keys:
field = old_fields_map.get(old_key)
if old_key not in new_keys and not cls._exclude_field(field):
cls._add_operator(
cls._remove_field(old_model, field),
upgrade,
isinstance(field, ForeignKeyFieldInstance),
)
@classmethod
def _remove_index(cls, model: Type[Model], field: Field):
return cls.ddl.drop_index(model, [field.model_field_name], field.unique)
@classmethod
def _add_index(cls, model: Type[Model], field: Field):
return cls.ddl.add_index(model, [field.model_field_name], field.unique)
@classmethod
def _exclude_field(cls, field: Field):
"""
exclude BackwardFKRelation
:param field:
:return:
"""
return isinstance(field, BackwardFKRelation)
@classmethod
def _add_field(cls, model: Type[Model], field: Field):
if isinstance(field, ForeignKeyFieldInstance):
return cls.ddl.add_fk(model, field)
else:
return cls.ddl.add_column(model, field)
@classmethod
def _remove_field(cls, model: Type[Model], field: Field):
if isinstance(field, ForeignKeyFieldInstance):
return cls.ddl.drop_fk(model, field)
return cls.ddl.drop_column(model, field.model_field_name)
@classmethod
def _add_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance):
"""
add fk
:param model:
:param field:
:return:
"""
return cls.ddl.add_fk(model, field)
@classmethod
def _remove_fk(cls, model: Type[Model], field: ForeignKeyFieldInstance):
"""
drop fk
:param model:
:param field:
:return:
"""
return cls.ddl.drop_fk(model, field)
@classmethod
def _merge_operators(cls):
"""
fk must be last when add,first when drop
:return:
"""
for _upgrade_fk_operator in cls._upgrade_fk_operators:
if "ADD" in _upgrade_fk_operator:
cls.upgrade_operators.append(_upgrade_fk_operator)
else:
cls.upgrade_operators.insert(0, _upgrade_fk_operator)
for _downgrade_fk_operator in cls._downgrade_fk_operators:
if "ADD" in _downgrade_fk_operator:
cls.downgrade_operators.append(_downgrade_fk_operator)
else:
cls.downgrade_operators.insert(0, _downgrade_fk_operator)

11
aerich/utils.py Normal file
View File

@@ -0,0 +1,11 @@
from tortoise import Tortoise
def get_app_connection(config, app):
"""
get tortoise app
:param config:
:param app:
:return:
"""
return Tortoise.get_connection(config.get("apps").get(app).get("default_connection"))