finish base cli

This commit is contained in:
long2ice
2020-05-14 18:46:17 +08:00
parent d385647fba
commit 2da90ecca6
21 changed files with 465 additions and 623 deletions

View File

@@ -1,79 +1,183 @@
import importlib
import inspect
import json
import os
import re
from copy import deepcopy
from datetime import datetime
from typing import Dict, List, Type
import dill
from typing import List, Type, Dict
from tortoise import Model, ForeignKeyFieldInstance, Tortoise
from tortoise import ForeignKeyFieldInstance, Model, Tortoise
from tortoise.backends.mysql.schema_generator import MySQLSchemaGenerator
from tortoise.fields import Field
from alice.backends import DDL
from alice.ddl import DDL
from alice.ddl.mysql import MysqlDDL
from alice.exceptions import ConfigurationError
from alice.utils import get_app_connection
class Migrate:
operators: List
upgrade_operators: List[str] = []
downgrade_operators: List[str] = []
ddl: DDL
old_models = 'old_models.pickle'
migrate_config: dict
old_models = "old_models"
diff_app = "diff_models"
app: str
migrate_location: str
def __init__(self, ddl: DDL):
self.operators = []
self.ddl = ddl
@classmethod
def get_old_model_file(cls):
return cls.old_models + ".py"
@staticmethod
def write_old_models(app, location):
ret = Tortoise.apps.get(app)
old_models = {}
for k, v in ret.items():
old_models[k] = deepcopy(v)
@classmethod
def _get_all_migrate_files(cls):
return sorted(filter(lambda x: x.endswith("json"), os.listdir(cls.migrate_location)))
dirname = os.path.join(location, app)
@classmethod
def _get_latest_version(cls) -> int:
ret = cls._get_all_migrate_files()
if ret:
return int(ret[-1].split("_")[0])
return 0
with open(os.path.join(dirname, Migrate.old_models), 'wb') as f:
dill.dump(old_models, f, )
@classmethod
def get_all_version_files(cls, is_all=True):
files = cls._get_all_migrate_files()
ret = []
for file in files:
with open(os.path.join(cls.migrate_location, file), "r") as f:
content = json.load(f)
if is_all or not content.get("migrate"):
ret.append(file)
return ret
@staticmethod
def read_old_models(app, location):
dirname = os.path.join(location, app)
with open(os.path.join(dirname, Migrate.old_models), 'rb') as f:
return dill.load(f, )
@classmethod
async def init_with_old_models(cls, config: dict, app: str, location: str):
migrate_config = cls._get_migrate_config(config, app, location)
def diff_models_module(self, old_models_module, new_models_module):
old_module = importlib.import_module(old_models_module)
old_models = {}
new_models = {}
for name, obj in inspect.getmembers(old_module):
if inspect.isclass(obj) and issubclass(obj, Model):
old_models[obj.__name__] = obj
cls.app = app
cls.migrate_config = migrate_config
cls.migrate_location = os.path.join(location, app)
new_module = importlib.import_module(new_models_module)
for name, obj in inspect.getmembers(new_module):
if inspect.isclass(obj) and issubclass(obj, Model):
new_models[obj.__name__] = obj
self.diff_models(old_models, new_models)
await Tortoise.init(config=migrate_config)
def diff_models(self, old_models: Dict[str, Type[Model]], new_models: Dict[str, Type[Model]]):
connection = get_app_connection(config, app)
if connection.schema_generator is MySQLSchemaGenerator:
cls.ddl = MysqlDDL(connection)
else:
raise NotImplementedError("Current only support MySQL")
@classmethod
def _generate_diff_sql(cls, name):
now = datetime.now().strftime("%Y%M%D%H%M%S").replace("/", "")
filename = f"{cls._get_latest_version() + 1}_{now}_{name}.json"
content = {
"upgrade": cls.upgrade_operators,
"download": cls.downgrade_operators,
"migrate": False,
}
with open(os.path.join(cls.migrate_location, filename), "w") as f:
json.dump(content, f, indent=4)
return filename
@classmethod
def migrate(cls, name):
if not cls.migrate_config:
raise ConfigurationError("You must call init_with_old_models() first!")
apps = Tortoise.apps
diff_models = apps.get(cls.diff_app)
app_models = apps.get(cls.app)
cls._diff_models(diff_models, app_models)
cls._diff_models(app_models, diff_models, False)
if not cls.upgrade_operators:
return False
return cls._generate_diff_sql(name)
@classmethod
def _add_operator(cls, operator: str, upgrade=True):
if upgrade:
cls.upgrade_operators.append(operator)
else:
cls.downgrade_operators.append(operator)
@classmethod
def cp_models(
cls, model_files: List[str], old_model_file,
):
"""
cp currents models to old_model_files
:param model_files:
:param old_model_file:
:return:
"""
pattern = (
r"(ManyToManyField|ForeignKeyField|OneToOneField)\((model_name)?(\"|\')(\w+)(.+)\)"
)
for i, model_file in enumerate(model_files):
with open(model_file, "r") as f:
content = f.read()
ret = re.sub(pattern, rf"\1\2(\3{cls.diff_app}\5)", content)
with open(old_model_file, "w" if i == 0 else "w+a") as f:
f.write(ret)
@classmethod
def _get_migrate_config(cls, config: dict, app: str, location: str):
temp_config = deepcopy(config)
path = os.path.join(location, app, cls.old_models)
path = path.replace("/", ".").lstrip(".")
temp_config["apps"][cls.diff_app] = {"models": [path]}
return temp_config
@classmethod
def write_old_models(cls, config: dict, app: str, location: str):
old_model_files = []
models = config.get("apps").get(app).get("models")
for model in models:
old_model_files.append(model.replace(".", "/") + ".py")
cls.cp_models(old_model_files, os.path.join(location, app, cls.get_old_model_file()))
@classmethod
def _diff_models(
cls, old_models: Dict[str, Type[Model]], new_models: Dict[str, Type[Model]], upgrade=True
):
"""
diff models and add operators
:param old_models:
:param new_models:
:param upgrade:
:return:
"""
for new_model_str, new_model in new_models.items():
if new_model_str not in old_models.keys():
self.add_model(new_model)
cls._add_operator(cls.add_model(new_model), upgrade)
else:
self.diff_model(old_models.get(new_model_str), new_model)
cls.diff_model(old_models.get(new_model_str), new_model, upgrade)
for old_model in old_models:
if old_model not in new_models.keys():
self.remove_model(old_models.get(old_model))
cls._add_operator(cls.remove_model(old_models.get(old_model)), upgrade)
def _add_operator(self, operator):
self.operators.append(operator)
@classmethod
def add_model(cls, model: Type[Model]):
return cls.ddl.create_table(model)
def add_model(self, model: Type[Model]):
self._add_operator(self.ddl.create_table(model))
@classmethod
def remove_model(cls, model: Type[Model]):
return cls.ddl.drop_table(model)
def remove_model(self, model: Type[Model]):
self._add_operator(self.ddl.drop_table(model))
def diff_model(self, old_model: Type[Model], new_model: Type[Model]):
@classmethod
def diff_model(cls, old_model: Type[Model], new_model: Type[Model], upgrade=True):
"""
diff single model
:param old_model:
:param new_model:
:param upgrade:
:return:
"""
old_fields_map = old_model._meta.fields_map
new_fields_map = new_model._meta.fields_map
old_keys = old_fields_map.keys()
@@ -81,31 +185,35 @@ class Migrate:
for new_key in new_keys:
new_field = new_fields_map.get(new_key)
if new_key not in old_keys:
self._add_field(new_model, new_field)
cls._add_operator(cls._add_field(new_model, new_field), upgrade)
else:
old_field = old_fields_map.get(new_key)
if old_field.index and not new_field.index:
self._remove_index(old_model, old_field)
cls._add_operator(cls._remove_index(old_model, old_field), upgrade)
elif new_field.index and not old_field.index:
self._add_index(new_model, new_field)
cls._add_operator(cls._add_index(new_model, new_field), upgrade)
for old_key in old_keys:
if old_key not in new_keys:
field = old_fields_map.get(old_key)
self._remove_field(old_model, field)
cls._add_operator(cls._remove_field(old_model, field), upgrade)
def _remove_index(self, model: Type[Model], field: Field):
self._add_operator(self.ddl.drop_index(model, [field.model_field_name], field.unique))
@classmethod
def _remove_index(cls, model: Type[Model], field: Field):
return cls.ddl.drop_index(model, [field.model_field_name], field.unique)
def _add_index(self, model: Type[Model], field: Field):
self._add_operator(self.ddl.add_index(model, [field.model_field_name], field.unique))
@classmethod
def _add_index(cls, model: Type[Model], field: Field):
return cls.ddl.add_index(model, [field.model_field_name], field.unique)
def _add_field(self, model: Type[Model], field: Field):
@classmethod
def _add_field(cls, model: Type[Model], field: Field):
if isinstance(field, ForeignKeyFieldInstance):
self._add_operator(self.ddl.add_fk(model, field))
return cls.ddl.add_fk(model, field)
else:
self._add_operator(self.ddl.add_column(model, field))
return cls.ddl.add_column(model, field)
def _remove_field(self, model: Type[Model], field: Field):
@classmethod
def _remove_field(cls, model: Type[Model], field: Field):
if isinstance(field, ForeignKeyFieldInstance):
self._add_operator(self.ddl.drop_fk(model, field))
self._add_operator(self.ddl.drop_column(model, field.model_field_name))
return cls.ddl.drop_fk(model, field)
return cls.ddl.drop_column(model, field.model_field_name)