Improve type hints of inspectdb (#371)
This commit is contained in:
parent
4e46d9d969
commit
b2f4029a4a
@ -1,9 +1,24 @@
|
|||||||
from typing import Any, List, Optional
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Callable, Dict, Optional, TypedDict
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from tortoise import BaseDBAsyncClient
|
from tortoise import BaseDBAsyncClient
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnInfoDict(TypedDict):
|
||||||
|
name: str
|
||||||
|
pk: str
|
||||||
|
index: str
|
||||||
|
null: str
|
||||||
|
default: str
|
||||||
|
length: str
|
||||||
|
comment: str
|
||||||
|
|
||||||
|
|
||||||
|
FieldMapDict = Dict[str, Callable[..., str]]
|
||||||
|
|
||||||
|
|
||||||
class Column(BaseModel):
|
class Column(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
data_type: str
|
data_type: str
|
||||||
@ -18,7 +33,7 @@ class Column(BaseModel):
|
|||||||
decimal_places: Optional[int] = None
|
decimal_places: Optional[int] = None
|
||||||
max_digits: Optional[int] = None
|
max_digits: Optional[int] = None
|
||||||
|
|
||||||
def translate(self) -> dict:
|
def translate(self) -> ColumnInfoDict:
|
||||||
comment = default = length = index = null = pk = ""
|
comment = default = length = index = null = pk = ""
|
||||||
if self.pk:
|
if self.pk:
|
||||||
pk = "pk=True, "
|
pk = "pk=True, "
|
||||||
@ -28,23 +43,24 @@ class Column(BaseModel):
|
|||||||
else:
|
else:
|
||||||
if self.index:
|
if self.index:
|
||||||
index = "index=True, "
|
index = "index=True, "
|
||||||
if self.data_type in ["varchar", "VARCHAR"]:
|
if self.data_type in ("varchar", "VARCHAR"):
|
||||||
length = f"max_length={self.length}, "
|
length = f"max_length={self.length}, "
|
||||||
if self.data_type in ["decimal", "numeric"]:
|
elif self.data_type in ("decimal", "numeric"):
|
||||||
length_parts = []
|
length_parts = []
|
||||||
if self.max_digits:
|
if self.max_digits:
|
||||||
length_parts.append(f"max_digits={self.max_digits}")
|
length_parts.append(f"max_digits={self.max_digits}")
|
||||||
if self.decimal_places:
|
if self.decimal_places:
|
||||||
length_parts.append(f"decimal_places={self.decimal_places}")
|
length_parts.append(f"decimal_places={self.decimal_places}")
|
||||||
length = ", ".join(length_parts)+", "
|
if length_parts:
|
||||||
|
length = ", ".join(length_parts) + ", "
|
||||||
if self.null:
|
if self.null:
|
||||||
null = "null=True, "
|
null = "null=True, "
|
||||||
if self.default is not None:
|
if self.default is not None:
|
||||||
if self.data_type in ["tinyint", "INT"]:
|
if self.data_type in ("tinyint", "INT"):
|
||||||
default = f"default={'True' if self.default == '1' else 'False'}, "
|
default = f"default={'True' if self.default == '1' else 'False'}, "
|
||||||
elif self.data_type == "bool":
|
elif self.data_type == "bool":
|
||||||
default = f"default={'True' if self.default == 'true' else 'False'}, "
|
default = f"default={'True' if self.default == 'true' else 'False'}, "
|
||||||
elif self.data_type in ["datetime", "timestamptz", "TIMESTAMP"]:
|
elif self.data_type in ("datetime", "timestamptz", "TIMESTAMP"):
|
||||||
if "CURRENT_TIMESTAMP" == self.default:
|
if "CURRENT_TIMESTAMP" == self.default:
|
||||||
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
|
if "DEFAULT_GENERATED on update CURRENT_TIMESTAMP" == self.extra:
|
||||||
default = "auto_now=True, "
|
default = "auto_now=True, "
|
||||||
@ -76,7 +92,7 @@ class Column(BaseModel):
|
|||||||
class Inspect:
|
class Inspect:
|
||||||
_table_template = "class {table}(Model):\n"
|
_table_template = "class {table}(Model):\n"
|
||||||
|
|
||||||
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
|
def __init__(self, conn: BaseDBAsyncClient, tables: list[str] | None = None) -> None:
|
||||||
self.conn = conn
|
self.conn = conn
|
||||||
try:
|
try:
|
||||||
self.database = conn.database # type:ignore[attr-defined]
|
self.database = conn.database # type:ignore[attr-defined]
|
||||||
@ -85,7 +101,7 @@ class Inspect:
|
|||||||
self.tables = tables
|
self.tables = tables
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def field_map(self) -> dict:
|
def field_map(self) -> FieldMapDict:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def inspect(self) -> str:
|
async def inspect(self) -> str:
|
||||||
@ -103,10 +119,10 @@ class Inspect:
|
|||||||
tables.append(model + "\n".join(fields))
|
tables.append(model + "\n".join(fields))
|
||||||
return result + "\n\n\n".join(tables)
|
return result + "\n\n\n".join(tables)
|
||||||
|
|
||||||
async def get_columns(self, table: str) -> List[Column]:
|
async def get_columns(self, table: str) -> list[Column]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_all_tables(self) -> List[str]:
|
async def get_all_tables(self) -> list[str]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
from typing import List
|
from __future__ import annotations
|
||||||
|
|
||||||
from aerich.inspectdb import Column, Inspect
|
from aerich.inspectdb import Column, FieldMapDict, Inspect
|
||||||
|
|
||||||
|
|
||||||
class InspectMySQL(Inspect):
|
class InspectMySQL(Inspect):
|
||||||
@property
|
@property
|
||||||
def field_map(self) -> dict:
|
def field_map(self) -> FieldMapDict:
|
||||||
return {
|
return {
|
||||||
"int": self.int_field,
|
"int": self.int_field,
|
||||||
"smallint": self.smallint_field,
|
"smallint": self.smallint_field,
|
||||||
@ -24,12 +24,12 @@ class InspectMySQL(Inspect):
|
|||||||
"longblob": self.binary_field,
|
"longblob": self.binary_field,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_all_tables(self) -> List[str]:
|
async def get_all_tables(self) -> list[str]:
|
||||||
sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s"
|
sql = "select TABLE_NAME from information_schema.TABLES where TABLE_SCHEMA=%s"
|
||||||
ret = await self.conn.execute_query_dict(sql, [self.database])
|
ret = await self.conn.execute_query_dict(sql, [self.database])
|
||||||
return list(map(lambda x: x["TABLE_NAME"], ret))
|
return list(map(lambda x: x["TABLE_NAME"], ret))
|
||||||
|
|
||||||
async def get_columns(self, table: str) -> List[Column]:
|
async def get_columns(self, table: str) -> list[Column]:
|
||||||
columns = []
|
columns = []
|
||||||
sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME
|
sql = """select c.*, s.NON_UNIQUE, s.INDEX_NAME
|
||||||
from information_schema.COLUMNS c
|
from information_schema.COLUMNS c
|
||||||
|
@ -1,18 +1,20 @@
|
|||||||
from typing import TYPE_CHECKING, List, Optional
|
from __future__ import annotations
|
||||||
|
|
||||||
from aerich.inspectdb import Column, Inspect
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from aerich.inspectdb import Column, FieldMapDict, Inspect
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tortoise.backends.base_postgres.client import BasePostgresClient
|
from tortoise.backends.base_postgres.client import BasePostgresClient
|
||||||
|
|
||||||
|
|
||||||
class InspectPostgres(Inspect):
|
class InspectPostgres(Inspect):
|
||||||
def __init__(self, conn: "BasePostgresClient", tables: Optional[List[str]] = None) -> None:
|
def __init__(self, conn: "BasePostgresClient", tables: list[str] | None = None) -> None:
|
||||||
super().__init__(conn, tables)
|
super().__init__(conn, tables)
|
||||||
self.schema = conn.server_settings.get("schema") or "public"
|
self.schema = conn.server_settings.get("schema") or "public"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def field_map(self) -> dict:
|
def field_map(self) -> FieldMapDict:
|
||||||
return {
|
return {
|
||||||
"int4": self.int_field,
|
"int4": self.int_field,
|
||||||
"int8": self.int_field,
|
"int8": self.int_field,
|
||||||
@ -34,12 +36,12 @@ class InspectPostgres(Inspect):
|
|||||||
"timestamp": self.datetime_field,
|
"timestamp": self.datetime_field,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_all_tables(self) -> List[str]:
|
async def get_all_tables(self) -> list[str]:
|
||||||
sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2"
|
sql = "select TABLE_NAME from information_schema.TABLES where table_catalog=$1 and table_schema=$2"
|
||||||
ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
|
ret = await self.conn.execute_query_dict(sql, [self.database, self.schema])
|
||||||
return list(map(lambda x: x["table_name"], ret))
|
return list(map(lambda x: x["table_name"], ret))
|
||||||
|
|
||||||
async def get_columns(self, table: str) -> List[Column]:
|
async def get_columns(self, table: str) -> list[Column]:
|
||||||
columns = []
|
columns = []
|
||||||
sql = f"""select c.column_name,
|
sql = f"""select c.column_name,
|
||||||
col_description('public.{table}'::regclass, ordinal_position) as column_comment,
|
col_description('public.{table}'::regclass, ordinal_position) as column_comment,
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
from typing import Callable, Dict, List
|
from __future__ import annotations
|
||||||
|
|
||||||
from aerich.inspectdb import Column, Inspect
|
from aerich.inspectdb import Column, FieldMapDict, Inspect
|
||||||
|
|
||||||
|
|
||||||
class InspectSQLite(Inspect):
|
class InspectSQLite(Inspect):
|
||||||
@property
|
@property
|
||||||
def field_map(self) -> Dict[str, Callable[..., str]]:
|
def field_map(self) -> FieldMapDict:
|
||||||
return {
|
return {
|
||||||
"INTEGER": self.int_field,
|
"INTEGER": self.int_field,
|
||||||
"INT": self.bool_field,
|
"INT": self.bool_field,
|
||||||
@ -21,7 +21,7 @@ class InspectSQLite(Inspect):
|
|||||||
"BLOB": self.binary_field,
|
"BLOB": self.binary_field,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def get_columns(self, table: str) -> List[Column]:
|
async def get_columns(self, table: str) -> list[Column]:
|
||||||
columns = []
|
columns = []
|
||||||
sql = f"PRAGMA table_info({table})"
|
sql = f"PRAGMA table_info({table})"
|
||||||
ret = await self.conn.execute_query_dict(sql)
|
ret = await self.conn.execute_query_dict(sql)
|
||||||
@ -45,7 +45,7 @@ class InspectSQLite(Inspect):
|
|||||||
)
|
)
|
||||||
return columns
|
return columns
|
||||||
|
|
||||||
async def _get_columns_index(self, table: str) -> Dict[str, str]:
|
async def _get_columns_index(self, table: str) -> dict[str, str]:
|
||||||
sql = f"PRAGMA index_list ({table})"
|
sql = f"PRAGMA index_list ({table})"
|
||||||
indexes = await self.conn.execute_query_dict(sql)
|
indexes = await self.conn.execute_query_dict(sql)
|
||||||
ret = {}
|
ret = {}
|
||||||
@ -55,7 +55,7 @@ class InspectSQLite(Inspect):
|
|||||||
ret[index_info["name"]] = "unique" if index["unique"] else "index"
|
ret[index_info["name"]] = "unique" if index["unique"] else "index"
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async def get_all_tables(self) -> List[str]:
|
async def get_all_tables(self) -> list[str]:
|
||||||
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
|
sql = "select tbl_name from sqlite_master where type='table' and name!='sqlite_sequence'"
|
||||||
ret = await self.conn.execute_query_dict(sql)
|
ret = await self.conn.execute_query_dict(sql)
|
||||||
return list(map(lambda x: x["tbl_name"], ret))
|
return list(map(lambda x: x["tbl_name"], ret))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user