finish base cli
This commit is contained in:
@@ -1 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
__version__ = "0.1.0"
|
||||
|
||||
208
alice/cli.py
208
alice/cli.py
@@ -1,13 +1,13 @@
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from enum import Enum
|
||||
|
||||
import asyncclick as click
|
||||
from asyncclick import BadParameter, ClickException
|
||||
from tortoise import Tortoise, generate_schema_for_client
|
||||
from tortoise import generate_schema_for_client, ConfigurationError, Tortoise
|
||||
|
||||
from alice.backends.mysql import MysqlDDL
|
||||
from alice.migrate import Migrate
|
||||
from alice.utils import get_app_connection
|
||||
|
||||
@@ -15,103 +15,165 @@ sys.path.append(os.getcwd())
|
||||
|
||||
|
||||
class Color(str, Enum):
|
||||
green = 'green'
|
||||
red = 'red'
|
||||
green = "green"
|
||||
red = "red"
|
||||
|
||||
|
||||
@click.group(context_settings={'help_option_names': ['-h', '--help']})
|
||||
@click.option('-c', '--config', default='settings', show_default=True,
|
||||
help='Tortoise-ORM config module, will read config variable from it, default is `settings`.')
|
||||
@click.option('-t', '--tortoise-orm', default='TORTOISE_ORM', show_default=True,
|
||||
help='Tortoise-ORM config dict variable, default is `TORTOISE_ORM`.')
|
||||
@click.option('-l', '--location', default='./migrations', show_default=True,
|
||||
help='Migrate store location, default is `./migrations`.')
|
||||
@click.option('-a', '--app', default='models', show_default=True, help='Tortoise-ORM app name, default is `models`.')
|
||||
@click.group(context_settings={"help_option_names": ["-h", "--help"]})
|
||||
@click.option(
|
||||
"--config",
|
||||
default="settings",
|
||||
show_default=True,
|
||||
help="Tortoise-ORM config module, will read config variable from it.",
|
||||
)
|
||||
@click.option(
|
||||
"--tortoise-orm",
|
||||
default="TORTOISE_ORM",
|
||||
show_default=True,
|
||||
help="Tortoise-ORM config dict variable.",
|
||||
)
|
||||
@click.option(
|
||||
"--location", default="./migrations", show_default=True, help="Migrate store location."
|
||||
)
|
||||
@click.option("--app", default="models", show_default=True, help="Tortoise-ORM app name.")
|
||||
@click.pass_context
|
||||
async def cli(ctx, config, tortoise_orm, location, app):
|
||||
ctx.ensure_object(dict)
|
||||
try:
|
||||
config_module = importlib.import_module(config)
|
||||
config = getattr(config_module, tortoise_orm, None)
|
||||
if not config:
|
||||
raise BadParameter(param_hint=['--config'],
|
||||
message=f'Can\'t get "{tortoise_orm}" from module "{config_module}"')
|
||||
|
||||
await Tortoise.init(config=config)
|
||||
|
||||
ctx.obj['config'] = config
|
||||
ctx.obj['location'] = location
|
||||
ctx.obj['app'] = app
|
||||
|
||||
if app not in config.get('apps').keys():
|
||||
raise BadParameter(param_hint=['--app'], message=f'No app found in "{config}"')
|
||||
|
||||
except ModuleNotFoundError:
|
||||
raise BadParameter(param_hint=['--tortoise-orm'], message=f'No module named "{config}"')
|
||||
raise BadParameter(param_hint=["--tortoise-orm"], message=f'No module named "{config}"')
|
||||
config = getattr(config_module, tortoise_orm, None)
|
||||
if not config:
|
||||
raise BadParameter(
|
||||
param_hint=["--config"],
|
||||
message=f'Can\'t get "{tortoise_orm}" from module "{config_module}"',
|
||||
)
|
||||
if app not in config.get("apps").keys():
|
||||
raise BadParameter(param_hint=["--app"], message=f'No app found in "{config}"')
|
||||
|
||||
ctx.obj["config"] = config
|
||||
ctx.obj["location"] = location
|
||||
ctx.obj["app"] = app
|
||||
try:
|
||||
await Migrate.init_with_old_models(config, app, location)
|
||||
except ConfigurationError:
|
||||
pass
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(help="Generate migrate changes file.")
|
||||
@click.option("--name", default="update", show_default=True, help="Migrate name.")
|
||||
@click.pass_context
|
||||
def migrate(ctx):
|
||||
config = ctx.obj['config']
|
||||
location = ctx.obj['location']
|
||||
app = ctx.obj['app']
|
||||
async def migrate(ctx, name):
|
||||
config = ctx.obj["config"]
|
||||
location = ctx.obj["location"]
|
||||
app = ctx.obj["app"]
|
||||
|
||||
old_models = Migrate.read_old_models(app, location)
|
||||
print(old_models)
|
||||
|
||||
new_models = Tortoise.apps.get(app)
|
||||
print(new_models)
|
||||
|
||||
ret = Migrate(MysqlDDL(get_app_connection(config, app))).diff_models(old_models, new_models)
|
||||
print(ret)
|
||||
ret = Migrate.migrate(name)
|
||||
if not ret:
|
||||
click.secho("No changes detected", fg=Color.green)
|
||||
else:
|
||||
Migrate.write_old_models(config, app, location)
|
||||
click.secho(f"Success migrate {ret}", fg=Color.green)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(help="Upgrade to latest version.")
|
||||
@click.pass_context
|
||||
def upgrade():
|
||||
pass
|
||||
async def upgrade(ctx):
|
||||
app = ctx.obj["app"]
|
||||
config = ctx.obj["config"]
|
||||
connection = get_app_connection(config, app)
|
||||
available_versions = Migrate.get_all_version_files(is_all=False)
|
||||
if not available_versions:
|
||||
return click.secho("No migrate items", fg=Color.green)
|
||||
async with connection._in_transaction() as conn:
|
||||
for file in available_versions:
|
||||
file_path = os.path.join(Migrate.migrate_location, file)
|
||||
with open(file_path, "r") as f:
|
||||
content = json.load(f)
|
||||
upgrade_query_list = content.get("upgrade")
|
||||
for upgrade_query in upgrade_query_list:
|
||||
await conn.execute_query(upgrade_query)
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
content["migrate"] = True
|
||||
json.dump(content, f, indent=4)
|
||||
click.secho(f"Success upgrade {file}", fg=Color.green)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@cli.command(help="Downgrade to previous version.")
|
||||
@click.pass_context
|
||||
def downgrade():
|
||||
pass
|
||||
async def downgrade(ctx):
|
||||
app = ctx.obj["app"]
|
||||
config = ctx.obj["config"]
|
||||
connection = get_app_connection(config, app)
|
||||
available_versions = Migrate.get_all_version_files()
|
||||
if not available_versions:
|
||||
return click.secho("No migrate items", fg=Color.green)
|
||||
|
||||
async with connection._in_transaction() as conn:
|
||||
for file in available_versions:
|
||||
file_path = os.path.join(Migrate.migrate_location, file)
|
||||
with open(file_path, "r") as f:
|
||||
content = json.load(f)
|
||||
if content.get("migrate"):
|
||||
downgrade_query_list = content.get("downgrade")
|
||||
for downgrade_query in downgrade_query_list:
|
||||
await conn.execute_query(downgrade_query)
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
content["migrate"] = False
|
||||
json.dump(content, f, indent=4)
|
||||
return click.secho(f"Success downgrade {file}", fg=Color.green)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option('--safe', is_flag=True, default=True,
|
||||
help='When set to true, creates the table only when it does not already exist..', show_default=True)
|
||||
@cli.command(help="Show current available heads in migrate location.")
|
||||
@click.pass_context
|
||||
async def initdb(ctx, safe):
|
||||
location = ctx.obj['location']
|
||||
config = ctx.obj['config']
|
||||
app = ctx.obj['app']
|
||||
|
||||
await generate_schema_for_client(get_app_connection(config, app), safe)
|
||||
|
||||
Migrate.write_old_models(app, location)
|
||||
|
||||
click.secho(f'Success initdb for app `{app}`', fg=Color.green)
|
||||
def heads(ctx):
|
||||
for version in Migrate.get_all_version_files(is_all=False):
|
||||
click.secho(version, fg=Color.green)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option('--overwrite', is_flag=True, default=False, help=f'Overwrite {Migrate.old_models}.', show_default=True)
|
||||
@cli.command(help="List all migrate items.")
|
||||
@click.pass_context
|
||||
def init(ctx, overwrite):
|
||||
location = ctx.obj['location']
|
||||
app = ctx.obj['app']
|
||||
def history(ctx):
|
||||
for version in Migrate.get_all_version_files():
|
||||
click.secho(version, fg=Color.green)
|
||||
|
||||
|
||||
@cli.command(
|
||||
help="Init migrate location and generate schema, you must call first before other actions."
|
||||
)
|
||||
@click.option(
|
||||
"--safe",
|
||||
is_flag=True,
|
||||
default=True,
|
||||
help="When set to true, creates the table only when it does not already exist.",
|
||||
show_default=True,
|
||||
)
|
||||
@click.pass_context
|
||||
async def init(ctx, safe):
|
||||
location = ctx.obj["location"]
|
||||
app = ctx.obj["app"]
|
||||
config = ctx.obj["config"]
|
||||
|
||||
if not os.path.isdir(location):
|
||||
os.mkdir(location)
|
||||
dirname = os.path.join(location, app)
|
||||
if not os.path.isdir(dirname):
|
||||
os.mkdir(dirname)
|
||||
click.secho(f'Success create migrate location {dirname}', fg=Color.green)
|
||||
if overwrite:
|
||||
Migrate.write_old_models(app, location)
|
||||
|
||||
dirname = os.path.join(location, app)
|
||||
if not os.path.isdir(dirname):
|
||||
os.mkdir(dirname)
|
||||
click.secho(f"Success create migrate location {dirname}", fg=Color.green)
|
||||
else:
|
||||
raise ClickException('Already inited')
|
||||
raise ClickException(f"Already inited app `{app}`")
|
||||
|
||||
Migrate.write_old_models(config, app, location)
|
||||
|
||||
await Migrate.init_with_old_models(config, app, location)
|
||||
await generate_schema_for_client(get_app_connection(config, app), safe)
|
||||
|
||||
click.secho(f"Success init for app `{app}`", fg=Color.green)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cli(_anyio_backend='asyncio')
|
||||
def main():
|
||||
cli(_anyio_backend="asyncio")
|
||||
|
||||
@@ -1,32 +1,32 @@
|
||||
from typing import Type, List
|
||||
from typing import List, Type
|
||||
|
||||
from tortoise import Model, BaseDBAsyncClient, ForeignKeyFieldInstance
|
||||
from tortoise import BaseDBAsyncClient, ForeignKeyFieldInstance, Model
|
||||
from tortoise.backends.base.schema_generator import BaseSchemaGenerator
|
||||
from tortoise.fields import Field, UUIDField, TextField, JSONField
|
||||
from tortoise.fields import Field, JSONField, TextField, UUIDField
|
||||
|
||||
|
||||
class DDL:
|
||||
schema_generator_cls: Type[BaseSchemaGenerator] = BaseSchemaGenerator
|
||||
DIALECT = "sql"
|
||||
_DROP_TABLE_TEMPLATE = 'DROP TABLE {table_name} IF EXISTS'
|
||||
_ADD_COLUMN_TEMPLATE = 'ALTER TABLE {table_name} ADD {column}'
|
||||
_DROP_COLUMN_TEMPLATE = 'ALTER TABLE {table_name} DROP COLUMN {column_name}'
|
||||
_ADD_INDEX_TEMPLATE = 'ALTER TABLE {table_name} ADD {unique} INDEX {index_name} ({column_names})'
|
||||
_DROP_INDEX_TEMPLATE = 'ALTER TABLE {table_name} DROP INDEX {index_name}'
|
||||
_ADD_FK_TEMPLATE = 'ALTER TABLE {table_name} ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}'
|
||||
_DROP_FK_TEMPLATE = 'ALTER TABLE {table_name} DROP FOREIGN KEY {fk_name}'
|
||||
_DROP_TABLE_TEMPLATE = "DROP TABLE {table_name} IF EXISTS"
|
||||
_ADD_COLUMN_TEMPLATE = "ALTER TABLE {table_name} ADD {column}"
|
||||
_DROP_COLUMN_TEMPLATE = "ALTER TABLE {table_name} DROP COLUMN {column_name}"
|
||||
_ADD_INDEX_TEMPLATE = (
|
||||
"ALTER TABLE {table_name} ADD {unique} INDEX {index_name} ({column_names})"
|
||||
)
|
||||
_DROP_INDEX_TEMPLATE = "ALTER TABLE {table_name} DROP INDEX {index_name}"
|
||||
_ADD_FK_TEMPLATE = "ALTER TABLE {table_name} ADD CONSTRAINT `{fk_name}` FOREIGN KEY (`{db_column}`) REFERENCES `{table}` (`{field}`) ON DELETE {on_delete}"
|
||||
_DROP_FK_TEMPLATE = "ALTER TABLE {table_name} DROP FOREIGN KEY {fk_name}"
|
||||
|
||||
def __init__(self, client: "BaseDBAsyncClient"):
|
||||
self.client = client
|
||||
self.schema_generator = self.schema_generator_cls(client)
|
||||
|
||||
def create_table(self, model: "Type[Model]"):
|
||||
return self.schema_generator._get_table_sql(model, True)['table_creation_string']
|
||||
return self.schema_generator._get_table_sql(model, True)["table_creation_string"]
|
||||
|
||||
def drop_table(self, model: "Type[Model]"):
|
||||
return self._DROP_TABLE_TEMPLATE.format(
|
||||
table_name=model._meta.db_table
|
||||
)
|
||||
return self._DROP_TABLE_TEMPLATE.format(table_name=model._meta.db_table)
|
||||
|
||||
def add_column(self, model: "Type[Model]", field_object: Field):
|
||||
db_table = model._meta.db_table
|
||||
@@ -59,33 +59,37 @@ class DDL:
|
||||
nullable="NOT NULL" if not field_object.null else "",
|
||||
unique="UNIQUE" if field_object.unique else "",
|
||||
comment=self.schema_generator._column_comment_generator(
|
||||
table=db_table, column=field_object.model_field_name, comment=field_object.description
|
||||
table=db_table,
|
||||
column=field_object.model_field_name,
|
||||
comment=field_object.description,
|
||||
)
|
||||
if field_object.description else "",
|
||||
if field_object.description
|
||||
else "",
|
||||
is_primary_key=field_object.pk,
|
||||
default=default
|
||||
)
|
||||
default=default,
|
||||
),
|
||||
)
|
||||
|
||||
def drop_column(self, model: "Type[Model]", column_name: str):
|
||||
return self._DROP_COLUMN_TEMPLATE.format(
|
||||
table_name=model._meta.db_table,
|
||||
column_name=column_name
|
||||
table_name=model._meta.db_table, column_name=column_name
|
||||
)
|
||||
|
||||
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
|
||||
return self._ADD_INDEX_TEMPLATE.format(
|
||||
unique='UNIQUE' if unique else '',
|
||||
index_name=self.schema_generator._generate_index_name("idx" if not unique else "uid", model,
|
||||
field_names),
|
||||
unique="UNIQUE" if unique else "",
|
||||
index_name=self.schema_generator._generate_index_name(
|
||||
"idx" if not unique else "uid", model, field_names
|
||||
),
|
||||
table_name=model._meta.db_table,
|
||||
column_names=", ".join([self.schema_generator.quote(f) for f in field_names]),
|
||||
)
|
||||
|
||||
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False):
|
||||
return self._DROP_INDEX_TEMPLATE.format(
|
||||
index_name=self.schema_generator._generate_index_name("idx" if not unique else "uid", model,
|
||||
field_names),
|
||||
index_name=self.schema_generator._generate_index_name(
|
||||
"idx" if not unique else "uid", model, field_names
|
||||
),
|
||||
table_name=model._meta.db_table,
|
||||
)
|
||||
|
||||
@@ -99,7 +103,7 @@ class DDL:
|
||||
from_table=db_table,
|
||||
from_field=field.model_field_name,
|
||||
to_table=field.related_model._meta.db_table,
|
||||
to_field=to_field_name
|
||||
to_field=to_field_name,
|
||||
)
|
||||
return self._ADD_FK_TEMPLATE.format(
|
||||
table_name=db_table,
|
||||
@@ -120,6 +124,6 @@ class DDL:
|
||||
from_table=model._meta.db_table,
|
||||
from_field=field.model_field_name,
|
||||
to_table=field.related_model._meta.db_table,
|
||||
to_field=to_field_name
|
||||
)
|
||||
to_field=to_field_name,
|
||||
),
|
||||
)
|
||||
@@ -1,8 +1,8 @@
|
||||
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
|
||||
|
||||
from alice.backends import DDL
|
||||
from alice.ddl import DDL
|
||||
|
||||
|
||||
class MysqlDDL(DDL):
|
||||
schema_generator_cls = MySQLSchemaGenerator
|
||||
DIALECT = "mysql"
|
||||
DIALECT = MySQLSchemaGenerator.DIALECT
|
||||
6
alice/exceptions.py
Normal file
6
alice/exceptions.py
Normal file
@@ -0,0 +1,6 @@
|
||||
class ConfigurationError(Exception):
|
||||
"""
|
||||
config error
|
||||
"""
|
||||
|
||||
pass
|
||||
236
alice/migrate.py
236
alice/migrate.py
@@ -1,79 +1,183 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Type
|
||||
|
||||
import dill
|
||||
from typing import List, Type, Dict
|
||||
|
||||
from tortoise import Model, ForeignKeyFieldInstance, Tortoise
|
||||
from tortoise import ForeignKeyFieldInstance, Model, Tortoise
|
||||
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
|
||||
from tortoise.fields import Field
|
||||
|
||||
from alice.backends import DDL
|
||||
from alice.ddl import DDL
|
||||
from alice.ddl.mysql import MysqlDDL
|
||||
from alice.exceptions import ConfigurationError
|
||||
from alice.utils import get_app_connection
|
||||
|
||||
|
||||
class Migrate:
|
||||
operators: List
|
||||
upgrade_operators: List[str] = []
|
||||
downgrade_operators: List[str] = []
|
||||
ddl: DDL
|
||||
old_models = 'old_models.pickle'
|
||||
migrate_config: dict
|
||||
old_models = "old_models"
|
||||
diff_app = "diff_models"
|
||||
app: str
|
||||
migrate_location: str
|
||||
|
||||
def __init__(self, ddl: DDL):
|
||||
self.operators = []
|
||||
self.ddl = ddl
|
||||
@classmethod
|
||||
def get_old_model_file(cls):
|
||||
return cls.old_models + ".py"
|
||||
|
||||
@staticmethod
|
||||
def write_old_models(app, location):
|
||||
ret = Tortoise.apps.get(app)
|
||||
old_models = {}
|
||||
for k, v in ret.items():
|
||||
old_models[k] = deepcopy(v)
|
||||
@classmethod
|
||||
def _get_all_migrate_files(cls):
|
||||
return sorted(filter(lambda x: x.endswith("json"), os.listdir(cls.migrate_location)))
|
||||
|
||||
dirname = os.path.join(location, app)
|
||||
@classmethod
|
||||
def _get_latest_version(cls) -> int:
|
||||
ret = cls._get_all_migrate_files()
|
||||
if ret:
|
||||
return int(ret[-1].split("_")[0])
|
||||
return 0
|
||||
|
||||
with open(os.path.join(dirname, Migrate.old_models), 'wb') as f:
|
||||
dill.dump(old_models, f, )
|
||||
@classmethod
|
||||
def get_all_version_files(cls, is_all=True):
|
||||
files = cls._get_all_migrate_files()
|
||||
ret = []
|
||||
for file in files:
|
||||
with open(os.path.join(cls.migrate_location, file), "r") as f:
|
||||
content = json.load(f)
|
||||
if is_all or not content.get("migrate"):
|
||||
ret.append(file)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def read_old_models(app, location):
|
||||
dirname = os.path.join(location, app)
|
||||
with open(os.path.join(dirname, Migrate.old_models), 'rb') as f:
|
||||
return dill.load(f, )
|
||||
@classmethod
|
||||
async def init_with_old_models(cls, config: dict, app: str, location: str):
|
||||
migrate_config = cls._get_migrate_config(config, app, location)
|
||||
|
||||
def diff_models_module(self, old_models_module, new_models_module):
|
||||
old_module = importlib.import_module(old_models_module)
|
||||
old_models = {}
|
||||
new_models = {}
|
||||
for name, obj in inspect.getmembers(old_module):
|
||||
if inspect.isclass(obj) and issubclass(obj, Model):
|
||||
old_models[obj.__name__] = obj
|
||||
cls.app = app
|
||||
cls.migrate_config = migrate_config
|
||||
cls.migrate_location = os.path.join(location, app)
|
||||
|
||||
new_module = importlib.import_module(new_models_module)
|
||||
for name, obj in inspect.getmembers(new_module):
|
||||
if inspect.isclass(obj) and issubclass(obj, Model):
|
||||
new_models[obj.__name__] = obj
|
||||
self.diff_models(old_models, new_models)
|
||||
await Tortoise.init(config=migrate_config)
|
||||
|
||||
def diff_models(self, old_models: Dict[str, Type[Model]], new_models: Dict[str, Type[Model]]):
|
||||
connection = get_app_connection(config, app)
|
||||
if connection.schema_generator is MySQLSchemaGenerator:
|
||||
cls.ddl = MysqlDDL(connection)
|
||||
else:
|
||||
raise NotImplementedError("Current only support MySQL")
|
||||
|
||||
@classmethod
|
||||
def _generate_diff_sql(cls, name):
|
||||
now = datetime.now().strftime("%Y%M%D%H%M%S").replace("/", "")
|
||||
filename = f"{cls._get_latest_version() + 1}_{now}_{name}.json"
|
||||
content = {
|
||||
"upgrade": cls.upgrade_operators,
|
||||
"download": cls.downgrade_operators,
|
||||
"migrate": False,
|
||||
}
|
||||
with open(os.path.join(cls.migrate_location, filename), "w") as f:
|
||||
json.dump(content, f, indent=4)
|
||||
return filename
|
||||
|
||||
@classmethod
|
||||
def migrate(cls, name):
|
||||
if not cls.migrate_config:
|
||||
raise ConfigurationError("You must call init_with_old_models() first!")
|
||||
apps = Tortoise.apps
|
||||
diff_models = apps.get(cls.diff_app)
|
||||
app_models = apps.get(cls.app)
|
||||
|
||||
cls._diff_models(diff_models, app_models)
|
||||
cls._diff_models(app_models, diff_models, False)
|
||||
|
||||
if not cls.upgrade_operators:
|
||||
return False
|
||||
|
||||
return cls._generate_diff_sql(name)
|
||||
|
||||
@classmethod
|
||||
def _add_operator(cls, operator: str, upgrade=True):
|
||||
if upgrade:
|
||||
cls.upgrade_operators.append(operator)
|
||||
else:
|
||||
cls.downgrade_operators.append(operator)
|
||||
|
||||
@classmethod
|
||||
def cp_models(
|
||||
cls, model_files: List[str], old_model_file,
|
||||
):
|
||||
"""
|
||||
cp currents models to old_model_files
|
||||
:param model_files:
|
||||
:param old_model_file:
|
||||
:return:
|
||||
"""
|
||||
pattern = (
|
||||
r"(ManyToManyField|ForeignKeyField|OneToOneField)\((model_name)?(\"|\')(\w+)(.+)\)"
|
||||
)
|
||||
for i, model_file in enumerate(model_files):
|
||||
with open(model_file, "r") as f:
|
||||
content = f.read()
|
||||
ret = re.sub(pattern, rf"\1\2(\3{cls.diff_app}\5)", content)
|
||||
with open(old_model_file, "w" if i == 0 else "w+a") as f:
|
||||
f.write(ret)
|
||||
|
||||
@classmethod
|
||||
def _get_migrate_config(cls, config: dict, app: str, location: str):
|
||||
temp_config = deepcopy(config)
|
||||
path = os.path.join(location, app, cls.old_models)
|
||||
path = path.replace("/", ".").lstrip(".")
|
||||
temp_config["apps"][cls.diff_app] = {"models": [path]}
|
||||
return temp_config
|
||||
|
||||
@classmethod
|
||||
def write_old_models(cls, config: dict, app: str, location: str):
|
||||
old_model_files = []
|
||||
models = config.get("apps").get(app).get("models")
|
||||
for model in models:
|
||||
old_model_files.append(model.replace(".", "/") + ".py")
|
||||
|
||||
cls.cp_models(old_model_files, os.path.join(location, app, cls.get_old_model_file()))
|
||||
|
||||
@classmethod
|
||||
def _diff_models(
|
||||
cls, old_models: Dict[str, Type[Model]], new_models: Dict[str, Type[Model]], upgrade=True
|
||||
):
|
||||
"""
|
||||
diff models and add operators
|
||||
:param old_models:
|
||||
:param new_models:
|
||||
:param upgrade:
|
||||
:return:
|
||||
"""
|
||||
for new_model_str, new_model in new_models.items():
|
||||
if new_model_str not in old_models.keys():
|
||||
self.add_model(new_model)
|
||||
cls._add_operator(cls.add_model(new_model), upgrade)
|
||||
else:
|
||||
self.diff_model(old_models.get(new_model_str), new_model)
|
||||
cls.diff_model(old_models.get(new_model_str), new_model, upgrade)
|
||||
|
||||
for old_model in old_models:
|
||||
if old_model not in new_models.keys():
|
||||
self.remove_model(old_models.get(old_model))
|
||||
cls._add_operator(cls.remove_model(old_models.get(old_model)), upgrade)
|
||||
|
||||
def _add_operator(self, operator):
|
||||
self.operators.append(operator)
|
||||
@classmethod
|
||||
def add_model(cls, model: Type[Model]):
|
||||
return cls.ddl.create_table(model)
|
||||
|
||||
def add_model(self, model: Type[Model]):
|
||||
self._add_operator(self.ddl.create_table(model))
|
||||
@classmethod
|
||||
def remove_model(cls, model: Type[Model]):
|
||||
return cls.ddl.drop_table(model)
|
||||
|
||||
def remove_model(self, model: Type[Model]):
|
||||
self._add_operator(self.ddl.drop_table(model))
|
||||
|
||||
def diff_model(self, old_model: Type[Model], new_model: Type[Model]):
|
||||
@classmethod
|
||||
def diff_model(cls, old_model: Type[Model], new_model: Type[Model], upgrade=True):
|
||||
"""
|
||||
diff single model
|
||||
:param old_model:
|
||||
:param new_model:
|
||||
:param upgrade:
|
||||
:return:
|
||||
"""
|
||||
old_fields_map = old_model._meta.fields_map
|
||||
new_fields_map = new_model._meta.fields_map
|
||||
old_keys = old_fields_map.keys()
|
||||
@@ -81,31 +185,35 @@ class Migrate:
|
||||
for new_key in new_keys:
|
||||
new_field = new_fields_map.get(new_key)
|
||||
if new_key not in old_keys:
|
||||
self._add_field(new_model, new_field)
|
||||
cls._add_operator(cls._add_field(new_model, new_field), upgrade)
|
||||
else:
|
||||
old_field = old_fields_map.get(new_key)
|
||||
if old_field.index and not new_field.index:
|
||||
self._remove_index(old_model, old_field)
|
||||
cls._add_operator(cls._remove_index(old_model, old_field), upgrade)
|
||||
elif new_field.index and not old_field.index:
|
||||
self._add_index(new_model, new_field)
|
||||
cls._add_operator(cls._add_index(new_model, new_field), upgrade)
|
||||
for old_key in old_keys:
|
||||
if old_key not in new_keys:
|
||||
field = old_fields_map.get(old_key)
|
||||
self._remove_field(old_model, field)
|
||||
cls._add_operator(cls._remove_field(old_model, field), upgrade)
|
||||
|
||||
def _remove_index(self, model: Type[Model], field: Field):
|
||||
self._add_operator(self.ddl.drop_index(model, [field.model_field_name], field.unique))
|
||||
@classmethod
|
||||
def _remove_index(cls, model: Type[Model], field: Field):
|
||||
return cls.ddl.drop_index(model, [field.model_field_name], field.unique)
|
||||
|
||||
def _add_index(self, model: Type[Model], field: Field):
|
||||
self._add_operator(self.ddl.add_index(model, [field.model_field_name], field.unique))
|
||||
@classmethod
|
||||
def _add_index(cls, model: Type[Model], field: Field):
|
||||
return cls.ddl.add_index(model, [field.model_field_name], field.unique)
|
||||
|
||||
def _add_field(self, model: Type[Model], field: Field):
|
||||
@classmethod
|
||||
def _add_field(cls, model: Type[Model], field: Field):
|
||||
if isinstance(field, ForeignKeyFieldInstance):
|
||||
self._add_operator(self.ddl.add_fk(model, field))
|
||||
return cls.ddl.add_fk(model, field)
|
||||
else:
|
||||
self._add_operator(self.ddl.add_column(model, field))
|
||||
return cls.ddl.add_column(model, field)
|
||||
|
||||
def _remove_field(self, model: Type[Model], field: Field):
|
||||
@classmethod
|
||||
def _remove_field(cls, model: Type[Model], field: Field):
|
||||
if isinstance(field, ForeignKeyFieldInstance):
|
||||
self._add_operator(self.ddl.drop_fk(model, field))
|
||||
self._add_operator(self.ddl.drop_column(model, field.model_field_name))
|
||||
return cls.ddl.drop_fk(model, field)
|
||||
return cls.ddl.drop_column(model, field.model_field_name)
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from tortoise import Tortoise
|
||||
|
||||
|
||||
def get_app_connection(config: dict, app: str):
|
||||
def get_app_connection(config, app):
|
||||
"""
|
||||
get tortoise connection by app
|
||||
get tortoise app
|
||||
:param config:
|
||||
:param app:
|
||||
:return:
|
||||
"""
|
||||
return Tortoise.get_connection(config.get('apps').get(app).get('default_connection')),
|
||||
return Tortoise.get_connection(config.get("apps").get(app).get("default_connection"))
|
||||
|
||||
Reference in New Issue
Block a user