This commit is contained in:
long2ice
2021-02-01 14:00:12 +08:00
parent b4cc2de0e3
commit 36f84702b7
7 changed files with 67 additions and 40 deletions

View File

@@ -3,7 +3,6 @@ from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type
import click
from dictdiffer import diff
from tortoise import (
BackwardFKRelation,
@@ -162,9 +161,7 @@ class Migrate:
cls.downgrade_operators.append(operator)
@classmethod
def diff_models(
cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True
):
def diff_models(cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True):
"""
diff models and add operators
:param old_models:
@@ -172,7 +169,7 @@ class Migrate:
:param upgrade:
:return:
"""
_aerich = f'{cls.app}.{cls._aerich}'
_aerich = f"{cls.app}.{cls._aerich}"
old_models.pop(_aerich, None)
new_models.pop(_aerich, None)
@@ -209,10 +206,28 @@ class Migrate:
"""
for change in diff(old_model_describe, new_model_describe):
action, field_type, fields = change
if action == 'add':
is_pk = field_type == "pk_field"
if action == "add":
for field in fields:
_, field_describe = field
cls._add_field(cls._get_model)
cls._add_operator(
cls._add_field(
cls._get_model(new_model_describe.get("name").split(".")[1]),
field_describe,
is_pk,
),
upgrade,
)
elif action == "remove":
for field in fields:
_, field_describe = field
cls._add_operator(
cls._remove_field(
cls._get_model(new_model_describe.get("name").split(".")[1]),
field_describe,
),
upgrade,
)
@classmethod
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]):
@@ -258,12 +273,12 @@ class Migrate:
return isinstance(field, (BackwardFKRelation, BackwardOneToOneRelation))
@classmethod
def _add_field(cls, model: Type[Model], field: dict):
if field.get('field_type') == 'ForeignKeyFieldInstance':
def _add_field(cls, model: Type[Model], field: dict, is_pk: bool = False):
if field.get("field_type") == "ForeignKeyFieldInstance":
return cls.ddl.add_fk(model, field)
if field.get('field_type') == 'ManyToManyFieldInstance':
if field.get("field_type") == "ManyToManyFieldInstance":
return cls.ddl.create_m2m_table(model, field)
return cls.ddl.add_column(model, field)
return cls.ddl.add_column(model, field, is_pk)
@classmethod
def _alter_default(cls, model: Type[Model], field: Field):