feat: inspectdb
support sqlite
This commit is contained in:
parent
75480e2041
commit
801dde15be
@ -4,7 +4,7 @@
|
||||
|
||||
### 0.6.3
|
||||
|
||||
- Improve `inspectdb` and support `postgres`.
|
||||
- Improve `inspectdb` and support `postgres` & `sqlite`.
|
||||
|
||||
### 0.6.2
|
||||
|
||||
|
3
Makefile
3
Makefile
@ -21,8 +21,7 @@ style: deps
|
||||
check: deps
|
||||
@black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
|
||||
@pflake8 $(checkfiles)
|
||||
@bandit -x tests -r $(checkfiles)
|
||||
#@mypy $(checkfiles)
|
||||
|
||||
test: deps
|
||||
$(py_warn) TEST_DB=sqlite://:memory: py.test
|
||||
|
||||
|
@ -165,7 +165,7 @@ Now your db is rolled back to the specified version.
|
||||
|
||||
### Inspect db tables to TortoiseORM model
|
||||
|
||||
Currently `inspectdb` only supports MySQL & Postgres.
|
||||
Currently `inspectdb` support MySQL & Postgres & SQLite.
|
||||
|
||||
```shell
|
||||
Usage: aerich inspectdb [OPTIONS]
|
||||
|
@ -10,6 +10,7 @@ from tortoise.utils import get_schema_sql
|
||||
from aerich.exceptions import DowngradeError
|
||||
from aerich.inspect.mysql import InspectMySQL
|
||||
from aerich.inspect.postgres import InspectPostgres
|
||||
from aerich.inspect.sqlite import InspectSQLite
|
||||
from aerich.migrate import Migrate
|
||||
from aerich.models import Aerich
|
||||
from aerich.utils import (
|
||||
@ -114,6 +115,8 @@ class Command:
|
||||
cls = InspectMySQL
|
||||
elif dialect == "postgres":
|
||||
cls = InspectPostgres
|
||||
elif dialect == "sqlite":
|
||||
cls = InspectSQLite
|
||||
else:
|
||||
raise NotImplementedError(f"{dialect} is not supported")
|
||||
inspect = cls(connection, tables)
|
||||
|
@ -30,7 +30,7 @@ def coro(f):
|
||||
try:
|
||||
loop.run_until_complete(f(*args, **kwargs))
|
||||
finally:
|
||||
if f.__name__ not in ["cli", "init"]:
|
||||
if Tortoise._inited:
|
||||
loop.run_until_complete(Tortoise.close_connections())
|
||||
|
||||
return wrapper
|
||||
|
@ -23,18 +23,18 @@ class Column(BaseModel):
|
||||
pk = "pk=True, "
|
||||
if self.unique:
|
||||
unique = "unique=True, "
|
||||
if self.data_type == "varchar":
|
||||
if self.data_type in ["varchar", "VARCHAR"]:
|
||||
length = f"max_length={self.length}, "
|
||||
if self.data_type == "decimal":
|
||||
length = f"max_digits={self.max_digits}, decimal_places={self.decimal_places}, "
|
||||
if self.null:
|
||||
null = "null=True, "
|
||||
if self.default is not None:
|
||||
if self.data_type == "tinyint":
|
||||
if self.data_type in ["tinyint", "INT"]:
|
||||
default = f"default={'True' if self.default == '1' else 'False'}, "
|
||||
elif self.data_type == "bool":
|
||||
default = f"default={'True' if self.default == 'true' else 'False'}, "
|
||||
elif self.data_type in ["datetime", "timestamptz"]:
|
||||
elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]:
|
||||
if "CURRENT_TIMESTAMP" == self.default:
|
||||
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
|
||||
default = "auto_now=True, "
|
||||
@ -66,7 +66,10 @@ class Inspect:
|
||||
|
||||
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
|
||||
self.conn = conn
|
||||
self.database = conn.database
|
||||
try:
|
||||
self.database = conn.database
|
||||
except AttributeError:
|
||||
pass
|
||||
self.tables = tables
|
||||
|
||||
@property
|
||||
|
@ -10,6 +10,7 @@ class InspectMySQL(Inspect):
|
||||
"int": self.int_field,
|
||||
"smallint": self.smallint_field,
|
||||
"tinyint": self.bool_field,
|
||||
"bigint": self.bigint_field,
|
||||
"varchar": self.char_field,
|
||||
"longtext": self.text_field,
|
||||
"text": self.text_field,
|
||||
@ -30,8 +31,8 @@ class InspectMySQL(Inspect):
|
||||
async def get_columns(self, table: str) -> List[Column]:
|
||||
columns = []
|
||||
sql = "select * from information_schema.columns where TABLE_SCHEMA=%s and TABLE_NAME=%s"
|
||||
ret = await self.conn.execute_query(sql, [self.database, table])
|
||||
for row in ret[1]:
|
||||
ret = await self.conn.execute_query_dict(sql, [self.database, table])
|
||||
for row in ret:
|
||||
columns.append(
|
||||
Column(
|
||||
name=row["COLUMN_NAME"],
|
||||
|
@ -15,8 +15,10 @@ class InspectPostgres(Inspect):
|
||||
return {
|
||||
"int4": self.int_field,
|
||||
"int8": self.int_field,
|
||||
"smallint": self.smallint_field,
|
||||
"varchar": self.char_field,
|
||||
"text": self.text_field,
|
||||
"bigint": self.bigint_field,
|
||||
"timestamptz": self.datetime_field,
|
||||
"float4": self.float_field,
|
||||
"float8": self.float_field,
|
||||
@ -53,8 +55,8 @@ from information_schema.constraint_column_usage const
|
||||
where c.table_catalog = $1
|
||||
and c.table_name = $2
|
||||
and c.table_schema = $3;"""
|
||||
ret = await self.conn.execute_query(sql, [self.database, table, self.schema])
|
||||
for row in ret[1]:
|
||||
ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema])
|
||||
for row in ret:
|
||||
columns.append(
|
||||
Column(
|
||||
name=row["column_name"],
|
||||
|
49
aerich/inspect/sqlite.py
Normal file
49
aerich/inspect/sqlite.py
Normal file
@ -0,0 +1,49 @@
|
||||
from typing import List
|
||||
|
||||
from aerich.inspect import Column, Inspect
|
||||
|
||||
|
||||
class InspectSQLite(Inspect):
|
||||
@property
|
||||
def field_map(self) -> dict:
|
||||
return {
|
||||
"INTEGER": self.int_field,
|
||||
"INT": self.bool_field,
|
||||
"SMALLINT": self.smallint_field,
|
||||
"VARCHAR": self.char_field,
|
||||
"TEXT": self.text_field,
|
||||
"TIMESTAMP": self.datetime_field,
|
||||
"REAL": self.float_field,
|
||||
"BIGINT": self.bigint_field,
|
||||
"DATE": self.date_field,
|
||||
"TIME": self.time_field,
|
||||
"JSON": self.json_field,
|
||||
"BLOB": self.binary_field,
|
||||
}
|
||||
|
||||
async def get_columns(self, table: str) -> List[Column]:
|
||||
columns = []
|
||||
sql = f"PRAGMA table_info({table})"
|
||||
ret = await self.conn.execute_query_dict(sql)
|
||||
for row in ret:
|
||||
try:
|
||||
length = row["type"].split("(")[1].split(")")[0]
|
||||
except IndexError:
|
||||
length = None
|
||||
columns.append(
|
||||
Column(
|
||||
name=row["name"],
|
||||
data_type=row["type"].split("(")[0],
|
||||
null=row["notnull"] == 0,
|
||||
default=row["dflt_value"],
|
||||
length=length,
|
||||
pk=row["pk"] == 1,
|
||||
unique=False, # can't get this simply
|
||||
)
|
||||
)
|
||||
return columns
|
||||
|
||||
async def get_all_tables(self) -> List[str]:
|
||||
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
|
||||
ret = await self.conn.execute_query_dict(sql)
|
||||
return list(map(lambda x: x["tbl_name"], ret))
|
Loading…
x
Reference in New Issue
Block a user