feat: inspectdb
support sqlite
This commit is contained in:
parent
75480e2041
commit
801dde15be
@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
### 0.6.3
|
### 0.6.3
|
||||||
|
|
||||||
- Improve `inspectdb` and support `postgres`.
|
- Improve `inspectdb` and support `postgres` & `sqlite`.
|
||||||
|
|
||||||
### 0.6.2
|
### 0.6.2
|
||||||
|
|
||||||
|
3
Makefile
3
Makefile
@ -21,8 +21,7 @@ style: deps
|
|||||||
check: deps
|
check: deps
|
||||||
@black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
|
@black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
|
||||||
@pflake8 $(checkfiles)
|
@pflake8 $(checkfiles)
|
||||||
@bandit -x tests -r $(checkfiles)
|
|
||||||
#@mypy $(checkfiles)
|
|
||||||
test: deps
|
test: deps
|
||||||
$(py_warn) TEST_DB=sqlite://:memory: py.test
|
$(py_warn) TEST_DB=sqlite://:memory: py.test
|
||||||
|
|
||||||
|
@ -165,7 +165,7 @@ Now your db is rolled back to the specified version.
|
|||||||
|
|
||||||
### Inspect db tables to TortoiseORM model
|
### Inspect db tables to TortoiseORM model
|
||||||
|
|
||||||
Currently `inspectdb` only supports MySQL & Postgres.
|
Currently `inspectdb` support MySQL & Postgres & SQLite.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
Usage: aerich inspectdb [OPTIONS]
|
Usage: aerich inspectdb [OPTIONS]
|
||||||
|
@ -10,6 +10,7 @@ from tortoise.utils import get_schema_sql
|
|||||||
from aerich.exceptions import DowngradeError
|
from aerich.exceptions import DowngradeError
|
||||||
from aerich.inspect.mysql import InspectMySQL
|
from aerich.inspect.mysql import InspectMySQL
|
||||||
from aerich.inspect.postgres import InspectPostgres
|
from aerich.inspect.postgres import InspectPostgres
|
||||||
|
from aerich.inspect.sqlite import InspectSQLite
|
||||||
from aerich.migrate import Migrate
|
from aerich.migrate import Migrate
|
||||||
from aerich.models import Aerich
|
from aerich.models import Aerich
|
||||||
from aerich.utils import (
|
from aerich.utils import (
|
||||||
@ -114,6 +115,8 @@ class Command:
|
|||||||
cls = InspectMySQL
|
cls = InspectMySQL
|
||||||
elif dialect == "postgres":
|
elif dialect == "postgres":
|
||||||
cls = InspectPostgres
|
cls = InspectPostgres
|
||||||
|
elif dialect == "sqlite":
|
||||||
|
cls = InspectSQLite
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{dialect} is not supported")
|
raise NotImplementedError(f"{dialect} is not supported")
|
||||||
inspect = cls(connection, tables)
|
inspect = cls(connection, tables)
|
||||||
|
@ -30,7 +30,7 @@ def coro(f):
|
|||||||
try:
|
try:
|
||||||
loop.run_until_complete(f(*args, **kwargs))
|
loop.run_until_complete(f(*args, **kwargs))
|
||||||
finally:
|
finally:
|
||||||
if f.__name__ not in ["cli", "init"]:
|
if Tortoise._inited:
|
||||||
loop.run_until_complete(Tortoise.close_connections())
|
loop.run_until_complete(Tortoise.close_connections())
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
@ -23,18 +23,18 @@ class Column(BaseModel):
|
|||||||
pk = "pk=True, "
|
pk = "pk=True, "
|
||||||
if self.unique:
|
if self.unique:
|
||||||
unique = "unique=True, "
|
unique = "unique=True, "
|
||||||
if self.data_type == "varchar":
|
if self.data_type in ["varchar", "VARCHAR"]:
|
||||||
length = f"max_length={self.length}, "
|
length = f"max_length={self.length}, "
|
||||||
if self.data_type == "decimal":
|
if self.data_type == "decimal":
|
||||||
length = f"max_digits={self.max_digits}, decimal_places={self.decimal_places}, "
|
length = f"max_digits={self.max_digits}, decimal_places={self.decimal_places}, "
|
||||||
if self.null:
|
if self.null:
|
||||||
null = "null=True, "
|
null = "null=True, "
|
||||||
if self.default is not None:
|
if self.default is not None:
|
||||||
if self.data_type == "tinyint":
|
if self.data_type in ["tinyint", "INT"]:
|
||||||
default = f"default={'True' if self.default == '1' else 'False'}, "
|
default = f"default={'True' if self.default == '1' else 'False'}, "
|
||||||
elif self.data_type == "bool":
|
elif self.data_type == "bool":
|
||||||
default = f"default={'True' if self.default == 'true' else 'False'}, "
|
default = f"default={'True' if self.default == 'true' else 'False'}, "
|
||||||
elif self.data_type in ["datetime", "timestamptz"]:
|
elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]:
|
||||||
if "CURRENT_TIMESTAMP" == self.default:
|
if "CURRENT_TIMESTAMP" == self.default:
|
||||||
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
|
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
|
||||||
default = "auto_now=True, "
|
default = "auto_now=True, "
|
||||||
@ -66,7 +66,10 @@ class Inspect:
|
|||||||
|
|
||||||
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
|
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
|
||||||
self.conn = conn
|
self.conn = conn
|
||||||
self.database = conn.database
|
try:
|
||||||
|
self.database = conn.database
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
self.tables = tables
|
self.tables = tables
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -10,6 +10,7 @@ class InspectMySQL(Inspect):
|
|||||||
"int": self.int_field,
|
"int": self.int_field,
|
||||||
"smallint": self.smallint_field,
|
"smallint": self.smallint_field,
|
||||||
"tinyint": self.bool_field,
|
"tinyint": self.bool_field,
|
||||||
|
"bigint": self.bigint_field,
|
||||||
"varchar": self.char_field,
|
"varchar": self.char_field,
|
||||||
"longtext": self.text_field,
|
"longtext": self.text_field,
|
||||||
"text": self.text_field,
|
"text": self.text_field,
|
||||||
@ -30,8 +31,8 @@ class InspectMySQL(Inspect):
|
|||||||
async def get_columns(self, table: str) -> List[Column]:
|
async def get_columns(self, table: str) -> List[Column]:
|
||||||
columns = []
|
columns = []
|
||||||
sql = "select * from information_schema.columns where TABLE_SCHEMA=%s and TABLE_NAME=%s"
|
sql = "select * from information_schema.columns where TABLE_SCHEMA=%s and TABLE_NAME=%s"
|
||||||
ret = await self.conn.execute_query(sql, [self.database, table])
|
ret = await self.conn.execute_query_dict(sql, [self.database, table])
|
||||||
for row in ret[1]:
|
for row in ret:
|
||||||
columns.append(
|
columns.append(
|
||||||
Column(
|
Column(
|
||||||
name=row["COLUMN_NAME"],
|
name=row["COLUMN_NAME"],
|
||||||
|
@ -15,8 +15,10 @@ class InspectPostgres(Inspect):
|
|||||||
return {
|
return {
|
||||||
"int4": self.int_field,
|
"int4": self.int_field,
|
||||||
"int8": self.int_field,
|
"int8": self.int_field,
|
||||||
|
"smallint": self.smallint_field,
|
||||||
"varchar": self.char_field,
|
"varchar": self.char_field,
|
||||||
"text": self.text_field,
|
"text": self.text_field,
|
||||||
|
"bigint": self.bigint_field,
|
||||||
"timestamptz": self.datetime_field,
|
"timestamptz": self.datetime_field,
|
||||||
"float4": self.float_field,
|
"float4": self.float_field,
|
||||||
"float8": self.float_field,
|
"float8": self.float_field,
|
||||||
@ -53,8 +55,8 @@ from information_schema.constraint_column_usage const
|
|||||||
where c.table_catalog = $1
|
where c.table_catalog = $1
|
||||||
and c.table_name = $2
|
and c.table_name = $2
|
||||||
and c.table_schema = $3;"""
|
and c.table_schema = $3;"""
|
||||||
ret = await self.conn.execute_query(sql, [self.database, table, self.schema])
|
ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema])
|
||||||
for row in ret[1]:
|
for row in ret:
|
||||||
columns.append(
|
columns.append(
|
||||||
Column(
|
Column(
|
||||||
name=row["column_name"],
|
name=row["column_name"],
|
||||||
|
49
aerich/inspect/sqlite.py
Normal file
49
aerich/inspect/sqlite.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
from aerich.inspect import Column, Inspect
|
||||||
|
|
||||||
|
|
||||||
|
class InspectSQLite(Inspect):
|
||||||
|
@property
|
||||||
|
def field_map(self) -> dict:
|
||||||
|
return {
|
||||||
|
"INTEGER": self.int_field,
|
||||||
|
"INT": self.bool_field,
|
||||||
|
"SMALLINT": self.smallint_field,
|
||||||
|
"VARCHAR": self.char_field,
|
||||||
|
"TEXT": self.text_field,
|
||||||
|
"TIMESTAMP": self.datetime_field,
|
||||||
|
"REAL": self.float_field,
|
||||||
|
"BIGINT": self.bigint_field,
|
||||||
|
"DATE": self.date_field,
|
||||||
|
"TIME": self.time_field,
|
||||||
|
"JSON": self.json_field,
|
||||||
|
"BLOB": self.binary_field,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def get_columns(self, table: str) -> List[Column]:
|
||||||
|
columns = []
|
||||||
|
sql = f"PRAGMA table_info({table})"
|
||||||
|
ret = await self.conn.execute_query_dict(sql)
|
||||||
|
for row in ret:
|
||||||
|
try:
|
||||||
|
length = row["type"].split("(")[1].split(")")[0]
|
||||||
|
except IndexError:
|
||||||
|
length = None
|
||||||
|
columns.append(
|
||||||
|
Column(
|
||||||
|
name=row["name"],
|
||||||
|
data_type=row["type"].split("(")[0],
|
||||||
|
null=row["notnull"] == 0,
|
||||||
|
default=row["dflt_value"],
|
||||||
|
length=length,
|
||||||
|
pk=row["pk"] == 1,
|
||||||
|
unique=False, # can't get this simply
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return columns
|
||||||
|
|
||||||
|
async def get_all_tables(self) -> List[str]:
|
||||||
|
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
|
||||||
|
ret = await self.conn.execute_query_dict(sql)
|
||||||
|
return list(map(lambda x: x["tbl_name"], ret))
|
Loading…
x
Reference in New Issue
Block a user