feat: inspectdb support sqlite

This commit is contained in:
long2ice 2022-04-01 20:30:36 +08:00
parent 75480e2041
commit 801dde15be
9 changed files with 70 additions and 13 deletions

View File

@ -4,7 +4,7 @@
### 0.6.3
- Improve `inspectdb` and support `postgres`.
- Improve `inspectdb` and support `postgres` & `sqlite`.
### 0.6.2

View File

@ -21,8 +21,7 @@ style: deps
check: deps
@black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
@pflake8 $(checkfiles)
@bandit -x tests -r $(checkfiles)
#@mypy $(checkfiles)
test: deps
$(py_warn) TEST_DB=sqlite://:memory: py.test

View File

@ -165,7 +165,7 @@ Now your db is rolled back to the specified version.
### Inspect db tables to TortoiseORM model
Currently `inspectdb` only supports MySQL & Postgres.
Currently `inspectdb` support MySQL & Postgres & SQLite.
```shell
Usage: aerich inspectdb [OPTIONS]

View File

@ -10,6 +10,7 @@ from tortoise.utils import get_schema_sql
from aerich.exceptions import DowngradeError
from aerich.inspect.mysql import InspectMySQL
from aerich.inspect.postgres import InspectPostgres
from aerich.inspect.sqlite import InspectSQLite
from aerich.migrate import Migrate
from aerich.models import Aerich
from aerich.utils import (
@ -114,6 +115,8 @@ class Command:
cls = InspectMySQL
elif dialect == "postgres":
cls = InspectPostgres
elif dialect == "sqlite":
cls = InspectSQLite
else:
raise NotImplementedError(f"{dialect} is not supported")
inspect = cls(connection, tables)

View File

@ -30,7 +30,7 @@ def coro(f):
try:
loop.run_until_complete(f(*args, **kwargs))
finally:
if f.__name__ not in ["cli", "init"]:
if Tortoise._inited:
loop.run_until_complete(Tortoise.close_connections())
return wrapper

View File

@ -23,18 +23,18 @@ class Column(BaseModel):
pk = "pk=True, "
if self.unique:
unique = "unique=True, "
if self.data_type == "varchar":
if self.data_type in ["varchar", "VARCHAR"]:
length = f"max_length={self.length}, "
if self.data_type == "decimal":
length = f"max_digits={self.max_digits}, decimal_places={self.decimal_places}, "
if self.null:
null = "null=True, "
if self.default is not None:
if self.data_type == "tinyint":
if self.data_type in ["tinyint", "INT"]:
default = f"default={'True' if self.default == '1' else 'False'}, "
elif self.data_type == "bool":
default = f"default={'True' if self.default == 'true' else 'False'}, "
elif self.data_type in ["datetime", "timestamptz"]:
elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]:
if "CURRENT_TIMESTAMP" == self.default:
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
default = "auto_now=True, "
@ -66,7 +66,10 @@ class Inspect:
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
self.conn = conn
try:
self.database = conn.database
except AttributeError:
pass
self.tables = tables
@property

View File

@ -10,6 +10,7 @@ class InspectMySQL(Inspect):
"int": self.int_field,
"smallint": self.smallint_field,
"tinyint": self.bool_field,
"bigint": self.bigint_field,
"varchar": self.char_field,
"longtext": self.text_field,
"text": self.text_field,
@ -30,8 +31,8 @@ class InspectMySQL(Inspect):
async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = "select * from information_schema.columns where TABLE_SCHEMA=%s and TABLE_NAME=%s"
ret = await self.conn.execute_query(sql, [self.database, table])
for row in ret[1]:
ret = await self.conn.execute_query_dict(sql, [self.database, table])
for row in ret:
columns.append(
Column(
name=row["COLUMN_NAME"],

View File

@ -15,8 +15,10 @@ class InspectPostgres(Inspect):
return {
"int4": self.int_field,
"int8": self.int_field,
"smallint": self.smallint_field,
"varchar": self.char_field,
"text": self.text_field,
"bigint": self.bigint_field,
"timestamptz": self.datetime_field,
"float4": self.float_field,
"float8": self.float_field,
@ -53,8 +55,8 @@ from information_schema.constraint_column_usage const
where c.table_catalog = $1
and c.table_name = $2
and c.table_schema = $3;"""
ret = await self.conn.execute_query(sql, [self.database, table, self.schema])
for row in ret[1]:
ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema])
for row in ret:
columns.append(
Column(
name=row["column_name"],

49
aerich/inspect/sqlite.py Normal file
View File

@ -0,0 +1,49 @@
from typing import List
from aerich.inspect import Column, Inspect
class InspectSQLite(Inspect):
@property
def field_map(self) -> dict:
return {
"INTEGER": self.int_field,
"INT": self.bool_field,
"SMALLINT": self.smallint_field,
"VARCHAR": self.char_field,
"TEXT": self.text_field,
"TIMESTAMP": self.datetime_field,
"REAL": self.float_field,
"BIGINT": self.bigint_field,
"DATE": self.date_field,
"TIME": self.time_field,
"JSON": self.json_field,
"BLOB": self.binary_field,
}
async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = f"PRAGMA table_info({table})"
ret = await self.conn.execute_query_dict(sql)
for row in ret:
try:
length = row["type"].split("(")[1].split(")")[0]
except IndexError:
length = None
columns.append(
Column(
name=row["name"],
data_type=row["type"].split("(")[0],
null=row["notnull"] == 0,
default=row["dflt_value"],
length=length,
pk=row["pk"] == 1,
unique=False, # can't get this simply
)
)
return columns
async def get_all_tables(self) -> List[str]:
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
ret = await self.conn.execute_query_dict(sql)
return list(map(lambda x: x["tbl_name"], ret))