feat: inspectdb support sqlite

This commit is contained in:
long2ice 2022-04-01 20:30:36 +08:00
parent 75480e2041
commit 801dde15be
9 changed files with 70 additions and 13 deletions

View File

@ -4,7 +4,7 @@
### 0.6.3 ### 0.6.3
- Improve `inspectdb` and support `postgres`. - Improve `inspectdb` and support `postgres` & `sqlite`.
### 0.6.2 ### 0.6.2

View File

@ -21,8 +21,7 @@ style: deps
check: deps check: deps
@black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false) @black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
@pflake8 $(checkfiles) @pflake8 $(checkfiles)
@bandit -x tests -r $(checkfiles)
#@mypy $(checkfiles)
test: deps test: deps
$(py_warn) TEST_DB=sqlite://:memory: py.test $(py_warn) TEST_DB=sqlite://:memory: py.test

View File

@ -165,7 +165,7 @@ Now your db is rolled back to the specified version.
### Inspect db tables to TortoiseORM model ### Inspect db tables to TortoiseORM model
Currently `inspectdb` only supports MySQL & Postgres. Currently `inspectdb` support MySQL & Postgres & SQLite.
```shell ```shell
Usage: aerich inspectdb [OPTIONS] Usage: aerich inspectdb [OPTIONS]

View File

@ -10,6 +10,7 @@ from tortoise.utils import get_schema_sql
from aerich.exceptions import DowngradeError from aerich.exceptions import DowngradeError
from aerich.inspect.mysql import InspectMySQL from aerich.inspect.mysql import InspectMySQL
from aerich.inspect.postgres import InspectPostgres from aerich.inspect.postgres import InspectPostgres
from aerich.inspect.sqlite import InspectSQLite
from aerich.migrate import Migrate from aerich.migrate import Migrate
from aerich.models import Aerich from aerich.models import Aerich
from aerich.utils import ( from aerich.utils import (
@ -114,6 +115,8 @@ class Command:
cls = InspectMySQL cls = InspectMySQL
elif dialect == "postgres": elif dialect == "postgres":
cls = InspectPostgres cls = InspectPostgres
elif dialect == "sqlite":
cls = InspectSQLite
else: else:
raise NotImplementedError(f"{dialect} is not supported") raise NotImplementedError(f"{dialect} is not supported")
inspect = cls(connection, tables) inspect = cls(connection, tables)

View File

@ -30,7 +30,7 @@ def coro(f):
try: try:
loop.run_until_complete(f(*args, **kwargs)) loop.run_until_complete(f(*args, **kwargs))
finally: finally:
if f.__name__ not in ["cli", "init"]: if Tortoise._inited:
loop.run_until_complete(Tortoise.close_connections()) loop.run_until_complete(Tortoise.close_connections())
return wrapper return wrapper

View File

@ -23,18 +23,18 @@ class Column(BaseModel):
pk = "pk=True, " pk = "pk=True, "
if self.unique: if self.unique:
unique = "unique=True, " unique = "unique=True, "
if self.data_type == "varchar": if self.data_type in ["varchar", "VARCHAR"]:
length = f"max_length={self.length}, " length = f"max_length={self.length}, "
if self.data_type == "decimal": if self.data_type == "decimal":
length = f"max_digits={self.max_digits}, decimal_places={self.decimal_places}, " length = f"max_digits={self.max_digits}, decimal_places={self.decimal_places}, "
if self.null: if self.null:
null = "null=True, " null = "null=True, "
if self.default is not None: if self.default is not None:
if self.data_type == "tinyint": if self.data_type in ["tinyint", "INT"]:
default = f"default={'True' if self.default == '1' else 'False'}, " default = f"default={'True' if self.default == '1' else 'False'}, "
elif self.data_type == "bool": elif self.data_type == "bool":
default = f"default={'True' if self.default == 'true' else 'False'}, " default = f"default={'True' if self.default == 'true' else 'False'}, "
elif self.data_type in ["datetime", "timestamptz"]: elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]:
if "CURRENT_TIMESTAMP" == self.default: if "CURRENT_TIMESTAMP" == self.default:
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra: if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
default = "auto_now=True, " default = "auto_now=True, "
@ -66,7 +66,10 @@ class Inspect:
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None): def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
self.conn = conn self.conn = conn
self.database = conn.database try:
self.database = conn.database
except AttributeError:
pass
self.tables = tables self.tables = tables
@property @property

View File

@ -10,6 +10,7 @@ class InspectMySQL(Inspect):
"int": self.int_field, "int": self.int_field,
"smallint": self.smallint_field, "smallint": self.smallint_field,
"tinyint": self.bool_field, "tinyint": self.bool_field,
"bigint": self.bigint_field,
"varchar": self.char_field, "varchar": self.char_field,
"longtext": self.text_field, "longtext": self.text_field,
"text": self.text_field, "text": self.text_field,
@ -30,8 +31,8 @@ class InspectMySQL(Inspect):
async def get_columns(self, table: str) -> List[Column]: async def get_columns(self, table: str) -> List[Column]:
columns = [] columns = []
sql = "select * from information_schema.columns where TABLE_SCHEMA=%s and TABLE_NAME=%s" sql = "select * from information_schema.columns where TABLE_SCHEMA=%s and TABLE_NAME=%s"
ret = await self.conn.execute_query(sql, [self.database, table]) ret = await self.conn.execute_query_dict(sql, [self.database, table])
for row in ret[1]: for row in ret:
columns.append( columns.append(
Column( Column(
name=row["COLUMN_NAME"], name=row["COLUMN_NAME"],

View File

@ -15,8 +15,10 @@ class InspectPostgres(Inspect):
return { return {
"int4": self.int_field, "int4": self.int_field,
"int8": self.int_field, "int8": self.int_field,
"smallint": self.smallint_field,
"varchar": self.char_field, "varchar": self.char_field,
"text": self.text_field, "text": self.text_field,
"bigint": self.bigint_field,
"timestamptz": self.datetime_field, "timestamptz": self.datetime_field,
"float4": self.float_field, "float4": self.float_field,
"float8": self.float_field, "float8": self.float_field,
@ -53,8 +55,8 @@ from information_schema.constraint_column_usage const
where c.table_catalog = $1 where c.table_catalog = $1
and c.table_name = $2 and c.table_name = $2
and c.table_schema = $3;""" and c.table_schema = $3;"""
ret = await self.conn.execute_query(sql, [self.database, table, self.schema]) ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema])
for row in ret[1]: for row in ret:
columns.append( columns.append(
Column( Column(
name=row["column_name"], name=row["column_name"],

49
aerich/inspect/sqlite.py Normal file
View File

@ -0,0 +1,49 @@
from typing import List
from aerich.inspect import Column, Inspect
class InspectSQLite(Inspect):
@property
def field_map(self) -> dict:
return {
"INTEGER": self.int_field,
"INT": self.bool_field,
"SMALLINT": self.smallint_field,
"VARCHAR": self.char_field,
"TEXT": self.text_field,
"TIMESTAMP": self.datetime_field,
"REAL": self.float_field,
"BIGINT": self.bigint_field,
"DATE": self.date_field,
"TIME": self.time_field,
"JSON": self.json_field,
"BLOB": self.binary_field,
}
async def get_columns(self, table: str) -> List[Column]:
columns = []
sql = f"PRAGMA table_info({table})"
ret = await self.conn.execute_query_dict(sql)
for row in ret:
try:
length = row["type"].split("(")[1].split(")")[0]
except IndexError:
length = None
columns.append(
Column(
name=row["name"],
data_type=row["type"].split("(")[0],
null=row["notnull"] == 0,
default=row["dflt_value"],
length=length,
pk=row["pk"] == 1,
unique=False, # can't get this simply
)
)
return columns
async def get_all_tables(self) -> List[str]:
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
ret = await self.conn.execute_query_dict(sql)
return list(map(lambda x: x["tbl_name"], ret))