Improve type hints of inspectdb (#371)

This commit is contained in:
Waket Zheng 2024-12-03 12:40:28 +08:00 committed by GitHub
parent 4e46d9d969
commit b2f4029a4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 46 additions and 28 deletions

View File

@ -1,9 +1,24 @@
from typing import Any, List, Optional from __future__ import annotations
from typing import Any, Callable, Dict, Optional, TypedDict
from pydantic import BaseModel from pydantic import BaseModel
from tortoise import BaseDBAsyncClient from tortoise import BaseDBAsyncClient
class ColumnInfoDict(TypedDict):
name: str
pk: str
index: str
null: str
default: str
length: str
comment: str
FieldMapDict = Dict[str, Callable[..., str]]
class Column(BaseModel): class Column(BaseModel):
name: str name: str
data_type: str data_type: str
@ -18,7 +33,7 @@ class Column(BaseModel):
decimal_places: Optional[int] = None decimal_places: Optional[int] = None
max_digits: Optional[int] = None max_digits: Optional[int] = None
def translate(self) -> dict: def translate(self) -> ColumnInfoDict:
comment = default = length = index = null = pk = "" comment = default = length = index = null = pk = ""
if self.pk: if self.pk:
pk = "pk=True, " pk = "pk=True, "
@ -28,23 +43,24 @@ class Column(BaseModel):
else: else:
if self.index: if self.index:
index = "index=True, " index = "index=True, "
if self.data_type in ["varchar", "VARCHAR"]: if self.data_type in ("varchar", "VARCHAR"):
length = f"max_length={self.length}, " length = f"max_length={self.length}, "
if self.data_type in ["decimal", "numeric"]: elif self.data_type in ("decimal", "numeric"):
length_parts = [] length_parts = []
if self.max_digits: if self.max_digits:
length_parts.append(f"max_digits={self.max_digits}") length_parts.append(f"max_digits={self.max_digits}")
if self.decimal_places: if self.decimal_places:
length_parts.append(f"decimal_places={self.decimal_places}") length_parts.append(f"decimal_places={self.decimal_places}")
length = ", ".join(length_parts)+", " if length_parts:
length = ", ".join(length_parts) + ", "
if self.null: if self.null:
null = "null=True, " null = "null=True, "
if self.default is not None: if self.default is not None:
if self.data_type in ["tinyint", "INT"]: if self.data_type in ("tinyint", "INT"):
default = f"default={'True' if self.default == '1' else 'False'}, " default = f"default={'True' if self.default == '1' else 'False'}, "
elif self.data_type == "bool": elif self.data_type == "bool":
default = f"default={'True' if self.default == 'true' else 'False'}, " default = f"default={'True' if self.default == 'true' else 'False'}, "
elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]: elif self.data_type in ("datetime", "timestamptz", "TIMESTAMP"):
if "CURRENT_TIMESTAMP" == self.default: if "CURRENT_TIMESTAMP" == self.default:
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra: if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
default = "auto_now=True, " default = "auto_now=True, "
@ -76,7 +92,7 @@ class Column(BaseModel):
class Inspect: class Inspect:
_table_template = "class {table}(Model):\n" _table_template = "class {table}(Model):\n"
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None): def __init__(self, conn: BaseDBAsyncClient, tables: list[str] | None = None) -> None:
self.conn = conn self.conn = conn
try: try:
self.database = conn.database # type:ignore[attr-defined] self.database = conn.database # type:ignore[attr-defined]
@ -85,7 +101,7 @@ class Inspect:
self.tables = tables self.tables = tables
@property @property
def field_map(self) -> dict: def field_map(self) -> FieldMapDict:
raise NotImplementedError raise NotImplementedError
async def inspect(self) -> str: async def inspect(self) -> str:
@ -103,10 +119,10 @@ class Inspect:
tables.append(model + "\n".join(fields)) tables.append(model + "\n".join(fields))
return result + "\n\n\n".join(tables) return result + "\n\n\n".join(tables)
async def get_columns(self, table: str) -> List[Column]: async def get_columns(self, table: str) -> list[Column]:
raise NotImplementedError raise NotImplementedError
async def get_all_tables(self) -> List[str]: async def get_all_tables(self) -> list[str]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod

View File

@ -1,11 +1,11 @@
from typing import List from __future__ import annotations
from aerich.inspectdb import Column, Inspect from aerich.inspectdb import Column, FieldMapDict, Inspect
class InspectMySQL(Inspect): class InspectMySQL(Inspect):
@property @property
def field_map(self) -> dict: def field_map(self) -> FieldMapDict:
return { return {
"int": self.int_field, "int": self.int_field,
"smallint": self.smallint_field, "smallint": self.smallint_field,
@ -24,12 +24,12 @@ class InspectMySQL(Inspect):
"longblob": self.binary_field, "longblob": self.binary_field,
} }
async def get_all_tables(self) -> List[str]: async def get_all_tables(self) -> list[str]:
sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s" sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s"
ret = await self.conn.execute_query_dict(sql, [self.database]) ret = await self.conn.execute_query_dict(sql, [self.database])
return list(map(lambda x: x["TABLE_NAME"], ret)) return list(map(lambda x: x["TABLE_NAME"], ret))
async def get_columns(self, table: str) -> List[Column]: async def get_columns(self, table: str) -> list[Column]:
columns = [] columns = []
sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME
from information_schema.COLUMNS c from information_schema.COLUMNS c

View File

@ -1,18 +1,20 @@
from typing import TYPE_CHECKING, List, Optional from __future__ import annotations
from aerich.inspectdb import Column, Inspect from typing import TYPE_CHECKING
from aerich.inspectdb import Column, FieldMapDict, Inspect
if TYPE_CHECKING: if TYPE_CHECKING:
from tortoise.backends.base_postgres.client import BasePostgresClient from tortoise.backends.base_postgres.client import BasePostgresClient
class InspectPostgres(Inspect): class InspectPostgres(Inspect):
def __init__(self, conn: "BasePostgresClient", tables: Optional[List[str]] = None) -> None: def __init__(self, conn: "BasePostgresClient", tables: list[str] | None = None) -> None:
super().__init__(conn, tables) super().__init__(conn, tables)
self.schema = conn.server_settings.get("schema") or "public" self.schema = conn.server_settings.get("schema") or "public"
@property @property
def field_map(self) -> dict: def field_map(self) -> FieldMapDict:
return { return {
"int4": self.int_field, "int4": self.int_field,
"int8": self.int_field, "int8": self.int_field,
@ -34,12 +36,12 @@ class InspectPostgres(Inspect):
"timestamp": self.datetime_field, "timestamp": self.datetime_field,
} }
async def get_all_tables(self) -> List[str]: async def get_all_tables(self) -> list[str]:
sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2" sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2"
ret = await self.conn.execute_query_dict(sql, [self.database, self.schema]) ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
return list(map(lambda x: x["table_name"], ret)) return list(map(lambda x: x["table_name"], ret))
async def get_columns(self, table: str) -> List[Column]: async def get_columns(self, table: str) -> list[Column]:
columns = [] columns = []
sql = f"""select c.column_name, sql = f"""select c.column_name,
col_description('public.{table}'::regclass, ordinal_position) as column_comment, col_description('public.{table}'::regclass, ordinal_position) as column_comment,

View File

@ -1,11 +1,11 @@
from typing import Callable, Dict, List from __future__ import annotations
from aerich.inspectdb import Column, Inspect from aerich.inspectdb import Column, FieldMapDict, Inspect
class InspectSQLite(Inspect): class InspectSQLite(Inspect):
@property @property
def field_map(self) -> Dict[str, Callable[..., str]]: def field_map(self) -> FieldMapDict:
return { return {
"INTEGER": self.int_field, "INTEGER": self.int_field,
"INT": self.bool_field, "INT": self.bool_field,
@ -21,7 +21,7 @@ class InspectSQLite(Inspect):
"BLOB": self.binary_field, "BLOB": self.binary_field,
} }
async def get_columns(self, table: str) -> List[Column]: async def get_columns(self, table: str) -> list[Column]:
columns = [] columns = []
sql = f"PRAGMA table_info({table})" sql = f"PRAGMA table_info({table})"
ret = await self.conn.execute_query_dict(sql) ret = await self.conn.execute_query_dict(sql)
@ -45,7 +45,7 @@ class InspectSQLite(Inspect):
) )
return columns return columns
async def _get_columns_index(self, table: str) -> Dict[str, str]: async def _get_columns_index(self, table: str) -> dict[str, str]:
sql = f"PRAGMA index_list ({table})" sql = f"PRAGMA index_list ({table})"
indexes = await self.conn.execute_query_dict(sql) indexes = await self.conn.execute_query_dict(sql)
ret = {} ret = {}
@ -55,7 +55,7 @@ class InspectSQLite(Inspect):
ret[index_info["name"]] = "unique" if index["unique"] else "index" ret[index_info["name"]] = "unique" if index["unique"] else "index"
return ret return ret
async def get_all_tables(self) -> List[str]: async def get_all_tables(self) -> list[str]:
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'" sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
ret = await self.conn.execute_query_dict(sql) ret = await self.conn.execute_query_dict(sql)
return list(map(lambda x: x["tbl_name"], ret)) return list(map(lambda x: x["tbl_name"], ret))