feat: improve inspectdb and support postgres

This commit is contained in:
long2ice
2022-04-01 19:56:48 +08:00
parent c39462820c
commit 45129cef9f
15 changed files with 598 additions and 274 deletions

View File

@@ -8,7 +8,8 @@ from tortoise.transactions import in_transaction
from tortoise.utils import get_schema_sql
from aerich.exceptions import DowngradeError
from aerich.inspectdb import InspectDb
from aerich.inspect.mysql import InspectMySQL
from aerich.inspect.postgres import InspectPostgres
from aerich.migrate import Migrate
from aerich.models import Aerich
from aerich.utils import (
@@ -106,10 +107,17 @@ class Command:
ret.append(version)
return ret
async def inspectdb(self, tables: List[str]):
async def inspectdb(self, tables: List[str]) -> str:
connection = get_app_connection(self.tortoise_config, self.app)
inspect = InspectDb(connection, tables)
await inspect.inspect()
dialect = connection.schema_generator.DIALECT
if dialect == "mysql":
cls = InspectMySQL
elif dialect == "postgres":
cls = InspectPostgres
else:
raise NotImplementedError(f"{dialect} is not supported")
inspect = cls(connection, tables)
return await inspect.inspect()
async def migrate(self, name: str = "update"):
return await Migrate.migrate(name)

View File

@@ -30,7 +30,7 @@ def coro(f):
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
if f.__name__ != "cli":
if f.__name__ not in ["cli", "init"]:
loop.run_until_complete(Tortoise.close_connections())
return wrapper
@@ -249,7 +249,8 @@ async def init_db(ctx: Context, safe):
@coro
async def inspectdb(ctx: Context, table: List[str]):
command = ctx.obj["command"]
await command.inspectdb(table)
ret = await command.inspectdb(table)
click.secho(ret)
def main():

View File

@@ -1,6 +1,6 @@
import base64
import json
import pickle # nosec: B301
import pickle # nosec: B301,B403
from tortoise.indexes import Index
@@ -10,8 +10,8 @@ class JsonEncoder(json.JSONEncoder):
if isinstance(obj, Index):
return {
"type": "index",
"val": base64.b64encode(pickle.dumps(obj)).decode(),
} # nosec: B301
"val": base64.b64encode(pickle.dumps(obj)).decode(), # nosec: B301
}
else:
return super().default(obj)

View File

@@ -78,15 +78,11 @@ class BaseDDL:
auto_now_add = field_describe.get("auto_now_add", False)
auto_now = field_describe.get("auto_now", False)
if default is not None or auto_now_add:
if (
field_describe.get("field_type")
in [
"UUIDField",
"TextField",
"JSONField",
]
or is_default_function(default)
):
if field_describe.get("field_type") in [
"UUIDField",
"TextField",
"JSONField",
] or is_default_function(default):
default = ""
else:
try:

155
aerich/inspect/__init__.py Normal file
View File

@@ -0,0 +1,155 @@
from typing import Any, List, Optional
from pydantic import BaseModel
from tortoise import BaseDBAsyncClient
class Column(BaseModel):
name: str
data_type: str
null: bool
default: Any
comment: Optional[str]
pk: bool
unique: bool
length: Optional[int]
extra: Optional[str]
decimal_places: Optional[int]
max_digits: Optional[int]
def translate(self) -> dict:
comment = default = length = unique = null = pk = ""
if self.pk:
pk = "pk=True, "
if self.unique:
unique = "unique=True, "
if self.data_type == "varchar":
length = f"max_length={self.length}, "
if self.data_type == "decimal":
length = f"max_digits={self.max_digits}, decimal_places={self.decimal_places}, "
if self.null:
null = "null=True, "
if self.default is not None:
if self.data_type == "tinyint":
default = f"default={'True' if self.default == '1' else 'False'}, "
elif self.data_type == "bool":
default = f"default={'True' if self.default == 'true' else 'False'}, "
elif self.data_type in ["datetime", "timestamptz"]:
if "CURRENT_TIMESTAMP" == self.default:
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
default = "auto_now=True, "
else:
default = "auto_now_add=True, "
else:
if "::" in self.default:
default = f"default={self.default.split('::')[0]}, "
elif self.default.endswith("()"):
default = ""
else:
default = f"default={self.default}, "
if self.comment:
comment = f"description='{self.comment}', "
return {
"name": self.name,
"pk": pk,
"unique": unique,
"null": null,
"default": default,
"length": length,
"comment": comment,
}
class Inspect:
_table_template = "class {table}(Model):\n"
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
self.conn = conn
self.database = conn.database
self.tables = tables
@property
def field_map(self) -> dict:
raise NotImplementedError
async def inspect(self) -> str:
if not self.tables:
self.tables = await self.get_all_tables()
result = "from tortoise import Model, fields\n\n\n"
tables = []
for table in self.tables:
columns = await self.get_columns(table)
fields = []
model = self._table_template.format(table=table.title().replace("_", ""))
for column in columns:
field = self.field_map[column.data_type](**column.translate())
fields.append(" " + field)
tables.append(model + "\n".join(fields))
return result + "\n\n\n".join(tables)
async def get_columns(self, table: str) -> List[Column]:
raise NotImplementedError
async def get_all_tables(self) -> List[str]:
raise NotImplementedError
@classmethod
def decimal_field(cls, **kwargs) -> str:
return "{name} = fields.DecimalField({pk}{unique}{length}{null}{default}{comment})".format(
**kwargs
)
@classmethod
def time_field(cls, **kwargs) -> str:
return "{name} = fields.TimeField({null}{default}{comment})".format(**kwargs)
@classmethod
def date_field(cls, **kwargs) -> str:
return "{name} = fields.DateField({null}{default}{comment})".format(**kwargs)
@classmethod
def float_field(cls, **kwargs) -> str:
return "{name} = fields.FloatField({null}{default}{comment})".format(**kwargs)
@classmethod
def datetime_field(cls, **kwargs) -> str:
return "{name} = fields.DatetimeField({null}{default}{comment})".format(**kwargs)
@classmethod
def text_field(cls, **kwargs) -> str:
return "{name} = fields.TextField({null}{default}{comment})".format(**kwargs)
@classmethod
def char_field(cls, **kwargs) -> str:
return "{name} = fields.CharField({pk}{unique}{length}{null}{default}{comment})".format(
**kwargs
)
@classmethod
def int_field(cls, **kwargs) -> str:
return "{name} = fields.IntField({pk}{unique}{comment})".format(**kwargs)
@classmethod
def smallint_field(cls, **kwargs) -> str:
return "{name} = fields.SmallIntField({pk}{unique}{comment})".format(**kwargs)
@classmethod
def bigint_field(cls, **kwargs) -> str:
return "{name} = fields.BigIntField({pk}{unique}{default}{comment})".format(**kwargs)
@classmethod
def bool_field(cls, **kwargs) -> str:
return "{name} = fields.BooleanField({null}{default}{comment})".format(**kwargs)
@classmethod
def uuid_field(cls, **kwargs) -> str:
return "{name} = fields.UUIDField({pk}{unique}{default}{comment})".format(**kwargs)
@classmethod
def json_field(cls, **kwargs) -> str:
return "{name} = fields.JSONField({null}{default}{comment})".format(**kwargs)
@classmethod
def binary_field(cls, **kwargs) -> str:
return "{name} = fields.BinaryField({null}{default}{comment})".format(**kwargs)

50
aerich/inspect/mysql.py Normal file
View File

@@ -0,0 +1,50 @@
from typing import List
from aerich.inspect import Column, Inspect
class InspectMySQL(Inspect):
@property
def field_map(self) -> dict:
return {
"int": self.int_field,
"smallint": self.smallint_field,
"tinyint": self.bool_field,
"varchar": self.char_field,
"longtext": self.text_field,
"text": self.text_field,
"datetime": self.datetime_field,
"float": self.float_field,
"date": self.date_field,
"time": self.time_field,
"decimal": self.decimal_field,
"json": self.json_field,
"longblob": self.binary_field,
}
async def get_all_tables(self) -> List[str]:
sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s"
ret = await self.conn.execute_query_dict(sql, [self.database])
return list(map(lambda x: x["TABLE_NAME"], ret))
async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = "select * from information_schema.columns where TABLE_SCHEMA=%s and TABLE_NAME=%s"
ret = await self.conn.execute_query(sql, [self.database, table])
for row in ret[1]:
columns.append(
Column(
name=row["COLUMN_NAME"],
data_type=row["DATA_TYPE"],
null=row["IS_NULLABLE"] == "YES",
default=row["COLUMN_DEFAULT"],
pk=row["COLUMN_KEY"] == "PRI",
comment=row["COLUMN_COMMENT"],
unique=row["COLUMN_KEY"] == "UNI",
extra=row["EXTRA"],
length=row["CHARACTER_MAXIMUM_LENGTH"],
max_digits=row["NUMERIC_PRECISION"],
decimal_places=row["NUMERIC_SCALE"],
)
)
return columns

View File

@@ -0,0 +1,72 @@
from typing import List, Optional
from tortoise import BaseDBAsyncClient
from aerich.inspect import Column, Inspect
class InspectPostgres(Inspect):
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
super().__init__(conn, tables)
self.schema = self.conn.server_settings.get("schema") or "public"
@property
def field_map(self) -> dict:
return {
"int4": self.int_field,
"int8": self.int_field,
"varchar": self.char_field,
"text": self.text_field,
"timestamptz": self.datetime_field,
"float4": self.float_field,
"float8": self.float_field,
"date": self.date_field,
"time": self.time_field,
"decimal": self.decimal_field,
"uuid": self.uuid_field,
"jsonb": self.json_field,
"bytea": self.binary_field,
"bool": self.bool_field,
"timestamp": self.datetime_field,
}
async def get_all_tables(self) -> List[str]:
sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2"
ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
return list(map(lambda x: x["table_name"], ret))
async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = f"""select c.column_name,
col_description('public.{table}'::regclass, ordinal_position) as column_comment,
t.constraint_type as column_key,
udt_name as data_type,
is_nullable,
column_default,
character_maximum_length,
numeric_precision,
numeric_scale
from information_schema.constraint_column_usage const
join information_schema.table_constraints t
using (table_catalog, table_schema, table_name, constraint_catalog, constraint_schema, constraint_name)
right join information_schema.columns c using (column_name, table_catalog, table_schema, table_name)
where c.table_catalog = $1
and c.table_name = $2
and c.table_schema = $3;"""
ret = await self.conn.execute_query(sql, [self.database, table, self.schema])
for row in ret[1]:
columns.append(
Column(
name=row["column_name"],
data_type=row["data_type"],
null=row["is_nullable"] == "YES",
default=row["column_default"],
length=row["character_maximum_length"],
max_digits=row["numeric_precision"],
decimal_places=row["numeric_scale"],
comment=row["column_comment"],
pk=row["column_key"] == "PRIMARY KEY",
unique=False, # can't get this simply
)
)
return columns

View File

@@ -1,87 +0,0 @@
import sys
from typing import List, Optional
from ddlparse import DdlParse
from tortoise import BaseDBAsyncClient
class InspectDb:
_table_template = "class {table}(Model):\n"
_field_template_mapping = {
"INT": " {field} = fields.IntField({pk}{unique}{comment})",
"SMALLINT": " {field} = fields.IntField({pk}{unique}{comment})",
"TINYINT": " {field} = fields.BooleanField({null}{default}{comment})",
"VARCHAR": " {field} = fields.CharField({pk}{unique}{length}{null}{default}{comment})",
"LONGTEXT": " {field} = fields.TextField({null}{default}{comment})",
"TEXT": " {field} = fields.TextField({null}{default}{comment})",
"DATETIME": " {field} = fields.DatetimeField({null}{default}{comment})",
"FLOAT": " {field} = fields.FloatField({null}{default}{comment})",
"DATE": " {field} = fields.DateField({null}{default}{comment})",
}
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
self.conn = conn
self.tables = tables
self.DIALECT = conn.schema_generator.DIALECT
async def show_create_tables(self):
if self.DIALECT == "mysql":
if not self.tables:
sql_tables = f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{self.conn.database}';" # nosec: B608
ret = await self.conn.execute_query(sql_tables)
self.tables = map(lambda x: x["TABLE_NAME"], ret[1])
for table in self.tables:
sql_show_create_table = f"SHOW CREATE TABLE {table}"
ret = await self.conn.execute_query(sql_show_create_table)
yield ret[1][0]["Create Table"]
else:
raise NotImplementedError("Currently only support MySQL")
async def inspect(self):
ddl_list = self.show_create_tables()
result = "from tortoise import Model, fields\n\n\n"
tables = []
async for ddl in ddl_list:
parser = DdlParse(ddl, DdlParse.DATABASE.mysql)
table = parser.parse()
name = table.name.title()
columns = table.columns
fields = []
model = self._table_template.format(table=name)
for column_name, column in columns.items():
comment = default = length = unique = null = pk = ""
if column.primary_key:
pk = "pk=True, "
if column.unique:
unique = "unique=True, "
if column.data_type == "VARCHAR":
length = f"max_length={column.length}, "
if not column.not_null:
null = "null=True, "
if column.default is not None:
if column.data_type == "TINYINT":
default = f"default={'True' if column.default == '1' else 'False'}, "
elif column.data_type == "DATETIME":
if "CURRENT_TIMESTAMP" in column.default:
if "ON UPDATE CURRENT_TIMESTAMP" in ddl:
default = "auto_now_add=True, "
else:
default = "auto_now=True, "
else:
default = f"default={column.default}, "
if column.comment:
comment = f"description='{column.comment}', "
field = self._field_template_mapping[column.data_type].format(
field=column_name,
pk=pk,
unique=unique,
length=length,
null=null,
default=default,
comment=comment,
)
fields.append(field)
tables.append(model + "\n".join(fields))
sys.stdout.write(result + "\n\n\n".join(tables))

View File

@@ -1 +1 @@
__version__ = "0.6.2"
__version__ = "0.6.3"