aerich/alice/migrate.py
2020-05-13 23:16:27 +08:00

112 lines
4.1 KiB
Python

import importlib
import inspect
import os
from copy import deepcopy
import dill
from typing import List, Type, Dict
from tortoise import Model, ForeignKeyFieldInstance, Tortoise
from tortoise.fields import Field
from alice.backends import DDL
class Migrate:
operators: List
ddl: DDL
old_models = 'old_models.pickle'
def __init__(self, ddl: DDL):
self.operators = []
self.ddl = ddl
@staticmethod
def write_old_models(app, location):
ret = Tortoise.apps.get(app)
old_models = {}
for k, v in ret.items():
old_models[k] = deepcopy(v)
dirname = os.path.join(location, app)
with open(os.path.join(dirname, Migrate.old_models), 'wb') as f:
dill.dump(old_models, f, )
@staticmethod
def read_old_models(app, location):
dirname = os.path.join(location, app)
with open(os.path.join(dirname, Migrate.old_models), 'rb') as f:
return dill.load(f, )
def diff_models_module(self, old_models_module, new_models_module):
old_module = importlib.import_module(old_models_module)
old_models = {}
new_models = {}
for name, obj in inspect.getmembers(old_module):
if inspect.isclass(obj) and issubclass(obj, Model):
old_models[obj.__name__] = obj
new_module = importlib.import_module(new_models_module)
for name, obj in inspect.getmembers(new_module):
if inspect.isclass(obj) and issubclass(obj, Model):
new_models[obj.__name__] = obj
self.diff_models(old_models, new_models)
def diff_models(self, old_models: Dict[str, Type[Model]], new_models: Dict[str, Type[Model]]):
for new_model_str, new_model in new_models.items():
if new_model_str not in old_models.keys():
self.add_model(new_model)
else:
self.diff_model(old_models.get(new_model_str), new_model)
for old_model in old_models:
if old_model not in new_models.keys():
self.remove_model(old_models.get(old_model))
def _add_operator(self, operator):
self.operators.append(operator)
def add_model(self, model: Type[Model]):
self._add_operator(self.ddl.create_table(model))
def remove_model(self, model: Type[Model]):
self._add_operator(self.ddl.drop_table(model))
def diff_model(self, old_model: Type[Model], new_model: Type[Model]):
old_fields_map = old_model._meta.fields_map
new_fields_map = new_model._meta.fields_map
old_keys = old_fields_map.keys()
new_keys = new_fields_map.keys()
for new_key in new_keys:
new_field = new_fields_map.get(new_key)
if new_key not in old_keys:
self._add_field(new_model, new_field)
else:
old_field = old_fields_map.get(new_key)
if old_field.index and not new_field.index:
self._remove_index(old_model, old_field)
elif new_field.index and not old_field.index:
self._add_index(new_model, new_field)
for old_key in old_keys:
if old_key not in new_keys:
field = old_fields_map.get(old_key)
self._remove_field(old_model, field)
def _remove_index(self, model: Type[Model], field: Field):
self._add_operator(self.ddl.drop_index(model, [field.model_field_name], field.unique))
def _add_index(self, model: Type[Model], field: Field):
self._add_operator(self.ddl.add_index(model, [field.model_field_name], field.unique))
def _add_field(self, model: Type[Model], field: Field):
if isinstance(field, ForeignKeyFieldInstance):
self._add_operator(self.ddl.add_fk(model, field))
else:
self._add_operator(self.ddl.add_column(model, field))
def _remove_field(self, model: Type[Model], field: Field):
if isinstance(field, ForeignKeyFieldInstance):
self._add_operator(self.ddl.drop_fk(model, field))
self._add_operator(self.ddl.drop_column(model, field.model_field_name))