diff --git a/CHANGELOG.md b/CHANGELOG.md index 4228c02..ac8e186 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ ### 0.6.3 -- Improve `inspectdb` and support `postgres`. +- Improve `inspectdb` and support `postgres` & `sqlite`. ### 0.6.2 diff --git a/Makefile b/Makefile index cff117e..1522fdf 100644 --- a/Makefile +++ b/Makefile @@ -21,8 +21,7 @@ style: deps check: deps @black --check $(black_opts) $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false) @pflake8 $(checkfiles) - @bandit -x tests -r $(checkfiles) - #@mypy $(checkfiles) + test: deps $(py_warn) TEST_DB=sqlite://:memory: py.test diff --git a/README.md b/README.md index c7d1ca3..7456d46 100644 --- a/README.md +++ b/README.md @@ -165,7 +165,7 @@ Now your db is rolled back to the specified version. ### Inspect db tables to TortoiseORM model -Currently `inspectdb` only supports MySQL & Postgres. +Currently `inspectdb` support MySQL & Postgres & SQLite. ```shell Usage: aerich inspectdb [OPTIONS] diff --git a/aerich/__init__.py b/aerich/__init__.py index cc5d081..1caf1bd 100644 --- a/aerich/__init__.py +++ b/aerich/__init__.py @@ -10,6 +10,7 @@ from tortoise.utils import get_schema_sql from aerich.exceptions import DowngradeError from aerich.inspect.mysql import InspectMySQL from aerich.inspect.postgres import InspectPostgres +from aerich.inspect.sqlite import InspectSQLite from aerich.migrate import Migrate from aerich.models import Aerich from aerich.utils import ( @@ -114,6 +115,8 @@ class Command: cls = InspectMySQL elif dialect == "postgres": cls = InspectPostgres + elif dialect == "sqlite": + cls = InspectSQLite else: raise NotImplementedError(f"{dialect} is not supported") inspect = cls(connection, tables) diff --git a/aerich/cli.py b/aerich/cli.py index 847a27c..bb48774 100644 --- a/aerich/cli.py +++ b/aerich/cli.py @@ -30,7 +30,7 @@ def coro(f): try: loop.run_until_complete(f(*args, **kwargs)) finally: - if f.__name__ not in ["cli", "init"]: + if Tortoise._inited: loop.run_until_complete(Tortoise.close_connections()) return wrapper diff --git a/aerich/inspect/__init__.py b/aerich/inspect/__init__.py index 73b1ee6..5667468 100644 --- a/aerich/inspect/__init__.py +++ b/aerich/inspect/__init__.py @@ -23,18 +23,18 @@ class Column(BaseModel): pk = "pk=True, " if self.unique: unique = "unique=True, " - if self.data_type == "varchar": + if self.data_type in ["varchar", "VARCHAR"]: length = f"max_length={self.length}, " if self.data_type == "decimal": length = f"max_digits={self.max_digits}, decimal_places={self.decimal_places}, " if self.null: null = "null=True, " if self.default is not None: - if self.data_type == "tinyint": + if self.data_type in ["tinyint", "INT"]: default = f"default={'True' if self.default == '1' else 'False'}, " elif self.data_type == "bool": default = f"default={'True' if self.default == 'true' else 'False'}, " - elif self.data_type in ["datetime", "timestamptz"]: + elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]: if "CURRENT_TIMESTAMP" == self.default: if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra: default = "auto_now=True, " @@ -66,7 +66,10 @@ class Inspect: def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None): self.conn = conn - self.database = conn.database + try: + self.database = conn.database + except AttributeError: + pass self.tables = tables @property diff --git a/aerich/inspect/mysql.py b/aerich/inspect/mysql.py index 8cd0084..28c480a 100644 --- a/aerich/inspect/mysql.py +++ b/aerich/inspect/mysql.py @@ -10,6 +10,7 @@ class InspectMySQL(Inspect): "int": self.int_field, "smallint": self.smallint_field, "tinyint": self.bool_field, + "bigint": self.bigint_field, "varchar": self.char_field, "longtext": self.text_field, "text": self.text_field, @@ -30,8 +31,8 @@ class InspectMySQL(Inspect): async def get_columns(self, table: str) -> List[Column]: columns = [] sql = "select * from information_schema.columns where TABLE_SCHEMA=%s and TABLE_NAME=%s" - ret = await self.conn.execute_query(sql, [self.database, table]) - for row in ret[1]: + ret = await self.conn.execute_query_dict(sql, [self.database, table]) + for row in ret: columns.append( Column( name=row["COLUMN_NAME"], diff --git a/aerich/inspect/postgres.py b/aerich/inspect/postgres.py index e61f637..205db7f 100644 --- a/aerich/inspect/postgres.py +++ b/aerich/inspect/postgres.py @@ -15,8 +15,10 @@ class InspectPostgres(Inspect): return { "int4": self.int_field, "int8": self.int_field, + "smallint": self.smallint_field, "varchar": self.char_field, "text": self.text_field, + "bigint": self.bigint_field, "timestamptz": self.datetime_field, "float4": self.float_field, "float8": self.float_field, @@ -53,8 +55,8 @@ from information_schema.constraint_column_usage const where c.table_catalog = $1 and c.table_name = $2 and c.table_schema = $3;""" - ret = await self.conn.execute_query(sql, [self.database, table, self.schema]) - for row in ret[1]: + ret = await self.conn.execute_query_dict(sql, [self.database, table, self.schema]) + for row in ret: columns.append( Column( name=row["column_name"], diff --git a/aerich/inspect/sqlite.py b/aerich/inspect/sqlite.py new file mode 100644 index 0000000..5ce5bf6 --- /dev/null +++ b/aerich/inspect/sqlite.py @@ -0,0 +1,49 @@ +from typing import List + +from aerich.inspect import Column, Inspect + + +class InspectSQLite(Inspect): + @property + def field_map(self) -> dict: + return { + "INTEGER": self.int_field, + "INT": self.bool_field, + "SMALLINT": self.smallint_field, + "VARCHAR": self.char_field, + "TEXT": self.text_field, + "TIMESTAMP": self.datetime_field, + "REAL": self.float_field, + "BIGINT": self.bigint_field, + "DATE": self.date_field, + "TIME": self.time_field, + "JSON": self.json_field, + "BLOB": self.binary_field, + } + + async def get_columns(self, table: str) -> List[Column]: + columns = [] + sql = f"PRAGMA table_info({table})" + ret = await self.conn.execute_query_dict(sql) + for row in ret: + try: + length = row["type"].split("(")[1].split(")")[0] + except IndexError: + length = None + columns.append( + Column( + name=row["name"], + data_type=row["type"].split("(")[0], + null=row["notnull"] == 0, + default=row["dflt_value"], + length=length, + pk=row["pk"] == 1, + unique=False, # can't get this simply + ) + ) + return columns + + async def get_all_tables(self) -> List[str]: + sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'" + ret = await self.conn.execute_query_dict(sql) + return list(map(lambda x: x["tbl_name"], ret))