Add type hints for ddl and inspectdb
This commit is contained in:
parent
51117867a6
commit
dd11bed5a0
@ -1,5 +1,5 @@
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Type
|
from typing import Any, List, Type, cast
|
||||||
|
|
||||||
from tortoise import BaseDBAsyncClient, Model
|
from tortoise import BaseDBAsyncClient, Model
|
||||||
from tortoise.backends.base.schema_generator import BaseSchemaGenerator
|
from tortoise.backends.base.schema_generator import BaseSchemaGenerator
|
||||||
@ -35,25 +35,26 @@ class BaseDDL:
|
|||||||
)
|
)
|
||||||
_RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"'
|
_RENAME_TABLE_TEMPLATE = 'ALTER TABLE "{old_table_name}" RENAME TO "{new_table_name}"'
|
||||||
|
|
||||||
def __init__(self, client: "BaseDBAsyncClient"):
|
def __init__(self, client: "BaseDBAsyncClient") -> None:
|
||||||
self.client = client
|
self.client = client
|
||||||
self.schema_generator = self.schema_generator_cls(client)
|
self.schema_generator = self.schema_generator_cls(client)
|
||||||
|
|
||||||
def create_table(self, model: "Type[Model]"):
|
def create_table(self, model: "Type[Model]") -> str:
|
||||||
return self.schema_generator._get_table_sql(model, True)["table_creation_string"].rstrip(
|
return self.schema_generator._get_table_sql(model, True)["table_creation_string"].rstrip(
|
||||||
";"
|
";"
|
||||||
)
|
)
|
||||||
|
|
||||||
def drop_table(self, table_name: str):
|
def drop_table(self, table_name: str) -> str:
|
||||||
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
|
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
|
||||||
|
|
||||||
def create_m2m(
|
def create_m2m(
|
||||||
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
|
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
|
||||||
):
|
) -> str:
|
||||||
through = field_describe.get("through")
|
through = cast(str, field_describe.get("through"))
|
||||||
description = field_describe.get("description")
|
description = field_describe.get("description")
|
||||||
reference_id = reference_table_describe.get("pk_field").get("db_column")
|
pk_field = cast(dict, reference_table_describe.get("pk_field"))
|
||||||
db_field_types = reference_table_describe.get("pk_field").get("db_field_types")
|
reference_id = pk_field.get("db_column")
|
||||||
|
db_field_types = cast(dict, pk_field.get("db_field_types"))
|
||||||
return self._M2M_TABLE_TEMPLATE.format(
|
return self._M2M_TABLE_TEMPLATE.format(
|
||||||
table_name=through,
|
table_name=through,
|
||||||
backward_table=model._meta.db_table,
|
backward_table=model._meta.db_table,
|
||||||
@ -73,15 +74,15 @@ class BaseDDL:
|
|||||||
else "",
|
else "",
|
||||||
)
|
)
|
||||||
|
|
||||||
def drop_m2m(self, table_name: str):
|
def drop_m2m(self, table_name: str) -> str:
|
||||||
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
|
return self._DROP_TABLE_TEMPLATE.format(table_name=table_name)
|
||||||
|
|
||||||
def _get_default(self, model: "Type[Model]", field_describe: dict):
|
def _get_default(self, model: "Type[Model]", field_describe: dict) -> Any:
|
||||||
db_table = model._meta.db_table
|
db_table = model._meta.db_table
|
||||||
default = field_describe.get("default")
|
default = field_describe.get("default")
|
||||||
if isinstance(default, Enum):
|
if isinstance(default, Enum):
|
||||||
default = default.value
|
default = default.value
|
||||||
db_column = field_describe.get("db_column")
|
db_column = cast(str, field_describe.get("db_column"))
|
||||||
auto_now_add = field_describe.get("auto_now_add", False)
|
auto_now_add = field_describe.get("auto_now_add", False)
|
||||||
auto_now = field_describe.get("auto_now", False)
|
auto_now = field_describe.get("auto_now", False)
|
||||||
if default is not None or auto_now_add:
|
if default is not None or auto_now_add:
|
||||||
@ -106,25 +107,34 @@ class BaseDDL:
|
|||||||
default = None
|
default = None
|
||||||
return default
|
return default
|
||||||
|
|
||||||
def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
|
def add_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str:
|
||||||
|
return self._add_or_modify_column(model, field_describe, is_pk)
|
||||||
|
|
||||||
|
def _add_or_modify_column(self, model, field_describe: dict, is_pk: bool, modify=False) -> str:
|
||||||
db_table = model._meta.db_table
|
db_table = model._meta.db_table
|
||||||
description = field_describe.get("description")
|
description = field_describe.get("description")
|
||||||
db_column = field_describe.get("db_column")
|
db_column = cast(str, field_describe.get("db_column"))
|
||||||
db_field_types = field_describe.get("db_field_types")
|
db_field_types = cast(dict, field_describe.get("db_field_types"))
|
||||||
default = self._get_default(model, field_describe)
|
default = self._get_default(model, field_describe)
|
||||||
if default is None:
|
if default is None:
|
||||||
default = ""
|
default = ""
|
||||||
return self._ADD_COLUMN_TEMPLATE.format(
|
if modify:
|
||||||
|
unique = ""
|
||||||
|
template = self._MODIFY_COLUMN_TEMPLATE
|
||||||
|
else:
|
||||||
|
unique = "UNIQUE" if field_describe.get("unique") else ""
|
||||||
|
template = self._ADD_COLUMN_TEMPLATE
|
||||||
|
return template.format(
|
||||||
table_name=db_table,
|
table_name=db_table,
|
||||||
column=self.schema_generator._create_string(
|
column=self.schema_generator._create_string(
|
||||||
db_column=db_column,
|
db_column=db_column,
|
||||||
field_type=db_field_types.get(self.DIALECT, db_field_types.get("")),
|
field_type=db_field_types.get(self.DIALECT, db_field_types.get("")),
|
||||||
nullable="NOT NULL" if not field_describe.get("nullable") else "",
|
nullable="NOT NULL" if not field_describe.get("nullable") else "",
|
||||||
unique="UNIQUE" if field_describe.get("unique") else "",
|
unique=unique,
|
||||||
comment=self.schema_generator._column_comment_generator(
|
comment=self.schema_generator._column_comment_generator(
|
||||||
table=db_table,
|
table=db_table,
|
||||||
column=db_column,
|
column=db_column,
|
||||||
comment=field_describe.get("description"),
|
comment=description,
|
||||||
)
|
)
|
||||||
if description
|
if description
|
||||||
else "",
|
else "",
|
||||||
@ -133,37 +143,17 @@ class BaseDDL:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def drop_column(self, model: "Type[Model]", column_name: str):
|
def drop_column(self, model: "Type[Model]", column_name: str) -> str:
|
||||||
return self._DROP_COLUMN_TEMPLATE.format(
|
return self._DROP_COLUMN_TEMPLATE.format(
|
||||||
table_name=model._meta.db_table, column_name=column_name
|
table_name=model._meta.db_table, column_name=column_name
|
||||||
)
|
)
|
||||||
|
|
||||||
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
|
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str:
|
||||||
db_table = model._meta.db_table
|
return self._add_or_modify_column(model, field_describe, is_pk, modify=True)
|
||||||
db_field_types = field_describe.get("db_field_types")
|
|
||||||
default = self._get_default(model, field_describe)
|
|
||||||
if default is None:
|
|
||||||
default = ""
|
|
||||||
return self._MODIFY_COLUMN_TEMPLATE.format(
|
|
||||||
table_name=db_table,
|
|
||||||
column=self.schema_generator._create_string(
|
|
||||||
db_column=field_describe.get("db_column"),
|
|
||||||
field_type=db_field_types.get(self.DIALECT) or db_field_types.get(""),
|
|
||||||
nullable="NOT NULL" if not field_describe.get("nullable") else "",
|
|
||||||
unique="",
|
|
||||||
comment=self.schema_generator._column_comment_generator(
|
|
||||||
table=db_table,
|
|
||||||
column=field_describe.get("db_column"),
|
|
||||||
comment=field_describe.get("description"),
|
|
||||||
)
|
|
||||||
if field_describe.get("description")
|
|
||||||
else "",
|
|
||||||
is_primary_key=is_pk,
|
|
||||||
default=default,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def rename_column(self, model: "Type[Model]", old_column_name: str, new_column_name: str):
|
def rename_column(
|
||||||
|
self, model: "Type[Model]", old_column_name: str, new_column_name: str
|
||||||
|
) -> str:
|
||||||
return self._RENAME_COLUMN_TEMPLATE.format(
|
return self._RENAME_COLUMN_TEMPLATE.format(
|
||||||
table_name=model._meta.db_table,
|
table_name=model._meta.db_table,
|
||||||
old_column_name=old_column_name,
|
old_column_name=old_column_name,
|
||||||
@ -172,7 +162,7 @@ class BaseDDL:
|
|||||||
|
|
||||||
def change_column(
|
def change_column(
|
||||||
self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str
|
self, model: "Type[Model]", old_column_name: str, new_column_name: str, new_column_type: str
|
||||||
):
|
) -> str:
|
||||||
return self._CHANGE_COLUMN_TEMPLATE.format(
|
return self._CHANGE_COLUMN_TEMPLATE.format(
|
||||||
table_name=model._meta.db_table,
|
table_name=model._meta.db_table,
|
||||||
old_column_name=old_column_name,
|
old_column_name=old_column_name,
|
||||||
@ -180,7 +170,7 @@ class BaseDDL:
|
|||||||
new_column_type=new_column_type,
|
new_column_type=new_column_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False):
|
def add_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str:
|
||||||
return self._ADD_INDEX_TEMPLATE.format(
|
return self._ADD_INDEX_TEMPLATE.format(
|
||||||
unique="UNIQUE " if unique else "",
|
unique="UNIQUE " if unique else "",
|
||||||
index_name=self.schema_generator._generate_index_name(
|
index_name=self.schema_generator._generate_index_name(
|
||||||
@ -190,7 +180,7 @@ class BaseDDL:
|
|||||||
column_names=", ".join(self.schema_generator.quote(f) for f in field_names),
|
column_names=", ".join(self.schema_generator.quote(f) for f in field_names),
|
||||||
)
|
)
|
||||||
|
|
||||||
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False):
|
def drop_index(self, model: "Type[Model]", field_names: List[str], unique=False) -> str:
|
||||||
return self._DROP_INDEX_TEMPLATE.format(
|
return self._DROP_INDEX_TEMPLATE.format(
|
||||||
index_name=self.schema_generator._generate_index_name(
|
index_name=self.schema_generator._generate_index_name(
|
||||||
"idx" if not unique else "uid", model, field_names
|
"idx" if not unique else "uid", model, field_names
|
||||||
@ -198,45 +188,52 @@ class BaseDDL:
|
|||||||
table_name=model._meta.db_table,
|
table_name=model._meta.db_table,
|
||||||
)
|
)
|
||||||
|
|
||||||
def drop_index_by_name(self, model: "Type[Model]", index_name: str):
|
def drop_index_by_name(self, model: "Type[Model]", index_name: str) -> str:
|
||||||
return self._DROP_INDEX_TEMPLATE.format(
|
return self._DROP_INDEX_TEMPLATE.format(
|
||||||
index_name=index_name,
|
index_name=index_name,
|
||||||
table_name=model._meta.db_table,
|
table_name=model._meta.db_table,
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict):
|
def _generate_fk_name(
|
||||||
|
self, db_table, field_describe: dict, reference_table_describe: dict
|
||||||
|
) -> str:
|
||||||
|
"""Generate fk name"""
|
||||||
|
db_column = cast(str, field_describe.get("raw_field"))
|
||||||
|
pk_field = cast(dict, reference_table_describe.get("pk_field"))
|
||||||
|
to_field = cast(str, pk_field.get("db_column"))
|
||||||
|
to_table = cast(str, reference_table_describe.get("table"))
|
||||||
|
return self.schema_generator._generate_fk_name(
|
||||||
|
from_table=db_table,
|
||||||
|
from_field=db_column,
|
||||||
|
to_table=to_table,
|
||||||
|
to_field=to_field,
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_fk(
|
||||||
|
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
|
||||||
|
) -> str:
|
||||||
db_table = model._meta.db_table
|
db_table = model._meta.db_table
|
||||||
|
|
||||||
db_column = field_describe.get("raw_field")
|
db_column = field_describe.get("raw_field")
|
||||||
reference_id = reference_table_describe.get("pk_field").get("db_column")
|
pk_field = cast(dict, reference_table_describe.get("pk_field"))
|
||||||
fk_name = self.schema_generator._generate_fk_name(
|
reference_id = pk_field.get("db_column")
|
||||||
from_table=db_table,
|
|
||||||
from_field=db_column,
|
|
||||||
to_table=reference_table_describe.get("table"),
|
|
||||||
to_field=reference_table_describe.get("pk_field").get("db_column"),
|
|
||||||
)
|
|
||||||
return self._ADD_FK_TEMPLATE.format(
|
return self._ADD_FK_TEMPLATE.format(
|
||||||
table_name=db_table,
|
table_name=db_table,
|
||||||
fk_name=fk_name,
|
fk_name=self._generate_fk_name(db_table, field_describe, reference_table_describe),
|
||||||
db_column=db_column,
|
db_column=db_column,
|
||||||
table=reference_table_describe.get("table"),
|
table=reference_table_describe.get("table"),
|
||||||
field=reference_id,
|
field=reference_id,
|
||||||
on_delete=field_describe.get("on_delete"),
|
on_delete=field_describe.get("on_delete"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def drop_fk(self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict):
|
def drop_fk(
|
||||||
|
self, model: "Type[Model]", field_describe: dict, reference_table_describe: dict
|
||||||
|
) -> str:
|
||||||
db_table = model._meta.db_table
|
db_table = model._meta.db_table
|
||||||
return self._DROP_FK_TEMPLATE.format(
|
fk_name = self._generate_fk_name(db_table, field_describe, reference_table_describe)
|
||||||
table_name=db_table,
|
return self._DROP_FK_TEMPLATE.format(table_name=db_table, fk_name=fk_name)
|
||||||
fk_name=self.schema_generator._generate_fk_name(
|
|
||||||
from_table=db_table,
|
|
||||||
from_field=field_describe.get("raw_field"),
|
|
||||||
to_table=reference_table_describe.get("table"),
|
|
||||||
to_field=reference_table_describe.get("pk_field").get("db_column"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def alter_column_default(self, model: "Type[Model]", field_describe: dict):
|
def alter_column_default(self, model: "Type[Model]", field_describe: dict) -> str:
|
||||||
db_table = model._meta.db_table
|
db_table = model._meta.db_table
|
||||||
default = self._get_default(model, field_describe)
|
default = self._get_default(model, field_describe)
|
||||||
return self._ALTER_DEFAULT_TEMPLATE.format(
|
return self._ALTER_DEFAULT_TEMPLATE.format(
|
||||||
@ -245,13 +242,13 @@ class BaseDDL:
|
|||||||
default="SET" + default if default is not None else "DROP DEFAULT",
|
default="SET" + default if default is not None else "DROP DEFAULT",
|
||||||
)
|
)
|
||||||
|
|
||||||
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
|
def alter_column_null(self, model: "Type[Model]", field_describe: dict) -> str:
|
||||||
return self.modify_column(model, field_describe)
|
return self.modify_column(model, field_describe)
|
||||||
|
|
||||||
def set_comment(self, model: "Type[Model]", field_describe: dict):
|
def set_comment(self, model: "Type[Model]", field_describe: dict) -> str:
|
||||||
return self.modify_column(model, field_describe)
|
return self.modify_column(model, field_describe)
|
||||||
|
|
||||||
def rename_table(self, model: "Type[Model]", old_table_name: str, new_table_name: str):
|
def rename_table(self, model: "Type[Model]", old_table_name: str, new_table_name: str) -> str:
|
||||||
db_table = model._meta.db_table
|
db_table = model._meta.db_table
|
||||||
return self._RENAME_TABLE_TEMPLATE.format(
|
return self._RENAME_TABLE_TEMPLATE.format(
|
||||||
table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name
|
table_name=db_table, old_table_name=old_table_name, new_table_name=new_table_name
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Type
|
from typing import Type, cast
|
||||||
|
|
||||||
from tortoise import Model
|
from tortoise import Model
|
||||||
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
|
from tortoise.backends.asyncpg.schema_generator import AsyncpgSchemaGenerator
|
||||||
@ -18,7 +18,7 @@ class PostgresDDL(BaseDDL):
|
|||||||
_SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}'
|
_SET_COMMENT_TEMPLATE = 'COMMENT ON COLUMN "{table_name}"."{column}" IS {comment}'
|
||||||
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"'
|
_DROP_FK_TEMPLATE = 'ALTER TABLE "{table_name}" DROP CONSTRAINT "{fk_name}"'
|
||||||
|
|
||||||
def alter_column_null(self, model: "Type[Model]", field_describe: dict):
|
def alter_column_null(self, model: "Type[Model]", field_describe: dict) -> str:
|
||||||
db_table = model._meta.db_table
|
db_table = model._meta.db_table
|
||||||
return self._ALTER_NULL_TEMPLATE.format(
|
return self._ALTER_NULL_TEMPLATE.format(
|
||||||
table_name=db_table,
|
table_name=db_table,
|
||||||
@ -26,9 +26,9 @@ class PostgresDDL(BaseDDL):
|
|||||||
set_drop="DROP" if field_describe.get("nullable") else "SET",
|
set_drop="DROP" if field_describe.get("nullable") else "SET",
|
||||||
)
|
)
|
||||||
|
|
||||||
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False):
|
def modify_column(self, model: "Type[Model]", field_describe: dict, is_pk: bool = False) -> str:
|
||||||
db_table = model._meta.db_table
|
db_table = model._meta.db_table
|
||||||
db_field_types = field_describe.get("db_field_types")
|
db_field_types = cast(dict, field_describe.get("db_field_types"))
|
||||||
db_column = field_describe.get("db_column")
|
db_column = field_describe.get("db_column")
|
||||||
datatype = db_field_types.get(self.DIALECT) or db_field_types.get("")
|
datatype = db_field_types.get(self.DIALECT) or db_field_types.get("")
|
||||||
return self._MODIFY_COLUMN_TEMPLATE.format(
|
return self._MODIFY_COLUMN_TEMPLATE.format(
|
||||||
@ -38,7 +38,7 @@ class PostgresDDL(BaseDDL):
|
|||||||
using=f' USING "{db_column}"::{datatype}',
|
using=f' USING "{db_column}"::{datatype}',
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_comment(self, model: "Type[Model]", field_describe: dict):
|
def set_comment(self, model: "Type[Model]", field_describe: dict) -> str:
|
||||||
db_table = model._meta.db_table
|
db_table = model._meta.db_table
|
||||||
return self._SET_COMMENT_TEMPLATE.format(
|
return self._SET_COMMENT_TEMPLATE.format(
|
||||||
table_name=db_table,
|
table_name=db_table,
|
||||||
|
@ -79,7 +79,7 @@ class Inspect:
|
|||||||
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
|
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
|
||||||
self.conn = conn
|
self.conn = conn
|
||||||
try:
|
try:
|
||||||
self.database = conn.database
|
self.database = conn.database # type:ignore[attr-defined]
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
self.tables = tables
|
self.tables = tables
|
||||||
|
@ -60,7 +60,8 @@ where c.TABLE_SCHEMA = %s
|
|||||||
comment=row["COLUMN_COMMENT"],
|
comment=row["COLUMN_COMMENT"],
|
||||||
unique=row["COLUMN_KEY"] == "UNI",
|
unique=row["COLUMN_KEY"] == "UNI",
|
||||||
extra=row["EXTRA"],
|
extra=row["EXTRA"],
|
||||||
unque=unique,
|
# TODO: why `unque`?
|
||||||
|
unque=unique, # type:ignore
|
||||||
index=index,
|
index=index,
|
||||||
length=row["CHARACTER_MAXIMUM_LENGTH"],
|
length=row["CHARACTER_MAXIMUM_LENGTH"],
|
||||||
max_digits=row["NUMERIC_PRECISION"],
|
max_digits=row["NUMERIC_PRECISION"],
|
||||||
|
@ -1,14 +1,15 @@
|
|||||||
from typing import List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from tortoise import BaseDBAsyncClient
|
|
||||||
|
|
||||||
from aerich.inspectdb import Column, Inspect
|
from aerich.inspectdb import Column, Inspect
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from tortoise.backends.base_postgres.client import BasePostgresClient
|
||||||
|
|
||||||
|
|
||||||
class InspectPostgres(Inspect):
|
class InspectPostgres(Inspect):
|
||||||
def __init__(self, conn: BaseDBAsyncClient, tables: Optional[List[str]] = None):
|
def __init__(self, conn: "BasePostgresClient", tables: Optional[List[str]] = None) -> None:
|
||||||
super().__init__(conn, tables)
|
super().__init__(conn, tables)
|
||||||
self.schema = self.conn.server_settings.get("schema") or "public"
|
self.schema = conn.server_settings.get("schema") or "public"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def field_map(self) -> dict:
|
def field_map(self) -> dict:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user