This commit is contained in:
long2ice
2021-02-01 14:00:12 +08:00
parent b4cc2de0e3
commit 36f84702b7
7 changed files with 67 additions and 40 deletions

View File

@@ -1 +1 @@
__version__ = "0.4.5"
__version__ = "0.5.0"

View File

@@ -84,25 +84,27 @@ class BaseDDL:
default = ""
return default
def add_column(self, model: "Type[Model]", field_object: Field):
def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
db_table = model._meta.db_table
description = field_describe.get("description")
db_column = field_describe.get("db_column")
db_field_types = field_describe.get("db_field_types")
return self._ADD_COLUMN_TEMPLATE.format(
table_name=db_table,
column=self.schema_generator._create_string(
db_column=field_object.model_field_name,
field_type=field_object.get_for_dialect(self.DIALECT, "SQL_TYPE"),
nullable="NOT NULL" if not field_object.null else "",
unique="UNIQUE" if field_object.unique else "",
db_column=db_column,
field_type=db_field_types.get(self.DIALECT, db_field_types.get("")),
nullable="NOT NULL" if not field_describe.get("nullable") else "",
unique="UNIQUE" if field_describe.get("unique") else "",
comment=self.schema_generator._column_comment_generator(
table=db_table,
column=field_object.model_field_name,
comment=field_object.description,
column=db_column,
comment=field_describe.get("description"),
)
if field_object.description
if description
else "",
is_primary_key=field_object.pk,
default=self._get_default(model, field_object),
is_primary_key=is_pk,
default=field_describe.get("default"),
),
)
@@ -140,7 +142,7 @@ class BaseDDL:
)
def change_column(
self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str
self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str
):
return self._CHANGE_COLUMN_TEMPLATE.format(
table_name=model._meta.db_table,
@@ -170,7 +172,7 @@ class BaseDDL:
def add_fk(self, model: "Type[Model]", field: dict):
db_table = model._meta.db_table
db_column = field.get('db_column')
db_column = field.get("db_column")
fk_name = self.schema_generator._generate_fk_name(
from_table=db_table,
from_field=db_column,
@@ -183,7 +185,7 @@ class BaseDDL:
db_column=db_column,
table=field.related_model._meta.db_table,
field=db_column,
on_delete=field.get('on_delete'),
on_delete=field.get("on_delete"),
)
def drop_fk(self, model: "Type[Model]", field: ForeignKeyFieldInstance):

View File

@@ -3,7 +3,6 @@ from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type
import click
from dictdiffer import diff
from tortoise import (
BackwardFKRelation,
@@ -162,9 +161,7 @@ class Migrate:
cls.downgrade_operators.append(operator)
@classmethod
def diff_models(
cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True
):
def diff_models(cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True):
"""
diff models and add operators
:param old_models:
@@ -172,7 +169,7 @@ class Migrate:
:param upgrade:
:return:
"""
_aerich = f'{cls.app}.{cls._aerich}'
_aerich = f"{cls.app}.{cls._aerich}"
old_models.pop(_aerich, None)
new_models.pop(_aerich, None)
@@ -209,10 +206,28 @@ class Migrate:
"""
for change in diff(old_model_describe, new_model_describe):
action, field_type, fields = change
if action == 'add':
is_pk = field_type == "pk_field"
if action == "add":
for field in fields:
_, field_describe = field
cls._add_field(cls._get_model)
cls._add_operator(
cls._add_field(
cls._get_model(new_model_describe.get("name").split(".")[1]),
field_describe,
is_pk,
),
upgrade,
)
elif action == "remove":
for field in fields:
_, field_describe = field
cls._add_operator(
cls._remove_field(
cls._get_model(new_model_describe.get("name").split(".")[1]),
field_describe,
),
upgrade,
)
@classmethod
def _resolve_fk_fields_name(cls, model: Type[Model], fields_name: Tuple[str]):
@@ -258,12 +273,12 @@ class Migrate:
return isinstance(field, (BackwardFKRelation, BackwardOneToOneRelation))
@classmethod
def _add_field(cls, model: Type[Model], field: dict):
if field.get('field_type') == 'ForeignKeyFieldInstance':
def _add_field(cls, model: Type[Model], field: dict, is_pk: bool = False):
if field.get("field_type") == "ForeignKeyFieldInstance":
return cls.ddl.add_fk(model, field)
if field.get('field_type') == 'ManyToManyFieldInstance':
if field.get("field_type") == "ManyToManyFieldInstance":
return cls.ddl.create_m2m_table(model, field)
return cls.ddl.add_column(model, field)
return cls.ddl.add_column(model, field, is_pk)
@classmethod
def _alter_default(cls, model: Type[Model], field: Field):