Merge branch 'dev' into dev

This commit is contained in:
Yasser Tahiri
2022-09-15 23:19:59 +01:00
committed by GitHub
24 changed files with 1077 additions and 719 deletions

View File

@@ -8,7 +8,9 @@ from tortoise.transactions import in_transaction
from tortoise.utils import get_schema_sql
from aerich.exceptions import DowngradeError
from aerich.inspectdb import InspectDb
from aerich.inspectdb.mysql import InspectMySQL
from aerich.inspectdb.postgres import InspectPostgres
from aerich.inspectdb.sqlite import InspectSQLite
from aerich.migrate import Migrate
from aerich.models import Aerich
from aerich.utils import (
@@ -103,10 +105,19 @@ class Command:
versions = Migrate.get_all_version_files()
return [version for version in versions]
async def inspectdb(self, tables: List[str]):
async def inspectdb(self, tables: List[str] = None) -> str:
connection = get_app_connection(self.tortoise_config, self.app)
inspect = InspectDb(connection, tables)
await inspect.inspect()
dialect = connection.schema_generator.DIALECT
if dialect == "mysql":
cls = InspectMySQL
elif dialect == "postgres":
cls = InspectPostgres
elif dialect == "sqlite":
cls = InspectSQLite
else:
raise NotImplementedError(f"{dialect} is not supported")
inspect = cls(connection, tables)
return await inspect.inspect()
async def migrate(self, name: str = "update"):
return await Migrate.migrate(name)

View File

@@ -10,12 +10,11 @@ from click import Context, UsageError
from tomlkit.exceptions import NonExistentKey
from tortoise import Tortoise
from aerich import Command
from aerich.enums import Color
from aerich.exceptions import DowngradeError
from aerich.utils import add_src_path, get_tortoise_config
from . import Command
from .enums import Color
from .version import __version__
from aerich.version import __version__
CONFIG_DEFAULT_VALUES = {
"src_folder": ".",
@@ -27,11 +26,11 @@ def coro(f):
def wrapper(*args, **kwargs):
loop = asyncio.get_event_loop()
# Close db connections at the end of all all but the cli group function
# Close db connections at the end of all but the cli group function
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
if f.__name__ != "cli":
if f.__name__ not in ["cli", "init_db"]:
loop.run_until_complete(Tortoise.close_connections())
return wrapper
@@ -55,10 +54,10 @@ async def cli(ctx: Context, config, app):
invoked_subcommand = ctx.invoked_subcommand
if invoked_subcommand != "init":
if not Path(config).exists():
config_path = Path(config)
if not config_path.exists():
raise UsageError("You must exec init first", ctx=ctx)
with open(config, "r") as f:
content = f.read()
content = config_path.read_text()
doc = tomlkit.parse(content)
try:
tool = doc["tool"]["aerich"]
@@ -193,18 +192,19 @@ async def init(ctx: Context, tortoise_orm, location, src_folder):
# check that we can find the configuration, if not we can fail before the config file gets created
add_src_path(src_folder)
get_tortoise_config(ctx, tortoise_orm)
with open(config_file, "r") as f:
content = f.read()
doc = tomlkit.parse(content)
config_path = Path(config_file)
if config_path.exists():
content = config_path.read_text()
doc = tomlkit.parse(content)
else:
doc = tomlkit.parse("[tool.aerich]")
table = tomlkit.table()
table["tortoise_orm"] = tortoise_orm
table["location"] = location
table["src_folder"] = src_folder
doc["tool"]["aerich"] = table
with open(config_file, "w") as f:
f.write(tomlkit.dumps(doc))
config_path.write_text(tomlkit.dumps(doc))
Path(location).mkdir(parents=True, exist_ok=True)
@@ -214,15 +214,17 @@ async def init(ctx: Context, tortoise_orm, location, src_folder):
@cli.command(help="Generate schema and generate app migrate location.")
@click.option(
"-s",
"--safe",
type=bool,
is_flag=True,
default=True,
help="When set to true, creates the table only when it does not already exist.",
show_default=True,
)
@click.pass_context
@coro
async def init_db(ctx: Context, safe):
async def init_db(ctx: Context, safe: bool):
command = ctx.obj["command"]
app = command.app
dirname = Path(command.location, app)
@@ -248,7 +250,8 @@ async def init_db(ctx: Context, safe):
@coro
async def inspectdb(ctx: Context, table: List[str]):
command = ctx.obj["command"]
await command.inspectdb(table)
ret = await command.inspectdb(table)
click.secho(ret)
def main():

31
aerich/coder.py Normal file
View File

@@ -0,0 +1,31 @@
import base64
import json
import pickle # nosec: B301,B403
from tortoise.indexes import Index
class JsonEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, Index):
return {
"type": "index",
"val": base64.b64encode(pickle.dumps(obj)).decode(), # nosec: B301
}
else:
return super().default(obj)
def object_hook(obj):
_type = obj.get("type")
if not _type:
return obj
return pickle.loads(base64.b64decode(obj["val"])) # nosec: B301
def encoder(obj: dict):
return json.dumps(obj, cls=JsonEncoder)
def decoder(obj: str):
return json.loads(obj, object_hook=object_hook)

View File

@@ -78,15 +78,11 @@ class BaseDDL:
auto_now_add = field_describe.get("auto_now_add", False)
auto_now = field_describe.get("auto_now", False)
if default is not None or auto_now_add:
if (
field_describe.get("field_type")
in [
"UUIDField",
"TextField",
"JSONField",
]
or is_default_function(default)
):
if field_describe.get("field_type") in [
"UUIDField",
"TextField",
"JSONField",
] or is_default_function(default):
default = ""
else:
try:
@@ -195,6 +191,12 @@ class BaseDDL:
table_name=model._meta.db_table,
)
def drop_index_by_name(self, model: "Type[Model]", index_name: str):
return self._DROP_INDEX_TEMPLATE.format(
index_name=index_name,
table_name=model._meta.db_table,
)
def add_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict):
db_table = model._meta.db_table

View File

@@ -1,86 +0,0 @@
import sys
from typing import List, Optional
from ddlparse import DdlParse
from tortoise import BaseDBAsyncClient
class InspectDb:
_table_template = "class {table}(Model):\n"
_field_template_mapping = {
"INT": " {field} = fields.IntField({pk}{unique}{comment})",
"SMALLINT": " {field} = fields.IntField({pk}{unique}{comment})",
"TINYINT": " {field} = fields.BooleanField({null}{default}{comment})",
"VARCHAR": " {field} = fields.CharField({pk}{unique}{length}{null}{default}{comment})",
"LONGTEXT": " {field} = fields.TextField({null}{default}{comment})",
"TEXT": " {field} = fields.TextField({null}{default}{comment})",
"DATETIME": " {field} = fields.DatetimeField({null}{default}{comment})",
"FLOAT": " {field} = fields.FloatField({null}{default}{comment})",
}
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
self.conn = conn
self.tables = tables
self.DIALECT = conn.schema_generator.DIALECT
async def show_create_tables(self):
if self.DIALECT == "mysql":
if not self.tables:
sql_tables = f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{self.conn.database}';" # nosec: B608
ret = await self.conn.execute_query(sql_tables)
self.tables = map(lambda x: x["TABLE_NAME"], ret[1])
for table in self.tables:
sql_show_create_table = f"SHOW CREATE TABLE {table}"
ret = await self.conn.execute_query(sql_show_create_table)
yield ret[1][0]["Create Table"]
else:
raise NotImplementedError("Currently only support MySQL")
async def inspect(self):
ddl_list = self.show_create_tables()
result = "from tortoise import Model, fields\n\n\n"
tables = []
async for ddl in ddl_list:
parser = DdlParse(ddl, DdlParse.DATABASE.mysql)
table = parser.parse()
name = table.name.title()
columns = table.columns
fields = []
model = self._table_template.format(table=name)
for column_name, column in columns.items():
comment = default = length = unique = null = pk = ""
if column.primary_key:
pk = "pk=True, "
if column.unique:
unique = "unique=True, "
if column.data_type == "VARCHAR":
length = f"max_length={column.length}, "
if not column.not_null:
null = "null=True, "
if column.default is not None:
if column.data_type == "TINYINT":
default = f"default={'True' if column.default == '1' else 'False'}, "
elif column.data_type == "DATETIME":
if "CURRENT_TIMESTAMP" in column.default:
if "ON UPDATE CURRENT_TIMESTAMP" in ddl:
default = "auto_now_add=True, "
else:
default = "auto_now=True, "
else:
default = f"default={column.default}, "
if column.comment:
comment = f"description='{column.comment}', "
field = self._field_template_mapping[column.data_type].format(
field=column_name,
pk=pk,
unique=unique,
length=length,
null=null,
default=default,
comment=comment,
)
fields.append(field)
tables.append(model + "\n".join(fields))
sys.stdout.write(result + "\n\n\n".join(tables))

View File

@@ -0,0 +1,168 @@
from typing import Any, List, Optional
from pydantic import BaseModel
from tortoise import BaseDBAsyncClient
class Column(BaseModel):
name: str
data_type: str
null: bool
default: Any
comment: Optional[str]
pk: bool
unique: bool
index: bool
length: Optional[int]
extra: Optional[str]
decimal_places: Optional[int]
max_digits: Optional[int]
def translate(self) -> dict:
comment = default = length = index = null = pk = ""
if self.pk:
pk = "pk=True, "
else:
if self.unique:
index = "unique=True, "
else:
if self.index:
index = "index=True, "
if self.data_type in ["varchar", "VARCHAR"]:
length = f"max_length={self.length}, "
if self.data_type in ["decimal", "numeric"]:
length_parts = []
if self.max_digits:
length_parts.append(f"max_digits={self.max_digits}")
if self.decimal_places:
length_parts.append(f"decimal_places={self.decimal_places}")
length = ", ".join(length_parts)
if self.null:
null = "null=True, "
if self.default is not None:
if self.data_type in ["tinyint", "INT"]:
default = f"default={'True' if self.default == '1' else 'False'}, "
elif self.data_type == "bool":
default = f"default={'True' if self.default == 'true' else 'False'}, "
elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]:
if "CURRENT_TIMESTAMP" == self.default:
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
default = "auto_now=True, "
else:
default = "auto_now_add=True, "
else:
if "::" in self.default:
default = f"default={self.default.split('::')[0]}, "
elif self.default.endswith("()"):
default = ""
else:
default = f"default={self.default}, "
if self.comment:
comment = f"description='{self.comment}', "
return {
"name": self.name,
"pk": pk,
"index": index,
"null": null,
"default": default,
"length": length,
"comment": comment,
}
class Inspect:
_table_template = "class {table}(Model):\n"
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
self.conn = conn
try:
self.database = conn.database
except AttributeError:
pass
self.tables = tables
@property
def field_map(self) -> dict:
raise NotImplementedError
async def inspect(self) -> str:
if not self.tables:
self.tables = await self.get_all_tables()
result = "from tortoise import Model, fields\n\n\n"
tables = []
for table in self.tables:
columns = await self.get_columns(table)
fields = []
model = self._table_template.format(table=table.title().replace("_", ""))
for column in columns:
field = self.field_map[column.data_type](**column.translate())
fields.append(" " + field)
tables.append(model + "\n".join(fields))
return result + "\n\n\n".join(tables)
async def get_columns(self, table: str) -> List[Column]:
raise NotImplementedError
async def get_all_tables(self) -> List[str]:
raise NotImplementedError
@classmethod
def decimal_field(cls, **kwargs) -> str:
return "{name} = fields.DecimalField({pk}{index}{length}{null}{default}{comment})".format(
**kwargs
)
@classmethod
def time_field(cls, **kwargs) -> str:
return "{name} = fields.TimeField({null}{default}{comment})".format(**kwargs)
@classmethod
def date_field(cls, **kwargs) -> str:
return "{name} = fields.DateField({null}{default}{comment})".format(**kwargs)
@classmethod
def float_field(cls, **kwargs) -> str:
return "{name} = fields.FloatField({null}{default}{comment})".format(**kwargs)
@classmethod
def datetime_field(cls, **kwargs) -> str:
return "{name} = fields.DatetimeField({null}{default}{comment})".format(**kwargs)
@classmethod
def text_field(cls, **kwargs) -> str:
return "{name} = fields.TextField({null}{default}{comment})".format(**kwargs)
@classmethod
def char_field(cls, **kwargs) -> str:
return "{name} = fields.CharField({pk}{index}{length}{null}{default}{comment})".format(
**kwargs
)
@classmethod
def int_field(cls, **kwargs) -> str:
return "{name} = fields.IntField({pk}{index}{comment})".format(**kwargs)
@classmethod
def smallint_field(cls, **kwargs) -> str:
return "{name} = fields.SmallIntField({pk}{index}{comment})".format(**kwargs)
@classmethod
def bigint_field(cls, **kwargs) -> str:
return "{name} = fields.BigIntField({pk}{index}{default}{comment})".format(**kwargs)
@classmethod
def bool_field(cls, **kwargs) -> str:
return "{name} = fields.BooleanField({null}{default}{comment})".format(**kwargs)
@classmethod
def uuid_field(cls, **kwargs) -> str:
return "{name} = fields.UUIDField({pk}{index}{default}{comment})".format(**kwargs)
@classmethod
def json_field(cls, **kwargs) -> str:
return "{name} = fields.JSONField({null}{default}{comment})".format(**kwargs)
@classmethod
def binary_field(cls, **kwargs) -> str:
return "{name} = fields.BinaryField({null}{default}{comment})".format(**kwargs)

69
aerich/inspectdb/mysql.py Normal file
View File

@@ -0,0 +1,69 @@
from typing import List
from aerich.inspectdb import Column, Inspect
class InspectMySQL(Inspect):
@property
def field_map(self) -> dict:
return {
"int": self.int_field,
"smallint": self.smallint_field,
"tinyint": self.bool_field,
"bigint": self.bigint_field,
"varchar": self.char_field,
"longtext": self.text_field,
"text": self.text_field,
"datetime": self.datetime_field,
"float": self.float_field,
"date": self.date_field,
"time": self.time_field,
"decimal": self.decimal_field,
"json": self.json_field,
"longblob": self.binary_field,
}
async def get_all_tables(self) -> List[str]:
sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s"
ret = await self.conn.execute_query_dict(sql, [self.database])
return list(map(lambda x: x["TABLE_NAME"], ret))
async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME
from information_schema.COLUMNS c
left join information_schema.STATISTICS s on c.TABLE_NAME = s.TABLE_NAME
and c.TABLE_SCHEMA = s.TABLE_SCHEMA
and c.COLUMN_NAME = s.COLUMN_NAME
where c.TABLE_SCHEMA = %s
and c.TABLE_NAME = %s"""
ret = await self.conn.execute_query_dict(sql, [self.database, table])
for row in ret:
non_unique = row["NON_UNIQUE"]
if non_unique is None:
unique = False
else:
unique = not non_unique
index_name = row["INDEX_NAME"]
if index_name is None:
index = False
else:
index = row["INDEX_NAME"] != "PRIMARY"
columns.append(
Column(
name=row["COLUMN_NAME"],
data_type=row["DATA_TYPE"],
null=row["IS_NULLABLE"] == "YES",
default=row["COLUMN_DEFAULT"],
pk=row["COLUMN_KEY"] == "PRI",
comment=row["COLUMN_COMMENT"],
unique=row["COLUMN_KEY"] == "UNI",
extra=row["EXTRA"],
unque=unique,
index=index,
length=row["CHARACTER_MAXIMUM_LENGTH"],
max_digits=row["NUMERIC_PRECISION"],
decimal_places=row["NUMERIC_SCALE"],
)
)
return columns

View File

@@ -0,0 +1,76 @@
from typing import List, Optional
from tortoise import BaseDBAsyncClient
from aerich.inspectdb import Column, Inspect
class InspectPostgres(Inspect):
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
super().__init__(conn, tables)
self.schema = self.conn.server_settings.get("schema") or "public"
@property
def field_map(self) -> dict:
return {
"int4": self.int_field,
"int8": self.int_field,
"smallint": self.smallint_field,
"varchar": self.char_field,
"text": self.text_field,
"bigint": self.bigint_field,
"timestamptz": self.datetime_field,
"float4": self.float_field,
"float8": self.float_field,
"date": self.date_field,
"time": self.time_field,
"decimal": self.decimal_field,
"numeric": self.decimal_field,
"uuid": self.uuid_field,
"jsonb": self.json_field,
"bytea": self.binary_field,
"bool": self.bool_field,
"timestamp": self.datetime_field,
}
async def get_all_tables(self) -> List[str]:
sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2"
ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
return list(map(lambda x: x["table_name"], ret))
async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = f"""select c.column_name,
col_description('public.{table}'::regclass, ordinal_position) as column_comment,
t.constraint_type as column_key,
udt_name as data_type,
is_nullable,
column_default,
character_maximum_length,
numeric_precision,
numeric_scale
from information_schema.constraint_column_usage const
join information_schema.table_constraints t
using (table_catalog, table_schema, table_name, constraint_catalog, constraint_schema, constraint_name)
right join information_schema.columns c using (column_name, table_catalog, table_schema, table_name)
where c.table_catalog = $1
and c.table_name = $2
and c.table_schema = $3"""
ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema])
for row in ret:
columns.append(
Column(
name=row["column_name"],
data_type=row["data_type"],
null=row["is_nullable"] == "YES",
default=row["column_default"],
length=row["character_maximum_length"],
max_digits=row["numeric_precision"],
decimal_places=row["numeric_scale"],
comment=row["column_comment"],
pk=row["column_key"] == "PRIMARY KEY",
unique=False, # can't get this simply
index=False, # can't get this simply
)
)
return columns

View File

@@ -0,0 +1,61 @@
from typing import List
from aerich.inspectdb import Column, Inspect
class InspectSQLite(Inspect):
@property
def field_map(self) -> dict:
return {
"INTEGER": self.int_field,
"INT": self.bool_field,
"SMALLINT": self.smallint_field,
"VARCHAR": self.char_field,
"TEXT": self.text_field,
"TIMESTAMP": self.datetime_field,
"REAL": self.float_field,
"BIGINT": self.bigint_field,
"DATE": self.date_field,
"TIME": self.time_field,
"JSON": self.json_field,
"BLOB": self.binary_field,
}
async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = f"PRAGMA table_info({table})"
ret = await self.conn.execute_query_dict(sql)
columns_index = await self._get_columns_index(table)
for row in ret:
try:
length = row["type"].split("(")[1].split(")")[0]
except IndexError:
length = None
columns.append(
Column(
name=row["name"],
data_type=row["type"].split("(")[0],
null=row["notnull"] == 0,
default=row["dflt_value"],
length=length,
pk=row["pk"] == 1,
unique=columns_index.get(row["name"]) == "unique",
index=columns_index.get(row["name"]) == "index",
)
)
return columns
async def _get_columns_index(self, table: str):
sql = f"PRAGMA index_list ({table})"
indexes = await self.conn.execute_query_dict(sql)
ret = {}
for index in indexes:
sql = f"PRAGMA index_info({index['name']})"
index_info = (await self.conn.execute_query_dict(sql))[0]
ret[index_info["name"]] = "unique" if index["unique"] else "index"
return ret
async def get_all_tables(self) -> List[str]:
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
ret = await self.conn.execute_query_dict(sql)
return list(map(lambda x: x["tbl_name"], ret))

View File

@@ -1,12 +1,15 @@
import importlib
import os
from datetime import datetime
from hashlib import md5
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Type
from typing import Dict, List, Optional, Tuple, Type, Union
import click
from dictdiffer import diff
from tortoise import BaseDBAsyncClient, Model, Tortoise
from tortoise.exceptions import OperationalError
from tortoise.indexes import Index
from aerich.ddl import BaseDDL
from aerich.models import MAX_VERSION_LENGTH, Aerich
@@ -32,7 +35,7 @@ class Migrate:
ddl: BaseDDL
_last_version_content: Optional[dict] = None
app: str
migrate_location: str
migrate_location: Path
dialect: str
_db_version: Optional[str] = None
@@ -61,6 +64,11 @@ class Migrate:
ret = await connection.execute_query(sql)
cls._db_version = ret[1][0].get("version")
@classmethod
async def load_ddl_class(cls):
ddl_dialect_module = importlib.import_module(f"aerich.ddl.{cls.dialect}")
return getattr(ddl_dialect_module, f"{cls.dialect.capitalize()}DDL")
@classmethod
async def init(cls, config: dict, app: str, location: str):
await Tortoise.init(config=config)
@@ -72,18 +80,8 @@ class Migrate:
connection = get_app_connection(config, app)
cls.dialect = connection.schema_generator.DIALECT
if cls.dialect == "mysql":
from aerich.ddl.mysql import MysqlDDL
cls.ddl = MysqlDDL(connection)
elif cls.dialect == "sqlite":
from aerich.ddl.sqlite import SqliteDDL
cls.ddl = SqliteDDL(connection)
elif cls.dialect == "postgres":
from aerich.ddl.postgres import PostgresDDL
cls.ddl = PostgresDDL(connection)
cls.ddl_class = await cls.load_ddl_class()
cls.ddl = cls.ddl_class(connection)
await cls._get_db_version(connection)
@classmethod
@@ -157,6 +155,18 @@ class Migrate:
else:
cls.downgrade_operators.append(operator)
@classmethod
def _handle_indexes(cls, model: Type[Model], indexes: List[Union[Tuple[str], Index]]):
ret = []
for index in indexes:
if isinstance(index, Index):
index.__hash__ = lambda self: md5( # nosec: B303
self.index_name(cls.ddl.schema_generator, model).encode()
+ self.__class__.__name__.encode()
).hexdigest()
ret.append(index)
return ret
@classmethod
def diff_models(cls, old_models: Dict[str, dict], new_models: Dict[str, dict], upgrade=True):
"""
@@ -192,8 +202,18 @@ class Migrate:
new_unique_together = set(
map(lambda x: tuple(x), new_model_describe.get("unique_together"))
)
old_indexes = set(map(lambda x: tuple(x), old_model_describe.get("indexes", [])))
new_indexes = set(map(lambda x: tuple(x), new_model_describe.get("indexes", [])))
old_indexes = set(
map(
lambda x: x if isinstance(x, Index) else tuple(x),
cls._handle_indexes(model, old_model_describe.get("indexes", [])),
)
)
new_indexes = set(
map(
lambda x: x if isinstance(x, Index) else tuple(x),
cls._handle_indexes(model, new_model_describe.get("indexes", [])),
)
)
old_pk_field = old_model_describe.get("pk_field")
new_pk_field = new_model_describe.get("pk_field")
# pk field
@@ -323,26 +343,44 @@ class Migrate:
),
upgrade,
)
if new_data_field["indexed"]:
cls._add_operator(
cls._add_index(
model, {new_data_field["db_column"]}, new_data_field["unique"]
),
upgrade,
True,
)
# remove fields
for old_data_field_name in set(old_data_fields_name).difference(
set(new_data_fields_name)
):
# don't remove field if is rename
# don't remove field if is renamed
if (upgrade and old_data_field_name in cls._rename_old) or (
not upgrade and old_data_field_name in cls._rename_new
):
continue
old_data_field = next(
filter(lambda x: x.get("name") == old_data_field_name, old_data_fields)
)
db_column = old_data_field["db_column"]
cls._add_operator(
cls._remove_field(
model,
next(
filter(
lambda x: x.get("name") == old_data_field_name, old_data_fields
)
).get("db_column"),
db_column,
),
upgrade,
)
if old_data_field["indexed"]:
cls._add_operator(
cls._drop_index(
model,
{db_column},
),
upgrade,
True,
)
old_fk_fields = old_model_describe.get("fk_fields")
new_fk_fields = new_model_describe.get("fk_fields")
@@ -402,8 +440,14 @@ class Migrate:
cls._drop_index(model, (field_name,), unique), upgrade, True
)
elif option == "db_field_types.":
# continue since repeated with others
continue
if new_data_field.get("field_type") == "DecimalField":
# modify column
cls._add_operator(
cls._modify_field(model, new_data_field),
upgrade,
)
else:
continue
elif option == "default":
if not (
is_default_function(old_new[0]) or is_default_function(old_new[1])
@@ -463,12 +507,18 @@ class Migrate:
return ret
@classmethod
def _drop_index(cls, model: Type[Model], fields_name: Tuple[str], unique=False):
def _drop_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False):
if isinstance(fields_name, Index):
return cls.ddl.drop_index_by_name(
model, fields_name.index_name(cls.ddl.schema_generator, model)
)
fields_name = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.drop_index(model, fields_name, unique)
@classmethod
def _add_index(cls, model: Type[Model], fields_name: Tuple[str], unique=False):
def _add_index(cls, model: Type[Model], fields_name: Union[Tuple[str], Index], unique=False):
if isinstance(fields_name, Index):
return fields_name.get_sql(cls.ddl.schema_generator, model, False)
fields_name = cls._resolve_fk_fields_name(model, fields_name)
return cls.ddl.add_index(model, fields_name, unique)

View File

@@ -1,12 +1,15 @@
from tortoise import Model, fields
from aerich.coder import decoder, encoder
MAX_VERSION_LENGTH = 255
MAX_APP_LENGTH = 100
class Aerich(Model):
version = fields.CharField(max_length=MAX_VERSION_LENGTH)
app = fields.CharField(max_length=20)
content = fields.JSONField()
app = fields.CharField(max_length=MAX_APP_LENGTH)
content = fields.JSONField(encoder=encoder, decoder=decoder)
class Meta:
ordering = ["-id"]

View File

@@ -87,19 +87,19 @@ def get_version_content_from_file(version_file: Union[str, Path]) -> Dict:
:param version_file:
:return:
"""
with open(version_file, "r", encoding="utf-8") as f:
content = f.read()
first = content.index(_UPGRADE)
try:
second = content.index(_DOWNGRADE)
except ValueError:
second = len(content) - 1
upgrade_content = content[first + len(_UPGRADE) : second].strip() # noqa:E203
downgrade_content = content[second + len(_DOWNGRADE) :].strip() # noqa:E203
return {
"upgrade": [line for line in upgrade_content.split(";\n") if line],
"downgrade": [line for line in downgrade_content.split(";\n") if line],
}
content = Path(version_file).read_text(encoding="utf-8")
first = content.index(_UPGRADE)
try:
second = content.index(_DOWNGRADE)
except ValueError:
second = len(content) - 1
upgrade_content = content[first + len(_UPGRADE) : second].strip() # noqa:E203
downgrade_content = content[second + len(_DOWNGRADE) :].strip() # noqa:E203
ret = {
"upgrade": list(filter(lambda x: x or False, upgrade_content.split(";\n"))),
"downgrade": list(filter(lambda x: x or False, downgrade_content.split(";\n"))),
}
return ret
def write_version_file(version_file: Path, content: Dict):
@@ -109,25 +109,25 @@ def write_version_file(version_file: Path, content: Dict):
:param content:
:return:
"""
with open(version_file, "w", encoding="utf-8") as f:
f.write(_UPGRADE)
upgrade = content.get("upgrade")
if len(upgrade) > 1:
f.write(";\n".join(upgrade))
if not upgrade[-1].endswith(";"):
f.write(";\n")
text = _UPGRADE
upgrade = content.get("upgrade")
if len(upgrade) > 1:
text += ";\n".join(upgrade)
if not upgrade[-1].endswith(";"):
text += ";\n"
else:
text += f"{upgrade[0]}"
if not upgrade[0].endswith(";"):
text += ";"
text += "\n"
downgrade = content.get("downgrade")
if downgrade:
text += _DOWNGRADE
if len(downgrade) > 1:
text += ";\n".join(downgrade) + ";\n"
else:
f.write(f"{upgrade[0]}")
if not upgrade[0].endswith(";"):
f.write(";")
f.write("\n")
downgrade = content.get("downgrade")
if downgrade:
f.write(_DOWNGRADE)
if len(downgrade) > 1:
f.write(";\n".join(downgrade) + ";\n")
else:
f.write(f"{downgrade[0]};\n")
text += f"{downgrade[0]};\n"
version_file.write_text(text, encoding="utf-8")
def get_models_describe(app: str) -> Dict:

View File

@@ -1 +1 @@
__version__ = "0.6.0"
__version__ = "0.6.4"