Improve type hints of inspectdb (#371)
This commit is contained in:
parent
4e46d9d969
commit
b2f4029a4a
@ -1,9 +1,24 @@
|
||||
from typing import Any, List, Optional
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
from tortoise import BaseDBAsyncClient
|
||||
|
||||
|
||||
class ColumnInfoDict(TypedDict):
|
||||
name: str
|
||||
pk: str
|
||||
index: str
|
||||
null: str
|
||||
default: str
|
||||
length: str
|
||||
comment: str
|
||||
|
||||
|
||||
FieldMapDict = Dict[str, Callable[..., str]]
|
||||
|
||||
|
||||
class Column(BaseModel):
|
||||
name: str
|
||||
data_type: str
|
||||
@ -18,7 +33,7 @@ class Column(BaseModel):
|
||||
decimal_places: Optional[int] = None
|
||||
max_digits: Optional[int] = None
|
||||
|
||||
def translate(self) -> dict:
|
||||
def translate(self) -> ColumnInfoDict:
|
||||
comment = default = length = index = null = pk = ""
|
||||
if self.pk:
|
||||
pk = "pk=True, "
|
||||
@ -28,23 +43,24 @@ class Column(BaseModel):
|
||||
else:
|
||||
if self.index:
|
||||
index = "index=True, "
|
||||
if self.data_type in ["varchar", "VARCHAR"]:
|
||||
if self.data_type in ("varchar", "VARCHAR"):
|
||||
length = f"max_length={self.length}, "
|
||||
if self.data_type in ["decimal", "numeric"]:
|
||||
elif self.data_type in ("decimal", "numeric"):
|
||||
length_parts = []
|
||||
if self.max_digits:
|
||||
length_parts.append(f"max_digits={self.max_digits}")
|
||||
if self.decimal_places:
|
||||
length_parts.append(f"decimal_places={self.decimal_places}")
|
||||
length = ", ".join(length_parts)+", "
|
||||
if length_parts:
|
||||
length = ", ".join(length_parts) + ", "
|
||||
if self.null:
|
||||
null = "null=True, "
|
||||
if self.default is not None:
|
||||
if self.data_type in ["tinyint", "INT"]:
|
||||
if self.data_type in ("tinyint", "INT"):
|
||||
default = f"default={'True' if self.default == '1' else 'False'}, "
|
||||
elif self.data_type == "bool":
|
||||
default = f"default={'True' if self.default == 'true' else 'False'}, "
|
||||
elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]:
|
||||
elif self.data_type in ("datetime", "timestamptz", "TIMESTAMP"):
|
||||
if "CURRENT_TIMESTAMP" == self.default:
|
||||
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
|
||||
default = "auto_now=True, "
|
||||
@ -76,7 +92,7 @@ class Column(BaseModel):
|
||||
class Inspect:
|
||||
_table_template = "class {table}(Model):\n"
|
||||
|
||||
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
|
||||
def __init__(self, conn: BaseDBAsyncClient, tables: list[str] | None = None) -> None:
|
||||
self.conn = conn
|
||||
try:
|
||||
self.database = conn.database # type:ignore[attr-defined]
|
||||
@ -85,7 +101,7 @@ class Inspect:
|
||||
self.tables = tables
|
||||
|
||||
@property
|
||||
def field_map(self) -> dict:
|
||||
def field_map(self) -> FieldMapDict:
|
||||
raise NotImplementedError
|
||||
|
||||
async def inspect(self) -> str:
|
||||
@ -103,10 +119,10 @@ class Inspect:
|
||||
tables.append(model + "\n".join(fields))
|
||||
return result + "\n\n\n".join(tables)
|
||||
|
||||
async def get_columns(self, table: str) -> List[Column]:
|
||||
async def get_columns(self, table: str) -> list[Column]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_all_tables(self) -> List[str]:
|
||||
async def get_all_tables(self) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
|
@ -1,11 +1,11 @@
|
||||
from typing import List
|
||||
from __future__ import annotations
|
||||
|
||||
from aerich.inspectdb import Column, Inspect
|
||||
from aerich.inspectdb import Column, FieldMapDict, Inspect
|
||||
|
||||
|
||||
class InspectMySQL(Inspect):
|
||||
@property
|
||||
def field_map(self) -> dict:
|
||||
def field_map(self) -> FieldMapDict:
|
||||
return {
|
||||
"int": self.int_field,
|
||||
"smallint": self.smallint_field,
|
||||
@ -24,12 +24,12 @@ class InspectMySQL(Inspect):
|
||||
"longblob": self.binary_field,
|
||||
}
|
||||
|
||||
async def get_all_tables(self) -> List[str]:
|
||||
async def get_all_tables(self) -> list[str]:
|
||||
sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s"
|
||||
ret = await self.conn.execute_query_dict(sql, [self.database])
|
||||
return list(map(lambda x: x["TABLE_NAME"], ret))
|
||||
|
||||
async def get_columns(self, table: str) -> List[Column]:
|
||||
async def get_columns(self, table: str) -> list[Column]:
|
||||
columns = []
|
||||
sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME
|
||||
from information_schema.COLUMNS c
|
||||
|
@ -1,18 +1,20 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from __future__ import annotations
|
||||
|
||||
from aerich.inspectdb import Column, Inspect
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from aerich.inspectdb import Column, FieldMapDict, Inspect
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tortoise.backends.base_postgres.client import BasePostgresClient
|
||||
|
||||
|
||||
class InspectPostgres(Inspect):
|
||||
def __init__(self, conn: "BasePostgresClient", tables: Optional[List[str]] = None) -> None:
|
||||
def __init__(self, conn: "BasePostgresClient", tables: list[str] | None = None) -> None:
|
||||
super().__init__(conn, tables)
|
||||
self.schema = conn.server_settings.get("schema") or "public"
|
||||
|
||||
@property
|
||||
def field_map(self) -> dict:
|
||||
def field_map(self) -> FieldMapDict:
|
||||
return {
|
||||
"int4": self.int_field,
|
||||
"int8": self.int_field,
|
||||
@ -34,12 +36,12 @@ class InspectPostgres(Inspect):
|
||||
"timestamp": self.datetime_field,
|
||||
}
|
||||
|
||||
async def get_all_tables(self) -> List[str]:
|
||||
async def get_all_tables(self) -> list[str]:
|
||||
sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2"
|
||||
ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
|
||||
return list(map(lambda x: x["table_name"], ret))
|
||||
|
||||
async def get_columns(self, table: str) -> List[Column]:
|
||||
async def get_columns(self, table: str) -> list[Column]:
|
||||
columns = []
|
||||
sql = f"""select c.column_name,
|
||||
col_description('public.{table}'::regclass, ordinal_position) as column_comment,
|
||||
|
@ -1,11 +1,11 @@
|
||||
from typing import Callable, Dict, List
|
||||
from __future__ import annotations
|
||||
|
||||
from aerich.inspectdb import Column, Inspect
|
||||
from aerich.inspectdb import Column, FieldMapDict, Inspect
|
||||
|
||||
|
||||
class InspectSQLite(Inspect):
|
||||
@property
|
||||
def field_map(self) -> Dict[str, Callable[..., str]]:
|
||||
def field_map(self) -> FieldMapDict:
|
||||
return {
|
||||
"INTEGER": self.int_field,
|
||||
"INT": self.bool_field,
|
||||
@ -21,7 +21,7 @@ class InspectSQLite(Inspect):
|
||||
"BLOB": self.binary_field,
|
||||
}
|
||||
|
||||
async def get_columns(self, table: str) -> List[Column]:
|
||||
async def get_columns(self, table: str) -> list[Column]:
|
||||
columns = []
|
||||
sql = f"PRAGMA table_info({table})"
|
||||
ret = await self.conn.execute_query_dict(sql)
|
||||
@ -45,7 +45,7 @@ class InspectSQLite(Inspect):
|
||||
)
|
||||
return columns
|
||||
|
||||
async def _get_columns_index(self, table: str) -> Dict[str, str]:
|
||||
async def _get_columns_index(self, table: str) -> dict[str, str]:
|
||||
sql = f"PRAGMA index_list ({table})"
|
||||
indexes = await self.conn.execute_query_dict(sql)
|
||||
ret = {}
|
||||
@ -55,7 +55,7 @@ class InspectSQLite(Inspect):
|
||||
ret[index_info["name"]] = "unique" if index["unique"] else "index"
|
||||
return ret
|
||||
|
||||
async def get_all_tables(self) -> List[str]:
|
||||
async def get_all_tables(self) -> list[str]:
|
||||
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
|
||||
ret = await self.conn.execute_query_dict(sql)
|
||||
return list(map(lambda x: x["tbl_name"], ret))
|
||||
|
Loading…
x
Reference in New Issue
Block a user