Improve type hints of inspectdb (#371)

This commit is contained in:
Waket Zheng 2024-12-03 12:40:28 +08:00 committed by GitHub
parent 4e46d9d969
commit b2f4029a4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 46 additions and 28 deletions

View File

@ -1,9 +1,24 @@
from typing import Any, List, Optional
from __future__ import annotations
from typing import Any, Callable, Dict, Optional, TypedDict
from pydantic import BaseModel
from tortoise import BaseDBAsyncClient
class ColumnInfoDict(TypedDict):
name: str
pk: str
index: str
null: str
default: str
length: str
comment: str
FieldMapDict = Dict[str, Callable[..., str]]
class Column(BaseModel):
name: str
data_type: str
@ -18,7 +33,7 @@ class Column(BaseModel):
decimal_places: Optional[int] = None
max_digits: Optional[int] = None
def translate(self) -> dict:
def translate(self) -> ColumnInfoDict:
comment = default = length = index = null = pk = ""
if self.pk:
pk = "pk=True, "
@ -28,23 +43,24 @@ class Column(BaseModel):
else:
if self.index:
index = "index=True, "
if self.data_type in ["varchar", "VARCHAR"]:
if self.data_type in ("varchar", "VARCHAR"):
length = f"max_length={self.length}, "
if self.data_type in ["decimal", "numeric"]:
elif self.data_type in ("decimal", "numeric"):
length_parts = []
if self.max_digits:
length_parts.append(f"max_digits={self.max_digits}")
if self.decimal_places:
length_parts.append(f"decimal_places={self.decimal_places}")
length = ", ".join(length_parts)+", "
if length_parts:
length = ", ".join(length_parts) + ", "
if self.null:
null = "null=True, "
if self.default is not None:
if self.data_type in ["tinyint", "INT"]:
if self.data_type in ("tinyint", "INT"):
default = f"default={'True' if self.default == '1' else 'False'}, "
elif self.data_type == "bool":
default = f"default={'True' if self.default == 'true' else 'False'}, "
elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]:
elif self.data_type in ("datetime", "timestamptz", "TIMESTAMP"):
if "CURRENT_TIMESTAMP" == self.default:
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
default = "auto_now=True, "
@ -76,7 +92,7 @@ class Column(BaseModel):
class Inspect:
_table_template = "class {table}(Model):\n"
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
def __init__(self, conn: BaseDBAsyncClient, tables: list[str] | None = None) -> None:
self.conn = conn
try:
self.database = conn.database # type:ignore[attr-defined]
@ -85,7 +101,7 @@ class Inspect:
self.tables = tables
@property
def field_map(self) -> dict:
def field_map(self) -> FieldMapDict:
raise NotImplementedError
async def inspect(self) -> str:
@ -103,10 +119,10 @@ class Inspect:
tables.append(model + "\n".join(fields))
return result + "\n\n\n".join(tables)
async def get_columns(self, table: str) -> List[Column]:
async def get_columns(self, table: str) -> list[Column]:
raise NotImplementedError
async def get_all_tables(self) -> List[str]:
async def get_all_tables(self) -> list[str]:
raise NotImplementedError
@classmethod

View File

@ -1,11 +1,11 @@
from typing import List
from __future__ import annotations
from aerich.inspectdb import Column, Inspect
from aerich.inspectdb import Column, FieldMapDict, Inspect
class InspectMySQL(Inspect):
@property
def field_map(self) -> dict:
def field_map(self) -> FieldMapDict:
return {
"int": self.int_field,
"smallint": self.smallint_field,
@ -24,12 +24,12 @@ class InspectMySQL(Inspect):
"longblob": self.binary_field,
}
async def get_all_tables(self) -> List[str]:
async def get_all_tables(self) -> list[str]:
sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s"
ret = await self.conn.execute_query_dict(sql, [self.database])
return list(map(lambda x: x["TABLE_NAME"], ret))
async def get_columns(self, table: str) -> List[Column]:
async def get_columns(self, table: str) -> list[Column]:
columns = []
sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME
from information_schema.COLUMNS c

View File

@ -1,18 +1,20 @@
from typing import TYPE_CHECKING, List, Optional
from __future__ import annotations
from aerich.inspectdb import Column, Inspect
from typing import TYPE_CHECKING
from aerich.inspectdb import Column, FieldMapDict, Inspect
if TYPE_CHECKING:
from tortoise.backends.base_postgres.client import BasePostgresClient
class InspectPostgres(Inspect):
def __init__(self, conn: "BasePostgresClient", tables: Optional[List[str]] = None) -> None:
def __init__(self, conn: "BasePostgresClient", tables: list[str] | None = None) -> None:
super().__init__(conn, tables)
self.schema = conn.server_settings.get("schema") or "public"
@property
def field_map(self) -> dict:
def field_map(self) -> FieldMapDict:
return {
"int4": self.int_field,
"int8": self.int_field,
@ -34,12 +36,12 @@ class InspectPostgres(Inspect):
"timestamp": self.datetime_field,
}
async def get_all_tables(self) -> List[str]:
async def get_all_tables(self) -> list[str]:
sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2"
ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
return list(map(lambda x: x["table_name"], ret))
async def get_columns(self, table: str) -> List[Column]:
async def get_columns(self, table: str) -> list[Column]:
columns = []
sql = f"""select c.column_name,
col_description('public.{table}'::regclass, ordinal_position) as column_comment,

View File

@ -1,11 +1,11 @@
from typing import Callable, Dict, List
from __future__ import annotations
from aerich.inspectdb import Column, Inspect
from aerich.inspectdb import Column, FieldMapDict, Inspect
class InspectSQLite(Inspect):
@property
def field_map(self) -> Dict[str, Callable[..., str]]:
def field_map(self) -> FieldMapDict:
return {
"INTEGER": self.int_field,
"INT": self.bool_field,
@ -21,7 +21,7 @@ class InspectSQLite(Inspect):
"BLOB": self.binary_field,
}
async def get_columns(self, table: str) -> List[Column]:
async def get_columns(self, table: str) -> list[Column]:
columns = []
sql = f"PRAGMA table_info({table})"
ret = await self.conn.execute_query_dict(sql)
@ -45,7 +45,7 @@ class InspectSQLite(Inspect):
)
return columns
async def _get_columns_index(self, table: str) -> Dict[str, str]:
async def _get_columns_index(self, table: str) -> dict[str, str]:
sql = f"PRAGMA index_list ({table})"
indexes = await self.conn.execute_query_dict(sql)
ret = {}
@ -55,7 +55,7 @@ class InspectSQLite(Inspect):
ret[index_info["name"]] = "unique" if index["unique"] else "index"
return ret
async def get_all_tables(self) -> List[str]:
async def get_all_tables(self) -> list[str]:
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
ret = await self.conn.execute_query_dict(sql)
return list(map(lambda x: x["tbl_name"], ret))